From aba4f717fb2184d5417cad9832689a1b4e4c7253 Mon Sep 17 00:00:00 2001 From: lltcggie Date: Wed, 6 Jul 2016 22:33:09 +0900 Subject: [PATCH] =?UTF-8?q?cuDNN=E3=81=AE=E3=82=A2=E3=83=AB=E3=82=B4?= =?UTF-8?q?=E3=83=AA=E3=82=BA=E3=83=A0=E3=83=87=E3=83=BC=E3=82=BF=E8=AA=AD?= =?UTF-8?q?=E3=81=BF=E6=9B=B8=E3=81=8D=E3=81=AE=E4=BE=8B=E5=A4=96=E5=AF=BE?= =?UTF-8?q?=E7=AD=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/waifu2x.cpp | 48 +++++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/common/waifu2x.cpp b/common/waifu2x.cpp index 55032e2..9725f63 100644 --- a/common/waifu2x.cpp +++ b/common/waifu2x.cpp @@ -318,12 +318,19 @@ private: fclose(fp); - CcuDNNAlgorithmElement elm; - msgpack::unpack(sbuf.data(), sbuf.size()).get().convert(elm); - sbuf.clear(); + try + { + CcuDNNAlgorithmElement elm; + msgpack::unpack(sbuf.data(), sbuf.size()).get().convert(elm); + sbuf.clear(); - const uint64_t key = InfoToKey(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size); - mAlgoEmlMap[key] = std::move(elm); + const uint64_t key = InfoToKey(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size); + mAlgoEmlMap[key] = std::move(elm); + } + catch (...) + { + boost::filesystem::remove(SavePath); + } return true; } @@ -373,22 +380,27 @@ public: auto &eml = p.second; if (eml.IsModefy()) { - msgpack::sbuffer sbuf; - msgpack::pack(sbuf, eml); - - uint8_t kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h; - uint16_t batch_size; - eml.GetLayerData(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size); - - const std::string SavePath = GetDataPath(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size); - FILE *fp = fopen(SavePath.c_str(), "wb"); - if (fp) + try { - fwrite(sbuf.data(), 1, sbuf.size(), fp); - fclose(fp); + msgpack::sbuffer sbuf; + msgpack::pack(sbuf, eml); - eml.Saved(); + uint8_t kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h; + uint16_t batch_size; + eml.GetLayerData(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size); + + const std::string SavePath = GetDataPath(kernel_w, kernel_h, pad_w, pad_h, stride_w, stride_h, batch_size); + FILE *fp = fopen(SavePath.c_str(), "wb"); + if (fp) + { + fwrite(sbuf.data(), 1, sbuf.size(), fp); + fclose(fp); + + eml.Saved(); + } } + catch(...) + {} } } }