From 76b691d80c6c5203c66365272ce246ac86e418f0 Mon Sep 17 00:00:00 2001 From: Dominic Kerr Date: Thu, 4 Apr 2024 01:42:25 +0100 Subject: [PATCH] Support pathlib.Path file paths when saving ONNX models (#19727) Co-authored-by: dominicgkerr --- src/lightning/pytorch/core/module.py | 2 +- tests/tests_pytorch/models/test_onnx.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 3075e8952b148..faeda00ce5aa9 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1395,7 +1395,7 @@ def forward(self, x): input_sample = self._on_before_batch_transfer(input_sample) input_sample = self._apply_batch_transfer_handler(input_sample) - torch.onnx.export(self, input_sample, file_path, **kwargs) + torch.onnx.export(self, input_sample, str(file_path), **kwargs) self.train(mode) @torch.no_grad() diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 9bb20579b7162..15d06355946fc 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,6 +13,7 @@ # limitations under the License. import operator import os +from pathlib import Path from unittest.mock import patch import numpy as np @@ -32,11 +33,14 @@ def test_model_saves_with_input_sample(tmp_path): """Test that ONNX model saves with input sample and size is greater than 3 MB.""" model = BoringModel() - trainer = Trainer(fast_dev_run=True) - trainer.fit(model) - - file_path = os.path.join(tmp_path, "model.onnx") input_sample = torch.randn((1, 32)) + + file_path = os.path.join(tmp_path, "os.path.onnx") + model.to_onnx(file_path, input_sample) + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 4e2 + + file_path = Path(tmp_path) / "pathlib.onnx" model.to_onnx(file_path, input_sample) assert os.path.isfile(file_path) assert os.path.getsize(file_path) > 4e2