モデルの短縮名をinfo.jsonから取得するようにした

This commit is contained in:
lltcggie 2016-07-03 15:36:32 +09:00
parent 9fa4d905fd
commit 9052df543f
5 changed files with 134 additions and 62 deletions

View File

@ -184,6 +184,45 @@ Waifu2x::eWaifu2xError cNet::ConstractNet(const boost::filesystem::path &model_p
return Waifu2x::eWaifu2xError_OK;
}
namespace
{
Waifu2x::eWaifu2xError ReadJson(const boost::filesystem::path &info_path, rapidjson::Document &d, std::vector<char> &jsonBuf)
{
try
{
boost::iostreams::stream<boost::iostreams::file_descriptor_source> is;
try
{
is.open(info_path, std::ios_base::in | std::ios_base::binary);
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
}
if (!is)
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
const size_t size = is.seekg(0, std::ios::end).tellg();
is.seekg(0, std::ios::beg);
jsonBuf.resize(size + 1);
is.read(jsonBuf.data(), jsonBuf.size());
jsonBuf[jsonBuf.size() - 1] = '\0';
d.Parse(jsonBuf.data());
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedParseModelFile;
}
return Waifu2x::eWaifu2xError_OK;
}
};
Waifu2x::eWaifu2xError cNet::LoadInfoFromJson(const boost::filesystem::path &info_path)
{
rapidjson::Document d;
@ -191,29 +230,11 @@ Waifu2x::eWaifu2xError cNet::LoadInfoFromJson(const boost::filesystem::path &inf
try
{
boost::iostreams::stream<boost::iostreams::file_descriptor_source> is;
Waifu2x::eWaifu2xError ret;
try
{
is.open(info_path, std::ios_base::in | std::ios_base::binary);
}
catch (...)
{
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
}
if (!is)
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
const size_t size = is.seekg(0, std::ios::end).tellg();
is.seekg(0, std::ios::beg);
jsonBuf.resize(size + 1);
is.read(jsonBuf.data(), jsonBuf.size());
jsonBuf[jsonBuf.size() - 1] = '\0';
d.Parse(jsonBuf.data());
ret = ReadJson(info_path, d, jsonBuf);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
const bool resize = d.HasMember("resize") && d["resize"].GetBool() ? true : false;
const auto name = d["name"].GetString();
@ -708,3 +729,29 @@ Waifu2x::eWaifu2xError cNet::ReconstructImage(const bool UseTTA, const int crop_
return Waifu2x::eWaifu2xError_OK;
}
std::string cNet::GetModelName(const boost::filesystem::path &info_path)
{
rapidjson::Document d;
std::vector<char> jsonBuf;
std::string str;
try
{
Waifu2x::eWaifu2xError ret;
ret = ReadJson(info_path, d, jsonBuf);
if (ret != Waifu2x::eWaifu2xError_OK)
return str;
const auto name = d["name"].GetString();
str = name;
}
catch (...)
{
}
return str;
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <string>
#include "waifu2x.h"
@ -34,4 +35,6 @@ public:
int GetOutputMemorySize(const int crop_w, const int crop_h, const int outer_padding, const int batch_size) const;
Waifu2x::eWaifu2xError ReconstructImage(const bool UseTTA, const int crop_w, const int crop_h, const int outer_padding, const int batch_size, float *inputBlockBuf, float *outputBlockBuf, const cv::Mat &inMat, cv::Mat &outMat);
static std::string GetModelName(const boost::filesystem::path &info_path);
};

View File

@ -285,19 +285,7 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
const auto cuDNNCheckEndTime = std::chrono::system_clock::now();
boost::filesystem::path mode_dir_path(model_dir);
if (!mode_dir_path.is_absolute()) // model_dirが相対パスなら絶対パスに直す
{
// まずはカレントディレクトリ下にあるか探す
mode_dir_path = boost::filesystem::absolute(model_dir);
if (!boost::filesystem::exists(mode_dir_path) && !ExeDir.empty()) // 無かったらargv[0]から実行ファイルのあるフォルダを推定し、そのフォルダ下にあるか探す
{
boost::filesystem::path a0(ExeDir);
if (a0.is_absolute())
mode_dir_path = a0.branch_path() / model_dir;
}
}
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
if (!boost::filesystem::exists(mode_dir_path))
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
@ -317,13 +305,14 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
// TODO: ノイズ除去と拡大を同時に行うネットワークへの対処を考える
const boost::filesystem::path info_path = GetInfoPath(mode_dir_path);
if (mode == "noise" || mode == "noise_scale" || mode == "auto_scale")
{
const std::string base_name = "noise" + std::to_string(noise_level) + "_model";
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
const boost::filesystem::path info_path = mode_dir_path / "info.json";
mNoiseNet.reset(new cNet);
@ -341,7 +330,6 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
const boost::filesystem::path info_path = mode_dir_path / "info.json";
mScaleNet.reset(new cNet);
@ -369,6 +357,31 @@ Waifu2x::eWaifu2xError Waifu2x::Init(const std::string &mode, const int noise_le
return Waifu2x::eWaifu2xError_OK;
}
boost::filesystem::path Waifu2x::GetModeDirPath(const boost::filesystem::path &model_dir)
{
boost::filesystem::path mode_dir_path(model_dir);
if (!mode_dir_path.is_absolute()) // model_dirが相対パスなら絶対パスに直す
{
// まずはカレントディレクトリ下にあるか探す
mode_dir_path = boost::filesystem::absolute(model_dir);
if (!boost::filesystem::exists(mode_dir_path) && !ExeDir.empty()) // 無かったらargv[0]から実行ファイルのあるフォルダを推定し、そのフォルダ下にあるか探す
{
boost::filesystem::path a0(ExeDir);
if (a0.is_absolute())
mode_dir_path = a0.branch_path() / model_dir;
}
}
return mode_dir_path;
}
boost::filesystem::path Waifu2x::GetInfoPath(const boost::filesystem::path &mode_dir_path)
{
const boost::filesystem::path info_path = mode_dir_path / "info.json";
return info_path;
}
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const boost::filesystem::path &input_file, const boost::filesystem::path &output_file,
const double factor, const waifu2xCancelFunc cancel_func, const int crop_w, const int crop_h,
const boost::optional<int> output_quality, const int output_depth, const bool use_tta,
@ -652,3 +665,14 @@ const std::string& Waifu2x::used_process() const
{
return mProcess;
}
std::string Waifu2x::GetModelName(const boost::filesystem::path & model_dir)
{
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
if (!boost::filesystem::exists(mode_dir_path))
return std::string();
const boost::filesystem::path info_path = mode_dir_path / "info.json";
return cNet::GetModelName(info_path);
}

View File

@ -87,7 +87,8 @@ private:
size_t mOutputBlockSize;
private:
eWaifu2xError ReconstructImage(boost::shared_ptr<caffe::Net<float>> net, const int reconstructed_scale, cv::Mat &im);
static boost::filesystem::path GetModeDirPath(const boost::filesystem::path &model_dir);
static boost::filesystem::path GetInfoPath(const boost::filesystem::path &model_dir);
Waifu2x::eWaifu2xError ReconstructImage(const double factor, const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const bool isReconstructNoise, const bool isReconstructScale, const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image);
@ -131,4 +132,6 @@ public:
void Destroy();
const std::string& used_process() const;
static std::string GetModelName(const boost::filesystem::path &model_dir);
};

View File

@ -4,6 +4,7 @@
#include <iostream>
#include <fstream>
#include <algorithm>
#include <codecvt>
#include <cblas.h>
#include <dlgs.h>
#include <boost/tokenizer.hpp>
@ -104,27 +105,21 @@ std::vector<int> CommonDivisorList(const int N)
tstring DialogEvent::AddName() const
{
tstring addstr;
tstring addstr(TEXT("("));
addstr += TEXT("(");
switch (modelType)
const std::string ModelName = Waifu2x::GetModelName(model_dir);
#ifdef UNICODE
{
case eModelTypeRGB:
addstr += TEXT("RGB");
break;
std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> cv;
const std::wstring wModelName = cv.from_bytes(ModelName);
case eModelTypePhoto:
addstr += TEXT("Photo");
break;
case eModelTypeY:
addstr += TEXT("Y");
break;
case eModelTypeUpConvRGB:
addstr += TEXT("UpConvRGB");
break;
addstr += wModelName;
}
#else
addstr += ModelName;
#endif
addstr += TEXT(")");
addstr += TEXT("(");
@ -276,20 +271,20 @@ bool DialogEvent::SyncMember(const bool NotSyncCropSize, const bool silent)
break;
case 1:
model_dir = TEXT("models/anime_style_art");
modelType = eModelTypeY;
break;
case 2:
model_dir = TEXT("models/photo");
modelType = eModelTypePhoto;
break;
case 3:
case 2:
model_dir = TEXT("models/upconv_7_anime_style_art_rgb");
modelType = eModelTypeUpConvRGB;
break;
case 3:
model_dir = TEXT("models/anime_style_art");
modelType = eModelTypeY;
break;
default:
break;
}
@ -1817,7 +1812,7 @@ void DialogEvent::Create(HWND hWnd, WPARAM wParam, LPARAM lParam, LPVOID lpData)
index = 0;
else if (modelType == eModelTypePhoto)
index = 1;
else if (modelType == eModelTypeY)
else if (modelType == eModelTypeUpConvRGB)
index = 2;
else
index = 3;