Test-Time Augmentation Mode実装

This commit is contained in:
lltcggie 2015-11-19 01:50:11 +09:00
parent 9d32a969bf
commit fb6f80970a
3 changed files with 153 additions and 31 deletions

View File

@ -815,7 +815,7 @@ Waifu2x::eWaifu2xError Waifu2x::ReconstructImage(boost::shared_ptr<caffe::Net<fl
}
Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &Mode, const int NoiseLevel, const double ScaleRatio, const std::string &ModelDir, const std::string &Process,
const int CropSize, const int BatchSize)
const bool UseTTA, const int CropSize, const int BatchSize)
{
Waifu2x::eWaifu2xError ret;
@ -832,6 +832,7 @@ Waifu2x::eWaifu2xError Waifu2x::init(int argc, char** argv, const std::string &M
scale_ratio = ScaleRatio;
model_dir = ModelDir;
process = Process;
use_tta = UseTTA;
crop_size = CropSize;
batch_size = BatchSize;
@ -1028,29 +1029,19 @@ Waifu2x::eWaifu2xError Waifu2x::WriteMat(const cv::Mat &im, const std::string &o
return eWaifu2xError_FailedOpenOutputFile;
}
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std::string &output_file,
const waifu2xCancelFunc cancel_func)
Waifu2x::eWaifu2xError Waifu2x::BeforeReconstructFloatMatProcess(const cv::Mat &in, cv::Mat &out)
{
Waifu2x::eWaifu2xError ret;
if (!is_inited)
return eWaifu2xError_NotInitialized;
cv::Mat float_image;
ret = LoadMat(float_image, input_file);
if (ret != eWaifu2xError_OK)
return ret;
cv::Mat im;
if (input_plane == 1)
CreateBrightnessImage(float_image, im);
CreateBrightnessImage(in, im);
else
{
std::vector<cv::Mat> planes;
cv::split(float_image, planes);
cv::split(in, planes);
if (float_image.channels() == 4)
if (in.channels() == 4)
planes.resize(3);
// BGRからRGBにする
@ -1058,13 +1049,19 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
cv::merge(planes, im);
}
out = im;
return eWaifu2xError_OK;
}
Waifu2x::eWaifu2xError Waifu2x::ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out)
{
Waifu2x::eWaifu2xError ret;
cv::Mat im(in);
cv::Size_<int> image_size = im.size();
const boost::filesystem::path ip(input_file);
const boost::filesystem::path ipext(ip.extension());
const bool isJpeg = boost::iequals(ipext.string(), ".jpg") || boost::iequals(ipext.string(), ".jpeg");
const bool isReconstructNoise = mode == "noise" || mode == "noise_scale" || (mode == "auto_scale" && isJpeg);
const bool isReconstructScale = mode == "scale" || mode == "noise_scale";
@ -1084,7 +1081,6 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
return eWaifu2xError_Cancel;
const int scale2 = ceil(log2(scale_ratio));
const double shrinkRatio = scale_ratio / std::pow(2.0, (double)scale2);
if (isReconstructScale)
{
@ -1105,18 +1101,24 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
if (cancel_func && cancel_func())
return eWaifu2xError_Cancel;
out = im;
return eWaifu2xError_OK;
}
Waifu2x::eWaifu2xError Waifu2x::AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out)
{
cv::Size_<int> image_size = in.size();
cv::Mat process_image;
if (input_plane == 1)
{
// 再構築した輝度画像とCreateZoomColorImage()で作成した色情報をマージして通常の画像に変換し、書き込む
std::vector<cv::Mat> color_planes;
CreateZoomColorImage(float_image, image_size, color_planes);
CreateZoomColorImage(floatim, image_size, color_planes);
float_image.release();
color_planes[0] = im;
im.release();
color_planes[0] = in;
cv::Mat converted_image;
cv::merge(color_planes, converted_image);
@ -1128,7 +1130,7 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
else
{
std::vector<cv::Mat> planes;
cv::split(im, planes);
cv::split(in, planes);
// RGBからBGRに直す
std::swap(planes[0], planes[2]);
@ -1137,10 +1139,10 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
}
cv::Mat alpha;
if (float_image.channels() == 4)
if (floatim.channels() == 4)
{
std::vector<cv::Mat> planes;
cv::split(float_image, planes);
cv::split(floatim, planes);
alpha = planes[3];
cv::resize(alpha, alpha, image_size, 0.0, 0.0, cv::INTER_CUBIC);
@ -1164,10 +1166,117 @@ Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std
cv::merge(planes, process_image);
}
const int scale2 = ceil(log2(scale_ratio));
const double shrinkRatio = scale_ratio / std::pow(2.0, (double)scale2);
const cv::Size_<int> ns(image_size.width * shrinkRatio, image_size.height * shrinkRatio);
if (image_size.width != ns.width || image_size.height != ns.height)
cv::resize(process_image, process_image, ns, 0.0, 0.0, cv::INTER_LINEAR);
out = process_image;
return eWaifu2xError_OK;
}
Waifu2x::eWaifu2xError Waifu2x::waifu2x(const std::string &input_file, const std::string &output_file,
const waifu2xCancelFunc cancel_func)
{
Waifu2x::eWaifu2xError ret;
if (!is_inited)
return eWaifu2xError_NotInitialized;
const boost::filesystem::path ip(input_file);
const boost::filesystem::path ipext(ip.extension());
const bool isJpeg = boost::iequals(ipext.string(), ".jpg") || boost::iequals(ipext.string(), ".jpeg");
cv::Mat float_image;
ret = LoadMat(float_image, input_file);
if (ret != eWaifu2xError_OK)
return ret;
cv::Mat brfm;
ret = BeforeReconstructFloatMatProcess(float_image, brfm);
if (ret != eWaifu2xError_OK)
return ret;
cv::Mat reconstruct_image;
if (!use_tta) // 普通に処理
{
ret = ReconstructFloatMat(isJpeg, cancel_func, brfm, reconstruct_image);
if (ret != eWaifu2xError_OK)
return ret;
}
else // Test-Time Augmentation Mode
{
const auto RotateClockwise90 = [](cv::Mat &mat)
{
cv::transpose(mat, mat);
cv::flip(mat, mat, 1);
};
const auto RotateClockwise90N = [RotateClockwise90](cv::Mat &mat, const int rotateNum)
{
for (int i = 0; i < rotateNum; i++)
RotateClockwise90(mat);
};
const auto RotateCounterclockwise90 = [](cv::Mat &mat)
{
cv::transpose(mat, mat);
cv::flip(mat, mat, 0);
};
const auto RotateCounterclockwise90N = [RotateCounterclockwise90](cv::Mat &mat, const int rotateNum)
{
for (int i = 0; i < rotateNum; i++)
RotateCounterclockwise90(mat);
};
cv::Mat ri[8];
for (int i = 0; i < 8; i++)
{
cv::Mat in(brfm.clone());
cv::imwrite("0.png", in * 255.0);
const int rotateNum = i % 4;
RotateClockwise90N(in, rotateNum);
cv::imwrite("1.png", in * 255.0);
if(i >= 4)
cv::flip(in, in, 1); // 垂直軸反転
cv::imwrite("2.png", in * 255.0);
ret = ReconstructFloatMat(isJpeg, cancel_func, in, in);
if (ret != eWaifu2xError_OK)
return ret;
cv::imwrite("3.png", in * 255.0);
if (i >= 4)
cv::flip(in, in, 1); // 垂直軸反転
cv::imwrite("4.png", in * 255.0);
RotateCounterclockwise90N(in, rotateNum);
cv::imwrite("5.png", in * 255.0);
ri[i] = in;
}
reconstruct_image = ri[0];
for (int i = 1; i < 8; i++)
reconstruct_image += ri[i];
reconstruct_image /= 8.0;
}
brfm.release();
cv::Mat process_image;
ret = AfterReconstructFloatMatProcess(float_image, reconstruct_image, process_image);
if (ret != eWaifu2xError_OK)
return ret;
float_image.release();
cv::Mat write_iamge;
process_image.convertTo(write_iamge, CV_8U, 255.0);
process_image.release();

View File

@ -90,6 +90,8 @@ private:
float *dummy_data;
float *output_block;
bool use_tta;
private:
static eWaifu2xError LoadMat(cv::Mat &float_image, const std::string &input_file);
static eWaifu2xError LoadMatBySTBI(cv::Mat &float_image, const std::string &input_file);
@ -103,6 +105,10 @@ private:
eWaifu2xError ReconstructImage(boost::shared_ptr<caffe::Net<float>> net, cv::Mat &im);
eWaifu2xError WriteMat(const cv::Mat &im, const std::string &output_file);
eWaifu2xError BeforeReconstructFloatMatProcess(const cv::Mat &in, cv::Mat &out);
eWaifu2xError ReconstructFloatMat(const bool isJpeg, const waifu2xCancelFunc cancel_func, const cv::Mat &in, cv::Mat &out);
eWaifu2xError AfterReconstructFloatMatProcess(const cv::Mat &floatim, const cv::Mat &in, cv::Mat &out);
public:
Waifu2x();
~Waifu2x();
@ -113,7 +119,7 @@ public:
// mode: noise or scale or noise_scale or auto_scale
// process: cpu or gpu or cudnn
eWaifu2xError init(int argc, char** argv, const std::string &mode, const int noise_level, const double scale_ratio, const std::string &model_dir, const std::string &process,
const int crop_size = 128, const int batch_size = 1);
const bool use_tta = false, const int crop_size = 128, const int batch_size = 1);
void destroy();

View File

@ -107,6 +107,13 @@ int main(int argc, char** argv)
"input batch size", false,
1, "int", cmd);
std::vector<int> cmdTTAConstraintV;
cmdTTAConstraintV.push_back(0);
cmdTTAConstraintV.push_back(1);
TCLAP::ValuesConstraint<int> cmdTTAConstraint(cmdTTAConstraintV);
TCLAP::ValueArg<int> cmdTTALevel("t", "tta", "8x slower and slightly high quality",
false, 0, &cmdTTAConstraint, cmd);
// definition of command line argument : end
TCLAP::Arg::enableIgnoreMismatched();
@ -237,7 +244,7 @@ int main(int argc, char** argv)
Waifu2x::eWaifu2xError ret;
Waifu2x w;
ret = w.init(argc, argv, cmdMode.getValue(), cmdNRLevel.getValue(), cmdScaleRatio.getValue(), cmdModelPath.getValue(), cmdProcess.getValue(),
ret = w.init(argc, argv, cmdMode.getValue(), cmdNRLevel.getValue(), cmdScaleRatio.getValue(), cmdModelPath.getValue(), cmdProcess.getValue(), cmdTTALevel.getValue() == 1,
cmdCropSizeFile.getValue(), cmdBatchSizeFile.getValue());
switch (ret)
{