Skip to content

Commit

Permalink
[Runtime][ROCm] Enable ROCm host memory support (#17037)
Browse files Browse the repository at this point in the history
This PR enables the ROCMHost memory support in ROCm device API.
  • Loading branch information
MasterJH5574 authored May 30, 2024
1 parent 291c047 commit 08b32a7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str

ICHECK(from->device.device_type == to->device.device_type || from->device.device_type == kDLCPU ||
to->device.device_type == kDLCPU || from->device.device_type == kDLCUDAHost ||
to->device.device_type == kDLCUDAHost)
to->device.device_type == kDLCUDAHost || from->device.device_type == kDLROCMHost ||
to->device.device_type == kDLROCMHost)
<< "Can not copy across different device types directly. From device type: "
<< from->device.device_type << " to device type: " << to->device.device_type;

Expand Down
40 changes: 35 additions & 5 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,26 @@ class ROCMDeviceAPI final : public DeviceAPI {
*rv = value;
}
void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
ROCM_CALL(hipSetDevice(dev.device_id));
ICHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
void* ret;
ROCM_CALL(hipMalloc(&ret, nbytes));
if (dev.device_type == kDLROCMHost) {
VLOG(1) << "allocating " << nbytes << "bytes on host";
ROCM_CALL(hipHostMalloc(&ret, nbytes));
} else {
ROCM_CALL(hipSetDevice(dev.device_id));
VLOG(1) << "allocating " << nbytes << " bytes on device";
ROCM_CALL(hipMalloc(&ret, nbytes));
}
return ret;
}

void FreeDataSpace(Device dev, void* ptr) final {
ROCM_CALL(hipSetDevice(dev.device_id));
ROCM_CALL(hipFree(ptr));
if (dev.device_type == kDLROCMHost) {
ROCM_CALL(hipHostFree(ptr));
} else {
ROCM_CALL(hipSetDevice(dev.device_id));
ROCM_CALL(hipFree(ptr));
}
}

void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
Expand All @@ -162,6 +172,21 @@ class ROCMDeviceAPI final : public DeviceAPI {
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;

if (dev_from.device_type == kDLROCMHost) {
dev_from.device_type = kDLCPU;
}

if (dev_to.device_type == kDLROCMHost) {
dev_to.device_type = kDLCPU;
}

// In case there is a copy from host mem to host mem */
if (dev_to.device_type == kDLCPU && dev_from.device_type == kDLCPU) {
memcpy(to, from, size);
return;
}

if (dev_from.device_type == kDLROCM && dev_to.device_type == kDLROCM) {
ROCM_CALL(hipSetDevice(dev_from.device_id));
if (dev_from.device_id == dev_to.device_id) {
Expand Down Expand Up @@ -210,7 +235,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
private:
static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind,
hipStream_t stream) {
if (stream != 0) {
if (stream != nullptr) {
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
} else {
ROCM_CALL(hipMemcpy(to, from, size, kind));
Expand All @@ -229,6 +254,11 @@ TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv
*rv = static_cast<void*>(ptr);
});

TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

class ROCMTimerNode : public TimerNode {
public:
virtual void Start() {
Expand Down

0 comments on commit 08b32a7

Please sign in to comment.