Skip to content

Commit

Permalink
Add initializer to targetOpHandles
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Oct 2, 2022
1 parent 394ba89 commit 41a8811
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,31 @@
package ai.djl.tensorflow.engine;

import org.tensorflow.internal.c_api.TF_Graph;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.proto.framework.CollectionDef;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

/** The wrapper class for native resources required for SavedModelBundle. */
public class SavedModelBundle implements AutoCloseable {

private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op";
private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op";
private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op";
private static final String TABLE_INITIALIZERS_COLLECTION_KEY = "table_initializer";

private TF_Graph graphHandle;
private TF_Session sessionHandle;
private MetaGraphDef metaGraphDef;
private TF_Operation[] targetOpHandles;
private AtomicBoolean closed;

public SavedModelBundle(
Expand All @@ -33,6 +47,61 @@ public SavedModelBundle(
this.sessionHandle = sessionHandle;
this.metaGraphDef = metaGraphDef;
closed = new AtomicBoolean(false);

Map<String, SignatureDef> functions = new ConcurrentHashMap<>();
metaGraphDef
.getSignatureDefMap()
.forEach(
(signatureName, signatureDef) -> {
if (!functions.containsKey(signatureName)) {
functions.put(signatureName, signatureDef);
}
});

List<TF_Operation> initOps = new ArrayList<>();
TF_Operation initOp = findInitOp(functions, metaGraphDef.getCollectionDefMap());
if (initOp != null) {
initOps.add(initOp);
}

if (metaGraphDef.containsCollectionDef(TABLE_INITIALIZERS_COLLECTION_KEY)) {
metaGraphDef
.getCollectionDefMap()
.get(TABLE_INITIALIZERS_COLLECTION_KEY)
.getNodeList()
.getValueList()
.forEach(
node -> {
initOps.add(tensorflow.TF_GraphOperationByName(graphHandle, node));
});
}
targetOpHandles = initOps.toArray(new TF_Operation[0]);
}

private TF_Operation findInitOp(
Map<String, SignatureDef> signatures, Map<String, CollectionDef> collections) {
SignatureDef initSig = signatures.get(INIT_OP_SIGNATURE_KEY);
if (initSig != null) {
String opName = initSig.getOutputsMap().get(INIT_OP_SIGNATURE_KEY).getName();
return tensorflow.TF_GraphOperationByName(graphHandle, opName);
}

CollectionDef initCollection;
if (collections.containsKey(MAIN_OP_COLLECTION_KEY)) {
initCollection = collections.get(MAIN_OP_COLLECTION_KEY);
} else {
initCollection = collections.get(LEGACY_INIT_OP_COLLECTION_KEY);
}

if (initCollection != null) {
CollectionDef.NodeList nodes = initCollection.getNodeList();
if (nodes.getValueCount() != 1) {
throw new IllegalArgumentException("Expected exactly one main op in saved model.");
}
String opName = nodes.getValue(0);
return tensorflow.TF_GraphOperationByName(graphHandle, opName);
}
return null;
}

/**
Expand Down Expand Up @@ -62,6 +131,10 @@ public MetaGraphDef getMetaGraphDef() {
return metaGraphDef;
}

TF_Operation[] getTargetOpHandles() {
return targetOpHandles;
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public TfSymbolBlock(SavedModelBundle bundle, String signatureDefKey) {
this.bundle = bundle;
graphHandle = bundle.getGraph();
sessionHandle = bundle.getSession();
targetOpHandles = bundle.getTargetOpHandles();
MetaGraphDef metaGraphDef = bundle.getMetaGraphDef();
Map<String, SignatureDef> signatureDefMap = metaGraphDef.getSignatureDefMap();
if (signatureDefMap.containsKey(signatureDefKey)) {
Expand All @@ -89,8 +90,6 @@ public TfSymbolBlock(SavedModelBundle bundle, String signatureDefKey) {
}
describeInput();
describeOutput();
// we don't use target for now
targetOpHandles = new TF_Operation[0];
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -215,6 +214,11 @@ public final PairList<String, Shape> describeOutput() {
for (String key : keys) {
TensorInfo tensorInfo = outputsMap.get(key);
TensorShapeProto shapeProto = tensorInfo.getTensorShape();
// does not support string tensors
// if (tensorInfo.getDtype() ==
// org.tensorflow.proto.framework.DataType.DT_STRING) {
// continue;
// }
outputDescriptions.add(
key,
new Shape(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.TestRequirements;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;

Expand Down Expand Up @@ -59,14 +64,21 @@ public void testTfSSD() throws IOException, ModelException, TranslateException {
Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
Assert.assertEquals(model.describeInput().get(0).getValue(), new Shape(-1, -1, -1, 3));
for (Pair<String, Shape> pair : model.describeOutput()) {
if (pair.getKey().contains("label")) {
Assert.assertEquals(pair.getValue(), new Shape(-1, 1));
} else if (pair.getKey().contains("box")) {
Assert.assertEquals(pair.getValue(), new Shape(-1, 4));
} else if (pair.getKey().contains("score")) {
Assert.assertEquals(pair.getValue(), new Shape(-1, 1));
} else {
throw new IllegalStateException("Unexpected output name:" + pair.getKey());
switch (pair.getKey()) {
case "box":
case "detection_boxes":
Assert.assertEquals(pair.getValue(), new Shape(-1, 4));
break;
case "label":
case "score":
case "detection_class_entities":
case "detection_class_labels":
case "detection_class_names":
case "detection_scores":
Assert.assertEquals(pair.getValue(), new Shape(-1, 1));
break;
default:
throw new IllegalStateException("Unexpected output name: " + pair.getKey());
}
}

Expand All @@ -82,6 +94,32 @@ public void testTfSSD() throws IOException, ModelException, TranslateException {
}
}

@Test
public void testStringInputOutput() throws IOException, ModelException, TranslateException {
TestRequirements.notArm();

Criteria<NDList, NDList> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(NDList.class, NDList.class)
.optArtifactId("ssd")
.optFilter("backbone", "mobilenet_v2")
.optEngine("TensorFlow")
.optProgress(new ProgressBar())
.optTranslator(new NoopTranslator())
.build();

try (ZooModel<NDList, NDList> model = criteria.loadModel();
Predictor<NDList, NDList> predictor = model.newPredictor();
NDManager manager = NDManager.newBaseManager()) {
NDArray array = manager.zeros(new Shape(1, 224, 224, 3));
NDList output = predictor.predict(new NDList(array));
Assert.assertEquals(output.size(), 5);
NDArray entities = output.get("detection_class_entities");
Assert.assertEquals(entities.getDataType(), DataType.STRING);
}
}

private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
throws IOException {
Path outputDir = Paths.get("build/output");
Expand Down

0 comments on commit 41a8811

Please sign in to comment.