Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[example] change timeseries dataset source and add test #2109

Merged
merged 6 commits into from
Nov 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ public static float[] predict() throws IOException, TranslateException, ModelExc
try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor();
NDManager manager = NDManager.newBaseManager(null, "MXNet")) {
TimeSeriesData data = getTimeSeriesData(manager, new URL(url));
TimeSeriesData input = getTimeSeriesData(manager, new URL(url));

// save data for plotting
NDArray target = data.get(FieldName.TARGET);
NDArray target = input.get(FieldName.TARGET);
target.setName("target");
saveNDArray(target);

Forecast forecast = predictor.predict(data);
Forecast forecast = predictor.predict(input);

// save data for plotting. Please see the corresponding python script from
// https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

package ai.djl.examples.inference.timeseries;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.tabular.utils.DynamicBuffer;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.inference.Predictor;
Expand All @@ -22,9 +24,13 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.SampleForecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.training.util.ProgressBar;
Expand All @@ -41,10 +47,12 @@
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.Reader;
import java.net.URL;
import java.nio.FloatBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.LocalDateTime;
Expand All @@ -70,17 +78,22 @@ public static void main(String[] args) throws IOException, TranslateException, M

public static Map<String, Float> predict()
throws IOException, TranslateException, ModelException {
// M5 Forecasting - Accuracy dataset requires manual download
String pathToData = "/Desktop/m5example/m5-forecasting-accuracy";
Path m5ForecastFile = Paths.get(System.getProperty("user.home") + pathToData);
NDManager manager = NDManager.newBaseManager(null, "MXNet");
M5Dataset dataset = M5Dataset.builder().setManager(manager).setRoot(m5ForecastFile).build();

// To use local dataset, users can load data as follows
// Repository repository = Repository.newInstance("local_dataset",
// Paths.get("rootPath/m5-forecasting-accuracy"));
// Then add the setting `.optRepository(repository)` to the builder below
M5Dataset dataset = M5Dataset.builder().setManager(manager).build();

// The modelUrl can be replaced by local model path. E.g.,
// String modelUrl = "rootPath/deepar.zip";
String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast";
int predictionLength = 4;
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/m5forecast")
.optModelUrls(modelUrl)
.optEngine("MXNet")
.optTranslatorFactory(new DeferredTranslatorFactory())
.optArgument("prediction_length", predictionLength)
Expand All @@ -104,22 +117,33 @@ public static Map<String, Float> predict()
input.setStartTime(LocalDateTime.parse("2011-01-29T00:00"));
input.setField(FieldName.TARGET, pastTarget);
Forecast forecast = predictor.predict(input);
// Here we focus on the metric Weighted Root Mean Squared Scaled Error (RMSSE) same
// as
// We focus on the metric Weighted Root Mean Squared Scaled Error (RMSSE) same as
// https://www.kaggle.com/competitions/m5-forecasting-accuracy/overview/evaluation
// The error is not small compared to the data values (sale amount). This is because
// The model is trained on a sparse data with many zeros. This will be improved by
// aggregating/coarse graining the data which will appear in the next PR.
// TODO: coarse graining the data.
// aggregating/coarse graining the data. See https://github.com/Carkham/m5_blog
evaluator.aggregateMetrics(evaluator.getMetricsPerTs(gt, pastTarget, forecast));
progress.increment(1);

// save data for plotting. Please see the corresponding python script from
// https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008
NDArray samples = ((SampleForecast) forecast).getSortedSamples();
samples.setName("samples");
saveNDArray(samples);
}

manager.close();
return evaluator.computeTotalMetrics();
}
}

private static void saveNDArray(NDArray array) throws IOException {
Path path = Paths.get("build").resolve(array.getName() + ".npz");
try (OutputStream os = Files.newOutputStream(path)) {
new NDList(new NDList(array)).encode(os, true);
}
}

/**
* M5 Forecasting - Accuracy from <a
* href="https://www.kaggle.com/competitions/m5-forecasting-accuracy">https://www.kaggle.com/competitions/m5-forecasting-accuracy</a>
Expand All @@ -146,7 +170,14 @@ private static final class M5Dataset implements Iterable<NDList>, Iterator<NDLis
}

private void prepare(Builder builder) throws IOException {
URL csvUrl = builder.root.resolve("weekly_sales_train_evaluation.csv").toUri().toURL();
MRL mrl = builder.getMrl();
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, null);

Path root = mrl.getRepository().getResourceDirectory(artifact);
Path csvFile = root.resolve("weekly_sales_train_evaluation.csv");

URL csvUrl = csvFile.toUri().toURL();
try (Reader reader =
new InputStreamReader(
new BufferedInputStream(csvUrl.openStream()), StandardCharsets.UTF_8)) {
Expand Down Expand Up @@ -197,9 +228,17 @@ public static final class Builder {
NDManager manager;
List<Feature> target;
CSVFormat csvFormat;
Path root;

Repository repository;
String groupId;
String artifactId;
String version;

Builder() {
repository = BasicDatasets.REPOSITORY;
groupId = BasicDatasets.GROUP_ID;
artifactId = "m5forecast-unittest";
version = "1.0";
csvFormat =
CSVFormat.DEFAULT
.builder()
Expand All @@ -214,8 +253,8 @@ public static final class Builder {
}
}

public Builder setRoot(Path root) {
this.root = root;
public Builder optRepository(Repository repository) {
this.repository = repository;
return this;
}

Expand All @@ -227,6 +266,10 @@ public Builder setManager(NDManager manager) {
public M5Dataset build() {
return new M5Dataset(this);
}

MRL getMrl() {
return repository.dataset(Application.Tabular.ANY, groupId, artifactId, version);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.engine.Engine;
import ai.djl.examples.inference.timeseries.M5ForecastingDeepAR;
Expand All @@ -26,7 +27,6 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.repository.Repository;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
Expand Down Expand Up @@ -83,13 +83,6 @@ public static void main(String[] args) throws IOException, TranslateException, M
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
// use data path to create a custom repository
Repository repository =
Repository.newInstance(
"test",
Paths.get(
System.getProperty("user.home")
+ "/Desktop/m5-forecasting-accuracy"));

Arguments arguments = new Arguments().parseArgs(args);
try (Model model = Model.newInstance("deepar")) {
Expand All @@ -106,8 +99,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
int contextLength = trainingNetwork.getContextLength();

M5Forecast trainSet =
getDataset(
trainingTransformation, repository, contextLength, Dataset.Usage.TRAIN);
getDataset(trainingTransformation, contextLength, Dataset.Usage.TRAIN);

try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
Expand Down Expand Up @@ -144,13 +136,6 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans

public static Map<String, Float> predict(String outputDir)
throws IOException, TranslateException, ModelException {
Repository repository =
Repository.newInstance(
"test",
Paths.get(
System.getProperty("user.home")
+ "/Desktop/m5-forecasting-accuracy"));

try (Model model = Model.newInstance("deepar")) {
DeepARNetwork predictionNetwork = getDeepARModel(new NegativeBinomialOutput(), false);
model.setBlock(predictionNetwork);
Expand All @@ -159,7 +144,6 @@ public static Map<String, Float> predict(String outputDir)
M5Forecast testSet =
getDataset(
new ArrayList<>(),
repository,
predictionNetwork.getContextLength(),
Dataset.Usage.TEST);

Expand Down Expand Up @@ -262,17 +246,16 @@ private static DeepARNetwork getDeepARModel(
}

private static M5Forecast getDataset(
List<TimeSeriesTransform> transformation,
Repository repository,
int contextLength,
Dataset.Usage usage)
List<TimeSeriesTransform> transformation, int contextLength, Dataset.Usage usage)
throws IOException {
// In order to create a TimeSeriesDataset, you must specify the transformation of the data
// preprocessing
M5Forecast.Builder builder =
M5Forecast.builder()
.optUsage(usage)
.optRepository(repository)
.optRepository(BasicDatasets.REPOSITORY)
.optGroupId(BasicDatasets.GROUP_ID)
.optArtifactId("m5forecast-unittest")
.setTransformation(transformation)
.setContextLength(contextLength)
.setSampling(32, usage == Dataset.Usage.TRAIN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,10 @@ private static RandomAccessDataset getData(Dataset.Usage usage, int batchSize)
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};

// If users want to use local repository, then the dataset can be loaded as follows
// Repository repository = Repository.newInstance("banana", Paths.get(LOCAL_FOLDER/{train OR
// test}));
// FruitsFreshAndRotten dataset =
// FruitsFreshAndRotten.builder()
// .optRepository(repository)
// .build()
// To use local dataset, users can load it as follows
// Repository repository = Repository.newInstance("banana",
// Paths.get("local_data_root/banana/train"));
// Then add the setting `.optRepository(repository)` to the builder below
FruitsFreshAndRotten dataset =
FruitsFreshAndRotten.builder()
.optUsage(usage)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.examples.inference;

import ai.djl.ModelException;
import ai.djl.examples.inference.timeseries.AirPassengersDeepAR;
import ai.djl.examples.inference.timeseries.M5ForecastingDeepAR;
import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.util.Map;

public class TimeSeriesTest {

private static final Logger logger = LoggerFactory.getLogger(TimeSeriesTest.class);

@Test
public void testM5Forecasting() throws ModelException, TranslateException, IOException {
TestRequirements.engine("MXNet");

Map<String, Float> result = M5ForecastingDeepAR.predict();

String[] metricNames =
new String[] {
"RMSSE",
"MSE",
"abs_error",
"abs_target_sum",
"abs_target_mean",
"MAPE",
"sMAPE",
"ND"
};
for (String metricName : metricNames) {
Assert.assertTrue(result.containsKey(metricName));
}
}

@Test
public void testAirPassenger() throws ModelException, TranslateException, IOException {
TestRequirements.engine("MXNet");

float[] result = AirPassengersDeepAR.predict();
logger.info("{}", result);

Assert.assertEquals(result.length, 12);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;

public class TrainTimeSeriesTest {

@Test
public void testTrainTimeSeries() throws TranslateException, IOException {
TestRequirements.engine("MXNet");

String[] args = new String[] {"-g", "1", "-e", "5", "-b", "32"};
TrainingResult result = TrainTimeSeries.runExample(args);
Assert.assertNotNull(result);
float loss = result.getTrainLoss();
Assert.assertTrue(loss < 10f, "Loss: " + loss);
}
}