waifu2x-caffe/common/waifu2x.cpp

743 lines
20 KiB
C++
Raw Permalink Normal View History

2015-05-29 01:47:26 +09:00
#include "waifu2x.h"
#include "stImage.h"
#include "cNet.h"
2015-05-29 01:47:26 +09:00
#include <caffe/caffe.hpp>
#include <cudnn.h>
#include <mutex>
2016-05-11 22:55:16 +09:00
#include <opencv2/core.hpp>
2015-05-29 01:47:26 +09:00
#include <tclap/CmdLine.h>
#include <boost/filesystem.hpp>
#include <boost/algorithm/string.hpp>
#include <chrono>
#include <cuda_runtime.h>
2015-05-29 01:47:26 +09:00
#include <boost/iostreams/stream.hpp>
#include <boost/iostreams/device/file_descriptor.hpp>
#include <fcntl.h>
#include <zlib.h>
#ifdef _MSC_VER
#include <io.h>
#endif
//#if defined(WIN32) || defined(WIN64)
//#include <Windows.h>
//#endif
2015-05-29 01:47:26 +09:00
#define CV_VERSION_STR CVAUX_STR(CV_MAJOR_VERSION) CVAUX_STR(CV_MINOR_VERSION) CVAUX_STR(CV_SUBMINOR_VERSION)
// <20>r<EFBFBD><72><EFBFBD>h<EFBFBD><68><EFBFBD>[<5B>h
#ifdef _DEBUG
#define CV_EXT_STR "d.lib"
#else
#define CV_EXT_STR ".lib"
#endif
#ifdef _MSC_VER
2016-05-11 22:55:16 +09:00
#pragma comment(lib, "opencv_core" CV_VERSION_STR CV_EXT_STR)
#pragma comment(lib, "opencv_imgcodecs" CV_VERSION_STR CV_EXT_STR)
#pragma comment(lib, "opencv_imgproc" CV_VERSION_STR CV_EXT_STR)
//#pragma comment(lib, "IlmImf" CV_EXT_STR)
//#pragma comment(lib, "libjasper" CV_EXT_STR)
//#pragma comment(lib, "libjpeg" CV_EXT_STR)
//#pragma comment(lib, "libpng" CV_EXT_STR)
//#pragma comment(lib, "libtiff" CV_EXT_STR)
//#pragma comment(lib, "libwebp" CV_EXT_STR)
#pragma comment(lib, "libopenblas.dll.a")
#pragma comment(lib, "cudart.lib")
#pragma comment(lib, "curand.lib")
#pragma comment(lib, "cublas.lib")
#pragma comment(lib, "cudnn.lib")
2015-05-29 01:47:26 +09:00
#ifdef _DEBUG
#pragma comment(lib, "caffe-d.lib")
#pragma comment(lib, "proto-d.lib")
#pragma comment(lib, "libboost_system-vc120-mt-gd-1_59.lib")
#pragma comment(lib, "libboost_thread-vc120-mt-gd-1_59.lib")
#pragma comment(lib, "libboost_filesystem-vc120-mt-gd-1_59.lib")
#pragma comment(lib, "glogd.lib")
#pragma comment(lib, "gflagsd.lib")
2015-05-29 01:47:26 +09:00
#pragma comment(lib, "libprotobufd.lib")
#pragma comment(lib, "libhdf5_hl_D.lib")
#pragma comment(lib, "libhdf5_D.lib")
#pragma comment(lib, "zlibstaticd.lib")
#pragma comment(lib, "libboost_iostreams-vc120-mt-gd-1_59.lib")
2015-05-29 01:47:26 +09:00
#else
#pragma comment(lib, "caffe.lib")
#pragma comment(lib, "proto.lib")
#pragma comment(lib, "libboost_system-vc120-mt-1_59.lib")
#pragma comment(lib, "libboost_thread-vc120-mt-1_59.lib")
#pragma comment(lib, "libboost_filesystem-vc120-mt-1_59.lib")
#pragma comment(lib, "glog.lib")
#pragma comment(lib, "gflags.lib")
2015-05-29 01:47:26 +09:00
#pragma comment(lib, "libprotobuf.lib")
#pragma comment(lib, "libhdf5_hl.lib")
#pragma comment(lib, "libhdf5.lib")
#pragma comment(lib, "zlibstatic.lib")
#pragma comment(lib, "libboost_iostreams-vc120-mt-1_59.lib")
#endif
2015-05-29 01:47:26 +09:00
#endif
const int ScaleBase = 2; // TODO: <20><><EFBFBD>f<EFBFBD><66><EFBFBD>̊g<CC8A><EFBFBD>ɂ<EFBFBD><C982><EFBFBD><EFBFBD>ĉ•ςł<CF82><C582><EFBFBD><EFBFBD><EFBFBD>ɂ<EFBFBD><C982><EFBFBD>
2015-05-29 01:47:26 +09:00
// <20><><EFBFBD>͉摜<CD89>ɒlj<C992><C789><EFBFBD><EFBFBD><EFBFBD><EFBFBD>p<EFBFBD>f<EFBFBD>B<EFBFBD><42><EFBFBD>O
const int OuterPadding = 0;
2015-05-29 01:47:26 +09:00
// <20>Œ<EFBFBD><C592><EFBFBD><EFBFBD>K<EFBFBD>v<EFBFBD><76>CUDA<44>h<EFBFBD><68><EFBFBD>C<EFBFBD>o<EFBFBD>[<5B>̃o<CC83>[<5B>W<EFBFBD><57><EFBFBD><EFBFBD>
const int MinCudaDriverVersion = 7050;
2015-06-03 03:01:56 +09:00
static std::once_flag waifu2x_once_flag;
static std::once_flag waifu2x_cudnn_once_flag;
static std::once_flag waifu2x_cuda_once_flag;
2015-05-29 01:47:26 +09:00
std::string Waifu2x::ExeDir;
2015-12-27 06:39:13 +09:00
#ifndef CUDA_CHECK_WAIFU2X
#define CUDA_CHECK_WAIFU2X(condition) \
do { \
cudaError_t error = condition; \
if(error != cudaSuccess) throw error; \
} while (0)
#endif
2015-05-29 01:47:26 +09:00
#define CUDA_HOST_SAFE_FREE(ptr) \
do { \
if (ptr) { \
cudaFreeHost(ptr); \
ptr = nullptr; \
} \
} while (0)
#define SAFE_DELETE_WAIFU2X(ptr) \
do { \
if (ptr) { \
delete [] ptr; \
ptr = nullptr; \
} \
} while (0)
2015-07-10 04:09:22 +09:00
namespace
{
class IgnoreErrorCV
{
private:
static int handleError(int status, const char* func_name,
const char* err_msg, const char* file_name,
int line, void* userdata)
{
return 0;
}
public:
IgnoreErrorCV()
{
cv::redirectError(handleError);
}
};
IgnoreErrorCV g_IgnoreErrorCV;
}
// CUDA<44><41><EFBFBD>g<EFBFBD><67><EFBFBD><EFBFBD>`<60>F<EFBFBD>b<EFBFBD>N
Waifu2x::eWaifu2xCudaError Waifu2x::can_use_CUDA()
{
static eWaifu2xCudaError CudaFlag = eWaifu2xCudaError_NotFind;
std::call_once(waifu2x_cuda_once_flag, [&]()
{
int driverVersion = 0;
if (cudaDriverGetVersion(&driverVersion) == cudaSuccess)
{
if (driverVersion > 0)
{
int runtimeVersion;
if (cudaRuntimeGetVersion(&runtimeVersion) == cudaSuccess)
{
if (runtimeVersion >= MinCudaDriverVersion && driverVersion >= runtimeVersion)
{
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if (prop.major >= 2)
CudaFlag = eWaifu2xCudaError_OK;
else
CudaFlag = eWaifu2xCudaError_OldDevice;
}
else
CudaFlag = eWaifu2xCudaError_OldVersion;
}
else
CudaFlag = eWaifu2xCudaError_NotFind;
}
else
CudaFlag = eWaifu2xCudaError_NotFind;
}
else
CudaFlag = eWaifu2xCudaError_NotFind;
});
2015-06-03 03:01:56 +09:00
return CudaFlag;
2015-06-03 03:01:56 +09:00
}
2015-05-29 01:47:26 +09:00
// cuDNN<4E><4E><EFBFBD>g<EFBFBD><67><EFBFBD><EFBFBD>`<60>F<EFBFBD>b<EFBFBD>N<EFBFBD>B<EFBFBD><42><EFBFBD><EFBFBD>Windows<77>̂<EFBFBD>
Waifu2x::eWaifu2xcuDNNError Waifu2x::can_use_cuDNN()
2015-05-29 01:47:26 +09:00
{
static eWaifu2xcuDNNError cuDNNFlag = eWaifu2xcuDNNError_NotFind;
2015-05-29 01:47:26 +09:00
std::call_once(waifu2x_cudnn_once_flag, [&]()
{
#if defined(WIN32) || defined(WIN64)
HMODULE hModule = LoadLibrary(TEXT(CUDNN_DLL_NAME));
2015-05-29 01:47:26 +09:00
if (hModule != NULL)
{
typedef cudnnStatus_t(__stdcall * cudnnCreateType)(cudnnHandle_t *);
typedef cudnnStatus_t(__stdcall * cudnnDestroyType)(cudnnHandle_t);
typedef uint64_t(__stdcall * cudnnGetVersionType)();
2015-05-29 01:47:26 +09:00
cudnnCreateType cudnnCreateFunc = (cudnnCreateType)GetProcAddress(hModule, "cudnnCreate");
cudnnDestroyType cudnnDestroyFunc = (cudnnDestroyType)GetProcAddress(hModule, "cudnnDestroy");
cudnnGetVersionType cudnnGetVersionFunc = (cudnnGetVersionType)GetProcAddress(hModule, "cudnnGetVersion");
if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr)
2015-05-29 01:47:26 +09:00
{
if (cudnnGetVersionFunc() >= 3000)
2015-05-29 01:47:26 +09:00
{
cudnnHandle_t h;
if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS)
{
if (cudnnDestroyFunc(h) == CUDNN_STATUS_SUCCESS)
cuDNNFlag = eWaifu2xcuDNNError_OK;
else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
}
else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
2015-05-29 01:47:26 +09:00
}
else
cuDNNFlag = eWaifu2xcuDNNError_OldVersion;
2015-05-29 01:47:26 +09:00
}
else
cuDNNFlag = eWaifu2xcuDNNError_NotFind;
2015-05-29 01:47:26 +09:00
FreeLibrary(hModule);
}
#endif
});
return cuDNNFlag;
}
void Waifu2x::init_liblary(int argc, char** argv)
{
if (argc > 0)
ExeDir = argv[0];
std::call_once(waifu2x_once_flag, [argc, argv]()
{
assert(argc >= 1);
int tmpargc = 1;
char* tmpargvv[] = {argv[0]};
char** tmpargv = tmpargvv;
// glog<6F><67><EFBFBD>̏<EFBFBD><CC8F><EFBFBD><EFBFBD><EFBFBD>
caffe::GlobalInit(&tmpargc, &tmpargv);
});
}
void Waifu2x::quit_liblary()
{}
2015-07-10 04:09:22 +09:00
Waifu2x::Waifu2x() : mIsInited(false), mNoiseLevel(0), mIsCuda(false), mOutputBlock(nullptr), mOutputBlockSize(0)
{}
2015-07-10 04:09:22 +09:00
Waifu2x::~Waifu2x()
{
Destroy();
}
2016-07-03 17:13:02 +09:00
Waifu2x::eWaifu2xError Waifu2x::Init(const eWaifu2xModelType mode, const int noise_level,
const boost::filesystem::path &model_dir, const std::string &process)
{
Waifu2x::eWaifu2xError ret;
if (mIsInited)
return Waifu2x::eWaifu2xError_OK;
try
{
std::string Process = process;
const auto cuDNNCheckStartTime = std::chrono::system_clock::now();
if (Process == "gpu")
{
if (can_use_CUDA() != eWaifu2xCudaError_OK)
return Waifu2x::eWaifu2xError_FailedCudaCheck;
// cuDNN<4E><4E><EFBFBD>g<EFBFBD><67><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ȃ<EFBFBD>cuDNN<4E><4E><EFBFBD>g<EFBFBD><67>
else if (can_use_cuDNN() == eWaifu2xcuDNNError_OK)
Process = "cudnn";
}
mMode = mode;
mNoiseLevel = noise_level;
mProcess = Process;
const auto cuDNNCheckEndTime = std::chrono::system_clock::now();
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
if (!boost::filesystem::exists(mode_dir_path))
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
if (mProcess == "cpu")
{
caffe::Caffe::set_mode(caffe::Caffe::CPU);
mIsCuda = false;
}
else
{
caffe::Caffe::set_mode(caffe::Caffe::GPU);
mIsCuda = true;
}
mInputPlane = 0;
mMaxNetOffset = 0;
2015-05-29 01:47:26 +09:00
const boost::filesystem::path info_path = GetInfoPath(mode_dir_path);
2016-07-03 17:13:02 +09:00
stInfo info;
ret = cNet::GetInfo(info_path, info);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
mHasNoiseScale = info.has_noise_scale;
mInputPlane = info.channels;
if (mode == eWaifu2xModelTypeNoise || mode == eWaifu2xModelTypeNoiseScale || mode == eWaifu2xModelTypeAutoScale)
{
2016-07-03 17:13:02 +09:00
std::string base_name;
mNoiseNet.reset(new cNet);
eWaifu2xModelType Mode = mode;
if (info.has_noise_scale) // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD>Ɗg<C68A><67><EFBFBD>𓯎<EFBFBD><F093AF8E>ɍs<C98D><73>
{
// <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD>g<EFBFBD><67><EFBFBD>l<EFBFBD>b<EFBFBD>g<EFBFBD>̍\<5C>z<EFBFBD><7A>eWaifu2xModelTypeNoiseScale<6C><65><EFBFBD>w<EFBFBD><EFBFBD><E882B7><EFBFBD>K<EFBFBD>v<EFBFBD><76><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
Mode = eWaifu2xModelTypeNoiseScale;
base_name = "noise" + std::to_string(noise_level) + "_scale2.0x_model";
}
else // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
{
Mode = eWaifu2xModelTypeNoise;
base_name = "noise" + std::to_string(noise_level) + "_model";
}
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
2016-07-03 17:13:02 +09:00
ret = mNoiseNet->ConstractNet(Mode, model_path, param_path, info, mProcess);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
mMaxNetOffset = mNoiseNet->GetNetOffset();
}
2015-05-29 01:47:26 +09:00
2016-07-03 17:13:02 +09:00
// noise_scale<6C><65><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ă<EFBFBD><C482><EFBFBD><EFBFBD><EFBFBD>̓<EFBFBD><CD83>`<60><><EFBFBD><EFBFBD><EFBFBD>l<EFBFBD><6C><EFBFBD>̊g<CC8A><67><EFBFBD>̂<EFBFBD><CC82>߂<EFBFBD>mScaleNet<65><74><EFBFBD>\<5C>z<EFBFBD><7A><EFBFBD><EFBFBD><EFBFBD>K<EFBFBD>v<EFBFBD><76><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
if (info.has_noise_scale || mode == eWaifu2xModelTypeScale || mode == eWaifu2xModelTypeNoiseScale || mode == eWaifu2xModelTypeAutoScale)
{
const std::string base_name = "scale2.0x_model";
2015-07-10 04:09:22 +09:00
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
2015-07-10 04:09:22 +09:00
mScaleNet.reset(new cNet);
2015-07-10 04:09:22 +09:00
2016-07-03 17:13:02 +09:00
ret = mScaleNet->ConstractNet(eWaifu2xModelTypeScale, model_path, param_path, info, mProcess);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
2015-07-10 04:09:22 +09:00
assert(mInputPlane == 0 || mInputPlane == mScaleNet->GetInputPlane());
2015-07-10 04:09:22 +09:00
mMaxNetOffset = std::max(mScaleNet->GetNetOffset(), mMaxNetOffset);
}
else
2015-07-10 04:09:22 +09:00
{
2015-07-10 04:09:22 +09:00
}
mIsInited = true;
}
catch (...)
2015-07-10 04:09:22 +09:00
{
return Waifu2x::eWaifu2xError_InvalidParameter;
2015-07-10 04:09:22 +09:00
}
return Waifu2x::eWaifu2xError_OK;
2015-07-10 04:09:22 +09:00
}
boost::filesystem::path Waifu2x::GetModeDirPath(const boost::filesystem::path &model_dir)
{
boost::filesystem::path mode_dir_path(model_dir);
if (!mode_dir_path.is_absolute()) // model_dir<69><72><EFBFBD><EFBFBD><EFBFBD>΃p<CE83>X<EFBFBD>Ȃ<EFBFBD><C882><EFBFBD><EFBFBD>΃p<CE83>X<EFBFBD>ɒ<EFBFBD><C992><EFBFBD>
{
// <20>܂<EFBFBD><DC82>̓J<CD83><4A><EFBFBD><EFBFBD><EFBFBD>g<EFBFBD>f<EFBFBD>B<EFBFBD><42><EFBFBD>N<EFBFBD>g<EFBFBD><67><EFBFBD><EFBFBD><EFBFBD>ɂ<EFBFBD><C982><EFBFBD>T<EFBFBD><54>
mode_dir_path = boost::filesystem::absolute(model_dir);
if (!boost::filesystem::exists(mode_dir_path) && !ExeDir.empty()) // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>argv[0]<5D><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>s<EFBFBD>t<EFBFBD>@<40>C<EFBFBD><43><EFBFBD>̂<EFBFBD><CC82><EFBFBD><EFBFBD>t<EFBFBD>H<EFBFBD><48><EFBFBD>_<EFBFBD>𐄒肵<F0908492>A<EFBFBD><41><EFBFBD>̃t<CC83>H<EFBFBD><48><EFBFBD>_<EFBFBD><5F><EFBFBD>ɂ<EFBFBD><C982><EFBFBD>T<EFBFBD><54>
{
boost::filesystem::path a0(ExeDir);
if (a0.is_absolute())
mode_dir_path = a0.branch_path() / model_dir;
}
}
return mode_dir_path;
}
boost::filesystem::path Waifu2x::GetInfoPath(const boost::filesystem::path &mode_dir_path)
{
const boost::filesystem::path info_path = mode_dir_path / "info.json";
return info_path;
}
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const boost::filesystem::path &input_file, const boost::filesystem::path &output_file,
2016-07-03 21:55:20 +09:00
const boost::optional<double> scale_ratio, const boost::optional<int> scale_width, const boost::optional<int> scale_height,
const waifu2xCancelFunc cancel_func, const int crop_w, const int crop_h,
const boost::optional<int> output_quality, const int output_depth, const bool use_tta,
const int batch_size)
2015-05-29 01:47:26 +09:00
{
Waifu2x::eWaifu2xError ret;
2015-05-29 01:47:26 +09:00
if (!mIsInited)
return Waifu2x::eWaifu2xError_NotInitialized;
2015-05-29 01:47:26 +09:00
stImage image;
ret = image.Load(input_file);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
2015-05-29 01:47:26 +09:00
image.Preprocess(mInputPlane, mMaxNetOffset);
2015-05-29 01:47:26 +09:00
2016-07-03 17:13:02 +09:00
const bool isReconstructNoise = mMode == eWaifu2xModelTypeNoise || mMode == eWaifu2xModelTypeNoiseScale || (mMode == eWaifu2xModelTypeAutoScale && image.RequestDenoise());
const bool isReconstructScale = mMode == eWaifu2xModelTypeScale || mMode == eWaifu2xModelTypeNoiseScale || mMode == eWaifu2xModelTypeAutoScale;
2016-07-03 21:55:20 +09:00
double Factor = CalcScaleRatio(scale_ratio, scale_width, scale_height, image);
2016-07-03 17:13:02 +09:00
if (!isReconstructScale)
Factor = 1.0;
2015-05-29 01:47:26 +09:00
cv::Mat reconstruct_image;
2016-07-03 17:13:02 +09:00
ret = ReconstructImage(Factor, crop_w, crop_h, use_tta, batch_size, isReconstructNoise, isReconstructScale, cancel_func, image);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
2015-05-29 01:47:26 +09:00
2016-07-03 17:13:02 +09:00
image.Postprocess(mInputPlane, Factor, output_depth);
2015-05-29 01:47:26 +09:00
ret = image.Save(output_file, output_quality);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
2015-07-10 04:09:22 +09:00
return Waifu2x::eWaifu2xError_OK;
2015-07-10 04:09:22 +09:00
}
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const double factor, const void* source, void* dest, const int width, const int height,
const int in_channel, const int in_stride, const int out_channel, const int out_stride,
const int crop_w, const int crop_h, const bool use_tta, const int batch_size)
2015-06-03 03:01:56 +09:00
{
Waifu2x::eWaifu2xError ret;
2015-05-29 01:47:26 +09:00
if (!mIsInited)
return Waifu2x::eWaifu2xError_NotInitialized;
2016-07-03 17:13:02 +09:00
stImage image;
ret = image.Load(source, width, height, in_channel, in_stride);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
image.Preprocess(mInputPlane, mMaxNetOffset);
2016-07-03 17:13:02 +09:00
const bool isReconstructNoise = mMode == eWaifu2xModelTypeNoise || mMode == eWaifu2xModelTypeNoiseScale;
const bool isReconstructScale = mMode == eWaifu2xModelTypeScale || mMode == eWaifu2xModelTypeNoiseScale || mMode == eWaifu2xModelTypeAutoScale;
double Factor = factor;
if (!isReconstructScale)
Factor = 1.0;
cv::Mat reconstruct_image;
2016-07-03 17:13:02 +09:00
ret = ReconstructImage(Factor, crop_w, crop_h, use_tta, batch_size, isReconstructNoise, isReconstructScale, nullptr, image);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
2016-07-03 17:13:02 +09:00
image.Postprocess(mInputPlane, Factor, 8);
cv::Mat out_image = image.GetEndImage();
image.Clear();
2015-05-29 01:47:26 +09:00
// <20>o<EFBFBD>͔z<CD94><7A><EFBFBD>֏<EFBFBD><D68F><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
{
const auto width = out_image.size().width;
const auto stride = out_image.step1();
for (int i = 0; i < out_image.size().height; i++)
memcpy((uint8_t *)dest + out_stride * i, out_image.data + stride * i, stride);
}
2015-05-29 01:47:26 +09:00
return Waifu2x::eWaifu2xError_OK;
2015-11-19 01:50:11 +09:00
}
2016-07-03 21:55:20 +09:00
double Waifu2x::CalcScaleRatio(const boost::optional<double> scale_ratio, const boost::optional<int> scale_width, const boost::optional<int> scale_height,
const stImage &image)
{
if (scale_ratio)
return *scale_ratio;
if (scale_width)
return image.GetScaleFromWidth(*scale_width);
if(scale_height)
return image.GetScaleFromWidth(*scale_height);
return 1.0;
}
Waifu2x::eWaifu2xError Waifu2x::ReconstructImage(const double factor, const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const bool isReconstructNoise, const bool isReconstructScale, const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image)
2015-11-19 01:50:11 +09:00
{
Waifu2x::eWaifu2xError ret;
2016-07-03 17:13:02 +09:00
double Factor = factor;
2015-05-29 01:47:26 +09:00
2015-06-03 03:01:56 +09:00
if (isReconstructNoise)
{
2016-07-03 17:13:02 +09:00
if (!mHasNoiseScale) // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
{
cv::Mat im;
cv::Size_<int> size;
image.GetScalePaddingedRGB(im, size, mNoiseNet->GetNetOffset(), OuterPadding, crop_w, crop_h, 1);
2015-05-29 01:47:26 +09:00
2016-07-03 17:13:02 +09:00
ret = ProcessNet(mNoiseNet, crop_w, crop_h, use_tta, batch_size, im);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
2015-05-29 01:47:26 +09:00
2016-07-03 17:13:02 +09:00
image.SetReconstructedRGB(im, size, 1);
}
else // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD>Ɗg<C68A><67>
{
ret = ReconstructNoiseScale(crop_w, crop_h, use_tta, batch_size, cancel_func, image);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
Factor /= mNoiseNet->GetInnerScale();
}
2015-06-03 03:01:56 +09:00
}
2015-05-29 01:47:26 +09:00
2015-06-03 03:01:56 +09:00
if (cancel_func && cancel_func())
return Waifu2x::eWaifu2xError_Cancel;
2015-05-29 01:47:26 +09:00
2016-07-03 17:13:02 +09:00
const int scaleNum = ceil(log(Factor) / log(ScaleBase));
2015-06-03 03:01:56 +09:00
if (isReconstructScale)
{
for (int i = 0; i < scaleNum; i++)
{
ret = ReconstructScale(crop_w, crop_h, use_tta, batch_size, cancel_func, image);
if (ret != Waifu2x::eWaifu2xError_OK)
2015-06-03 03:01:56 +09:00
return ret;
}
}
return Waifu2x::eWaifu2xError_OK;
}
Waifu2x::eWaifu2xError Waifu2x::ReconstructScale(const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image)
{
Waifu2x::eWaifu2xError ret;
if (image.HasAlpha())
{
cv::Mat im;
cv::Size_<int> size;
image.GetScalePaddingedA(im, size, mScaleNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mScaleNet->GetScale() / mScaleNet->GetInnerScale());
ret = ReconstructByNet(mScaleNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
image.SetReconstructedA(im, size, mScaleNet->GetInnerScale());
2015-06-03 03:01:56 +09:00
}
2015-05-29 01:47:26 +09:00
cv::Mat im;
cv::Size_<int> size;
image.GetScalePaddingedRGB(im, size, mScaleNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mScaleNet->GetScale() / mScaleNet->GetInnerScale());
ret = ReconstructByNet(mScaleNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
2015-05-29 01:47:26 +09:00
image.SetReconstructedRGB(im, size, mScaleNet->GetInnerScale());
2015-11-19 01:50:11 +09:00
return Waifu2x::eWaifu2xError_OK;
2015-11-19 01:50:11 +09:00
}
2016-07-03 17:13:02 +09:00
Waifu2x::eWaifu2xError Waifu2x::ReconstructNoiseScale(const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image)
{
Waifu2x::eWaifu2xError ret;
if (image.HasAlpha())
{
// <20><><EFBFBD>`<60><><EFBFBD><EFBFBD><EFBFBD>l<EFBFBD><6C><EFBFBD>ɂ̓m<CD83>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>s<EFBFBD><73><EFBFBD>Ȃ<EFBFBD>
cv::Mat im;
cv::Size_<int> size;
image.GetScalePaddingedA(im, size, mScaleNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mScaleNet->GetScale() / mScaleNet->GetInnerScale());
ret = ReconstructByNet(mScaleNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
image.SetReconstructedA(im, size, mScaleNet->GetInnerScale());
}
cv::Mat im;
cv::Size_<int> size;
image.GetScalePaddingedRGB(im, size, mNoiseNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mNoiseNet->GetScale() / mNoiseNet->GetInnerScale());
ret = ReconstructByNet(mNoiseNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
if (ret != Waifu2x::eWaifu2xError_OK)
return ret;
image.SetReconstructedRGB(im, size, mNoiseNet->GetInnerScale());
return Waifu2x::eWaifu2xError_OK;
}
Waifu2x::eWaifu2xError Waifu2x::ReconstructByNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
const Waifu2x::waifu2xCancelFunc cancel_func, cv::Mat &im)
2015-12-06 02:13:30 +09:00
{
Waifu2x::eWaifu2xError ret;
if (!use_tta) // <20><><EFBFBD>ʂɏ<CA82><C98F><EFBFBD>
{
ret = ProcessNet(net, crop_w, crop_h, use_tta, batch_size, im);
if (ret != Waifu2x::eWaifu2xError_OK)
2015-12-06 02:13:30 +09:00
return ret;
}
else // Test-Time Augmentation Mode
{
const auto RotateClockwise90 = [](cv::Mat &mat)
{
cv::transpose(mat, mat);
cv::flip(mat, mat, 1);
};
const auto RotateClockwise90N = [RotateClockwise90](cv::Mat &mat, const int rotateNum)
{
for (int i = 0; i < rotateNum; i++)
RotateClockwise90(mat);
};
const auto RotateCounterclockwise90 = [](cv::Mat &mat)
{
cv::transpose(mat, mat);
cv::flip(mat, mat, 0);
};
const auto RotateCounterclockwise90N = [RotateCounterclockwise90](cv::Mat &mat, const int rotateNum)
{
for (int i = 0; i < rotateNum; i++)
RotateCounterclockwise90(mat);
};
cv::Mat reconstruct_image;
2015-12-06 02:13:30 +09:00
for (int i = 0; i < 8; i++)
{
cv::Mat in(im.clone());
2015-12-06 02:13:30 +09:00
const int rotateNum = i % 4;
RotateClockwise90N(in, rotateNum);
if (i >= 4)
cv::flip(in, in, 1); // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>]
ret = ProcessNet(net, crop_w, crop_h, use_tta, batch_size, in);
if (ret != Waifu2x::eWaifu2xError_OK)
2015-12-06 02:13:30 +09:00
return ret;
if (i >= 4)
cv::flip(in, in, 1); // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>]
RotateCounterclockwise90N(in, rotateNum);
if (i == 0)
reconstruct_image = in;
else
reconstruct_image += in;
2015-12-06 02:13:30 +09:00
}
reconstruct_image /= 8.0;
im = reconstruct_image;
2015-12-06 02:13:30 +09:00
}
return Waifu2x::eWaifu2xError_OK;
2015-12-06 02:13:30 +09:00
}
Waifu2x::eWaifu2xError Waifu2x::ProcessNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size, cv::Mat &im)
2015-11-19 01:50:11 +09:00
{
Waifu2x::eWaifu2xError ret;
2015-05-29 01:47:26 +09:00
const auto OutputMemorySize = net->GetOutputMemorySize(crop_w, crop_h, OuterPadding, batch_size);
if (OutputMemorySize > mOutputBlockSize)
{
if (mIsCuda)
CUDA_HOST_SAFE_FREE(mOutputBlock);
else
SAFE_DELETE_WAIFU2X(mOutputBlock);
CUDA_CHECK_WAIFU2X(cudaHostAlloc(&mOutputBlock, OutputMemorySize, cudaHostAllocDefault));
mOutputBlockSize = OutputMemorySize;
}
2015-11-19 01:50:11 +09:00
ret = net->ReconstructImage(use_tta, crop_w, crop_h, OuterPadding, batch_size, mOutputBlock, im, im);
if (ret != Waifu2x::eWaifu2xError_OK)
2015-11-19 01:50:11 +09:00
return ret;
return Waifu2x::eWaifu2xError_OK;
}
void Waifu2x::Destroy()
{
mNoiseNet.reset();
mScaleNet.reset();
if (mIsCuda)
{
CUDA_HOST_SAFE_FREE(mOutputBlock);
}
else
{
SAFE_DELETE_WAIFU2X(mOutputBlock);
}
mIsInited = false;
2015-05-29 01:47:26 +09:00
}
2015-06-03 03:01:56 +09:00
const std::string& Waifu2x::used_process() const
{
return mProcess;
}
std::string Waifu2x::GetModelName(const boost::filesystem::path & model_dir)
{
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
if (!boost::filesystem::exists(mode_dir_path))
return std::string();
const boost::filesystem::path info_path = mode_dir_path / "info.json";
return cNet::GetModelName(info_path);
}