Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou committed Apr 30, 2019
1 parent f375e85 commit b1b4498
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions tests/cpp/common/test_host_device_vector.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ TEST(HostDeviceVector, Reshard) {
vec.Shard(devices);
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
ASSERT_EQ(vec.Size(), h_vec.size());
auto span = vec.DeviceSpan(0); // sync to device
PlusOne(&vec);

vec.Reshard(GPUDistribution::Empty());
Expand Down Expand Up @@ -287,38 +286,36 @@ TEST(HostDeviceVector, MGPU_Reshard) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
SetCudaSetDeviceHandler(SetDevice);

std::vector<int> h_vec (2345);
for (size_t i = 0; i < h_vec.size(); ++i) {
h_vec[i] = i;
}
HostDeviceVector<int> vec (h_vec);
size_t n = 1001;
int n_devices = 2;
auto distribution = GPUDistribution::Block(GPUSet::Range(0, n_devices));
std::vector<size_t> starts{0, 501};
std::vector<size_t> sizes{501, 500};

// Data size for each device.
std::vector<size_t> devices_size (devices.Size());
HostDeviceVector<int> v;
InitHostDeviceVector(n, distribution, &v);
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
PlusOne(&v);
CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kRead);
CheckHost(&v, GPUAccess::kWrite);

// From CPU to GPUs.
vec.Shard(devices);
for (size_t i = 0; i < devices.Size(); ++i) {
auto span = vec.DeviceSpan(i); // sync to device
}
PlusOne(&vec);
auto distribution1 = GPUDistribution::Overlap(GPUSet::Range(0, n_devices), 1);
v.Reshard(distribution1);

// Reshard is allowed for already sharded vector.
vec.Reshard(GPUDistribution::Overlap(devices, 7));
size_t total_size = 0;
for (size_t i = 0; i < devices.Size(); ++i) {
total_size += vec.DeviceSize(i);
devices_size[i] = vec.DeviceSize(i);
auto span = v.DeviceSpan(i); // sync to device
}
size_t overlap = 7 * (devices.Size() - 1);
ASSERT_EQ(total_size, h_vec.size() + overlap);
ASSERT_EQ(total_size, vec.Size() + overlap);

auto h_vec_1 = vec.HostVector();
for (size_t i = 0; i < h_vec_1.size(); ++i) {
ASSERT_EQ(h_vec_1.at(i), i + 1);
}
std::vector<size_t> starts1{0, 500};
std::vector<size_t> sizes1{501, 501};
CheckDevice(&v, starts1, sizes1, 1, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kRead);
CheckHost(&v, GPUAccess::kWrite);

SetCudaSetDeviceHandler(nullptr);
}
#endif

Expand Down

0 comments on commit b1b4498

Please sign in to comment.