Skip to content

Commit

Permalink
#9486: revert from ttnn to tt_lib all_gather in llama2_70 t3k
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt committed Jun 21, 2024
1 parent ef22a4a commit 4b52249
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
8 changes: 4 additions & 4 deletions models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,11 @@ def attn_selfout(
if self.emulated:
attn_output = tt_all_gather_torch(attn_output, dim=-1)
else:
attn_output = ttnn.all_gather(
attn_output = tt_lib.tensor.all_gather(
attn_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["L1_MEMCFG"],
output_mem_config=self.model_config["L1_MEMCFG"],
)

for i in range(len(attn_output)):
Expand Down Expand Up @@ -911,11 +911,11 @@ def prefill_attn_selfout(self, attn_output: List[tt_lib.tensor.Tensor]) -> List[
if self.emulated:
attn_output = tt_all_gather_torch(attn_output, dim=-1)
else:
attn_output = ttnn.all_gather(
attn_output = tt_lib.tensor.all_gather(
attn_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["DRAM_MEMCFG"],
outout_mem_config=self.model_config["DRAM_MEMCFG"],
)

for i in range(len(attn_output)):
Expand Down
16 changes: 8 additions & 8 deletions models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,11 @@ def decode_forward(
if self.emulated:
xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1)
else:
xs_replicated = ttnn.all_gather(
xs_replicated = tt_lib.tensor.all_gather(
xs_replicated,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["L1_MEMCFG"],
output_mem_config=self.model_config["L1_MEMCFG"],
)

for i in range(self.num_devices):
Expand Down Expand Up @@ -360,11 +360,11 @@ def decode_forward(
if self.emulated:
attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1)
else:
attn_resid_replicated = ttnn.all_gather(
attn_resid_replicated = tt_lib.tensor.all_gather(
attn_resid_replicated,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["L1_MEMCFG"],
output_mem_config=self.model_config["L1_MEMCFG"],
)

for i in range(self.num_devices):
Expand Down Expand Up @@ -480,11 +480,11 @@ def prefill_forward(
if self.emulated:
xs_replicated = tt_all_gather_torch(xs_replicated, dim=-1)
else:
xs_replicated = ttnn.all_gather(
xs_replicated = tt_lib.tensor.all_gather(
xs_replicated,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["DRAM_MEMCFG"],
output_mem_config=self.model_config["DRAM_MEMCFG"],
)

attn_norm_interleaved = self.sharded_rmsnorm(xs_replicated, self.norm_eps, self.attn_norm_list)
Expand Down Expand Up @@ -515,11 +515,11 @@ def prefill_forward(
if self.emulated:
attn_resid_replicated = tt_all_gather_torch(attn_resid_replicated, dim=-1)
else:
attn_resid_replicated = ttnn.all_gather(
attn_resid_replicated = tt_lib.tensor.all_gather(
attn_resid_replicated,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["L1_MEMCFG"],
output_mem_config=self.model_config["L1_MEMCFG"],
)

ffn_norm_interleaved = self.sharded_rmsnorm(attn_resid_replicated, self.norm_eps, self.ffn_norm_list)
Expand Down
6 changes: 3 additions & 3 deletions models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def prefill_forward(self, x: List[tt_lib.tensor.Tensor]) -> List[tt_lib.tensor.T
if self.emulated:
hidden_states = tt_all_gather_torch(hidden_states, dim=-1)
else:
hidden_states = ttnn.all_gather(
hidden_states = tt_lib.tensor.all_gather(
hidden_states,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
Expand Down Expand Up @@ -270,11 +270,11 @@ def decode_forward(self, x: List[tt_lib.tensor.Tensor]) -> List[tt_lib.tensor.Te
if self.emulated:
hidden_states = tt_all_gather_torch(hidden_states, dim=-1)
else:
hidden_states = ttnn.all_gather(
hidden_states = tt_lib.tensor.all_gather(
hidden_states,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["L1_MEMCFG"],
output_mem_config=self.model_config["L1_MEMCFG"],
)

# Put AllGather results in L1 Sharded
Expand Down
8 changes: 4 additions & 4 deletions models/demos/t3000/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,11 @@ def decode_forward(
if self.emulated:
xs = tt_all_gather_torch(xs, dim=-1)
else:
xs = ttnn.all_gather(
xs = tt_lib.tensor.all_gather(
xs,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["L1_MEMCFG"],
output_mem_config=self.model_config["L1_MEMCFG"],
)

## Duplicate layernorm
Expand Down Expand Up @@ -492,11 +492,11 @@ def prefill_forward(
if self.emulated:
xs = tt_all_gather_torch(xs, dim=-1)
else:
xs = ttnn.all_gather(
xs = ttnn.tt_lib.tensor.all_gather(
xs,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
memory_config=self.model_config["DRAM_MEMCFG"],
output_mem_config=self.model_config["DRAM_MEMCFG"],
)

## Duplicate layernorm
Expand Down

0 comments on commit 4b52249

Please sign in to comment.