From c32bc74fab9124b336d9ec7364d208edd4292f3c Mon Sep 17 00:00:00 2001 From: cfjchu Date: Mon, 16 Sep 2024 17:05:38 +0000 Subject: [PATCH] #10608: add device virtualization btw MeshDevice and physical system - This fixes #10608, #10419 by adding a logical 2D mesh to remap the physical mesh onto a flattened 2d grid. - A logical to physical coordinate translation map is now supplied per supported system. --- conftest.py | 16 +- .../llama2_70b/tests/test_llama_generation.py | 2 +- .../t3000/llama2_70b/tests/test_llama_perf.py | 2 +- .../tests/test_llama_stress_test.py | 2 +- .../demos/t3000/llama2_70b/tt/llama_common.py | 2 +- .../llama3_70b/tt/llama_attention_galaxy.py | 2 +- .../tg/llama3_70b/tt/llama_mlp_galaxy.py | 4 +- models/utility_functions.py | 7 +- tests/scripts/tgg/run_tgg_unit_tests.sh | 1 + tests/sweep_framework/README.md | 4 +- .../sweep_framework/sweeps/line_all_gather.py | 3 +- .../test_mesh_device_TGG.py | 12 + .../test_multidevice_TG.py | 10 +- .../ttnn/unit_tests/gtests/test_ccl_on_tg.cpp | 32 +- .../unit_tests/gtests/ttnn_test_fixtures.hpp | 13 +- tests/ttnn/unit_tests/test_multi_device.py | 9 +- .../impl/device/mesh_configurations/N300.json | 5 + .../device/mesh_configurations/T3000.json | 6 + .../impl/device/mesh_configurations/TG.json | 36 ++ .../impl/device/mesh_configurations/TGG.json | 68 ++++ .../device/mesh_configurations/device.json | 5 + tt_metal/impl/device/mesh_device.cpp | 375 ++++++++++++------ tt_metal/impl/device/mesh_device.hpp | 121 +++++- tt_metal/llrt/tt_cluster.cpp | 2 +- ttnn/cpp/pybind11/multi_device.hpp | 31 +- ttnn/cpp/ttnn/events.cpp | 4 +- ttnn/cpp/ttnn/multi_device.cpp | 8 +- ttnn/cpp/ttnn/multi_device.hpp | 6 +- ttnn/ttnn/__init__.py | 7 +- ttnn/ttnn/multi_device.py | 24 +- 30 files changed, 584 insertions(+), 235 deletions(-) create mode 100644 tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py create mode 100644 tt_metal/impl/device/mesh_configurations/N300.json create mode 100644 tt_metal/impl/device/mesh_configurations/T3000.json create mode 100644 tt_metal/impl/device/mesh_configurations/TG.json create mode 100644 tt_metal/impl/device/mesh_configurations/TGG.json create mode 100644 tt_metal/impl/device/mesh_configurations/device.json diff --git a/conftest.py b/conftest.py index 75036c294d2..36a8d35b6b5 100644 --- a/conftest.py +++ b/conftest.py @@ -207,9 +207,7 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]] - mesh_device = ttnn.open_mesh_device( - mesh_shape, device_ids[:num_devices_requested], dispatch_core_type=get_dispatch_core_type(), **device_params - ) + mesh_device = ttnn.open_mesh_device(mesh_shape, dispatch_core_type=get_dispatch_core_type(), **device_params) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") yield mesh_device @@ -235,9 +233,9 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic mesh_device = ttnn.open_mesh_device( ttnn.MeshShape(1, num_pcie_devices_requested), - device_ids[:num_pcie_devices_requested], dispatch_core_type=get_dispatch_core_type(), **device_params, + physical_device_ids=device_ids[:num_pcie_devices_requested], ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") @@ -256,17 +254,9 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device if ttnn.get_num_devices() < 8: pytest.skip() - device_ids = [0, 4, 5, 1, 2, 6, 7, 3] - try: - num_devices_requested = min(request.param, len(device_ids)) - except (ValueError, AttributeError): - num_devices_requested = len(device_ids) - - request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]] mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(1, num_devices_requested), - device_ids[:num_devices_requested], + ttnn.MeshShape(2, 4), dispatch_core_type=get_dispatch_core_type(), **device_params, ) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py index 48f6fe2e34c..253e550008d 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py @@ -147,7 +147,7 @@ def test_LlamaModel_inference( if t3k_mesh_device.get_num_devices() < n_devices and not emulated: pytest.skip(f"Requires at {n_devices} devices to run") - compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py index d7f9e674b5d..d9d10d7c04d 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py @@ -316,7 +316,7 @@ def test_Llama_perf_host( if t3k_mesh_device.get_num_devices() < n_devices and not emulated: pytest.skip(f"Requires at {n_devices} devices to run") - compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py b/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py index 3488a7e50cc..fc71590add5 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py @@ -145,7 +145,7 @@ def test_Llama_stress_test( if t3k_mesh_device.get_num_devices() < n_devices and not emulated: pytest.skip(f"Requires at {n_devices} devices to run") - compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index 78aec8ecab0..39aff936e33 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -192,7 +192,7 @@ def check_mesh_device(t3k_mesh_device, model_config): model_config["NUM_DEVICES"], ) - compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() assert not ( compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1] ), ("Requires grid size of at least %d to run", model_config["MAX_GRID_SIZE"]) diff --git a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py index 0888550e944..6723c2b1e2c 100644 --- a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py @@ -448,7 +448,7 @@ def attn_mqa( value_layer.deallocate(True) program_config = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=self.mesh_device.get_device(0).compute_with_storage_grid_size(), + compute_with_storage_grid_size=self.mesh_device.compute_with_storage_grid_size(), q_chunk_size=0, # unused k_chunk_size=0, # unused ) diff --git a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py index 05950e220d0..67d1e3a1918 100644 --- a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py @@ -55,8 +55,8 @@ def get_mlp_model_config(self): ttnn.CoreRange( ttnn.CoreCoord(0, 0), ttnn.CoreCoord( - self.mesh_device.get_device(0).dram_grid_size().x - 1, - self.mesh_device.get_device(0).dram_grid_size().y - 1, + self.mesh_device.dram_grid_size().x - 1, + self.mesh_device.dram_grid_size().y - 1, ), ) } diff --git a/models/utility_functions.py b/models/utility_functions.py index 4f217b42284..c158ff65025 100644 --- a/models/utility_functions.py +++ b/models/utility_functions.py @@ -971,9 +971,10 @@ def get_devices_for_t3000(all_devices, num_devices): if num_devices <= 4: return all_devices[:num_devices] elif num_devices == 8: - # TODO: Generalize this for different arch - hamiltonian_ring_indices = [0, 4, 5, 1, 2, 6, 7, 3] - return [all_devices[i] for i in hamiltonian_ring_indices] + # Temporary until we move request for ring order to CCL operations directly. + # This is better because we no longer need to manually manage the ring order. + ring_indices = ttnn.get_t3k_physical_device_ids_ring() + return [all_devices[i] for i in ring_indices] else: raise NotImplementedError("Only supports 1, 2, 3, 4, and 8 chip configurations!") diff --git a/tests/scripts/tgg/run_tgg_unit_tests.sh b/tests/scripts/tgg/run_tgg_unit_tests.sh index fb58c9b9b21..aa5a2449042 100755 --- a/tests/scripts/tgg/run_tgg_unit_tests.sh +++ b/tests/scripts/tgg/run_tgg_unit_tests.sh @@ -7,6 +7,7 @@ run_tgg_tests() { TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*" ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*" + pytest -s tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py } main() { diff --git a/tests/sweep_framework/README.md b/tests/sweep_framework/README.md index 85d779fea2a..0a7468523fa 100644 --- a/tests/sweep_framework/README.md +++ b/tests/sweep_framework/README.md @@ -190,11 +190,11 @@ def mesh_device_fixture(): assert ttnn.get_num_devices() >= 8, "Not T3000!" - device_ids = [0, 4, 5, 1, 2, 6, 7, 3] + device_ids = ttnn.get_t3k_physical_device_ids_ring() num_devices_requested = len(device_ids) mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested] + ttnn.MeshShape(1, num_devices_requested), ) print("ADD: Opened device mesh") diff --git a/tests/sweep_framework/sweeps/line_all_gather.py b/tests/sweep_framework/sweeps/line_all_gather.py index f5d2ebeb25c..8172c701c62 100644 --- a/tests/sweep_framework/sweeps/line_all_gather.py +++ b/tests/sweep_framework/sweeps/line_all_gather.py @@ -62,9 +62,8 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: def mesh_device_fixture(): - assert ttnn.get_num_devices() >= 8, "Not T3000!" - device_ids = [0, 4, 5, 1, 2, 6, 7, 3] + device_ids = ttnn.get_t3k_physical_device_ids_ring() num_devices_requested = len(device_ids) mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested]) print("ALL GATHER: Opened device mesh") diff --git a/tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py b/tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py new file mode 100644 index 00000000000..d5f4b8ed905 --- /dev/null +++ b/tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn + + +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 8), id="8x8_grid")], indirect=True) +def test_visualize_mesh_device(mesh_device): + ttnn.visualize_mesh_device(mesh_device) diff --git a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py index 919dcc59524..ac7a996f7b9 100644 --- a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py +++ b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py @@ -210,9 +210,7 @@ def test_galaxy_matmul_2d_fracture_dram_sharded(M, K, N, weights_dtype, mesh_sha { ttnn.CoreRange( ttnn.CoreCoord(0, 0), - ttnn.CoreCoord( - mesh_device.get_device(0).dram_grid_size().x - 1, mesh_device.get_device(0).dram_grid_size().y - 1 - ), + ttnn.CoreCoord(mesh_device.dram_grid_size().x - 1, mesh_device.dram_grid_size().y - 1), ) } ) @@ -711,7 +709,7 @@ def test_fill_cache( xt = x - compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size() + compute_grid_size = mesh_device.compute_with_storage_grid_size() num_cores = min(seq_len // 32 * num_heads, 32) # Always use max 32 cores for testing mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True)) input_shard_spec = ttnn.ShardSpec( @@ -783,7 +781,7 @@ def test_update_cache_decode( x_new = torch.cat((x_new, torch.zeros(32 - num_users - batch_offset, num_heads, 1, head_dim)), dim=0) assert x_new.shape[0] == 32, f"Expected x.shape[0] to be 32, got {x_new.shape[0]}" xt = x_new.permute(2, 1, 0, 3) - compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size() + compute_grid_size = mesh_device.compute_with_storage_grid_size() num_cores = min(max(num_users, 32) // 32 * num_heads, compute_grid_size.x * compute_grid_size.y) mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True)) input_shard_spec = ttnn.ShardSpec( @@ -865,7 +863,7 @@ def run_test_sdpa_decode_single_iter( sharded_in=False, sharded_out=False, ): - compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size() + compute_grid_size = mesh_device.compute_with_storage_grid_size() if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}") diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp index fe2b76abc96..912814fa76c 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp @@ -55,17 +55,8 @@ TEST(TGTests, TestAllGatherDeadlock) { TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4"); TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy"); - std::vector all_device_ids = {}; - for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) { - auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx); - for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) { - auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx); - all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end()); - } - } - // Create the device mesh: Grid size is . - auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER); + auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER); // Setup input data and output data containers MemoryConfig mem_cfg = MemoryConfig{ @@ -87,7 +78,7 @@ TEST(TGTests, TestAllGatherDeadlock) { // Iterate over each tunnel and run line all-gather multiple times. // For each tunnel, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged. for (uint32_t row = 0; row < 8; row++) { - auto devs = mesh.get_devices_on_row(row); + auto devs = mesh->get_devices_on_row(row); std::vector device_ids = {}; for (auto dev : devs) { device_ids.push_back(dev->id()); @@ -146,26 +137,17 @@ TEST(TGTests, TestReduceScatterDeadlock) { TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4"); TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy"); - std::vector all_device_ids = {}; - for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) { - auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx); - for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) { - auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx); - all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end()); - } - } - // Create the device mesh: Grid size is . - auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER); + auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER); // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the // first tunnel (forward path). - std::vector ring_devices = mesh.get_devices_on_row(0); // Tunnel 0 - std::vector ring_devices_1 = mesh.get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices = mesh->get_devices_on_row(0); // Tunnel 0 + std::vector ring_devices_1 = mesh->get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); - std::vector ring_devices_2 = mesh.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering + std::vector ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering std::reverse(ring_devices_2.begin(), ring_devices_2.end()); ring_devices_2 = std::vector(ring_devices_2.begin() + 1, ring_devices_2.end()); - std::vector ring_devices_3 = mesh.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks std::reverse(ring_devices_3.begin(), ring_devices_3.end()); ring_devices_3 = std::vector(ring_devices_3.begin() + 1, ring_devices_3.end() - 1); diff --git a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp index a645c7876db..9e779b7f0cc 100644 --- a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp +++ b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp @@ -67,19 +67,20 @@ class T3kMultiDeviceFixture : public ::testing::Test { if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) { GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine."; } - const auto T3K_DEVICE_IDS = DeviceIds{0, 4, 5, 1, 2, 6, 7, 3}; constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1; - mesh_device_ = std::make_unique( - MeshShape{1, num_devices}, - T3K_DEVICE_IDS, + mesh_device_ = MeshDevice::create( + MeshShape{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, DEFAULT_NUM_COMMAND_QUEUES, DispatchCoreType::WORKER); } - void TearDown() override { mesh_device_.reset(); } - std::unique_ptr mesh_device_; + void TearDown() override { + mesh_device_->close_devices(); + mesh_device_.reset(); + } + std::shared_ptr mesh_device_; }; } // namespace ttnn::multi_device::test diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index e286993f855..f1c2728857f 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -23,8 +23,7 @@ def test_mesh_device_open_close_explicit(silicon_arch_name, silicon_arch_wormhol if num_pcie_devices <= 1: pytest.skip("Requires multiple devices to run") - mesh_shape, device_ids = ttnn.MeshShape(2, 2), ttnn.get_pcie_device_ids() - multi_device = ttnn.open_mesh_device(mesh_shape, device_ids) + multi_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 2)) ttnn.close_mesh_device(multi_device) @@ -34,12 +33,12 @@ def test_multi_device_subset_mesh(silicon_arch_name, silicon_arch_wormhole_b0): if num_pcie_devices <= 1: pytest.skip("Requires multiple devices to run") - mesh_shape, device_ids = ttnn.MeshShape(1, 2), ttnn.get_pcie_device_ids() - multi_device = ttnn.open_mesh_device(mesh_shape, device_ids) + mesh_shape = ttnn.MeshShape(1, 2) + multi_device = ttnn.open_mesh_device(mesh_shape) assert multi_device.get_num_devices() == 2 ttnn.close_mesh_device(multi_device) - multi_device = ttnn.open_mesh_device(mesh_shape, device_ids) + multi_device = ttnn.open_mesh_device(mesh_shape) assert multi_device.get_num_devices() == 2 ttnn.close_mesh_device(multi_device) diff --git a/tt_metal/impl/device/mesh_configurations/N300.json b/tt_metal/impl/device/mesh_configurations/N300.json new file mode 100644 index 00000000000..9891c42cc11 --- /dev/null +++ b/tt_metal/impl/device/mesh_configurations/N300.json @@ -0,0 +1,5 @@ +{ + "logical_to_physical_coordinates": [ + [[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]] + ] +} diff --git a/tt_metal/impl/device/mesh_configurations/T3000.json b/tt_metal/impl/device/mesh_configurations/T3000.json new file mode 100644 index 00000000000..2c62209d01f --- /dev/null +++ b/tt_metal/impl/device/mesh_configurations/T3000.json @@ -0,0 +1,6 @@ +{ + "logical_to_physical_coordinates": [ + [[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]], + [[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]] + ] +} diff --git a/tt_metal/impl/device/mesh_configurations/TG.json b/tt_metal/impl/device/mesh_configurations/TG.json new file mode 100644 index 00000000000..63860eff6b6 --- /dev/null +++ b/tt_metal/impl/device/mesh_configurations/TG.json @@ -0,0 +1,36 @@ +{ + "logical_to_physical_coordinates": [ + [[0, 0], [7, 3, 0, 1]], + [[0, 1], [7, 2, 0, 1]], + [[0, 2], [7, 1, 0, 1]], + [[0, 3], [7, 0, 0, 1]], + [[1, 0], [6, 3, 0, 1]], + [[1, 1], [6, 2, 0, 1]], + [[1, 2], [6, 1, 0, 1]], + [[1, 3], [6, 0, 0, 1]], + [[2, 0], [5, 3, 0, 1]], + [[2, 1], [5, 2, 0, 1]], + [[2, 2], [5, 1, 0, 1]], + [[2, 3], [5, 0, 0, 1]], + [[3, 0], [4, 3, 0, 1]], + [[3, 1], [4, 2, 0, 1]], + [[3, 2], [4, 1, 0, 1]], + [[3, 3], [4, 0, 0, 1]], + [[4, 0], [3, 3, 0, 1]], + [[4, 1], [3, 2, 0, 1]], + [[4, 2], [3, 1, 0, 1]], + [[4, 3], [3, 0, 0, 1]], + [[5, 0], [2, 3, 0, 1]], + [[5, 1], [2, 2, 0, 1]], + [[5, 2], [2, 1, 0, 1]], + [[5, 3], [2, 0, 0, 1]], + [[6, 0], [1, 3, 0, 1]], + [[6, 1], [1, 2, 0, 1]], + [[6, 2], [1, 1, 0, 1]], + [[6, 3], [1, 0, 0, 1]], + [[7, 0], [0, 3, 0, 1]], + [[7, 1], [0, 2, 0, 1]], + [[7, 2], [0, 1, 0, 1]], + [[7, 3], [0, 0, 0, 1]] + ] +} diff --git a/tt_metal/impl/device/mesh_configurations/TGG.json b/tt_metal/impl/device/mesh_configurations/TGG.json new file mode 100644 index 00000000000..904acf6758d --- /dev/null +++ b/tt_metal/impl/device/mesh_configurations/TGG.json @@ -0,0 +1,68 @@ +{ + "logical_to_physical_coordinates": [ + [[0, 0], [7, 3, 0, 1]], + [[0, 1], [7, 2, 0, 1]], + [[0, 2], [7, 1, 0, 1]], + [[0, 3], [7, 0, 0, 1]], + [[1, 0], [6, 3, 0, 1]], + [[1, 1], [6, 2, 0, 1]], + [[1, 2], [6, 1, 0, 1]], + [[1, 3], [6, 0, 0, 1]], + [[2, 0], [5, 3, 0, 1]], + [[2, 1], [5, 2, 0, 1]], + [[2, 2], [5, 1, 0, 1]], + [[2, 3], [5, 0, 0, 1]], + [[3, 0], [4, 3, 0, 1]], + [[3, 1], [4, 2, 0, 1]], + [[3, 2], [4, 1, 0, 1]], + [[3, 3], [4, 0, 0, 1]], + [[4, 0], [3, 3, 0, 1]], + [[4, 1], [3, 2, 0, 1]], + [[4, 2], [3, 1, 0, 1]], + [[4, 3], [3, 0, 0, 1]], + [[5, 0], [2, 3, 0, 1]], + [[5, 1], [2, 2, 0, 1]], + [[5, 2], [2, 1, 0, 1]], + [[5, 3], [2, 0, 0, 1]], + [[6, 0], [1, 3, 0, 1]], + [[6, 1], [1, 2, 0, 1]], + [[6, 2], [1, 1, 0, 1]], + [[6, 3], [1, 0, 0, 1]], + [[7, 0], [0, 3, 0, 1]], + [[7, 1], [0, 2, 0, 1]], + [[7, 2], [0, 1, 0, 1]], + [[7, 3], [0, 0, 0, 1]], + [[0, 4], [0, 0, 0, 2]], + [[0, 5], [0, 1, 0, 2]], + [[0, 6], [0, 2, 0, 2]], + [[0, 7], [0, 3, 0, 2]], + [[1, 4], [1, 0, 0, 2]], + [[1, 5], [1, 1, 0, 2]], + [[1, 6], [1, 2, 0, 2]], + [[1, 7], [1, 3, 0, 2]], + [[2, 4], [2, 0, 0, 2]], + [[2, 5], [2, 1, 0, 2]], + [[2, 6], [2, 2, 0, 2]], + [[2, 7], [2, 3, 0, 2]], + [[3, 4], [3, 0, 0, 2]], + [[3, 5], [3, 1, 0, 2]], + [[3, 6], [3, 2, 0, 2]], + [[3, 7], [3, 3, 0, 2]], + [[4, 4], [4, 0, 0, 2]], + [[4, 5], [4, 1, 0, 2]], + [[4, 6], [4, 2, 0, 2]], + [[4, 7], [4, 3, 0, 2]], + [[5, 4], [5, 0, 0, 2]], + [[5, 5], [5, 1, 0, 2]], + [[5, 6], [5, 2, 0, 2]], + [[5, 7], [5, 3, 0, 2]], + [[6, 4], [6, 0, 0, 2]], + [[6, 5], [6, 1, 0, 2]], + [[6, 6], [6, 2, 0, 2]], + [[6, 7], [6, 3, 0, 2]], + [[7, 4], [7, 0, 0, 2]], + [[7, 5], [7, 1, 0, 2]], + [[7, 6], [7, 2, 0, 2]], + [[7, 7], [7, 3, 0, 2]] + ] + } diff --git a/tt_metal/impl/device/mesh_configurations/device.json b/tt_metal/impl/device/mesh_configurations/device.json new file mode 100644 index 00000000000..ea1a34365f0 --- /dev/null +++ b/tt_metal/impl/device/mesh_configurations/device.json @@ -0,0 +1,5 @@ +{ + "logical_to_physical_coordinates": [ + [[0, 0], [0, 0, 0, 0]] + ] +} diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index e43ddc6c4fe..e90d4a8925e 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -2,186 +2,302 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "tt_metal/impl/device/mesh_device.hpp" + #include +#include -#include "tt_metal/impl/device/mesh_device.hpp" -#include "tt_metal/impl/device/mesh_device_view.hpp" -#include "tt_metal/host_api.hpp" +#include "device/tt_cluster_descriptor_types.h" +#include "tt_metal/common/logger.hpp" #include "tt_metal/detail/tt_metal.hpp" - +#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/device/mesh_device_view.hpp" namespace tt::tt_metal { -MeshDevice::MeshDevice(const MeshShape& mesh_shape, const DeviceIds &device_ids, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type) - : mesh_shape(mesh_shape) -{ - auto [num_rows, num_cols] = mesh_shape; - auto num_requested_devices = num_rows * num_cols; - auto num_available_devices = tt::tt_metal::GetNumAvailableDevices(); - TT_ASSERT(num_requested_devices <= num_available_devices, "Requested more devices than available"); - TT_ASSERT(num_requested_devices <= device_ids.size(), "User provided insufficient number of device_ids for MeshDevice"); - - this->is_galaxy_ = tt::Cluster::instance().is_galaxy_cluster(); - if (this->is_galaxy_) { - // Temp solution until we add algorithmic way to determine chip connectivity - // Map col to tunnel depth and row to tunnel count - int cluster_tunnel_depth = tt::Cluster::instance().get_mmio_device_max_tunnel_depth(0); - int cluster_tunnel_count = tt::Cluster::instance().get_mmio_device_tunnel_count(0); - int num_mmio_devices = tt::Cluster::instance().number_of_pci_devices(); - TT_FATAL(num_cols <= cluster_tunnel_depth and num_rows <= cluster_tunnel_count * num_mmio_devices, "Unsupported Galaxy mesh shape"); - - DeviceIds galaxy_device_ids; - for (int mmio_device_id = 0; mmio_device_id < num_mmio_devices; mmio_device_id++) { - auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_device_id); - for (uint32_t t = 0; t < tunnels_from_mmio.size(); t++) { - if (galaxy_device_ids.size() == num_requested_devices) { - break; - } - int col_idx = 0; - for (uint32_t ts = 1; ts < tunnels_from_mmio[t].size(); ts++) { - galaxy_device_ids.push_back(tunnels_from_mmio[t][ts]); - col_idx ++; - if (col_idx == num_cols) { - break; - } - } - } - } - managed_devices = tt::tt_metal::detail::CreateDevices(galaxy_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type); - for (int i = 0; i < num_requested_devices; i++) { - mesh_devices.emplace_back(device_ids[i], managed_devices.at(galaxy_device_ids[i])); - } - this->view = std::make_unique(*this); - } else { - managed_devices = tt::tt_metal::detail::CreateDevices(device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type); - for (int i = 0; i < num_requested_devices; i++) { - mesh_devices.emplace_back(device_ids[i], managed_devices.at(device_ids[i])); - } +using LogicalCoordinate = Coordinate; +using PhysicalCoordinate = eth_coord_t; + +static std::string get_config_path(const std::string& filename) { + std::string root_path = getenv("TT_METAL_HOME") ? getenv("TT_METAL_HOME") : "./"; + return root_path + "/tt_metal/impl/device/mesh_configurations/" + filename; +} + +static std::map load_translation_map(const std::string& filename, const std::string& key) { + std::ifstream file(filename); + if (!file.is_open()) { + throw std::runtime_error("Unable to open file: " + filename); } - for (const auto& [dev_id, dev]: mesh_devices) { - log_debug(tt::LogMetal, "TTNN Dev {}: Metal Dev {}", dev_id, dev->id()); + nlohmann::json j; + try { + file >> j; + } catch (const nlohmann::json::parse_error& e) { + throw std::runtime_error("JSON parsing error in file " + filename + ": " + e.what()); + } + + if (!j.contains(key)) { + throw std::runtime_error("Key '" + key + "' not found in JSON file: " + filename); + } + + std::map result; + for (const auto& mapping : j[key]) { + if (mapping.size() != 2 || mapping[0].size() != 2 || mapping[1].size() != 4) { + throw std::runtime_error("Invalid coordinate format in JSON file: " + filename); + } + result.emplace(LogicalCoordinate{mapping[0][0], mapping[0][1]}, PhysicalCoordinate{mapping[1][1], mapping[1][0], mapping[1][2], mapping[1][3]}); } + + return result; } +MeshShape SystemMesh::get_system_mesh_shape(std::size_t system_num_devices) { + const std::unordered_map system_mesh_to_shape = { + {1, MeshShape{1, 1}}, // single-device + {2, MeshShape{1, 2}}, // N300 + {8, MeshShape{2, 4}}, // T3000; as ring to match existing tests + {32, MeshShape{8, 4}}, // TG + {64, MeshShape{8, 8}}, // TGG + }; + TT_FATAL(system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); + auto shape = system_mesh_to_shape.at(system_num_devices); + log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.first, shape.second); + return shape; +} -MeshDevice::~MeshDevice() { - if (not managed_devices.empty()) { - close_devices(); +std::map SystemMesh::get_system_mesh_translation_map(std::size_t system_num_devices) { + const std::unordered_map system_mesh_translation_map = { + {1, "device.json"}, + {2, "N300.json"}, + {8, "T3000.json"}, + {32, "TG.json"}, + {64, "TGG.json"}, + }; + TT_FATAL(system_mesh_translation_map.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); + auto translation_config_file = get_config_path(system_mesh_translation_map.at(system_num_devices)); + return load_translation_map(translation_config_file, "logical_to_physical_coordinates"); +} + +bool SystemMesh::is_system_mesh_initialized() const { + return this->physical_coordinate_to_device_id.size() > 0; +} + +SystemMesh& SystemMesh::instance() { + static SystemMesh instance; + if (!instance.is_system_mesh_initialized()) { + instance.initialize(); + } + return instance; +} +void SystemMesh::initialize() { + this->physical_device_id_to_coordinate = tt::Cluster::instance().get_user_chip_ethernet_coordinates(); + for (const auto& [chip_id, physical_coordinate] : this->physical_device_id_to_coordinate) { + this->physical_coordinate_to_device_id.emplace(physical_coordinate, chip_id); } + + // Initialize the system mesh shape and translation map + auto num_devices = physical_coordinate_to_device_id.size(); + this->logical_mesh_shape = SystemMesh::get_system_mesh_shape(num_devices); + this->logical_to_physical_coordinates = SystemMesh::get_system_mesh_translation_map(num_devices); +} + +const MeshShape& SystemMesh::get_shape() const { return this->logical_mesh_shape; } +std::size_t SystemMesh::get_num_devices() const { + auto [num_rows, num_cols] = this->get_shape(); + return num_rows * num_cols; } -Device* MeshDevice::get_device(int logical_device_id) const { - for (const auto& [device_id, device] : mesh_devices) { - if (device_id == logical_device_id) { - return device; +std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { + std::vector physical_device_ids; + auto [system_mesh_rows, system_mesh_cols] = this->get_shape(); + auto [requested_rows, requested_cols] = config.mesh_shape; + auto [row_offset, col_offset] = config.offset; + + for (int row = 0; row < requested_rows; row++) { + for (int col = 0; col < requested_cols; col++) { + auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset); + auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; + auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate); + auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate); + physical_device_ids.push_back(physical_device_id); + + log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", + logical_device_id, logical_coordinate, physical_coordinate, physical_device_id); } } - TT_THROW("User has provided an invalid device index"); + return physical_device_ids; } -std::vector MeshDevice::get_devices() const -{ - std::vector devices; - for (const auto& [device_id, device] : mesh_devices) { - devices.push_back(device); +std::vector SystemMesh::map_mesh_device( + std::shared_ptr mesh_device, + size_t num_command_queues, + size_t l1_small_size, + size_t trace_region_size, + DispatchCoreType dispatch_core_type, + const std::pair& offset, + const std::vector& user_provided_physical_device_ids) { + + auto [requested_num_rows, requested_num_cols] = mesh_device->shape(); + auto [max_num_rows, max_num_cols] = this->logical_mesh_shape; + auto [row_offset, col_offset] = offset; + + log_debug(LogMetal, "Mapping MeshDevice ({}x{}) with offset: {}, {}", requested_num_rows, requested_num_cols, row_offset, col_offset); + TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); + TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols); + + this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); + + auto physical_device_ids = user_provided_physical_device_ids.empty() ? + this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) : + user_provided_physical_device_ids; + + this->opened_devices[mesh_device->get_mesh_id()] = tt::tt_metal::detail::CreateDevices( + physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type); + + std::vector mapped_devices; + for (auto physical_device_id : physical_device_ids) { + auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id); + mapped_devices.push_back(mapped_device); + this->assigned_devices[mesh_device->get_mesh_id()].push_back(physical_device_id); + this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); } - return devices; + return mapped_devices; } -Device* MeshDevice::get_device(int row_idx, int col_idx) const { - if (not is_galaxy_) { - TT_THROW("Non-galaxy device mesh does not currently support indexing over rows and columns of a logical 2D mesh."); +void SystemMesh::unmap_mesh_device(const std::shared_ptr& mesh_device) { + auto mesh_id = mesh_device->get_mesh_id(); + + // Clean up all state related to this virtual mesh + this->assigned_mesh_device_devices.erase(mesh_id); + + // Remove the devices from assigned_physical_id_to_device + for (auto physical_id : this->assigned_devices.at(mesh_id)) { + this->assigned_physical_id_to_device.erase(physical_id); } + this->assigned_devices.erase(mesh_id); - TT_FATAL( - this->num_rows() != 0 and this->num_cols() != 0, - "#10419, Current device mesh does not support indexing by row or col indices."); - TT_FATAL(row_idx >= 0 and row_idx < this->num_rows(), "Invalid row index."); - TT_FATAL(col_idx >= 0 and col_idx < this->num_cols(), "Invalid col index."); - int idx = row_idx * this->num_cols() + col_idx; - return this->mesh_devices[idx].second; + // Close the devices + tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); + this->opened_devices.erase(mesh_id); } -std::vector MeshDevice::get_devices_on_row(int row_idx) const { - if (not is_galaxy_) { - TT_THROW("Non-galaxy device mesh does not currently support indexing over rows and columns of a logical 2D mesh."); - } - return this->view->get_devices_on_row(row_idx); +static MeshDeviceID generate_unique_mesh_id() { + static std::atomic next_id{0}; + return next_id++; } -std::vector MeshDevice::get_devices_on_column(int col_idx) const { - if (not is_galaxy_) { - TT_THROW("Non-galaxy device mesh does not currently support indexing over rows and columns of a logical 2D mesh."); - } - return this->view->get_devices_on_column(col_idx); +MeshDevice::MeshDevice(const MeshShape& mesh_device_shape) : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()) {} + +std::shared_ptr MeshDevice::create( + const MeshShape& mesh_device_shape, + size_t l1_small_size, + size_t trace_region_size, + size_t num_command_queues, + DispatchCoreType dispatch_core_type, + const std::pair& offset, + const std::vector& user_provided_physical_device_ids) +{ + auto mesh_device = std::make_shared(mesh_device_shape); + mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset, user_provided_physical_device_ids); + + return mesh_device; } -const DeviceIds MeshDevice::get_device_ids() const +void MeshDevice::initialize( + size_t l1_small_size, + size_t trace_region_size, + size_t num_command_queues, + DispatchCoreType dispatch_core_type, + const std::pair& offset, + const std::vector& physical_device_ids) { - DeviceIds device_ids; - for (const auto& [device_id, device] : mesh_devices) { - device_ids.push_back(device_id); + auto [num_rows, num_cols] = this->shape(); + auto num_requested_devices = num_rows * num_cols; + auto num_available_devices = tt::tt_metal::GetNumAvailableDevices(); + TT_FATAL( + num_requested_devices <= num_available_devices, + "User has requested more devices than available: {} requested, {} available", + num_requested_devices, num_available_devices); + + auto& instance = SystemMesh::instance(); + this->devices = instance.map_mesh_device( + shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids); + this->primary_view = std::make_unique(*this); + + for (int device_index = 0; device_index < this->devices.size(); device_index++) { + this->physical_id_to_device_index.insert({this->devices[device_index]->id(), device_index}); } - return device_ids; } -int MeshDevice::num_devices() const -{ - return mesh_devices.size(); +MeshDevice::~MeshDevice() { + if (not this->devices.empty()) { + this->close_devices(); + } } -CoreCoord MeshDevice::compute_with_storage_grid_size() const { - return mesh_devices.at(0).second->compute_with_storage_grid_size(); +Device* MeshDevice::get_device_index(int logical_device_id) const { + TT_FATAL(logical_device_id >= 0 and logical_device_id < num_devices(), "Invalid device index"); + return this->devices.at(logical_device_id); } -CoreCoord MeshDevice::dram_grid_size() const { - return mesh_devices.at(0).second->dram_grid_size(); +Device* MeshDevice::get_device(int physical_device_id) const { + return this->devices.at(this->physical_id_to_device_index.at(physical_device_id)); } -tt::ARCH MeshDevice::arch() const { - return mesh_devices.at(0).second->arch(); +std::vector MeshDevice::get_devices() const { return this->devices; } + +Device* MeshDevice::get_device(int row_idx, int col_idx) const { + return this->get_device_index(row_idx * num_cols() + col_idx); } -int MeshDevice::num_rows() const -{ - return this->mesh_shape.first; +std::vector MeshDevice::get_devices_on_row(int row_idx) const { + return this->primary_view->get_devices_on_row(row_idx); } -int MeshDevice::num_cols() const -{ - return this->mesh_shape.second; +std::vector MeshDevice::get_devices_on_column(int col_idx) const { + return this->primary_view->get_devices_on_column(col_idx); } -MeshShape MeshDevice::shape() const -{ - return this->mesh_shape; +const DeviceIds MeshDevice::get_device_ids() const { + DeviceIds device_ids; + for (auto device : this->get_devices()) { + device_ids.push_back(device->id()); + } + return device_ids; } +int MeshDevice::num_devices() const { return num_rows() * num_cols(); } + +CoreCoord MeshDevice::compute_with_storage_grid_size() const { return get_device_index(0)->compute_with_storage_grid_size(); } + +CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_grid_size(); } + +tt::ARCH MeshDevice::arch() const { return get_device_index(0)->arch(); } + +int MeshDevice::num_rows() const { return this->mesh_device_shape.first; } + +int MeshDevice::num_cols() const { return this->mesh_device_shape.second; } + +MeshShape MeshDevice::shape() const { return this->mesh_device_shape; } + void MeshDevice::close_devices() { - tt::tt_metal::detail::CloseDevices(managed_devices); - mesh_devices.clear(); - managed_devices.clear(); + SystemMesh::instance().unmap_mesh_device(shared_from_this()); + this->devices.clear(); + this->physical_id_to_device_index.clear(); + this->primary_view.reset(); } std::string MeshDevice::to_string() const { - return fmt::format("MeshDevice({}x{} grid, {} devices)", - this->num_rows(), - this->num_cols(), - this->num_devices()); + return fmt::format("MeshDevice({}x{} grid, {} devices)", this->num_rows(), this->num_cols(), this->num_devices()); } -std::shared_ptr MeshDevice::get_view() const { - return this->view; -} +std::shared_ptr MeshDevice::get_view() const { return this->primary_view; } -std::shared_ptr MeshDevice::get_view() { - return this->view; -} +std::shared_ptr MeshDevice::get_view() { return this->primary_view; } -std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { - return os << mesh_device.to_string(); -} +MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } + +std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); } bool validate_worker_modes(const std::vector& workers) { bool worker_modes_match = true; @@ -192,4 +308,13 @@ bool validate_worker_modes(const std::vector& workers) { return worker_modes_match; } -} // namespace tt::tt_metal +std::vector get_t3k_physical_device_ids_ring() { + auto& instance = SystemMesh::instance(); + auto num_devices = instance.get_num_devices(); + TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices"); + + auto physical_device_ids = instance.get_mapped_physical_device_ids(MeshDeviceConfig{instance.get_shape(), MeshOffset{0, 0}}); + return physical_device_ids; +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 6ba0c336aaf..940110973cc 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -4,28 +4,106 @@ #pragma once -#include -#include #include +#include #include +#include +#include "mesh_device_view.hpp" #include "tt_metal/impl/device/device.hpp" #include "tt_metal/impl/device/mesh_device_view.hpp" namespace tt::tt_metal { using DeviceIds = std::vector; +using MeshDeviceID = std::size_t; +using MeshOffset = std::pair; class MeshDeviceView; -class MeshDevice -{ -public: +struct MeshDeviceConfig { MeshShape mesh_shape; - std::map managed_devices; - std::vector> mesh_devices; - std::shared_ptr view; + MeshOffset offset; +}; - MeshDevice(const MeshShape &mesh_shape, const DeviceIds &device_ids, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type); +// SystemMesh creates a virtualization over the physical devices in the system. +// It creates a logical 2D-mesh of devices and manages the mapping between logical and physical device coordinates. +// It is responsible for the assignment of devices in a MeshDevice to physical devices, and the creation and deletion of +// device resources. +class SystemMesh { + private: + using LogicalCoordinate = Coordinate; + using PhysicalCoordinate = eth_coord_t; + + // Keep track of the devices that were opened so we can close them later. We shouldn't + // to keep track of this but DevicePool seems to open all devices associated with an MMIO device id + std::unordered_map> opened_devices; + std::unordered_map> assigned_devices; + std::unordered_map> assigned_mesh_device_devices; + std::unordered_map assigned_physical_id_to_device; + + // Logical mesh shape and coordinates + MeshShape logical_mesh_shape; + std::map logical_to_physical_coordinates; + + // Handling of physical coordinates + std::unordered_map physical_coordinate_to_device_id; + std::unordered_map physical_device_id_to_coordinate; + + SystemMesh() = default; + SystemMesh(const SystemMesh &) = delete; + SystemMesh &operator=(const SystemMesh &) = delete; + SystemMesh(SystemMesh &&) = delete; + SystemMesh &operator=(SystemMesh &&) = delete; + + static MeshShape get_system_mesh_shape(std::size_t system_num_devices); + static std::map get_system_mesh_translation_map( + std::size_t system_num_devices); + + bool is_system_mesh_initialized() const; + + public: + static SystemMesh &instance(); + + void initialize(); + + // Return the shape of the logical mesh + const MeshShape &get_shape() const; + std::size_t get_num_devices() const; + + // Get the physical device IDs mapped to a MeshDevice + std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; + + // Map MeshDevice to physical devices + std::vector map_mesh_device( + std::shared_ptr mesh_device, + size_t num_command_queues, + size_t l1_small_size, + size_t trace_region_size, + DispatchCoreType dispatch_core_type, + const std::pair &offset = {0, 0}, + const std::vector &physical_device_ids = {}); + + // Unmap MeshDevice, releasing the associated physical devices. + void unmap_mesh_device(const std::shared_ptr &mesh_device); +}; + +class MeshDevice : public std::enable_shared_from_this { + MeshDeviceID mesh_id; + MeshShape mesh_device_shape; + std::shared_ptr primary_view; + std::vector devices; + std::unordered_map physical_id_to_device_index; + + void initialize( + size_t l1_small_size, + size_t trace_region_size, + size_t num_command_queues, + DispatchCoreType dispatch_core_type, + const std::pair &offset, + const std::vector &physical_device_ids); + + public: + MeshDevice(const MeshShape &mesh_device_shape); ~MeshDevice(); MeshDevice(const MeshDevice &) = delete; @@ -34,8 +112,9 @@ class MeshDevice MeshDevice(MeshDevice &&) = delete; MeshDevice &operator=(MeshDevice &&) = delete; - std::vector get_devices() const; - Device *get_device(int logical_device_id) const; + std::vector get_devices() const; + Device *get_device_index(int logical_device_id) const; + Device *get_device(int physical_device_id) const; Device *get_device(int row_idx, int col_idx) const; std::vector get_devices_on_row(int row_idx) const; std::vector get_devices_on_column(int col_idx) const; @@ -58,12 +137,20 @@ class MeshDevice std::shared_ptr get_view(); std::string to_string() const; - - private: - bool is_galaxy_; + MeshDeviceID get_mesh_id() const; + + static std::shared_ptr create( + const MeshShape &mesh_device_shape, + size_t l1_small_size, + size_t trace_region_size, + size_t num_command_queues, + DispatchCoreType dispatch_core_type, + const std::pair &offset = {0, 0}, + const std::vector &physical_device_ids = {}); }; -std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device); -bool validate_worker_modes(const std::vector& workers); +std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device); +bool validate_worker_modes(const std::vector &workers); +std::vector get_t3k_physical_device_ids_ring(); -} // namespace tt::tt_metal +} // namespace tt::tt_metal diff --git a/tt_metal/llrt/tt_cluster.cpp b/tt_metal/llrt/tt_cluster.cpp index 43d47b182ae..1483bd38bbd 100644 --- a/tt_metal/llrt/tt_cluster.cpp +++ b/tt_metal/llrt/tt_cluster.cpp @@ -329,7 +329,7 @@ std::unordered_map Cluster::get_user_chip_ethernet_coord auto user_chip_ethernet_coordinates = this->cluster_desc_->get_chip_locations(); if (this->is_galaxy_cluster()) { std::erase_if(user_chip_ethernet_coordinates, [this](const auto& entry) { - return this->cluster_desc_->get_board_type(entry.first) != BoardType::GALAXY; // need to fix this + return this->cluster_desc_->get_board_type(entry.first) != BoardType::GALAXY; }); } return user_chip_ethernet_coordinates; diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index 7b9fbff52a8..70d9755d040 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -16,20 +16,36 @@ namespace ttnn { namespace multi_device { -void py_module_types(py::module& module) { py::class_(module, "MeshDevice"); } +void py_module_types(py::module& module) { py::class_>(module, "MeshDevice"); } void py_module(py::module& module) { - auto py_mesh_device = static_cast>(module.attr("MeshDevice")); + auto py_mesh_device = static_cast>>(module.attr("MeshDevice")); py_mesh_device .def( - py::init, size_t, size_t, size_t, DispatchCoreType>(), + py::init([](const MeshShape& mesh_device_shape, + size_t l1_small_size, + size_t trace_region_size, + size_t num_command_queues, + DispatchCoreType dispatch_core_type, + const std::pair& offset, + const std::vector& physical_device_ids) { + return MeshDevice::create( + mesh_device_shape, + l1_small_size, + trace_region_size, + num_command_queues, + dispatch_core_type, + offset, + physical_device_ids); + }), py::kw_only(), py::arg("mesh_shape"), - py::arg("device_ids"), py::arg("l1_small_size"), py::arg("trace_region_size"), py::arg("num_command_queues"), - py::arg("dispatch_core_type")) + py::arg("dispatch_core_type"), + py::arg("offset"), + py::arg("physical_device_ids")) .def("get_num_devices", &MeshDevice::num_devices) .def("get_device_ids", &MeshDevice::get_device_ids) .def( @@ -106,11 +122,11 @@ void py_module(py::module& module) { &open_mesh_device, py::kw_only(), py::arg("mesh_shape"), - py::arg("device_ids"), py::arg("l1_small_size"), py::arg("trace_region_size"), py::arg("num_command_queues"), - py::arg("dispatch_core_type")); + py::arg("dispatch_core_type"), + py::arg("physical_device_ids")); module.def("close_mesh_device", &close_mesh_device, py::arg("mesh_device"), py::kw_only()); module.def( @@ -147,6 +163,7 @@ void py_module(py::module& module) { )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); module.def("aggregate_as_tensor", &aggregate_as_tensor, py::arg("tensors"), py::kw_only()); + module.def("get_t3k_physical_device_ids_ring", &tt::tt_metal::get_t3k_physical_device_ids_ring); } } // namespace multi_device diff --git a/ttnn/cpp/ttnn/events.cpp b/ttnn/cpp/ttnn/events.cpp index f37efb94aa7..5f207b1cfdf 100644 --- a/ttnn/cpp/ttnn/events.cpp +++ b/ttnn/cpp/ttnn/events.cpp @@ -11,11 +11,11 @@ namespace ttnn::events { MultiDeviceEvent::MultiDeviceEvent(MeshDevice* mesh_device) { TT_ASSERT(mesh_device != nullptr, "Must provide a valid mesh_device when initializing an event on multiple devices."); - auto& devices = mesh_device->mesh_devices; + auto devices = mesh_device->get_devices(); this->events = std::vector>(devices.size()); for (int event_idx = 0; event_idx < devices.size(); event_idx++) { this->events[event_idx] = std::make_shared(); - this->events[event_idx]->device = devices[event_idx].second; + this->events[event_idx]->device = devices[event_idx]; } } diff --git a/ttnn/cpp/ttnn/multi_device.cpp b/ttnn/cpp/ttnn/multi_device.cpp index fe77a9f5fe4..7fa5f9e0d65 100644 --- a/ttnn/cpp/ttnn/multi_device.cpp +++ b/ttnn/cpp/ttnn/multi_device.cpp @@ -12,12 +12,12 @@ namespace ttnn::multi_device { -MeshDevice open_mesh_device(const MeshShape& mesh_shape, const DeviceIds& device_ids, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type) { - return MeshDevice(mesh_shape, device_ids, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type); +std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset) { + return MeshDevice::create(mesh_shape, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset); } -void close_mesh_device(MeshDevice &multi_device) { - multi_device.close_devices(); +void close_mesh_device(const std::shared_ptr& mesh_device) { + mesh_device->close_devices(); } std::vector get_device_tensors(const ttnn::Tensor& tensor) { diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index b3821e2b5cb..d7db05721bc 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -15,13 +15,15 @@ using Device = ttnn::Device; namespace ttnn { namespace multi_device { -MeshDevice open_mesh_device(const MeshShape& mesh_shape, const DeviceIds& device_ids, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type); -void close_mesh_device(MeshDevice &multi_device); +std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset = {0, 0}); +void close_mesh_device(const std::shared_ptr& mesh_device); std::vector get_device_tensors(const ttnn::Tensor& tensor); Tensor aggregate_as_tensor(std::vector& tensor_shards); +std::vector get_t3k_physical_device_ids_ring(); + } // namespace multi_device using namespace multi_device; diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index e5423eb21ee..0192b2b1673 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -93,7 +93,12 @@ def manage_config(name, value): logger.debug(f"Restored ttnn.CONFIG.{name} to {original_value}") -from ttnn._ttnn.multi_device import get_device_tensor, get_device_tensors, aggregate_as_tensor +from ttnn._ttnn.multi_device import ( + get_device_tensor, + get_device_tensors, + aggregate_as_tensor, + get_t3k_physical_device_ids_ring, +) from ttnn._ttnn.events import create_event, record_event, wait_for_event diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/multi_device.py index 263666f6398..2f49af86422 100644 --- a/ttnn/ttnn/multi_device.py +++ b/ttnn/ttnn/multi_device.py @@ -5,7 +5,7 @@ import contextlib import functools -from typing import List, Dict, Optional, Callable, Tuple, Optional, Callable, Union +from typing import List, Dict, Optional, Callable, Tuple, Optional, Callable, Union, List import ttnn @@ -134,26 +134,36 @@ def get_device_ids() -> List[int]: def open_mesh_device( mesh_shape: ttnn.MeshShape, - device_ids: List[int], l1_small_size: int = ttnn._ttnn.device.DEFAULT_L1_SMALL_SIZE, trace_region_size: int = ttnn._ttnn.device.DEFAULT_TRACE_REGION_SIZE, num_command_queues: int = 1, dispatch_core_type: int = DispatchCoreType.WORKER, + offset: Tuple[int, int] = (0, 0), + physical_device_ids: List[int] = [], ): """ - open_mesh_device(mesh_shape: ttnn.MeshShape, device_ids: int) -> ttnn.MeshDevice: + Open a mesh device with the specified configuration. - Open a device with the given device_id. If the device is already open, return the existing device. - """ - assert len(device_ids) > 0 + Args: + mesh_shape (ttnn.MeshShape): The shape of the mesh device. + l1_small_size (int, optional): Size of the L1 small memory. Defaults to ttnn._ttnn.device.DEFAULT_L1_SMALL_SIZE. + trace_region_size (int, optional): Size of the trace region. Defaults to ttnn._ttnn.device.DEFAULT_TRACE_REGION_SIZE. + num_command_queues (int, optional): Number of command queues. Defaults to 1. + dispatch_core_type (int, optional): Type of dispatch core. Defaults to DispatchCoreType.WORKER. + offset (Tuple[int, int], optional): Offset in logical mesh coordinates for the mesh device. Defaults to (0, 0). + Returns: + ttnn._ttnn.multi_device.MeshDevice: The opened mesh device. + + """ return ttnn._ttnn.multi_device.MeshDevice( mesh_shape=mesh_shape.as_tuple(), - device_ids=device_ids, l1_small_size=l1_small_size, trace_region_size=trace_region_size, num_command_queues=num_command_queues, dispatch_core_type=dispatch_core_type, + offset=offset, + physical_device_ids=physical_device_ids, )