From b4a6148c9239dbc95b3a76f16acfd9457a8a36f6 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 27 Sep 2024 16:30:13 -0700 Subject: [PATCH] Migrate from capture_pre_autograd_graph to torch.export.export_for_training (#5730) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5730 As titled. The `capture_pre_autograd_graph` API is deprecated. Reviewed By: hsharma35 Differential Revision: D63541800 fbshipit-source-id: 7b830ae55a5dff8bf61be0470f54302c2ab461d8 --- backends/cadence/aot/compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index e1494f8d20..fe8fc72124 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -30,7 +30,6 @@ ) from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -58,7 +57,7 @@ def convert_pt2( """ # Export with dynamo - model_gm = capture_pre_autograd_graph(model, inputs) + model_gm = torch.export.export_for_training(model, inputs).module() if model_gm_has_SDPA(model_gm): # pyre-fixme[6] # Decompose SDPA