2015-05-29 01:47:26 +09:00
|
|
|
|
#include "waifu2x.h"
|
2016-07-03 13:37:26 +09:00
|
|
|
|
#include "stImage.h"
|
|
|
|
|
#include "cNet.h"
|
2015-05-29 01:47:26 +09:00
|
|
|
|
#include <caffe/caffe.hpp>
|
|
|
|
|
#include <cudnn.h>
|
|
|
|
|
#include <mutex>
|
2016-05-11 22:55:16 +09:00
|
|
|
|
#include <opencv2/core.hpp>
|
2016-07-08 23:35:51 +09:00
|
|
|
|
#include <opencv2/imgproc.hpp>
|
2015-05-29 01:47:26 +09:00
|
|
|
|
#include <tclap/CmdLine.h>
|
|
|
|
|
#include <boost/filesystem.hpp>
|
|
|
|
|
#include <boost/algorithm/string.hpp>
|
2015-06-02 01:04:20 +09:00
|
|
|
|
#include <chrono>
|
2016-07-06 18:41:13 +09:00
|
|
|
|
#include <unordered_map>
|
2015-06-05 01:40:03 +09:00
|
|
|
|
#include <cuda_runtime.h>
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2015-12-06 18:48:37 +09:00
|
|
|
|
#include <boost/iostreams/stream.hpp>
|
|
|
|
|
#include <boost/iostreams/device/file_descriptor.hpp>
|
2016-07-06 18:41:13 +09:00
|
|
|
|
#include <msgpack.hpp>
|
2016-07-03 13:37:26 +09:00
|
|
|
|
|
2015-12-06 18:48:37 +09:00
|
|
|
|
#include <fcntl.h>
|
2016-04-29 13:57:23 +09:00
|
|
|
|
#include <zlib.h>
|
2015-12-06 18:48:37 +09:00
|
|
|
|
#ifdef _MSC_VER
|
|
|
|
|
#include <io.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
//#if defined(WIN32) || defined(WIN64)
|
|
|
|
|
//#include <Windows.h>
|
|
|
|
|
//#endif
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2015-12-26 18:09:43 +09:00
|
|
|
|
#define CV_VERSION_STR CVAUX_STR(CV_MAJOR_VERSION) CVAUX_STR(CV_MINOR_VERSION) CVAUX_STR(CV_SUBMINOR_VERSION)
|
|
|
|
|
|
|
|
|
|
// <20>r<EFBFBD><72><EFBFBD>h<EFBFBD><68><EFBFBD>[<5B>h
|
|
|
|
|
#ifdef _DEBUG
|
|
|
|
|
#define CV_EXT_STR "d.lib"
|
|
|
|
|
#else
|
|
|
|
|
#define CV_EXT_STR ".lib"
|
|
|
|
|
#endif
|
|
|
|
|
|
2015-12-06 18:48:37 +09:00
|
|
|
|
#ifdef _MSC_VER
|
2016-05-11 22:55:16 +09:00
|
|
|
|
|
|
|
|
|
#pragma comment(lib, "opencv_core" CV_VERSION_STR CV_EXT_STR)
|
|
|
|
|
#pragma comment(lib, "opencv_imgcodecs" CV_VERSION_STR CV_EXT_STR)
|
|
|
|
|
#pragma comment(lib, "opencv_imgproc" CV_VERSION_STR CV_EXT_STR)
|
2018-10-17 03:55:18 +09:00
|
|
|
|
#pragma comment(lib, "IlmImf" CV_EXT_STR)
|
|
|
|
|
#pragma comment(lib, "ippicvmt.lib")
|
|
|
|
|
#pragma comment(lib, "libjasper" CV_EXT_STR)
|
|
|
|
|
#pragma comment(lib, "libjpeg-turbo" CV_EXT_STR)
|
|
|
|
|
#pragma comment(lib, "libpng" CV_EXT_STR)
|
|
|
|
|
#pragma comment(lib, "libtiff" CV_EXT_STR)
|
|
|
|
|
#pragma comment(lib, "libwebp" CV_EXT_STR)
|
|
|
|
|
#pragma comment(lib, "zlib" CV_EXT_STR)
|
2016-05-11 22:55:16 +09:00
|
|
|
|
|
|
|
|
|
#pragma comment(lib, "libopenblas.dll.a")
|
|
|
|
|
#pragma comment(lib, "cudart.lib")
|
|
|
|
|
#pragma comment(lib, "curand.lib")
|
|
|
|
|
#pragma comment(lib, "cublas.lib")
|
|
|
|
|
#pragma comment(lib, "cudnn.lib")
|
|
|
|
|
|
2015-05-29 01:47:26 +09:00
|
|
|
|
#ifdef _DEBUG
|
2015-12-04 00:48:55 +09:00
|
|
|
|
#pragma comment(lib, "caffe-d.lib")
|
2018-10-17 03:55:18 +09:00
|
|
|
|
#pragma comment(lib, "caffeproto-d.lib")
|
|
|
|
|
#pragma comment(lib, "libprotobufd.lib")
|
2015-12-04 00:48:55 +09:00
|
|
|
|
#pragma comment(lib, "glogd.lib")
|
|
|
|
|
#pragma comment(lib, "gflagsd.lib")
|
2018-10-17 03:55:18 +09:00
|
|
|
|
#pragma comment(lib, "libboost_system-vc140-mt-gd-1_61.lib")
|
|
|
|
|
#pragma comment(lib, "boost_thread-vc140-mt-gd-1_61.lib")
|
|
|
|
|
#pragma comment(lib, "boost_filesystem-vc140-mt-gd-1_61.lib")
|
|
|
|
|
#pragma comment(lib, "boost_iostreams-vc140-mt-gd-1_61.lib")
|
|
|
|
|
//#pragma comment(lib, "zlibstaticd.lib")
|
|
|
|
|
|
2015-12-26 18:09:43 +09:00
|
|
|
|
|
2015-05-29 01:47:26 +09:00
|
|
|
|
#else
|
2015-12-04 00:48:55 +09:00
|
|
|
|
#pragma comment(lib, "caffe.lib")
|
2018-10-17 03:55:18 +09:00
|
|
|
|
#pragma comment(lib, "caffeproto.lib")
|
|
|
|
|
#pragma comment(lib, "libprotobuf.lib")
|
2015-12-04 00:48:55 +09:00
|
|
|
|
#pragma comment(lib, "glog.lib")
|
|
|
|
|
#pragma comment(lib, "gflags.lib")
|
2018-10-17 03:55:18 +09:00
|
|
|
|
#pragma comment(lib, "libboost_system-vc140-mt-1_61.lib")
|
|
|
|
|
#pragma comment(lib, "boost_thread-vc140-mt-1_61.lib")
|
|
|
|
|
#pragma comment(lib, "boost_filesystem-vc140-mt-1_61.lib")
|
|
|
|
|
#pragma comment(lib, "boost_iostreams-vc140-mt-1_61.lib")
|
2015-12-06 18:48:37 +09:00
|
|
|
|
#endif
|
2015-05-29 01:47:26 +09:00
|
|
|
|
#endif
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
const int ScaleBase = 2; // TODO: <20><><EFBFBD>f<EFBFBD><66><EFBFBD>̊g<CC8A>嗦<EFBFBD>ɂ<EFBFBD><C982><EFBFBD><EFBFBD>ĉςł<CF82><C582><EFBFBD><EFBFBD>悤<EFBFBD>ɂ<EFBFBD><C982><EFBFBD>
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
// <20><><EFBFBD>͉摜<CD89>ɒlj<C992><C789><EFBFBD><EFBFBD><EFBFBD><EFBFBD>p<EFBFBD>f<EFBFBD>B<EFBFBD><42><EFBFBD>O
|
|
|
|
|
const int OuterPadding = 0;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2015-06-08 03:34:42 +09:00
|
|
|
|
// <20>Œ<EFBFBD><C592><EFBFBD><EFBFBD>K<EFBFBD>v<EFBFBD><76>CUDA<44>h<EFBFBD><68><EFBFBD>C<EFBFBD>o<EFBFBD>[<5B>̃o<CC83>[<5B>W<EFBFBD><57><EFBFBD><EFBFBD>
|
2015-12-27 06:42:05 +09:00
|
|
|
|
const int MinCudaDriverVersion = 7050;
|
2015-06-08 03:34:42 +09:00
|
|
|
|
|
2015-06-03 03:01:56 +09:00
|
|
|
|
static std::once_flag waifu2x_once_flag;
|
|
|
|
|
static std::once_flag waifu2x_cudnn_once_flag;
|
2015-06-08 03:34:42 +09:00
|
|
|
|
static std::once_flag waifu2x_cuda_once_flag;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
std::string Waifu2x::ExeDir;
|
|
|
|
|
|
2015-12-27 06:39:13 +09:00
|
|
|
|
|
2015-06-05 01:40:03 +09:00
|
|
|
|
#ifndef CUDA_CHECK_WAIFU2X
|
|
|
|
|
#define CUDA_CHECK_WAIFU2X(condition) \
|
|
|
|
|
do { \
|
|
|
|
|
cudaError_t error = condition; \
|
|
|
|
|
if(error != cudaSuccess) throw error; \
|
|
|
|
|
} while (0)
|
|
|
|
|
#endif
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2015-06-05 01:40:03 +09:00
|
|
|
|
#define CUDA_HOST_SAFE_FREE(ptr) \
|
|
|
|
|
do { \
|
|
|
|
|
if (ptr) { \
|
|
|
|
|
cudaFreeHost(ptr); \
|
|
|
|
|
ptr = nullptr; \
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
#define SAFE_DELETE_WAIFU2X(ptr) \
|
|
|
|
|
do { \
|
|
|
|
|
if (ptr) { \
|
|
|
|
|
delete [] ptr; \
|
|
|
|
|
ptr = nullptr; \
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
2015-07-10 04:09:22 +09:00
|
|
|
|
namespace
|
|
|
|
|
{
|
|
|
|
|
class IgnoreErrorCV
|
|
|
|
|
{
|
|
|
|
|
private:
|
|
|
|
|
static int handleError(int status, const char* func_name,
|
|
|
|
|
const char* err_msg, const char* file_name,
|
|
|
|
|
int line, void* userdata)
|
|
|
|
|
{
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
IgnoreErrorCV()
|
|
|
|
|
{
|
|
|
|
|
cv::redirectError(handleError);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
IgnoreErrorCV g_IgnoreErrorCV;
|
2016-06-08 11:11:36 +09:00
|
|
|
|
|
|
|
|
|
class CudaDeviceSet
|
|
|
|
|
{
|
|
|
|
|
private:
|
|
|
|
|
int orgDevice;
|
|
|
|
|
bool mIsSet;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
CudaDeviceSet(const std::string &process, const int devno) : orgDevice(0), mIsSet(false)
|
|
|
|
|
{
|
|
|
|
|
if (process == "gpu" || process == "cudnn")
|
|
|
|
|
{
|
|
|
|
|
int count = 0;
|
|
|
|
|
if (cudaGetDeviceCount(&count) != CUDA_SUCCESS)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
if (devno >= count || count < 0)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
if (cudaGetDevice(&orgDevice) != CUDA_SUCCESS)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
if (cudaSetDevice(devno) != CUDA_SUCCESS)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
mIsSet = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~CudaDeviceSet()
|
|
|
|
|
{
|
|
|
|
|
if (mIsSet)
|
|
|
|
|
cudaSetDevice(orgDevice);
|
|
|
|
|
}
|
|
|
|
|
};
|
2015-07-10 04:09:22 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-06 18:41:13 +09:00
|
|
|
|
class CcuDNNAlgorithmElement
|
|
|
|
|
{
|
|
|
|
|
private:
|
|
|
|
|
typedef std::unordered_map<uint64_t, uint8_t> AlgoMap;
|
|
|
|
|
|
|
|
|
|
AlgoMap mAlgo;
|
|
|
|
|
bool mIsModefy;
|
|
|
|
|
|
|
|
|
|
uint8_t kernel_w;
|
|
|
|
|
uint8_t kernel_h;
|
|
|
|
|
uint8_t pad_w;
|
|
|
|
|
uint8_t pad_h;
|
|
|
|
|
uint8_t stride_w;
|
|
|
|
|
uint8_t stride_h;
|
|
|
|
|
uint16_t batch_size;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static uint64_t InfoToKey(uint16_t num_input, uint16_t num_output, uint16_t width, uint16_t height)
|
|
|
|
|
{
|
|
|
|
|
return (uint64_t)num_input << 8 * 3 | (uint64_t)num_output << 8 * 2 | (uint64_t)width << 8 * 1 | (uint64_t)height << 8 * 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
CcuDNNAlgorithmElement() : mIsModefy(false)
|
|
|
|
|
{}
|
|
|
|
|
~CcuDNNAlgorithmElement()
|
|
|
|
|
{}
|
|
|
|
|
|
|
|
|
|
void SetLayerData(uint8_t kernel_w, uint8_t kernel_h, uint8_t pad_w, uint8_t pad_h, uint8_t stride_w, uint8_t stride_h, uint16_t batch_size)
|
|
|
|
|
{
|
|
|
|
|
this->kernel_w = kernel_w;
|
|
|
|
|
this->kernel_h = kernel_h;
|
|
|
|
|
this->pad_w = pad_w;
|
|
|
|
|
this->pad_h = pad_h;
|
|
|
|
|
this->stride_w = stride_w;
|
|
|
|
|
this->stride_h = stride_h;
|
|
|
|
|
this->batch_size = batch_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetLayerData(uint8_t &kernel_w, uint8_t &kernel_h, uint8_t &pad_w, uint8_t &pad_h, uint8_t &stride_w, uint8_t &stride_h, uint16_t &batch_size)
|
|
|
|
|
{
|
|
|
|
|
kernel_w = this->kernel_w;
|
|
|
|
|
kernel_h = this->kernel_h;
|
|
|
|
|
pad_w = this->pad_w;
|
|
|
|
|
pad_h = this->pad_h;
|
|
|
|
|
stride_w = this->stride_w;
|
|
|
|
|
stride_h = this->stride_h;
|
|
|
|
|
batch_size = this->batch_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GetAlgorithm(uint16_t num_input, uint16_t num_output, uint16_t width, uint16_t height) const
|
|
|
|
|
{
|
|
|
|
|
const uint64_t key = InfoToKey(num_input, num_output, width, height);
|
|
|
|
|
const auto it = mAlgo.find(key);
|
|
|
|
|
if (it != mAlgo.end())
|
|
|
|
|
return it->second;
|
|
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetAlgorithm(uint8_t algo, uint16_t num_input, uint16_t num_output, uint16_t width, uint16_t height)
|
|
|
|
|
{
|
|
|
|
|
const uint64_t key = InfoToKey(num_input, num_output, width, height);
|
|
|
|
|
auto it = mAlgo.find(key);
|
|
|
|
|
if (it == mAlgo.end() || it->second != algo)
|
|
|
|
|
mIsModefy = true;
|
|
|
|
|
|
|
|
|
|
mAlgo[key] = algo;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsModefy() const
|
|
|
|
|
{
|
|
|
|
|
return mIsModefy;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Saved()
|
|
|
|
|
{
|
|
|
|
|
mIsModefy = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MSGPACK_DEFINE(mAlgo, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CcuDNNAlgorithm
|
|
|
|
|
{
|
|
|
|
|
private:
|
|
|
|
|
typedef std::unordered_map<uint64_t, CcuDNNAlgorithmElement> AlgoEmlMap;
|
|
|
|
|
|
|
|
|
|
AlgoEmlMap mAlgoEmlMap;
|
|
|
|
|
std::string mDataPath;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static uint64_t InfoToKey(uint8_t kernel_w, uint8_t kernel_h, uint8_t pad_w, uint8_t pad_h, uint8_t stride_w, uint8_t stride_h, uint16_t batch_size)
|
|
|
|
|
{
|
|
|
|
|
return
|
|
|
|
|
(uint64_t)kernel_w << 8 * 7 | (uint64_t)kernel_h << 8 * 6 |
|
|
|
|
|
(uint64_t)pad_w << 8 * 5 | (uint64_t)pad_h << 8 * 4 |
|
|
|
|
|
(uint64_t)stride_w << 8 * 3 | (uint64_t)stride_h << 8 * 2 |
|
|
|
|
|
(uint64_t)batch_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetDataPath(uint8_t kernel_w, uint8_t kernel_h, uint8_t pad_w, uint8_t pad_h, uint8_t stride_w, uint8_t stride_h, uint16_t batch_size) const
|
|
|
|
|
{
|
|
|
|
|
std::string SavePath = mDataPath;
|
|
|
|
|
SavePath +=
|
|
|
|
|
std::to_string(kernel_w) + "x" + std::to_string(kernel_h) + " " +
|
|
|
|
|
std::to_string(pad_w) + "x" + std::to_string(pad_w) + " " +
|
|
|
|
|
std::to_string(stride_w) + "x" + std::to_string(stride_w) + " " +
|
|
|
|
|
std::to_string(batch_size);
|
|
|
|
|
SavePath += ".dat";
|
|
|
|
|
|
|
|
|
|
return SavePath;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Load(uint8_t kernel_w, uint8_t kernel_h, uint8_t pad_w, uint8_t pad_h, uint8_t stride_w, uint8_t stride_h, uint16_t batch_size)
|
|
|
|
|
{
|
|
|
|
|
const std::string SavePath = GetDataPath(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size);
|
|
|
|
|
|
|
|
|
|
std::vector<char> sbuf;
|
|
|
|
|
|
|
|
|
|
FILE *fp = fopen(SavePath.c_str(), "rb");
|
|
|
|
|
if (!fp)
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
fseek(fp, 0, SEEK_END);
|
|
|
|
|
const auto size = ftell(fp);
|
|
|
|
|
fseek(fp, 0, SEEK_SET);
|
|
|
|
|
|
|
|
|
|
sbuf.resize(size);
|
|
|
|
|
|
|
|
|
|
if (fread(sbuf.data(), 1, sbuf.size(), fp) != sbuf.size())
|
|
|
|
|
{
|
|
|
|
|
fclose(fp);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fclose(fp);
|
|
|
|
|
|
2016-07-06 22:33:09 +09:00
|
|
|
|
try
|
|
|
|
|
{
|
|
|
|
|
CcuDNNAlgorithmElement elm;
|
|
|
|
|
msgpack::unpack(sbuf.data(), sbuf.size()).get().convert(elm);
|
|
|
|
|
sbuf.clear();
|
2016-07-06 18:41:13 +09:00
|
|
|
|
|
2016-07-06 22:33:09 +09:00
|
|
|
|
const uint64_t key = InfoToKey(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size);
|
|
|
|
|
mAlgoEmlMap[key] = std::move(elm);
|
|
|
|
|
}
|
|
|
|
|
catch (...)
|
|
|
|
|
{
|
|
|
|
|
boost::filesystem::remove(SavePath);
|
|
|
|
|
}
|
2016-07-06 18:41:13 +09:00
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
CcuDNNAlgorithm()
|
|
|
|
|
{}
|
|
|
|
|
|
|
|
|
|
~CcuDNNAlgorithm()
|
|
|
|
|
{
|
|
|
|
|
Save();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GetAlgorithm(uint16_t num_input, uint16_t num_output, uint16_t batch_size,
|
|
|
|
|
uint16_t width, uint16_t height, uint16_t kernel_w, uint16_t kernel_h, uint16_t pad_w, uint16_t pad_h, uint16_t stride_w, uint16_t stride_h)
|
|
|
|
|
{
|
|
|
|
|
const uint64_t key = InfoToKey(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size);
|
|
|
|
|
const auto it = mAlgoEmlMap.find(key);
|
|
|
|
|
if (it != mAlgoEmlMap.end())
|
|
|
|
|
{
|
|
|
|
|
const auto &elm = it->second;
|
|
|
|
|
return elm.GetAlgorithm(num_input, num_output, width, height);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (Load(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size))
|
|
|
|
|
return mAlgoEmlMap[key].GetAlgorithm(num_input, num_output, width, height);
|
|
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetAlgorithm(int algo, uint16_t num_input, uint16_t num_output, uint16_t batch_size,
|
|
|
|
|
uint16_t width, uint16_t height, uint16_t kernel_w, uint16_t kernel_h, uint16_t pad_w, uint16_t pad_h, uint16_t stride_w, uint16_t stride_h)
|
|
|
|
|
{
|
|
|
|
|
if (algo < 0 || algo > 255)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
const uint64_t key = InfoToKey(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size);
|
|
|
|
|
auto &eml = mAlgoEmlMap[key];
|
|
|
|
|
eml.SetAlgorithm(algo, num_input, num_output, width, height);
|
|
|
|
|
eml.SetLayerData(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Save()
|
|
|
|
|
{
|
|
|
|
|
for (auto &p : mAlgoEmlMap)
|
|
|
|
|
{
|
|
|
|
|
auto &eml = p.second;
|
|
|
|
|
if (eml.IsModefy())
|
|
|
|
|
{
|
2016-07-06 22:33:09 +09:00
|
|
|
|
try
|
2016-07-06 18:41:13 +09:00
|
|
|
|
{
|
2016-07-06 22:33:09 +09:00
|
|
|
|
msgpack::sbuffer sbuf;
|
|
|
|
|
msgpack::pack(sbuf, eml);
|
|
|
|
|
|
|
|
|
|
uint8_t kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h;
|
|
|
|
|
uint16_t batch_size;
|
|
|
|
|
eml.GetLayerData(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size);
|
|
|
|
|
|
|
|
|
|
const std::string SavePath = GetDataPath(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size);
|
|
|
|
|
FILE *fp = fopen(SavePath.c_str(), "wb");
|
|
|
|
|
if (fp)
|
|
|
|
|
{
|
|
|
|
|
fwrite(sbuf.data(), 1, sbuf.size(), fp);
|
|
|
|
|
fclose(fp);
|
|
|
|
|
|
|
|
|
|
eml.Saved();
|
|
|
|
|
}
|
2016-07-06 18:41:13 +09:00
|
|
|
|
}
|
2016-07-06 22:33:09 +09:00
|
|
|
|
catch(...)
|
|
|
|
|
{}
|
2016-07-06 18:41:13 +09:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDataPath(std::string path)
|
|
|
|
|
{
|
|
|
|
|
mDataPath = path;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
CcuDNNAlgorithm g_ConvCcuDNNAlgorithm;
|
|
|
|
|
CcuDNNAlgorithm g_DeconvCcuDNNAlgorithm;
|
|
|
|
|
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
// CUDA<44><41><EFBFBD>g<EFBFBD><67><EFBFBD>邩<EFBFBD>`<60>F<EFBFBD>b<EFBFBD>N
|
|
|
|
|
Waifu2x::eWaifu2xCudaError Waifu2x::can_use_CUDA()
|
2015-12-06 18:48:37 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
static eWaifu2xCudaError CudaFlag = eWaifu2xCudaError_NotFind;
|
|
|
|
|
std::call_once(waifu2x_cuda_once_flag, [&]()
|
2015-12-26 18:45:30 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
int driverVersion = 0;
|
|
|
|
|
if (cudaDriverGetVersion(&driverVersion) == cudaSuccess)
|
|
|
|
|
{
|
|
|
|
|
if (driverVersion > 0)
|
|
|
|
|
{
|
|
|
|
|
int runtimeVersion;
|
|
|
|
|
if (cudaRuntimeGetVersion(&runtimeVersion) == cudaSuccess)
|
|
|
|
|
{
|
|
|
|
|
if (runtimeVersion >= MinCudaDriverVersion && driverVersion >= runtimeVersion)
|
|
|
|
|
{
|
|
|
|
|
cudaDeviceProp prop;
|
|
|
|
|
cudaGetDeviceProperties(&prop, 0);
|
2020-09-05 16:30:18 +09:00
|
|
|
|
if (prop.major >= 3 && prop.minor >= 5 || prop.major >= 4)
|
2016-07-03 13:37:26 +09:00
|
|
|
|
CudaFlag = eWaifu2xCudaError_OK;
|
|
|
|
|
else
|
|
|
|
|
CudaFlag = eWaifu2xCudaError_OldDevice;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
CudaFlag = eWaifu2xCudaError_OldVersion;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
CudaFlag = eWaifu2xCudaError_NotFind;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
CudaFlag = eWaifu2xCudaError_NotFind;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
CudaFlag = eWaifu2xCudaError_NotFind;
|
|
|
|
|
});
|
2015-06-03 03:01:56 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return CudaFlag;
|
2015-06-03 03:01:56 +09:00
|
|
|
|
}
|
|
|
|
|
|
2015-05-29 01:47:26 +09:00
|
|
|
|
// cuDNN<4E><4E><EFBFBD>g<EFBFBD><67><EFBFBD>邩<EFBFBD>`<60>F<EFBFBD>b<EFBFBD>N<EFBFBD>B<EFBFBD><42><EFBFBD><EFBFBD>Windows<77>̂<EFBFBD>
|
2015-06-03 13:51:48 +09:00
|
|
|
|
Waifu2x::eWaifu2xcuDNNError Waifu2x::can_use_cuDNN()
|
2015-05-29 01:47:26 +09:00
|
|
|
|
{
|
2015-06-03 13:51:48 +09:00
|
|
|
|
static eWaifu2xcuDNNError cuDNNFlag = eWaifu2xcuDNNError_NotFind;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
std::call_once(waifu2x_cudnn_once_flag, [&]()
|
|
|
|
|
{
|
|
|
|
|
#if defined(WIN32) || defined(WIN64)
|
2015-12-04 00:48:55 +09:00
|
|
|
|
HMODULE hModule = LoadLibrary(TEXT(CUDNN_DLL_NAME));
|
2015-05-29 01:47:26 +09:00
|
|
|
|
if (hModule != NULL)
|
|
|
|
|
{
|
2020-09-05 16:25:18 +09:00
|
|
|
|
typedef cudnnStatus_t(CUDNNWINAPI* cudnnCreateType)(cudnnHandle_t *);
|
|
|
|
|
typedef cudnnStatus_t(CUDNNWINAPI* cudnnDestroyType)(cudnnHandle_t);
|
|
|
|
|
typedef size_t(CUDNNWINAPI* cudnnGetVersionType)();
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
|
|
|
|
cudnnCreateType cudnnCreateFunc = (cudnnCreateType)GetProcAddress(hModule, "cudnnCreate");
|
|
|
|
|
cudnnDestroyType cudnnDestroyFunc = (cudnnDestroyType)GetProcAddress(hModule, "cudnnDestroy");
|
2015-06-03 13:51:48 +09:00
|
|
|
|
cudnnGetVersionType cudnnGetVersionFunc = (cudnnGetVersionType)GetProcAddress(hModule, "cudnnGetVersion");
|
|
|
|
|
if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr)
|
2015-05-29 01:47:26 +09:00
|
|
|
|
{
|
2018-11-23 23:10:54 +09:00
|
|
|
|
if (cudnnGetVersionFunc() >= CUDNN_REQUIRE_VERION)
|
2015-05-29 01:47:26 +09:00
|
|
|
|
{
|
2015-06-03 13:51:48 +09:00
|
|
|
|
cudnnHandle_t h;
|
|
|
|
|
if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS)
|
|
|
|
|
{
|
|
|
|
|
if (cudnnDestroyFunc(h) == CUDNN_STATUS_SUCCESS)
|
|
|
|
|
cuDNNFlag = eWaifu2xcuDNNError_OK;
|
|
|
|
|
else
|
|
|
|
|
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
}
|
2015-06-03 13:51:48 +09:00
|
|
|
|
else
|
|
|
|
|
cuDNNFlag = eWaifu2xcuDNNError_OldVersion;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
}
|
2015-06-03 13:51:48 +09:00
|
|
|
|
else
|
|
|
|
|
cuDNNFlag = eWaifu2xcuDNNError_NotFind;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
|
|
|
|
FreeLibrary(hModule);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return cuDNNFlag;
|
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
void Waifu2x::init_liblary(int argc, char** argv)
|
2015-06-08 03:34:42 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (argc > 0)
|
|
|
|
|
ExeDir = argv[0];
|
2015-06-08 03:34:42 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
std::call_once(waifu2x_once_flag, [argc, argv]()
|
|
|
|
|
{
|
|
|
|
|
assert(argc >= 1);
|
2015-06-08 03:34:42 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
int tmpargc = 1;
|
|
|
|
|
char* tmpargvv[] = {argv[0]};
|
|
|
|
|
char** tmpargv = tmpargvv;
|
|
|
|
|
// glog<6F><67><EFBFBD>̏<EFBFBD><CC8F><EFBFBD><EFBFBD><EFBFBD>
|
|
|
|
|
caffe::GlobalInit(&tmpargc, &tmpargv);
|
|
|
|
|
});
|
2015-12-04 00:48:55 +09:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Waifu2x::quit_liblary()
|
2016-07-06 18:41:13 +09:00
|
|
|
|
{
|
|
|
|
|
g_ConvCcuDNNAlgorithm.Save();
|
|
|
|
|
g_DeconvCcuDNNAlgorithm.Save();
|
2016-07-10 20:09:46 +09:00
|
|
|
|
|
2016-07-11 00:21:39 +09:00
|
|
|
|
//caffe::GlobalFinalize();
|
2016-07-06 18:41:13 +09:00
|
|
|
|
}
|
2015-12-04 00:48:55 +09:00
|
|
|
|
|
2016-07-10 20:09:46 +09:00
|
|
|
|
void Waifu2x::quit_thread_liblary()
|
|
|
|
|
{
|
2016-07-11 00:21:39 +09:00
|
|
|
|
//caffe::ThreadFinalize();
|
2016-07-10 20:09:46 +09:00
|
|
|
|
}
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2016-07-03 22:23:54 +09:00
|
|
|
|
Waifu2x::Waifu2x() : mIsInited(false), mNoiseLevel(0), mIsCuda(false), mOutputBlock(nullptr), mOutputBlockSize(0), mGPUNo(0)
|
2016-07-03 13:37:26 +09:00
|
|
|
|
{}
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
Waifu2x::~Waifu2x()
|
2015-12-06 00:55:45 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
Destroy();
|
|
|
|
|
}
|
2015-12-06 00:55:45 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
Waifu2x::eWaifu2xError Waifu2x::Init(const eWaifu2xModelType mode, const int noise_level,
|
2016-07-03 22:23:54 +09:00
|
|
|
|
const boost::filesystem::path &model_dir, const std::string &process, const int GPUNo)
|
2016-07-03 13:37:26 +09:00
|
|
|
|
{
|
|
|
|
|
Waifu2x::eWaifu2xError ret;
|
2015-12-06 00:55:45 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (mIsInited)
|
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2015-12-06 00:55:45 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
try
|
2015-12-06 00:55:45 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
std::string Process = process;
|
|
|
|
|
const auto cuDNNCheckStartTime = std::chrono::system_clock::now();
|
2015-12-06 00:55:45 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (Process == "gpu")
|
2015-12-06 00:55:45 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (can_use_CUDA() != eWaifu2xCudaError_OK)
|
|
|
|
|
return Waifu2x::eWaifu2xError_FailedCudaCheck;
|
|
|
|
|
// cuDNN<4E><4E><EFBFBD>g<EFBFBD><67><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ȃ<EFBFBD>cuDNN<4E><4E><EFBFBD>g<EFBFBD><67>
|
|
|
|
|
else if (can_use_cuDNN() == eWaifu2xcuDNNError_OK)
|
|
|
|
|
Process = "cudnn";
|
2015-12-06 00:55:45 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
mMode = mode;
|
|
|
|
|
mNoiseLevel = noise_level;
|
|
|
|
|
mProcess = Process;
|
2016-07-03 22:23:54 +09:00
|
|
|
|
mGPUNo = GPUNo;
|
2015-12-06 00:55:45 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
const auto cuDNNCheckEndTime = std::chrono::system_clock::now();
|
2015-12-06 18:48:37 +09:00
|
|
|
|
|
2016-07-07 01:25:04 +09:00
|
|
|
|
if (Process == "cudnn")
|
2016-07-06 18:41:13 +09:00
|
|
|
|
{
|
2016-07-07 01:25:04 +09:00
|
|
|
|
// exe<78>̃f<CC83>B<EFBFBD><42><EFBFBD>N<EFBFBD>g<EFBFBD><67><EFBFBD><EFBFBD>cuDNN<4E>̃A<CC83><41><EFBFBD>S<EFBFBD><53><EFBFBD>Y<EFBFBD><59><EFBFBD>f<EFBFBD>[<5B>^<5E>ۑ<EFBFBD>
|
|
|
|
|
boost::filesystem::path cudnn_data_base_dir_path(ExeDir);
|
|
|
|
|
if (cudnn_data_base_dir_path.is_relative())
|
|
|
|
|
cudnn_data_base_dir_path = boost::filesystem::system_complete(cudnn_data_base_dir_path);
|
2016-07-06 18:41:13 +09:00
|
|
|
|
|
2016-07-07 01:25:04 +09:00
|
|
|
|
if (!boost::filesystem::is_directory(cudnn_data_base_dir_path))
|
|
|
|
|
cudnn_data_base_dir_path = cudnn_data_base_dir_path.branch_path();
|
2016-07-06 18:41:13 +09:00
|
|
|
|
|
2016-07-07 01:25:04 +09:00
|
|
|
|
if (!boost::filesystem::exists(cudnn_data_base_dir_path))
|
2016-07-06 18:41:13 +09:00
|
|
|
|
{
|
2016-07-07 01:25:04 +09:00
|
|
|
|
// exe<78>̃f<CC83>B<EFBFBD><42><EFBFBD>N<EFBFBD>g<EFBFBD><67><EFBFBD><EFBFBD><EFBFBD>擾<EFBFBD>ł<EFBFBD><C582>Ȃ<EFBFBD><C882><EFBFBD><EFBFBD>J<CE83><4A><EFBFBD><EFBFBD><EFBFBD>g<EFBFBD>f<EFBFBD>B<EFBFBD><42><EFBFBD>N<EFBFBD>g<EFBFBD><67><EFBFBD>ɕۑ<C995>
|
|
|
|
|
|
|
|
|
|
cudnn_data_base_dir_path = boost::filesystem::current_path();
|
|
|
|
|
|
|
|
|
|
if (cudnn_data_base_dir_path.is_relative())
|
|
|
|
|
cudnn_data_base_dir_path = boost::filesystem::system_complete(cudnn_data_base_dir_path);
|
|
|
|
|
|
|
|
|
|
if (!boost::filesystem::exists(cudnn_data_base_dir_path))
|
|
|
|
|
cudnn_data_base_dir_path = "./";
|
2016-07-06 18:41:13 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-07 01:25:04 +09:00
|
|
|
|
if (boost::filesystem::exists(cudnn_data_base_dir_path))
|
2016-07-06 18:41:13 +09:00
|
|
|
|
{
|
2016-07-07 01:25:04 +09:00
|
|
|
|
const boost::filesystem::path cudnn_data_dir_path(cudnn_data_base_dir_path / "cudnn_data");
|
|
|
|
|
|
|
|
|
|
bool isOK = false;
|
|
|
|
|
if (boost::filesystem::exists(cudnn_data_dir_path))
|
|
|
|
|
isOK = true;
|
|
|
|
|
|
|
|
|
|
if (!isOK)
|
|
|
|
|
{
|
|
|
|
|
boost::system::error_code error;
|
|
|
|
|
const bool result = boost::filesystem::create_directory(cudnn_data_dir_path, error);
|
|
|
|
|
if (result && !error)
|
|
|
|
|
isOK = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (isOK)
|
2016-07-06 18:41:13 +09:00
|
|
|
|
{
|
2016-07-07 01:25:04 +09:00
|
|
|
|
cudaDeviceProp prop;
|
|
|
|
|
if (cudaGetDeviceProperties(&prop, mGPUNo) == cudaSuccess)
|
|
|
|
|
{
|
|
|
|
|
std::string conv_filename(prop.name);
|
|
|
|
|
conv_filename += " conv ";
|
2016-07-06 18:41:13 +09:00
|
|
|
|
|
2016-07-07 01:25:04 +09:00
|
|
|
|
std::string deconv_filename(prop.name);
|
|
|
|
|
deconv_filename += " deconv ";
|
2016-07-06 18:41:13 +09:00
|
|
|
|
|
2016-07-07 01:25:04 +09:00
|
|
|
|
const boost::filesystem::path conv_data_path = cudnn_data_dir_path / conv_filename;
|
|
|
|
|
const boost::filesystem::path deconv_data_path = cudnn_data_dir_path / deconv_filename;
|
2016-07-06 18:41:13 +09:00
|
|
|
|
|
2016-07-07 01:25:04 +09:00
|
|
|
|
g_ConvCcuDNNAlgorithm.SetDataPath(conv_data_path.string());
|
|
|
|
|
g_DeconvCcuDNNAlgorithm.SetDataPath(deconv_data_path.string());
|
|
|
|
|
}
|
2016-07-06 18:41:13 +09:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 15:36:32 +09:00
|
|
|
|
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (!boost::filesystem::exists(mode_dir_path))
|
|
|
|
|
return Waifu2x::eWaifu2xError_FailedOpenModelFile;
|
2015-12-06 18:48:37 +09:00
|
|
|
|
|
2016-07-03 22:23:54 +09:00
|
|
|
|
CudaDeviceSet devset(process, mGPUNo);
|
2015-12-06 18:48:37 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (mProcess == "cpu")
|
2015-12-06 18:48:37 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
caffe::Caffe::set_mode(caffe::Caffe::CPU);
|
|
|
|
|
mIsCuda = false;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
caffe::Caffe::set_mode(caffe::Caffe::GPU);
|
|
|
|
|
mIsCuda = true;
|
2015-12-06 18:48:37 +09:00
|
|
|
|
}
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-06 18:41:13 +09:00
|
|
|
|
caffe::Caffe::SetGetcuDNNAlgorithmFunc(GetcuDNNAlgorithm);
|
|
|
|
|
caffe::Caffe::SetSetcuDNNAlgorithmFunc(SetcuDNNAlgorithm);
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
mInputPlane = 0;
|
|
|
|
|
mMaxNetOffset = 0;
|
2015-12-27 07:20:55 +09:00
|
|
|
|
|
2016-07-03 15:36:32 +09:00
|
|
|
|
const boost::filesystem::path info_path = GetInfoPath(mode_dir_path);
|
2015-12-27 07:20:55 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
stInfo info;
|
|
|
|
|
ret = cNet::GetInfo(info_path, info);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2019-01-16 02:44:33 +09:00
|
|
|
|
mHasNoiseScaleOnly = info.has_noise_scale;
|
2016-07-03 17:13:02 +09:00
|
|
|
|
mInputPlane = info.channels;
|
2015-06-01 23:45:40 +09:00
|
|
|
|
|
2019-01-16 02:44:33 +09:00
|
|
|
|
if (mode == eWaifu2xModelTypeNoise && info.has_noise_only) // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>m<C283>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>f<EFBFBD><66><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>݂<EFBFBD><DD82><EFBFBD><EFBFBD>̂ł<CC82><C582><EFBFBD><EFBFBD>m<CE83>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD>X<EFBFBD>P<EFBFBD>[<5B><><EFBFBD><EFBFBD><EFBFBD>f<EFBFBD><66><EFBFBD>͎g<CD8E><67><EFBFBD>Ȃ<EFBFBD><C882>悤<EFBFBD>ɂ<EFBFBD><C982><EFBFBD>
|
|
|
|
|
mHasNoiseScaleOnly = false;
|
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
if (mode == eWaifu2xModelTypeNoise || mode == eWaifu2xModelTypeNoiseScale || mode == eWaifu2xModelTypeAutoScale)
|
2016-07-03 13:37:26 +09:00
|
|
|
|
{
|
2016-07-03 17:13:02 +09:00
|
|
|
|
std::string base_name;
|
2015-06-01 23:45:40 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
mNoiseNet.reset(new cNet);
|
2015-06-01 23:45:40 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
eWaifu2xModelType Mode = mode;
|
2018-11-23 23:49:22 +09:00
|
|
|
|
if (mHasNoiseScaleOnly) // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD>Ɗg<C68A><67><EFBFBD><EFBFBD><F093AF8E>ɍs<C98D><73>
|
2016-07-03 17:13:02 +09:00
|
|
|
|
{
|
|
|
|
|
// <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD>g<EFBFBD><67><EFBFBD>l<EFBFBD>b<EFBFBD>g<EFBFBD>̍\<5C>z<EFBFBD><7A>eWaifu2xModelTypeNoiseScale<6C><65><EFBFBD>w<EFBFBD>肷<EFBFBD><E882B7><EFBFBD>K<EFBFBD>v<EFBFBD><76><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
|
|
|
|
Mode = eWaifu2xModelTypeNoiseScale;
|
|
|
|
|
base_name = "noise" + std::to_string(noise_level) + "_scale2.0x_model";
|
|
|
|
|
}
|
|
|
|
|
else // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
|
|
|
|
{
|
|
|
|
|
Mode = eWaifu2xModelTypeNoise;
|
|
|
|
|
base_name = "noise" + std::to_string(noise_level) + "_model";
|
|
|
|
|
}
|
2015-06-01 23:45:40 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
|
|
|
|
|
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
|
2015-06-01 23:45:40 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
ret = mNoiseNet->ConstractNet(Mode, model_path, param_path, info, mProcess);
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
mMaxNetOffset = mNoiseNet->GetNetOffset();
|
|
|
|
|
}
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2019-01-16 02:32:49 +09:00
|
|
|
|
// <20>g<EFBFBD>傪<EFBFBD>K<EFBFBD>v<EFBFBD>ȏꍇ<C88F>̓<EFBFBD><CD83>`<60><><EFBFBD><EFBFBD><EFBFBD>l<EFBFBD><6C><EFBFBD>̊g<CC8A><67><EFBFBD>̂<EFBFBD><CC82>߂<EFBFBD>mScaleNet<65><74><EFBFBD>\<5C>z<EFBFBD><7A><EFBFBD><EFBFBD><EFBFBD>K<EFBFBD>v<EFBFBD><76><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
|
|
|
|
if (mode == eWaifu2xModelTypeScale || mode == eWaifu2xModelTypeNoiseScale || mode == eWaifu2xModelTypeAutoScale)
|
2016-07-03 13:37:26 +09:00
|
|
|
|
{
|
|
|
|
|
const std::string base_name = "scale2.0x_model";
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
const boost::filesystem::path model_path = mode_dir_path / (base_name + ".prototxt");
|
|
|
|
|
const boost::filesystem::path param_path = mode_dir_path / (base_name + ".json");
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
mScaleNet.reset(new cNet);
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
ret = mScaleNet->ConstractNet(eWaifu2xModelTypeScale, model_path, param_path, info, mProcess);
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
assert(mInputPlane == 0 || mInputPlane == mScaleNet->GetInputPlane());
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
mMaxNetOffset = std::max(mScaleNet->GetNetOffset(), mMaxNetOffset);
|
|
|
|
|
}
|
|
|
|
|
else
|
2015-07-10 04:09:22 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
|
2015-07-10 04:09:22 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
mIsInited = true;
|
|
|
|
|
}
|
|
|
|
|
catch (...)
|
2015-07-10 04:09:22 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_InvalidParameter;
|
2015-07-10 04:09:22 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2015-07-10 04:09:22 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 15:36:32 +09:00
|
|
|
|
boost::filesystem::path Waifu2x::GetModeDirPath(const boost::filesystem::path &model_dir)
|
2015-05-29 01:47:26 +09:00
|
|
|
|
{
|
2016-07-03 15:36:32 +09:00
|
|
|
|
boost::filesystem::path mode_dir_path(model_dir);
|
|
|
|
|
if (!mode_dir_path.is_absolute()) // model_dir<69><72><EFBFBD><EFBFBD><EFBFBD>p<CE83>X<EFBFBD>Ȃ<EFBFBD><C882><EFBFBD><EFBFBD>p<CE83>X<EFBFBD>ɒ<EFBFBD><C992><EFBFBD>
|
2015-12-06 03:56:33 +09:00
|
|
|
|
{
|
2016-07-03 15:36:32 +09:00
|
|
|
|
// <20>܂<EFBFBD><DC82>̓J<CD83><4A><EFBFBD><EFBFBD><EFBFBD>g<EFBFBD>f<EFBFBD>B<EFBFBD><42><EFBFBD>N<EFBFBD>g<EFBFBD><67><EFBFBD><EFBFBD><EFBFBD>ɂ<EFBFBD><C982>邩<EFBFBD>T<EFBFBD><54>
|
|
|
|
|
mode_dir_path = boost::filesystem::absolute(model_dir);
|
|
|
|
|
if (!boost::filesystem::exists(mode_dir_path) && !ExeDir.empty()) // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>argv[0]<5D><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>s<EFBFBD>t<EFBFBD>@<40>C<EFBFBD><43><EFBFBD>̂<EFBFBD><CC82><EFBFBD><EFBFBD>t<EFBFBD>H<EFBFBD><48><EFBFBD>_<EFBFBD>𐄒肵<F0908492>A<EFBFBD><41><EFBFBD>̃t<CC83>H<EFBFBD><48><EFBFBD>_<EFBFBD><5F><EFBFBD>ɂ<EFBFBD><C982>邩<EFBFBD>T<EFBFBD><54>
|
|
|
|
|
{
|
|
|
|
|
boost::filesystem::path a0(ExeDir);
|
|
|
|
|
if (a0.is_absolute())
|
|
|
|
|
mode_dir_path = a0.branch_path() / model_dir;
|
|
|
|
|
}
|
2015-12-06 03:56:33 +09:00
|
|
|
|
}
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 15:36:32 +09:00
|
|
|
|
return mode_dir_path;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 15:36:32 +09:00
|
|
|
|
boost::filesystem::path Waifu2x::GetInfoPath(const boost::filesystem::path &mode_dir_path)
|
2015-05-29 01:47:26 +09:00
|
|
|
|
{
|
2016-07-03 15:36:32 +09:00
|
|
|
|
const boost::filesystem::path info_path = mode_dir_path / "info.json";
|
|
|
|
|
|
|
|
|
|
return info_path;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const boost::filesystem::path &input_file, const boost::filesystem::path &output_file,
|
2016-07-03 21:55:20 +09:00
|
|
|
|
const boost::optional<double> scale_ratio, const boost::optional<int> scale_width, const boost::optional<int> scale_height,
|
|
|
|
|
const waifu2xCancelFunc cancel_func, const int crop_w, const int crop_h,
|
2016-07-03 13:37:26 +09:00
|
|
|
|
const boost::optional<int> output_quality, const int output_depth, const bool use_tta,
|
|
|
|
|
const int batch_size)
|
2015-05-29 01:47:26 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
Waifu2x::eWaifu2xError ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (!mIsInited)
|
|
|
|
|
return Waifu2x::eWaifu2xError_NotInitialized;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
stImage image;
|
|
|
|
|
ret = image.Load(input_file);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
image.Preprocess(mInputPlane, mMaxNetOffset);
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
const bool isReconstructNoise = mMode == eWaifu2xModelTypeNoise || mMode == eWaifu2xModelTypeNoiseScale || (mMode == eWaifu2xModelTypeAutoScale && image.RequestDenoise());
|
|
|
|
|
const bool isReconstructScale = mMode == eWaifu2xModelTypeScale || mMode == eWaifu2xModelTypeNoiseScale || mMode == eWaifu2xModelTypeAutoScale;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
auto factor = CalcScaleRatio(scale_ratio, scale_width, scale_height, image);
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
if (!isReconstructScale)
|
2017-04-29 12:41:42 +09:00
|
|
|
|
factor = Factor(1.0, 1.0);
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
cv::Mat reconstruct_image;
|
2017-04-29 12:41:42 +09:00
|
|
|
|
ret = ReconstructImage(factor, crop_w, crop_h, use_tta, batch_size, isReconstructNoise, isReconstructScale, cancel_func, image);
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2017-03-20 21:44:57 +09:00
|
|
|
|
if(!scale_width || !scale_height)
|
2017-04-29 12:41:42 +09:00
|
|
|
|
image.Postprocess(mInputPlane, factor, output_depth);
|
2017-03-20 21:44:57 +09:00
|
|
|
|
else
|
|
|
|
|
image.Postprocess(mInputPlane, *scale_width, *scale_height, output_depth);
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
ret = image.Save(output_file, output_quality);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const double factor, const void* source, void* dest, const int width, const int height,
|
|
|
|
|
const int in_channel, const int in_stride, const int out_channel, const int out_stride,
|
|
|
|
|
const int crop_w, const int crop_h, const bool use_tta, const int batch_size)
|
2015-05-29 01:47:26 +09:00
|
|
|
|
{
|
2015-06-03 03:01:56 +09:00
|
|
|
|
Waifu2x::eWaifu2xError ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (!mIsInited)
|
|
|
|
|
return Waifu2x::eWaifu2xError_NotInitialized;
|
2015-06-24 01:07:27 +09:00
|
|
|
|
|
2016-07-08 23:35:51 +09:00
|
|
|
|
int cvrSetting = -1;
|
|
|
|
|
if (in_channel == 3 && out_channel == 3)
|
|
|
|
|
cvrSetting = CV_BGR2RGB;
|
|
|
|
|
else if (in_channel == 4 && out_channel == 4)
|
|
|
|
|
cvrSetting = CV_BGRA2RGBA;
|
|
|
|
|
else if (in_channel == 3 && out_channel == 4)
|
|
|
|
|
cvrSetting = CV_BGR2RGBA;
|
|
|
|
|
else if (in_channel == 4 && out_channel == 3)
|
|
|
|
|
cvrSetting = CV_BGRA2RGB;
|
|
|
|
|
else if (!(in_channel == 1 && out_channel == 1))
|
|
|
|
|
return Waifu2x::eWaifu2xError_InvalidParameter;
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
stImage image;
|
|
|
|
|
ret = image.Load(source, width, height, in_channel, in_stride);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-12-26 18:31:36 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
image.Preprocess(mInputPlane, mMaxNetOffset);
|
2015-12-26 18:31:36 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
const bool isReconstructNoise = mMode == eWaifu2xModelTypeNoise || mMode == eWaifu2xModelTypeNoiseScale;
|
|
|
|
|
const bool isReconstructScale = mMode == eWaifu2xModelTypeScale || mMode == eWaifu2xModelTypeNoiseScale || mMode == eWaifu2xModelTypeAutoScale;
|
2015-12-06 18:48:37 +09:00
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
Factor nowFactor = Factor(factor, 1.0);
|
2016-07-03 17:13:02 +09:00
|
|
|
|
if (!isReconstructScale)
|
2017-04-29 12:41:42 +09:00
|
|
|
|
nowFactor = Factor(1.0, 1.0);
|
2015-06-24 01:07:27 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
cv::Mat reconstruct_image;
|
2017-04-29 12:41:42 +09:00
|
|
|
|
ret = ReconstructImage(nowFactor, crop_w, crop_h, use_tta, batch_size, isReconstructNoise, isReconstructScale, nullptr, image);
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
image.Postprocess(mInputPlane, nowFactor, 8);
|
2015-12-06 18:48:37 +09:00
|
|
|
|
|
2016-07-08 23:35:51 +09:00
|
|
|
|
cv::Mat out_bgr_image = image.GetEndImage();
|
2016-07-03 13:37:26 +09:00
|
|
|
|
image.Clear();
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-08 23:35:51 +09:00
|
|
|
|
cv::Mat out_image;
|
|
|
|
|
if (cvrSetting >= 0)
|
|
|
|
|
cv::cvtColor(out_bgr_image, out_image, cvrSetting); // BGR<47><52><EFBFBD><EFBFBD>RGB<47>ɖ߂<C996>
|
|
|
|
|
else
|
|
|
|
|
out_image = out_bgr_image;
|
|
|
|
|
out_bgr_image.release();
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
// <20>o<EFBFBD>͔z<CD94><7A><EFBFBD>֏<EFBFBD><D68F><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
|
|
|
|
{
|
|
|
|
|
const auto width = out_image.size().width;
|
|
|
|
|
const auto stride = out_image.step1();
|
|
|
|
|
for (int i = 0; i < out_image.size().height; i++)
|
2016-08-15 19:05:48 +09:00
|
|
|
|
memcpy((uint8_t *)dest + out_stride * i, out_image.data + stride * i, out_stride);
|
2015-06-24 01:07:27 +09:00
|
|
|
|
}
|
2015-06-20 03:54:15 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2015-12-06 21:15:45 +09:00
|
|
|
|
}
|
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
Factor Waifu2x::CalcScaleRatio(const boost::optional<double> scale_ratio, const boost::optional<int> scale_width, const boost::optional<int> scale_height,
|
2016-07-03 21:55:20 +09:00
|
|
|
|
const stImage &image)
|
2015-07-10 04:09:22 +09:00
|
|
|
|
{
|
2016-07-03 21:55:20 +09:00
|
|
|
|
if (scale_ratio)
|
2017-04-29 12:41:42 +09:00
|
|
|
|
return Factor(*scale_ratio, 1.0);
|
2015-12-06 18:48:37 +09:00
|
|
|
|
|
2017-03-20 21:44:57 +09:00
|
|
|
|
if (scale_width && scale_height)
|
|
|
|
|
{
|
|
|
|
|
const auto d1 = image.GetScaleFromWidth(*scale_width);
|
|
|
|
|
const auto d2 = image.GetScaleFromWidth(*scale_height);
|
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
return d1.toDouble() >= d2.toDouble() ? d1 : d2;
|
2017-03-20 21:44:57 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 21:55:20 +09:00
|
|
|
|
if (scale_width)
|
|
|
|
|
return image.GetScaleFromWidth(*scale_width);
|
2016-02-04 17:38:27 +09:00
|
|
|
|
|
2016-07-03 21:55:20 +09:00
|
|
|
|
if(scale_height)
|
2016-07-08 22:02:12 +09:00
|
|
|
|
return image.GetScaleFromHeight(*scale_height);
|
2016-02-04 17:38:27 +09:00
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
return Factor(1.0, 1.0);
|
2016-07-03 21:55:20 +09:00
|
|
|
|
}
|
2016-02-04 17:38:27 +09:00
|
|
|
|
|
2016-07-06 18:41:13 +09:00
|
|
|
|
int Waifu2x::GetcuDNNAlgorithm(const char * layer_name, int num_input, int num_output, int batch_size,
|
|
|
|
|
int width, int height, int kernel_w, int kernel_h, int pad_w, int pad_h, int stride_w, int stride_h)
|
|
|
|
|
{
|
2016-07-11 00:40:47 +09:00
|
|
|
|
// g_ConvCcuDNNAlgorithm<68><6D>g_DeconvCcuDNNAlgorithm<68><6D><EFBFBD>t<EFBFBD>ɂȂ<C982><C882>Ă<EFBFBD><C482>܂<EFBFBD><DC82>Ă<EFBFBD><C482>邪<EFBFBD>A<EFBFBD>t<EFBFBD>@<40>C<EFBFBD><43><EFBFBD><EFBFBD><EFBFBD>ɂ<EFBFBD><C982><EFBFBD><EFBFBD>e<EFBFBD><65><EFBFBD><EFBFBD><EFBFBD>Ȃ<EFBFBD><C882>̂ƌ݊<C68C><DD8A><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ȃ<EFBFBD><C882>Ȃ<EFBFBD><C882>̂ł<CC82><C582>̂܂d<DC8E>l<EFBFBD>Ƃ<EFBFBD><C682><EFBFBD>
|
2016-07-06 18:41:13 +09:00
|
|
|
|
if (strcmp(layer_name, "Deconvolution") == 0)
|
|
|
|
|
return g_ConvCcuDNNAlgorithm.GetAlgorithm(num_input, num_output, batch_size, width, height, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h);
|
|
|
|
|
else if (strcmp(layer_name, "Convolution") == 0)
|
|
|
|
|
return g_DeconvCcuDNNAlgorithm.GetAlgorithm(num_input, num_output, batch_size, width, height, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h);
|
|
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Waifu2x::SetcuDNNAlgorithm(int algo, const char * layer_name, int num_input, int num_output, int batch_size,
|
|
|
|
|
int width, int height, int kernel_w, int kernel_h, int pad_w, int pad_h, int stride_w, int stride_h)
|
|
|
|
|
{
|
2016-07-11 00:40:47 +09:00
|
|
|
|
// g_ConvCcuDNNAlgorithm<68><6D>g_DeconvCcuDNNAlgorithm<68><6D><EFBFBD>t<EFBFBD>ɂȂ<C982><C882>Ă<EFBFBD><C482>܂<EFBFBD><DC82>Ă<EFBFBD><C482>邪<EFBFBD>A<EFBFBD>t<EFBFBD>@<40>C<EFBFBD><43><EFBFBD><EFBFBD><EFBFBD>ɂ<EFBFBD><C982><EFBFBD><EFBFBD>e<EFBFBD><65><EFBFBD><EFBFBD><EFBFBD>Ȃ<EFBFBD><C882>̂ƌ݊<C68C><DD8A><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ȃ<EFBFBD><C882>Ȃ<EFBFBD><C882>̂ł<CC82><C582>̂܂d<DC8E>l<EFBFBD>Ƃ<EFBFBD><C682><EFBFBD>
|
2016-07-06 18:41:13 +09:00
|
|
|
|
if (strcmp(layer_name, "Deconvolution") == 0)
|
|
|
|
|
return g_ConvCcuDNNAlgorithm.SetAlgorithm(algo, num_input, num_output, batch_size, width, height, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h);
|
|
|
|
|
else if (strcmp(layer_name, "Convolution") == 0)
|
|
|
|
|
return g_DeconvCcuDNNAlgorithm.SetAlgorithm(algo, num_input, num_output, batch_size, width, height, kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h);
|
|
|
|
|
}
|
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
Waifu2x::eWaifu2xError Waifu2x::ReconstructImage(const Factor factor, const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
|
2016-07-03 13:37:26 +09:00
|
|
|
|
const bool isReconstructNoise, const bool isReconstructScale, const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image)
|
2015-11-19 01:50:11 +09:00
|
|
|
|
{
|
|
|
|
|
Waifu2x::eWaifu2xError ret;
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
Factor nowFactor = factor;
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2015-06-03 03:01:56 +09:00
|
|
|
|
if (isReconstructNoise)
|
2015-07-10 04:09:22 +09:00
|
|
|
|
{
|
2018-11-23 23:49:22 +09:00
|
|
|
|
if (!mHasNoiseScaleOnly) // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
2016-07-03 17:13:02 +09:00
|
|
|
|
{
|
|
|
|
|
cv::Mat im;
|
|
|
|
|
cv::Size_<int> size;
|
|
|
|
|
image.GetScalePaddingedRGB(im, size, mNoiseNet->GetNetOffset(), OuterPadding, crop_w, crop_h, 1);
|
2015-12-27 06:39:13 +09:00
|
|
|
|
|
2016-11-15 00:44:54 +09:00
|
|
|
|
ret = ReconstructByNet(mNoiseNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
|
2016-07-03 17:13:02 +09:00
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-12-27 06:39:13 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
image.SetReconstructedRGB(im, size, 1);
|
|
|
|
|
}
|
|
|
|
|
else // <20>m<EFBFBD>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD>Ɗg<C68A><67>
|
2015-12-27 06:39:13 +09:00
|
|
|
|
{
|
2016-07-03 17:13:02 +09:00
|
|
|
|
ret = ReconstructNoiseScale(crop_w, crop_h, use_tta, batch_size, cancel_func, image);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-12-27 06:39:13 +09:00
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
//nowFactor /= mNoiseNet->GetInnerScale();
|
|
|
|
|
nowFactor = nowFactor.MultiDenominator(mNoiseNet->GetInnerScale());
|
2015-12-27 06:39:13 +09:00
|
|
|
|
}
|
2015-06-03 03:01:56 +09:00
|
|
|
|
}
|
2015-12-27 06:39:13 +09:00
|
|
|
|
|
2015-06-03 03:01:56 +09:00
|
|
|
|
if (cancel_func && cancel_func())
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_Cancel;
|
2015-12-06 18:48:37 +09:00
|
|
|
|
|
2017-04-29 12:41:42 +09:00
|
|
|
|
const int scaleNum = ceil(log(nowFactor.toDouble()) / log(ScaleBase));
|
2015-07-10 04:09:22 +09:00
|
|
|
|
|
2015-06-03 03:01:56 +09:00
|
|
|
|
if (isReconstructScale)
|
2015-07-10 04:09:22 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
for (int i = 0; i < scaleNum; i++)
|
2015-06-01 23:45:40 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
ret = ReconstructScale(crop_w, crop_h, use_tta, batch_size, cancel_func, image);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
2015-06-03 03:01:56 +09:00
|
|
|
|
return ret;
|
2016-07-03 13:37:26 +09:00
|
|
|
|
}
|
2015-07-10 04:09:22 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2015-07-10 04:09:22 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
Waifu2x::eWaifu2xError Waifu2x::ReconstructScale(const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
|
|
|
|
|
const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image)
|
2015-06-03 03:01:56 +09:00
|
|
|
|
{
|
|
|
|
|
Waifu2x::eWaifu2xError ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (image.HasAlpha())
|
2015-06-24 01:07:27 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
cv::Mat im;
|
|
|
|
|
cv::Size_<int> size;
|
|
|
|
|
image.GetScalePaddingedA(im, size, mScaleNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mScaleNet->GetScale() / mScaleNet->GetInnerScale());
|
2015-12-06 03:11:22 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
ret = ReconstructByNet(mScaleNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-06-24 01:07:27 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
image.SetReconstructedA(im, size, mScaleNet->GetInnerScale());
|
2015-06-03 03:01:56 +09:00
|
|
|
|
}
|
2015-06-24 01:07:27 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
cv::Mat im;
|
|
|
|
|
cv::Size_<int> size;
|
|
|
|
|
image.GetScalePaddingedRGB(im, size, mScaleNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mScaleNet->GetScale() / mScaleNet->GetInnerScale());
|
2015-06-24 01:07:27 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
ret = ReconstructByNet(mScaleNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
image.SetReconstructedRGB(im, size, mScaleNet->GetInnerScale());
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2015-11-19 01:50:11 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
Waifu2x::eWaifu2xError Waifu2x::ReconstructNoiseScale(const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
|
|
|
|
|
const Waifu2x::waifu2xCancelFunc cancel_func, stImage &image)
|
2015-11-19 01:50:11 +09:00
|
|
|
|
{
|
|
|
|
|
Waifu2x::eWaifu2xError ret;
|
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
if (image.HasAlpha())
|
2015-06-03 03:01:56 +09:00
|
|
|
|
{
|
2016-07-03 17:13:02 +09:00
|
|
|
|
// <20><><EFBFBD>`<60><><EFBFBD><EFBFBD><EFBFBD>l<EFBFBD><6C><EFBFBD>ɂ̓m<CD83>C<EFBFBD>Y<EFBFBD><59><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>s<EFBFBD><73><EFBFBD>Ȃ<EFBFBD>
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
cv::Mat im;
|
|
|
|
|
cv::Size_<int> size;
|
|
|
|
|
image.GetScalePaddingedA(im, size, mScaleNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mScaleNet->GetScale() / mScaleNet->GetInnerScale());
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
ret = ReconstructByNet(mScaleNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-12-27 19:25:38 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
image.SetReconstructedA(im, size, mScaleNet->GetInnerScale());
|
2015-06-03 03:01:56 +09:00
|
|
|
|
}
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
cv::Mat im;
|
|
|
|
|
cv::Size_<int> size;
|
|
|
|
|
image.GetScalePaddingedRGB(im, size, mNoiseNet->GetNetOffset(), OuterPadding, crop_w, crop_h, mNoiseNet->GetScale() / mNoiseNet->GetInnerScale());
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
ret = ReconstructByNet(mNoiseNet, crop_w, crop_h, use_tta, batch_size, cancel_func, im);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
|
|
|
|
return ret;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
image.SetReconstructedRGB(im, size, mNoiseNet->GetInnerScale());
|
2015-11-19 01:50:11 +09:00
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2015-11-19 01:50:11 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 17:13:02 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
Waifu2x::eWaifu2xError Waifu2x::ReconstructByNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size,
|
|
|
|
|
const Waifu2x::waifu2xCancelFunc cancel_func, cv::Mat &im)
|
2015-12-06 02:13:30 +09:00
|
|
|
|
{
|
|
|
|
|
Waifu2x::eWaifu2xError ret;
|
|
|
|
|
|
|
|
|
|
if (!use_tta) // <20><><EFBFBD>ʂɏ<CA82><C98F><EFBFBD>
|
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
ret = ProcessNet(net, crop_w, crop_h, use_tta, batch_size, im);
|
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
2015-12-06 02:13:30 +09:00
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
else // Test-Time Augmentation Mode
|
|
|
|
|
{
|
|
|
|
|
const auto RotateClockwise90 = [](cv::Mat &mat)
|
|
|
|
|
{
|
|
|
|
|
cv::transpose(mat, mat);
|
|
|
|
|
cv::flip(mat, mat, 1);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const auto RotateClockwise90N = [RotateClockwise90](cv::Mat &mat, const int rotateNum)
|
|
|
|
|
{
|
|
|
|
|
for (int i = 0; i < rotateNum; i++)
|
|
|
|
|
RotateClockwise90(mat);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const auto RotateCounterclockwise90 = [](cv::Mat &mat)
|
|
|
|
|
{
|
|
|
|
|
cv::transpose(mat, mat);
|
|
|
|
|
cv::flip(mat, mat, 0);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const auto RotateCounterclockwise90N = [RotateCounterclockwise90](cv::Mat &mat, const int rotateNum)
|
|
|
|
|
{
|
|
|
|
|
for (int i = 0; i < rotateNum; i++)
|
|
|
|
|
RotateCounterclockwise90(mat);
|
|
|
|
|
};
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
cv::Mat reconstruct_image;
|
2015-12-06 02:13:30 +09:00
|
|
|
|
for (int i = 0; i < 8; i++)
|
|
|
|
|
{
|
2016-07-03 21:40:20 +09:00
|
|
|
|
cv::Mat in(im.clone());
|
2015-12-06 02:13:30 +09:00
|
|
|
|
|
|
|
|
|
const int rotateNum = i % 4;
|
|
|
|
|
RotateClockwise90N(in, rotateNum);
|
|
|
|
|
|
|
|
|
|
if (i >= 4)
|
|
|
|
|
cv::flip(in, in, 1); // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>]
|
|
|
|
|
|
2016-09-11 14:06:32 +09:00
|
|
|
|
const int cw = (rotateNum % 2 == 0) ? crop_w : crop_h;
|
|
|
|
|
const int ch = (rotateNum % 2 == 0) ? crop_h : crop_w;
|
|
|
|
|
|
|
|
|
|
ret = ProcessNet(net, cw, ch, use_tta, batch_size, in);
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
2015-12-06 02:13:30 +09:00
|
|
|
|
return ret;
|
|
|
|
|
|
|
|
|
|
if (i >= 4)
|
|
|
|
|
cv::flip(in, in, 1); // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>]
|
|
|
|
|
|
|
|
|
|
RotateCounterclockwise90N(in, rotateNum);
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (i == 0)
|
|
|
|
|
reconstruct_image = in;
|
2016-04-29 13:57:23 +09:00
|
|
|
|
else
|
2016-07-03 13:37:26 +09:00
|
|
|
|
reconstruct_image += in;
|
2016-03-20 14:10:04 +09:00
|
|
|
|
}
|
2016-04-29 13:57:23 +09:00
|
|
|
|
|
2015-12-06 02:13:30 +09:00
|
|
|
|
reconstruct_image /= 8.0;
|
2016-04-29 13:57:23 +09:00
|
|
|
|
|
2016-07-03 21:40:20 +09:00
|
|
|
|
im = reconstruct_image;
|
2016-04-29 13:57:23 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2016-04-29 15:30:02 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
Waifu2x::eWaifu2xError Waifu2x::ProcessNet(std::shared_ptr<cNet> net, const int crop_w, const int crop_h, const bool use_tta, const int batch_size, cv::Mat &im)
|
2015-11-19 01:50:11 +09:00
|
|
|
|
{
|
|
|
|
|
Waifu2x::eWaifu2xError ret;
|
|
|
|
|
|
2016-07-03 22:23:54 +09:00
|
|
|
|
CudaDeviceSet devset(mProcess, mGPUNo);
|
2015-12-27 07:20:55 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
const auto OutputMemorySize = net->GetOutputMemorySize(crop_w, crop_h, OuterPadding, batch_size);
|
|
|
|
|
if (OutputMemorySize > mOutputBlockSize)
|
2015-12-06 05:22:47 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (mIsCuda)
|
2016-07-06 12:27:03 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
CUDA_HOST_SAFE_FREE(mOutputBlock);
|
2016-07-06 12:27:03 +09:00
|
|
|
|
CUDA_CHECK_WAIFU2X(cudaHostAlloc(&mOutputBlock, OutputMemorySize, cudaHostAllocDefault));
|
|
|
|
|
}
|
2016-07-03 13:37:26 +09:00
|
|
|
|
else
|
2016-07-06 12:27:03 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
SAFE_DELETE_WAIFU2X(mOutputBlock);
|
2016-07-06 12:27:03 +09:00
|
|
|
|
mOutputBlock = new float[OutputMemorySize];
|
|
|
|
|
}
|
2015-12-06 05:22:47 +09:00
|
|
|
|
|
2016-07-03 20:56:00 +09:00
|
|
|
|
mOutputBlockSize = OutputMemorySize;
|
2015-12-06 05:22:47 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 20:56:00 +09:00
|
|
|
|
ret = net->ReconstructImage(use_tta, crop_w, crop_h, OuterPadding, batch_size, mOutputBlock, im, im);
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (ret != Waifu2x::eWaifu2xError_OK)
|
2016-02-06 00:52:49 +09:00
|
|
|
|
return ret;
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return Waifu2x::eWaifu2xError_OK;
|
2016-02-06 00:52:49 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
void Waifu2x::Destroy()
|
2016-02-06 00:52:49 +09:00
|
|
|
|
{
|
2016-07-03 22:23:54 +09:00
|
|
|
|
CudaDeviceSet devset(mProcess, mGPUNo);
|
2016-02-06 00:52:49 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
mNoiseNet.reset();
|
|
|
|
|
mScaleNet.reset();
|
2016-02-06 00:52:49 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
if (mIsCuda)
|
2016-02-06 00:52:49 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
CUDA_HOST_SAFE_FREE(mOutputBlock);
|
2016-02-06 00:52:49 +09:00
|
|
|
|
}
|
2016-07-03 13:37:26 +09:00
|
|
|
|
else
|
2016-02-06 00:52:49 +09:00
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
SAFE_DELETE_WAIFU2X(mOutputBlock);
|
2016-02-06 00:52:49 +09:00
|
|
|
|
}
|
2015-06-02 01:04:20 +09:00
|
|
|
|
|
2016-07-03 13:37:26 +09:00
|
|
|
|
mIsInited = false;
|
2015-05-29 01:47:26 +09:00
|
|
|
|
}
|
2015-06-03 03:01:56 +09:00
|
|
|
|
|
|
|
|
|
const std::string& Waifu2x::used_process() const
|
|
|
|
|
{
|
2016-07-03 13:37:26 +09:00
|
|
|
|
return mProcess;
|
2015-12-27 07:20:55 +09:00
|
|
|
|
}
|
|
|
|
|
|
2016-07-03 15:36:32 +09:00
|
|
|
|
std::string Waifu2x::GetModelName(const boost::filesystem::path & model_dir)
|
2015-12-27 07:20:55 +09:00
|
|
|
|
{
|
2016-07-03 15:36:32 +09:00
|
|
|
|
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
|
|
|
|
|
if (!boost::filesystem::exists(mode_dir_path))
|
|
|
|
|
return std::string();
|
2015-12-27 07:20:55 +09:00
|
|
|
|
|
2016-07-03 15:36:32 +09:00
|
|
|
|
const boost::filesystem::path info_path = mode_dir_path / "info.json";
|
2015-12-27 07:20:55 +09:00
|
|
|
|
|
2016-07-03 15:36:32 +09:00
|
|
|
|
return cNet::GetModelName(info_path);
|
2015-12-27 07:20:55 +09:00
|
|
|
|
}
|
2018-10-25 04:15:53 +09:00
|
|
|
|
|
|
|
|
|
bool Waifu2x::GetInfo(const boost::filesystem::path &model_dir, stInfo &info)
|
|
|
|
|
{
|
|
|
|
|
const boost::filesystem::path mode_dir_path(GetModeDirPath(model_dir));
|
|
|
|
|
if (!boost::filesystem::exists(mode_dir_path))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
const boost::filesystem::path info_path = mode_dir_path / "info.json";
|
|
|
|
|
|
|
|
|
|
return cNet::GetInfo(info_path, info) == Waifu2x::eWaifu2xError_OK;
|
|
|
|
|
}
|