Skip to content

Commit

Permalink
Move tile rotation to top of IPU Jacobi loop body.
Browse files Browse the repository at this point in the history
Allows to optimize out one on-tile-copy, saving an additional 10% of
cycles.
  • Loading branch information
balancap committed Oct 19, 2023
1 parent b891bd9 commit d415805
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu
halfN = Apcols.shape[0]

with jax.named_scope("jacobi_eigh"):
# with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
# with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
with jax.named_scope("Apqcols_rotation"):
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
with jax.named_scope("Vpqcols_rotation"):
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Barrier, to make we sync. both set of tiles A and V and force fused comms.
Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)

# Sharded constant with p/q indices to ignore in second update stage.
with jax.named_scope("rotset_index_ignored"):
Expand Down Expand Up @@ -274,13 +275,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu
Vqcols,
)

# Barrier, to make we sync. both set of tiles A and V
Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
# Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile.
with jax.named_scope("Apqcols_rotation"):
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
with jax.named_scope("Vpqcols_rotation"):
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
# # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile.
# with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
# with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
return Apcols, Aqcols, Vpcols, Vqcols


Expand Down

0 comments on commit d415805

Please sign in to comment.