Skip to content

Commit

Permalink
Fix build failure on GPU (#1279)
Browse files Browse the repository at this point in the history
Change-Id: I6a916d24668d3ed04cb7c43ca70455a5d049b0d6
  • Loading branch information
frankfliu authored Oct 8, 2021
1 parent 8629668 commit ff659d8
Show file tree
Hide file tree
Showing 15 changed files with 53 additions and 24 deletions.
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ public int getDeviceId() {
return deviceId;
}

/**
* Returns if the {@code Device} is GPU.
*
* @return if the {@code Device} is GPU.
*/
public boolean isGpu() {
return Type.GPU.equals(deviceType);
}

/** {@inheritDoc} */
@Override
public String toString() {
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/util/cuda/CudaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public static String getComputeCapability(int device) {
* @throws IllegalArgumentException if {@link Device} is not GPU device or does not exist
*/
public static MemoryUsage getGpuMemory(Device device) {
if (!Device.Type.GPU.equals(device.getDeviceType())) {
if (!device.isGpu()) {
throw new IllegalArgumentException("Only GPU device is allowed.");
}

Expand Down
1 change: 0 additions & 1 deletion api/src/test/java/ai/djl/util/PlatformTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public void testPlatform() throws IOException {
Assert.assertEquals(system.getClassifier(), "linux-x86_64");
Assert.assertEquals(system.getOsPrefix(), "linux");
Assert.assertEquals(system.getOsArch(), "x86_64");
Assert.assertNull(system.getCudaArch());

url = createPropertyFile("version=1.8.0\nplaceholder=true");
Platform platform = Platform.fromUrl(url);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public static int getGpuCount() {
}

public static long[] getGpuMemory(Device device) {
if (!Device.Type.GPU.equals(device.getDeviceType())) {
if (!device.isGpu()) {
throw new IllegalArgumentException("Only GPU device is allowed.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
try {
Device device = manager.getDevice();
OrtSession session;
if (Device.Type.GPU.equals(device.getDeviceType())) {
if (device.isGpu()) {
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(manager.getDevice().getDeviceId());
session = env.createSession(modelFile.toString(), sessionOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public void set(Buffer data) {
if (data.isDirect() && data instanceof ByteBuffer) {
// If NDArray is on the GPU, it is native code responsibility to control the data life
// cycle
if (!Device.Type.GPU.equals(getDevice().getDeviceType())) {
if (!getDevice().isGpu()) {
dataRef = (ByteBuffer) data;
}
JniUtils.set(this, (ByteBuffer) data);
Expand All @@ -227,7 +227,7 @@ public void set(Buffer data) {
BaseNDManager.copyBuffer(data, buf);

// If NDArray is on the GPU, it is native code responsibility to control the data life cycle
if (!Device.Type.GPU.equals(getDevice().getDeviceType())) {
if (!getDevice().isGpu()) {
dataRef = buf;
}
JniUtils.set(this, buf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public static PtNDArray createNdFromByteBuffer(
new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()},
false);

if (layout == 1 || layout == 2 || Device.Type.GPU.equals(device.getDeviceType())) {
if (layout == 1 || layout == 2 || device.isGpu()) {
// MKLDNN & COO & GPU device will explicitly make a copy in native code
// so we don't want to hold a reference on Java side
return new PtNDArray(manager, handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ public ByteBuffer toByteBuffer() {
/** {@inheritDoc} */
@Override
public void set(Buffer data) {
if (getDevice().isGpu()) {
// TODO: Implement set for GPU
throw new UnsupportedOperationException("GPU Tensor cannot be modified after creation");
}
int size = Math.toIntExact(getShape().size());
BaseNDManager.validateBufferSize(data, getDataType(), size);
if (data instanceof ByteBuffer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public NDArray create(Shape shape, DataType dataType) {
// initialize with scalar 0
return create(0f).toType(dataType, false);
}
TFE_TensorHandle handle = JavacppUtils.createEmptyTFETensor(shape, dataType);
TFE_TensorHandle handle =
JavacppUtils.createEmptyTFETensor(shape, dataType, getEagerSession(), device);
return new TfNDArray(this, handle);
}

Expand All @@ -84,12 +85,15 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
BaseNDManager.validateBufferSize(data, dataType, size);
if (data.isDirect() && data instanceof ByteBuffer) {
TFE_TensorHandle handle =
JavacppUtils.createTFETensorFromByteBuffer((ByteBuffer) data, shape, dataType);
JavacppUtils.createTFETensorFromByteBuffer(
(ByteBuffer) data, shape, dataType, getEagerSession(), device);
return new TfNDArray(this, handle);
}
ByteBuffer buf = allocateDirect(size * dataType.getNumOfBytes());
copyBuffer(data, buf);
TFE_TensorHandle handle = JavacppUtils.createTFETensorFromByteBuffer(buf, shape, dataType);
TFE_TensorHandle handle =
JavacppUtils.createTFETensorFromByteBuffer(
buf, shape, dataType, getEagerSession(), device);
return new TfNDArray(this, handle);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public static Shape getShape(TFE_TensorHandle handle) {
}
}

public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {
private static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {
int dType = TfDataType.toTf(dataType);
long[] dims = shape.getShape();
long numBytes = dataType.getNumOfBytes() * shape.size();
Expand All @@ -260,12 +260,16 @@ public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {
}

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataType) {
public static TFE_TensorHandle createEmptyTFETensor(
Shape shape, DataType dataType, TFE_Context eagerSessionHandle, Device device) {
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = createEmptyTFTensor(shape, dataType);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
if (device.isGpu()) {
return toDevice(handle, eagerSessionHandle, device);
}
return handle.retainReference();
}
}
Expand Down Expand Up @@ -303,7 +307,11 @@ public static Pair<TF_Tensor, TFE_TensorHandle> createStringTensor(

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createTFETensorFromByteBuffer(
ByteBuffer buf, Shape shape, DataType dataType) {
ByteBuffer buf,
Shape shape,
DataType dataType,
TFE_Context eagerSessionHandle,
Device device) {
int dType = TfDataType.toTf(dataType);
long[] dims = shape.getShape();
long numBytes;
Expand All @@ -320,6 +328,9 @@ public static TFE_TensorHandle createTFETensorFromByteBuffer(
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
if (device.isGpu()) {
return toDevice(handle, eagerSessionHandle, device);
}
return handle.retainReference();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,14 @@ public void testNDArray() {
array.toStringArray()[1].getBytes(StandardCharsets.UTF_8), buf2.array());

array = manager.zeros(new Shape(2));
final NDArray b = array;
float[] expected = {2, 3};
array.set(expected);
Assert.assertEquals(array.toFloatArray(), expected);
if (array.getDevice().isGpu()) {
Assert.assertThrows(UnsupportedOperationException.class, () -> b.set(expected));
} else {
array.set(expected);
Assert.assertEquals(array.toFloatArray(), expected);
}

Assert.assertThrows(
IllegalArgumentException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public TrtNDManager newBaseManager() {
public TrtNDManager newBaseManager(Device device) {
// Only support GPU for now
device = device == null ? defaultDevice() : device;
if (!Device.Type.GPU.equals(device.getDeviceType())) {
if (!device.isGpu()) {
throw new IllegalArgumentException("TensorRT only support GPU");
}
return TrtNDManager.getSystemManager().newSubManager(device);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.tensorrt.engine;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
Expand All @@ -31,7 +30,7 @@ public void testNDArray() {
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) {
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
try (NDManager manager = TrtNDManager.getSystemManager().newSubManager()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.tensorrt.integration;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
Expand Down Expand Up @@ -50,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) {
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
Criteria<float[], float[]> criteria =
Expand All @@ -76,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) {
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
List<String> synset =
Expand Down Expand Up @@ -113,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) {
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
Criteria<float[], float[]> criteria =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ public class Arguments {
threads = Integer.parseInt(cmd.getOptionValue("threads"));
Engine eng = Engine.getEngine(engine);
Device[] devices = eng.getDevices(maxGpus);
String deviceType = devices[0].getDeviceType();
if (Device.Type.GPU.equals(deviceType)) {
if (devices[0].isGpu()) {
// one thread per GPU
if (threads <= 0) {
threads = devices.length;
Expand Down

0 comments on commit ff659d8

Please sign in to comment.