mirror of
https://github.com/lltcggie/waifu2x-caffe.git
synced 2025-06-26 05:32:47 +00:00
Test-Time Augmentation Mode実装
This commit is contained in:
parent
9d32a969bf
commit
fb6f80970a
@ -815,7 +815,7 @@ Waifu2x::eWaifu2xError Waifu2x::ReconstructImage(boost::shared_ptr<caffe::Net<fl
|
||||
}
|
||||
|
||||
Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &Mode, const int NoiseLevel, const double ScaleRatio, const std::string &ModelDir, const std::string &Process,
|
||||
const int CropSize, const int BatchSize)
|
||||
const bool UseTTA, const int CropSize, const int BatchSize)
|
||||
{
|
||||
Waifu2x::eWaifu2xError ret;
|
||||
|
||||
@ -832,6 +832,7 @@ Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &M
|
||||
scale_ratio = ScaleRatio;
|
||||
model_dir = ModelDir;
|
||||
process = Process;
|
||||
use_tta = UseTTA;
|
||||
|
||||
crop_size = CropSize;
|
||||
batch_size = BatchSize;
|
||||
@ -1028,29 +1029,19 @@ Waifu2x::eWaifu2xError Waifu2x::WriteMat(const cv::Mat &im, const std::string &o
|
||||
return eWaifu2xError_FailedOpenOutputFile;
|
||||
}
|
||||
|
||||
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std::string &output_file,
|
||||
const waifu2xCancelFunc cancel_func)
|
||||
Waifu2x::eWaifu2xError Waifu2x::BeforeReconstructFloatMatProcess(const cv::Mat &in, cv::Mat &out)
|
||||
{
|
||||
Waifu2x::eWaifu2xError ret;
|
||||
|
||||
if (!is_inited)
|
||||
return eWaifu2xError_NotInitialized;
|
||||
|
||||
cv::Mat float_image;
|
||||
ret = LoadMat(float_image, input_file);
|
||||
if (ret != eWaifu2xError_OK)
|
||||
return ret;
|
||||
|
||||
cv::Mat im;
|
||||
if (input_plane == 1)
|
||||
CreateBrightnessImage(float_image, im);
|
||||
CreateBrightnessImage(in, im);
|
||||
else
|
||||
{
|
||||
|
||||
std::vector<cv::Mat> planes;
|
||||
cv::split(float_image, planes);
|
||||
cv::split(in, planes);
|
||||
|
||||
if (float_image.channels() == 4)
|
||||
if (in.channels() == 4)
|
||||
planes.resize(3);
|
||||
|
||||
// BGRからRGBにする
|
||||
@ -1058,13 +1049,19 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
||||
|
||||
cv::merge(planes, im);
|
||||
}
|
||||
|
||||
out = im;
|
||||
|
||||
return eWaifu2xError_OK;
|
||||
}
|
||||
|
||||
Waifu2x::eWaifu2xError Waifu2x::ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out)
|
||||
{
|
||||
Waifu2x::eWaifu2xError ret;
|
||||
|
||||
cv::Mat im(in);
|
||||
cv::Size_<int> image_size = im.size();
|
||||
|
||||
const boost::filesystem::path ip(input_file);
|
||||
const boost::filesystem::path ipext(ip.extension());
|
||||
|
||||
const bool isJpeg = boost::iequals(ipext.string(), ".jpg") || boost::iequals(ipext.string(), ".jpeg");
|
||||
|
||||
const bool isReconstructNoise = mode == "noise" || mode == "noise_scale" || (mode == "auto_scale" && isJpeg);
|
||||
const bool isReconstructScale = mode == "scale" || mode == "noise_scale";
|
||||
|
||||
@ -1084,7 +1081,6 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
||||
return eWaifu2xError_Cancel;
|
||||
|
||||
const int scale2 = ceil(log2(scale_ratio));
|
||||
const double shrinkRatio = scale_ratio / std::pow(2.0, (double)scale2);
|
||||
|
||||
if (isReconstructScale)
|
||||
{
|
||||
@ -1105,18 +1101,24 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
||||
if (cancel_func && cancel_func())
|
||||
return eWaifu2xError_Cancel;
|
||||
|
||||
out = im;
|
||||
|
||||
return eWaifu2xError_OK;
|
||||
}
|
||||
|
||||
Waifu2x::eWaifu2xError Waifu2x::AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out)
|
||||
{
|
||||
cv::Size_<int> image_size = in.size();
|
||||
|
||||
cv::Mat process_image;
|
||||
if (input_plane == 1)
|
||||
{
|
||||
// 再構築した輝度画像とCreateZoomColorImage()で作成した色情報をマージして通常の画像に変換し、書き込む
|
||||
|
||||
std::vector<cv::Mat> color_planes;
|
||||
CreateZoomColorImage(float_image, image_size, color_planes);
|
||||
CreateZoomColorImage(floatim, image_size, color_planes);
|
||||
|
||||
float_image.release();
|
||||
|
||||
color_planes[0] = im;
|
||||
im.release();
|
||||
color_planes[0] = in;
|
||||
|
||||
cv::Mat converted_image;
|
||||
cv::merge(color_planes, converted_image);
|
||||
@ -1128,7 +1130,7 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
||||
else
|
||||
{
|
||||
std::vector<cv::Mat> planes;
|
||||
cv::split(im, planes);
|
||||
cv::split(in, planes);
|
||||
|
||||
// RGBからBGRに直す
|
||||
std::swap(planes[0], planes[2]);
|
||||
@ -1137,10 +1139,10 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
||||
}
|
||||
|
||||
cv::Mat alpha;
|
||||
if (float_image.channels() == 4)
|
||||
if (floatim.channels() == 4)
|
||||
{
|
||||
std::vector<cv::Mat> planes;
|
||||
cv::split(float_image, planes);
|
||||
cv::split(floatim, planes);
|
||||
alpha = planes[3];
|
||||
|
||||
cv::resize(alpha, alpha, image_size, 0.0, 0.0, cv::INTER_CUBIC);
|
||||
@ -1164,10 +1166,117 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
|
||||
cv::merge(planes, process_image);
|
||||
}
|
||||
|
||||
const int scale2 = ceil(log2(scale_ratio));
|
||||
const double shrinkRatio = scale_ratio / std::pow(2.0, (double)scale2);
|
||||
|
||||
const cv::Size_<int> ns(image_size.width * shrinkRatio, image_size.height * shrinkRatio);
|
||||
if (image_size.width != ns.width || image_size.height != ns.height)
|
||||
cv::resize(process_image, process_image, ns, 0.0, 0.0, cv::INTER_LINEAR);
|
||||
|
||||
out = process_image;
|
||||
|
||||
return eWaifu2xError_OK;
|
||||
}
|
||||
|
||||
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std::string &output_file,
|
||||
const waifu2xCancelFunc cancel_func)
|
||||
{
|
||||
Waifu2x::eWaifu2xError ret;
|
||||
|
||||
if (!is_inited)
|
||||
return eWaifu2xError_NotInitialized;
|
||||
|
||||
const boost::filesystem::path ip(input_file);
|
||||
const boost::filesystem::path ipext(ip.extension());
|
||||
|
||||
const bool isJpeg = boost::iequals(ipext.string(), ".jpg") || boost::iequals(ipext.string(), ".jpeg");
|
||||
|
||||
cv::Mat float_image;
|
||||
ret = LoadMat(float_image, input_file);
|
||||
if (ret != eWaifu2xError_OK)
|
||||
return ret;
|
||||
|
||||
cv::Mat brfm;
|
||||
ret = BeforeReconstructFloatMatProcess(float_image, brfm);
|
||||
if (ret != eWaifu2xError_OK)
|
||||
return ret;
|
||||
|
||||
cv::Mat reconstruct_image;
|
||||
if (!use_tta) // 普通に処理
|
||||
{
|
||||
ret = ReconstructFloatMat(isJpeg, cancel_func, brfm, reconstruct_image);
|
||||
if (ret != eWaifu2xError_OK)
|
||||
return ret;
|
||||
}
|
||||
else // Test-Time Augmentation Mode
|
||||
{
|
||||
const auto RotateClockwise90 = [](cv::Mat &mat)
|
||||
{
|
||||
cv::transpose(mat, mat);
|
||||
cv::flip(mat, mat, 1);
|
||||
};
|
||||
|
||||
const auto RotateClockwise90N = [RotateClockwise90](cv::Mat &mat, const int rotateNum)
|
||||
{
|
||||
for (int i = 0; i < rotateNum; i++)
|
||||
RotateClockwise90(mat);
|
||||
};
|
||||
|
||||
const auto RotateCounterclockwise90 = [](cv::Mat &mat)
|
||||
{
|
||||
cv::transpose(mat, mat);
|
||||
cv::flip(mat, mat, 0);
|
||||
};
|
||||
|
||||
const auto RotateCounterclockwise90N = [RotateCounterclockwise90](cv::Mat &mat, const int rotateNum)
|
||||
{
|
||||
for (int i = 0; i < rotateNum; i++)
|
||||
RotateCounterclockwise90(mat);
|
||||
};
|
||||
|
||||
cv::Mat ri[8];
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
cv::Mat in(brfm.clone());
|
||||
|
||||
cv::imwrite("0.png", in * 255.0);
|
||||
|
||||
const int rotateNum = i % 4;
|
||||
RotateClockwise90N(in, rotateNum);
|
||||
cv::imwrite("1.png", in * 255.0);
|
||||
|
||||
if(i >= 4)
|
||||
cv::flip(in, in, 1); // 垂直軸反転
|
||||
cv::imwrite("2.png", in * 255.0);
|
||||
|
||||
ret = ReconstructFloatMat(isJpeg, cancel_func, in, in);
|
||||
if (ret != eWaifu2xError_OK)
|
||||
return ret;
|
||||
cv::imwrite("3.png", in * 255.0);
|
||||
if (i >= 4)
|
||||
cv::flip(in, in, 1); // 垂直軸反転
|
||||
cv::imwrite("4.png", in * 255.0);
|
||||
RotateCounterclockwise90N(in, rotateNum);
|
||||
cv::imwrite("5.png", in * 255.0);
|
||||
ri[i] = in;
|
||||
}
|
||||
|
||||
reconstruct_image = ri[0];
|
||||
for (int i = 1; i < 8; i++)
|
||||
reconstruct_image += ri[i];
|
||||
|
||||
reconstruct_image /= 8.0;
|
||||
}
|
||||
|
||||
brfm.release();
|
||||
|
||||
cv::Mat process_image;
|
||||
ret = AfterReconstructFloatMatProcess(float_image, reconstruct_image, process_image);
|
||||
if (ret != eWaifu2xError_OK)
|
||||
return ret;
|
||||
|
||||
float_image.release();
|
||||
|
||||
cv::Mat write_iamge;
|
||||
process_image.convertTo(write_iamge, CV_8U, 255.0);
|
||||
process_image.release();
|
||||
|
@ -90,6 +90,8 @@ private:
|
||||
float *dummy_data;
|
||||
float *output_block;
|
||||
|
||||
bool use_tta;
|
||||
|
||||
private:
|
||||
static eWaifu2xError LoadMat(cv::Mat &float_image, const std::string &input_file);
|
||||
static eWaifu2xError LoadMatBySTBI(cv::Mat &float_image, const std::string &input_file);
|
||||
@ -103,6 +105,10 @@ private:
|
||||
eWaifu2xError ReconstructImage(boost::shared_ptr<caffe::Net<float>> net, cv::Mat &im);
|
||||
eWaifu2xError WriteMat(const cv::Mat &im, const std::string &output_file);
|
||||
|
||||
eWaifu2xError BeforeReconstructFloatMatProcess(const cv::Mat &in, cv::Mat &out);
|
||||
eWaifu2xError ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out);
|
||||
eWaifu2xError AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out);
|
||||
|
||||
public:
|
||||
Waifu2x();
|
||||
~Waifu2x();
|
||||
@ -113,7 +119,7 @@ public:
|
||||
// mode: noise or scale or noise_scale or auto_scale
|
||||
// process: cpu or gpu or cudnn
|
||||
eWaifu2xError init(int argc, char** argv, const std::string &mode, const int noise_level, const double scale_ratio, const std::string &model_dir, const std::string &process,
|
||||
const int crop_size = 128, const int batch_size = 1);
|
||||
const bool use_tta = false, const int crop_size = 128, const int batch_size = 1);
|
||||
|
||||
void destroy();
|
||||
|
||||
|
@ -107,6 +107,13 @@ int main(int argc, char** argv)
|
||||
"input batch size", false,
|
||||
1, "int", cmd);
|
||||
|
||||
std::vector<int> cmdTTAConstraintV;
|
||||
cmdTTAConstraintV.push_back(0);
|
||||
cmdTTAConstraintV.push_back(1);
|
||||
TCLAP::ValuesConstraint<int> cmdTTAConstraint(cmdTTAConstraintV);
|
||||
TCLAP::ValueArg<int> cmdTTALevel("t", "tta", "8x slower and slightly high quality",
|
||||
false, 0, &cmdTTAConstraint, cmd);
|
||||
|
||||
// definition of command line argument : end
|
||||
|
||||
TCLAP::Arg::enableIgnoreMismatched();
|
||||
@ -237,7 +244,7 @@ int main(int argc, char** argv)
|
||||
|
||||
Waifu2x::eWaifu2xError ret;
|
||||
Waifu2x w;
|
||||
ret = w.init(argc, argv, cmdMode.getValue(), cmdNRLevel.getValue(), cmdScaleRatio.getValue(), cmdModelPath.getValue(), cmdProcess.getValue(),
|
||||
ret = w.init(argc, argv, cmdMode.getValue(), cmdNRLevel.getValue(), cmdScaleRatio.getValue(), cmdModelPath.getValue(), cmdProcess.getValue(), cmdTTALevel.getValue() == 1,
|
||||
cmdCropSizeFile.getValue(), cmdBatchSizeFile.getValue());
|
||||
switch (ret)
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user