Skip to content

Commit

Permalink
#13136: Use ring as common topology across all ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Sep 30, 2024
1 parent 13a3885 commit 12a5db9
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 41 deletions.
18 changes: 16 additions & 2 deletions models/demos/tg/llama3_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ def tt_all_reduce(input_tensor, mesh_device, cluster_axis, dim=0, num_links=2, m
input_tensor = ttnn.to_memory_config(input_tensor, ttnn.DRAM_MEMORY_CONFIG)

gathered_tensor = ttnn.all_gather(
input_tensor, dim, num_links=num_links, cluster_axis=cluster_axis, mesh_device=mesh_device
input_tensor,
dim,
num_links=num_links,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
)
reduced_tensors = ttnn.experimental.fast_reduce_nc(
gathered_tensor, dims=[dim], output=None, compute_kernel_config=None
Expand All @@ -49,7 +54,14 @@ def tt_all_gather(input_tensor, mesh_device, cluster_axis, dim, num_links=2, mem
# Ensure the input tensor is in the correct memory configuration
input_tensor = ttnn.to_memory_config(input_tensor, ttnn.DRAM_MEMORY_CONFIG)

return ttnn.all_gather(input_tensor, dim, num_links=num_links, cluster_axis=cluster_axis, mesh_device=mesh_device)
return ttnn.all_gather(
input_tensor,
dim,
num_links=num_links,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
)


def tt_sharded_all_reduce(input_tensor, mesh_device, cluster_axis, dim=0, num_links=2, memory_config=None):
Expand All @@ -60,6 +72,7 @@ def tt_sharded_all_reduce(input_tensor, mesh_device, cluster_axis, dim=0, num_li
cluster_axis=cluster_axis,
mesh_device=mesh_device,
memory_config=memory_config,
topology=ttnn.Topology.Linear,
)
# Fast_reduce_nc does not support sharded memory configuration, convert to interleaved
gathered_tensor = ttnn.to_memory_config(gathered_tensor, ttnn.L1_MEMORY_CONFIG)
Expand All @@ -79,4 +92,5 @@ def tt_sharded_all_gather(input_tensor, mesh_device, cluster_axis, dim, num_link
cluster_axis=cluster_axis,
mesh_device=mesh_device,
memory_config=memory_config,
topology=ttnn.Topology.Linear,
)
27 changes: 23 additions & 4 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,9 @@ def test_device_line_all_gather_8x1(mesh_device):
full_tensor, mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(rows, cols), dims=(-2, -1))
)
ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=2, cluster_axis=0, mesh_device=mesh_device, num_links=1)
ttnn_tensor = ttnn.all_gather(
ttnn_tensor, dim=2, cluster_axis=0, mesh_device=mesh_device, num_links=1, topology=ttnn.Topology.Linear
)

device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(ttnn_tensor)
for index, device_tensor in enumerate(device_tensors):
Expand Down Expand Up @@ -1277,7 +1279,14 @@ def test_device_line_all_gather_8x4_data(mesh_device, cluster_axis: int, dim: in
full_tensor, mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(rows, cols), dims=(-2, -1))
)
ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=dim, cluster_axis=cluster_axis, mesh_device=mesh_device, num_links=1)
ttnn_tensor = ttnn.all_gather(
ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=1,
topology=ttnn.Topology.Linear,
)

device_tensors: typing.List[ttnn.Tensor] = ttnn.get_device_tensors(ttnn_tensor)

Expand Down Expand Up @@ -1416,7 +1425,9 @@ def test_line_all_gather_column_major(mesh_device):
)
ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)
ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, cluster_axis=0, mesh_device=mesh_device, num_links=1)
ttnn_tensor = ttnn.all_gather(
ttnn_tensor, dim=3, cluster_axis=0, mesh_device=mesh_device, num_links=1, topology=ttnn.Topology.Linear
)
tt_outputs = ttnn.to_torch(ttnn_tensor, mesh_composer=ListMeshToTensor(mesh_device))
for output in tt_outputs[1:]:
assert output.shape == (1, 1, 32, 32 * 8)
Expand Down Expand Up @@ -1456,7 +1467,14 @@ def test_device_line_all_gather_8x4_data(mesh_device, cluster_axis: int, dim: in
full_tensor, mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(rows, cols), dims=(-2, -1))
)
ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)
ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=dim, cluster_axis=cluster_axis, mesh_device=mesh_device, num_links=1)
ttnn_tensor = ttnn.all_gather(
ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=1,
topology=ttnn.Topology.Linear,
)


@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
Expand Down Expand Up @@ -1548,6 +1566,7 @@ def test_sharded_distributed_layernorm(mesh_device, input_width, input_height, c
cluster_axis=1,
mesh_device=mesh_device,
memory_config=gathered_stats_sharded_memory_config,
topology=ttnn.Topology.Linear,
)

tt_output_tensor = ttnn.rms_norm_post_all_gather(
Expand Down
7 changes: 6 additions & 1 deletion tests/ttnn/unit_tests/operations/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def run_with_trace(
memory_config=output_mem_config,
num_workers=n_worker,
num_buffers_per_channel=n_buffer,
topology=all_gather_topology,
)
for d in devices:
ttnn.synchronize_device(d)
Expand All @@ -103,6 +104,7 @@ def run_with_trace(
memory_config=output_mem_config,
num_workers=n_worker,
num_buffers_per_channel=n_buffer,
topology=all_gather_topology,
)
ttnn.end_trace_capture(t3k_mesh_device, trace_id, cq_id=0)
for d in devices:
Expand Down Expand Up @@ -157,7 +159,9 @@ def run_all_gather_impl(

input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)
for i in range(num_iters):
tt_out_tensor = ttnn.all_gather(input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config)
tt_out_tensor = ttnn.all_gather(
input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config, topology=all_gather_topology
)

for d in devices:
ttnn.synchronize_device(d)
Expand Down Expand Up @@ -1127,6 +1131,7 @@ def run_all_gather_sharded(
memory_config=output_mem_config,
num_workers=n_worker,
num_buffers_per_channel=n_buffer,
topology=all_gather_topology,
)
## Wait for completion
for d in devices:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
for _ in range(num_iters):
ttnn_tensor_out = ttnn.all_gather(
ttnn_tensor, dim=dim, cluster_axis=cluster_axis, mesh_device=mesh_device, num_links=num_links
ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
topology=ttnn.Topology.Linear,
)

concat_dims = (3, 2) if cluster_axis == 0 else (2, 3)
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct ExecuteAllGather {
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> num_workers = std::nullopt,
const std::optional<size_t> num_buffers_per_channel = std::nullopt,
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear);
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring);
};

} // namespace ccl
Expand Down
39 changes: 7 additions & 32 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation,
py::arg("memory_config") = std::nullopt,
py::arg("num_workers") = std::nullopt,
py::arg("num_buffers_per_channel") = std::nullopt,
py::arg("topology") = ttnn::ccl::Topology::Linear});
py::arg("topology") = ttnn::ccl::Topology::Ring});
}

} // namespace detail
Expand All @@ -85,35 +85,7 @@ void py_bind_all_gather(pybind11::module& module) {
Args:
* :attr:`input_tensor` (ttnn.Tensor): multi-device tensor
* :attr:`dim` (int)
Keyword Args:
* :attr:`num_links` (int): Number of links to use for the all-gather operation.
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
* :attr:`num_workers` (int): Number of workers to use for the operation.
* :attr:`num_buffers_per_channel` (int): Number of buffers per channel to use for the operation.
* :attr:`topology`: Topology to be used for the operation. Allowable options are Linear and Ring
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = ttnn.all_gather(tensor, dim=0)
)doc");

/*
detail::bind_line_all_gather(
module,
ttnn::line_all_gather,
R"doc(line_all_gather(input_tensor: ttnn.Tensor, dim: int, *, num_links: int = 1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Performs an all-gather operation on multi-device :attr:`input_tensor` across all devices.
Args:
* :attr:`input_tensor` (ttnn.Tensor):
multi-device tensor
* :attr:`dim` (int):
Dimension to perform the all-gather operation on.
After the all-gather operation, the size of the :attr:`dim`
dimension will larger by number of devices in the line.
* Following are applicable only for Linear Topology
* :attr:`cluster_axis` (int):
Provided a MeshTensor, the axis corresponding to MeshDevice
to perform the line-all-gather operation on.
Expand All @@ -123,13 +95,16 @@ void py_bind_all_gather(pybind11::module& module) {
Keyword Args:
* :attr:`num_links` (int): Number of links to use for the all-gather operation.
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
* :attr:`num_workers` (int): Number of workers to use for the operation.
* :attr:`num_buffers_per_channel` (int): Number of buffers per channel to use for the operation.
* :attr:`topology`: Topology to be used for the operation. Allowable options are Linear and Ring
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = ttnn.all_gather(tensor, dim=0, topology=ttnn.Topology.Linear)
>>> output = ttnn.all_gather(tensor, dim=0)
)doc");*/
)doc");
}

} // namespace ttnn::operations::ccl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ Tensor all_gather(
const std::optional<size_t> user_defined_num_buffers_per_channel,
const ttnn::ccl::Topology topology) {

TT_FATAL(topology != ttnn::ccl::Topology::Linear, "This api currently supported only for Linear topology");
const auto mesh_view = mesh_device.get_view();
std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();

Expand Down

0 comments on commit 12a5db9

Please sign in to comment.