Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Support #553

Open
falknerdominik opened this issue Nov 17, 2021 · 9 comments
Open

ONNX Support #553

falknerdominik opened this issue Nov 17, 2021 · 9 comments
Assignees
Labels
feature request Use this label to request a new feature

Comments

@falknerdominik
Copy link

falknerdominik commented Nov 17, 2021

Is your feature request related to a current problem? Please describe.
Using ONNX on other pc / inside containers is a hassle and currently (as of reading the documentation) only supported by pickling. Pickling can be error prone and produces big files (with a lot of unneeded data) > 1MB for non baseline models.

Describe proposed solution
Support ONNX (like scikit) which allows to save and load the model.

Describe potential alternatives
Provide your own way to serialize / deserialize a model (probably not the best idea)

@falknerdominik falknerdominik added the triage Issue waiting for triaging label Nov 17, 2021
@hrzn hrzn added feature request Use this label to request a new feature and removed triage Issue waiting for triaging labels Nov 27, 2021
@hrzn
Copy link
Contributor

hrzn commented Nov 27, 2021

Thanks for the suggestion @falkmatt91 - Indeed it would be nice. We are also very open to receive pull requests, if you want to contribute something along these lines. I think most of the work here should revolve around the TorchForecastingModel class.

@pelekhs
Copy link

pelekhs commented Dec 13, 2021

Hi! This is a very interesting request indeed. But I am afraid I have not the required experience to contribute as it will be my first time. If anyone is interested in a collaboration I could join however! By the way I have developed some darts models for a project which will use ONNX as the interoperability standard. Is there any way currently to deploy those models as ONNX? (even if it is not the most efficient one). @falknerdominik is it possible to store darts models as pkl files and then convert them to ONNX? And are they going to be fully functional? Does ONNX accept darts.TimeSeries as (example) input? Thanks in advance!

@falknerdominik
Copy link
Author

@hrzn do you have a starting point for this? Might be interested in implementing part of it.

@pelekhs ONNX models are saved as a simple plain text file which just allows you to do minimal inference (train data, ... is normally not saved, can not be retrained as far as a i know). The nice thing is it can be easily deployed and the onnx foundation provides docker images which can load any onnx file and output an inference.

Inputs/Outputs have to be described manually (Althrough we can probably specify output automatically).

@madtoinou
Copy link
Collaborator

After answering to a comment on gitter, noticed that some features implemented recently could be reused for this feature:

I am bumping this feature is the backlog, it should become a low hanging fruit.

@madtoinou madtoinou self-assigned this May 8, 2023
@BlackFireAlex
Copy link

Hello,
Could somebody provide an example to get the tensor dimensions for a TFTModel ?

@madtoinou
Copy link
Collaborator

TFTModel uses the MixedCovariatesSequentialDataset by default, you can look at the PLMixedCovariatesModule._get_batch_prediction() method to see how the covariates are handled before passing them to the model (there is an example for NHiTSModel in #1521).

I started working on a PR to close this issue but encountering some difficulties when loading the ONNX model and trying to run inference. If anyone with experience in ONNX (or who managed to make it work for darts models) wants to pick it up, go ahead!

@JoonasHL
Copy link

Hello, has there been any work done on ONNX export functionality for the darts models (the pytorch-lightning ones)? If not - does someone know if there is a work-around for exporting NBeats/TCN to ONNX format?

@madtoinou
Copy link
Collaborator

Hi @JoonasHL,

Made a bit of progress but other features got priority over this one. Still on the roadmap, I need to find some time for it (or happy to let someone else take over this).

I wrote a workaround in #1521, but the preparation for inference once in ONNX format is not straightforward yet. Since NHiTSModel, NBeats and TCNModel are all PastCovariatesModel so the code snippet could pretty much be used as is.

@JoonasHL
Copy link

JoonasHL commented Dec 20, 2023

Thanks @madtoinou. Finally got it to work for my use case. I also encountered some issues when running inference with the prepared onnx model. Reading in the onnxruntime forum, seems like there is limited support for double type. What i did was that i cast the model and the input data to float32 and using your example in #1521, i got it to work.

sample code:

model = NBEATSModel(  ...  )
model.fit( ... )

model_for_onnx = model.model

`# Cast model dtype to float32`
`model_for_onnx = model_for_onnx.to(torch.float32)`


dim_component = 2
(
    past_target,
    past_covariates,
    future_past_covariates,
    static_covariates,
) = [torch.Tensor(x).unsqueeze(0) if x is not None else None for x in model.train_sample]

n_past_covs = (
    past_covariates.shape[dim_component] if past_covariates is not None else 0
)

input_past = torch.cat(
    [ds for ds in [past_target, past_covariates] if ds is not None],
    dim=dim_component,
)

`input_sample = [input_past.float(), static_covariates.float()]`
model_for_onnx.to_onnx("test_export.onnx", input_sample=input_sample)

import onnxruntime
import numpy as np

onnx_model_path = "test_export.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_model_path)

input_np_float = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).reshape(1, 30, 1).float().numpy()

input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
result = onnx_session.run([output_name], {input_name: input_np_float})

print("Output shape:", result[0].shape)
print("Predictions:", result[0]) 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request Use this label to request a new feature
Projects
None yet
Development

No branches or pull requests

6 participants