upresnet10_3.prototxt生成スクリプトが正しくないネットワークを吐き出していたのを修正、スクリプトのファイル名を変更

This commit is contained in:
lltcggie 2018-10-25 03:34:27 +09:00
parent 706ee0d32b
commit 65ad2de604

View File

@ -81,35 +81,35 @@ def DeConv(name, bottom, num_output, kernel_size, stride = 1, pad = 0, nobias =
def Sigmoid(name, bottom):
top_name = name
top_name = name + '_sigmoid'
# ReLU
sigmoid_layer = caffe_pb2.LayerParameter()
sigmoid_layer.name = name + '_sigmoid'
sigmoid_layer.type = 'Sigmoid'
sigmoid_layer.bottom.extend([top_name])
sigmoid_layer.bottom.extend([bottom])
sigmoid_layer.top.extend([top_name])
return sigmoid_layer
def Relu(name, bottom):
top_name = name
top_name = name + '_relu'
# ReLU
relu_layer = caffe_pb2.LayerParameter()
relu_layer.name = name + '_relu'
relu_layer.type = 'ReLU'
relu_layer.bottom.extend([top_name])
relu_layer.bottom.extend([bottom])
relu_layer.top.extend([top_name])
return relu_layer
def LeakyRelu(name, bottom, negative_slope):
top_name = name
top_name = name + '_relu'
# LeakyRelu
relu_layer = caffe_pb2.LayerParameter()
relu_layer.name = name + '_relu'
relu_layer.type = 'ReLU'
relu_layer.relu_param.negative_slope = negative_slope
relu_layer.bottom.extend([top_name])
relu_layer.bottom.extend([bottom])
relu_layer.top.extend([top_name])
return relu_layer
@ -122,11 +122,12 @@ def ConvLeakyRelu(name, bottom, num_output, kernel_size, stride = 1, pad = 0, ne
def GlobalAvgPool(name, bottom, stride = 1, pad = 0):
top_name = name + '_globalavgpool'
layer = caffe_pb2.LayerParameter()
layer.name = name + '_globalavgpool'
layer.type = 'Pooling'
layer.bottom.extend([bottom])
layer.top.extend([name])
layer.top.extend([top_name])
layer.pooling_param.pool = caffe_pb2.PoolingParameter.AVE
layer.pooling_param.stride = stride
layer.pooling_param.pad = pad
@ -240,15 +241,6 @@ def main(args):
with open(args.output, 'w') as f:
f.write(pb.text_format.MessageToString(model))
# caffe.set_mode_cpu()
# net = caffe.Net(args.output, caffe.TEST)
# input_data = np.random.random_sample(net.blobs['input'].data.shape)
# net.blobs['input'].data[...] = input_data
# ret = net.forward()
# print input_data
# print ret
# print input_data.shape
# print ret['/conv_post'].shape
if __name__ == '__main__':
parser = ArgumentParser()