diff --git a/tests/cpp/common/test_host_device_vector.cu b/tests/cpp/common/test_host_device_vector.cu index 6aca78c81455..fb25ac5500c0 100644 --- a/tests/cpp/common/test_host_device_vector.cu +++ b/tests/cpp/common/test_host_device_vector.cu @@ -210,7 +210,6 @@ TEST(HostDeviceVector, Reshard) { vec.Shard(devices); ASSERT_EQ(vec.DeviceSize(0), h_vec.size()); ASSERT_EQ(vec.Size(), h_vec.size()); - auto span = vec.DeviceSpan(0); // sync to device PlusOne(&vec); vec.Reshard(GPUDistribution::Empty()); @@ -287,38 +286,36 @@ TEST(HostDeviceVector, MGPU_Reshard) { LOG(WARNING) << "Not testing in multi-gpu environment."; return; } + SetCudaSetDeviceHandler(SetDevice); - std::vector h_vec (2345); - for (size_t i = 0; i < h_vec.size(); ++i) { - h_vec[i] = i; - } - HostDeviceVector vec (h_vec); + size_t n = 1001; + int n_devices = 2; + auto distribution = GPUDistribution::Block(GPUSet::Range(0, n_devices)); + std::vector starts{0, 501}; + std::vector sizes{501, 500}; - // Data size for each device. - std::vector devices_size (devices.Size()); + HostDeviceVector v; + InitHostDeviceVector(n, distribution, &v); + CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead); + PlusOne(&v); + CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite); + CheckHost(&v, GPUAccess::kRead); + CheckHost(&v, GPUAccess::kWrite); - // From CPU to GPUs. - vec.Shard(devices); - for (size_t i = 0; i < devices.Size(); ++i) { - auto span = vec.DeviceSpan(i); // sync to device - } - PlusOne(&vec); + auto distribution1 = GPUDistribution::Overlap(GPUSet::Range(0, n_devices), 1); + v.Reshard(distribution1); - // Reshard is allowed for already sharded vector. - vec.Reshard(GPUDistribution::Overlap(devices, 7)); - size_t total_size = 0; for (size_t i = 0; i < devices.Size(); ++i) { - total_size += vec.DeviceSize(i); - devices_size[i] = vec.DeviceSize(i); + auto span = v.DeviceSpan(i); // sync to device } - size_t overlap = 7 * (devices.Size() - 1); - ASSERT_EQ(total_size, h_vec.size() + overlap); - ASSERT_EQ(total_size, vec.Size() + overlap); - auto h_vec_1 = vec.HostVector(); - for (size_t i = 0; i < h_vec_1.size(); ++i) { - ASSERT_EQ(h_vec_1.at(i), i + 1); - } + std::vector starts1{0, 500}; + std::vector sizes1{501, 501}; + CheckDevice(&v, starts1, sizes1, 1, GPUAccess::kWrite); + CheckHost(&v, GPUAccess::kRead); + CheckHost(&v, GPUAccess::kWrite); + + SetCudaSetDeviceHandler(nullptr); } #endif