diff --git a/examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java b/examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java index c178e1a94e3..5ee5fd1a23f 100644 --- a/examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java +++ b/examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java @@ -77,14 +77,14 @@ public static float[] predict() throws IOException, TranslateException, ModelExc try (ZooModel model = criteria.loadModel(); Predictor predictor = model.newPredictor(); NDManager manager = NDManager.newBaseManager(null, "MXNet")) { - TimeSeriesData data = getTimeSeriesData(manager, new URL(url)); + TimeSeriesData input = getTimeSeriesData(manager, new URL(url)); // save data for plotting - NDArray target = data.get(FieldName.TARGET); + NDArray target = input.get(FieldName.TARGET); target.setName("target"); saveNDArray(target); - Forecast forecast = predictor.predict(data); + Forecast forecast = predictor.predict(input); // save data for plotting. Please see the corresponding python script from // https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008 diff --git a/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java b/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java index d43c16e79ce..46346d08a50 100644 --- a/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java +++ b/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java @@ -13,7 +13,9 @@ package ai.djl.examples.inference.timeseries; +import ai.djl.Application; import ai.djl.ModelException; +import ai.djl.basicdataset.BasicDatasets; import ai.djl.basicdataset.tabular.utils.DynamicBuffer; import ai.djl.basicdataset.tabular.utils.Feature; import ai.djl.inference.Predictor; @@ -22,9 +24,13 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; +import ai.djl.repository.Artifact; +import ai.djl.repository.MRL; +import ai.djl.repository.Repository; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; import ai.djl.timeseries.Forecast; +import ai.djl.timeseries.SampleForecast; import ai.djl.timeseries.TimeSeriesData; import ai.djl.timeseries.dataset.FieldName; import ai.djl.training.util.ProgressBar; @@ -41,10 +47,12 @@ import java.io.BufferedInputStream; import java.io.IOException; import java.io.InputStreamReader; +import java.io.OutputStream; import java.io.Reader; import java.net.URL; import java.nio.FloatBuffer; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.time.LocalDateTime; @@ -70,17 +78,22 @@ public static void main(String[] args) throws IOException, TranslateException, M public static Map predict() throws IOException, TranslateException, ModelException { - // M5 Forecasting - Accuracy dataset requires manual download - String pathToData = "/Desktop/m5example/m5-forecasting-accuracy"; - Path m5ForecastFile = Paths.get(System.getProperty("user.home") + pathToData); NDManager manager = NDManager.newBaseManager(null, "MXNet"); - M5Dataset dataset = M5Dataset.builder().setManager(manager).setRoot(m5ForecastFile).build(); + // To use local dataset, users can load data as follows + // Repository repository = Repository.newInstance("local_dataset", + // Paths.get("rootPath/m5-forecasting-accuracy")); + // Then add the setting `.optRepository(repository)` to the builder below + M5Dataset dataset = M5Dataset.builder().setManager(manager).build(); + + // The modelUrl can be replaced by local model path. E.g., + // String modelUrl = "rootPath/deepar.zip"; + String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast"; int predictionLength = 4; Criteria criteria = Criteria.builder() .setTypes(TimeSeriesData.class, Forecast.class) - .optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/m5forecast") + .optModelUrls(modelUrl) .optEngine("MXNet") .optTranslatorFactory(new DeferredTranslatorFactory()) .optArgument("prediction_length", predictionLength) @@ -104,15 +117,19 @@ public static Map predict() input.setStartTime(LocalDateTime.parse("2011-01-29T00:00")); input.setField(FieldName.TARGET, pastTarget); Forecast forecast = predictor.predict(input); - // Here we focus on the metric Weighted Root Mean Squared Scaled Error (RMSSE) same - // as + // We focus on the metric Weighted Root Mean Squared Scaled Error (RMSSE) same as // https://www.kaggle.com/competitions/m5-forecasting-accuracy/overview/evaluation // The error is not small compared to the data values (sale amount). This is because // The model is trained on a sparse data with many zeros. This will be improved by - // aggregating/coarse graining the data which will appear in the next PR. - // TODO: coarse graining the data. + // aggregating/coarse graining the data. See https://github.com/Carkham/m5_blog evaluator.aggregateMetrics(evaluator.getMetricsPerTs(gt, pastTarget, forecast)); progress.increment(1); + + // save data for plotting. Please see the corresponding python script from + // https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008 + NDArray samples = ((SampleForecast) forecast).getSortedSamples(); + samples.setName("samples"); + saveNDArray(samples); } manager.close(); @@ -120,6 +137,13 @@ public static Map predict() } } + private static void saveNDArray(NDArray array) throws IOException { + Path path = Paths.get("build").resolve(array.getName() + ".npz"); + try (OutputStream os = Files.newOutputStream(path)) { + new NDList(new NDList(array)).encode(os, true); + } + } + /** * M5 Forecasting - Accuracy from https://www.kaggle.com/competitions/m5-forecasting-accuracy @@ -146,7 +170,14 @@ private static final class M5Dataset implements Iterable, Iterator target; CSVFormat csvFormat; - Path root; + + Repository repository; + String groupId; + String artifactId; + String version; Builder() { + repository = BasicDatasets.REPOSITORY; + groupId = BasicDatasets.GROUP_ID; + artifactId = "m5forecast-unittest"; + version = "1.0"; csvFormat = CSVFormat.DEFAULT .builder() @@ -214,8 +253,8 @@ public static final class Builder { } } - public Builder setRoot(Path root) { - this.root = root; + public Builder optRepository(Repository repository) { + this.repository = repository; return this; } @@ -227,6 +266,10 @@ public Builder setManager(NDManager manager) { public M5Dataset build() { return new M5Dataset(this); } + + MRL getMrl() { + return repository.dataset(Application.Tabular.ANY, groupId, artifactId, version); + } } } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java index 65c926c1691..e0143ed524b 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java @@ -15,6 +15,7 @@ import ai.djl.Model; import ai.djl.ModelException; +import ai.djl.basicdataset.BasicDatasets; import ai.djl.basicdataset.tabular.utils.Feature; import ai.djl.engine.Engine; import ai.djl.examples.inference.timeseries.M5ForecastingDeepAR; @@ -26,7 +27,6 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Parameter; -import ai.djl.repository.Repository; import ai.djl.timeseries.Forecast; import ai.djl.timeseries.TimeSeriesData; import ai.djl.timeseries.dataset.FieldName; @@ -83,13 +83,6 @@ public static void main(String[] args) throws IOException, TranslateException, M } public static TrainingResult runExample(String[] args) throws IOException, TranslateException { - // use data path to create a custom repository - Repository repository = - Repository.newInstance( - "test", - Paths.get( - System.getProperty("user.home") - + "/Desktop/m5-forecasting-accuracy")); Arguments arguments = new Arguments().parseArgs(args); try (Model model = Model.newInstance("deepar")) { @@ -106,8 +99,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans int contextLength = trainingNetwork.getContextLength(); M5Forecast trainSet = - getDataset( - trainingTransformation, repository, contextLength, Dataset.Usage.TRAIN); + getDataset(trainingTransformation, contextLength, Dataset.Usage.TRAIN); try (Trainer trainer = model.newTrainer(config)) { trainer.setMetrics(new Metrics()); @@ -144,13 +136,6 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans public static Map predict(String outputDir) throws IOException, TranslateException, ModelException { - Repository repository = - Repository.newInstance( - "test", - Paths.get( - System.getProperty("user.home") - + "/Desktop/m5-forecasting-accuracy")); - try (Model model = Model.newInstance("deepar")) { DeepARNetwork predictionNetwork = getDeepARModel(new NegativeBinomialOutput(), false); model.setBlock(predictionNetwork); @@ -159,7 +144,6 @@ public static Map predict(String outputDir) M5Forecast testSet = getDataset( new ArrayList<>(), - repository, predictionNetwork.getContextLength(), Dataset.Usage.TEST); @@ -262,17 +246,16 @@ private static DeepARNetwork getDeepARModel( } private static M5Forecast getDataset( - List transformation, - Repository repository, - int contextLength, - Dataset.Usage usage) + List transformation, int contextLength, Dataset.Usage usage) throws IOException { // In order to create a TimeSeriesDataset, you must specify the transformation of the data // preprocessing M5Forecast.Builder builder = M5Forecast.builder() .optUsage(usage) - .optRepository(repository) + .optRepository(BasicDatasets.REPOSITORY) + .optGroupId(BasicDatasets.GROUP_ID) + .optArtifactId("m5forecast-unittest") .setTransformation(transformation) .setContextLength(contextLength) .setSampling(32, usage == Dataset.Usage.TRAIN); diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java index 79c730abf00..2f2b85c8e67 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java @@ -131,13 +131,10 @@ private static RandomAccessDataset getData(Dataset.Usage usage, int batchSize) float[] mean = {0.485f, 0.456f, 0.406f}; float[] std = {0.229f, 0.224f, 0.225f}; - // If users want to use local repository, then the dataset can be loaded as follows - // Repository repository = Repository.newInstance("banana", Paths.get(LOCAL_FOLDER/{train OR - // test})); - // FruitsFreshAndRotten dataset = - // FruitsFreshAndRotten.builder() - // .optRepository(repository) - // .build() + // To use local dataset, users can load it as follows + // Repository repository = Repository.newInstance("banana", + // Paths.get("local_data_root/banana/train")); + // Then add the setting `.optRepository(repository)` to the builder below FruitsFreshAndRotten dataset = FruitsFreshAndRotten.builder() .optUsage(usage) diff --git a/examples/src/test/java/ai/djl/examples/inference/TimeSeriesTest.java b/examples/src/test/java/ai/djl/examples/inference/TimeSeriesTest.java new file mode 100644 index 00000000000..f6b315cee36 --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/inference/TimeSeriesTest.java @@ -0,0 +1,65 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.examples.inference.timeseries.AirPassengersDeepAR; +import ai.djl.examples.inference.timeseries.M5ForecastingDeepAR; +import ai.djl.testing.TestRequirements; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.util.Map; + +public class TimeSeriesTest { + + private static final Logger logger = LoggerFactory.getLogger(TimeSeriesTest.class); + + @Test + public void testM5Forecasting() throws ModelException, TranslateException, IOException { + TestRequirements.engine("MXNet"); + + Map result = M5ForecastingDeepAR.predict(); + + String[] metricNames = + new String[] { + "RMSSE", + "MSE", + "abs_error", + "abs_target_sum", + "abs_target_mean", + "MAPE", + "sMAPE", + "ND" + }; + for (String metricName : metricNames) { + Assert.assertTrue(result.containsKey(metricName)); + } + } + + @Test + public void testAirPassenger() throws ModelException, TranslateException, IOException { + TestRequirements.engine("MXNet"); + + float[] result = AirPassengersDeepAR.predict(); + logger.info("{}", result); + + Assert.assertEquals(result.length, 12); + } +} diff --git a/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java b/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java new file mode 100644 index 00000000000..b92c1bb8249 --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.examples.training; + +import ai.djl.testing.TestRequirements; +import ai.djl.training.TrainingResult; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; + +public class TrainTimeSeriesTest { + + @Test + public void testTrainTimeSeries() throws TranslateException, IOException { + TestRequirements.engine("MXNet"); + + String[] args = new String[] {"-g", "1", "-e", "5", "-b", "32"}; + TrainingResult result = TrainTimeSeries.runExample(args); + Assert.assertNotNull(result); + float loss = result.getTrainLoss(); + Assert.assertTrue(loss < 10f, "Loss: " + loss); + } +}