diff --git a/sycl/plugins/opencl/pi_opencl.cpp b/sycl/plugins/opencl/pi_opencl.cpp index e7bd37ceb4dea..734a3eda6d58b 100644 --- a/sycl/plugins/opencl/pi_opencl.cpp +++ b/sycl/plugins/opencl/pi_opencl.cpp @@ -865,14 +865,24 @@ pi_result piKernelGetSubGroupInfo(pi_kernel kernel, pi_device device, std::shared_ptr implicit_input_value; if (param_name == PI_KERNEL_MAX_SUB_GROUP_SIZE && !input_value) { // OpenCL needs an input value for PI_KERNEL_MAX_SUB_GROUP_SIZE so if no - // value is given we use the max work item sizes of the device to avoid - // truncation of max sub-group size. - implicit_input_value = std::shared_ptr(new size_t[3]); - pi_result pi_ret_err = piDeviceGetInfo( - device, PI_DEVICE_INFO_MAX_WORK_ITEM_SIZES, 3 * sizeof(size_t), - implicit_input_value.get(), nullptr); + // value is given we use the max work item size of the device in the first + // dimention to avoid truncation of max sub-group size. + pi_uint32 max_dims = 0; + pi_result pi_ret_err = + piDeviceGetInfo(device, PI_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS, + sizeof(pi_uint32), &max_dims, nullptr); if (pi_ret_err != PI_SUCCESS) return pi_ret_err; + std::shared_ptr WGSizes{new size_t[max_dims]}; + pi_ret_err = + piDeviceGetInfo(device, PI_DEVICE_INFO_MAX_WORK_ITEM_SIZES, + max_dims * sizeof(size_t), WGSizes.get(), nullptr); + if (pi_ret_err != PI_SUCCESS) + return pi_ret_err; + for (size_t i = 1; i < max_dims; ++i) + WGSizes.get()[i] = 1; + implicit_input_value = std::move(WGSizes); + input_value_size = max_dims * sizeof(size_t); input_value = implicit_input_value.get(); }