Skip to content

Commit

Permalink
error on exporting ScriptModule (pytorch#135302)
Browse files Browse the repository at this point in the history
Test Plan: added test

Differential Revision: D62279179

Pull Request resolved: pytorch#135302
Approved by: https://github.com/yushangdi
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Sep 6, 2024
1 parent ad29a2c commit de74aaf
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
19 changes: 19 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,25 @@ def forward(self, x):
torch.allclose(ep.module()(torch.zeros(2, 3)), torch.ones(2, 3) * 21)
)

def test_export_script_module(self):
class Foo(torch.nn.Module):
def forward(self, rv: torch.Tensor, t: torch.Tensor):
i = t.item()
return rv + i

foo = Foo()
foo_script = torch.jit.script(foo)
inp = (torch.zeros(3, 4), torch.tensor(7))

with self.assertRaisesRegex(
ValueError, "Exporting a ScriptModule is not supported"
):
export(foo_script, inp)

from torch._export.converter import TS2EPConverter

TS2EPConverter(foo_script, inp).convert()

def test_torch_fn(self):
class M1(torch.nn.Module):
def __init__(self) -> None:
Expand Down
12 changes: 12 additions & 0 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def export_for_training(
raise ValueError(
f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
)
if isinstance(mod, torch.jit.ScriptModule):
raise ValueError(
"Exporting a ScriptModule is not supported. "
"Maybe try converting your ScriptModule to an ExportedProgram "
"using `TS2EPConverter(mod, args, kwargs).convert()` instead."
)
return _export_for_training(
mod,
args,
Expand Down Expand Up @@ -255,6 +261,12 @@ def export(
raise ValueError(
f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
)
if isinstance(mod, torch.jit.ScriptModule):
raise ValueError(
"Exporting a ScriptModule is not supported. "
"Maybe try converting your ScriptModule to an ExportedProgram "
"using `TS2EPConverter(mod, args, kwargs).convert()` instead."
)
return _export(
mod,
args,
Expand Down

0 comments on commit de74aaf

Please sign in to comment.