From 9f89247295495b15f10176fbb43c2fe756077505 Mon Sep 17 00:00:00 2001 From: Danilo Krummrich Date: Wed, 28 Sep 2022 14:56:00 +0200 Subject: [PATCH] [SYCL][PI/CL] Check device version/extensions rather than platform version/extensions (#6795) For OpenCL backends currently piProgramCreate() queries the platform version (CL_PLATFORM_VERSION) and platform extensions (CL_PLATFORM_EXTENSIONS) to check whether we're capable of running on top of a particular OpenCL backend. However, there might be platforms where the supported device version is lower than the platform version or where not all devices do support the same extensions and hence some extensions supported by a particular device are not reported in the platform extensions. In particular for CL_PLATFORM_EXTENSIONS the OpenCL specification says: "[...] Each extension that is supported by all devices associated with this platform must be reported here." In 3.4.1 Mixed Version Support the specification also says: "[...] The version returned corresponds to the highest version of the OpenCL specification for which the device is conformant, but is not higher than the platform version." Hence, check for the device version and extensions rather than the platform version and extensions in piProgramCreate(). Signed-off-by: Danilo Krummrich --- sycl/plugins/opencl/pi_opencl.cpp | 164 +++++++++++++++++++++--------- sycl/plugins/opencl/pi_opencl.hpp | 91 +++++++++++++++++ 2 files changed, 208 insertions(+), 47 deletions(-) diff --git a/sycl/plugins/opencl/pi_opencl.cpp b/sycl/plugins/opencl/pi_opencl.cpp index 734a3eda6d58b..07fd25da52f35 100644 --- a/sycl/plugins/opencl/pi_opencl.cpp +++ b/sycl/plugins/opencl/pi_opencl.cpp @@ -89,6 +89,71 @@ pi_result piPluginGetLastError(char **message) { return ErrorMessageCode; } +static cl_int getPlatformVersion(cl_platform_id plat, + OCLV::OpenCLVersion &version) { + cl_int ret_err = CL_INVALID_VALUE; + + size_t platVerSize = 0; + ret_err = + clGetPlatformInfo(plat, CL_PLATFORM_VERSION, 0, nullptr, &platVerSize); + + std::string platVer(platVerSize, '\0'); + ret_err = clGetPlatformInfo(plat, CL_PLATFORM_VERSION, platVerSize, + platVer.data(), nullptr); + + if (ret_err != CL_SUCCESS) + return ret_err; + + version = OCLV::OpenCLVersion(platVer); + if (!version.isValid()) + return CL_INVALID_PLATFORM; + + return ret_err; +} + +static cl_int getDeviceVersion(cl_device_id dev, OCLV::OpenCLVersion &version) { + cl_int ret_err = CL_INVALID_VALUE; + + size_t devVerSize = 0; + ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, 0, nullptr, &devVerSize); + + std::string devVer(devVerSize, '\0'); + ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, devVerSize, devVer.data(), + nullptr); + + if (ret_err != CL_SUCCESS) + return ret_err; + + version = OCLV::OpenCLVersion(devVer); + if (!version.isValid()) + return CL_INVALID_DEVICE; + + return ret_err; +} + +static cl_int checkDeviceExtensions(cl_device_id dev, + const std::vector &exts, + bool &supported) { + cl_int ret_err = CL_INVALID_VALUE; + + size_t extSize = 0; + ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize); + + std::string extStr(extSize, '\0'); + ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, extSize, extStr.data(), + nullptr); + + if (ret_err != CL_SUCCESS) + return ret_err; + + supported = true; + for (const std::string &ext : exts) + if (!(supported = (extStr.find(ext) != std::string::npos))) + break; + + return ret_err; +} + // USM helper function to get an extension function pointer template static pi_result getExtFuncFromContext(pi_context context, T *fptr) { @@ -215,17 +280,18 @@ pi_result piDeviceGetInfo(pi_device device, pi_device_info paramName, case PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES: return PI_ERROR_INVALID_VALUE; case PI_DEVICE_INFO_ATOMIC_64: { - size_t extSize; - cl_bool result = clGetDeviceInfo( - cast(device), CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize); - std::string extStr(extSize, '\0'); - result = clGetDeviceInfo(cast(device), CL_DEVICE_EXTENSIONS, - extSize, &extStr.front(), nullptr); - if (extStr.find("cl_khr_int64_base_atomics") == std::string::npos || - extStr.find("cl_khr_int64_extended_atomics") == std::string::npos) - result = false; - else - result = true; + cl_int ret_err = CL_SUCCESS; + cl_bool result = CL_FALSE; + bool supported = false; + + ret_err = checkDeviceExtensions( + cast(device), + {"cl_khr_int64_base_atomics", "cl_khr_int64_extended_atomics"}, + supported); + if (ret_err != CL_SUCCESS) + return static_cast(ret_err); + + result = supported; std::memcpy(paramValue, &result, sizeof(cl_bool)); return PI_SUCCESS; } @@ -402,18 +468,6 @@ pi_result piQueueCreate(pi_context context, pi_device device, CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err); - size_t platVerSize; - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr, - &platVerSize); - - CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err); - - std::string platVer(platVerSize, '\0'); - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, platVerSize, - &platVer.front(), nullptr); - - CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err); - // Check that unexpected bits are not set. assert(!(properties & ~(PI_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE | @@ -425,9 +479,12 @@ pi_result piQueueCreate(pi_context context, pi_device device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE | CL_QUEUE_PROFILING_ENABLE | CL_QUEUE_ON_DEVICE | CL_QUEUE_ON_DEVICE_DEFAULT; - if (platVer.find("OpenCL 1.0") != std::string::npos || - platVer.find("OpenCL 1.1") != std::string::npos || - platVer.find("OpenCL 1.2") != std::string::npos) { + OCLV::OpenCLVersion version; + ret_err = getPlatformVersion(curPlatform, version); + + CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err); + + if (version >= OCLV::V2_0) { *queue = cast(clCreateCommandQueue( cast(context), cast(device), cast(properties) & SupportByOpenCL, @@ -482,38 +539,51 @@ pi_result piProgramCreate(pi_context context, const void *il, size_t length, CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); - size_t devVerSize; - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr, - &devVerSize); - std::string devVer(devVerSize, '\0'); - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, devVerSize, - &devVer.front(), nullptr); + OCLV::OpenCLVersion platVer; + ret_err = getPlatformVersion(curPlatform, platVer); CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); pi_result err = PI_SUCCESS; - if (devVer.find("OpenCL 1.0") == std::string::npos && - devVer.find("OpenCL 1.1") == std::string::npos && - devVer.find("OpenCL 1.2") == std::string::npos && - devVer.find("OpenCL 2.0") == std::string::npos) { + if (platVer >= OCLV::V2_1) { + + /* Make sure all devices support CL 2.1 or newer as well. */ + for (cl_device_id dev : devicesInCtx) { + OCLV::OpenCLVersion devVer; + + ret_err = getDeviceVersion(dev, devVer); + CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); + + /* If the device does not support CL 2.1 or greater, we need to make sure + * it supports the cl_khr_il_program extension. + */ + if (devVer < OCLV::V2_1) { + bool supported = false; + + ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported); + CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); + + if (!supported) + return cast(CL_INVALID_OPERATION); + } + } if (res_program != nullptr) *res_program = cast(clCreateProgramWithIL( cast(context), il, length, cast(&err))); return err; } - size_t extSize; - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, 0, nullptr, - &extSize); - std::string extStr(extSize, '\0'); - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, extSize, - &extStr.front(), nullptr); + /* If none of the devices conform with CL 2.1 or newer make sure they all + * support the cl_khr_il_program extension. + */ + for (cl_device_id dev : devicesInCtx) { + bool supported = false; - if (ret_err != CL_SUCCESS || - extStr.find("cl_khr_il_program") == std::string::npos) { - if (res_program != nullptr) - *res_program = nullptr; - return cast(CL_INVALID_CONTEXT); + ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported); + CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); + + if (!supported) + return cast(CL_INVALID_OPERATION); } using apiFuncT = diff --git a/sycl/plugins/opencl/pi_opencl.hpp b/sycl/plugins/opencl/pi_opencl.hpp index 53dbd2a540590..179d0566c3088 100644 --- a/sycl/plugins/opencl/pi_opencl.hpp +++ b/sycl/plugins/opencl/pi_opencl.hpp @@ -17,6 +17,10 @@ #ifndef PI_OPENCL_HPP #define PI_OPENCL_HPP +#include +#include +#include + // This version should be incremented for any change made to this file or its // corresponding .cpp file. #define _PI_OPENCL_PLUGIN_VERSION 1 @@ -24,4 +28,91 @@ #define _PI_OPENCL_PLUGIN_VERSION_STRING \ _PI_PLUGIN_VERSION_STRING(_PI_OPENCL_PLUGIN_VERSION) +namespace OCLV { +class OpenCLVersion { +protected: + unsigned int major; + unsigned int minor; + +public: + OpenCLVersion() : major(0), minor(0) {} + + OpenCLVersion(unsigned int major, unsigned int minor) + : major(major), minor(minor) { + if (!isValid()) + major = minor = 0; + } + + OpenCLVersion(const char *version) : OpenCLVersion(std::string(version)) {} + + OpenCLVersion(const std::string &version) : major(0), minor(0) { + /* The OpenCL specification defines the full version string as + * 'OpenCL' for platforms and as + * 'OpenCL' for devices. + */ + std::regex rx("OpenCL ([0-9]+)\\.([0-9]+)"); + std::smatch match; + + if (std::regex_search(version, match, rx) && (match.size() == 3)) { + major = strtoul(match[1].str().c_str(), nullptr, 10); + minor = strtoul(match[2].str().c_str(), nullptr, 10); + + if (!isValid()) + major = minor = 0; + } + } + + bool operator==(const OpenCLVersion &v) const { + return major == v.major && minor == v.minor; + } + + bool operator!=(const OpenCLVersion &v) const { return !(*this == v); } + + bool operator<(const OpenCLVersion &v) const { + if (major == v.major) + return minor < v.minor; + + return major < v.major; + } + + bool operator>(const OpenCLVersion &v) const { return v < *this; } + + bool operator<=(const OpenCLVersion &v) const { + return (*this < v) || (*this == v); + } + + bool operator>=(const OpenCLVersion &v) const { + return (*this > v) || (*this == v); + } + + bool isValid() const { + switch (major) { + case 0: + return false; + case 1: + case 2: + return minor <= 2; + case UINT_MAX: + return false; + default: + return minor != UINT_MAX; + } + } + + int getMajor() const { return major; } + int getMinor() const { return minor; } +}; + +inline const OpenCLVersion V1_0(1, 0); +inline const OpenCLVersion V1_1(1, 1); +inline const OpenCLVersion V1_2(1, 2); +inline const OpenCLVersion V2_0(2, 0); +inline const OpenCLVersion V2_1(2, 1); +inline const OpenCLVersion V2_2(2, 2); +inline const OpenCLVersion V3_0(3, 0); + +} // namespace OCLV + #endif // PI_OPENCL_HPP