cuDNNが使うランタイムよりCUDAランタイムのバージョンが古くないかのチェックを入れた

This commit is contained in:
lltcggie 2020-09-05 16:49:04 +09:00
parent c365977a1e
commit 3812b90f68
3 changed files with 23 additions and 6 deletions

View File

@ -465,21 +465,34 @@ Waifu2x::eWaifu2xcuDNNError Waifu2x::can_use_cuDNN()
typedef cudnnStatus_t(CUDNNWINAPI* cudnnCreateType)(cudnnHandle_t *); typedef cudnnStatus_t(CUDNNWINAPI* cudnnCreateType)(cudnnHandle_t *);
typedef cudnnStatus_t(CUDNNWINAPI* cudnnDestroyType)(cudnnHandle_t); typedef cudnnStatus_t(CUDNNWINAPI* cudnnDestroyType)(cudnnHandle_t);
typedef size_t(CUDNNWINAPI* cudnnGetVersionType)(); typedef size_t(CUDNNWINAPI* cudnnGetVersionType)();
typedef size_t(CUDNNWINAPI* cudnnGetCudartVersionType)();
cudnnCreateType cudnnCreateFunc = (cudnnCreateType)GetProcAddress(hModule, "cudnnCreate"); cudnnCreateType cudnnCreateFunc = (cudnnCreateType)GetProcAddress(hModule, "cudnnCreate");
cudnnDestroyType cudnnDestroyFunc = (cudnnDestroyType)GetProcAddress(hModule, "cudnnDestroy"); cudnnDestroyType cudnnDestroyFunc = (cudnnDestroyType)GetProcAddress(hModule, "cudnnDestroy");
cudnnGetVersionType cudnnGetVersionFunc = (cudnnGetVersionType)GetProcAddress(hModule, "cudnnGetVersion"); cudnnGetVersionType cudnnGetVersionFunc = (cudnnGetVersionType)GetProcAddress(hModule, "cudnnGetVersion");
if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr) cudnnGetCudartVersionType cudnnGetCudartVersionFunc = (cudnnGetCudartVersionType)GetProcAddress(hModule, "cudnnGetCudartVersion");
if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr && cudnnGetCudartVersionFunc != nullptr)
{ {
if (cudnnGetVersionFunc() >= CUDNN_REQUIRE_VERION) if (cudnnGetVersionFunc() >= CUDNN_REQUIRE_VERION)
{ {
cudnnHandle_t h; int runtimeVersion;
if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS) if (cudaRuntimeGetVersion(&runtimeVersion) == cudaSuccess)
{ {
if (cudnnDestroyFunc(h) == CUDNN_STATUS_SUCCESS) if (cudnnGetCudartVersionFunc() >= runtimeVersion)
cuDNNFlag = eWaifu2xcuDNNError_OK; {
cudnnHandle_t h;
if (cudnnCreateFunc(&h) == CUDNN_STATUS_SUCCESS)
{
if (cudnnDestroyFunc(h) == CUDNN_STATUS_SUCCESS)
cuDNNFlag = eWaifu2xcuDNNError_OK;
else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
}
else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
}
else else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate; cuDNNFlag = eWaifu2xcuDNNError_OldCudaVersion;
} }
else else
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate; cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;

View File

@ -122,6 +122,7 @@ public:
eWaifu2xcuDNNError_NotFind, eWaifu2xcuDNNError_NotFind,
eWaifu2xcuDNNError_OldVersion, eWaifu2xcuDNNError_OldVersion,
eWaifu2xcuDNNError_CannotCreate, eWaifu2xcuDNNError_CannotCreate,
eWaifu2xcuDNNError_OldCudaVersion,
}; };
typedef std::function<bool()> waifu2xCancelFunc; typedef std::function<bool()> waifu2xCancelFunc;

View File

@ -2737,6 +2737,9 @@ void DialogEvent::CheckCUDNN(HWND hWnd, WPARAM wParam, LPARAM lParam, LPVOID lpD
case Waifu2x::eWaifu2xcuDNNError_CannotCreate: case Waifu2x::eWaifu2xcuDNNError_CannotCreate:
MessageBox(dh, langStringList.GetString(L"MessagecuDNNCannotCreateError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR); MessageBox(dh, langStringList.GetString(L"MessagecuDNNCannotCreateError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR);
break; break;
case Waifu2x::eWaifu2xcuDNNError_OldCudaVersion:
MessageBox(dh, langStringList.GetString(L"MessageCudaOldVersionError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR);
break;
default: default:
MessageBox(dh, langStringList.GetString(L"MessagecuDNNDefautlError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR); MessageBox(dh, langStringList.GetString(L"MessagecuDNNDefautlError").c_str(), langStringList.GetString(L"MessageTitleResult").c_str(), MB_OK | MB_ICONERROR);
} }