Skip to content

Commit

Permalink
add tribal knowledge decomps for newer torch versions
Browse files Browse the repository at this point in the history
Once we migrate default torch version these will be hit by existing
tests.
  • Loading branch information
dan-garvey committed Feb 6, 2024
1 parent 66f79ab commit 42821b3
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions core/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@
torch.ops.aten._scaled_dot_product_flash_attention.default,
]

# These decompositions either didnt exist or weren't required for 2.1.0
if torch.__version__ > "2.1.0":
DEFAULT_DECOMPOSITIONS.append(torch.ops.aten._scaled_dot_product_flash_attention_for_cpu)
DEFAULT_DECOMPOSITIONS.append(torch.ops.aten.unbind_int)


def apply_decompositions(
gm: torch.fx.GraphModule,
Expand Down

0 comments on commit 42821b3

Please sign in to comment.