diff --git a/common/waifu2x.cpp b/common/waifu2x.cpp index 2cb1a55..ac7faff 100644 --- a/common/waifu2x.cpp +++ b/common/waifu2x.cpp @@ -465,21 +465,34 @@ Waifu2x::eWaifu2xcuDNNError Waifu2x::can_use_cuDNN() typedef cudnnStatus_t(CUDNNWINAPI* cudnnCreateType)(cudnnHandle_t *); typedef cudnnStatus_t(CUDNNWINAPI* cudnnDestroyType)(cudnnHandle_t); typedef size_t(CUDNNWINAPI* cudnnGetVersionType)(); + typedef size_t(CUDNNWINAPI* cudnnGetCudartVersionType)(); cudnnCreateType cudnnCreateFunc = (cudnnCreateType)GetProcAddress(hModule, "cudnnCreate"); cudnnDestroyType cudnnDestroyFunc = (cudnnDestroyType)GetProcAddress(hModule, "cudnnDestroy"); cudnnGetVersionType cudnnGetVersionFunc = (cudnnGetVersionType)GetProcAddress(hModule, "cudnnGetVersion"); - if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr) + cudnnGetCudartVersionType cudnnGetCudartVersionFunc = (cudnnGetCudartVersionType)GetProcAddress(hModule, "cudnnGetCudartVersion"); + if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr && cudnnGetCudartVersionFunc != nullptr) { if (cudnnGetVersionFunc() >= CUDNN_REQUIRE_VERION) { - cudnnHandle_t h; - if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS) + int runtimeVersion; + if (cudaRuntimeGetVersion(&runtimeVersion) == cudaSuccess) { - if (cudnnDestroyFunc(h) == CUDNN_STATUS_SUCCESS) - cuDNNFlag = eWaifu2xcuDNNError_OK; + if (cudnnGetCudartVersionFunc() >= runtimeVersion) + { + cudnnHandle_t h; + if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS) + { + if (cudnnDestroyFunc(h) == CUDNN_STATUS_SUCCESS) + cuDNNFlag = eWaifu2xcuDNNError_OK; + else + cuDNNFlag = eWaifu2xcuDNNError_CannotCreate; + } + else + cuDNNFlag = eWaifu2xcuDNNError_CannotCreate; + } else - cuDNNFlag = eWaifu2xcuDNNError_CannotCreate; + cuDNNFlag = eWaifu2xcuDNNError_OldCudaVersion; } else cuDNNFlag = eWaifu2xcuDNNError_CannotCreate; diff --git a/common/waifu2x.h b/common/waifu2x.h index e41eacd..c70a554 100644 --- a/common/waifu2x.h +++ b/common/waifu2x.h @@ -122,6 +122,7 @@ public: eWaifu2xcuDNNError_NotFind, eWaifu2xcuDNNError_OldVersion, eWaifu2xcuDNNError_CannotCreate, + eWaifu2xcuDNNError_OldCudaVersion, }; typedef std::function waifu2xCancelFunc; diff --git a/waifu2x-caffe-gui/MainDialog.cpp b/waifu2x-caffe-gui/MainDialog.cpp index f9dd1bc..d22194e 100644 --- a/waifu2x-caffe-gui/MainDialog.cpp +++ b/waifu2x-caffe-gui/MainDialog.cpp @@ -2737,6 +2737,9 @@ void DialogEvent::CheckCUDNN(HWND hWnd, WPARAM wParam, LPARAM lParam, LPVOID lpD case Waifu2x::eWaifu2xcuDNNError_CannotCreate: MessageBox(dh, langStringList.GetString(L"MessagecuDNNCannotCreateError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR); break; + case Waifu2x::eWaifu2xcuDNNError_OldCudaVersion: + MessageBox(dh, langStringList.GetString(L"MessageCudaOldVersionError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR); + break; default: MessageBox(dh, langStringList.GetString(L"MessagecuDNNDefautlError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR); }