Skip to content

Commit

Permalink
[SYCL] Relax kernel bundle device check to allow descendent devices (#…
Browse files Browse the repository at this point in the history
…7334)

Change kernel_bundle_impl constructors to treat descendent devices of
context members as valid in accordance with SYCL 2020.
  • Loading branch information
sergey-semenov authored Nov 17, 2022
1 parent 4cc1094 commit a782779
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 34 deletions.
21 changes: 21 additions & 0 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,27 @@ class context_impl {
/// Returns true if and only if context contains the given device.
bool hasDevice(std::shared_ptr<detail::device_impl> Device) const;

/// Returns true if and only if the device can be used within this context.
/// For OpenCL this is currently equivalent to hasDevice, for other backends
/// it returns true if the device is either a member of the context or a
/// descendant of a member.
bool isDeviceValid(DeviceImplPtr Device) {
// OpenCL does not support using descendants of context members within that
// context yet.
// TODO remove once this limitation is lifted
if (!is_host() && getPlugin().getBackend() == backend::opencl)
return hasDevice(Device);

while (!hasDevice(Device)) {
if (Device->isRootDevice())
return false;
Device = detail::getSyclObjImpl(
Device->get_info<info::device::parent_device>());
}

return true;
}

/// Given a PiDevice, returns the matching shared_ptr<device_impl>
/// within this context. May return nullptr if no match discovered.
DeviceImplPtr findMatchingDeviceImpl(RT::PiDevice &DevicePI) const;
Expand Down
6 changes: 2 additions & 4 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ namespace detail {

static bool checkAllDevicesAreInContext(const std::vector<device> &Devices,
const context &Context) {
const std::vector<device> &ContextDevices = Context.get_devices();
return std::all_of(
Devices.begin(), Devices.end(), [&ContextDevices](const device &Dev) {
return ContextDevices.end() !=
std::find(ContextDevices.begin(), ContextDevices.end(), Dev);
Devices.begin(), Devices.end(), [&Context](const device &Dev) {
return getSyclObjImpl(Context)->isDeviceValid(getSyclObjImpl(Dev));
});
}

Expand Down
25 changes: 2 additions & 23 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class queue_impl {

ContextImplPtr DefaultContext = detail::getSyclObjImpl(
Device->get_platform().ext_oneapi_get_default_context());
if (isValidDevice(DefaultContext, Device))
if (DefaultContext->isDeviceValid(Device))
return DefaultContext;
return detail::getSyclObjImpl(
context{createSyclObjFromImpl<device>(Device), {}, {}});
Expand Down Expand Up @@ -104,7 +104,7 @@ class queue_impl {
"Queue cannot be constructed with both of "
"discard_events and enable_profiling.");
}
if (!isValidDevice(Context, Device)) {
if (!Context->isDeviceValid(Device)) {
if (!Context->is_host() &&
Context->getPlugin().getBackend() == backend::opencl)
throw sycl::invalid_object_error(
Expand Down Expand Up @@ -486,27 +486,6 @@ class queue_impl {
}

protected:
/// Helper function for checking whether a device is either a member of a
/// context or a descendnant of its member.
/// \return True iff the device or its parent is a member of the context.
static bool isValidDevice(const ContextImplPtr &Context,
DeviceImplPtr Device) {
// OpenCL does not support creating a queue with a descendant of a device
// from the given context yet.
// TODO remove once this limitation is lifted
if (!Context->is_host() &&
Context->getPlugin().getBackend() == backend::opencl)
return Context->hasDevice(Device);

while (!Context->hasDevice(Device)) {
if (Device->isRootDevice())
return false;
Device = detail::getSyclObjImpl(
Device->get_info<info::device::parent_device>());
}
return true;
}

/// Performs command group submission to the queue.
///
/// \param CGF is a function object containing command group.
Expand Down
78 changes: 78 additions & 0 deletions sycl/unittests/SYCL2020/KernelBundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include <detail/device_impl.hpp>
#include <detail/kernel_bundle_impl.hpp>
#include <sycl/sycl.hpp>

Expand Down Expand Up @@ -459,3 +460,80 @@ TEST(KernelBundle, EmptyDevicesKernelBundleLinkException) {
FAIL() << "Unexpected exception was thrown in sycl::link.";
}
}

pi_device ParentDevice = nullptr;
pi_platform PiPlatform = nullptr;

pi_result redefinedDeviceGetInfoAfter(pi_device device,
pi_device_info param_name,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
if (param_name == PI_DEVICE_INFO_PARTITION_PROPERTIES) {
if (param_value) {
auto *Result =
reinterpret_cast<pi_device_partition_property *>(param_value);
*Result = PI_DEVICE_PARTITION_EQUALLY;
}
if (param_value_size_ret)
*param_value_size_ret = sizeof(pi_device_partition_property);
} else if (param_name == PI_DEVICE_INFO_MAX_COMPUTE_UNITS) {
auto *Result = reinterpret_cast<pi_uint32 *>(param_value);
*Result = 2;
} else if (param_name == PI_DEVICE_INFO_PARENT_DEVICE) {
auto *Result = reinterpret_cast<pi_device *>(param_value);
*Result = (device == ParentDevice) ? nullptr : ParentDevice;
} else if (param_name == PI_DEVICE_INFO_PLATFORM) {
auto *Result = reinterpret_cast<pi_platform *>(param_value);
*Result = PiPlatform;
}
return PI_SUCCESS;
}

pi_result redefinedDevicePartitionAfter(
pi_device device, const pi_device_partition_property *properties,
pi_uint32 num_devices, pi_device *out_devices, pi_uint32 *out_num_devices) {
if (out_devices) {
for (size_t I = 0; I < num_devices; ++I) {
out_devices[I] = reinterpret_cast<pi_device>(1000 + I);
}
}
if (out_num_devices)
*out_num_devices = num_devices;
return PI_SUCCESS;
}

TEST(KernelBundle, DescendentDevice) {
// Mock a non-OpenCL plugin since use of descendent devices of context members
// is not supported there yet.
sycl::unittest::PiMock Mock(sycl::backend::level_zero);

sycl::platform Plt = Mock.getPlatform();

PiPlatform = sycl::detail::getSyclObjImpl(Plt)->getHandleRef();

Mock.redefineAfter<sycl::detail::PiApiKind::piDeviceGetInfo>(
redefinedDeviceGetInfoAfter);
Mock.redefineAfter<sycl::detail::PiApiKind::piDevicePartition>(
redefinedDevicePartitionAfter);

const sycl::device Dev = Mock.getPlatform().get_devices()[0];
ParentDevice = sycl::detail::getSyclObjImpl(Dev)->getHandleRef();
sycl::context Ctx{Dev};
sycl::device Subdev =
Dev.create_sub_devices<sycl::info::partition_property::partition_equally>(
2)[0];

sycl::queue Queue{Ctx, Subdev};

sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
sycl::get_kernel_bundle<sycl::bundle_state::executable>(Ctx, {Subdev});

sycl::kernel Kernel =
KernelBundle.get_kernel(sycl::get_kernel_id<TestKernel>());

sycl::kernel_bundle<sycl::bundle_state::executable> RetKernelBundle =
Kernel.get_kernel_bundle();

EXPECT_EQ(KernelBundle, RetKernelBundle);
}
18 changes: 11 additions & 7 deletions sycl/unittests/helpers/PiMock.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,12 @@ class PiMock {
/// within the given context. A separate platform instance will be
/// held by the PiMock instance.
///
PiMock() {
/// \param Backend is the backend type to mock, intended for testing backend
/// specific runtime logic.
PiMock(backend Backend = backend::opencl) {
// Create new mock plugin platform and plugin handles
// Note: Mock plugin will be generated if it has not been yet.
MPlatformImpl = GetMockPlatformImpl();
MPlatformImpl = GetMockPlatformImpl(Backend);
std::shared_ptr<detail::plugin> NewPluginPtr;
{
const detail::plugin &OriginalPiPlugin = MPlatformImpl->getPlugin();
Expand Down Expand Up @@ -328,7 +330,9 @@ class PiMock {
/// in the global handler. Additionally, all existing plugins will be removed
/// and unloaded to avoid them being accidentally picked up by tests using
/// selectors.
static void EnsureMockPluginInitialized() {
/// \param Backend is the backend type to mock, intended for testing backend
/// specific runtime logic.
static void EnsureMockPluginInitialized(backend Backend = backend::opencl) {
// Only initialize the plugin once.
if (MMockPluginPtr)
return;
Expand All @@ -346,8 +350,7 @@ class PiMock {
RT::PiPlugin{"pi.ver.mock", "plugin.ver.mock", /*Targets=*/nullptr,
getProxyMockedFunctionPointers()});

// FIXME: which backend to pass here? does it affect anything?
MMockPluginPtr = std::make_unique<detail::plugin>(RTPlugin, backend::opencl,
MMockPluginPtr = std::make_unique<detail::plugin>(RTPlugin, Backend,
/*Library=*/nullptr);
Plugins.push_back(*MMockPluginPtr);
}
Expand All @@ -357,8 +360,9 @@ class PiMock {
/// platform_impl from it.
///
/// \return a shared_ptr to a platform_impl created from the mock PI plugin.
static std::shared_ptr<sycl::detail::platform_impl> GetMockPlatformImpl() {
EnsureMockPluginInitialized();
static std::shared_ptr<sycl::detail::platform_impl>
GetMockPlatformImpl(backend Backend) {
EnsureMockPluginInitialized(Backend);

pi_uint32 NumPlatforms = 0;
MMockPluginPtr->call_nocheck<detail::PiApiKind::piPlatformsGet>(
Expand Down

0 comments on commit a782779

Please sign in to comment.