From b959e2ce6ed99b58494560781d0e17a967914047 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 7 Aug 2024 21:24:35 -0700 Subject: [PATCH] [example] Enable PyTorch for some training example --- .../djl/examples/training/TrainCaptchaTest.java | 17 ++++++++++++++++- .../training/TrainMnistWithLSTMTest.java | 10 +--------- .../djl/examples/training/TrainResNetTest.java | 1 - .../training/TrainSentimentAnalysisTest.java | 1 + .../examples/training/TrainTimeSeriesTest.java | 11 +---------- 5 files changed, 19 insertions(+), 21 deletions(-) diff --git a/examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java b/examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java index a2ccbbf4357..5c01e1cc8c4 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java @@ -27,7 +27,22 @@ public class TrainCaptchaTest { public void testTrainCaptcha() throws IOException, TranslateException { TestRequirements.linux(); - // TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH + // TODO: PyTorch + /* + ai.djl.engine.EngineException: index 11 is out of bounds for dimension 1 with size 11 + at app//ai.djl.pytorch.jni.PyTorchLibrary.torchGather(Native Method) + at app//ai.djl.pytorch.jni.JniUtils.pick(JniUtils.java:581) + at app//ai.djl.pytorch.jni.JniUtils.indexAdv(JniUtils.java:417) + at app//ai.djl.pytorch.engine.PtNDArrayIndexer.get(PtNDArrayIndexer.java:74) + at app//ai.djl.ndarray.NDArray.get(NDArray.java:614) + at app//ai.djl.ndarray.NDArray.get(NDArray.java:603) + at app//ai.djl.training.loss.SoftmaxCrossEntropyLoss.evaluate(SoftmaxCrossEntropyLoss.java:86) + at app//ai.djl.training.loss.IndexLoss.evaluate(IndexLoss.java:55) + at app//ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:68) + at app//ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:124) + at app//ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110) + at app//ai.djl.training.EasyTrain.fit(EasyTrain.java:58) + */ String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"}; TrainingResult result = TrainCaptcha.runExample(args); Assert.assertNotNull(result); diff --git a/examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java b/examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java index 7693e195434..714c25a9e00 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java @@ -12,7 +12,6 @@ */ package ai.djl.examples.training; -import ai.djl.engine.Engine; import ai.djl.training.TrainingResult; import ai.djl.translate.TranslateException; @@ -25,14 +24,7 @@ public class TrainMnistWithLSTMTest { @Test public void testTrainMnistWithLSTM() throws IOException, TranslateException { - String[] args; - Engine engine = Engine.getEngine("PyTorch"); - if (engine.getGpuCount() > 0) { - // TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH - args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"}; - } else { - args = new String[] {"-g", "1", "-e", "1", "-m", "2"}; - } + String[] args = {"-g", "1", "-e", "1", "-m", "2"}; TrainingResult result = TrainMnistWithLSTM.runExample(args); Assert.assertNotNull(result); } diff --git a/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java b/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java index cde2e175251..9b249df7c6a 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java @@ -36,7 +36,6 @@ public void testTrainResNet() throws ModelException, IOException, TranslateExcep // Limit max 4 gpu for cifar10 training to make it converge faster. // and only train 10 batch for unit test. - // only MXNet support symbolic model String[] args = {"-e", "2", "-g", "4", "-m", "10", "-p"}; TrainingResult result = TrainResnetWithCifar10.runExample(args); diff --git a/examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java b/examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java index c5ef28a5581..7a0bb9fbdf7 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java @@ -29,6 +29,7 @@ public void testTrainSentimentAnalysis() TestRequirements.nightly(); TestRequirements.gpu("MXNet", 1); + // TODO: Add a PyTorch Glove model to model zoo String[] args = {"-e", "1", "-g", "1", "--engine", "MXNet"}; TrainSentimentAnalysis.runExample(args); } diff --git a/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java b/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java index 19cfdc3c143..f68efe39504 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java @@ -13,7 +13,6 @@ package ai.djl.examples.training; -import ai.djl.engine.Engine; import ai.djl.training.TrainingResult; import ai.djl.translate.TranslateException; @@ -26,15 +25,7 @@ public class TrainTimeSeriesTest { @Test public void testTrainTimeSeries() throws TranslateException, IOException { - String[] args; - Engine engine = Engine.getEngine("PyTorch"); - if (engine.getGpuCount() > 0) { - // TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH - args = new String[] {"-g", "1", "-e", "5", "-b", "32", "--engine", "MXNet"}; - } else { - args = new String[] {"-g", "1", "-e", "5", "-b", "32"}; - } - + String[] args = {"-g", "1", "-e", "5", "-b", "32"}; TrainingResult result = TrainTimeSeries.runExample(args); Assert.assertNotNull(result); float loss = result.getTrainLoss();