From 6aefd6317ca8711095e8a1556abf33c26d3d4132 Mon Sep 17 00:00:00 2001 From: lbushi25 <113361374+lbushi25@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:08:33 -0500 Subject: [PATCH] [SYCL] Improve ODS negative filter implementation (#7453) The negative filter implementation for ONEAPI_DEVICE_SELECTOR uses a map to keep track of blacklisted devices. The keys used by this map were originally device addresses in a vector container which are not very robust because vectors can potentially move their data to other locations and the device addresses could change thus invalidating the blacklist map. Even though in the source code the resizing of the vector only happens after we are done with the blacklist, you never know what tricks the compiler might pull on us. We use device numbers instead which are unique for each device in a platform and do not change during the function execution. --- sycl/source/detail/platform_impl.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sycl/source/detail/platform_impl.cpp b/sycl/source/detail/platform_impl.cpp index 1ea82843050b3..a7532dc9c494e 100644 --- a/sycl/source/detail/platform_impl.cpp +++ b/sycl/source/detail/platform_impl.cpp @@ -185,7 +185,9 @@ static std::vector filterDeviceFilter(std::vector &PiDevices, // used in the ONEAPI_DEVICE_SELECTOR implemenation. It cannot be placed // in the if statement above because it will then be out of scope in the rest // of the function - std::map Blacklist; + std::map Blacklist; + // original indices keeps track of the device numbers of the chosen + // devices and is whats returned by the function std::vector original_indices; std::vector &Plugins = RT::initialize(); @@ -223,14 +225,14 @@ static std::vector filterDeviceFilter(std::vector &PiDevices, // Last, match the device_num entry if (!Filter.DeviceNum || DeviceNum == Filter.DeviceNum.value()) { if constexpr (is_ods_target) { // dealing with ODS filters - if (!Blacklist[&Device]) { // ensure it is not blacklisted + if (!Blacklist[DeviceNum]) { // ensure it is not blacklisted if (!Filter.IsNegativeTarget) { // is filter positive? PiDevices[InsertIDx++] = Device; original_indices.push_back(DeviceNum); } else { // Filter is negative and the device matches the filter so // blacklist the device. - Blacklist[&Device] = true; + Blacklist[DeviceNum] = true; } } } else { // dealing with SYCL_DEVICE_FILTER @@ -243,14 +245,14 @@ static std::vector filterDeviceFilter(std::vector &PiDevices, } else if (FilterDevType == DeviceType) { if (!Filter.DeviceNum || DeviceNum == Filter.DeviceNum.value()) { if constexpr (is_ods_target) { - if (!Blacklist[&Device]) { + if (!Blacklist[DeviceNum]) { if (!Filter.IsNegativeTarget) { PiDevices[InsertIDx++] = Device; original_indices.push_back(DeviceNum); } else { // Filter is negative and the device matches the filter so // blacklist the device. - Blacklist[&Device] = true; + Blacklist[DeviceNum] = true; } } } else {