UpResNet10のcaffemodelをwaifu2x-chainerのUpResNet10モデルから生成するスクリプト追加、生成したUpResNet10モデルとwaifu2x-chainerのUpResNet10モデルの出力を比較するスクリプト追加

This commit is contained in:
lltcggie 2018-10-25 03:36:01 +09:00
parent 65ad2de604
commit 1e3c4796d8
2 changed files with 124 additions and 0 deletions

View File

@ -0,0 +1,63 @@
import os
import os.path as osp
import sys
import google.protobuf as pb
from argparse import ArgumentParser
import numpy as np
import shutil
import caffe
from caffe.proto import caffe_pb2
sys.path.append('waifu2x-chainer')
from lib import srcnn
import chainer
def main():
caffe.set_mode_cpu()
model_name = 'UpResNet10'
model_dir = 'waifu2x-chainer/models/{}'.format(model_name.lower())
model_class = srcnn.archs[model_name]
for filename in os.listdir(model_dir):
basename, ext = os.path.splitext(filename)
if ext == '.npz':
model_path = os.path.join(model_dir, filename)
print(model_path)
channels = 3 if 'rgb' in filename else 1
model = model_class(channels)
chainer.serializers.load_npz(model_path, model)
model.to_cpu()
params = {}
for path, param in model.namedparams():
params[path] = param.array
net = caffe.Net('upresnet10_3.prototxt', caffe.TEST)
for key in net.params:
l = len(net.params[key])
net.params[key][0].data[...] = params[key + '/W']
if l >= 2:
net.params[key][1].data[...] = params[key + '/b']
input_data = np.empty(net.blobs['input'].data.shape, dtype=np.float32)
input_data[...] = np.random.random_sample(net.blobs['input'].data.shape)
net.blobs['input'].data[...] = input_data
ret = net.forward()
input_data = np.empty(net.blobs['input'].data.shape, dtype=np.float32)
input_data[...] = np.random.random_sample(net.blobs['input'].data.shape)
net.blobs['input'].data[...] = input_data
ret = net.forward()
batch_y = model(input_data)
print(batch_y.array - ret['/conv_post'])
if __name__ == '__main__':
caffe.init_log(3)
main()

View File

@ -0,0 +1,61 @@
import os
import os.path as osp
import sys
import google.protobuf as pb
from argparse import ArgumentParser
import numpy as np
import shutil
import caffe
from caffe.proto import caffe_pb2
sys.path.append('waifu2x-chainer')
from lib import srcnn
import chainer
fname_convert_table = {
'anime_style_noise0_scale_rgb': 'noise0_scale2.0x_model',
'anime_style_noise1_scale_rgb': 'noise1_scale2.0x_model',
'anime_style_noise2_scale_rgb': 'noise2_scale2.0x_model',
'anime_style_noise3_scale_rgb': 'noise3_scale2.0x_model',
'anime_style_scale_rgb': 'scale2.0x_model',
}
def main():
caffe.set_mode_cpu()
model_name = 'UpResNet10'
model_dir = 'waifu2x-chainer/models/{}'.format(model_name.lower())
model_class = srcnn.archs[model_name]
for filename in os.listdir(model_dir):
basename, ext = os.path.splitext(filename)
if ext == '.npz':
model_path = os.path.join(model_dir, filename)
print(model_path)
channels = 3 if 'rgb' in filename else 1
model = model_class(channels)
size = 64 + model.offset
data = np.zeros((1, channels, size, size), dtype=np.float32)
x = chainer.Variable(data)
chainer.serializers.load_npz(model_path, model)
params = {}
for path, param in model.namedparams():
params[path] = param.array
net = caffe.Net('upresnet10_3.prototxt', caffe.TEST)
for key in net.params:
l = len(net.params[key])
net.params[key][0].data[...] = params[key + '/W']
if l >= 2:
net.params[key][1].data[...] = params[key + '/b']
prototxt_path = '{}.prototxt'.format(fname_convert_table[basename])
caffemodel_path = '{}.json.caffemodel'.format(fname_convert_table[basename])
net.save(caffemodel_path)
shutil.copy('upresnet10_3.prototxt', prototxt_path)
if __name__ == '__main__':
caffe.init_log(3)
main()