cuDNNが使うランタイムよりCUDAランタイムのバージョンが古くないかのチェックを入れた

This commit is contained in:
lltcggie 2020-09-05 16:49:04 +09:00
parent c365977a1e
commit 3812b90f68
3 changed files with 23 additions and 6 deletions

View File

@ -465,13 +465,20 @@ Waifu2x::eWaifu2xcuDNNError Waifu2x::can_use_cuDNN()
typedef cudnnStatus_t(CUDNNWINAPI* cudnnCreateType)(cudnnHandle_t *);
typedef cudnnStatus_t(CUDNNWINAPI* cudnnDestroyType)(cudnnHandle_t);
typedef size_t(CUDNNWINAPI* cudnnGetVersionType)();
typedef size_t(CUDNNWINAPI* cudnnGetCudartVersionType)();
cudnnCreateType cudnnCreateFunc = (cudnnCreateType)GetProcAddress(hModule, "cudnnCreate");
cudnnDestroyType cudnnDestroyFunc = (cudnnDestroyType)GetProcAddress(hModule, "cudnnDestroy");
cudnnGetVersionType cudnnGetVersionFunc = (cudnnGetVersionType)GetProcAddress(hModule, "cudnnGetVersion");
if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr)
cudnnGetCudartVersionType cudnnGetCudartVersionFunc = (cudnnGetCudartVersionType)GetProcAddress(hModule, "cudnnGetCudartVersion");
if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr && cudnnGetCudartVersionFunc != nullptr)
{
if (cudnnGetVersionFunc() >= CUDNN_REQUIRE_VERION)
{
int runtimeVersion;
if (cudaRuntimeGetVersion(&runtimeVersion) == cudaSuccess)
{
if (cudnnGetCudartVersionFunc() >= runtimeVersion)
{
cudnnHandle_t h;
if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS)
@ -484,6 +491,12 @@ Waifu2x::eWaifu2xcuDNNError Waifu2x::can_use_cuDNN()
else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
}
else
cuDNNFlag = eWaifu2xcuDNNError_OldCudaVersion;
}
else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
}
else
cuDNNFlag = eWaifu2xcuDNNError_OldVersion;
}

View File

@ -122,6 +122,7 @@ public:
eWaifu2xcuDNNError_NotFind,
eWaifu2xcuDNNError_OldVersion,
eWaifu2xcuDNNError_CannotCreate,
eWaifu2xcuDNNError_OldCudaVersion,
};
typedef std::function<bool()> waifu2xCancelFunc;

View File

@ -2737,6 +2737,9 @@ void DialogEvent::CheckCUDNN(HWND hWnd, WPARAM wParam, LPARAM lParam, LPVOID lpD
case Waifu2x::eWaifu2xcuDNNError_CannotCreate:
MessageBox(dh, langStringList.GetString(L"MessagecuDNNCannotCreateError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR);
break;
case Waifu2x::eWaifu2xcuDNNError_OldCudaVersion:
MessageBox(dh, langStringList.GetString(L"MessageCudaOldVersionError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR);
break;
default:
MessageBox(dh, langStringList.GetString(L"MessagecuDNNDefautlError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR);
}