-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use _preserve_ops for to_edge_transform_and_lower (#4273)
Summary: Pull Request resolved: #4273 ## Motivation `run_decompositions()` has a new preserve_ops functionality which allows us to specify which ops we want to refrain from decomposing. This is super helpful for the to_edge_transform_and_lower api because it allows us to preserve decomposition that occur beyond the first level. For example consider LSTM. when exported using torch.export, we see a torch.ops.aten.LSTM() operator in the graph. When running decompositions this is decomposed into linear, and then further decomposed into addmm. Since the linear op is produced from decomposing LSTM and does not exist until after we run_decompositions(), we can not perform our trick of changing the name space to prevent its decomposition. However, now using `_preserve_ops=(torch.ops.aten.linear.default,)` we are able to prevent this second layer decomposition. ## API Implementation Change So in the implementation we do two passes. The first pass is we run_decompositions preserving all aten ops specified by our partitioners using `_preserve_ops`. On our second pass, we further filter which aten ops should be preserved by using the check_op_fn given to us by partitioners. We then use our namespace trick to prevent the decomposition of all aten ops which pass our check_op_fn. ## Testing Changes To strengthen our tests, I first change the functionality of the NonDecompPartitioner. We partition only pre-decomp aten ops. And each of these ops live within their own delegate (this allows us to have a 1:1 mapping for call_delegate and pre_decomp aten nodes). In testing, this will allow us to ensure that the number of ops which are to preserved is correct by counting the number of delegates calls. In testing we then count the number of aten ops which should correctly be preserved. And then check after the fact that all these ops are 1. No longer in the graph after to_edge_transform_and_lower 2. Each of these preserved ops are transformed into a call_delegate node Reviewed By: tarun292 Differential Revision: D59786323 fbshipit-source-id: 7ea946e0d5afc8ebddd26913f6e843305116ad3b
- Loading branch information
1 parent
b448254
commit c3357e1
Showing
3 changed files
with
148 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters