Skip to content

Commit

Permalink
[SYCL][PI/CL] Check device version/extensions rather than platform ve…
Browse files Browse the repository at this point in the history
…rsion/extensions (#6795)

For OpenCL backends currently piProgramCreate() queries the platform
version (CL_PLATFORM_VERSION) and platform extensions
(CL_PLATFORM_EXTENSIONS) to check whether we're capable of running on
top of a particular OpenCL backend.
    
However, there might be platforms where the supported device version is
lower than the platform version or where not all devices do support the
same extensions and hence some extensions supported by a particular
device are not reported in the platform extensions.

In particular for CL_PLATFORM_EXTENSIONS the OpenCL specification says:
"[...] Each extension that is supported by all devices associated with
this platform must be reported here."
    
In 3.4.1 Mixed Version Support the specification also says: "[...] The
version returned corresponds to the highest version of the OpenCL
specification for which the device is conformant, but is not higher than
the platform version."
    
Hence, check for the device version and extensions rather than the
platform version and extensions in piProgramCreate().

Signed-off-by: Danilo Krummrich <dakr@redhat.com>
  • Loading branch information
Danilo Krummrich authored Sep 28, 2022
1 parent 1f8d90f commit 9f89247
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 47 deletions.
164 changes: 117 additions & 47 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,71 @@ pi_result piPluginGetLastError(char **message) {
return ErrorMessageCode;
}

static cl_int getPlatformVersion(cl_platform_id plat,
OCLV::OpenCLVersion &version) {
cl_int ret_err = CL_INVALID_VALUE;

size_t platVerSize = 0;
ret_err =
clGetPlatformInfo(plat, CL_PLATFORM_VERSION, 0, nullptr, &platVerSize);

std::string platVer(platVerSize, '\0');
ret_err = clGetPlatformInfo(plat, CL_PLATFORM_VERSION, platVerSize,
platVer.data(), nullptr);

if (ret_err != CL_SUCCESS)
return ret_err;

version = OCLV::OpenCLVersion(platVer);
if (!version.isValid())
return CL_INVALID_PLATFORM;

return ret_err;
}

static cl_int getDeviceVersion(cl_device_id dev, OCLV::OpenCLVersion &version) {
cl_int ret_err = CL_INVALID_VALUE;

size_t devVerSize = 0;
ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, 0, nullptr, &devVerSize);

std::string devVer(devVerSize, '\0');
ret_err = clGetDeviceInfo(dev, CL_DEVICE_VERSION, devVerSize, devVer.data(),
nullptr);

if (ret_err != CL_SUCCESS)
return ret_err;

version = OCLV::OpenCLVersion(devVer);
if (!version.isValid())
return CL_INVALID_DEVICE;

return ret_err;
}

static cl_int checkDeviceExtensions(cl_device_id dev,
const std::vector<std::string> &exts,
bool &supported) {
cl_int ret_err = CL_INVALID_VALUE;

size_t extSize = 0;
ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize);

std::string extStr(extSize, '\0');
ret_err = clGetDeviceInfo(dev, CL_DEVICE_EXTENSIONS, extSize, extStr.data(),
nullptr);

if (ret_err != CL_SUCCESS)
return ret_err;

supported = true;
for (const std::string &ext : exts)
if (!(supported = (extStr.find(ext) != std::string::npos)))
break;

return ret_err;
}

// USM helper function to get an extension function pointer
template <const char *FuncName, typename T>
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
Expand Down Expand Up @@ -215,17 +280,18 @@ pi_result piDeviceGetInfo(pi_device device, pi_device_info paramName,
case PI_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES:
return PI_ERROR_INVALID_VALUE;
case PI_DEVICE_INFO_ATOMIC_64: {
size_t extSize;
cl_bool result = clGetDeviceInfo(
cast<cl_device_id>(device), CL_DEVICE_EXTENSIONS, 0, nullptr, &extSize);
std::string extStr(extSize, '\0');
result = clGetDeviceInfo(cast<cl_device_id>(device), CL_DEVICE_EXTENSIONS,
extSize, &extStr.front(), nullptr);
if (extStr.find("cl_khr_int64_base_atomics") == std::string::npos ||
extStr.find("cl_khr_int64_extended_atomics") == std::string::npos)
result = false;
else
result = true;
cl_int ret_err = CL_SUCCESS;
cl_bool result = CL_FALSE;
bool supported = false;

ret_err = checkDeviceExtensions(
cast<cl_device_id>(device),
{"cl_khr_int64_base_atomics", "cl_khr_int64_extended_atomics"},
supported);
if (ret_err != CL_SUCCESS)
return static_cast<pi_result>(ret_err);

result = supported;
std::memcpy(paramValue, &result, sizeof(cl_bool));
return PI_SUCCESS;
}
Expand Down Expand Up @@ -402,18 +468,6 @@ pi_result piQueueCreate(pi_context context, pi_device device,

CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);

size_t platVerSize;
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr,
&platVerSize);

CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);

std::string platVer(platVerSize, '\0');
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, platVerSize,
&platVer.front(), nullptr);

CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);

// Check that unexpected bits are not set.
assert(!(properties &
~(PI_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE |
Expand All @@ -425,9 +479,12 @@ pi_result piQueueCreate(pi_context context, pi_device device,
CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE | CL_QUEUE_PROFILING_ENABLE |
CL_QUEUE_ON_DEVICE | CL_QUEUE_ON_DEVICE_DEFAULT;

if (platVer.find("OpenCL 1.0") != std::string::npos ||
platVer.find("OpenCL 1.1") != std::string::npos ||
platVer.find("OpenCL 1.2") != std::string::npos) {
OCLV::OpenCLVersion version;
ret_err = getPlatformVersion(curPlatform, version);

CHECK_ERR_SET_NULL_RET(ret_err, queue, ret_err);

if (version >= OCLV::V2_0) {
*queue = cast<pi_queue>(clCreateCommandQueue(
cast<cl_context>(context), cast<cl_device_id>(device),
cast<cl_command_queue_properties>(properties) & SupportByOpenCL,
Expand Down Expand Up @@ -482,38 +539,51 @@ pi_result piProgramCreate(pi_context context, const void *il, size_t length,

CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

size_t devVerSize;
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, 0, nullptr,
&devVerSize);
std::string devVer(devVerSize, '\0');
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_VERSION, devVerSize,
&devVer.front(), nullptr);
OCLV::OpenCLVersion platVer;
ret_err = getPlatformVersion(curPlatform, platVer);

CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

pi_result err = PI_SUCCESS;
if (devVer.find("OpenCL 1.0") == std::string::npos &&
devVer.find("OpenCL 1.1") == std::string::npos &&
devVer.find("OpenCL 1.2") == std::string::npos &&
devVer.find("OpenCL 2.0") == std::string::npos) {
if (platVer >= OCLV::V2_1) {

/* Make sure all devices support CL 2.1 or newer as well. */
for (cl_device_id dev : devicesInCtx) {
OCLV::OpenCLVersion devVer;

ret_err = getDeviceVersion(dev, devVer);
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

/* If the device does not support CL 2.1 or greater, we need to make sure
* it supports the cl_khr_il_program extension.
*/
if (devVer < OCLV::V2_1) {
bool supported = false;

ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported);
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

if (!supported)
return cast<pi_result>(CL_INVALID_OPERATION);
}
}
if (res_program != nullptr)
*res_program = cast<pi_program>(clCreateProgramWithIL(
cast<cl_context>(context), il, length, cast<cl_int *>(&err)));
return err;
}

size_t extSize;
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, 0, nullptr,
&extSize);
std::string extStr(extSize, '\0');
ret_err = clGetPlatformInfo(curPlatform, CL_PLATFORM_EXTENSIONS, extSize,
&extStr.front(), nullptr);
/* If none of the devices conform with CL 2.1 or newer make sure they all
* support the cl_khr_il_program extension.
*/
for (cl_device_id dev : devicesInCtx) {
bool supported = false;

if (ret_err != CL_SUCCESS ||
extStr.find("cl_khr_il_program") == std::string::npos) {
if (res_program != nullptr)
*res_program = nullptr;
return cast<pi_result>(CL_INVALID_CONTEXT);
ret_err = checkDeviceExtensions(dev, {"cl_khr_il_program"}, supported);
CHECK_ERR_SET_NULL_RET(ret_err, res_program, CL_INVALID_CONTEXT);

if (!supported)
return cast<pi_result>(CL_INVALID_OPERATION);
}

using apiFuncT =
Expand Down
91 changes: 91 additions & 0 deletions sycl/plugins/opencl/pi_opencl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,102 @@
#ifndef PI_OPENCL_HPP
#define PI_OPENCL_HPP

#include <climits>
#include <regex>
#include <string>

// This version should be incremented for any change made to this file or its
// corresponding .cpp file.
#define _PI_OPENCL_PLUGIN_VERSION 1

#define _PI_OPENCL_PLUGIN_VERSION_STRING \
_PI_PLUGIN_VERSION_STRING(_PI_OPENCL_PLUGIN_VERSION)

namespace OCLV {
class OpenCLVersion {
protected:
unsigned int major;
unsigned int minor;

public:
OpenCLVersion() : major(0), minor(0) {}

OpenCLVersion(unsigned int major, unsigned int minor)
: major(major), minor(minor) {
if (!isValid())
major = minor = 0;
}

OpenCLVersion(const char *version) : OpenCLVersion(std::string(version)) {}

OpenCLVersion(const std::string &version) : major(0), minor(0) {
/* The OpenCL specification defines the full version string as
* 'OpenCL<space><major_version.minor_version><space><platform-specific
* information>' for platforms and as
* 'OpenCL<space><major_version.minor_version><space><vendor-specific
* information>' for devices.
*/
std::regex rx("OpenCL ([0-9]+)\\.([0-9]+)");
std::smatch match;

if (std::regex_search(version, match, rx) && (match.size() == 3)) {
major = strtoul(match[1].str().c_str(), nullptr, 10);
minor = strtoul(match[2].str().c_str(), nullptr, 10);

if (!isValid())
major = minor = 0;
}
}

bool operator==(const OpenCLVersion &v) const {
return major == v.major && minor == v.minor;
}

bool operator!=(const OpenCLVersion &v) const { return !(*this == v); }

bool operator<(const OpenCLVersion &v) const {
if (major == v.major)
return minor < v.minor;

return major < v.major;
}

bool operator>(const OpenCLVersion &v) const { return v < *this; }

bool operator<=(const OpenCLVersion &v) const {
return (*this < v) || (*this == v);
}

bool operator>=(const OpenCLVersion &v) const {
return (*this > v) || (*this == v);
}

bool isValid() const {
switch (major) {
case 0:
return false;
case 1:
case 2:
return minor <= 2;
case UINT_MAX:
return false;
default:
return minor != UINT_MAX;
}
}

int getMajor() const { return major; }
int getMinor() const { return minor; }
};

inline const OpenCLVersion V1_0(1, 0);
inline const OpenCLVersion V1_1(1, 1);
inline const OpenCLVersion V1_2(1, 2);
inline const OpenCLVersion V2_0(2, 0);
inline const OpenCLVersion V2_1(2, 1);
inline const OpenCLVersion V2_2(2, 2);
inline const OpenCLVersion V3_0(3, 0);

} // namespace OCLV

#endif // PI_OPENCL_HPP

0 comments on commit 9f89247

Please sign in to comment.