Skip to content

Commit

Permalink
[JNI] rmm based pinned pool (#15219)
Browse files Browse the repository at this point in the history
Part of #14782.

This PR removes our old implementation of the java based pinned memory pool and replaces it with a jni layer on top of `rmm::pool_memory_resource<rmm::pinned_host_memory_resource>`

This PR does NOT set the default cuIO pinned host resource. That is happening after this PR goes in #15079. We'll need a follow on PR to change `PinnedMemoryPool.initialize` method to add an argument to set the cuIO pinned host resource.

I have run with this and version of it that are shared with cuIO and I can't find regressions in NDS at SF3K.

Note that we don't align anymore on our side. RMM is doing the same alignment we were doing before, using `std::max_align_t`.

Note also that the rmm pool doesn't have a quick way to find out what the current size is. So we had some tests that were asserting for this, and I have removed the asserts. If we would like to get that back I am happy to work with RMM to figure out how to do that.

Authors:
  - Alessandro Bellina (https://github.com/abellina)

Approvers:
  - Jim Brennan (https://github.com/jbrennan333)

URL: #15219
  • Loading branch information
abellina authored Mar 5, 2024
1 parent d53df8c commit 176f75b
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 269 deletions.
281 changes: 29 additions & 252 deletions java/src/main/java/ai/rapids/cudf/PinnedMemoryPool.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,109 +22,46 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Comparator;
import java.util.Iterator;
import java.util.Objects;
import java.util.Optional;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

/**
* This provides a pool of pinned memory similar to what RMM does for device memory.
* This is the JNI interface to a rmm::pool_memory_resource<rmm::pinned_host_memory_resource>.
*/
public final class PinnedMemoryPool implements AutoCloseable {
private static final Logger log = LoggerFactory.getLogger(PinnedMemoryPool.class);
private static final long ALIGNMENT = ColumnView.hostPaddingSizeInBytes();

// These static fields should only ever be accessed when class-synchronized.
// Do NOT use singleton_ directly! Use the getSingleton accessor instead.
private static volatile PinnedMemoryPool singleton_ = null;
private static Future<PinnedMemoryPool> initFuture = null;

private final long totalPoolSize;
private final long pinnedPoolBase;
private final SortedSet<MemorySection> freeHeap = new TreeSet<>(new SortedByAddress());
private int numAllocatedSections = 0;
private long availableBytes;

private static class SortedBySize implements Comparator<MemorySection> {
@Override
public int compare(MemorySection s0, MemorySection s1) {
return Long.compare(s0.size, s1.size);
}
}

private static class SortedByAddress implements Comparator<MemorySection> {
@Override
public int compare(MemorySection s0, MemorySection s1) {
return Long.compare(s0.baseAddress, s1.baseAddress);
}
}

private static class MemorySection {
private long baseAddress;
private long size;

MemorySection(long baseAddress, long size) {
this.baseAddress = baseAddress;
this.size = size;
}

boolean canCombine(MemorySection other) {
boolean ret = (other.baseAddress + other.size) == baseAddress ||
(baseAddress + size) == other.baseAddress;
log.trace("CAN {} COMBINE WITH {} ? {}", this, other, ret);
return ret;
}

void combineWith(MemorySection other) {
assert canCombine(other);
log.trace("COMBINING {} AND {}", this, other);
this.baseAddress = Math.min(baseAddress, other.baseAddress);
this.size = other.size + this.size;
log.trace("COMBINED TO {}\n", this);
}

MemorySection splitOff(long newSize) {
assert this.size > newSize;
MemorySection ret = new MemorySection(baseAddress, newSize);
this.baseAddress += newSize;
this.size -= newSize;
return ret;
}

@Override
public String toString() {
return "PINNED: " + size + " bytes (0x" + Long.toHexString(baseAddress)
+ " to 0x" + Long.toHexString(baseAddress + size) + ")";
}
}
private long poolHandle;
private long poolSize;

private static final class PinnedHostBufferCleaner extends MemoryBuffer.MemoryBufferCleaner {
private MemorySection section;
private long address;
private final long origLength;

PinnedHostBufferCleaner(MemorySection section, long length) {
this.section = section;
PinnedHostBufferCleaner(long address, long length) {
this.address = address;
origLength = length;
}

@Override
protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) {
boolean neededCleanup = false;
long origAddress = 0;
if (section != null) {
origAddress = section.baseAddress;
if (address != -1) {
origAddress = address;
try {
PinnedMemoryPool.freeInternal(section);
PinnedMemoryPool.freeInternal(address, origLength);
} finally {
// Always mark the resource as freed even if an exception is thrown.
// We cannot know how far it progressed before the exception, and
// therefore it is unsafe to retry.
section = null;
address = -1;
}
neededCleanup = true;
}
Expand All @@ -137,7 +74,7 @@ protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) {

@Override
public boolean isClean() {
return section == null;
return address == -1;
}
}

Expand All @@ -161,16 +98,8 @@ private static PinnedMemoryPool getSingleton() {
return singleton_;
}

private static void freeInternal(MemorySection section) {
Objects.requireNonNull(getSingleton()).free(section);
}

/**
* Used to indicate that memory was allocated from a reservation. This primarily is for
* keeping track of outstanding allocations.
*/
private static void reserveAllocInternal(MemorySection section) {
Objects.requireNonNull(getSingleton()).reserveAllocHappened(section);
private static void freeInternal(long address, long origLength) {
Objects.requireNonNull(getSingleton()).free(address, origLength);
}

/**
Expand Down Expand Up @@ -209,12 +138,14 @@ public static boolean isInitialized() {
}

/**
* Shut down the pool of memory. If there are outstanding allocations this may fail.
* Shut down the RMM pool_memory_resource, nulling out our reference. Any allocation
* or free that is in flight will fail after this.
*/
public static synchronized void shutdown() {
PinnedMemoryPool pool = getSingleton();
if (pool != null) {
pool.close();
pool = null;
}
initFuture = null;
singleton_ = null;
Expand All @@ -235,21 +166,6 @@ public static HostMemoryBuffer tryAllocate(long bytes) {
return result;
}

/**
* Factory method to create a pinned host memory reservation.
*
* @param bytes size in bytes to reserve
* @return newly created reservation or null if insufficient pinned memory to cover it.
*/
public static HostMemoryReservation tryReserve(long bytes) {
HostMemoryReservation result = null;
PinnedMemoryPool pool = getSingleton();
if (pool != null) {
result = pool.tryReserveInternal(bytes);
}
return result;
}

/**
* Factory method to create a host buffer but preferably pointing to pinned memory.
* It is not guaranteed that the returned buffer will be pointer to pinned memory.
Expand All @@ -276,26 +192,13 @@ public static HostMemoryBuffer allocate(long bytes) {
return allocate(bytes, DefaultHostMemoryAllocator.get());
}

/**
* Get the number of bytes free in the pinned memory pool.
*
* @return amount of free memory in bytes or 0 if the pool is not initialized
*/
public static long getAvailableBytes() {
PinnedMemoryPool pool = getSingleton();
if (pool != null) {
return pool.getAvailableBytesInternal();
}
return 0;
}

/**
* Get the number of bytes that the pinned memory pool was allocated with.
*/
public static long getTotalPoolSizeBytes() {
PinnedMemoryPool pool = getSingleton();
if (pool != null) {
return pool.getTotalPoolSizeInternal();
return pool.poolSize;
}
return 0;
}
Expand All @@ -306,157 +209,31 @@ private PinnedMemoryPool(long poolSize, int gpuId) {
Cuda.setDevice(gpuId);
Cuda.freeZero();
}
this.totalPoolSize = poolSize;
this.pinnedPoolBase = Cuda.hostAllocPinned(poolSize);
freeHeap.add(new MemorySection(pinnedPoolBase, poolSize));
this.availableBytes = poolSize;
this.poolHandle = Rmm.newPinnedPoolMemoryResource(poolSize, poolSize);
this.poolSize = poolSize;
}

@Override
public void close() {
assert numAllocatedSections == 0 : "Leaked " + numAllocatedSections + " pinned allocations";
Cuda.freePinned(pinnedPoolBase);
Rmm.releasePinnedPoolMemoryResource(this.poolHandle);
this.poolHandle = -1;
}

/**
* Pads a length of bytes to the alignment the CPU wants in the worst case. This helps to
* calculate the size needed for a reservation if there are multiple buffers.
* @param bytes the size in bytes
* @return the new padded size in bytes.
* This makes an attempt to allocate pinned memory, and if the pinned memory allocation fails
* it will return null, instead of throw.
*/
public static long padToCpuAlignment(long bytes) {
return ((bytes + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT;
}

private synchronized MemorySection tryGetInternal(long bytes, String what) {
if (freeHeap.isEmpty()) {
log.debug("No free pinned memory left");
return null;
}
// Align the allocation
long alignedBytes = padToCpuAlignment(bytes);
Optional<MemorySection> firstFit = freeHeap.stream()
.filter(section -> section.size >= alignedBytes)
.findFirst();
if (!firstFit.isPresent()) {
if (log.isDebugEnabled()) {
MemorySection largest = freeHeap.stream()
.max(new SortedBySize())
.orElse(new MemorySection(0, 0));
log.debug("Insufficient pinned memory. {} needed, {} found", alignedBytes, largest.size);
}
return null;
}
MemorySection first = firstFit.get();
log.debug("{} {}/{} bytes pinned from {} FREE COUNT {} OUTSTANDING COUNT {}",
what, bytes, alignedBytes, first, freeHeap.size(), numAllocatedSections);
freeHeap.remove(first);
MemorySection allocated;
if (first.size == alignedBytes) {
allocated = first;
} else {
allocated = first.splitOff(alignedBytes);
freeHeap.add(first);
}
numAllocatedSections++;
availableBytes -= allocated.size;
log.debug("{} {} free {} outstanding {}", what, allocated, freeHeap, numAllocatedSections);
return allocated;
}

private synchronized HostMemoryBuffer tryAllocateInternal(long bytes) {
MemorySection allocated = tryGetInternal(bytes, "allocate");
if (allocated == null) {
long allocated = Rmm.allocFromPinnedPool(this.poolHandle, bytes);
if (allocated == -1) {
return null;
} else {
return new HostMemoryBuffer(allocated.baseAddress, bytes,
return new HostMemoryBuffer(allocated, bytes,
new PinnedHostBufferCleaner(allocated, bytes));
}
}

private class PinnedReservation implements HostMemoryReservation {
private MemorySection section = null;

public PinnedReservation(MemorySection section) {
this.section = section;
}

@Override
public synchronized HostMemoryBuffer allocate(long bytes, boolean preferPinned) {
return this.allocate(bytes);
}

@Override
public synchronized HostMemoryBuffer allocate(long bytes) {
if (section == null || section.size < bytes) {
throw new OutOfMemoryError("Reservation didn't have enough space " + bytes + " / " +
(section == null ? 0 : section.size));
}
long alignedSize = padToCpuAlignment(bytes);
MemorySection allocated;
if (section.size >= bytes && section.size <= alignedSize) {
allocated = section;
section = null;
// No need for reserveAllocInternal because the original section is already tracked
} else {
allocated = section.splitOff(alignedSize);
PinnedMemoryPool.reserveAllocInternal(allocated);
}
return new HostMemoryBuffer(allocated.baseAddress, bytes,
new PinnedHostBufferCleaner(allocated, bytes));
}

@Override
public synchronized void close() throws Exception {
if (section != null) {
try {
PinnedMemoryPool.freeInternal(section);
} finally {
// Always mark the resource as freed even if an exception is thrown.
// We cannot know how far it progressed before the exception, and
// therefore it is unsafe to retry.
section = null;
}
}
}
}

private HostMemoryReservation tryReserveInternal(long bytes) {
MemorySection allocated = tryGetInternal(bytes, "allocate");
if (allocated == null) {
return null;
} else {
return new PinnedReservation(allocated);
}
}

private synchronized void free(MemorySection section) {
log.debug("Freeing {} with {} outstanding {}", section, freeHeap, numAllocatedSections);
availableBytes += section.size;
Iterator<MemorySection> it = freeHeap.iterator();
while(it.hasNext()) {
MemorySection current = it.next();
if (section.canCombine(current)) {
it.remove();
section.combineWith(current);
}
}
freeHeap.add(section);
numAllocatedSections--;
log.debug("After freeing {} outstanding {}", freeHeap, numAllocatedSections);
}

private synchronized void reserveAllocHappened(MemorySection section) {
if (section != null && section.size > 0) {
numAllocatedSections++;
}
}

private synchronized long getAvailableBytesInternal() {
return this.availableBytes;
}

private long getTotalPoolSizeInternal() {
return this.totalPoolSize;
private synchronized void free(long address, long size) {
Rmm.freeFromPinnedPool(this.poolHandle, address, size);
}
}
Loading

0 comments on commit 176f75b

Please sign in to comment.