From 3703b69c3b513b5d8971e978f3334ccc16543c9d Mon Sep 17 00:00:00 2001 From: Carkham <60054018+Carkham@users.noreply.github.com> Date: Fri, 28 Oct 2022 23:27:51 +0800 Subject: [PATCH] [timeseries] add some basic block and deepAR model (#2027) * feature: add TimeSeriesDataset and training transform * feature: some basic block and deepar model * feature: add train example * feature: add m5-demo and air passengers demo Co-authored-by: Carkham <1302112560@qq.com> Co-authored-by: Frank Liu Co-authored-by: KexinFeng --- api/src/main/java/ai/djl/Application.java | 10 + .../java/ai/djl/mxnet/zoo/MxModelZoo.java | 2 + .../ai/djl/mxnet/deepar/metadata.json | 66 ++ .../timeseries/AirPassengersDeepAR.java | 128 ++++ .../M5ForecastingDeepAR.java} | 64 +- .../inference/timeseries/package-info.java | 15 + .../examples/training/TrainTimeSeries.java | 300 +++++++++ extensions/timeseries/M5DEMO.md | 303 +++++++++ extensions/timeseries/README.md | 184 +++++- .../djl/timeseries/block/FeatureEmbedder.java | 172 +++++ .../timeseries/block/FeatureEmbedding.java | 121 ++++ .../ai/djl/timeseries/block/MeanScaler.java | 110 ++++ .../ai/djl/timeseries/block/NopScaler.java | 71 ++ .../java/ai/djl/timeseries/block/Scaler.java | 96 +++ .../ai/djl/timeseries/block/package-info.java | 15 + .../ai/djl/timeseries/dataset/M5Forecast.java | 3 +- .../timeseries/dataset/TimeSeriesDataset.java | 16 +- .../distribution/NegativeBinomial.java | 63 +- .../output/NegativeBinomialOutput.java | 24 +- .../ai/djl/timeseries/evaluator/Rmsse.java | 125 ++++ .../timeseries/evaluator/package-info.java | 15 + .../model/deepar/DeepARNetwork.java | 616 ++++++++++++++++++ .../model/deepar/DeepARPredictionNetwork.java | 117 ++++ .../model/deepar/DeepARTrainingNetwork.java | 107 +++ .../timeseries/model/deepar/package-info.java | 15 + .../ai/djl/timeseries/block/BlockTest.java | 185 ++++++ .../ai/djl/timeseries/block/package-info.java | 15 + .../distribution/DistributionTest.java | 20 +- .../djl/timeseries/evaluator/RmsseTest.java | 49 ++ .../timeseries/evaluator/package-info.java | 15 + .../ai/djl/timeseries/model/DeepARTest.java | 312 +++++++++ .../ai/djl/timeseries/model/package-info.java | 15 + 32 files changed, 3264 insertions(+), 105 deletions(-) create mode 100644 engines/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/timeseries/forecasting/ai/djl/mxnet/deepar/metadata.json create mode 100644 examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java rename examples/src/main/java/ai/djl/examples/inference/{DeepARTimeSeries.java => timeseries/M5ForecastingDeepAR.java} (85%) create mode 100644 examples/src/main/java/ai/djl/examples/inference/timeseries/package-info.java create mode 100644 examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java create mode 100644 extensions/timeseries/M5DEMO.md create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/block/FeatureEmbedder.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/block/FeatureEmbedding.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/block/MeanScaler.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/block/NopScaler.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/block/Scaler.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/block/package-info.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/package-info.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARNetwork.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARPredictionNetwork.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARTrainingNetwork.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/package-info.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/block/BlockTest.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/block/package-info.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/evaluator/RmsseTest.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/evaluator/package-info.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/model/DeepARTest.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/model/package-info.java diff --git a/api/src/main/java/ai/djl/Application.java b/api/src/main/java/ai/djl/Application.java index cfd8b41b328..be245eca7e2 100644 --- a/api/src/main/java/ai/djl/Application.java +++ b/api/src/main/java/ai/djl/Application.java @@ -314,4 +314,14 @@ public interface Audio { /** Any audio application, including those in {@link Audio}. */ Application ANY = new Application("audio"); } + + /** The common set of applications for timeseries extension. */ + public interface TimeSeries { + + /** + * An application that take a past target vector with corresponding feature and predicts a + * probability distribution based on it. + */ + Application FORECASTING = new Application("timeseries/forecasting"); + } } diff --git a/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/MxModelZoo.java b/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/MxModelZoo.java index 30f7884e030..292f9a0519f 100644 --- a/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/MxModelZoo.java +++ b/engines/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/MxModelZoo.java @@ -14,6 +14,7 @@ import ai.djl.Application.CV; import ai.djl.Application.NLP; +import ai.djl.Application.TimeSeries; import ai.djl.mxnet.engine.MxEngine; import ai.djl.repository.Repository; import ai.djl.repository.zoo.ModelZoo; @@ -53,6 +54,7 @@ public class MxModelZoo extends ModelZoo { addModel(REPOSITORY.model(CV.ACTION_RECOGNITION, GROUP_ID, "action_recognition", "0.0.1")); addModel(REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1")); addModel(REPOSITORY.model(NLP.WORD_EMBEDDING, GROUP_ID, "glove", "0.0.2")); + addModel(REPOSITORY.model(TimeSeries.FORECASTING, GROUP_ID, "deepar", "0.0.1")); } /** {@inheritDoc} */ diff --git a/engines/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/timeseries/forecasting/ai/djl/mxnet/deepar/metadata.json b/engines/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/timeseries/forecasting/ai/djl/mxnet/deepar/metadata.json new file mode 100644 index 00000000000..2eede49adae --- /dev/null +++ b/engines/mxnet/mxnet-model-zoo/src/test/resources/mlrepo/model/timeseries/forecasting/ai/djl/mxnet/deepar/metadata.json @@ -0,0 +1,66 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "timeseries/forecasting", + "groupId": "ai.djl.mxnet", + "artifactId": "resnest", + "name": "deepar", + "description": "DeepAR model for timeseries forecasting", + "website": "http://www.djl.ai/engines/mxnet/model-zoo", + "licenses": { + "apache": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "airpassengers", + "properties": { + "dataset": "airpassengers" + }, + "arguments": { + "prediction_length": 12, + "freq": "M", + "use_feat_dynamic_real": false, + "use_feat_static_cat": false, + "use_feat_static_real": false, + "translatorFactory": "ai.djl.timeseries.translator.DeepARTranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/airpassengers.zip", + "sha1Hash": "1c99cdaefb79c3e63bc7ff1965b0fb2ba45e96c3", + "name": "", + "size": 106895 + } + } + }, + { + "version": "0.0.1", + "snapshot": false, + "name": "m5forecast", + "properties": { + "dataset": "m5forecast" + }, + "arguments": { + "prediction_length": 4, + "freq": "W", + "use_feat_dynamic_real": false, + "use_feat_static_cat": false, + "use_feat_static_real": false, + "translatorFactory": "ai.djl.timeseries.translator.DeepARTranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/m5forecast.zip", + "sha1Hash": "e251628df3a246911479de0ed36762515a5df241", + "name": "", + "size": 96363 + } + } + } + ] +} diff --git a/examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java b/examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java new file mode 100644 index 00000000000..c178e1a94e3 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java @@ -0,0 +1,128 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.examples.inference.timeseries; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.timeseries.Forecast; +import ai.djl.timeseries.SampleForecast; +import ai.djl.timeseries.TimeSeriesData; +import ai.djl.timeseries.dataset.FieldName; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.DeferredTranslatorFactory; +import ai.djl.translate.TranslateException; + +import com.google.gson.GsonBuilder; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.Reader; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.util.Date; + +public final class AirPassengersDeepAR { + + private static final Logger logger = LoggerFactory.getLogger(AirPassengersDeepAR.class); + + private AirPassengersDeepAR() {} + + public static void main(String[] args) throws IOException, TranslateException, ModelException { + float[] results = predict(); + logger.info("{}", results); + } + + public static float[] predict() throws IOException, TranslateException, ModelException { + Criteria criteria = + Criteria.builder() + .setTypes(TimeSeriesData.class, Forecast.class) + .optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/airpassengers") + .optEngine("MXNet") + .optTranslatorFactory(new DeferredTranslatorFactory()) + .optArgument("prediction_length", 12) + .optArgument("freq", "M") + .optArgument("use_feat_dynamic_real", false) + .optArgument("use_feat_static_cat", false) + .optArgument("use_feat_static_real", false) + .optProgress(new ProgressBar()) + .build(); + + String url = "https://resources.djl.ai/test-models/mxnet/timeseries/air_passengers.json"; + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = NDManager.newBaseManager(null, "MXNet")) { + TimeSeriesData data = getTimeSeriesData(manager, new URL(url)); + + // save data for plotting + NDArray target = data.get(FieldName.TARGET); + target.setName("target"); + saveNDArray(target); + + Forecast forecast = predictor.predict(data); + + // save data for plotting. Please see the corresponding python script from + // https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008 + NDArray samples = ((SampleForecast) forecast).getSortedSamples(); + samples.setName("samples"); + saveNDArray(samples); + return forecast.mean().toFloatArray(); + } + } + + private static TimeSeriesData getTimeSeriesData(NDManager manager, URL url) throws IOException { + try (Reader reader = new InputStreamReader(url.openStream(), StandardCharsets.UTF_8)) { + AirPassengers passengers = + new GsonBuilder() + .setDateFormat("yyyy-MM") + .create() + .fromJson(reader, AirPassengers.class); + + LocalDateTime start = + passengers.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime(); + NDArray target = manager.create(passengers.target); + TimeSeriesData data = new TimeSeriesData(10); + data.setStartTime(start); + data.setField(FieldName.TARGET, target); + return data; + } + } + + private static void saveNDArray(NDArray array) throws IOException { + Path path = Paths.get("build").resolve(array.getName() + ".npz"); + try (OutputStream os = Files.newOutputStream(path)) { + new NDList(new NDList(array)).encode(os, true); + } + } + + private static final class AirPassengers { + + Date start; + float[] target; + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/DeepARTimeSeries.java b/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java similarity index 85% rename from examples/src/main/java/ai/djl/examples/inference/DeepARTimeSeries.java rename to examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java index 5c80b0237cf..d43c16e79ce 100644 --- a/examples/src/main/java/ai/djl/examples/inference/DeepARTimeSeries.java +++ b/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java @@ -11,7 +11,7 @@ * and limitations under the License. */ -package ai.djl.examples.inference; +package ai.djl.examples.inference.timeseries; import ai.djl.ModelException; import ai.djl.basicdataset.tabular.utils.DynamicBuffer; @@ -27,8 +27,8 @@ import ai.djl.timeseries.Forecast; import ai.djl.timeseries.TimeSeriesData; import ai.djl.timeseries.dataset.FieldName; -import ai.djl.timeseries.translator.DeepARTranslator; import ai.djl.training.util.ProgressBar; +import ai.djl.translate.DeferredTranslatorFactory; import ai.djl.translate.TranslateException; import ai.djl.util.Progress; @@ -55,17 +55,16 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -public final class DeepARTimeSeries { +public final class M5ForecastingDeepAR { - private static final Logger logger = LoggerFactory.getLogger(DeepARTimeSeries.class); + private static final Logger logger = LoggerFactory.getLogger(M5ForecastingDeepAR.class); - private DeepARTimeSeries() {} + private M5ForecastingDeepAR() {} public static void main(String[] args) throws IOException, TranslateException, ModelException { - logger.info("model: DeepAR"); Map metrics = predict(); for (Map.Entry entry : metrics.entrySet()) { - logger.info(String.format("metric: %s:\t%.2f", entry.getKey(), entry.getValue())); + logger.info("{}", String.format("metric: %s:\t%.2f", entry.getKey(), entry.getValue())); } } @@ -74,25 +73,21 @@ public static Map predict() // M5 Forecasting - Accuracy dataset requires manual download String pathToData = "/Desktop/m5example/m5-forecasting-accuracy"; Path m5ForecastFile = Paths.get(System.getProperty("user.home") + pathToData); - NDManager manager = NDManager.newBaseManager(); + NDManager manager = NDManager.newBaseManager(null, "MXNet"); M5Dataset dataset = M5Dataset.builder().setManager(manager).setRoot(m5ForecastFile).build(); - String modelUrl = "https://resources.djl.ai/test-models/mxnet/timeseries/deepar.zip"; - Map arguments = new ConcurrentHashMap<>(); - int predictionLength = 28; - arguments.put("prediction_length", predictionLength); - arguments.put("freq", "D"); - arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false); - arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false); - arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false); - - DeepARTranslator.Builder builder = DeepARTranslator.builder(arguments); - DeepARTranslator translator = builder.build(); + int predictionLength = 4; Criteria criteria = Criteria.builder() .setTypes(TimeSeriesData.class, Forecast.class) - .optModelUrls(modelUrl) - .optTranslator(translator) + .optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/m5forecast") + .optEngine("MXNet") + .optTranslatorFactory(new DeferredTranslatorFactory()) + .optArgument("prediction_length", predictionLength) + .optArgument("freq", "W") + .optArgument("use_feat_dynamic_real", "false") + .optArgument("use_feat_static_cat", "false") + .optArgument("use_feat_static_real", "false") .optProgress(new ProgressBar()) .build(); @@ -119,6 +114,8 @@ public static Map predict() evaluator.aggregateMetrics(evaluator.getMetricsPerTs(gt, pastTarget, forecast)); progress.increment(1); } + + manager.close(); return evaluator.computeTotalMetrics(); } } @@ -143,14 +140,13 @@ private static final class M5Dataset implements Iterable, Iterator(); - for (int i = 1; i <= 1941; i++) { - target.add(new Feature("d_" + i, true)); + for (int i = 1; i <= 277; i++) { + target.add(new Feature("w_" + i, true)); } } @@ -234,7 +230,8 @@ public M5Dataset build() { } } - private static final class M5Evaluator { + /** An evaluator that calculates performance metrics. */ + public static final class M5Evaluator { private float[] quantiles; Map totalMetrics; Map totalNum; @@ -252,15 +249,14 @@ public Map getMetricsPerTs( new ConcurrentHashMap<>((8 + quantiles.length * 2) * 3 / 2); NDArray meanFcst = forecast.mean(); NDArray medianFcst = forecast.median(); - NDArray target = NDArrays.concat(new NDList(pastTarget, gtTarget), -1); - NDArray successiveDiff = target.get("1:").sub(target.get(":-1")); - successiveDiff = successiveDiff.square(); - successiveDiff = successiveDiff.get(":{}", -forecast.getPredictionLength()); - NDArray denom = successiveDiff.mean(); + NDArray meanSquare = gtTarget.sub(meanFcst).square().mean(); + NDArray scaleDenom = gtTarget.get("1:").sub(gtTarget.get(":-1")).square().mean(); + + NDArray rmsse = meanSquare.div(scaleDenom).sqrt(); + rmsse = NDArrays.where(scaleDenom.eq(0), rmsse.onesLike(), rmsse); - NDArray num = gtTarget.sub(meanFcst).square().mean(); - retMetrics.put("RMSSE", num.getFloat() / denom.getFloat()); + retMetrics.put("RMSSE", rmsse.getFloat()); retMetrics.put("MSE", gtTarget.sub(meanFcst).square().mean().getFloat()); retMetrics.put("abs_error", gtTarget.sub(medianFcst).abs().sum().getFloat()); diff --git a/examples/src/main/java/ai/djl/examples/inference/timeseries/package-info.java b/examples/src/main/java/ai/djl/examples/inference/timeseries/package-info.java new file mode 100644 index 00000000000..1af3c428485 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/timeseries/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains examples of time series forecasting. */ +package ai.djl.examples.inference.timeseries; diff --git a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java new file mode 100644 index 00000000000..65c926c1691 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java @@ -0,0 +1,300 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.examples.training; + +import ai.djl.Model; +import ai.djl.ModelException; +import ai.djl.basicdataset.tabular.utils.Feature; +import ai.djl.engine.Engine; +import ai.djl.examples.inference.timeseries.M5ForecastingDeepAR; +import ai.djl.examples.training.util.Arguments; +import ai.djl.inference.Predictor; +import ai.djl.metric.Metrics; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; +import ai.djl.repository.Repository; +import ai.djl.timeseries.Forecast; +import ai.djl.timeseries.TimeSeriesData; +import ai.djl.timeseries.dataset.FieldName; +import ai.djl.timeseries.dataset.M5Forecast; +import ai.djl.timeseries.dataset.TimeFeaturizers; +import ai.djl.timeseries.distribution.DistributionLoss; +import ai.djl.timeseries.distribution.output.DistributionOutput; +import ai.djl.timeseries.distribution.output.NegativeBinomialOutput; +import ai.djl.timeseries.evaluator.Rmsse; +import ai.djl.timeseries.model.deepar.DeepARNetwork; +import ai.djl.timeseries.timefeature.TimeFeature; +import ai.djl.timeseries.transform.TimeSeriesTransform; +import ai.djl.timeseries.translator.DeepARTranslator; +import ai.djl.training.DefaultTrainingConfig; +import ai.djl.training.EasyTrain; +import ai.djl.training.Trainer; +import ai.djl.training.TrainingResult; +import ai.djl.training.dataset.Batch; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.initializer.XavierInitializer; +import ai.djl.training.listener.SaveModelTrainingListener; +import ai.djl.training.listener.TrainingListener; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; +import ai.djl.util.Progress; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Paths; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** An example of training a deepar timeseries model. */ +public final class TrainTimeSeries { + + private static final Logger logger = LoggerFactory.getLogger(TrainTimeSeries.class); + private static String freq = "W"; + private static int predictionLength = 4; + private static LocalDateTime startTime = LocalDateTime.parse("2011-01-29T00:00"); + + private TrainTimeSeries() {} + + public static void main(String[] args) throws IOException, TranslateException, ModelException { + TrainTimeSeries.runExample(args); + Map metrics = predict("build/model"); + for (Map.Entry entry : metrics.entrySet()) { + logger.info(String.format("metric: %s:\t%.2f", entry.getKey(), entry.getValue())); + } + } + + public static TrainingResult runExample(String[] args) throws IOException, TranslateException { + // use data path to create a custom repository + Repository repository = + Repository.newInstance( + "test", + Paths.get( + System.getProperty("user.home") + + "/Desktop/m5-forecasting-accuracy")); + + Arguments arguments = new Arguments().parseArgs(args); + try (Model model = Model.newInstance("deepar")) { + // specify the model distribution output, for M5 case, NegativeBinomial best describe it + DistributionOutput distributionOutput = new NegativeBinomialOutput(); + DefaultTrainingConfig config = setupTrainingConfig(arguments, distributionOutput); + + NDManager manager = model.getNDManager(); + DeepARNetwork trainingNetwork = getDeepARModel(distributionOutput, true); + model.setBlock(trainingNetwork); + + List trainingTransformation = + trainingNetwork.createTrainingTransformation(manager); + int contextLength = trainingNetwork.getContextLength(); + + M5Forecast trainSet = + getDataset( + trainingTransformation, repository, contextLength, Dataset.Usage.TRAIN); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.setMetrics(new Metrics()); + + int historyLength = trainingNetwork.getHistoryLength(); + Shape[] inputShapes = new Shape[9]; + // (N, num_cardinality) + inputShapes[0] = new Shape(1, 5); + // (N, num_real) if use_feat_stat_real else (N, 1) + inputShapes[1] = new Shape(1, 1); + // (N, history_length, num_time_feat + num_age_feat) + inputShapes[2] = + new Shape( + 1, + historyLength, + TimeFeature.timeFeaturesFromFreqStr(freq).size() + 1); + inputShapes[3] = new Shape(1, historyLength); + inputShapes[4] = new Shape(1, historyLength); + inputShapes[5] = new Shape(1, historyLength); + inputShapes[6] = + new Shape( + 1, + predictionLength, + TimeFeature.timeFeaturesFromFreqStr(freq).size() + 1); + inputShapes[7] = new Shape(1, predictionLength); + inputShapes[8] = new Shape(1, predictionLength); + trainer.initialize(inputShapes); + + EasyTrain.fit(trainer, arguments.getEpoch(), trainSet, null); + return trainer.getTrainingResult(); + } + } + } + + public static Map predict(String outputDir) + throws IOException, TranslateException, ModelException { + Repository repository = + Repository.newInstance( + "test", + Paths.get( + System.getProperty("user.home") + + "/Desktop/m5-forecasting-accuracy")); + + try (Model model = Model.newInstance("deepar")) { + DeepARNetwork predictionNetwork = getDeepARModel(new NegativeBinomialOutput(), false); + model.setBlock(predictionNetwork); + model.load(Paths.get(outputDir)); + + M5Forecast testSet = + getDataset( + new ArrayList<>(), + repository, + predictionNetwork.getContextLength(), + Dataset.Usage.TEST); + + Map arguments = new ConcurrentHashMap<>(); + arguments.put("prediction_length", predictionLength); + arguments.put("freq", freq); + arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false); + arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), true); + arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false); + DeepARTranslator translator = DeepARTranslator.builder(arguments).build(); + + M5ForecastingDeepAR.M5Evaluator evaluator = + new M5ForecastingDeepAR.M5Evaluator(0.5f, 0.67f, 0.95f, 0.99f); + Progress progress = new ProgressBar(); + progress.reset("Inferring", testSet.size()); + try (Predictor predictor = model.newPredictor(translator)) { + for (Batch batch : testSet.getData(model.getNDManager().newSubManager())) { + NDList data = batch.getData(); + NDArray target = data.head(); + NDArray featStaticCat = data.get(1); + + NDArray gt = target.get(":, {}:", -predictionLength); + NDArray pastTarget = target.get(":, :{}", -predictionLength); + + NDList gtSplit = gt.split(batch.getSize()); + NDList pastTargetSplit = pastTarget.split(batch.getSize()); + NDList featStaticCatSplit = featStaticCat.split(batch.getSize()); + + List batchInput = new ArrayList<>(batch.getSize()); + for (int i = 0; i < batch.getSize(); i++) { + TimeSeriesData input = new TimeSeriesData(10); + input.setStartTime(startTime); + input.setField(FieldName.TARGET, pastTargetSplit.get(i).squeeze(0)); + input.setField( + FieldName.FEAT_STATIC_CAT, featStaticCatSplit.get(i).squeeze(0)); + batchInput.add(input); + } + List forecasts = predictor.batchPredict(batchInput); + for (int i = 0; i < forecasts.size(); i++) { + evaluator.aggregateMetrics( + evaluator.getMetricsPerTs( + gtSplit.get(i).squeeze(0), + pastTargetSplit.get(i).squeeze(0), + forecasts.get(i))); + } + progress.increment(batch.getSize()); + batch.close(); + } + return evaluator.computeTotalMetrics(); + } + } + } + + private static DefaultTrainingConfig setupTrainingConfig( + Arguments arguments, DistributionOutput distributionOutput) { + String outputDir = arguments.getOutputDir(); + SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir); + listener.setSaveModelCallback( + trainer -> { + TrainingResult result = trainer.getTrainingResult(); + Model model = trainer.getModel(); + float rmsse = result.getValidateEvaluation("RMSSE"); + model.setProperty("RMSSE", String.format("%.5f", rmsse)); + model.setProperty("Loss", String.format("%.5f", result.getValidateLoss())); + }); + + return new DefaultTrainingConfig(new DistributionLoss("Loss", distributionOutput)) + .addEvaluator(new Rmsse(distributionOutput)) + .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optInitializer(new XavierInitializer(), Parameter.Type.WEIGHT) + .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) + .addTrainingListeners(listener); + } + + /** + * Create the deepar model with specified distribution output. + * + * @param distributionOutput the distribution output + * @param training if training create trainingNetwork else predictionNetwork + * @return deepar model + */ + private static DeepARNetwork getDeepARModel( + DistributionOutput distributionOutput, boolean training) { + // here is feat_static_cat's cardinality which depend on your dataset + List cardinality = new ArrayList<>(); + cardinality.add(3); + cardinality.add(10); + cardinality.add(3); + cardinality.add(7); + cardinality.add(3049); + + DeepARNetwork.Builder builder = + DeepARNetwork.builder() + .setCardinality(cardinality) + .setFreq(freq) + .setPredictionLength(predictionLength) + .optDistrOutput(distributionOutput) + .optUseFeatStaticCat(true); + return training ? builder.buildTrainingNetwork() : builder.buildPredictionNetwork(); + } + + private static M5Forecast getDataset( + List transformation, + Repository repository, + int contextLength, + Dataset.Usage usage) + throws IOException { + // In order to create a TimeSeriesDataset, you must specify the transformation of the data + // preprocessing + M5Forecast.Builder builder = + M5Forecast.builder() + .optUsage(usage) + .optRepository(repository) + .setTransformation(transformation) + .setContextLength(contextLength) + .setSampling(32, usage == Dataset.Usage.TRAIN); + + int maxWeek = usage == Dataset.Usage.TRAIN ? 273 : 277; + for (int i = 1; i <= maxWeek; i++) { + builder.addFeature("w_" + i, FieldName.TARGET); + } + + M5Forecast m5Forecast = + builder.addFeature("state_id", FieldName.FEAT_STATIC_CAT) + .addFeature("store_id", FieldName.FEAT_STATIC_CAT) + .addFeature("cat_id", FieldName.FEAT_STATIC_CAT) + .addFeature("dept_id", FieldName.FEAT_STATIC_CAT) + .addFeature("item_id", FieldName.FEAT_STATIC_CAT) + .addFieldFeature( + FieldName.START, + new Feature( + "date", + TimeFeaturizers.getConstantTimeFeaturizer(startTime))) + .build(); + m5Forecast.prepare(new ProgressBar()); + return m5Forecast; + } +} diff --git a/extensions/timeseries/M5DEMO.md b/extensions/timeseries/M5DEMO.md new file mode 100644 index 00000000000..4c12b6f6a6a --- /dev/null +++ b/extensions/timeseries/M5DEMO.md @@ -0,0 +1,303 @@ +# DJL timeseries package applied on M5forecasting dataset + +The timeseries package contains components and tools for building time series models and +inferring on pretrained models using DJL. + +Now it contains: + +- Translator for preprocess and postprocess data, and also includes the corresponding data transform modules. +- Components for building and training new probabilistic prediction models. +- Time series data loading and processing. +- A M5-forecast dataset. +- Two pre-trained model. +- A pre-built trainable model (DeepAR). + +## M5 Forecasting data + +[M5 Forecasting competition]([M5 Forecasting - Accuracy | Kaggle](https://www.kaggle.com/competitions/m5-forecasting-accuracy/overview/description)) +goal is to forecast future sales at Walmart based on hierarchical sales in the states of California, +Texas, and Wisconsin. It provides information on daily sales, product attributes, prices, and calendars. + +> Notes: Taking into account the model training performance, we sum the sales every 7 days, +> coarse-grained the data, so that the model can better learn the time series information. +> **After downloading the dataset from [M5 Forecasting competition](https://www.kaggle.com/competitions/m5-forecasting-accuracy/overview/description), +> you can use our [script](https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008) +> to get coarse-grained data. This script will create "weekly_xxx.csv" files representing weekly +> data in the dataset directory you specify.** + +## DeepAR model + +DeepAR forecasting algorithm is a supervised learning algorithm for forecasting scalar +(one-dimensional) time series using recurrent neural networks (RNN). + +Unlike traditional time series forecasting models, DeepAR estimates the future probability +distribution of time series based on the past. In retail businesses, probabilistic demand +forecasting is critical to delivering the right inventory at the right time and in the right place. + +Therefore, we choose the sales data set in the real scene as an example to describe how to use +the timeseries package for forecasting + +### Metrics + +We use the following metrics to evaluate the performance of the DeepAR model in the +[M5 Forecasting competition](https://www.kaggle.com/competitions/m5-forecasting-accuracy/overview/description). + +``` +> [INFO ] - metric: Coverage[0.99]: 0.92 +> [INFO ] - metric: Coverage[0.67]: 0.51 +> [INFO ] - metric: abs_target_sum: 1224665.00 +> [INFO ] - metric: abs_target_mean: 10.04 +> [INFO ] - metric: NRMSE: 0.84 +> [INFO ] - metric: RMSE: 8.40 +> [INFO ] - metric: RMSSE: 1.00 +> [INFO ] - metric: abs_error: 14.47 +> [INFO ] - metric: QuantileLoss[0.67]: 18.23 +> [INFO ] - metric: QuantileLoss[0.99]: 103.07 +> [INFO ] - metric: QuantileLoss[0.50]: 9.49 +> [INFO ] - metric: QuantileLoss[0.95]: 66.69 +> [INFO ] - metric: Coverage[0.95]: 0.87 +> [INFO ] - metric: Coverage[0.50]: 0.33 +> [INFO ] - metric: MSE: 70.64 +``` + +As you can see, our pretrained model has some effect on the data prediction of item value. And +some metrics can basically meet expectations. For example, **RMSSE**, which is a measure of the +relative error between the predicted value and the actual value. 1.00 means that the model can +reflect the changes of the time series data to a certain extent. + +## Run the M5 Forecasting example + +### Define your dataset + +In order to realize the preprocessing of time series data, we define the `TimeSeriesData` as +the input of the Translator, which is used to store the feature fields and perform corresponding +transformations. + +So for your own dataset, you need to customize the way you get the data and put it into +`TimeSeriesData` as the input to the translator. + +For M5 dataset we have: + +```java +private static final class M5Dataset implements Iterable, Iterator { + + // coarse-grained data + private static String fileName = "weekly_sales_train_evaluation.csv"; + + private NDManager manager; + private List target; + private List csvRecords; + private long size; + private long current; + + M5Dataset(Builder builder) { + manager = builder.manager; + target = builder.target; + try { + prepare(builder); + } catch (Exception e) { + throw new AssertionError( + String.format("Failed to read m5-forecast-accuracy/%s.", fileName), e); + } + size = csvRecords.size(); + } + + /** Load data into CSVRecords */ + private void prepare(Builder builder) throws IOException { + URL csvUrl = builder.root.resolve(fileName).toUri().toURL(); + try (Reader reader = + new InputStreamReader( + new BufferedInputStream(csvUrl.openStream()), StandardCharsets.UTF_8)) { + CSVParser csvParser = new CSVParser(reader, builder.csvFormat); + csvRecords = csvParser.getRecords(); + } + } + + @Override + public boolean hasNext() { + return current < size; + } + + @Override + public NDList next() { + NDList data = getRowFeatures(manager, current, target); + current++; + return data; + } + + public static Builder builder() { + return new Builder(); + } + + /** Get string data of selected cell from index row in CSV file and create NDArray to save */ + private NDList getRowFeatures(NDManager manager, long index, List selected) { + DynamicBuffer bb = new DynamicBuffer(); + for (Feature feature : selected) { + String name = feature.getName(); + String value = getCell(index, name); + feature.getFeaturizer().featurize(bb, value); + } + FloatBuffer buf = bb.getBuffer(); + return new NDList(manager.create(buf, new Shape(bb.getLength()))); + } + + private String getCell(long rowIndex, String featureName) { + CSVRecord record = csvRecords.get(Math.toIntExact(rowIndex)); + return record.get(featureName); + } + + @Override + public Iterator iterator() { + return this; + } + + public static final class Builder { + + NDManager manager; + List target; + CSVFormat csvFormat; + Path root; + + Builder() { + csvFormat = + CSVFormat.DEFAULT + .builder() + .setHeader() + .setSkipHeaderRecord(true) + .setIgnoreHeaderCase(true) + .setTrim(true) + .build(); + target = new ArrayList<>(); + for (int i = 1; i <= 277; i++) { + target.add(new Feature("w_" + i, true)); + } + } + + public Builder setRoot(Path root) { + this.root = root; + return this; + } + + public Builder setManager(NDManager manager) { + this.manager = manager; + return this; + } + + public M5Dataset build() { + return new M5Dataset(this); + } + } + } +``` + +### Prepare dataset + +Set your own dataset path. + +```java +Path m5ForecastFile = Paths.get("/YOUR PATH/m5-forecasting-accuracy"); +NDManager manager = NDManager.newBaseManager(); +M5Dataset dataset = M5Dataset.builder().setManager(manager).setRoot(m5ForecastFile).build(); +``` + +### Config your translator + +`DeepARTranslator` provides support for data preprocessing and postprocessing for probabilistic +prediction models. Referring to GluonTS, our translator can perform corresponding preprocessing +on `TimeseriesData` containing data according to different parameters to obtain the input of +the network model. And post-processing the output of the network to get the prediction result. + +For DeepAR models, you must set the following arguments. + +```java +Logger logger = LoggerFactory.getLogger(TimeSeriesDemo.class); +String freq = "W"; +int predictionLength = 4; +LocalDateTime startTime = LocalDateTime.parse("2011-01-29T00:00"); + +Map arguments = new ConcurrentHashMap<>(); + +arguments.put("prediction_length", predictionLength); +arguments.put("freq", freq); // The predicted frequency contains units and values + +// Parameters from DeepAR in GluonTS +arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false); +arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false); +arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false); +``` + +For any other GluonTS model, you can quickly develop your own translator using the classes +in `transform` modules (etc. `TransformerTranslator`). + +### Load your own model from the local file system + +At this step, you need to construct the `Criteria` API, which is used as search criteria to look +for a ZooModel. In this application, you can customize your local pretrained model path +(local directory or an archive file containing .`params` and `symbol.json`.) +with .`optModelPath()`. The following code snippet loads the model with the file +path: `/YOUR PATH/deepar.zip` . + +```java +DeepARTranslator translator = DeepARTranslator.builder(arguments).build(); +Criteria criteria = + Criteria.builder() + .setTypes(TimeSeriesData.class, Forecast.class) + .optModelPath(Paths.get("/YOUR PATH/deepar.zip"})) + .optTranslator(translator) + .optProgress(new ProgressBar()) + .build(); +``` + +### Inference + +Now, you are ready to used the model bundled with the translator created above to run inference. + +Since we need to generate features based on dates and make predictions with reference to the +context, for each `TimeSeriesData` you must set the values of its **`StartTime`** and **`TARGET`** fields. + +```java +try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + data = dataset.next(); + NDArray array = data.singletonOrThrow(); + TimeSeriesData input = new TimeSeriesData(10); + input.setStartTime(startTime); // start time of prediction + input.setField(FieldName.TARGET, array); // target value through whole context length + Forecast forecast = predictor.predict(input); + saveResult(forecast); // save result and plot it with python. + } +} +``` + +### Results + +The `Forecast` are objects that contain all the sample paths in the form of `NDArray` +with dimension `(numSamples, predictionLength)`, the start date of the forecast. You can access +all these information by simply invoking the corresponding function. + +You can summarize the sample paths by computing, including the mean and quantile, for each step +in the prediction window. + +```java +logger.info("Mean of the prediction windows:\n" + forecast.mean().toDebugString()); +logger.info("0.5-quantile(Median) of the prediction windows:\n" + forecast.quantile("0.5").toDebugString()); +``` + +``` +> [INFO ] - Mean of the prediction windows: +> ND: (4) cpu() float32 +> [5.97, 6.1 , 5.9 , 6.11] +> +> [INFO ] - 0.5-quantile(Median) of the prediction windows: +> ND: (4) cpu() float32 +> [6., 5., 5., 6.] +``` + +We visualize the forecast result with mean, prediction intervals, etc. + +![m5_forecast_0](https://resources.djl.ai/images/timeseries/m5_forecast_0.jpg) + +### Metrics + +Here we compute aggregate performance metrics in the +[source code](../../examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java) diff --git a/extensions/timeseries/README.md b/extensions/timeseries/README.md index 1f26aa49f1d..7405ee670e8 100644 --- a/extensions/timeseries/README.md +++ b/extensions/timeseries/README.md @@ -2,7 +2,17 @@ This module contains the time series model support extension with [GluonTS](https://github.com/awslabs/gluonts). -Right now, the package provides the `BaseTimeSeriesTranslator` and transform package that allows you to do inference from your pre-trained time series model. +Right now, the package provides the `BaseTimeSeriesTranslator` and transform package that allows +you to do inference from your pre-trained time series model. + +Now it contains: + +- Translator for preprocess and postprocess data, and also includes the corresponding data transform modules. +- Components for building and training new probabilistic prediction models. +- Time series data loading and processing. +- A M5-forecast dataset. +- Two pre-trained model. +- A pre-built trainable model (DeepAR). ## Module structure @@ -10,13 +20,16 @@ Right now, the package provides the `BaseTimeSeriesTranslator` and transform pac An abstract class representing the forecast result. -It contains the distribution of the results, the start date of the forecast, the frequency of the time series, etc. User can get all these information by simply invoking the corresponding attribute. +It contains the distribution of the results, the start date of the forecast, the frequency of the +time series, etc. User can get all these information by simply invoking the corresponding attribute. -- `SampleForecast` extends the `Forecast` that contain all the sample paths in the form of `NDArray`. User can query the prediction results by accessing the data in the samples. +- `SampleForecast` extends the `Forecast` that contain all the sample paths in the form of `NDArray`. +User can query the prediction results by accessing the data in the samples. ### TimeSeriesData -The data entry for managing timing data in preprocessing as an input to the transform method. It contains a key-value pair list mapping from the time series name field to `NDArray`. +The data entry for managing timing data in preprocessing as an input to the transform method. +It contains a key-value pair list mapping from the time series name field to `NDArray`. ### dataset @@ -31,43 +44,60 @@ This module contains all the methods for generating time features from the predi ### transform -In general, it gets the `TimeSeriesData` and transform it to another `TimeSeriesData` that can possibly contain more fields. It can be done by defining a set of of "actions" to the raw dataset in training or just invoking at translator in inference. +In general, it gets the `TimeSeriesData` and transform it to another `TimeSeriesData` that can +possibly contain more fields. It can be done by defining a set of of "actions" to the raw dataset +in training or just invoking at translator in inference. This action usually create some additional features or transform an existing feature. #### convert - Convert -- Convert the array shape to the preprocessing. -- VstackFeatures.java -- vstack the inputs name field of the `TimeSeriesData`. We make it implement the `TimeSeriesTransform` interface for **training feature.** +- VstackFeatures.java -- vstack the inputs name field of the `TimeSeriesData`. We make it implement +- the `TimeSeriesTransform` interface for **training feature.** #### feature - Feature -- Add time features to the preprocessing. -- AddAgeFeature -- Creates the `FEAT_DYNAMIC_AGE` name field in the `TimeSeriesData`. Adds a feature that its value is small for distant past timestamps and it monotonically increases the more we approach the current timestamp. We make it implement the `TimeSeriesTransform` interface for **training feature.** -- AddObservedValueIndicator -- Creates the `OBSERVED_VALUES` name field in the `TimeSeriesData`. Adds a feature that equals to 1 if the value is observed and 0 if the value is missing. We make it implement the `TimeSeriesTransform` interface for **training feature.** -- AddTimeFeature -- Creates the `FEAT_TIME` name field in the `TimeSeriesData`. Adds a feature that its value is based on the different prediction frequencies. We make it implement the `TimeSeriesTransform` interface for **training feature.** +- AddAgeFeature -- Creates the `FEAT_DYNAMIC_AGE` name field in the `TimeSeriesData`. Adds a +feature that its value is small for distant past timestamps and it monotonically increases +the more we approach the current timestamp. We make it implement the `TimeSeriesTransform` +interface for **training feature.** +- AddObservedValueIndicator -- Creates the `OBSERVED_VALUES` name field in the `TimeSeriesData`. +Adds a feature that equals to 1 if the value is observed and 0 if the value is missing. +We make it implement the `TimeSeriesTransform` interface for **training feature.** +- AddTimeFeature -- Creates the `FEAT_TIME` name field in the `TimeSeriesData`. Adds a feature +that its value is based on the different prediction frequencies. We make it implement the +`TimeSeriesTransform` interface for **training feature.** #### field -- Field -- Process key-value data entry to the preprocessing. It usually add or remove the feature in the `TimeSeriesData`. -- RemoveFields -- Remove the input name field. We make it implement the `TimeSeriesTransform` interface for **training feature.** -- SelectField -- Only keep input name fields. We make it implement the `TimeSeriesTransform` interface for **training feature.** -- SetField -- Set the input name field with `NDArray`. We make it implement the `TimeSeriesTransform` interface for **training feature.** +- Field -- Process key-value data entry to the preprocessing. It usually add or remove the +feature in the `TimeSeriesData`. +- RemoveFields -- Remove the input name field. We make it implement the `TimeSeriesTransform` +interface for **training feature.** +- SelectField -- Only keep input name fields. We make it implement the `TimeSeriesTransform` +interface for **training feature.** +- SetField -- Set the input name field with `NDArray`. We make it implement the +`TimeSeriesTransform` interface for **training feature.** #### split - Split -- Split time series data for training and inferring to the preprocessing. -- InstanceSplit -- Split time series data with the slice from `Sampler` for training and inferring to the preprocessing. We make it implement the `TimeSeriesTransform` interface for **training feature.** +- InstanceSplit -- Split time series data with the slice from `Sampler` for training and inferring +to the preprocessing. We make it implement the `TimeSeriesTransform` interface for **training feature.** ### InstanceSampler Sample index for splitting based on training or inferring. -`PredictionSampler` extends `InstanceSampler` for the prediction including test and valid. It would return the end of the time series bound as the dividing line between the future and past. +`PredictionSampler` extends `InstanceSampler` for the prediction including test and valid. +It would return the end of the time series bound as the dividing line between the future and past. ### translator -Existing time series model translators and corresponding factories. Now we have developed `DeepARTranslator` and `TransformerTranslator` for users. +Existing time series model translators and corresponding factories. Now we have developed +`DeepARTranslator` and `TransformerTranslator` for users. The following pseudocode demonstrates how to create a `DeepARTranslator` with `arguments`. @@ -79,15 +109,127 @@ The following pseudocode demonstrates how to create a `DeepARTranslator` with `a DeepARTranslator translator = builder.build(); ``` -If you want to customize your own time series model translator, you can easily use the transform package for your data preprocess. +If you want to customize your own time series model translator, you can easily use the transform +package for your data preprocess. See [examples](src/test/java/ai/djl/timeseries/translator/DeepARTranslatorTest.java) for more details. -We plan to add the following features in the future: +## Simple Example + +To demonstrate how to use the timeseries package, we trained a DeepAR model on a simple dataset +and used it for prediction. This dataset contains monthly air passenger numbers from 1949 to 1960. +We will train on the first 9 years of data and predict the last 36 months of data. + +### Define Data + +In order to realize the preprocessing of time series data, we define the `TimeSeriesData` as the +input of the Translator, which is used to store the feature fields and perform corresponding +transformations. + +Here we define how to get `TimeSeriesData` from the dataset. + + +```java +public static class AirPassengers { + + private Path path; + private AirPassengerData data; + + public AirPassengers(Path path) { + this.path = path; + prepare(); + } + + public TimeSeriesData get(NDManager manager) { + LocalDateTime start = + data.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime(); + NDArray target = manager.create(data.target); + TimeSeriesData ret = new TimeSeriesData(10); + // A TimeSeriesData must contain start time and target value. + ret.setStartTime(start); + ret.setField(FieldName.TARGET, target); + return ret; + } + + /** prepare the file data */ + private void prepare() { + Path filePath = path.resolve("test").resolve("data.json.gz"); + try { + URL url = filePath.toUri().toURL(); + try (GZIPInputStream is = new GZIPInputStream(url.openStream())) { + Reader reader = new InputStreamReader(is); + data = + new GsonBuilder() + .setDateFormat("yyyy-MM") + .create() + .fromJson(reader, AirPassengerData.class); + } + } catch (IOException e) { + throw new IllegalArgumentException("Invalid url: " + filePath, e); + } + } + + private static class AirPassengerData { + Date start; + float[] target; + } +} +``` + +### Predict + +In djl we need to define `Translator` to help us with data pre- and post-processing. + +```java +public static float[] predict() throws IOException, TranslateException, ModelException { + Map arguments = new ConcurrentHashMap<>(); + // set parameter + arguments.put("prediction_length", 12); + arguments.put("freq", "M"); + arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false); + arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false); + arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false); + + // build translator + DeepARTranslator translator = DeepARTranslator.builder(arguments).build(); + + // create criteria + Criteria criteria = + Criteria.builder() + .setTypes(TimeSeriesData.class, Forecast.class) + .optModelPath(Paths.get(modelUrl)) + .optTranslator(translator) + .optProgress(new ProgressBar()) + .build(); + + // load model + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + NDManager manager = model.getNDManager(); + + AirPassengers ap = new AirPassengers(Paths.get("Not implemented")); + TimeSeriesData data = ap.get(manager); + + // prediction + Forecast forecast = predictor.predict(data); + + return forecast.mean().toFloatArray(); + } +} + +``` +### Visualize + +![simple_forecast](https://resources.djl.ai/images/timeseries/simple_forecast.png) + +Note that the prediction results are displayed in the form of probability distributions, and the +shaded areas represent different prediction intervals. + +Since djl doesn't support drawing yet, you can find our script for visualization +[here](https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008). -- a `TimeSeriesDataset`class to support creating data entry and transforming raw csv data like in TimeSeries. -- Many time series models that can be trained in djl. -- ...... +The **full source code** for this example is available +[here](../../examples/src/main/java/ai/djl/examples/inference/timeseries/AirPassengersDeepAR.java) ## Documentation diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/block/FeatureEmbedder.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/FeatureEmbedder.java new file mode 100644 index 00000000000..21b2efed0fe --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/FeatureEmbedder.java @@ -0,0 +1,172 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.block; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractBlock; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +import java.util.ArrayList; +import java.util.List; + +/** Embed a sequence of categorical features. */ +public class FeatureEmbedder extends AbstractBlock { + + private List cardinalities; + private List embeddingDims; + private List embedders; + private int numFeatures; + + FeatureEmbedder(Builder builder) { + cardinalities = builder.cardinalities; + embeddingDims = builder.embeddingDims; + numFeatures = cardinalities.size(); + embedders = new ArrayList<>(); + for (int i = 0; i < cardinalities.size(); i++) { + embedders.add(createEmbedding(i, cardinalities.get(i), embeddingDims.get(i))); + } + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + // Categorical features with shape: (N,T,C) or (N,C), where C is the number of categorical + // features. + NDArray features = inputs.singletonOrThrow(); + + NDList catFeatureSlices; + if (numFeatures > 1) { + // slice the last dimension, giving an array of length numFeatures with shape (N,T) or + // (N) + catFeatureSlices = features.split(numFeatures, features.getShape().dimension() - 1); + } else { + catFeatureSlices = new NDList(features); + } + + NDList output = new NDList(); + for (int i = 0; i < numFeatures; i++) { + FeatureEmbedding embed = embedders.get(i); + NDArray catFeatureSlice = catFeatureSlices.get(i); + catFeatureSlice = catFeatureSlice.squeeze(-1); + output.add( + embed.forward(parameterStore, new NDList(catFeatureSlice), training, params) + .singletonOrThrow()); + } + return new NDList(NDArrays.concat(output, -1)); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + Shape inputShape = inputShapes[0]; + Shape[] embedInputShapes = {inputShape.slice(0, inputShape.dimension() - 1)}; + long embedSizes = 0; + for (FeatureEmbedding embed : embedders) { + embedSizes += embed.getOutputShapes(embedInputShapes)[0].tail(); + } + return new Shape[] {inputShape.slice(0, inputShape.dimension() - 1).add(embedSizes)}; + } + + /** {@inheritDoc} */ + @Override + protected void initializeChildBlocks( + NDManager manager, DataType dataType, Shape... inputShapes) { + for (FeatureEmbedding embed : embedders) { + embed.initialize(manager, dataType, inputShapes); + } + } + + private FeatureEmbedding createEmbedding(int i, int c, int d) { + FeatureEmbedding embedding = + FeatureEmbedding.builder().setNumEmbeddings(c).setEmbeddingSize(d).build(); + addChildBlock(String.format("cat_%d_embedding", i), embedding); + return embedding; + } + + /** + * Return a builder to build an {@code FeatureEmbedder}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** The builder to construct a {@link FeatureEmbedder} type of {@link ai.djl.nn.Block}. */ + public static final class Builder { + + private List cardinalities; + private List embeddingDims; + + /** + * Set the cardinality for each categorical feature. + * + * @param cardinalities the cardinality for each categorical feature + * @return this Builder + */ + public Builder setCardinalities(List cardinalities) { + this.cardinalities = cardinalities; + return this; + } + + /** + * Set the number of dimensions to embed each categorical feature. + * + * @param embeddingDims number of dimensions to embed each categorical feature + * @return this Builder + */ + public Builder setEmbeddingDims(List embeddingDims) { + this.embeddingDims = embeddingDims; + return this; + } + + /** + * Return the constructed {@code FeatureEmbedder}. + * + * @return the constructed {@code FeatureEmbedder} + */ + public FeatureEmbedder build() { + if (cardinalities.isEmpty()) { + throw new IllegalArgumentException( + "Length of 'cardinalities' list must be greater than zero"); + } + if (cardinalities.size() != embeddingDims.size()) { + throw new IllegalArgumentException( + "Length of `cardinalities` and `embedding_dims` should match"); + } + for (int c : cardinalities) { + if (c <= 0) { + throw new IllegalArgumentException("Elements of `cardinalities` should be > 0"); + } + } + for (int d : embeddingDims) { + if (d <= 0) { + throw new IllegalArgumentException( + "Elements of `embedding_dims` should be > 0"); + } + } + return new FeatureEmbedder(this); + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/block/FeatureEmbedding.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/FeatureEmbedding.java new file mode 100644 index 00000000000..edc4b580573 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/FeatureEmbedding.java @@ -0,0 +1,121 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.block; + +import ai.djl.Device; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; +import ai.djl.ndarray.types.SparseFormat; +import ai.djl.nn.AbstractBlock; +import ai.djl.nn.Parameter; +import ai.djl.nn.core.Embedding; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +/** An implement of nn.embedding. */ +public final class FeatureEmbedding extends AbstractBlock { + + private static final String EMBEDDING_PARAM_NAME = "embedding"; + + private int embeddingSize; + private int numEmbeddings; + + private Parameter embedding; + + FeatureEmbedding(Builder builder) { + embeddingSize = builder.embeddingSize; + numEmbeddings = builder.numEmbeddings; + embedding = + addParameter( + Parameter.builder() + .setName(EMBEDDING_PARAM_NAME) + .setType(Parameter.Type.WEIGHT) + .optShape(new Shape(numEmbeddings, embeddingSize)) + .build()); + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDArray input = inputs.singletonOrThrow(); + Device device = input.getDevice(); + NDArray weight = parameterStore.getValue(embedding, device, training); + return Embedding.embedding(input, weight, SparseFormat.DENSE); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))}; + } + + /** + * Return a builder to build an {@code FeatureEmbedding}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** The builder to construct a {@link FeatureEmbedding} type of {@link ai.djl.nn.Block}. */ + public static final class Builder { + + private int embeddingSize; + private int numEmbeddings; + + /** + * Sets the size of the embeddings. + * + * @param embeddingSize the size of the embeddings + * @return this Builder + */ + public Builder setEmbeddingSize(int embeddingSize) { + this.embeddingSize = embeddingSize; + return this; + } + + /** + * Sets the size of the dictionary of embeddings. + * + * @param numEmbeddings the size of the dictionary of embeddings + * @return this Builder + */ + public Builder setNumEmbeddings(int numEmbeddings) { + this.numEmbeddings = numEmbeddings; + return this; + } + + /** + * Return the constructed {@code FeatureEmbedding}. + * + * @return the constructed {@code FeatureEmbedding} + */ + public FeatureEmbedding build() { + if (numEmbeddings <= 0) { + throw new IllegalArgumentException( + "You must specify the dictionary Size for the embedding."); + } + if (embeddingSize == 0) { + throw new IllegalArgumentException("You must specify the embedding size"); + } + return new FeatureEmbedding(this); + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/block/MeanScaler.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/MeanScaler.java new file mode 100644 index 00000000000..dab1367564c --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/MeanScaler.java @@ -0,0 +1,110 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.block; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +/** + * A class computes a scaling factor as the weighted average absolute value along dimension {@code + * dim}, and scales the data accordingly. + */ +public class MeanScaler extends Scaler { + + private float minimumScale; + + MeanScaler(Builder builder) { + super(builder); + minimumScale = builder.minimumScale; + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDArray data = inputs.get(0); + NDArray weights = inputs.get(1); + + NDArray totalWeight = weights.sum(new int[] {dim}); + NDArray weightedSum = data.abs().mul(weights).sum(new int[] {dim}); + + NDArray totalObserved = totalWeight.sum(new int[] {0}); + NDArray denominator = NDArrays.maximum(totalObserved, 1f); + NDArray defaultScale = weightedSum.sum(new int[] {0}).div(denominator); + + denominator = NDArrays.maximum(totalWeight, 1f); + NDArray scale = weightedSum.div(denominator); + + scale = + NDArrays.maximum( + minimumScale, + NDArrays.where( + weightedSum.gt(weightedSum.zerosLike()), + scale, + defaultScale.mul(totalWeight.onesLike()))) + .expandDims(dim); + + return new NDList(data.div(scale), keepDim ? scale : scale.squeeze(dim)); + } + + /** + * Create a builder to build a {@code MeanScaler}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** The builder to construct a {@code MeanScaler}. */ + public static final class Builder extends ScalerBuilder { + + private float minimumScale = 1e-10f; + + Builder() {} + + /** + * Sets the minimum scalar of the data. + * + * @param minimumScale the minimum value + * @return this Builder + */ + public Builder optMinimumScale(float minimumScale) { + this.minimumScale = minimumScale; + return this; + } + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + + /** + * Return the constructed {@code MeanScaler}. + * + * @return the constructed {@code MeanScaler} + */ + public MeanScaler build() { + validate(); + return new MeanScaler(this); + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/block/NopScaler.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/NopScaler.java new file mode 100644 index 00000000000..4fda9c52326 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/NopScaler.java @@ -0,0 +1,71 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.block; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +/** + * A class assigns a scaling factor equal to 1 along dimension {@code dim}, and therefore applies no + * scaling to the input data. + */ +public class NopScaler extends Scaler { + + NopScaler(Builder builder) { + super(builder); + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDArray data = inputs.get(0); + NDArray scale = data.onesLike().mean(new int[] {dim}, keepDim); + return new NDList(data, scale); + } + + /** + * Create a builder to build a {@code NopScaler}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** The builder to construct a {@code NopScaler}. */ + public static final class Builder extends ScalerBuilder { + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + + /** + * Return constructed {@code NOPScaler}. + * + * @return the constructed {@code NOPScaler} + */ + public NopScaler build() { + validate(); + return new NopScaler(this); + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/block/Scaler.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/Scaler.java new file mode 100644 index 00000000000..e9175a1ee2b --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/Scaler.java @@ -0,0 +1,96 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.block; + +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractBlock; +import ai.djl.util.Preconditions; + +/** An abstract class used to scale data. */ +public abstract class Scaler extends AbstractBlock { + + private static final byte VERSION = 1; + + protected int dim; + protected boolean keepDim; + + Scaler(ScalerBuilder builder) { + super(VERSION); + dim = builder.dim; + keepDim = builder.keepDim; + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + Shape inputShape = inputShapes[0]; + Shape outputShape = new Shape(); + for (int i = 0; i < inputShape.dimension(); i++) { + if (i != dim) { + outputShape = outputShape.add(inputShape.get(i)); + } else { + if (keepDim) { + outputShape = outputShape.add(1L); + } + } + } + return new Shape[] {inputShape, outputShape}; + } + + /** + * A builder to extend for all classes extend the {@link Scaler}. + * + * @param the concrete builder type + */ + public abstract static class ScalerBuilder> { + + protected int dim; + protected boolean keepDim; + + /** + * Set the dim to scale. + * + * @param dim which dim to scale + * @return this Builder + */ + public T setDim(int dim) { + this.dim = dim; + return self(); + } + + /** + * Set whether to keep dim. Defaults to false; + * + * @param keepDim whether to keep dim + * @return this Builder + */ + public T optKeepDim(boolean keepDim) { + this.keepDim = keepDim; + return self(); + } + + /** + * Validates that the required arguments are set. + * + * @throws IllegalArgumentException if the required arguments are illegal + */ + protected void validate() { + Preconditions.checkArgument( + dim > 0, + "Cannot compute scale along dim = 0 (batch dimension), please provide dim > 0"); + } + + protected abstract T self(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/block/package-info.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/package-info.java new file mode 100644 index 00000000000..dc63372687d --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/block/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains the basic block classes. */ +package ai.djl.timeseries.block; diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/dataset/M5Forecast.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/dataset/M5Forecast.java index a379f76a07d..fa2f6546aaf 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/dataset/M5Forecast.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/dataset/M5Forecast.java @@ -264,7 +264,8 @@ private void parseFeatures() { Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) { mf = JsonUtils.GSON.fromJson(reader, M5Features.class); } catch (IOException e) { - throw new AssertionError("Failed to read m5forecast.json from classpath", e); + throw new AssertionError( + "Failed to read m5forecast_parser.json from classpath", e); } } } diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/dataset/TimeSeriesDataset.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/dataset/TimeSeriesDataset.java index a963800f90a..bec33de59d9 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/dataset/TimeSeriesDataset.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/dataset/TimeSeriesDataset.java @@ -54,17 +54,27 @@ public Record get(NDManager manager, long index) { TimeSeriesData data = getTimeSeriesData(manager, index); if (transformation.isEmpty()) { - // For inference + // For inference with translator return new Record(data.toNDList(), new NDList()); } data = apply(manager, data); - if (!data.contains("PAST_" + FieldName.TARGET) - || !data.contains("FUTURE_" + FieldName.TARGET)) { + + // For both training and prediction + if (!data.contains("PAST_" + FieldName.TARGET)) { throw new IllegalArgumentException( "Transformation must include InstanceSampler to split data into past and future" + " part"); } + if (!data.contains("FUTURE_" + FieldName.TARGET)) { + // Warning: We do not recommend using TimeSeriesDataset directly to generate the + // inference input, using Translator instead + // For prediction without translator, we don't need labels and corresponding + // FUTURE_TARGET. + return new Record(data.toNDList(), new NDList()); + } + + // For training, we must have the FUTURE_TARGET label to compute Loss. NDArray contextTarget = data.get("PAST_" + FieldName.TARGET).get("{}:", -contextLength + 1); NDArray futureTarget = data.get("FUTURE_" + FieldName.TARGET); NDList label = new NDList(contextTarget.concat(futureTarget, 0)); diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java index 8565babc244..179f07818d5 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java @@ -22,49 +22,55 @@ * *

The distribution of the number of successes in a sequence of independent Bernoulli trials. * - *

Two arguments for this distribution. {@code mu} mean of the distribution, {@code alpha} the - * inverse number of negative Bernoulli trials to stop + *

Two arguments for this distribution. {@code total_count} non-negative number of negative + * Bernoulli trials to stop, {@code logits} Event log-odds for probabilities of success */ public final class NegativeBinomial extends Distribution { - private NDArray mu; - private NDArray alpha; + private NDArray totalCount; + private NDArray logits; NegativeBinomial(Builder builder) { - mu = builder.distrArgs.get("mu"); - alpha = builder.distrArgs.get("alpha"); + totalCount = builder.distrArgs.get("total_count"); + logits = builder.distrArgs.get("logits"); } /** {@inheritDoc} */ @Override public NDArray logProb(NDArray target) { + NDArray logUnnormalizedProb = + totalCount.mul(logSigmoid(logits.mul(-1))).add(target.mul(logSigmoid(logits))); - NDArray alphaInv = alpha.getNDArrayInternal().rdiv(1); - NDArray alphaTimesMu = alpha.mul(mu); - - return target.mul(alphaTimesMu.div(alphaTimesMu.add(1)).log()) - .sub(alphaInv.mul(alphaTimesMu.add(1).log())) - .add(target.add(alphaInv).gammaln()) - .sub(target.add(1.).gammaln()) - .sub(alphaInv.gammaln()); + NDArray logNormalization = + totalCount + .add(target) + .gammaln() + .mul(-1) + .add(target.add(1).gammaln()) + .add(totalCount.gammaln()); + return logUnnormalizedProb.sub(logNormalization); } /** {@inheritDoc} */ @Override public NDArray sample(int numSamples) { - NDManager manager = mu.getManager(); - NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu; - NDArray expandedAlpha = numSamples > 0 ? alpha.expandDims(0).repeat(0, numSamples) : alpha; + NDManager manager = totalCount.getManager(); + NDArray expandedTotalCount = + numSamples > 0 ? totalCount.expandDims(0).repeat(0, numSamples) : totalCount; + NDArray expandedLogits = + numSamples > 0 ? logits.expandDims(0).repeat(0, numSamples) : logits; - NDArray r = expandedAlpha.getNDArrayInternal().rdiv(1f); - NDArray theta = expandedAlpha.mul(expandedMu); - return manager.samplePoisson(manager.sampleGamma(r, theta)); + return manager.samplePoisson(manager.sampleGamma(expandedTotalCount, expandedLogits.exp())); } /** {@inheritDoc} */ @Override public NDArray mean() { - return mu; + return totalCount.mul(logits.exp()); + } + + private NDArray logSigmoid(NDArray x) { + return x.mul(-1).exp().add(1).getNDArrayInternal().rdiv(1).log(); } /** @@ -83,17 +89,18 @@ public static final class Builder extends DistributionBuilder { @Override public Distribution build() { Preconditions.checkArgument( - distrArgs.contains("mu"), "NegativeBinomial's args must contain mu."); + distrArgs.contains("total_count"), + "NegativeBinomial's args must contain total_count."); Preconditions.checkArgument( - distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha."); + distrArgs.contains("logits"), "NegativeBinomial's args must contain logits."); // We cannot scale using the affine transformation since negative binomial should return // integers. Instead we scale the parameters. if (scale != null) { - NDArray mu = distrArgs.get("mu"); - mu = mu.mul(scale); - mu.setName("mu"); - distrArgs.remove("mu"); - distrArgs.add(mu); + NDArray logits = distrArgs.get("logits"); + logits.add(scale.log()); + logits.setName("logits"); + distrArgs.remove("logits"); + distrArgs.add(logits); } return new NegativeBinomial(this); } diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java index 84222fe9741..9d3bd44ea1b 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java @@ -25,24 +25,26 @@ */ public final class NegativeBinomialOutput extends DistributionOutput { - /** Construct a negative binomial output with two arguments, {@code mu} and {@code alpha}. */ + /** + * Construct a negative binomial output with two arguments, {@code total_count} and {@code + * logits}. + */ public NegativeBinomialOutput() { argsDim = new PairList<>(2); - argsDim.add("mu", 1); - argsDim.add("alpha", 1); + argsDim.add("total_count", 1); + argsDim.add("logits", 1); } /** {@inheritDoc} */ @Override public NDList domainMap(NDList arrays) { - NDArray mu = arrays.get(0); - NDArray alpha = arrays.get(1); - mu = mu.getNDArrayInternal().softPlus().maximum(Float.MIN_VALUE).squeeze(-1); - alpha = alpha.getNDArrayInternal().softPlus().maximum(Float.MIN_VALUE).squeeze(-1); - // TODO: make setName() must be implemented - mu.setName("mu"); - alpha.setName("alpha"); - return new NDList(mu, alpha); + NDArray totalCount = arrays.get(0); + NDArray logits = arrays.get(1); + totalCount = totalCount.getNDArrayInternal().softPlus().squeeze(-1); + logits = logits.squeeze(-1); + totalCount.setName("total_count"); + logits.setName("logits"); + return new NDList(totalCount, logits); } /** {@inheritDoc} */ diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java new file mode 100644 index 00000000000..5b642285c3e --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java @@ -0,0 +1,125 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.evaluator; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.output.DistributionOutput; +import ai.djl.training.evaluator.Evaluator; +import ai.djl.util.Pair; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** A class used to calculate Root Mean Squared Scaled Error. */ +public class Rmsse extends Evaluator { + + private DistributionOutput distributionOutput; + private int axis; + private Map totalLoss; + + /** + * Creates an evaluator that computes Root Mean Squared Scaled Error across axis 1. + * + *

Please referring https://www.kaggle.com/competitions/m5-forecasting-accuracy/overview/evaluation + * for more details. + * + * @param distributionOutput the {@link DistributionOutput} to construct the target distribution + */ + public Rmsse(DistributionOutput distributionOutput) { + this("RMSSE", 1, distributionOutput); + } + + /** + * Creates an evaluator that computes Root Mean Squared Scaled Error across axis 1. + * + *

Please referring https://www.kaggle.com/competitions/m5-forecasting-accuracy/overview/evaluation + * for more details. + * + * @param name the name of the evaluator, default is "RMSSE" + * @param axis the axis that represent time length in prediction, default 1 + * @param distributionOutput the {@link DistributionOutput} to construct the target distribution + */ + public Rmsse(String name, int axis, DistributionOutput distributionOutput) { + super(name); + this.axis = axis; + this.distributionOutput = distributionOutput; + totalLoss = new ConcurrentHashMap<>(); + } + + protected Pair evaluateHelper(NDList labels, NDList predictions) { + NDArray label = labels.head(); + NDArray prediction = + distributionOutput.distributionBuilder().setDistrArgs(predictions).build().mean(); + + checkLabelShapes(label, prediction); + NDArray meanSquare = label.sub(prediction).square().mean(new int[] {axis}); + NDArray scaleDenom = + label.get(":, 1:").sub(label.get(":, :-1")).square().mean(new int[] {axis}); + + NDArray rmsse = meanSquare.div(scaleDenom).sqrt(); + rmsse = NDArrays.where(scaleDenom.eq(0), rmsse.onesLike(), rmsse); + long total = rmsse.countNonzero().getLong(); + + return new Pair<>(total, rmsse); + } + + /** {@inheritDoc} */ + @Override + public NDArray evaluate(NDList labels, NDList predictions) { + return evaluateHelper(labels, predictions).getValue(); + } + + /** {@inheritDoc} */ + @Override + public void addAccumulator(String key) { + totalInstances.put(key, 0L); + totalLoss.put(key, 0f); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulator(String key, NDList labels, NDList predictions) { + Pair update = evaluateHelper(labels, predictions); + totalInstances.compute(key, (k, v) -> v + update.getKey()); + totalLoss.compute( + key, + (k, v) -> { + try (NDArray array = update.getValue().sum()) { + return v + array.getFloat(); + } + }); + } + + /** {@inheritDoc} */ + @Override + public void resetAccumulator(String key) { + totalInstances.compute(key, (k, v) -> 0L); + totalLoss.compute(key, (k, v) -> 0f); + } + + /** {@inheritDoc} */ + @Override + public float getAccumulator(String key) { + Long total = totalInstances.get(key); + if (total == null || total == 0) { + return Float.NaN; + } + + return (float) totalLoss.get(key) / totalInstances.get(key); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/package-info.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/package-info.java new file mode 100644 index 00000000000..584359892cc --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains the evaluator classes. */ +package ai.djl.timeseries.evaluator; diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARNetwork.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARNetwork.java new file mode 100644 index 00000000000..4a60b8dc0a0 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARNetwork.java @@ -0,0 +1,616 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.model.deepar; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractBlock; +import ai.djl.nn.Block; +import ai.djl.nn.recurrent.LSTM; +import ai.djl.timeseries.block.FeatureEmbedder; +import ai.djl.timeseries.block.MeanScaler; +import ai.djl.timeseries.block.NopScaler; +import ai.djl.timeseries.block.Scaler; +import ai.djl.timeseries.dataset.FieldName; +import ai.djl.timeseries.distribution.output.DistributionOutput; +import ai.djl.timeseries.distribution.output.StudentTOutput; +import ai.djl.timeseries.timefeature.Lag; +import ai.djl.timeseries.timefeature.TimeFeature; +import ai.djl.timeseries.transform.ExpectedNumInstanceSampler; +import ai.djl.timeseries.transform.InstanceSampler; +import ai.djl.timeseries.transform.PredictionSplitSampler; +import ai.djl.timeseries.transform.TimeSeriesTransform; +import ai.djl.timeseries.transform.convert.VstackFeatures; +import ai.djl.timeseries.transform.feature.AddAgeFeature; +import ai.djl.timeseries.transform.feature.AddObservedValuesIndicator; +import ai.djl.timeseries.transform.feature.AddTimeFeature; +import ai.djl.timeseries.transform.field.RemoveFields; +import ai.djl.timeseries.transform.field.SelectField; +import ai.djl.timeseries.transform.field.SetField; +import ai.djl.timeseries.transform.split.InstanceSplit; +import ai.djl.training.ParameterStore; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +/** + * Implements the deepar model. + * + *

This closely follows the Salinas et al. 2020 and + * its gluonts implementation. + */ +public abstract class DeepARNetwork extends AbstractBlock { + + private static final String[] TRAIN_INPUT_FIELDS = { + FieldName.FEAT_STATIC_CAT.name(), + FieldName.FEAT_STATIC_REAL.name(), + "PAST_" + FieldName.FEAT_TIME.name(), + "PAST_" + FieldName.TARGET.name(), + "PAST_" + FieldName.OBSERVED_VALUES.name(), + "PAST_" + FieldName.IS_PAD.name(), + "FUTURE_" + FieldName.FEAT_TIME.name(), + "FUTURE_" + FieldName.TARGET.name(), + "FUTURE_" + FieldName.OBSERVED_VALUES.name() + }; + + private static final String[] PRED_INPUT_FIELDS = { + FieldName.FEAT_STATIC_CAT.name(), + FieldName.FEAT_STATIC_REAL.name(), + "PAST_" + FieldName.FEAT_TIME.name(), + "PAST_" + FieldName.TARGET.name(), + "PAST_" + FieldName.OBSERVED_VALUES.name(), + "FUTURE_" + FieldName.FEAT_TIME.name(), + "PAST_" + FieldName.IS_PAD.name() + }; + + protected String freq; + protected int historyLength; + protected int contextLength; + protected int predictionLength; + + protected boolean useFeatDynamicReal; + protected boolean useFeatStaticCat; + protected boolean useFeatStaticReal; + + protected DistributionOutput distrOutput; + protected List cardinality; + protected List embeddingDimension; + protected List lagsSeq; + protected int numParallelSamples; + + protected FeatureEmbedder embedder; + protected Block paramProj; + + protected LSTM rnn; + protected Scaler scaler; + + DeepARNetwork(Builder builder) { + freq = builder.freq; + predictionLength = builder.predictionLength; + contextLength = builder.contextLength != 0 ? builder.contextLength : predictionLength; + distrOutput = builder.distrOutput; + cardinality = builder.cardinality; + + useFeatStaticReal = builder.useFeatStaticReal; + useFeatDynamicReal = builder.useFeatDynamicReal; + useFeatStaticCat = builder.useFeatStaticCat; + numParallelSamples = builder.numParallelSamples; + + paramProj = addChildBlock("param_proj", distrOutput.getArgsProj()); + if (builder.embeddingDimension != null || builder.cardinality == null) { + embeddingDimension = builder.embeddingDimension; + } else { + embeddingDimension = new ArrayList<>(); + for (int cat : cardinality) { + embeddingDimension.add(Math.min(50, (cat + 1) / 2)); + } + } + lagsSeq = builder.lagsSeq == null ? Lag.getLagsForFreq(builder.freq) : builder.lagsSeq; + historyLength = contextLength + lagsSeq.stream().max(Comparator.naturalOrder()).get(); + embedder = + addChildBlock( + "feature_embedder", + FeatureEmbedder.builder() + .setCardinalities(cardinality) + .setEmbeddingDims(embeddingDimension) + .build()); + if (builder.scaling) { + scaler = + addChildBlock( + "scaler", + MeanScaler.builder() + .setDim(1) + .optKeepDim(true) + .optMinimumScale(1e-10f) + .build()); + } else { + scaler = + addChildBlock("scaler", NopScaler.builder().setDim(1).optKeepDim(true).build()); + } + + rnn = + addChildBlock( + "rnn_lstm", + LSTM.builder() + .setNumLayers(builder.numLayers) + .setStateSize(builder.hiddenSize) + .optDropRate(builder.dropRate) + .optBatchFirst(true) + .optReturnState(true) + .build()); + } + + /** {@inheritDoc} */ + @Override + protected void initializeChildBlocks( + NDManager manager, DataType dataType, Shape... inputShapes) { + + Shape targetShape = inputShapes[3].slice(2); + Shape contextShape = new Shape(1, contextLength).addAll(targetShape); + scaler.initialize(manager, dataType, contextShape, contextShape); + long scaleSize = scaler.getOutputShapes(new Shape[] {contextShape, contextShape})[1].get(1); + + embedder.initialize(manager, dataType, inputShapes[0]); + long embeddedCatSize = embedder.getOutputShapes(new Shape[] {inputShapes[0]})[0].get(1); + + Shape inputShape = new Shape(1, contextLength * 2L - 1).addAll(targetShape); + Shape lagsShape = inputShape.add(lagsSeq.size()); + long featSize = inputShapes[2].get(2) + embeddedCatSize + inputShapes[1].get(1) + scaleSize; + Shape rnnInputShape = + lagsShape.slice(0, lagsShape.dimension() - 1).add(lagsShape.tail() + featSize); + rnn.initialize(manager, dataType, rnnInputShape); + + Shape rnnOutShape = rnn.getOutputShapes(new Shape[] {rnnInputShape})[0]; + paramProj.initialize(manager, dataType, rnnOutShape); + } + + /** + * Applies the underlying RNN to the provided target data and covariates. + * + * @param ps the parameter store + * @param inputs the input NDList + * @param training true for a training forward pass + * @return a {@link NDList} containing arguments of the output distribution, scaling factor, raw + * output of rnn, static input of rnn, output state of rnn + */ + protected NDList unrollLaggedRnn(ParameterStore ps, NDList inputs, boolean training) { + try (NDManager scope = inputs.getManager().newSubManager()) { + scope.tempAttachAll(inputs); + + NDArray featStaticCat = inputs.get(0); + NDArray featStaticReal = inputs.get(1); + NDArray pastTimeFeat = inputs.get(2); + NDArray pastTarget = inputs.get(3); + NDArray pastObservedValues = inputs.get(4); + NDArray futureTimeFeat = inputs.get(5); + NDArray futureTarget = inputs.size() > 6 ? inputs.get(6) : null; + + NDArray context = pastTarget.get(":,{}:", -contextLength); + NDArray observedContext = pastObservedValues.get(":,{}:", -contextLength); + NDArray scale = + scaler.forward(ps, new NDList(context, observedContext), training).get(1); + + NDArray priorSequence = pastTarget.get(":,:{}", -contextLength).div(scale); + NDArray sequence = + futureTarget != null + ? context.concat(futureTarget.get(":, :-1"), 1).div(scale) + : context.div(scale); + + NDArray embeddedCat = + embedder.forward(ps, new NDList(featStaticCat), training).singletonOrThrow(); + NDArray staticFeat = + NDArrays.concat(new NDList(embeddedCat, featStaticReal, scale.log()), 1); + NDArray expandedStaticFeat = + staticFeat.expandDims(1).repeat(1, sequence.getShape().get(1)); + + NDArray timeFeat = + futureTimeFeat != null + ? pastTimeFeat + .get(":, {}:", -contextLength + 1) + .concat(futureTimeFeat, 1) + : pastTimeFeat.get(":, {}:", -contextLength + 1); + + NDArray features = expandedStaticFeat.concat(timeFeat, -1); + NDArray lags = laggedSequenceValues(lagsSeq, priorSequence, sequence); + + NDArray rnnInput = lags.concat(features, -1); + + NDList outputs = rnn.forward(ps, new NDList(rnnInput), training); + NDArray output = outputs.get(0); + NDArray hiddenState = outputs.get(1); + NDArray cellState = outputs.get(2); + + NDList params = paramProj.forward(ps, new NDList(output), training); + + scale.setName("scale"); + output.setName("output"); + staticFeat.setName("static_feat"); + hiddenState.setName("hidden_state"); + cellState.setName("cell_state"); + return scope.ret( + params.addAll(new NDList(scale, output, staticFeat, hiddenState, cellState))); + } + } + + /** + * Construct an {@link NDArray} of lagged values from a given sequence. + * + * @param indices indices of lagged observations + * @param priorSequence the input sequence prior to the time range for which the output is + * required + * @param sequence the input sequence in the time range where the output is required + * @return the lagged values + */ + protected NDArray laggedSequenceValues( + List indices, NDArray priorSequence, NDArray sequence) { + if (Collections.max(indices) > (int) priorSequence.getShape().get(1)) { + throw new IllegalArgumentException( + String.format( + "lags cannot go further than prior sequence length, found lag %d while" + + " prior sequence is only %d-long", + Collections.max(indices), priorSequence.getShape().get(1))); + } + try (NDManager scope = NDManager.subManagerOf(priorSequence)) { + scope.tempAttachAll(priorSequence, sequence); + NDArray fullSequence = priorSequence.concat(sequence, 1); + + NDList lagsValues = new NDList(indices.size()); + for (int lagIndex : indices) { + long begin = -lagIndex - sequence.getShape().get(1); + long end = -lagIndex; + lagsValues.add( + end < 0 + ? fullSequence.get(":, {}:{}", begin, end) + : fullSequence.get(":, {}:", begin)); + } + + NDArray lags = NDArrays.stack(lagsValues, -1); + return scope.ret(lags.reshape(lags.getShape().get(0), lags.getShape().get(1), -1)); + } + } + + /** + * Return the context length. + * + * @return the context length + */ + public int getContextLength() { + return contextLength; + } + + /** + * Return the history length. + * + * @return the history length + */ + public int getHistoryLength() { + return historyLength; + } + + /** + * Construct a training transformation of deepar model. + * + * @param manager the {@link NDManager} to create value + * @return the transformation + */ + public List createTrainingTransformation(NDManager manager) { + List transformation = createTransformation(manager); + + InstanceSampler sampler = new ExpectedNumInstanceSampler(0, 0, predictionLength, 1.0); + transformation.add( + new InstanceSplit( + FieldName.TARGET, + FieldName.IS_PAD, + FieldName.START, + FieldName.FORECAST_START, + sampler, + historyLength, + predictionLength, + new FieldName[] {FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES}, + distrOutput.getValueInSupport())); + + transformation.add(new SelectField(TRAIN_INPUT_FIELDS)); + return transformation; + } + + /** + * Construct a prediction transformation of deepar model. + * + * @param manager the {@link NDManager} to create value + * @return the transformation + */ + public List createPredictionTransformation(NDManager manager) { + List transformation = createTransformation(manager); + + InstanceSampler sampler = PredictionSplitSampler.newValidationSplitSampler(); + transformation.add( + new InstanceSplit( + FieldName.TARGET, + FieldName.IS_PAD, + FieldName.START, + FieldName.FORECAST_START, + sampler, + historyLength, + predictionLength, + new FieldName[] {FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES}, + distrOutput.getValueInSupport())); + + transformation.add(new SelectField(PRED_INPUT_FIELDS)); + return transformation; + } + + private List createTransformation(NDManager manager) { + List transformation = new ArrayList<>(); + + List removeFieldNames = new ArrayList<>(); + removeFieldNames.add(FieldName.FEAT_DYNAMIC_CAT); + if (!useFeatStaticReal) { + removeFieldNames.add(FieldName.FEAT_STATIC_REAL); + } + if (!useFeatDynamicReal) { + removeFieldNames.add(FieldName.FEAT_DYNAMIC_REAL); + } + + transformation.add(new RemoveFields(removeFieldNames)); + if (!useFeatStaticCat) { + transformation.add( + new SetField(FieldName.FEAT_STATIC_CAT, manager.zeros(new Shape(1)))); + } + if (!useFeatDynamicReal) { + transformation.add( + new SetField(FieldName.FEAT_STATIC_REAL, manager.zeros(new Shape(1)))); + } + + transformation.add( + new AddObservedValuesIndicator(FieldName.TARGET, FieldName.OBSERVED_VALUES)); + + transformation.add( + new AddTimeFeature( + FieldName.START, + FieldName.TARGET, + FieldName.FEAT_TIME, + TimeFeature.timeFeaturesFromFreqStr(freq), + predictionLength, + freq)); + + transformation.add( + new AddAgeFeature(FieldName.TARGET, FieldName.FEAT_AGE, predictionLength, true)); + + FieldName[] inputFields; + if (!useFeatDynamicReal) { + inputFields = new FieldName[] {FieldName.FEAT_TIME, FieldName.FEAT_AGE}; + } else { + inputFields = + new FieldName[] { + FieldName.FEAT_TIME, FieldName.FEAT_AGE, FieldName.FEAT_DYNAMIC_REAL + }; + } + transformation.add(new VstackFeatures(FieldName.FEAT_TIME, inputFields)); + + return transformation; + } + + /** + * Create a builder to build a {@code DeepARTrainingNetwork} or {@code DeepARPredictionNetwork}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * The builder to construct a {@code DeepARTrainingNetwork} or {@code DeepARPredictionNetwork}. + * type of {@link ai.djl.nn.Block}. + */ + public static final class Builder { + + private String freq; + private int contextLength; + private int predictionLength; + private int numParallelSamples = 100; + private int numLayers = 2; + private int hiddenSize = 40; + private float dropRate = 0.1f; + + private boolean useFeatDynamicReal; + private boolean useFeatStaticCat; + private boolean useFeatStaticReal; + private boolean scaling = true; + + private DistributionOutput distrOutput = new StudentTOutput(); + private List cardinality; + private List embeddingDimension; + private List lagsSeq; + + /** + * Set the prediction frequency. + * + * @param freq the frequency + * @return this builder + */ + public Builder setFreq(String freq) { + this.freq = freq; + return this; + } + + /** + * Set the prediction length. + * + * @param predictionLength the prediction length + * @return this builder + */ + public Builder setPredictionLength(int predictionLength) { + this.predictionLength = predictionLength; + return this; + } + + /** + * Set the cardinality for static categorical feature. + * + * @param cardinality the cardinality + * @return this builder + */ + public Builder setCardinality(List cardinality) { + this.cardinality = cardinality; + return this; + } + + /** + * Set the optional {@link DistributionOutput} default {@link StudentTOutput}. + * + * @param distrOutput the {@link DistributionOutput} + * @return this builder + */ + public Builder optDistrOutput(DistributionOutput distrOutput) { + this.distrOutput = distrOutput; + return this; + } + + /** + * Set the optional context length. + * + * @param contextLength the context length + * @return this builder + */ + public Builder optContextLength(int contextLength) { + this.contextLength = contextLength; + return this; + } + + /** + * Set the optional number parallel samples. + * + * @param numParallelSamples the num parallel samples + * @return this builder + */ + public Builder optNumParallelSamples(int numParallelSamples) { + this.numParallelSamples = numParallelSamples; + return this; + } + + /** + * Set the optional number of rnn layers. + * + * @param numLayers the number of rnn layers + * @return this builder + */ + public Builder optNumLayers(int numLayers) { + this.numLayers = numLayers; + return this; + } + + /** + * Set the optional number of rnn hidden size. + * + * @param hiddenSize the number of rnn hidden size + * @return this builder + */ + public Builder optHiddenSize(int hiddenSize) { + this.hiddenSize = hiddenSize; + return this; + } + + /** + * Set the optional number of rnn drop rate. + * + * @param dropRate the number of rnn drop rate + * @return this builder + */ + public Builder optDropRate(float dropRate) { + this.dropRate = dropRate; + return this; + } + + /** + * Set the optional embedding dimension. + * + * @param embeddingDimension the embedding dimension + * @return this builder + */ + public Builder optEmbeddingDimension(List embeddingDimension) { + this.embeddingDimension = embeddingDimension; + return this; + } + + /** + * Set the optional lags sequence, default generate from frequency. + * + * @param lagsSeq the lags sequence + * @return this builder + */ + public Builder optLagsSeq(List lagsSeq) { + this.lagsSeq = lagsSeq; + return this; + } + + /** + * Set whether to use dynamic real feature. + * + * @param useFeatDynamicReal whether to use dynamic real feature + * @return this builder + */ + public Builder optUseFeatDynamicReal(boolean useFeatDynamicReal) { + this.useFeatDynamicReal = useFeatDynamicReal; + return this; + } + + /** + * Set whether to use static categorical feature. + * + * @param useFeatStaticCat whether to use static categorical feature + * @return this builder + */ + public Builder optUseFeatStaticCat(boolean useFeatStaticCat) { + this.useFeatStaticCat = useFeatStaticCat; + return this; + } + + /** + * Set whether to use static real feature. + * + * @param useFeatStaticReal whether to use static real feature + * @return this builder + */ + public Builder optUseFeatStaticReal(boolean useFeatStaticReal) { + this.useFeatStaticReal = useFeatStaticReal; + return this; + } + + /** + * Build a {@link DeepARTrainingNetwork} block. + * + * @return the {@link DeepARTrainingNetwork} block. + */ + public DeepARTrainingNetwork buildTrainingNetwork() { + return new DeepARTrainingNetwork(this); + } + + /** + * Build a {@link DeepARPredictionNetwork} block. + * + * @return the {@link DeepARPredictionNetwork} block. + */ + public DeepARPredictionNetwork buildPredictionNetwork() { + return new DeepARPredictionNetwork(this); + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARPredictionNetwork.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARPredictionNetwork.java new file mode 100644 index 00000000000..02013c18718 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARPredictionNetwork.java @@ -0,0 +1,117 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.model.deepar; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; +import ai.djl.timeseries.distribution.Distribution; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +/** A deepar implements for prediction. */ +public class DeepARPredictionNetwork extends DeepARNetwork { + + DeepARPredictionNetwork(Builder builder) { + super(builder); + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDList unrollInputs = + new NDList( + inputs.get(0), // feat_static_cat + inputs.get(1), // feat_static_real + inputs.get(2), // past_time_feat + inputs.get(3), // past_target + inputs.get(4), // past_observed_value + inputs.get(5).get(":, :1") // future_time_feat + ); + NDList unrollOutput = unrollLaggedRnn(parameterStore, unrollInputs, training); + NDList state = new NDList(unrollOutput.get("hidden_state"), unrollOutput.get("cell_state")); + String[] argNames = distrOutput.getArgsArray(); + + NDList repeatedArgs = new NDList(argNames.length); + for (String argName : distrOutput.getArgsArray()) { + NDArray repeatedArg = unrollOutput.get(argName).repeat(0, numParallelSamples); + repeatedArg.setName(argName); + repeatedArgs.add(repeatedArg); + } + + NDArray repeatedScale = unrollOutput.get("scale").repeat(0, numParallelSamples); + NDArray repeatedStaticFeat = + unrollOutput.get("static_feat").repeat(0, numParallelSamples).expandDims(1); + NDArray repeatedPastTarget = inputs.get(3).repeat(0, numParallelSamples).div(repeatedScale); + NDArray repeatedTimeFeat = inputs.get(5).repeat(0, numParallelSamples); + + NDList repeatedState = new NDList(state.size()); + for (NDArray s : state) { + repeatedState.add(s.repeat(1, numParallelSamples)); + } + + Distribution distr = outputDistribution(repeatedArgs, repeatedScale, 1); + NDArray nextSample = distr.sample(); + NDList futureSamples = new NDList(predictionLength); + futureSamples.add(nextSample); + for (int k = 1; k < predictionLength; k++) { + NDArray scaledNextSample = nextSample.div(repeatedScale); + NDArray nextFeatures = + repeatedStaticFeat.concat(repeatedTimeFeat.get(":, {}:{}", k, k + 1), -1); + NDArray nextLags = laggedSequenceValues(lagsSeq, repeatedPastTarget, scaledNextSample); + NDArray rnnInput = nextLags.concat(nextFeatures, -1); + + NDList outputs = + rnn.forward( + parameterStore, new NDList(rnnInput).addAll(repeatedState), training); + NDArray output = outputs.get(0); + repeatedState = outputs.subNDList(1); + + repeatedPastTarget = repeatedPastTarget.concat(scaledNextSample, 1); + + repeatedArgs = paramProj.forward(parameterStore, new NDList(output), training); + distr = outputDistribution(repeatedArgs, repeatedScale, 0); + nextSample = distr.sample(); + futureSamples.add(nextSample); + } + + NDArray futureSamplesConcat = NDArrays.concat(futureSamples, 1); + return new NDList(futureSamplesConcat.reshape(-1, numParallelSamples, predictionLength)); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + long batchSize = inputShapes[0].head(); + return new Shape[] {new Shape(batchSize, numParallelSamples, predictionLength)}; + } + + private Distribution outputDistribution(NDList params, NDArray scale, int trailingN) { + NDList slicedParams = params; + if (trailingN > 0) { + slicedParams = new NDList(params.size()); + for (NDArray p : params) { + NDArray slicedP = p.get(":, {}:", -trailingN); + slicedP.setName(p.getName()); + slicedParams.add(slicedP); + } + } + return distrOutput.distributionBuilder().setDistrArgs(slicedParams).optScale(scale).build(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARTrainingNetwork.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARTrainingNetwork.java new file mode 100644 index 00000000000..88577953351 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/DeepARTrainingNetwork.java @@ -0,0 +1,107 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.model.deepar; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.Shape; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +/** A deepar implements for training. */ +public final class DeepARTrainingNetwork extends DeepARNetwork { + + DeepARTrainingNetwork(Builder builder) { + super(builder); + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDArray featStaticCat = inputs.get(0); + NDArray featStaticReal = inputs.get(1); + NDArray pastTimeFeat = inputs.get(2); + NDArray pastTarget = inputs.get(3); + NDArray pastObservedValues = inputs.get(4); + // NDArray pastIsPad = inputs.get(5); + NDArray futureTimeFeat = inputs.get(6); + NDArray futureTarget = inputs.get(7); + NDArray futureObservedValues = inputs.get(8); + + NDList unrollOutput = + unrollLaggedRnn( + parameterStore, + new NDList( + featStaticCat, + featStaticReal, + pastTimeFeat, + pastTarget, + pastObservedValues, + futureTimeFeat, + futureTarget), + training); + + NDArray observedValues = + pastObservedValues + .get(":, {}:", -contextLength + 1) + .concat(futureObservedValues, 1); + observedValues.setName("loss_weights"); + + String[] argNames = distrOutput.getArgsArray(); + NDList ret = new NDList(argNames.length + 2); // args + scale + loss_weights + + for (String argName : argNames) { + ret.add(unrollOutput.get(argName)); + } + ret.add(unrollOutput.get("scale")); + ret.add(observedValues); + return ret; + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + Shape targetShape = inputShapes[3].slice(2); + Shape contextShape = new Shape(1, contextLength).addAll(targetShape); + Shape scaleShape = scaler.getOutputShapes(new Shape[] {contextShape, contextShape})[1]; + long scaleSize = scaleShape.get(1); + + long embeddedCatSize = embedder.getOutputShapes(new Shape[] {inputShapes[0]})[0].get(1); + + Shape inputShape = new Shape(1, contextLength * 2L - 1).addAll(targetShape); + Shape lagsShape = inputShape.add(lagsSeq.size()); + long featSize = inputShapes[2].get(2) + embeddedCatSize + inputShapes[1].get(1) + scaleSize; + Shape rnnInputShape = + lagsShape.slice(0, lagsShape.dimension() - 1).add(lagsShape.tail() + featSize); + + Shape rnnOutShape = rnn.getOutputShapes(new Shape[] {rnnInputShape})[0]; + Shape[] argShapes = paramProj.getOutputShapes(new Shape[] {rnnOutShape}); + + long[] observedValueShape = new long[inputShapes[8].dimension()]; + System.arraycopy( + inputShapes[8].getShape(), 0, observedValueShape, 0, observedValueShape.length); + observedValueShape[1] += contextLength - 1; + Shape lossWeightsShape = new Shape(observedValueShape); + + Shape[] ret = new Shape[argShapes.length + 2]; + System.arraycopy(argShapes, 0, ret, 0, argShapes.length); + ret[argShapes.length] = scaleShape; + ret[argShapes.length + 1] = lossWeightsShape; + return ret; + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/package-info.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/package-info.java new file mode 100644 index 00000000000..61a397b0f1e --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/model/deepar/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains blocks for deepar models. */ +package ai.djl.timeseries.model.deepar; diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/block/BlockTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/block/BlockTest.java new file mode 100644 index 00000000000..0457044845d --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/block/BlockTest.java @@ -0,0 +1,185 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.block; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; +import ai.djl.testing.Assertions; +import ai.djl.training.ParameterStore; +import ai.djl.training.initializer.Initializer; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +public class BlockTest { + + @Test + public void testFeature() { + int n = 10; + int t = 20; + + // single static feat + Shape inputShape = new Shape(n, 1); + List cardinalities = Arrays.asList(50); + List embeddingDims = Arrays.asList(10); + testFeatureEmbedder(inputShape, cardinalities, embeddingDims); + + // single dynamic feat + inputShape = new Shape(n, t, 1); + cardinalities = Arrays.asList(2); + testFeatureEmbedder(inputShape, cardinalities, embeddingDims); + + // multiple static feat + inputShape = new Shape(n, 4); + cardinalities = Arrays.asList(50, 50, 50, 50); + embeddingDims = Arrays.asList(10, 20, 30, 40); + testFeatureEmbedder(inputShape, cardinalities, embeddingDims); + + // multiple dynamic features + inputShape = new Shape(n, t, 3); + cardinalities = Arrays.asList(30, 30, 30); + embeddingDims = Arrays.asList(10, 20, 30); + testFeatureEmbedder(inputShape, cardinalities, embeddingDims); + } + + @Test + public void testScaler() { + try (NDManager manager = NDManager.newBaseManager()) { + ParameterStore ps = new ParameterStore(manager, false); + + Scaler scaler = MeanScaler.builder().setDim(1).build(); + NDArray target = manager.randomNormal(new Shape(5, 30)); + NDArray observed = manager.zeros(new Shape(5, 30)); + NDArray expScale = manager.full(new Shape(5), 1e-10f); + Assert.assertEquals( + scaler.getOutputShapes(new Shape[] {target.getShape(), observed.getShape()})[1], + expScale.getShape()); + assertOutput( + scaler.forward(ps, new NDList(target, observed), false), + target, + expScale, + 1, + false); + + scaler = MeanScaler.builder().setDim(2).optMinimumScale(1e-6f).build(); + target = manager.randomNormal(new Shape(5, 3, 30)); + observed = manager.zeros(new Shape(5, 3, 30)); + expScale = manager.full(new Shape(5, 3), 1e-6f); + Assert.assertEquals( + scaler.getOutputShapes(new Shape[] {target.getShape(), observed.getShape()})[1], + expScale.getShape()); + assertOutput( + scaler.forward(ps, new NDList(target, observed), false), + target, + expScale, + 2, + false); + + scaler = MeanScaler.builder().setDim(1).optKeepDim(true).build(); + target = manager.randomNormal(new Shape(5, 30, 1)); + observed = manager.zeros(new Shape(5, 30, 1)); + expScale = manager.full(new Shape(5, 1, 1), 1e-10f); + Assert.assertEquals( + scaler.getOutputShapes(new Shape[] {target.getShape(), observed.getShape()})[1], + expScale.getShape()); + assertOutput( + scaler.forward(ps, new NDList(target, observed), false), + target, + expScale, + 1, + true); + + scaler = NopScaler.builder().setDim(1).build(); + target = manager.randomNormal(new Shape(10, 20, 30)); + observed = manager.zeros(new Shape(10, 20, 30)); + expScale = manager.ones(new Shape(10, 30)); + Assert.assertEquals( + scaler.getOutputShapes(new Shape[] {target.getShape(), observed.getShape()})[1], + expScale.getShape()); + assertOutput( + scaler.forward(ps, new NDList(target, observed), false), + target, + expScale, + 1, + false); + + scaler = NopScaler.builder().setDim(1).optKeepDim(true).build(); + target = manager.randomNormal(new Shape(10, 20, 30)); + observed = manager.ones(new Shape(10, 20, 30)); + expScale = manager.ones(new Shape(10, 1, 30)); + Assert.assertEquals( + scaler.getOutputShapes(new Shape[] {target.getShape(), observed.getShape()})[1], + expScale.getShape()); + assertOutput( + scaler.forward(ps, new NDList(target, observed), false), + target, + expScale, + 1, + true); + } + } + + private void assertOutput( + NDList scalerOutput, NDArray target, NDArray expScale, int dim, boolean keepDim) { + NDArray actTargetScaled = scalerOutput.get(0); + NDArray actScale = scalerOutput.get(1); + Assertions.assertAlmostEquals(actScale, expScale); + + NDArray expTargetScaled; + if (keepDim) { + expTargetScaled = target.div(expScale); + } else { + expTargetScaled = target.div(expScale.expandDims(dim)); + } + Assertions.assertAlmostEquals(actTargetScaled, expTargetScaled); + } + + private void testFeatureEmbedder( + Shape inputShape, List cardinalities, List embeddingDims) { + try (NDManager manager = NDManager.newBaseManager()) { + Shape outputShape = + inputShape + .slice(0, inputShape.dimension() - 1) + .add(embeddingDims.stream().mapToInt(Integer::intValue).sum()); + + FeatureEmbedder embedder = + FeatureEmbedder.builder() + .setCardinalities(cardinalities) + .setEmbeddingDims(embeddingDims) + .build(); + + embedder.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); + embedder.initialize(manager, DataType.FLOAT32, inputShape); + + int expParamsLen = embedder.getParameters().keys().size(); + int actParamsLen = embeddingDims.size(); + Assert.assertEquals(expParamsLen, actParamsLen); + + ParameterStore ps = new ParameterStore(manager, true); + NDArray actOutput = + embedder.forward(ps, new NDList(manager.ones(inputShape)), true) + .singletonOrThrow(); + NDArray expOutput = manager.ones(outputShape); + Assert.assertEquals(actOutput.getShape(), expOutput.getShape()); + Assertions.assertAlmostEquals(actOutput, expOutput); + } + } +} diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/block/package-info.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/block/package-info.java new file mode 100644 index 00000000000..f58021f963a --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/block/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests for the basic block module. */ +package ai.djl.timeseries.block; \ No newline at end of file diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java index 051f7f4743b..8708fad6627 100644 --- a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java @@ -25,16 +25,19 @@ public class DistributionTest { @Test public void testNegativeBinomial() { try (NDManager manager = NDManager.newBaseManager()) { - NDArray mu = manager.create(new float[] {1000f, 1f}); - NDArray alpha = manager.create(new float[] {1f, 2f}); - mu.setName("mu"); - alpha.setName("alpha"); + NDArray totalCount = manager.create(new float[] {5f, 1f}); + NDArray logits = manager.create(new float[] {0.1f, 2f}); + totalCount.setName("total_count"); + logits.setName("logits"); Distribution negativeBinomial = - NegativeBinomial.builder().setDistrArgs(new NDList(mu, alpha)).build(); + NegativeBinomial.builder().setDistrArgs(new NDList(totalCount, logits)).build(); - NDArray expected = manager.create(new float[] {-6.9098f, -1.6479f}); - NDArray real = negativeBinomial.logProb(manager.create(new float[] {1f, 1f})); + NDArray expected = manager.create(new float[] {-2.3027f, -2.2539f}); + NDArray real = negativeBinomial.logProb(manager.create(new float[] {2f, 1f})); Assertions.assertAlmostEquals(real, expected); + + NDArray samplesMean = negativeBinomial.sample(100000).mean(new int[] {0}); + Assertions.assertAlmostEquals(samplesMean, negativeBinomial.mean(), 2e-2f, 2e-2f); } } @@ -53,6 +56,9 @@ public void testStudentT() { NDArray expected = manager.create(new float[] {-0.9779f, -1.6940f}); NDArray real = studentT.logProb(manager.create(new float[] {1000f, -1000f})); Assertions.assertAlmostEquals(real, expected); + + NDArray samplesMean = studentT.sample(100000).mean(new int[] {0}); + Assertions.assertAlmostEquals(samplesMean, studentT.mean(), 2e-2f, 2e-2f); } } } diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/evaluator/RmsseTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/evaluator/RmsseTest.java new file mode 100644 index 00000000000..71f42eafdae --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/evaluator/RmsseTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.evaluator; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.timeseries.distribution.output.NegativeBinomialOutput; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RmsseTest { + + public static void main(String[] args) { + RmsseTest t = new RmsseTest(); + t.testRmsse(); + } + + @Test + public void testRmsse() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray totalCount = manager.create(new float[] {1f, 2f}).expandDims(0); + NDArray logits = manager.create(new float[] {1f, 2f}).log().expandDims(0); + totalCount.setName("total_count"); + logits.setName("logits"); + NDList prediction = new NDList(totalCount, logits); + NDList label = new NDList(manager.create(new float[] {3f, 4f}).expandDims(0)); + + Rmsse rmsse = new Rmsse(new NegativeBinomialOutput()); + rmsse.addAccumulator(""); + rmsse.updateAccumulator("", label, prediction); + float rmsseValue = rmsse.getAccumulator(""); + float expectedRmsse = 1.414213562373095f; + Assert.assertEquals(rmsseValue, expectedRmsse); + } + } +} diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/evaluator/package-info.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/evaluator/package-info.java new file mode 100644 index 00000000000..a8bbb8689c3 --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/evaluator/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests for the evaluator module. */ +package ai.djl.timeseries.evaluator; \ No newline at end of file diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/model/DeepARTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/model/DeepARTest.java new file mode 100644 index 00000000000..08cccb9e99c --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/model/DeepARTest.java @@ -0,0 +1,312 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.model; + +import ai.djl.Model; +import ai.djl.basicdataset.BasicDatasets; +import ai.djl.basicdataset.tabular.utils.Feature; +import ai.djl.engine.Engine; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; +import ai.djl.timeseries.dataset.FieldName; +import ai.djl.timeseries.dataset.M5Forecast; +import ai.djl.timeseries.dataset.TimeFeaturizers; +import ai.djl.timeseries.distribution.DistributionLoss; +import ai.djl.timeseries.distribution.output.DistributionOutput; +import ai.djl.timeseries.distribution.output.NegativeBinomialOutput; +import ai.djl.timeseries.distribution.output.StudentTOutput; +import ai.djl.timeseries.model.deepar.DeepARNetwork; +import ai.djl.timeseries.timefeature.TimeFeature; +import ai.djl.timeseries.transform.TimeSeriesTransform; +import ai.djl.training.DefaultTrainingConfig; +import ai.djl.training.EasyTrain; +import ai.djl.training.ParameterStore; +import ai.djl.training.Trainer; +import ai.djl.training.TrainingConfig; +import ai.djl.training.dataset.Batch; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.initializer.Initializer; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import ai.djl.util.PairList; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class DeepARTest { + + private static int predictionLength = 4; + private static String freq = "D"; + + public static void main(String[] args) throws TranslateException, IOException { + DeepARTest test = new DeepARTest(); + test.testPredictionTransformation(); + } + + @Test + public void testTrainingNetwork() { + try (Model model = Model.newInstance("deepar")) { + DistributionOutput distributionOutput = new NegativeBinomialOutput(); + TrainingConfig config = + new DefaultTrainingConfig(new DistributionLoss("Loss", distributionOutput)) + .optDevices(Engine.getInstance().getDevices()); + + NDManager manager = model.getNDManager(); + DeepARNetwork deepARTraining = getDeepARModel(distributionOutput, true); + model.setBlock(deepARTraining); + try (Trainer trainer = model.newTrainer(config)) { + int batchSize = 1; + int historyLength = deepARTraining.getHistoryLength(); + Shape[] inputShapes = getTrainingInputShapes(batchSize, historyLength); + trainer.initialize(inputShapes); + + NDList inputs = + new NDList( + Stream.of(inputShapes) + .map(manager::ones) + .collect(Collectors.toList())); + NDArray label = + manager.ones( + new Shape( + batchSize, + deepARTraining.getContextLength() + predictionLength - 1)); + Batch batch = + new Batch( + manager.newSubManager(), + inputs, + new NDList(label), + batchSize, + Batchifier.STACK, + Batchifier.STACK, + 0, + 0); + PairList parameters = deepARTraining.getParameters(); + EasyTrain.trainBatch(trainer, batch); + trainer.step(); + + Assert.assertEquals( + parameters.get(0).getValue().getArray().getShape(), new Shape(1, 40)); + Assert.assertEquals( + parameters.get(1).getValue().getArray().getShape(), new Shape(1)); + Assert.assertEquals( + parameters.get(2).getValue().getArray().getShape(), new Shape(1, 40)); + Assert.assertEquals( + parameters.get(3).getValue().getArray().getShape(), new Shape(1)); + Assert.assertEquals( + parameters.get(4).getValue().getArray().getShape(), new Shape(5, 3)); + } + } + } + + @Test + public void testPredictionNetwork() { + DeepARNetwork deepAR = getDeepARModel(new NegativeBinomialOutput(), false); + try (NDManager manager = NDManager.newBaseManager()) { + int batchSize = 1; + int historyLength = deepAR.getHistoryLength(); + Shape[] inputShapes = getPredictionInputShapes(batchSize, historyLength); + + deepAR.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); + deepAR.initialize(manager, DataType.FLOAT32, inputShapes); + + ParameterStore ps = new ParameterStore(manager, true); + NDArray actOutput = + deepAR.forward( + ps, + new NDList( + Stream.of(inputShapes) + .map(manager::ones) + .collect(Collectors.toList())), + false) + .singletonOrThrow(); + NDArray expOutput = manager.ones(new Shape(batchSize, 100, predictionLength)); + Assert.assertEquals(actOutput.getShape(), expOutput.getShape()); + } + } + + @Test + public void testOutputShapes() { + DeepARNetwork deepARTraining = getDeepARModel(new NegativeBinomialOutput(), true); + DeepARNetwork deepARPrediction = getDeepARModel(new StudentTOutput(), false); + try (NDManager manager = NDManager.newBaseManager()) { + int batchSize = 1; + int historyLength = deepARTraining.getHistoryLength(); + Shape[] trainingInputShapes = getTrainingInputShapes(batchSize, historyLength); + Shape[] predictionInputShapes = getPredictionInputShapes(batchSize, historyLength); + + deepARTraining.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); + deepARTraining.initialize(manager, DataType.FLOAT32, trainingInputShapes); + + deepARPrediction.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); + deepARPrediction.initialize(manager, DataType.FLOAT32, predictionInputShapes); + + Shape[] trainingOutputShapes = deepARTraining.getOutputShapes(trainingInputShapes); + Shape[] predictionOutputShapes = + deepARPrediction.getOutputShapes(predictionInputShapes); + + int contextLength = deepARTraining.getContextLength(); + // Distribution param shape (batch_size, context_length - 1 + prediction_length, + // arg_shape) + Assert.assertEquals( + trainingOutputShapes[0], + new Shape(batchSize, contextLength - 1 + predictionLength)); + Assert.assertEquals( + trainingOutputShapes[1], + new Shape(batchSize, contextLength - 1 + predictionLength)); + // scale shape + Assert.assertEquals(trainingOutputShapes[2], new Shape(batchSize, 1)); + // loss weights shape + Assert.assertEquals( + trainingOutputShapes[3], + new Shape(batchSize, contextLength - 1 + predictionLength)); + + // prediction sample shape (batch_size, prediction_length - 1) + Assert.assertEquals( + predictionOutputShapes[0], new Shape(batchSize, 100, predictionLength)); + } + } + + @Test + public void testTrainingTransformation() throws IOException, TranslateException { + DeepARNetwork deepAR = getDeepARModel(new NegativeBinomialOutput(), true); + try (NDManager manager = NDManager.newBaseManager()) { + List trainingTransformation = + deepAR.createTrainingTransformation(manager); + + int batchSize = 32; + M5Forecast m5Forecast = getDataset(batchSize, trainingTransformation); + Batch batch = m5Forecast.getData(manager).iterator().next(); + Assert.assertEquals(batch.getData().size(), 9); + Assert.assertEquals(batch.getLabels().size(), 1); + + Shape[] actInputShapes = + batch.getData().stream().map(NDArray::getShape).toArray(Shape[]::new); + Assert.assertEquals( + actInputShapes, getTrainingInputShapes(batchSize, deepAR.getHistoryLength())); + } + } + + @Test + public void testPredictionTransformation() throws IOException, TranslateException { + DeepARNetwork deepAR = getDeepARModel(new StudentTOutput(), false); + try (NDManager manager = NDManager.newBaseManager()) { + List predictionTransformation = + deepAR.createPredictionTransformation(manager); + + int batchSize = 32; + M5Forecast m5Forecast = getDataset(batchSize, predictionTransformation); + Batch batch = m5Forecast.getData(manager).iterator().next(); + Assert.assertEquals(batch.getData().size(), 7); + Assert.assertEquals(batch.getLabels().size(), 0); + + Shape[] actInputShapes = + batch.getData().stream().map(NDArray::getShape).toArray(Shape[]::new); + Assert.assertEquals( + actInputShapes, getPredictionInputShapes(batchSize, deepAR.getHistoryLength())); + } + } + + private DeepARNetwork getDeepARModel(DistributionOutput distributionOutput, boolean isTrain) { + // here is feat_static_cat's cardinality which depend on your dataset + List cardinality = Arrays.asList(5); + + DeepARNetwork.Builder builder = + DeepARNetwork.builder() + .setCardinality(cardinality) + .setFreq(freq) + .setPredictionLength(predictionLength) + .optDistrOutput(distributionOutput); + return isTrain ? builder.buildTrainingNetwork() : builder.buildPredictionNetwork(); + } + + private Shape[] getTrainingInputShapes(int batchSize, int historyLength) { + return new Shape[] { + new Shape(batchSize, 1), + new Shape(batchSize, 1), + new Shape( + batchSize, historyLength, TimeFeature.timeFeaturesFromFreqStr(freq).size() + 1), + new Shape(batchSize, historyLength), + new Shape(batchSize, historyLength), + new Shape(batchSize, historyLength), + new Shape( + batchSize, + predictionLength, + TimeFeature.timeFeaturesFromFreqStr(freq).size() + 1), + new Shape(batchSize, predictionLength), + new Shape(batchSize, predictionLength), + }; + } + + private Shape[] getPredictionInputShapes(int batchSize, int historyLength) { + return new Shape[] { + new Shape(batchSize, 1), + new Shape(batchSize, 1), + new Shape( + batchSize, historyLength, TimeFeature.timeFeaturesFromFreqStr(freq).size() + 1), + new Shape(batchSize, historyLength), + new Shape(batchSize, historyLength), + new Shape( + batchSize, + predictionLength, + TimeFeature.timeFeaturesFromFreqStr(freq).size() + 1), + new Shape(batchSize, historyLength) + }; + } + + private M5Forecast getDataset(int batchSize, List transforms) + throws TranslateException, IOException { + M5Forecast.Builder builder = + M5Forecast.builder() + .optUsage(Dataset.Usage.TEST) + .optUsage(Dataset.Usage.TEST) + .optRepository(BasicDatasets.REPOSITORY) + .optGroupId(BasicDatasets.GROUP_ID) + .optArtifactId("m5forecast-unittest") + .setTransformation(transforms) + .setContextLength(predictionLength) + .setSampling(batchSize, true); + List features = builder.getAvailableFeatures(); + Assert.assertEquals(features.size(), 5); + for (int i = 1; i <= 277; i++) { + builder.addFeature("w_" + i, FieldName.TARGET); + } + M5Forecast m5Forecast = + builder.addFeature("state_id", FieldName.FEAT_STATIC_CAT) + .addFeature("store_id", FieldName.FEAT_STATIC_CAT) + .addFeature("cat_id", FieldName.FEAT_STATIC_CAT) + .addFeature("dept_id", FieldName.FEAT_STATIC_CAT) + .addFeature("item_id", FieldName.FEAT_STATIC_CAT) + .addFieldFeature( + FieldName.START, + new Feature( + "date", + TimeFeaturizers.getConstantTimeFeaturizer( + LocalDateTime.parse("2011-01-29T00:00")))) + .build(); + + m5Forecast.prepare(); + return m5Forecast; + } +} diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/model/package-info.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/model/package-info.java new file mode 100644 index 00000000000..9a4cd89bb3d --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/model/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests for the model module. */ +package ai.djl.timeseries.model; \ No newline at end of file