diff --git a/common/waifu2x.cpp b/common/waifu2x.cpp index 1943f3b..fa76d4f 100644 --- a/common/waifu2x.cpp +++ b/common/waifu2x.cpp @@ -71,13 +71,35 @@ bool can_use_cuDNN() // 画像を読み込んで値を0.0f〜1.0fの範囲に変換 eWaifu2xError LoadImage(cv::Mat &float_image, const std::string &input_file) { - cv::Mat original_image = cv::imread(input_file, cv::IMREAD_COLOR); + cv::Mat original_image = cv::imread(input_file, cv::IMREAD_UNCHANGED); if (original_image.empty()) return eWaifu2xError_FailedOpenInputFile; - original_image.convertTo(float_image, CV_32F, 1.0 / 255.0); + cv::Mat convert; + original_image.convertTo(convert, CV_32F, 1.0 / 255.0); original_image.release(); + if (convert.channels() == 1) + cv::cvtColor(convert, convert, cv::COLOR_GRAY2BGR); + else if (convert.channels() == 4) + { + // アルファチャンネル付きだったら背景を1(白)として画像合成する + + std::vector planes; + cv::split(convert, planes); + + cv::Mat w2 = planes[3]; + cv::Mat w1 = 1.0 - planes[3]; + + planes[0] = planes[0].mul(w2) + w1; + planes[1] = planes[1].mul(w2) + w1; + planes[2] = planes[2].mul(w2) + w1; + + cv::merge(planes, convert); + } + + float_image = convert; + return eWaifu2xError_OK; } @@ -546,6 +568,17 @@ eWaifu2xError waifu2x(int argc, char** argv, const std::vector color_planes; CreateZoomColorImage(float_image, image_size, color_planes); + + cv::Mat alpha; + if (float_image.channels() == 4) + { + std::vector planes; + cv::split(float_image, planes); + alpha = planes[3]; + + cv::resize(alpha, alpha, image_size, 0.0, 0.0, cv::INTER_CUBIC); + } + float_image.release(); color_planes[0] = im; @@ -559,6 +592,24 @@ eWaifu2xError waifu2x(int argc, char** argv, const std::vector planes; + cv::split(process_image, planes); + process_image.release(); + + planes.push_back(alpha); + + cv::Mat w2 = planes[3]; + + planes[0] = (planes[0] - 1.0).mul(1.0 / w2) + 1.0; + planes[1] = (planes[1] - 1.0).mul(1.0 / w2) + 1.0; + planes[2] = (planes[2] - 1.0).mul(1.0 / w2) + 1.0; + + cv::merge(planes, process_image); + } + const cv::Size_ ns(image_size.width * shrinkRatio, image_size.height * shrinkRatio); if (image_size.width != ns.width || image_size.height != ns.height) cv::resize(process_image, process_image, ns, 0.0, 0.0, cv::INTER_LINEAR);