diff --git a/sycl/plugins/opencl/pi_opencl.cpp b/sycl/plugins/opencl/pi_opencl.cpp index 734a3eda6d58b..07fd25da52f35 100644 --- a/sycl/plugins/opencl/pi_opencl.cpp +++ b/sycl/plugins/opencl/pi_opencl.cpp @@ -89,6 +89,71 @@ pi_result piPluginGetLastError(char **message) { return ErrorMessageCode; } +static cl_int getPlatformVersion(cl_platform_id plat, + OCLV::OpenCLVersion &version) { + cl_int ret_err = CL_INVALID_VALUE; + + size_t platVerSize = 0; + ret_err = + clGetPlatformInfo(plat, CL_PLATFORM_VERSION, 0, nullptr, &platVerSize); + + std::string platVer(platVerSize, '\0'); + ret_err = clGetPlatformInfo(plat, CL_PLATFORM_VERSION, platVerSize, + platVer.data(), nullptr); + + if (ret_err != CL_SUCCESS) + return ret_err; + + version = OCLV::OpenCLVersion(platVer); + if (!version.isValid()) + return CL_INVALID_PLATFORM; + + return ret_err; +} + +static cl_int getDeviceVersion(cl_device_id dev, OCLV::OpenCLVersion &version) { + cl_int ret_err = CL_INVALID_VALUE; + + size_t devVerSize = 0; + ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, 0, nullptr, &devVerSize); + + std::string devVer(devVerSize, '\0'); + ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, devVerSize, devVer.data(), + nullptr); + + if (ret_err != CL_SUCCESS) + return ret_err; + + version = OCLV::OpenCLVersion(devVer); + if (!version.isValid()) + return CL_INVALID_DEVICE; + + return ret_err; +} + +static cl_int checkDeviceExtensions(cl_device_id dev, + const std::vector &exts, + bool &supported) { + cl_int ret_err = CL_INVALID_VALUE; + + size_t extSize = 0; + ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize); + + std::string extStr(extSize, '\0'); + ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, extSize, extStr.data(), + nullptr); + + if (ret_err != CL_SUCCESS) + return ret_err; + + supported = true; + for (const std::string &ext : exts) + if (!(supported = (extStr.find(ext) != std::string::npos))) + break; + + return ret_err; +} + // USM helper function to get an extension function pointer template static pi_result getExtFuncFromContext(pi_context context, T *fptr) { @@ -215,17 +280,18 @@ pi_result piDeviceGetInfo(pi_device device, pi_device_info paramName, case PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES: return PI_ERROR_INVALID_VALUE; case PI_DEVICE_INFO_ATOMIC_64: { - size_t extSize; - cl_bool result = clGetDeviceInfo( - cast(device), CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize); - std::string extStr(extSize, '\0'); - result = clGetDeviceInfo(cast(device), CL_DEVICE_EXTENSIONS, - extSize, &extStr.front(), nullptr); - if (extStr.find("cl_khr_int64_base_atomics") == std::string::npos || - extStr.find("cl_khr_int64_extended_atomics") == std::string::npos) - result = false; - else - result = true; + cl_int ret_err = CL_SUCCESS; + cl_bool result = CL_FALSE; + bool supported = false; + + ret_err = checkDeviceExtensions( + cast(device), + {"cl_khr_int64_base_atomics", "cl_khr_int64_extended_atomics"}, + supported); + if (ret_err != CL_SUCCESS) + return static_cast(ret_err); + + result = supported; std::memcpy(paramValue, &result, sizeof(cl_bool)); return PI_SUCCESS; } @@ -402,18 +468,6 @@ pi_result piQueueCreate(pi_context context, pi_device device, CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err); - size_t platVerSize; - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr, - &platVerSize); - - CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err); - - std::string platVer(platVerSize, '\0'); - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, platVerSize, - &platVer.front(), nullptr); - - CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err); - // Check that unexpected bits are not set. assert(!(properties & ~(PI_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE | @@ -425,9 +479,12 @@ pi_result piQueueCreate(pi_context context, pi_device device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE | CL_QUEUE_PROFILING_ENABLE | CL_QUEUE_ON_DEVICE | CL_QUEUE_ON_DEVICE_DEFAULT; - if (platVer.find("OpenCL 1.0") != std::string::npos || - platVer.find("OpenCL 1.1") != std::string::npos || - platVer.find("OpenCL 1.2") != std::string::npos) { + OCLV::OpenCLVersion version; + ret_err = getPlatformVersion(curPlatform, version); + + CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err); + + if (version >= OCLV::V2_0) { *queue = cast(clCreateCommandQueue( cast(context), cast(device), cast(properties) & SupportByOpenCL, @@ -482,38 +539,51 @@ pi_result piProgramCreate(pi_context context, const void *il, size_t length, CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); - size_t devVerSize; - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr, - &devVerSize); - std::string devVer(devVerSize, '\0'); - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, devVerSize, - &devVer.front(), nullptr); + OCLV::OpenCLVersion platVer; + ret_err = getPlatformVersion(curPlatform, platVer); CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); pi_result err = PI_SUCCESS; - if (devVer.find("OpenCL 1.0") == std::string::npos && - devVer.find("OpenCL 1.1") == std::string::npos && - devVer.find("OpenCL 1.2") == std::string::npos && - devVer.find("OpenCL 2.0") == std::string::npos) { + if (platVer >= OCLV::V2_1) { + + /* Make sure all devices support CL 2.1 or newer as well. */ + for (cl_device_id dev : devicesInCtx) { + OCLV::OpenCLVersion devVer; + + ret_err = getDeviceVersion(dev, devVer); + CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); + + /* If the device does not support CL 2.1 or greater, we need to make sure + * it supports the cl_khr_il_program extension. + */ + if (devVer < OCLV::V2_1) { + bool supported = false; + + ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported); + CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); + + if (!supported) + return cast(CL_INVALID_OPERATION); + } + } if (res_program != nullptr) *res_program = cast(clCreateProgramWithIL( cast(context), il, length, cast(&err))); return err; } - size_t extSize; - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, 0, nullptr, - &extSize); - std::string extStr(extSize, '\0'); - ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, extSize, - &extStr.front(), nullptr); + /* If none of the devices conform with CL 2.1 or newer make sure they all + * support the cl_khr_il_program extension. + */ + for (cl_device_id dev : devicesInCtx) { + bool supported = false; - if (ret_err != CL_SUCCESS || - extStr.find("cl_khr_il_program") == std::string::npos) { - if (res_program != nullptr) - *res_program = nullptr; - return cast(CL_INVALID_CONTEXT); + ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported); + CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT); + + if (!supported) + return cast(CL_INVALID_OPERATION); } using apiFuncT = diff --git a/sycl/plugins/opencl/pi_opencl.hpp b/sycl/plugins/opencl/pi_opencl.hpp index 53dbd2a540590..179d0566c3088 100644 --- a/sycl/plugins/opencl/pi_opencl.hpp +++ b/sycl/plugins/opencl/pi_opencl.hpp @@ -17,6 +17,10 @@ #ifndef PI_OPENCL_HPP #define PI_OPENCL_HPP +#include +#include +#include + // This version should be incremented for any change made to this file or its // corresponding .cpp file. #define _PI_OPENCL_PLUGIN_VERSION 1 @@ -24,4 +28,91 @@ #define _PI_OPENCL_PLUGIN_VERSION_STRING \ _PI_PLUGIN_VERSION_STRING(_PI_OPENCL_PLUGIN_VERSION) +namespace OCLV { +class OpenCLVersion { +protected: + unsigned int major; + unsigned int minor; + +public: + OpenCLVersion() : major(0), minor(0) {} + + OpenCLVersion(unsigned int major, unsigned int minor) + : major(major), minor(minor) { + if (!isValid()) + major = minor = 0; + } + + OpenCLVersion(const char *version) : OpenCLVersion(std::string(version)) {} + + OpenCLVersion(const std::string &version) : major(0), minor(0) { + /* The OpenCL specification defines the full version string as + * 'OpenCL' for platforms and as + * 'OpenCL' for devices. + */ + std::regex rx("OpenCL ([0-9]+)\\.([0-9]+)"); + std::smatch match; + + if (std::regex_search(version, match, rx) && (match.size() == 3)) { + major = strtoul(match[1].str().c_str(), nullptr, 10); + minor = strtoul(match[2].str().c_str(), nullptr, 10); + + if (!isValid()) + major = minor = 0; + } + } + + bool operator==(const OpenCLVersion &v) const { + return major == v.major && minor == v.minor; + } + + bool operator!=(const OpenCLVersion &v) const { return !(*this == v); } + + bool operator<(const OpenCLVersion &v) const { + if (major == v.major) + return minor < v.minor; + + return major < v.major; + } + + bool operator>(const OpenCLVersion &v) const { return v < *this; } + + bool operator<=(const OpenCLVersion &v) const { + return (*this < v) || (*this == v); + } + + bool operator>=(const OpenCLVersion &v) const { + return (*this > v) || (*this == v); + } + + bool isValid() const { + switch (major) { + case 0: + return false; + case 1: + case 2: + return minor <= 2; + case UINT_MAX: + return false; + default: + return minor != UINT_MAX; + } + } + + int getMajor() const { return major; } + int getMinor() const { return minor; } +}; + +inline const OpenCLVersion V1_0(1, 0); +inline const OpenCLVersion V1_1(1, 1); +inline const OpenCLVersion V1_2(1, 2); +inline const OpenCLVersion V2_0(2, 0); +inline const OpenCLVersion V2_1(2, 1); +inline const OpenCLVersion V2_2(2, 2); +inline const OpenCLVersion V3_0(3, 0); + +} // namespace OCLV + #endif // PI_OPENCL_HPP