diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/SavedModelBundle.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/SavedModelBundle.java index 331e1ff5257..3eec14fbc28 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/SavedModelBundle.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/SavedModelBundle.java @@ -14,17 +14,31 @@ package ai.djl.tensorflow.engine; import org.tensorflow.internal.c_api.TF_Graph; +import org.tensorflow.internal.c_api.TF_Operation; import org.tensorflow.internal.c_api.TF_Session; +import org.tensorflow.internal.c_api.global.tensorflow; +import org.tensorflow.proto.framework.CollectionDef; import org.tensorflow.proto.framework.MetaGraphDef; +import org.tensorflow.proto.framework.SignatureDef; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; /** The wrapper class for native resources required for SavedModelBundle. */ public class SavedModelBundle implements AutoCloseable { + private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; + private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op"; + private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op"; + private static final String TABLE_INITIALIZERS_COLLECTION_KEY = "table_initializer"; + private TF_Graph graphHandle; private TF_Session sessionHandle; private MetaGraphDef metaGraphDef; + private TF_Operation[] targetOpHandles; private AtomicBoolean closed; public SavedModelBundle( @@ -33,6 +47,61 @@ public SavedModelBundle( this.sessionHandle = sessionHandle; this.metaGraphDef = metaGraphDef; closed = new AtomicBoolean(false); + + Map functions = new ConcurrentHashMap<>(); + metaGraphDef + .getSignatureDefMap() + .forEach( + (signatureName, signatureDef) -> { + if (!functions.containsKey(signatureName)) { + functions.put(signatureName, signatureDef); + } + }); + + List initOps = new ArrayList<>(); + TF_Operation initOp = findInitOp(functions, metaGraphDef.getCollectionDefMap()); + if (initOp != null) { + initOps.add(initOp); + } + + if (metaGraphDef.containsCollectionDef(TABLE_INITIALIZERS_COLLECTION_KEY)) { + metaGraphDef + .getCollectionDefMap() + .get(TABLE_INITIALIZERS_COLLECTION_KEY) + .getNodeList() + .getValueList() + .forEach( + node -> { + initOps.add(tensorflow.TF_GraphOperationByName(graphHandle, node)); + }); + } + targetOpHandles = initOps.toArray(new TF_Operation[0]); + } + + private TF_Operation findInitOp( + Map signatures, Map collections) { + SignatureDef initSig = signatures.get(INIT_OP_SIGNATURE_KEY); + if (initSig != null) { + String opName = initSig.getOutputsMap().get(INIT_OP_SIGNATURE_KEY).getName(); + return tensorflow.TF_GraphOperationByName(graphHandle, opName); + } + + CollectionDef initCollection; + if (collections.containsKey(MAIN_OP_COLLECTION_KEY)) { + initCollection = collections.get(MAIN_OP_COLLECTION_KEY); + } else { + initCollection = collections.get(LEGACY_INIT_OP_COLLECTION_KEY); + } + + if (initCollection != null) { + CollectionDef.NodeList nodes = initCollection.getNodeList(); + if (nodes.getValueCount() != 1) { + throw new IllegalArgumentException("Expected exactly one main op in saved model."); + } + String opName = nodes.getValue(0); + return tensorflow.TF_GraphOperationByName(graphHandle, opName); + } + return null; } /** @@ -62,6 +131,10 @@ public MetaGraphDef getMetaGraphDef() { return metaGraphDef; } + TF_Operation[] getTargetOpHandles() { + return targetOpHandles; + } + /** {@inheritDoc} */ @Override public void close() { diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java index 4112a5e55a3..2453e58c25d 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java @@ -67,6 +67,7 @@ public TfSymbolBlock(SavedModelBundle bundle, String signatureDefKey) { this.bundle = bundle; graphHandle = bundle.getGraph(); sessionHandle = bundle.getSession(); + targetOpHandles = bundle.getTargetOpHandles(); MetaGraphDef metaGraphDef = bundle.getMetaGraphDef(); Map signatureDefMap = metaGraphDef.getSignatureDefMap(); if (signatureDefMap.containsKey(signatureDefKey)) { @@ -89,8 +90,6 @@ public TfSymbolBlock(SavedModelBundle bundle, String signatureDefKey) { } describeInput(); describeOutput(); - // we don't use target for now - targetOpHandles = new TF_Operation[0]; } /** {@inheritDoc} */ diff --git a/engines/tensorflow/tensorflow-engine/src/test/java/ai/djl/tensorflow/integration/ModelLoadingTest.java b/engines/tensorflow/tensorflow-engine/src/test/java/ai/djl/tensorflow/integration/ModelLoadingTest.java deleted file mode 100644 index e6d1e0e22bf..00000000000 --- a/engines/tensorflow/tensorflow-engine/src/test/java/ai/djl/tensorflow/integration/ModelLoadingTest.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.tensorflow.integration; - -import ai.djl.ModelException; -import ai.djl.inference.Predictor; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.DataType; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.testing.TestRequirements; -import ai.djl.translate.TranslateException; - -import org.testng.Assert; -import org.testng.annotations.Test; - -import java.io.IOException; - -public class ModelLoadingTest { - - @Test - public void loadModelWithStringTensor() throws ModelException, IOException, TranslateException { - TestRequirements.nightly(); - String url = "https://resources.djl.ai/test-models/tensorflow/string_tensor.zip"; - Criteria criteria = - Criteria.builder().setTypes(NDList.class, NDList.class).optModelUrls(url).build(); - - try (ZooModel model = criteria.loadModel(); - Predictor predictor = model.newPredictor(); - NDManager manager = NDManager.newBaseManager()) { - NDArray array = manager.create("Test1"); - NDList output = predictor.predict(new NDList(array)); - Assert.assertEquals(output.size(), 2); - NDArray str = output.get("corrected_query"); - Assert.assertEquals(str.getDataType(), DataType.STRING); - } - } -} diff --git a/engines/tensorflow/tensorflow-model-zoo/src/test/java/ai/djl/tensorflow/integration/modality/cv/TfSsdTest.java b/engines/tensorflow/tensorflow-model-zoo/src/test/java/ai/djl/tensorflow/integration/modality/cv/TfSsdTest.java index 649ca193b5f..3e8937de7af 100644 --- a/engines/tensorflow/tensorflow-model-zoo/src/test/java/ai/djl/tensorflow/integration/modality/cv/TfSsdTest.java +++ b/engines/tensorflow/tensorflow-model-zoo/src/test/java/ai/djl/tensorflow/integration/modality/cv/TfSsdTest.java @@ -19,11 +19,16 @@ import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; import ai.djl.testing.TestRequirements; import ai.djl.training.util.ProgressBar; +import ai.djl.translate.NoopTranslator; import ai.djl.translate.TranslateException; import ai.djl.util.Pair; @@ -59,14 +64,21 @@ public void testTfSSD() throws IOException, ModelException, TranslateException { Predictor predictor = model.newPredictor()) { Assert.assertEquals(model.describeInput().get(0).getValue(), new Shape(-1, -1, -1, 3)); for (Pair pair : model.describeOutput()) { - if (pair.getKey().contains("label")) { - Assert.assertEquals(pair.getValue(), new Shape(-1, 1)); - } else if (pair.getKey().contains("box")) { - Assert.assertEquals(pair.getValue(), new Shape(-1, 4)); - } else if (pair.getKey().contains("score")) { - Assert.assertEquals(pair.getValue(), new Shape(-1, 1)); - } else { - throw new IllegalStateException("Unexpected output name:" + pair.getKey()); + switch (pair.getKey()) { + case "box": + case "detection_boxes": + Assert.assertEquals(pair.getValue(), new Shape(-1, 4)); + break; + case "label": + case "score": + case "detection_class_entities": + case "detection_class_labels": + case "detection_class_names": + case "detection_scores": + Assert.assertEquals(pair.getValue(), new Shape(-1, 1)); + break; + default: + throw new IllegalStateException("Unexpected output name: " + pair.getKey()); } } @@ -82,6 +94,32 @@ public void testTfSSD() throws IOException, ModelException, TranslateException { } } + @Test + public void testStringInputOutput() throws IOException, ModelException, TranslateException { + TestRequirements.notArm(); + + Criteria criteria = + Criteria.builder() + .optApplication(Application.CV.OBJECT_DETECTION) + .setTypes(NDList.class, NDList.class) + .optArtifactId("ssd") + .optFilter("backbone", "mobilenet_v2") + .optEngine("TensorFlow") + .optProgress(new ProgressBar()) + .optTranslator(new NoopTranslator()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = NDManager.newBaseManager()) { + NDArray array = manager.zeros(new Shape(1, 224, 224, 3)); + NDList output = predictor.predict(new NDList(array)); + Assert.assertEquals(output.size(), 5); + NDArray entities = output.get("detection_class_entities"); + Assert.assertEquals(entities.getDataType(), DataType.STRING); + } + } + private static void saveBoundingBoxImage(Image img, DetectedObjects detection) throws IOException { Path outputDir = Paths.get("build/output");