diff --git a/common/waifu2x.cpp b/common/waifu2x.cpp index b9c65ad..396db4b 100644 --- a/common/waifu2x.cpp +++ b/common/waifu2x.cpp @@ -1204,6 +1204,83 @@ Waifu2x::eWaifu2xError Waifu2x::ReconstructFloatMat(const bool isJpeg, const wai return eWaifu2xError_OK; } +Waifu2x::eWaifu2xError Waifu2x::Reconstruct(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out) +{ + Waifu2x::eWaifu2xError ret; + + cv::Mat brfm; + ret = BeforeReconstructFloatMatProcess(in, brfm); + if (ret != eWaifu2xError_OK) + return ret; + + cv::Mat reconstruct_image; + if (!use_tta) // 普通に処理 + { + ret = ReconstructFloatMat(isJpeg, cancel_func, brfm, reconstruct_image); + if (ret != eWaifu2xError_OK) + return ret; + } + else // Test-Time Augmentation Mode + { + const auto RotateClockwise90 = [](cv::Mat &mat) + { + cv::transpose(mat, mat); + cv::flip(mat, mat, 1); + }; + + const auto RotateClockwise90N = [RotateClockwise90](cv::Mat &mat, const int rotateNum) + { + for (int i = 0; i < rotateNum; i++) + RotateClockwise90(mat); + }; + + const auto RotateCounterclockwise90 = [](cv::Mat &mat) + { + cv::transpose(mat, mat); + cv::flip(mat, mat, 0); + }; + + const auto RotateCounterclockwise90N = [RotateCounterclockwise90](cv::Mat &mat, const int rotateNum) + { + for (int i = 0; i < rotateNum; i++) + RotateCounterclockwise90(mat); + }; + + cv::Mat ri[8]; + for (int i = 0; i < 8; i++) + { + cv::Mat in(brfm.clone()); + + const int rotateNum = i % 4; + RotateClockwise90N(in, rotateNum); + + if (i >= 4) + cv::flip(in, in, 1); // 垂直軸反転 + + ret = ReconstructFloatMat(isJpeg, cancel_func, in, in); + if (ret != eWaifu2xError_OK) + return ret; + + if (i >= 4) + cv::flip(in, in, 1); // 垂直軸反転 + + RotateCounterclockwise90N(in, rotateNum); + + ri[i] = in; + } + + reconstruct_image = ri[0]; + for (int i = 1; i < 8; i++) + reconstruct_image += ri[i]; + + reconstruct_image /= 8.0; + } + + out = reconstruct_image; + + return eWaifu2xError_OK; +} + Waifu2x::eWaifu2xError Waifu2x::AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out) { cv::Size_ image_size = in.size(); @@ -1310,76 +1387,11 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std if (ret != eWaifu2xError_OK) return ret; - cv::Mat brfm; - ret = BeforeReconstructFloatMatProcess(float_image, brfm); + cv::Mat reconstruct_image; + ret = Reconstruct(isJpeg, cancel_func, float_image, reconstruct_image); if (ret != eWaifu2xError_OK) return ret; - cv::Mat reconstruct_image; - if (!use_tta) // 普通に処理 - { - ret = ReconstructFloatMat(isJpeg, cancel_func, brfm, reconstruct_image); - if (ret != eWaifu2xError_OK) - return ret; - } - else // Test-Time Augmentation Mode - { - const auto RotateClockwise90 = [](cv::Mat &mat) - { - cv::transpose(mat, mat); - cv::flip(mat, mat, 1); - }; - - const auto RotateClockwise90N = [RotateClockwise90](cv::Mat &mat, const int rotateNum) - { - for (int i = 0; i < rotateNum; i++) - RotateClockwise90(mat); - }; - - const auto RotateCounterclockwise90 = [](cv::Mat &mat) - { - cv::transpose(mat, mat); - cv::flip(mat, mat, 0); - }; - - const auto RotateCounterclockwise90N = [RotateCounterclockwise90](cv::Mat &mat, const int rotateNum) - { - for (int i = 0; i < rotateNum; i++) - RotateCounterclockwise90(mat); - }; - - cv::Mat ri[8]; - for (int i = 0; i < 8; i++) - { - cv::Mat in(brfm.clone()); - - const int rotateNum = i % 4; - RotateClockwise90N(in, rotateNum); - - if(i >= 4) - cv::flip(in, in, 1); // 垂直軸反転 - - ret = ReconstructFloatMat(isJpeg, cancel_func, in, in); - if (ret != eWaifu2xError_OK) - return ret; - - if (i >= 4) - cv::flip(in, in, 1); // 垂直軸反転 - - RotateCounterclockwise90N(in, rotateNum); - - ri[i] = in; - } - - reconstruct_image = ri[0]; - for (int i = 1; i < 8; i++) - reconstruct_image += ri[i]; - - reconstruct_image /= 8.0; - } - - brfm.release(); - cv::Mat process_image; ret = AfterReconstructFloatMatProcess(float_image, reconstruct_image, process_image); if (ret != eWaifu2xError_OK) diff --git a/common/waifu2x.h b/common/waifu2x.h index 33e0313..d1b4459 100644 --- a/common/waifu2x.h +++ b/common/waifu2x.h @@ -110,6 +110,7 @@ private: eWaifu2xError BeforeReconstructFloatMatProcess(const cv::Mat &in, cv::Mat &out); eWaifu2xError ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out); + eWaifu2xError Reconstruct(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out); eWaifu2xError AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out); public: