Skip to content

Commit

Permalink
Add torch.no_grad() guard on export (pytorch#2373)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2373

For some models where the grad is enabled (like LLava model), we need to disable grad before exporting. This PR is to add this guard when using export script. It would not affect the models that grad is not enabled.

Reviewed By: cccclai

Differential Revision: D54812718

fbshipit-source-id: 639efa5a77839d8d5de52cfa6cb125c883ba54dc
  • Loading branch information
Martin Yuan authored and facebook-github-bot committed Mar 12, 2024
1 parent 7675073 commit 35a847e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/portable/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import argparse
import logging

import torch

from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig

from ...models import MODEL_NAME_TO_MODEL
Expand Down Expand Up @@ -75,4 +77,5 @@ def main() -> None:


if __name__ == "__main__":
main() # pragma: no cover
with torch.no_grad():
main() # pragma: no cover

0 comments on commit 35a847e

Please sign in to comment.