From 1e3c4796d873bd690db64148fe5d9d3720619179 Mon Sep 17 00:00:00 2001 From: lltcggie Date: Thu, 25 Oct 2018 03:36:01 +0900 Subject: [PATCH] =?UTF-8?q?UpResNet10=E3=81=AEcaffemodel=E3=82=92waifu2x-c?= =?UTF-8?q?hainer=E3=81=AEUpResNet10=E3=83=A2=E3=83=87=E3=83=AB=E3=81=8B?= =?UTF-8?q?=E3=82=89=E7=94=9F=E6=88=90=E3=81=99=E3=82=8B=E3=82=B9=E3=82=AF?= =?UTF-8?q?=E3=83=AA=E3=83=97=E3=83=88=E8=BF=BD=E5=8A=A0=E3=80=81=E7=94=9F?= =?UTF-8?q?=E6=88=90=E3=81=97=E3=81=9FUpResNet10=E3=83=A2=E3=83=87?= =?UTF-8?q?=E3=83=AB=E3=81=A8waifu2x-chainer=E3=81=AEUpResNet10=E3=83=A2?= =?UTF-8?q?=E3=83=87=E3=83=AB=E3=81=AE=E5=87=BA=E5=8A=9B=E3=82=92=E6=AF=94?= =?UTF-8?q?=E8=BC=83=E3=81=99=E3=82=8B=E3=82=B9=E3=82=AF=E3=83=AA=E3=83=97?= =?UTF-8?q?=E3=83=88=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- appendix/check_diff_upresnet10.py | 63 +++++++++++++++++++++++++++ appendix/gen_caffemodel_upresnet10.py | 61 ++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 appendix/check_diff_upresnet10.py create mode 100644 appendix/gen_caffemodel_upresnet10.py diff --git a/appendix/check_diff_upresnet10.py b/appendix/check_diff_upresnet10.py new file mode 100644 index 0000000..692f5e0 --- /dev/null +++ b/appendix/check_diff_upresnet10.py @@ -0,0 +1,63 @@ +import os +import os.path as osp +import sys +import google.protobuf as pb +from argparse import ArgumentParser +import numpy as np +import shutil + +import caffe +from caffe.proto import caffe_pb2 + +sys.path.append('waifu2x-chainer') +from lib import srcnn +import chainer + + +def main(): + caffe.set_mode_cpu() + + model_name = 'UpResNet10' + model_dir = 'waifu2x-chainer/models/{}'.format(model_name.lower()) + model_class = srcnn.archs[model_name] + + for filename in os.listdir(model_dir): + basename, ext = os.path.splitext(filename) + if ext == '.npz': + model_path = os.path.join(model_dir, filename) + print(model_path) + channels = 3 if 'rgb' in filename else 1 + model = model_class(channels) + chainer.serializers.load_npz(model_path, model) + + model.to_cpu() + + params = {} + for path, param in model.namedparams(): + params[path] = param.array + + net = caffe.Net('upresnet10_3.prototxt', caffe.TEST) + for key in net.params: + l = len(net.params[key]) + net.params[key][0].data[...] = params[key + '/W'] + if l >= 2: + net.params[key][1].data[...] = params[key + '/b'] + + input_data = np.empty(net.blobs['input'].data.shape, dtype=np.float32) + input_data[...] = np.random.random_sample(net.blobs['input'].data.shape) + + net.blobs['input'].data[...] = input_data + ret = net.forward() + + input_data = np.empty(net.blobs['input'].data.shape, dtype=np.float32) + input_data[...] = np.random.random_sample(net.blobs['input'].data.shape) + + net.blobs['input'].data[...] = input_data + ret = net.forward() + + batch_y = model(input_data) + print(batch_y.array - ret['/conv_post']) + +if __name__ == '__main__': + caffe.init_log(3) + main() diff --git a/appendix/gen_caffemodel_upresnet10.py b/appendix/gen_caffemodel_upresnet10.py new file mode 100644 index 0000000..0c5a1cb --- /dev/null +++ b/appendix/gen_caffemodel_upresnet10.py @@ -0,0 +1,61 @@ +import os +import os.path as osp +import sys +import google.protobuf as pb +from argparse import ArgumentParser +import numpy as np +import shutil + +import caffe +from caffe.proto import caffe_pb2 + +sys.path.append('waifu2x-chainer') +from lib import srcnn +import chainer + +fname_convert_table = { + 'anime_style_noise0_scale_rgb': 'noise0_scale2.0x_model', + 'anime_style_noise1_scale_rgb': 'noise1_scale2.0x_model', + 'anime_style_noise2_scale_rgb': 'noise2_scale2.0x_model', + 'anime_style_noise3_scale_rgb': 'noise3_scale2.0x_model', + 'anime_style_scale_rgb': 'scale2.0x_model', +} + +def main(): + caffe.set_mode_cpu() + + model_name = 'UpResNet10' + model_dir = 'waifu2x-chainer/models/{}'.format(model_name.lower()) + model_class = srcnn.archs[model_name] + + for filename in os.listdir(model_dir): + basename, ext = os.path.splitext(filename) + if ext == '.npz': + model_path = os.path.join(model_dir, filename) + print(model_path) + channels = 3 if 'rgb' in filename else 1 + model = model_class(channels) + size = 64 + model.offset + data = np.zeros((1, channels, size, size), dtype=np.float32) + x = chainer.Variable(data) + chainer.serializers.load_npz(model_path, model) + + params = {} + for path, param in model.namedparams(): + params[path] = param.array + + net = caffe.Net('upresnet10_3.prototxt', caffe.TEST) + for key in net.params: + l = len(net.params[key]) + net.params[key][0].data[...] = params[key + '/W'] + if l >= 2: + net.params[key][1].data[...] = params[key + '/b'] + + prototxt_path = '{}.prototxt'.format(fname_convert_table[basename]) + caffemodel_path = '{}.json.caffemodel'.format(fname_convert_table[basename]) + net.save(caffemodel_path) + shutil.copy('upresnet10_3.prototxt', prototxt_path) + +if __name__ == '__main__': + caffe.init_log(3) + main()