Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up buffers in case AssertionError #13262

Merged
merged 20 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/BitVectorHelper.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -71,7 +71,7 @@ private static void shiftSrcLeftAndWriteToDst(HostMemoryBuffer src, HostMemoryBu
/**
* This method returns the length in bytes needed to represent X number of rows
* e.g. getValidityLengthInBytes(5) => 1 byte
* getLengthInBytes(7) => 1 byte
* getValidityLengthInBytes(7) => 1 byte
abellina marked this conversation as resolved.
Show resolved Hide resolved
* getValidityLengthInBytes(14) => 2 bytes
*/
static long getValidityLengthInBytes(long rows) {
Expand Down
7 changes: 5 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ public ColumnVector(DType type, long rows, Optional<Long> nullCount,
incRefCountInternal(true);
}

private static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nullCount,
/**
* This method is internal and exposed purely for testing purposes
*/
static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nullCount,
abellina marked this conversation as resolved.
Show resolved Hide resolved
DeviceMemoryBuffer dataBuffer, DeviceMemoryBuffer validityBuffer,
DeviceMemoryBuffer offsetBuffer, List<DeviceMemoryBuffer> toClose, long[] childHandles) {
long viewHandle = initViewHandle(type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
Expand All @@ -141,7 +144,7 @@ private static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nu
* @param offsetBuffer a host buffer required for strings and string categories. The column
* vector takes ownership of the buffer. Do not use the buffer after calling
* this.
* @param toClose List of buffers to track adn close once done, usually in case of children
* @param toClose List of buffers to track and close once done, usually in case of children
* @param childHandles array of longs for child column view handles.
*/
public ColumnVector(DType type, long rows, Optional<Long> nullCount,
Expand Down
59 changes: 51 additions & 8 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,50 @@ public class ColumnView implements AutoCloseable, BinaryOperable {
protected final ColumnVector.OffHeapState offHeap;

/**
* Constructs a Column View given a native view address
* Constructs a Column View given a native view address. This asserts that if the ColumnView is
* of nested-type it doesn't contain non-empty nulls
* @param address the view handle
* @throws AssertionError if the address points to a nested-type view with non-empty nulls
*/
ColumnView(long address) {
this.viewHandle = address;
this.type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle));
this.rows = ColumnView.getNativeRowCount(viewHandle);
this.nullCount = ColumnView.getNativeNullCount(viewHandle);
this.offHeap = null;
AssertEmptyNulls.assertNullsAreEmpty(this);
try {
razajafri marked this conversation as resolved.
Show resolved Hide resolved
AssertEmptyNulls.assertNullsAreEmpty(this);
} catch (AssertionError ae) {
// offHeap state is null, so there is nothing to clean in offHeap
// delete ColumnView to avoid memory leak
deleteColumnView(viewHandle);
viewHandle = 0;
throw ae;
}
}


/**
* Intended to be called from ColumnVector when it is being constructed. Because state creates a
* cudf::column_view instance and will close it in all cases, we don't want to have to double
* close it.
* close it. This asserts that if the offHeapState is of nested-type it doesn't contain non-empty nulls
* @param state the state this view is based off of.
* @throws AssertionError if offHeapState points to a nested-type view with non-empty nulls
*/
protected ColumnView(ColumnVector.OffHeapState state) {
offHeap = state;
viewHandle = state.getViewHandle();
type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle));
rows = ColumnView.getNativeRowCount(viewHandle);
nullCount = ColumnView.getNativeNullCount(viewHandle);
AssertEmptyNulls.assertNullsAreEmpty(this);
try {
AssertEmptyNulls.assertNullsAreEmpty(this);
} catch (AssertionError ae) {
// cleanup offHeap
offHeap.clean(false);
abellina marked this conversation as resolved.
Show resolved Hide resolved
viewHandle = 0;
throw ae;
}
}

/**
Expand Down Expand Up @@ -649,8 +667,14 @@ public final ColumnVector ifElse(Scalar trueValue, Scalar falseValue) {
public final ColumnVector[] slice(int... indices) {
long[] nativeHandles = slice(this.getNativeView(), indices);
ColumnVector[] columnVectors = new ColumnVector[nativeHandles.length];
for (int i = 0; i < nativeHandles.length; i++) {
columnVectors[i] = new ColumnVector(nativeHandles[i]);
try {
for (int i = 0; i < nativeHandles.length; i++) {
columnVectors[i] = new ColumnVector(nativeHandles[i]);
nativeHandles[i] = 0;
}
} catch (Throwable t) {
cleanupColumnViews(nativeHandles, columnVectors);
throw t;
}
return columnVectors;
}
Expand Down Expand Up @@ -788,12 +812,31 @@ public final ColumnVector[] split(int... indices) {
public ColumnView[] splitAsViews(int... indices) {
long[] nativeHandles = split(this.getNativeView(), indices);
ColumnView[] columnViews = new ColumnView[nativeHandles.length];
for (int i = 0; i < nativeHandles.length; i++) {
columnViews[i] = new ColumnView(nativeHandles[i]);
try {
for (int i = 0; i < nativeHandles.length; i++) {
columnViews[i] = new ColumnView(nativeHandles[i]);
nativeHandles[i] = 0;
}
} catch (Throwable t) {
cleanupColumnViews(nativeHandles, columnViews);
throw t;
}
return columnViews;
}

static void cleanupColumnViews(long[] nativeHandles, ColumnView[] columnViews) {
for (ColumnView columnView: columnViews) {
if (columnView != null) {
columnView.close();
}
}
for (long nativeHandle: nativeHandles) {
if (nativeHandle != 0) {
deleteColumnView(nativeHandle);
}
}
}

/**
* Create a new vector of "normalized" values, where:
* 1. All representations of NaN (and -NaN) are replaced with the normalized NaN value
Expand Down
38 changes: 19 additions & 19 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public Table(long[] cudfColumns) {
try {
for (int i = 0; i < cudfColumns.length; i++) {
this.columns[i] = new ColumnVector(cudfColumns[i]);
cudfColumns[i] = 0;
}
long[] views = new long[columns.length];
for (int i = 0; i < columns.length; i++) {
Expand All @@ -95,13 +96,7 @@ public Table(long[] cudfColumns) {
nativeHandle = createCudfTableView(views);
this.rows = columns[0].getRowCount();
} catch (Throwable t) {
for (int i = 0; i < cudfColumns.length; i++) {
if (this.columns[i] != null) {
this.columns[i].close();
} else {
ColumnVector.deleteCudfColumn(cudfColumns[i]);
}
}
ColumnView.cleanupColumnViews(cudfColumns, this.columns);
abellina marked this conversation as resolved.
Show resolved Hide resolved
throw t;
}
}
Expand Down Expand Up @@ -3396,8 +3391,14 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe
public ColumnVector[] convertToRows() {
long[] ptrs = convertToRows(nativeHandle);
ColumnVector[] ret = new ColumnVector[ptrs.length];
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
try {
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
ptrs[i] = 0;
}
} catch (Throwable t) {
ColumnView.cleanupColumnViews(ptrs, ret);
throw t;
}
return ret;
}
Expand Down Expand Up @@ -3479,8 +3480,14 @@ public ColumnVector[] convertToRows() {
public ColumnVector[] convertToRowsFixedWidthOptimized() {
long[] ptrs = convertToRowsFixedWidthOptimized(nativeHandle);
ColumnVector[] ret = new ColumnVector[ptrs.length];
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
try {
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
ptrs[i] = 0;
}
} catch (Throwable t) {
ColumnView.cleanupColumnViews(ptrs, ret);
throw t;
}
return ret;
}
Expand Down Expand Up @@ -3552,14 +3559,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data
}
result = new Table(columns);
} catch (Throwable t) {
for (int i = 0; i < columns.length; i++) {
if (columns[i] != null) {
columns[i].close();
}
if (columnViewAddresses[i] != 0) {
ColumnView.deleteColumnView(columnViewAddresses[i]);
}
}
ColumnView.cleanupColumnViews(columnViewAddresses, columns);
throw t;
}

Expand Down
50 changes: 49 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -6677,6 +6678,54 @@ void testApplyBooleanMaskFromListOfStructure() {
}
}

@Test
void testColumnViewWithNonEmptyNullsIsCleared() {
abellina marked this conversation as resolved.
Show resolved Hide resolved
List<Integer> list0 = Arrays.asList(1, 2, 3);
List<Integer> list1 = Arrays.asList(4, 5, null);
List<Integer> list2 = Arrays.asList(7, 8, 9);
List<Integer> list3 = null;
try (ColumnVector input = ColumnVectorTest.makeListsColumn(DType.INT32, list0, list1, list2, list3);
BaseDeviceMemoryBuffer baseValidityBuffer = input.getDeviceBufferFor(BufferType.VALIDITY);
BaseDeviceMemoryBuffer baseOffsetBuffer = input.getDeviceBufferFor(BufferType.OFFSET);
HostMemoryBuffer newValidity = HostMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(4))) {

newValidity.copyFromDeviceBuffer(baseValidityBuffer);
// we are setting list1 with 3 elements to null. This will result in a non-empty null in the
// ColumnView at index 1
BitVectorHelper.setNullAt(newValidity, 1);
// validityBuffer will be closed by offHeapState later
DeviceMemoryBuffer validityBuffer = DeviceMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(4));
try {
// offsetBuffer will be closed by offHeapState later
DeviceMemoryBuffer offsetBuffer = DeviceMemoryBuffer.allocate(baseOffsetBuffer.getLength());
try {
validityBuffer.copyFromHostBuffer(newValidity);
offsetBuffer.copyFromMemoryBuffer(0, baseOffsetBuffer, 0,
baseOffsetBuffer.length, Cuda.DEFAULT_STREAM);

// The new offHeapState will have 2 nulls, one null at index 4 from the original ColumnVector
// the other at index 1 which is non-empty
ColumnVector.OffHeapState offHeapState = ColumnVector.makeOffHeap(input.type, input.rows, Optional.of(2L),
null, validityBuffer, offsetBuffer,
null, Arrays.stream(input.getChildColumnViews()).mapToLong((c) -> c.viewHandle).toArray());
try {
new ColumnView(offHeapState);
} catch (AssertionError ae) {
assert offHeapState.isClean();
}
} catch (Exception e) {
if (!offsetBuffer.closed) {
offsetBuffer.close();
}
}
} catch (Exception e) {
if (!validityBuffer.closed) {
validityBuffer.close();
}
}
}
}

@Test
public void testEventHandlerIsCalledForEachClose() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
Expand All @@ -6700,5 +6749,4 @@ public void testEventHandlerIsNotCalledIfNotSet() {
}
assertEquals(0, onClosedWasCalled.get());
}

}