diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 758c052537f..f7a307be806 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -371,6 +371,15 @@ default boolean[] toBooleanArray() { return ret; } + /** + * Converts this {@code NDArray} to a String array. + * + *
This method is only applicable to the String typed NDArray and not for printing purpose
+ *
+ * @return Array of Strings
+ */
+ String[] toStringArray();
+
/**
* Converts this {@code NDArray} to a Number array based on its {@link DataType}.
*
diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
index 8afb9ca3ec0..56ffa83c8d1 100644
--- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
+++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
@@ -135,6 +135,12 @@ default NDArray stopGradient() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}
+ /** {@inheritDoc} */
+ @Override
+ default String[] toStringArray() {
+ throw new UnsupportedOperationException(UNSUPPORTED_MSG);
+ }
+
/** {@inheritDoc} */
@Override
default ByteBuffer toByteBuffer() {
diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java
index e28426e048f..e9620edbb00 100644
--- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java
+++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java
@@ -47,6 +47,11 @@ static Engine newInstance() {
}
private Engine getAlternativeEngine() {
+ boolean disableAlternative =
+ Boolean.parseBoolean(System.getProperty("ai.djl.dlr.disable_alternative", "false"));
+ if (disableAlternative) {
+ return null;
+ }
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
index e8ecde776a9..c1277dca17a 100644
--- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
+++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java
@@ -268,11 +268,18 @@ public boolean hasGradient() {
return hasGradient;
}
+ /** {@inheritDoc} */
@Override
public NDArray stopGradient() {
return manager.invoke("stop_gradient", this, null);
}
+ /** {@inheritDoc} */
+ @Override
+ public String[] toStringArray() {
+ throw new UnsupportedOperationException("String NDArray is not supported!");
+ }
+
/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
index 5396d0e4e2d..1dde554e258 100644
--- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
+++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java
@@ -57,6 +57,12 @@ public int getRank() {
}
private Engine getAlternativeEngine() {
+ boolean disableAlternative =
+ Boolean.parseBoolean(
+ System.getProperty("ai.djl.onnx.disable_alternative", "false"));
+ if (disableAlternative) {
+ return null;
+ }
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java
index 42556a788ca..806fb96714e 100644
--- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java
+++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java
@@ -13,12 +13,16 @@
package ai.djl.onnxruntime.engine;
import ai.djl.Device;
+import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrayAdapter;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OrtException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.UUID;
@@ -117,20 +121,35 @@ public void detach() {
manager = OrtNDManager.getSystemManager();
}
+ /** {@inheritDoc} */
+ @Override
+ public String[] toStringArray() {
+ try {
+ return (String[]) tensor.getValue();
+ } catch (OrtException e) {
+ throw new EngineException(e);
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public ByteBuffer toByteBuffer() {
+ return tensor.getByteBuffer().order(ByteOrder.nativeOrder());
+ }
+
/** {@inheritDoc} */
@Override
public String toString() {
if (isClosed) {
return "This array is already closed";
}
- return "ND: "
- + getShape()
- + ' '
- + getDevice()
- + ' '
- + getDataType()
- + '\n'
- + Arrays.toString(toArray());
+ String arrStr;
+ if (getDataType() == DataType.STRING) {
+ arrStr = Arrays.toString(toStringArray());
+ } else {
+ arrStr = Arrays.toString(toArray());
+ }
+ return "ND: " + getShape() + ' ' + getDevice() + ' ' + getDataType() + '\n' + arrStr;
}
/** {@inheritDoc} */
diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java
index a1a0408f99c..90d8e200e98 100644
--- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java
+++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtUtils.java
@@ -98,6 +98,8 @@ public static DataType toDataType(OnnxJavaType javaType) {
return DataType.BOOLEAN;
case UNKNOWN:
return DataType.UNKNOWN;
+ case STRING:
+ return DataType.STRING;
default:
throw new UnsupportedOperationException("type is not supported: " + javaType);
}
diff --git a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java
index dbcbe7013cb..4d8df458605 100644
--- a/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java
+++ b/onnxruntime/onnxruntime-engine/src/test/java/ai/djl/onnxruntime/engine/OrtTest.java
@@ -73,6 +73,7 @@ public void testOrt() throws TranslateException, ModelException, IOException {
public void testStringTensor()
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
+ System.setProperty("ai.djl.onnx.disable_alternative", "true");
Criteria Currently, only TorchScript file are supported
+ *
+ * @param modelStream the stream of the model file
+ * @throws IOException model loading error
+ */
+ public void load(InputStream modelStream) throws IOException {
+ block = JniUtils.loadModule((PtNDManager) manager, modelStream, manager.getDevice(), false);
+ }
+
private Path findModelFile(String prefix) {
if (Files.isRegularFile(modelDir)) {
Path file = modelDir;
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
index dedf1988561..1a9e6613797 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java
@@ -211,6 +211,12 @@ public ByteBuffer toByteBuffer() {
return JniUtils.getByteBuffer(this);
}
+ /** {@inheritDoc} */
+ @Override
+ public String[] toStringArray() {
+ throw new UnsupportedOperationException("String NDArray is not supported!");
+ }
+
/** {@inheritDoc} */
@Override
public void set(Buffer data) {
diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java
new file mode 100644
index 00000000000..cda9b86d8dd
--- /dev/null
+++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.pytorch.integration;
+
+import ai.djl.Model;
+import ai.djl.inference.Predictor;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.pytorch.engine.PtModel;
+import ai.djl.translate.NoopTranslator;
+import ai.djl.translate.TranslateException;
+import java.io.IOException;
+import java.net.URL;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+public class PtModelTest {
+
+ @Test
+ public void testLoadFromStream() throws IOException, TranslateException {
+ URL url =
+ new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt");
+ try (PtModel model = (PtModel) Model.newInstance("test model")) {
+ model.load(url.openStream());
+ try (Predictor