From b0fe73a97ea4310238c0d9a4d3db618b39b93a36 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 8 Jun 2021 17:56:19 -0700 Subject: [PATCH] [pytorch] Update load model jupyter notebook (#1002) Change-Id: I1889aa93d2002e6ce02c740d2d1d3517bf586760 --- jupyter/load_pytorch_model.ipynb | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/jupyter/load_pytorch_model.ipynb b/jupyter/load_pytorch_model.ipynb index 6f16e2cd3cc..6a9e0c7073a 100644 --- a/jupyter/load_pytorch_model.ipynb +++ b/jupyter/load_pytorch_model.ipynb @@ -121,18 +121,15 @@ "metadata": {}, "outputs": [], "source": [ - "Pipeline pipeline = new Pipeline();\n", - "pipeline.add(new Resize(256))\n", - " .add(new CenterCrop(224, 224))\n", - " .add(new ToTensor())\n", - " .add(new Normalize(\n", - " new float[] {0.485f, 0.456f, 0.406f},\n", - " new float[] {0.229f, 0.224f, 0.225f}));\n", - "\n", "Translator translator = ImageClassificationTranslator.builder()\n", - " .setPipeline(pipeline)\n", - " .optApplySoftmax(true)\n", - " .build();" + " .addTransform(new Resize(256))\n", + " .addTransform(new CenterCrop(224, 224))\n", + " .addTransform(new ToTensor())\n", + " .addTransform(new Normalize(\n", + " new float[] {0.485f, 0.456f, 0.406f},\n", + " new float[] {0.229f, 0.224f, 0.225f}))\n", + " .optApplySoftmax(true)\n", + " .build();" ] }, { @@ -141,9 +138,7 @@ "source": [ "## Step 3: Load your model\n", "\n", - "Next, we will set the model zoo location to the `build/pytorch_models` directory we saved the model to. You can also create your own [`Repository`](https://javadoc.io/static/ai.djl/repository/0.11.0/index.html?ai/djl/repository/Repository.html) to avoid manually managing files.\n", - "\n", - "Next, we add some search criteria to find the resnet18 model and load it." + "Next, we add some search criteria to find the resnet18 model and load it. In this case, we need to tell `Criteria` where to locate the model by calling `.optModelPath()` API." ] }, { @@ -226,7 +221,7 @@ "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", - "version": "11.0.11+9-LTS" + "version": "14.0.2+12" }, "pycharm": { "stem_cell": {