mirror of
https://github.com/lltcggie/waifu2x-caffe.git
synced 2025-06-25 21:22:47 +00:00
upresnet10_3.prototxt生成スクリプトが正しくないネットワークを吐き出していたのを修正、スクリプトのファイル名を変更
This commit is contained in:
parent
706ee0d32b
commit
65ad2de604
@ -81,35 +81,35 @@ def DeConv(name, bottom, num_output, kernel_size, stride = 1, pad = 0, nobias =
|
|||||||
|
|
||||||
|
|
||||||
def Sigmoid(name, bottom):
|
def Sigmoid(name, bottom):
|
||||||
top_name = name
|
top_name = name + '_sigmoid'
|
||||||
# ReLU
|
# ReLU
|
||||||
sigmoid_layer = caffe_pb2.LayerParameter()
|
sigmoid_layer = caffe_pb2.LayerParameter()
|
||||||
sigmoid_layer.name = name + '_sigmoid'
|
sigmoid_layer.name = name + '_sigmoid'
|
||||||
sigmoid_layer.type = 'Sigmoid'
|
sigmoid_layer.type = 'Sigmoid'
|
||||||
sigmoid_layer.bottom.extend([top_name])
|
sigmoid_layer.bottom.extend([bottom])
|
||||||
sigmoid_layer.top.extend([top_name])
|
sigmoid_layer.top.extend([top_name])
|
||||||
return sigmoid_layer
|
return sigmoid_layer
|
||||||
|
|
||||||
|
|
||||||
def Relu(name, bottom):
|
def Relu(name, bottom):
|
||||||
top_name = name
|
top_name = name + '_relu'
|
||||||
# ReLU
|
# ReLU
|
||||||
relu_layer = caffe_pb2.LayerParameter()
|
relu_layer = caffe_pb2.LayerParameter()
|
||||||
relu_layer.name = name + '_relu'
|
relu_layer.name = name + '_relu'
|
||||||
relu_layer.type = 'ReLU'
|
relu_layer.type = 'ReLU'
|
||||||
relu_layer.bottom.extend([top_name])
|
relu_layer.bottom.extend([bottom])
|
||||||
relu_layer.top.extend([top_name])
|
relu_layer.top.extend([top_name])
|
||||||
return relu_layer
|
return relu_layer
|
||||||
|
|
||||||
|
|
||||||
def LeakyRelu(name, bottom, negative_slope):
|
def LeakyRelu(name, bottom, negative_slope):
|
||||||
top_name = name
|
top_name = name + '_relu'
|
||||||
# LeakyRelu
|
# LeakyRelu
|
||||||
relu_layer = caffe_pb2.LayerParameter()
|
relu_layer = caffe_pb2.LayerParameter()
|
||||||
relu_layer.name = name + '_relu'
|
relu_layer.name = name + '_relu'
|
||||||
relu_layer.type = 'ReLU'
|
relu_layer.type = 'ReLU'
|
||||||
relu_layer.relu_param.negative_slope = negative_slope
|
relu_layer.relu_param.negative_slope = negative_slope
|
||||||
relu_layer.bottom.extend([top_name])
|
relu_layer.bottom.extend([bottom])
|
||||||
relu_layer.top.extend([top_name])
|
relu_layer.top.extend([top_name])
|
||||||
return relu_layer
|
return relu_layer
|
||||||
|
|
||||||
@ -122,11 +122,12 @@ def ConvLeakyRelu(name, bottom, num_output, kernel_size, stride = 1, pad = 0, ne
|
|||||||
|
|
||||||
|
|
||||||
def GlobalAvgPool(name, bottom, stride = 1, pad = 0):
|
def GlobalAvgPool(name, bottom, stride = 1, pad = 0):
|
||||||
|
top_name = name + '_globalavgpool'
|
||||||
layer = caffe_pb2.LayerParameter()
|
layer = caffe_pb2.LayerParameter()
|
||||||
layer.name = name + '_globalavgpool'
|
layer.name = name + '_globalavgpool'
|
||||||
layer.type = 'Pooling'
|
layer.type = 'Pooling'
|
||||||
layer.bottom.extend([bottom])
|
layer.bottom.extend([bottom])
|
||||||
layer.top.extend([name])
|
layer.top.extend([top_name])
|
||||||
layer.pooling_param.pool = caffe_pb2.PoolingParameter.AVE
|
layer.pooling_param.pool = caffe_pb2.PoolingParameter.AVE
|
||||||
layer.pooling_param.stride = stride
|
layer.pooling_param.stride = stride
|
||||||
layer.pooling_param.pad = pad
|
layer.pooling_param.pad = pad
|
||||||
@ -240,15 +241,6 @@ def main(args):
|
|||||||
with open(args.output, 'w') as f:
|
with open(args.output, 'w') as f:
|
||||||
f.write(pb.text_format.MessageToString(model))
|
f.write(pb.text_format.MessageToString(model))
|
||||||
|
|
||||||
# caffe.set_mode_cpu()
|
|
||||||
# net = caffe.Net(args.output, caffe.TEST)
|
|
||||||
# input_data = np.random.random_sample(net.blobs['input'].data.shape)
|
|
||||||
# net.blobs['input'].data[...] = input_data
|
|
||||||
# ret = net.forward()
|
|
||||||
# print input_data
|
|
||||||
# print ret
|
|
||||||
# print input_data.shape
|
|
||||||
# print ret['/conv_post'].shape
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
Loading…
x
Reference in New Issue
Block a user