Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JNI] Enables fabric handles for CUDA async memory pools #17526

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions java/src/main/java/ai/rapids/cudf/Rmm.java
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ private static void setGlobalValsFromResource(RmmDeviceMemoryResource resource)
* {@link RmmAllocationMode#CUDA_DEFAULT},
* {@link RmmAllocationMode#POOL},
* {@link RmmAllocationMode#ARENA},
* {@link RmmAllocationMode#CUDA_ASYNC} and
* {@link RmmAllocationMode#CUDA_ASYNC},
* {@link RmmAllocationMode#CUDA_ASYNC_FABRIC} and
* {@link RmmAllocationMode#CUDA_MANAGED_MEMORY}
* @param logConf How to do logging or null if you don't want to
* @param poolSize The initial pool size in bytes
Expand All @@ -221,6 +222,7 @@ public static synchronized void initialize(int allocationMode, LogConf logConf,
boolean isPool = (allocationMode & RmmAllocationMode.POOL) != 0;
boolean isArena = (allocationMode & RmmAllocationMode.ARENA) != 0;
boolean isAsync = (allocationMode & RmmAllocationMode.CUDA_ASYNC) != 0;
boolean isAsyncFabric = (allocationMode & RmmAllocationMode.CUDA_ASYNC_FABRIC) != 0;
boolean isManaged = (allocationMode & RmmAllocationMode.CUDA_MANAGED_MEMORY) != 0;

if (isAsync && isManaged) {
Expand All @@ -246,6 +248,9 @@ public static synchronized void initialize(int allocationMode, LogConf logConf,
} else if (isAsync) {
resource = new RmmLimitingResourceAdaptor<>(
new RmmCudaAsyncMemoryResource(poolSize, poolSize), poolSize, 512);
} else if (isAsyncFabric) {
resource = new RmmLimitingResourceAdaptor<>(
new RmmCudaAsyncMemoryResource(poolSize, poolSize, true), poolSize, 512);
} else if (isManaged) {
resource = new RmmManagedMemoryResource();
} else {
Expand Down Expand Up @@ -521,7 +526,6 @@ public static DeviceMemoryBuffer alloc(long size, Cuda.Stream stream) {

private static native long allocInternal(long size, long stream) throws RmmException;


static native void free(long ptr, long length, long stream) throws RmmException;

/**
Expand Down Expand Up @@ -562,7 +566,7 @@ static native long newArenaMemoryResource(long childHandle,

static native void releaseArenaMemoryResource(long handle);

static native long newCudaAsyncMemoryResource(long size, long release) throws RmmException;
static native long newCudaAsyncMemoryResource(long size, long release, boolean fabric) throws RmmException;

static native void releaseCudaAsyncMemoryResource(long handle);

Expand All @@ -575,7 +579,6 @@ static native long newLoggingResourceAdaptor(long handle, int type, String path,

static native void releaseLoggingResourceAdaptor(long handle);


static native long newTrackingResourceAdaptor(long handle, long alignment) throws RmmException;

static native void releaseTrackingResourceAdaptor(long handle);
Expand Down
7 changes: 6 additions & 1 deletion java/src/main/java/ai/rapids/cudf/RmmAllocationMode.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,4 +36,9 @@ public class RmmAllocationMode {
* Use CUDA async suballocation strategy
*/
public static final int CUDA_ASYNC = 0x00000008;
/**
* Use CUDA async suballocation strategy with fabric handles that are
* peer accessible with read-write access
*/
public static final int CUDA_ASYNC_FABRIC = 0x00000010;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,9 +30,20 @@ public class RmmCudaAsyncMemoryResource implements RmmDeviceMemoryResource {
* @param releaseThreshold size in bytes for when memory is released back to cuda
*/
public RmmCudaAsyncMemoryResource(long size, long releaseThreshold) {
this(size, releaseThreshold, false);
}

/**
* Create a new async memory resource
* @param size the initial size of the pool
* @param releaseThreshold size in bytes for when memory is released back to cuda
* @param fabric if true request peer read+write accessible fabric handles when
* creating the pool
*/
public RmmCudaAsyncMemoryResource(long size, long releaseThreshold, boolean fabric) {
this.size = size;
this.releaseThreshold = releaseThreshold;
handle = Rmm.newCudaAsyncMemoryResource(size, releaseThreshold);
handle = Rmm.newCudaAsyncMemoryResource(size, releaseThreshold, fabric);
}

@Override
Expand Down
20 changes: 15 additions & 5 deletions java/src/main/native/src/RmmJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -772,14 +772,24 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseArenaMemoryResource(JNIEnv
CATCH_STD(env, )
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newCudaAsyncMemoryResource(JNIEnv* env,
jclass clazz,
jlong init,
jlong release)
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newCudaAsyncMemoryResource(
JNIEnv* env, jclass clazz, jlong init, jlong release, jboolean fabric)
{
try {
cudf::jni::auto_set_device(env);
auto ret = new rmm::mr::cuda_async_memory_resource(init, release);

// When we are using fabric, we need to set the memory access to be
// read_write, in order for peer GPUs to have access to this memory.
// Otherwise, choose default parameters (optional set to nullopt).
auto [handle_type, prot_flag] =
fabric
? std::pair{std::optional{
rmm::mr::cuda_async_memory_resource::allocation_handle_type::fabric},
std::optional{rmm::mr::cuda_async_memory_resource::access_flags::read_write}}
: std::pair{std::nullopt, std::nullopt};

auto ret = new rmm::mr::cuda_async_memory_resource(init, release, handle_type, prot_flag);

return reinterpret_cast<jlong>(ret);
}
CATCH_STD(env, 0)
Expand Down
Loading