diff --git a/common/waifu2x.cpp b/common/waifu2x.cpp index 4120f69..dc6f058 100644 --- a/common/waifu2x.cpp +++ b/common/waifu2x.cpp @@ -815,7 +815,7 @@ Waifu2x::eWaifu2xError Waifu2x::ReconstructImage(boost::shared_ptr planes; - cv::split(float_image, planes); + cv::split(in, planes); - if (float_image.channels() == 4) + if (in.channels() == 4) planes.resize(3); // BGRからRGBにする @@ -1058,13 +1049,19 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std cv::merge(planes, im); } + + out = im; + + return eWaifu2xError_OK; +} + +Waifu2x::eWaifu2xError Waifu2x::ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out) +{ + Waifu2x::eWaifu2xError ret; + + cv::Mat im(in); cv::Size_ image_size = im.size(); - const boost::filesystem::path ip(input_file); - const boost::filesystem::path ipext(ip.extension()); - - const bool isJpeg = boost::iequals(ipext.string(), ".jpg") || boost::iequals(ipext.string(), ".jpeg"); - const bool isReconstructNoise = mode == "noise" || mode == "noise_scale" || (mode == "auto_scale" && isJpeg); const bool isReconstructScale = mode == "scale" || mode == "noise_scale"; @@ -1084,7 +1081,6 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std return eWaifu2xError_Cancel; const int scale2 = ceil(log2(scale_ratio)); - const double shrinkRatio = scale_ratio / std::pow(2.0, (double)scale2); if (isReconstructScale) { @@ -1105,18 +1101,24 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std if (cancel_func && cancel_func()) return eWaifu2xError_Cancel; + out = im; + + return eWaifu2xError_OK; +} + +Waifu2x::eWaifu2xError Waifu2x::AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out) +{ + cv::Size_ image_size = in.size(); + cv::Mat process_image; if (input_plane == 1) { // 再構築した輝度画像とCreateZoomColorImage()で作成した色情報をマージして通常の画像に変換し、書き込む std::vector color_planes; - CreateZoomColorImage(float_image, image_size, color_planes); + CreateZoomColorImage(floatim, image_size, color_planes); - float_image.release(); - - color_planes[0] = im; - im.release(); + color_planes[0] = in; cv::Mat converted_image; cv::merge(color_planes, converted_image); @@ -1128,7 +1130,7 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std else { std::vector planes; - cv::split(im, planes); + cv::split(in, planes); // RGBからBGRに直す std::swap(planes[0], planes[2]); @@ -1137,10 +1139,10 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std } cv::Mat alpha; - if (float_image.channels() == 4) + if (floatim.channels() == 4) { std::vector planes; - cv::split(float_image, planes); + cv::split(floatim, planes); alpha = planes[3]; cv::resize(alpha, alpha, image_size, 0.0, 0.0, cv::INTER_CUBIC); @@ -1164,10 +1166,117 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std cv::merge(planes, process_image); } + const int scale2 = ceil(log2(scale_ratio)); + const double shrinkRatio = scale_ratio / std::pow(2.0, (double)scale2); + const cv::Size_ ns(image_size.width * shrinkRatio, image_size.height * shrinkRatio); if (image_size.width != ns.width || image_size.height != ns.height) cv::resize(process_image, process_image, ns, 0.0, 0.0, cv::INTER_LINEAR); + out = process_image; + + return eWaifu2xError_OK; +} + +Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std::string &output_file, + const waifu2xCancelFunc cancel_func) +{ + Waifu2x::eWaifu2xError ret; + + if (!is_inited) + return eWaifu2xError_NotInitialized; + + const boost::filesystem::path ip(input_file); + const boost::filesystem::path ipext(ip.extension()); + + const bool isJpeg = boost::iequals(ipext.string(), ".jpg") || boost::iequals(ipext.string(), ".jpeg"); + + cv::Mat float_image; + ret = LoadMat(float_image, input_file); + if (ret != eWaifu2xError_OK) + return ret; + + cv::Mat brfm; + ret = BeforeReconstructFloatMatProcess(float_image, brfm); + if (ret != eWaifu2xError_OK) + return ret; + + cv::Mat reconstruct_image; + if (!use_tta) // 普通に処理 + { + ret = ReconstructFloatMat(isJpeg, cancel_func, brfm, reconstruct_image); + if (ret != eWaifu2xError_OK) + return ret; + } + else // Test-Time Augmentation Mode + { + const auto RotateClockwise90 = [](cv::Mat &mat) + { + cv::transpose(mat, mat); + cv::flip(mat, mat, 1); + }; + + const auto RotateClockwise90N = [RotateClockwise90](cv::Mat &mat, const int rotateNum) + { + for (int i = 0; i < rotateNum; i++) + RotateClockwise90(mat); + }; + + const auto RotateCounterclockwise90 = [](cv::Mat &mat) + { + cv::transpose(mat, mat); + cv::flip(mat, mat, 0); + }; + + const auto RotateCounterclockwise90N = [RotateCounterclockwise90](cv::Mat &mat, const int rotateNum) + { + for (int i = 0; i < rotateNum; i++) + RotateCounterclockwise90(mat); + }; + + cv::Mat ri[8]; + for (int i = 0; i < 8; i++) + { + cv::Mat in(brfm.clone()); + + cv::imwrite("0.png", in * 255.0); + + const int rotateNum = i % 4; + RotateClockwise90N(in, rotateNum); + cv::imwrite("1.png", in * 255.0); + + if(i >= 4) + cv::flip(in, in, 1); // 垂直軸反転 + cv::imwrite("2.png", in * 255.0); + + ret = ReconstructFloatMat(isJpeg, cancel_func, in, in); + if (ret != eWaifu2xError_OK) + return ret; + cv::imwrite("3.png", in * 255.0); + if (i >= 4) + cv::flip(in, in, 1); // 垂直軸反転 + cv::imwrite("4.png", in * 255.0); + RotateCounterclockwise90N(in, rotateNum); + cv::imwrite("5.png", in * 255.0); + ri[i] = in; + } + + reconstruct_image = ri[0]; + for (int i = 1; i < 8; i++) + reconstruct_image += ri[i]; + + reconstruct_image /= 8.0; + } + + brfm.release(); + + cv::Mat process_image; + ret = AfterReconstructFloatMatProcess(float_image, reconstruct_image, process_image); + if (ret != eWaifu2xError_OK) + return ret; + + float_image.release(); + cv::Mat write_iamge; process_image.convertTo(write_iamge, CV_8U, 255.0); process_image.release(); diff --git a/common/waifu2x.h b/common/waifu2x.h index 738849a..f5e9fad 100644 --- a/common/waifu2x.h +++ b/common/waifu2x.h @@ -90,6 +90,8 @@ private: float *dummy_data; float *output_block; + bool use_tta; + private: static eWaifu2xError LoadMat(cv::Mat &float_image, const std::string &input_file); static eWaifu2xError LoadMatBySTBI(cv::Mat &float_image, const std::string &input_file); @@ -103,6 +105,10 @@ private: eWaifu2xError ReconstructImage(boost::shared_ptr> net, cv::Mat &im); eWaifu2xError WriteMat(const cv::Mat &im, const std::string &output_file); + eWaifu2xError BeforeReconstructFloatMatProcess(const cv::Mat &in, cv::Mat &out); + eWaifu2xError ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out); + eWaifu2xError AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out); + public: Waifu2x(); ~Waifu2x(); @@ -113,7 +119,7 @@ public: // mode: noise or scale or noise_scale or auto_scale // process: cpu or gpu or cudnn eWaifu2xError init(int argc, char** argv, const std::string &mode, const int noise_level, const double scale_ratio, const std::string &model_dir, const std::string &process, - const int crop_size = 128, const int batch_size = 1); + const bool use_tta = false, const int crop_size = 128, const int batch_size = 1); void destroy(); diff --git a/waifu2x-caffe/Source.cpp b/waifu2x-caffe/Source.cpp index 5420978..fe24de6 100644 --- a/waifu2x-caffe/Source.cpp +++ b/waifu2x-caffe/Source.cpp @@ -107,6 +107,13 @@ int main(int argc, char** argv) "input batch size", false, 1, "int", cmd); + std::vector cmdTTAConstraintV; + cmdTTAConstraintV.push_back(0); + cmdTTAConstraintV.push_back(1); + TCLAP::ValuesConstraint cmdTTAConstraint(cmdTTAConstraintV); + TCLAP::ValueArg cmdTTALevel("t", "tta", "8x slower and slightly high quality", + false, 0, &cmdTTAConstraint, cmd); + // definition of command line argument : end TCLAP::Arg::enableIgnoreMismatched(); @@ -237,7 +244,7 @@ int main(int argc, char** argv) Waifu2x::eWaifu2xError ret; Waifu2x w; - ret = w.init(argc, argv, cmdMode.getValue(), cmdNRLevel.getValue(), cmdScaleRatio.getValue(), cmdModelPath.getValue(), cmdProcess.getValue(), + ret = w.init(argc, argv, cmdMode.getValue(), cmdNRLevel.getValue(), cmdScaleRatio.getValue(), cmdModelPath.getValue(), cmdProcess.getValue(), cmdTTALevel.getValue() == 1, cmdCropSizeFile.getValue(), cmdBatchSizeFile.getValue()); switch (ret) {