モデルのおすすめのCropSizeを設定する機能追加、UpResNet10モデルにおすすめのCropSize設定

CropSizeによって出力が変わるモデル用の設定
This commit is contained in:
lltcggie 2018-10-25 04:15:53 +09:00
parent a17ad89116
commit 6effe0dfaa
6 changed files with 59 additions and 29 deletions

View File

@ -1,3 +1,3 @@
{"name":"UpResNet10","arch_name":"upresnet10","has_noise_scale":true,"channels":3, {"name":"UpResNet10","arch_name":"upresnet10","has_noise_scale":true,"channels":3,
"scale_factor":2,"offset":26 "scale_factor":2,"offset":26,"recommended_crop_size":38
} }

View File

@ -174,7 +174,7 @@ cNet::cNet() : mModelScale(0), mInnerScale(0), mNetOffset(0), mInputPlane(0), mH
cNet::~cNet() cNet::~cNet()
{} {}
Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path, stInfo &info) Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path, Waifu2x::stInfo &info)
{ {
rapidjson::Document d; rapidjson::Document d;
std::vector<char> jsonBuf; std::vector<char> jsonBuf;
@ -191,11 +191,13 @@ Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path,
const auto arch_name = d["arch_name"].GetString(); const auto arch_name = d["arch_name"].GetString();
const bool has_noise_scale = d.HasMember("has_noise_scale") && d["has_noise_scale"].GetBool() ? true : false; const bool has_noise_scale = d.HasMember("has_noise_scale") && d["has_noise_scale"].GetBool() ? true : false;
const int channels = d["channels"].GetInt(); const int channels = d["channels"].GetInt();
const int recommended_crop_size = d.HasMember("recommended_crop_size") ? d["recommended_crop_size"].GetInt() : -1;
info.name = name; info.name = name;
info.arch_name = arch_name; info.arch_name = arch_name;
info.has_noise_scale = has_noise_scale; info.has_noise_scale = has_noise_scale;
info.channels = channels; info.channels = channels;
info.recommended_crop_size = recommended_crop_size;
if (d.HasMember("offset")) if (d.HasMember("offset"))
{ {
@ -261,7 +263,7 @@ Waifu2x::eWaifu2xError cNet::GetInfo(const boost::filesystem::path & info_path,
// モデルファイルからネットワークを構築 // モデルファイルからネットワークを構築
// processでcudnnが指定されなかった場合はcuDNNが呼び出されないように変更する // processでcudnnが指定されなかった場合はcuDNNが呼び出されないように変更する
Waifu2x::eWaifu2xError cNet::ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const stInfo &info, const std::string &process) Waifu2x::eWaifu2xError cNet::ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const Waifu2x::stInfo &info, const std::string &process)
{ {
Waifu2x::eWaifu2xError ret; Waifu2x::eWaifu2xError ret;
@ -321,11 +323,11 @@ Waifu2x::eWaifu2xError cNet::ConstractNet(const Waifu2x::eWaifu2xModelType mode,
return Waifu2x::eWaifu2xError_OK; return Waifu2x::eWaifu2xError_OK;
} }
void cNet::LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const stInfo &info) void cNet::LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const Waifu2x::stInfo &info)
{ {
mModelScale = 2; // TODO: 動的に設定するようにする mModelScale = 2; // TODO: 動的に設定するようにする
stInfo::stParam param; Waifu2x::stInfo::stParam param;
switch (mode) switch (mode)
{ {
@ -824,7 +826,7 @@ std::string cNet::GetModelName(const boost::filesystem::path &info_path)
{ {
Waifu2x::eWaifu2xError ret; Waifu2x::eWaifu2xError ret;
stInfo info; Waifu2x::stInfo info;
ret = GetInfo(info_path, info); ret = GetInfo(info_path, info);
if (ret != Waifu2x::eWaifu2xError_OK) if (ret != Waifu2x::eWaifu2xError_OK)
return std::string(); return std::string();

View File

@ -4,24 +4,6 @@
#include "waifu2x.h" #include "waifu2x.h"
struct stInfo
{
struct stParam
{
int scale_factor;
int offset;
};
std::string name;
std::string arch_name;
bool has_noise_scale;
int channels;
stParam noise;
stParam scale;
stParam noise_scale;
};
class cNet class cNet
{ {
private: private:
@ -36,7 +18,7 @@ private:
bool mHasNoiseScaleModel; bool mHasNoiseScaleModel;
private: private:
void LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const stInfo &info); void LoadParamFromInfo(const Waifu2x::eWaifu2xModelType mode, const Waifu2x::stInfo &info);
Waifu2x::eWaifu2xError LoadParameterFromJson(const boost::filesystem::path &model_path, const boost::filesystem::path &param_path Waifu2x::eWaifu2xError LoadParameterFromJson(const boost::filesystem::path &model_path, const boost::filesystem::path &param_path
, const boost::filesystem::path &modelbin_path, const boost::filesystem::path &caffemodel_path, const std::string &process); , const boost::filesystem::path &modelbin_path, const boost::filesystem::path &caffemodel_path, const std::string &process);
Waifu2x::eWaifu2xError SetParameter(caffe::NetParameter &param, const std::string &process) const; Waifu2x::eWaifu2xError SetParameter(caffe::NetParameter &param, const std::string &process) const;
@ -45,9 +27,9 @@ public:
cNet(); cNet();
~cNet(); ~cNet();
static Waifu2x::eWaifu2xError GetInfo(const boost::filesystem::path &info_path, stInfo &info); static Waifu2x::eWaifu2xError GetInfo(const boost::filesystem::path &info_path, Waifu2x::stInfo &info);
Waifu2x::eWaifu2xError ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const stInfo &info, const std::string &process); Waifu2x::eWaifu2xError ConstractNet(const Waifu2x::eWaifu2xModelType mode, const boost::filesystem::path &model_path, const boost::filesystem::path &param_path, const Waifu2x::stInfo &info, const std::string &process);
int GetInputPlane() const; int GetInputPlane() const;
int GetInnerScale() const; int GetInnerScale() const;

View File

@ -1149,3 +1149,14 @@ std::string Waifu2x::GetModelName(const boost::filesystem::path & model_dir)
return cNet::GetModelName(info_path); return cNet::GetModelName(info_path);
} }
bool Waifu2x::GetInfo(const boost::filesystem::path &model_dir, stInfo &info)
{
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
if (!boost::filesystem::exists(mode_dir_path))
return false;
const boost::filesystem::path info_path = mode_dir_path / "info.json";
return cNet::GetInfo(info_path, info) == Waifu2x::eWaifu2xError_OK;
}

View File

@ -59,6 +59,25 @@ public:
class Waifu2x class Waifu2x
{ {
public: public:
struct stInfo
{
struct stParam
{
int scale_factor;
int offset;
};
std::string name;
std::string arch_name;
bool has_noise_scale;
int channels;
int recommended_crop_size;
stParam noise;
stParam scale;
stParam noise_scale;
};
enum eWaifu2xModelType enum eWaifu2xModelType
{ {
eWaifu2xModelTypeNoise = 0, eWaifu2xModelTypeNoise = 0,
@ -183,4 +202,5 @@ public:
const std::string& used_process() const; const std::string& used_process() const;
static std::string GetModelName(const boost::filesystem::path &model_dir); static std::string GetModelName(const boost::filesystem::path &model_dir);
static bool GetInfo(const boost::filesystem::path &model_dir, stInfo &info);
}; };

View File

@ -527,6 +527,18 @@ void DialogEvent::SetCropSizeList(const boost::filesystem::path & input_path)
} }
), list.end()); ), list.end());
bool isRecommendedCropSize = false;
Waifu2x::stInfo info;
if (Waifu2x::GetInfo(model_dir, info) && info.recommended_crop_size > 0)
{
tstring str(to_tstring(info.recommended_crop_size));
SendMessage(hcrop, CB_ADDSTRING, 0, (LPARAM)str.c_str());
isRecommendedCropSize = true;
}
if (list.size() > 0)
SendMessage(hcrop, CB_ADDSTRING, 0, (LPARAM)TEXT("-----------------------"));
int mindiff = INT_MAX; int mindiff = INT_MAX;
int defaultIndex = -1; int defaultIndex = -1;
for (int i = 0; i < list.size(); i++) for (int i = 0; i < list.size(); i++)
@ -534,13 +546,13 @@ void DialogEvent::SetCropSizeList(const boost::filesystem::path & input_path)
const int n = list[i]; const int n = list[i];
tstring str(to_tstring(n)); tstring str(to_tstring(n));
SendMessage(hcrop, CB_ADDSTRING, 0, (LPARAM)str.c_str()); const int index = SendMessage(hcrop, CB_ADDSTRING, 0, (LPARAM)str.c_str());
const int diff = abs(DefaultCommonDivisor - n); const int diff = abs(DefaultCommonDivisor - n);
if (DefaultCommonDivisorRange.first <= n && n <= DefaultCommonDivisorRange.second && diff < mindiff) if (DefaultCommonDivisorRange.first <= n && n <= DefaultCommonDivisorRange.second && diff < mindiff)
{ {
mindiff = diff; mindiff = diff;
defaultIndex = i; defaultIndex = index;
} }
} }
@ -565,6 +577,9 @@ void DialogEvent::SetCropSizeList(const boost::filesystem::path & input_path)
if (defaultIndex == -1) if (defaultIndex == -1)
defaultIndex = defaultListIndex; defaultIndex = defaultListIndex;
if(isRecommendedCropSize)
defaultIndex = 0;
if (GetWindowTextLength(hcrop) == 0) if (GetWindowTextLength(hcrop) == 0)
SendMessage(hcrop, CB_SETCURSEL, defaultIndex, 0); SendMessage(hcrop, CB_SETCURSEL, defaultIndex, 0);
} }