Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx] Adds fp16 and bfp16 support for OnnxRuntime #3281

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,33 @@ public static OnnxTensor toTensor(
if (shape.size() == 0) {
throw new UnsupportedOperationException("OnnxRuntime doesn't support 0 length tensor.");
}
if (data instanceof ByteBuffer) {
data = dataType.asDataType((ByteBuffer) data);
}
long[] sh = shape.getShape();
try {
switch (dataType) {
case FLOAT32:
return OnnxTensor.createTensor(env, (FloatBuffer) data, sh);
return OnnxTensor.createTensor(env, asFloatBuffer(data), sh);
case FLOAT64:
return OnnxTensor.createTensor(env, (DoubleBuffer) data, sh);
return OnnxTensor.createTensor(env, asDoubleBuffer(data), sh);
case FLOAT16:
return OnnxTensor.createTensor(
env, (ByteBuffer) data, sh, OnnxJavaType.FLOAT16);
case BFLOAT16:
return OnnxTensor.createTensor(
env, (ByteBuffer) data, sh, OnnxJavaType.BFLOAT16);
case INT32:
return OnnxTensor.createTensor(env, (IntBuffer) data, sh);
return OnnxTensor.createTensor(env, asIntBuffer(data), sh);
case INT64:
return OnnxTensor.createTensor(env, (LongBuffer) data, sh);
return OnnxTensor.createTensor(env, asLongBuffer(data), sh);
case INT8:
return OnnxTensor.createTensor(env, (ByteBuffer) data, sh, OnnxJavaType.INT8);
case UINT8:
return OnnxTensor.createTensor(env, (ByteBuffer) data, sh, OnnxJavaType.UINT8);
case BOOLEAN:
return OnnxTensor.createTensor(env, (ByteBuffer) data, sh, OnnxJavaType.BOOL);
case STRING:
throw new UnsupportedOperationException(
"Use toTensor(OrtEnvironment env, String[] inputs, Shape shape)"
+ " instead.");
case BOOLEAN:
case FLOAT16:
default:
throw new UnsupportedOperationException("Data type not supported: " + dataType);
}
Expand All @@ -81,6 +84,10 @@ public static DataType toDataType(OnnxJavaType javaType) {
switch (javaType) {
case FLOAT:
return DataType.FLOAT32;
case FLOAT16:
return DataType.FLOAT16;
case BFLOAT16:
return DataType.BFLOAT16;
case DOUBLE:
return DataType.FLOAT64;
case INT8:
Expand All @@ -101,4 +108,32 @@ public static DataType toDataType(OnnxJavaType javaType) {
throw new UnsupportedOperationException("type is not supported: " + javaType);
}
}

private static FloatBuffer asFloatBuffer(Buffer data) {
if (data instanceof ByteBuffer) {
return ((ByteBuffer) data).asFloatBuffer();
}
return (FloatBuffer) data;
}

private static DoubleBuffer asDoubleBuffer(Buffer data) {
if (data instanceof ByteBuffer) {
return ((ByteBuffer) data).asDoubleBuffer();
}
return (DoubleBuffer) data;
}

private static IntBuffer asIntBuffer(Buffer data) {
if (data instanceof ByteBuffer) {
return ((ByteBuffer) data).asIntBuffer();
}
return (IntBuffer) data;
}

private static LongBuffer asLongBuffer(Buffer data) {
if (data instanceof ByteBuffer) {
return ((ByteBuffer) data).asLongBuffer();
}
return (LongBuffer) data;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.onnxruntime.zoo.tabular.softmax_regression.IrisFlower;
import ai.djl.repository.zoo.Criteria;
Expand Down Expand Up @@ -143,6 +144,42 @@ public void testNDArray() throws OrtException {

float[][] value = (float[][]) ((OrtNDArray) ones).getTensor().getValue();
Assert.assertEquals(value[0], new float[] {1, 1});

array = manager.create(new Shape(1), DataType.BOOLEAN);
Assert.assertEquals(array.getDataType(), DataType.BOOLEAN);

array = manager.create(new Shape(1), DataType.FLOAT16);
Assert.assertEquals(array.getDataType(), DataType.FLOAT16);

array = manager.create(new Shape(1), DataType.BFLOAT16);
Assert.assertEquals(array.getDataType(), DataType.BFLOAT16);

array = manager.create(new double[] {0});
Assert.assertEquals(array.getDataType(), DataType.FLOAT64);

array = manager.create(new Shape(1), DataType.FLOAT64);
Assert.assertEquals(array.getDataType(), DataType.FLOAT64);

array = manager.create(new Shape(1), DataType.INT8);
Assert.assertEquals(array.getDataType(), DataType.INT8);

array = manager.create(new Shape(1), DataType.UINT8);
Assert.assertEquals(array.getDataType(), DataType.UINT8);

array = manager.create(new int[] {0});
Assert.assertEquals(array.getDataType(), DataType.INT32);

array = manager.create(new Shape(1), DataType.INT32);
Assert.assertEquals(array.getDataType(), DataType.INT32);

array = manager.create(new long[] {0L});
Assert.assertEquals(array.getDataType(), DataType.INT64);

array = manager.create(new Shape(1), DataType.INT64);
Assert.assertEquals(array.getDataType(), DataType.INT64);

Assert.assertThrows(() -> manager.create(new Shape(0), DataType.FLOAT32));
Assert.assertThrows(() -> manager.create(new Shape(1), DataType.UINT32));
}
}

Expand Down
Loading