バイナリのモデル、パラメータファイルの読み込みに失敗していたバグを修正

This commit is contained in:
lltcggie 2015-12-27 04:14:34 +09:00
parent de24cc5d99
commit c13d6c1a8b

View File

@ -652,32 +652,32 @@ Waifu2x::eWaifu2xError Waifu2x::CreateZoomColorImage(const cv::Mat &float_image,
// processでcudnnが指定されなかった場合はcuDNNが呼び出されないように変更する // processでcudnnが指定されなかった場合はcuDNNが呼び出されないように変更する
Waifu2x::eWaifu2xError Waifu2x::ConstractNet(boost::shared_ptr<caffe::Net<float>> &net, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const std::string &process) Waifu2x::eWaifu2xError Waifu2x::ConstractNet(boost::shared_ptr<caffe::Net<float>> &net, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const std::string &process)
{ {
boost::filesystem::path caffemodel_path = param_path;
caffemodel_path += ".caffemodel";
boost::filesystem::path modelbin_path = model_path; boost::filesystem::path modelbin_path = model_path;
modelbin_path += ".protobin"; modelbin_path += ".protobin";
boost::filesystem::path caffemodel_path = param_path;
caffemodel_path += ".caffemodel";
caffe::NetParameter param; caffe::NetParameter param_model;
caffe::NetParameter param_caffemodel; caffe::NetParameter param_caffemodel;
const auto retModelBin = readProtoBinary(caffemodel_path, &param); const auto retModelBin = readProtoBinary(modelbin_path, &param_model);
const auto retParamBin = readProtoBinary(modelbin_path, &param_caffemodel); const auto retParamBin = readProtoBinary(caffemodel_path, &param_caffemodel);
if (retModelBin == eWaifu2xError_OK && retParamBin == eWaifu2xError_OK) if (retModelBin == eWaifu2xError_OK && retParamBin == eWaifu2xError_OK)
{ {
Waifu2x::eWaifu2xError ret; Waifu2x::eWaifu2xError ret;
ret = SetParameter(param, process); ret = SetParameter(param_model, process);
if (ret != eWaifu2xError_OK) if (ret != eWaifu2xError_OK)
return ret; return ret;
if (!caffe::UpgradeNetAsNeeded(caffemodel_path.string(), &param_caffemodel)) if (!caffe::UpgradeNetAsNeeded(caffemodel_path.string(), &param_caffemodel))
return Waifu2x::eWaifu2xError_FailedParseModelFile; return Waifu2x::eWaifu2xError_FailedParseModelFile;
net = boost::shared_ptr<caffe::Net<float>>(new caffe::Net<float>(param)); net = boost::shared_ptr<caffe::Net<float>>(new caffe::Net<float>(param_model));
net->CopyTrainedLayersFrom(param_caffemodel); net->CopyTrainedLayersFrom(param_caffemodel);
input_plane = param.input_dim(1); input_plane = param_model.input_dim(1);
} }
else else
{ {