Skip to content

Commit

Permalink
initial (#3)
Browse files Browse the repository at this point in the history
Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Sep 7, 2023
1 parent 22f6d0e commit 06ebdc8
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions tests/python-oneapi/test_oneapi_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from hypothesis import given, strategies, assume, settings, note

import sys
import os
sys.path.append("tests/python")
import testing as tm

Expand Down Expand Up @@ -45,7 +46,22 @@ def test_oneapi_hist(self, param, num_rounds, dataset):
@given(tm.dataset_strategy.filter(lambda x: x.name != "empty"), strategies.integers(0, 1))
@settings(deadline=None)
def test_specified_device_id_oneapi_update(self, dataset, device_id):
param = {'updater': 'grow_quantile_histmaker_oneapi', 'device_id': device_id}
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), 10)
assert tm.non_increasing(result['train'][dataset.metric])
# Read the list of sycl-devicese
sycl_ls = os.popen('sycl-ls').read()
devices = sycl_ls.split('\n')

# Test should launch only on gpu
# Find gpus in the list of devices
# and use the id in the list insteard of device_id
target_device_type = "opencl:gpu"
found_devices = 0
for idx in range(len(devices)):
if len(devices[idx]) >= len(target_device_type):
if devices[idx][1:1+len(target_device_type)] == target_device_type:
if (found_devices == device_id):
param = {'updater': 'grow_quantile_histmaker_oneapi', 'device_id': idx}
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), 10)
assert tm.non_increasing(result['train'][dataset.metric])
else:
found_devices += 1

0 comments on commit 06ebdc8

Please sign in to comment.