Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TimeSeries API PyTorch Engine support #3259

Closed
keklol5050 opened this issue Jun 17, 2024 · 0 comments · Fixed by #3262
Closed

TimeSeries API PyTorch Engine support #3259

keklol5050 opened this issue Jun 17, 2024 · 0 comments · Fixed by #3262
Labels
bug Something isn't working

Comments

@keklol5050
Copy link

The documentation states (https://djl.ai/extensions/timeseries/docs/forecast_with_M5_data.html) that PyTorch also supports the TimeSeries API, but when trying to train a model with the PyTorch engine, an error occurs:
Exception in thread "main" java.lang.UnsupportedOperationException: Not implemented yet.
at ai.djl.pytorch.engine.PtNDArray.gammaln(PtNDArray.java:845)
at ai.djl.timeseries.distribution.NegativeBinomial.logProb(NegativeBinomial.java:47)
at ai.djl.timeseries.distribution.DistributionLoss.evaluate(DistributionLoss.java:54)
at ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:124)
at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110)
at ai.djl.training.EasyTrain.fit(EasyTrain.java:58)
at com.crypto.analysis.main.ex.NNew.train(NNew.java:134)
at com.crypto.analysis.main.ex.NNew.main(NNew.java:63)

`public static void train(Path outputDir) {
try (Model model = Model.newInstance("DeepAR");
NDManager manager = NDManager.newBaseManager()) {

        DistributionOutput distributionOutput = new NegativeBinomialOutput();
        DefaultTrainingConfig config = setupTrainingConfig(distributionOutput);
        DeepARNetwork trainingNetwork = getDeepARModel(freq, predictionLength, distributionOutput, true);

        model.setBlock(trainingNetwork);

        List<TimeSeriesTransform> trainingTransformation = trainingNetwork.createTrainingTransformation(manager);
        int contextLength = trainingNetwork.getContextLength();

        NDataset trainSet = getDataset(trainingTransformation, contextLength);

        try (Trainer trainer = model.newTrainer(config)) {
            trainer.setMetrics(new Metrics());

            System.out.println("+++++" + trainSet.availableSize());
            int historyLength = trainingNetwork.getHistoryLength();
            Shape[] inputShapes = new Shape[4];
            // (N, num_cardinality)
            inputShapes[0] = new Shape(1, 1);
            // (N, num_real) if use_feat_stat_real else (N, 1)
            inputShapes[1] = new Shape(1, 1);

            inputShapes[2] =
                    new Shape(
                            1,
                            historyLength,
                            TimeFeature.timeFeaturesFromFreqStr("M").size() + 1);
            inputShapes[3] = new Shape(1, historyLength);
            trainer.initialize(inputShapes);

            EasyTrain.fit(trainer, 1500, trainSet, null);
            logger.info(String.valueOf(trainer.getTrainingResult()));

            model.setProperty("Epochs", "1500");
            model.save(outputDir, "DeepAR");
        }

    } catch (TranslateException | IOException e) {
        throw new RuntimeException(e);
    }
}

private static DeepARNetwork getDeepARModel(String freq, int predictionLength,
                                            DistributionOutput distributionOutput, boolean training) {

    List<Integer> cardinality = new ArrayList<>();
    cardinality.add(allSize-testSize);

    DeepARNetwork.Builder builder =
            DeepARNetwork.builder()
                    .setFreq(freq)
                    .setPredictionLength(predictionLength)
                    .setCardinality(cardinality)
                    .optDistrOutput(distributionOutput)
                    .optUseFeatDynamicReal(false)
                    .optUseFeatStaticCat(false)
                    .optUseFeatStaticReal(false);
    return training ? builder.buildTrainingNetwork() : builder.buildPredictionNetwork();
}

private static DefaultTrainingConfig setupTrainingConfig(DistributionOutput distributionOutput) {
    return new DefaultTrainingConfig(new DistributionLoss("Loss", distributionOutput))
            .addEvaluator(new Rmsse(distributionOutput))
            .optInitializer(new XavierInitializer(), Parameter.Type.WEIGHT)
            .addTrainingListeners(TrainingListener.Defaults.logging());
}`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
1 participant