From 742a37d9dc3a4b2b550312c7c3a94f6ca9b86263 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Mon, 23 Sep 2024 23:35:43 +0000 Subject: [PATCH] Revert "Mesh Virtualization (#12719)" This reverts commit 7738225bedacbfbd5dedb35e1c37a54abc75c6c2. --- conftest.py | 16 +- .../llama2_70b/tests/test_llama_generation.py | 2 +- .../t3000/llama2_70b/tests/test_llama_perf.py | 2 +- .../tests/test_llama_stress_test.py | 2 +- .../demos/t3000/llama2_70b/tt/llama_common.py | 2 +- .../llama3_70b/tt/llama_attention_galaxy.py | 2 +- .../tg/llama3_70b/tt/llama_mlp_galaxy.py | 4 +- models/utility_functions.py | 7 +- tests/scripts/tgg/run_tgg_unit_tests.sh | 1 - tests/sweep_framework/README.md | 4 +- .../sweep_framework/sweeps/line_all_gather.py | 3 +- .../test_mesh_device_TGG.py | 12 - .../test_multidevice_TG.py | 10 +- .../ttnn/unit_tests/gtests/test_ccl_on_tg.cpp | 32 +- .../unit_tests/gtests/ttnn_test_fixtures.hpp | 13 +- tests/ttnn/unit_tests/test_multi_device.py | 9 +- .../impl/device/mesh_configurations/N300.json | 5 - .../device/mesh_configurations/T3000.json | 6 - .../impl/device/mesh_configurations/TG.json | 36 -- .../impl/device/mesh_configurations/TGG.json | 68 ---- .../device/mesh_configurations/device.json | 5 - tt_metal/impl/device/mesh_device.cpp | 375 ++++++------------ tt_metal/impl/device/mesh_device.hpp | 121 +----- tt_metal/llrt/tt_cluster.cpp | 10 - tt_metal/llrt/tt_cluster.hpp | 2 - ttnn/cpp/pybind11/multi_device.hpp | 31 +- ttnn/cpp/ttnn/events.cpp | 4 +- ttnn/cpp/ttnn/multi_device.cpp | 8 +- ttnn/cpp/ttnn/multi_device.hpp | 6 +- ttnn/ttnn/__init__.py | 7 +- ttnn/ttnn/multi_device.py | 24 +- 31 files changed, 234 insertions(+), 595 deletions(-) delete mode 100644 tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py delete mode 100644 tt_metal/impl/device/mesh_configurations/N300.json delete mode 100644 tt_metal/impl/device/mesh_configurations/T3000.json delete mode 100644 tt_metal/impl/device/mesh_configurations/TG.json delete mode 100644 tt_metal/impl/device/mesh_configurations/TGG.json delete mode 100644 tt_metal/impl/device/mesh_configurations/device.json diff --git a/conftest.py b/conftest.py index 36a8d35b6b5..75036c294d2 100644 --- a/conftest.py +++ b/conftest.py @@ -207,7 +207,9 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]] - mesh_device = ttnn.open_mesh_device(mesh_shape, dispatch_core_type=get_dispatch_core_type(), **device_params) + mesh_device = ttnn.open_mesh_device( + mesh_shape, device_ids[:num_devices_requested], dispatch_core_type=get_dispatch_core_type(), **device_params + ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") yield mesh_device @@ -233,9 +235,9 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic mesh_device = ttnn.open_mesh_device( ttnn.MeshShape(1, num_pcie_devices_requested), + device_ids[:num_pcie_devices_requested], dispatch_core_type=get_dispatch_core_type(), **device_params, - physical_device_ids=device_ids[:num_pcie_devices_requested], ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") @@ -254,9 +256,17 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device if ttnn.get_num_devices() < 8: pytest.skip() + device_ids = [0, 4, 5, 1, 2, 6, 7, 3] + try: + num_devices_requested = min(request.param, len(device_ids)) + except (ValueError, AttributeError): + num_devices_requested = len(device_ids) + + request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]] mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(2, 4), + ttnn.MeshShape(1, num_devices_requested), + device_ids[:num_devices_requested], dispatch_core_type=get_dispatch_core_type(), **device_params, ) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py index 1d92fe1916f..efcc5f27b06 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py @@ -147,7 +147,7 @@ def test_LlamaModel_inference( if t3k_mesh_device.get_num_devices() < n_devices: pytest.skip(f"Requires at {n_devices} devices to run") - compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py index b4e81524ec6..ebf9ad4edf7 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py @@ -310,7 +310,7 @@ def test_Llama_perf_host( if t3k_mesh_device.get_num_devices() < n_devices and not emulated: pytest.skip(f"Requires at {n_devices} devices to run") - compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py b/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py index 621fcd2b3a7..33bf35a2685 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py @@ -143,7 +143,7 @@ def test_Llama_stress_test( if t3k_mesh_device.get_num_devices() < n_devices and not emulated: pytest.skip(f"Requires at {n_devices} devices to run") - compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index 760e0279656..c57f1a2a339 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -189,7 +189,7 @@ def check_mesh_device(t3k_mesh_device, model_config): model_config["NUM_DEVICES"], ) - compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size() assert not ( compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1] ), ("Requires grid size of at least %d to run", model_config["MAX_GRID_SIZE"]) diff --git a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py index cb4da029cdb..7c051294668 100644 --- a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py @@ -472,7 +472,7 @@ def attn_mqa( value_layer.deallocate(True) program_config = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=self.mesh_device.compute_with_storage_grid_size(), + compute_with_storage_grid_size=self.mesh_device.get_device(0).compute_with_storage_grid_size(), q_chunk_size=0, # unused k_chunk_size=0, # unused ) diff --git a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py index 4514be7af5f..ba4b465f189 100644 --- a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py @@ -54,8 +54,8 @@ def get_mlp_model_config(self, mode): ttnn.CoreRange( ttnn.CoreCoord(0, 0), ttnn.CoreCoord( - self.mesh_device.dram_grid_size().x - 1, - self.mesh_device.dram_grid_size().y - 1, + self.mesh_device.get_device(0).dram_grid_size().x - 1, + self.mesh_device.get_device(0).dram_grid_size().y - 1, ), ) } diff --git a/models/utility_functions.py b/models/utility_functions.py index 699237c1778..877249cd8c8 100644 --- a/models/utility_functions.py +++ b/models/utility_functions.py @@ -971,10 +971,9 @@ def get_devices_for_t3000(all_devices, num_devices): if num_devices <= 4: return all_devices[:num_devices] elif num_devices == 8: - # Temporary until we move request for ring order to CCL operations directly. - # This is better because we no longer need to manually manage the ring order. - ring_indices = ttnn.get_t3k_physical_device_ids_ring() - return [all_devices[i] for i in ring_indices] + # TODO: Generalize this for different arch + hamiltonian_ring_indices = [0, 4, 5, 1, 2, 6, 7, 3] + return [all_devices[i] for i in hamiltonian_ring_indices] else: raise NotImplementedError("Only supports 1, 2, 3, 4, and 8 chip configurations!") diff --git a/tests/scripts/tgg/run_tgg_unit_tests.sh b/tests/scripts/tgg/run_tgg_unit_tests.sh index aa5a2449042..fb58c9b9b21 100755 --- a/tests/scripts/tgg/run_tgg_unit_tests.sh +++ b/tests/scripts/tgg/run_tgg_unit_tests.sh @@ -7,7 +7,6 @@ run_tgg_tests() { TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*" ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*" - pytest -s tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py } main() { diff --git a/tests/sweep_framework/README.md b/tests/sweep_framework/README.md index 0a7468523fa..85d779fea2a 100644 --- a/tests/sweep_framework/README.md +++ b/tests/sweep_framework/README.md @@ -190,11 +190,11 @@ def mesh_device_fixture(): assert ttnn.get_num_devices() >= 8, "Not T3000!" - device_ids = ttnn.get_t3k_physical_device_ids_ring() + device_ids = [0, 4, 5, 1, 2, 6, 7, 3] num_devices_requested = len(device_ids) mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(1, num_devices_requested), + ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested] ) print("ADD: Opened device mesh") diff --git a/tests/sweep_framework/sweeps/line_all_gather.py b/tests/sweep_framework/sweeps/line_all_gather.py index 8172c701c62..f5d2ebeb25c 100644 --- a/tests/sweep_framework/sweeps/line_all_gather.py +++ b/tests/sweep_framework/sweeps/line_all_gather.py @@ -62,8 +62,9 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: def mesh_device_fixture(): + assert ttnn.get_num_devices() >= 8, "Not T3000!" - device_ids = ttnn.get_t3k_physical_device_ids_ring() + device_ids = [0, 4, 5, 1, 2, 6, 7, 3] num_devices_requested = len(device_ids) mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested]) print("ALL GATHER: Opened device mesh") diff --git a/tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py b/tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py deleted file mode 100644 index d5f4b8ed905..00000000000 --- a/tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import pytest -import ttnn - - -@pytest.mark.parametrize("mesh_device", [pytest.param((8, 8), id="8x8_grid")], indirect=True) -def test_visualize_mesh_device(mesh_device): - ttnn.visualize_mesh_device(mesh_device) diff --git a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py index a4bee53b464..6112ff3d5a0 100644 --- a/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py +++ b/tests/ttnn/multichip_unit_tests/test_multidevice_TG.py @@ -224,7 +224,9 @@ def test_galaxy_matmul_2d_fracture_dram_sharded(M, K, N, weights_dtype, mesh_sha { ttnn.CoreRange( ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(mesh_device.dram_grid_size().x - 1, mesh_device.dram_grid_size().y - 1), + ttnn.CoreCoord( + mesh_device.get_device(0).dram_grid_size().x - 1, mesh_device.get_device(0).dram_grid_size().y - 1 + ), ) } ) @@ -735,7 +737,7 @@ def test_fill_cache( xt = x - compute_grid_size = mesh_device.compute_with_storage_grid_size() + compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size() num_cores = min(seq_len // 32 * num_heads, 32) # Always use max 32 cores for testing mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True)) input_shard_spec = ttnn.ShardSpec( @@ -809,7 +811,7 @@ def test_update_cache_decode( x_new = torch.cat((x_new, torch.zeros(32 - num_users - batch_offset, num_heads, 1, head_dim)), dim=0) assert x_new.shape[0] == 32, f"Expected x.shape[0] to be 32, got {x_new.shape[0]}" xt = x_new.permute(2, 1, 0, 3) - compute_grid_size = mesh_device.compute_with_storage_grid_size() + compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size() num_cores = min(max(num_users, 32) // 32 * num_heads, compute_grid_size.x * compute_grid_size.y) mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True)) input_shard_spec = ttnn.ShardSpec( @@ -891,7 +893,7 @@ def run_test_sdpa_decode_single_iter( sharded_in=False, sharded_out=False, ): - compute_grid_size = mesh_device.compute_with_storage_grid_size() + compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size() if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}") diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp index 51d1b6c8d9c..2c6059d78a8 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp @@ -55,8 +55,17 @@ TEST(TGTests, TestAllGatherDeadlock) { TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4"); TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy"); + std::vector all_device_ids = {}; + for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) { + auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx); + for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) { + auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx); + all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end()); + } + } + // Create the device mesh: Grid size is . - auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER); + auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER); // Setup input data and output data containers MemoryConfig mem_cfg = MemoryConfig{ @@ -78,7 +87,7 @@ TEST(TGTests, TestAllGatherDeadlock) { // Iterate over each tunnel and run line all-gather multiple times. // For each tunnel, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged. for (uint32_t row = 0; row < 8; row++) { - auto devs = mesh->get_devices_on_row(row); + auto devs = mesh.get_devices_on_row(row); std::vector device_ids = {}; for (auto dev : devs) { device_ids.push_back(dev->id()); @@ -137,17 +146,26 @@ TEST(TGTests, TestReduceScatterDeadlock) { TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4"); TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy"); + std::vector all_device_ids = {}; + for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) { + auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx); + for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) { + auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx); + all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end()); + } + } + // Create the device mesh: Grid size is . - auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER); + auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER); // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the // first tunnel (forward path). - std::vector ring_devices = mesh->get_devices_on_row(0); // Tunnel 0 - std::vector ring_devices_1 = mesh->get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices = mesh.get_devices_on_row(0); // Tunnel 0 + std::vector ring_devices_1 = mesh.get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); - std::vector ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering + std::vector ring_devices_2 = mesh.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering std::reverse(ring_devices_2.begin(), ring_devices_2.end()); ring_devices_2 = std::vector(ring_devices_2.begin() + 1, ring_devices_2.end()); - std::vector ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices_3 = mesh.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks std::reverse(ring_devices_3.begin(), ring_devices_3.end()); ring_devices_3 = std::vector(ring_devices_3.begin() + 1, ring_devices_3.end() - 1); diff --git a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp index 9e779b7f0cc..a645c7876db 100644 --- a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp +++ b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp @@ -67,20 +67,19 @@ class T3kMultiDeviceFixture : public ::testing::Test { if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) { GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine."; } + const auto T3K_DEVICE_IDS = DeviceIds{0, 4, 5, 1, 2, 6, 7, 3}; constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1; - mesh_device_ = MeshDevice::create( - MeshShape{2, 4}, + mesh_device_ = std::make_unique( + MeshShape{1, num_devices}, + T3K_DEVICE_IDS, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, DEFAULT_NUM_COMMAND_QUEUES, DispatchCoreType::WORKER); } - void TearDown() override { - mesh_device_->close_devices(); - mesh_device_.reset(); - } - std::shared_ptr mesh_device_; + void TearDown() override { mesh_device_.reset(); } + std::unique_ptr mesh_device_; }; } // namespace ttnn::multi_device::test diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index f1c2728857f..e286993f855 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -23,7 +23,8 @@ def test_mesh_device_open_close_explicit(silicon_arch_name, silicon_arch_wormhol if num_pcie_devices <= 1: pytest.skip("Requires multiple devices to run") - multi_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 2)) + mesh_shape, device_ids = ttnn.MeshShape(2, 2), ttnn.get_pcie_device_ids() + multi_device = ttnn.open_mesh_device(mesh_shape, device_ids) ttnn.close_mesh_device(multi_device) @@ -33,12 +34,12 @@ def test_multi_device_subset_mesh(silicon_arch_name, silicon_arch_wormhole_b0): if num_pcie_devices <= 1: pytest.skip("Requires multiple devices to run") - mesh_shape = ttnn.MeshShape(1, 2) - multi_device = ttnn.open_mesh_device(mesh_shape) + mesh_shape, device_ids = ttnn.MeshShape(1, 2), ttnn.get_pcie_device_ids() + multi_device = ttnn.open_mesh_device(mesh_shape, device_ids) assert multi_device.get_num_devices() == 2 ttnn.close_mesh_device(multi_device) - multi_device = ttnn.open_mesh_device(mesh_shape) + multi_device = ttnn.open_mesh_device(mesh_shape, device_ids) assert multi_device.get_num_devices() == 2 ttnn.close_mesh_device(multi_device) diff --git a/tt_metal/impl/device/mesh_configurations/N300.json b/tt_metal/impl/device/mesh_configurations/N300.json deleted file mode 100644 index 9891c42cc11..00000000000 --- a/tt_metal/impl/device/mesh_configurations/N300.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "logical_to_physical_coordinates": [ - [[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]] - ] -} diff --git a/tt_metal/impl/device/mesh_configurations/T3000.json b/tt_metal/impl/device/mesh_configurations/T3000.json deleted file mode 100644 index 2c62209d01f..00000000000 --- a/tt_metal/impl/device/mesh_configurations/T3000.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "logical_to_physical_coordinates": [ - [[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]], - [[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]] - ] -} diff --git a/tt_metal/impl/device/mesh_configurations/TG.json b/tt_metal/impl/device/mesh_configurations/TG.json deleted file mode 100644 index 63860eff6b6..00000000000 --- a/tt_metal/impl/device/mesh_configurations/TG.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "logical_to_physical_coordinates": [ - [[0, 0], [7, 3, 0, 1]], - [[0, 1], [7, 2, 0, 1]], - [[0, 2], [7, 1, 0, 1]], - [[0, 3], [7, 0, 0, 1]], - [[1, 0], [6, 3, 0, 1]], - [[1, 1], [6, 2, 0, 1]], - [[1, 2], [6, 1, 0, 1]], - [[1, 3], [6, 0, 0, 1]], - [[2, 0], [5, 3, 0, 1]], - [[2, 1], [5, 2, 0, 1]], - [[2, 2], [5, 1, 0, 1]], - [[2, 3], [5, 0, 0, 1]], - [[3, 0], [4, 3, 0, 1]], - [[3, 1], [4, 2, 0, 1]], - [[3, 2], [4, 1, 0, 1]], - [[3, 3], [4, 0, 0, 1]], - [[4, 0], [3, 3, 0, 1]], - [[4, 1], [3, 2, 0, 1]], - [[4, 2], [3, 1, 0, 1]], - [[4, 3], [3, 0, 0, 1]], - [[5, 0], [2, 3, 0, 1]], - [[5, 1], [2, 2, 0, 1]], - [[5, 2], [2, 1, 0, 1]], - [[5, 3], [2, 0, 0, 1]], - [[6, 0], [1, 3, 0, 1]], - [[6, 1], [1, 2, 0, 1]], - [[6, 2], [1, 1, 0, 1]], - [[6, 3], [1, 0, 0, 1]], - [[7, 0], [0, 3, 0, 1]], - [[7, 1], [0, 2, 0, 1]], - [[7, 2], [0, 1, 0, 1]], - [[7, 3], [0, 0, 0, 1]] - ] -} diff --git a/tt_metal/impl/device/mesh_configurations/TGG.json b/tt_metal/impl/device/mesh_configurations/TGG.json deleted file mode 100644 index 904acf6758d..00000000000 --- a/tt_metal/impl/device/mesh_configurations/TGG.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "logical_to_physical_coordinates": [ - [[0, 0], [7, 3, 0, 1]], - [[0, 1], [7, 2, 0, 1]], - [[0, 2], [7, 1, 0, 1]], - [[0, 3], [7, 0, 0, 1]], - [[1, 0], [6, 3, 0, 1]], - [[1, 1], [6, 2, 0, 1]], - [[1, 2], [6, 1, 0, 1]], - [[1, 3], [6, 0, 0, 1]], - [[2, 0], [5, 3, 0, 1]], - [[2, 1], [5, 2, 0, 1]], - [[2, 2], [5, 1, 0, 1]], - [[2, 3], [5, 0, 0, 1]], - [[3, 0], [4, 3, 0, 1]], - [[3, 1], [4, 2, 0, 1]], - [[3, 2], [4, 1, 0, 1]], - [[3, 3], [4, 0, 0, 1]], - [[4, 0], [3, 3, 0, 1]], - [[4, 1], [3, 2, 0, 1]], - [[4, 2], [3, 1, 0, 1]], - [[4, 3], [3, 0, 0, 1]], - [[5, 0], [2, 3, 0, 1]], - [[5, 1], [2, 2, 0, 1]], - [[5, 2], [2, 1, 0, 1]], - [[5, 3], [2, 0, 0, 1]], - [[6, 0], [1, 3, 0, 1]], - [[6, 1], [1, 2, 0, 1]], - [[6, 2], [1, 1, 0, 1]], - [[6, 3], [1, 0, 0, 1]], - [[7, 0], [0, 3, 0, 1]], - [[7, 1], [0, 2, 0, 1]], - [[7, 2], [0, 1, 0, 1]], - [[7, 3], [0, 0, 0, 1]], - [[0, 4], [0, 0, 0, 2]], - [[0, 5], [0, 1, 0, 2]], - [[0, 6], [0, 2, 0, 2]], - [[0, 7], [0, 3, 0, 2]], - [[1, 4], [1, 0, 0, 2]], - [[1, 5], [1, 1, 0, 2]], - [[1, 6], [1, 2, 0, 2]], - [[1, 7], [1, 3, 0, 2]], - [[2, 4], [2, 0, 0, 2]], - [[2, 5], [2, 1, 0, 2]], - [[2, 6], [2, 2, 0, 2]], - [[2, 7], [2, 3, 0, 2]], - [[3, 4], [3, 0, 0, 2]], - [[3, 5], [3, 1, 0, 2]], - [[3, 6], [3, 2, 0, 2]], - [[3, 7], [3, 3, 0, 2]], - [[4, 4], [4, 0, 0, 2]], - [[4, 5], [4, 1, 0, 2]], - [[4, 6], [4, 2, 0, 2]], - [[4, 7], [4, 3, 0, 2]], - [[5, 4], [5, 0, 0, 2]], - [[5, 5], [5, 1, 0, 2]], - [[5, 6], [5, 2, 0, 2]], - [[5, 7], [5, 3, 0, 2]], - [[6, 4], [6, 0, 0, 2]], - [[6, 5], [6, 1, 0, 2]], - [[6, 6], [6, 2, 0, 2]], - [[6, 7], [6, 3, 0, 2]], - [[7, 4], [7, 0, 0, 2]], - [[7, 5], [7, 1, 0, 2]], - [[7, 6], [7, 2, 0, 2]], - [[7, 7], [7, 3, 0, 2]] - ] - } diff --git a/tt_metal/impl/device/mesh_configurations/device.json b/tt_metal/impl/device/mesh_configurations/device.json deleted file mode 100644 index ea1a34365f0..00000000000 --- a/tt_metal/impl/device/mesh_configurations/device.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "logical_to_physical_coordinates": [ - [[0, 0], [0, 0, 0, 0]] - ] -} diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index e90d4a8925e..e43ddc6c4fe 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -2,302 +2,186 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt_metal/impl/device/mesh_device.hpp" - #include -#include -#include "device/tt_cluster_descriptor_types.h" -#include "tt_metal/common/logger.hpp" -#include "tt_metal/detail/tt_metal.hpp" -#include "tt_metal/host_api.hpp" +#include "tt_metal/impl/device/mesh_device.hpp" #include "tt_metal/impl/device/mesh_device_view.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/detail/tt_metal.hpp" -namespace tt::tt_metal { - -using LogicalCoordinate = Coordinate; -using PhysicalCoordinate = eth_coord_t; - -static std::string get_config_path(const std::string& filename) { - std::string root_path = getenv("TT_METAL_HOME") ? getenv("TT_METAL_HOME") : "./"; - return root_path + "/tt_metal/impl/device/mesh_configurations/" + filename; -} - -static std::map load_translation_map(const std::string& filename, const std::string& key) { - std::ifstream file(filename); - if (!file.is_open()) { - throw std::runtime_error("Unable to open file: " + filename); - } - - nlohmann::json j; - try { - file >> j; - } catch (const nlohmann::json::parse_error& e) { - throw std::runtime_error("JSON parsing error in file " + filename + ": " + e.what()); - } - if (!j.contains(key)) { - throw std::runtime_error("Key '" + key + "' not found in JSON file: " + filename); - } +namespace tt::tt_metal { - std::map result; - for (const auto& mapping : j[key]) { - if (mapping.size() != 2 || mapping[0].size() != 2 || mapping[1].size() != 4) { - throw std::runtime_error("Invalid coordinate format in JSON file: " + filename); +MeshDevice::MeshDevice(const MeshShape& mesh_shape, const DeviceIds &device_ids, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type) + : mesh_shape(mesh_shape) +{ + auto [num_rows, num_cols] = mesh_shape; + auto num_requested_devices = num_rows * num_cols; + auto num_available_devices = tt::tt_metal::GetNumAvailableDevices(); + TT_ASSERT(num_requested_devices <= num_available_devices, "Requested more devices than available"); + TT_ASSERT(num_requested_devices <= device_ids.size(), "User provided insufficient number of device_ids for MeshDevice"); + + this->is_galaxy_ = tt::Cluster::instance().is_galaxy_cluster(); + if (this->is_galaxy_) { + // Temp solution until we add algorithmic way to determine chip connectivity + // Map col to tunnel depth and row to tunnel count + int cluster_tunnel_depth = tt::Cluster::instance().get_mmio_device_max_tunnel_depth(0); + int cluster_tunnel_count = tt::Cluster::instance().get_mmio_device_tunnel_count(0); + int num_mmio_devices = tt::Cluster::instance().number_of_pci_devices(); + TT_FATAL(num_cols <= cluster_tunnel_depth and num_rows <= cluster_tunnel_count * num_mmio_devices, "Unsupported Galaxy mesh shape"); + + DeviceIds galaxy_device_ids; + for (int mmio_device_id = 0; mmio_device_id < num_mmio_devices; mmio_device_id++) { + auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_device_id); + for (uint32_t t = 0; t < tunnels_from_mmio.size(); t++) { + if (galaxy_device_ids.size() == num_requested_devices) { + break; + } + int col_idx = 0; + for (uint32_t ts = 1; ts < tunnels_from_mmio[t].size(); ts++) { + galaxy_device_ids.push_back(tunnels_from_mmio[t][ts]); + col_idx ++; + if (col_idx == num_cols) { + break; + } + } + } + } + managed_devices = tt::tt_metal::detail::CreateDevices(galaxy_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type); + for (int i = 0; i < num_requested_devices; i++) { + mesh_devices.emplace_back(device_ids[i], managed_devices.at(galaxy_device_ids[i])); + } + this->view = std::make_unique(*this); + } else { + managed_devices = tt::tt_metal::detail::CreateDevices(device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type); + for (int i = 0; i < num_requested_devices; i++) { + mesh_devices.emplace_back(device_ids[i], managed_devices.at(device_ids[i])); } - result.emplace(LogicalCoordinate{mapping[0][0], mapping[0][1]}, PhysicalCoordinate{mapping[1][1], mapping[1][0], mapping[1][2], mapping[1][3]}); } - return result; -} - -MeshShape SystemMesh::get_system_mesh_shape(std::size_t system_num_devices) { - const std::unordered_map system_mesh_to_shape = { - {1, MeshShape{1, 1}}, // single-device - {2, MeshShape{1, 2}}, // N300 - {8, MeshShape{2, 4}}, // T3000; as ring to match existing tests - {32, MeshShape{8, 4}}, // TG - {64, MeshShape{8, 8}}, // TGG - }; - TT_FATAL(system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); - auto shape = system_mesh_to_shape.at(system_num_devices); - log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.first, shape.second); - return shape; -} - -std::map SystemMesh::get_system_mesh_translation_map(std::size_t system_num_devices) { - const std::unordered_map system_mesh_translation_map = { - {1, "device.json"}, - {2, "N300.json"}, - {8, "T3000.json"}, - {32, "TG.json"}, - {64, "TGG.json"}, - }; - TT_FATAL(system_mesh_translation_map.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); - auto translation_config_file = get_config_path(system_mesh_translation_map.at(system_num_devices)); - return load_translation_map(translation_config_file, "logical_to_physical_coordinates"); -} - -bool SystemMesh::is_system_mesh_initialized() const { - return this->physical_coordinate_to_device_id.size() > 0; -} - -SystemMesh& SystemMesh::instance() { - static SystemMesh instance; - if (!instance.is_system_mesh_initialized()) { - instance.initialize(); + for (const auto& [dev_id, dev]: mesh_devices) { + log_debug(tt::LogMetal, "TTNN Dev {}: Metal Dev {}", dev_id, dev->id()); } - return instance; } -void SystemMesh::initialize() { - this->physical_device_id_to_coordinate = tt::Cluster::instance().get_user_chip_ethernet_coordinates(); - for (const auto& [chip_id, physical_coordinate] : this->physical_device_id_to_coordinate) { - this->physical_coordinate_to_device_id.emplace(physical_coordinate, chip_id); - } - // Initialize the system mesh shape and translation map - auto num_devices = physical_coordinate_to_device_id.size(); - this->logical_mesh_shape = SystemMesh::get_system_mesh_shape(num_devices); - this->logical_to_physical_coordinates = SystemMesh::get_system_mesh_translation_map(num_devices); -} -const MeshShape& SystemMesh::get_shape() const { return this->logical_mesh_shape; } -std::size_t SystemMesh::get_num_devices() const { - auto [num_rows, num_cols] = this->get_shape(); - return num_rows * num_cols; +MeshDevice::~MeshDevice() { + if (not managed_devices.empty()) { + close_devices(); + } } -std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const { - std::vector physical_device_ids; - auto [system_mesh_rows, system_mesh_cols] = this->get_shape(); - auto [requested_rows, requested_cols] = config.mesh_shape; - auto [row_offset, col_offset] = config.offset; - - for (int row = 0; row < requested_rows; row++) { - for (int col = 0; col < requested_cols; col++) { - auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset); - auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; - auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate); - auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate); - physical_device_ids.push_back(physical_device_id); - - log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", - logical_device_id, logical_coordinate, physical_coordinate, physical_device_id); +Device* MeshDevice::get_device(int logical_device_id) const { + for (const auto& [device_id, device] : mesh_devices) { + if (device_id == logical_device_id) { + return device; } } - return physical_device_ids; + TT_THROW("User has provided an invalid device index"); } -std::vector SystemMesh::map_mesh_device( - std::shared_ptr mesh_device, - size_t num_command_queues, - size_t l1_small_size, - size_t trace_region_size, - DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& user_provided_physical_device_ids) { - - auto [requested_num_rows, requested_num_cols] = mesh_device->shape(); - auto [max_num_rows, max_num_cols] = this->logical_mesh_shape; - auto [row_offset, col_offset] = offset; - - log_debug(LogMetal, "Mapping MeshDevice ({}x{}) with offset: {}, {}", requested_num_rows, requested_num_cols, row_offset, col_offset); - TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); - TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols); - - this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); - - auto physical_device_ids = user_provided_physical_device_ids.empty() ? - this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) : - user_provided_physical_device_ids; - - this->opened_devices[mesh_device->get_mesh_id()] = tt::tt_metal::detail::CreateDevices( - physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type); - - std::vector mapped_devices; - for (auto physical_device_id : physical_device_ids) { - auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id); - mapped_devices.push_back(mapped_device); - this->assigned_devices[mesh_device->get_mesh_id()].push_back(physical_device_id); - this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); +std::vector MeshDevice::get_devices() const +{ + std::vector devices; + for (const auto& [device_id, device] : mesh_devices) { + devices.push_back(device); } - return mapped_devices; + return devices; } -void SystemMesh::unmap_mesh_device(const std::shared_ptr& mesh_device) { - auto mesh_id = mesh_device->get_mesh_id(); - - // Clean up all state related to this virtual mesh - this->assigned_mesh_device_devices.erase(mesh_id); - - // Remove the devices from assigned_physical_id_to_device - for (auto physical_id : this->assigned_devices.at(mesh_id)) { - this->assigned_physical_id_to_device.erase(physical_id); +Device* MeshDevice::get_device(int row_idx, int col_idx) const { + if (not is_galaxy_) { + TT_THROW("Non-galaxy device mesh does not currently support indexing over rows and columns of a logical 2D mesh."); } - this->assigned_devices.erase(mesh_id); - // Close the devices - tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); - this->opened_devices.erase(mesh_id); + TT_FATAL( + this->num_rows() != 0 and this->num_cols() != 0, + "#10419, Current device mesh does not support indexing by row or col indices."); + TT_FATAL(row_idx >= 0 and row_idx < this->num_rows(), "Invalid row index."); + TT_FATAL(col_idx >= 0 and col_idx < this->num_cols(), "Invalid col index."); + int idx = row_idx * this->num_cols() + col_idx; + return this->mesh_devices[idx].second; } -static MeshDeviceID generate_unique_mesh_id() { - static std::atomic next_id{0}; - return next_id++; +std::vector MeshDevice::get_devices_on_row(int row_idx) const { + if (not is_galaxy_) { + TT_THROW("Non-galaxy device mesh does not currently support indexing over rows and columns of a logical 2D mesh."); + } + return this->view->get_devices_on_row(row_idx); } -MeshDevice::MeshDevice(const MeshShape& mesh_device_shape) : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()) {} - -std::shared_ptr MeshDevice::create( - const MeshShape& mesh_device_shape, - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, - DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& user_provided_physical_device_ids) -{ - auto mesh_device = std::make_shared(mesh_device_shape); - mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset, user_provided_physical_device_ids); - - return mesh_device; +std::vector MeshDevice::get_devices_on_column(int col_idx) const { + if (not is_galaxy_) { + TT_THROW("Non-galaxy device mesh does not currently support indexing over rows and columns of a logical 2D mesh."); + } + return this->view->get_devices_on_column(col_idx); } -void MeshDevice::initialize( - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, - DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& physical_device_ids) +const DeviceIds MeshDevice::get_device_ids() const { - auto [num_rows, num_cols] = this->shape(); - auto num_requested_devices = num_rows * num_cols; - auto num_available_devices = tt::tt_metal::GetNumAvailableDevices(); - TT_FATAL( - num_requested_devices <= num_available_devices, - "User has requested more devices than available: {} requested, {} available", - num_requested_devices, num_available_devices); - - auto& instance = SystemMesh::instance(); - this->devices = instance.map_mesh_device( - shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids); - this->primary_view = std::make_unique(*this); - - for (int device_index = 0; device_index < this->devices.size(); device_index++) { - this->physical_id_to_device_index.insert({this->devices[device_index]->id(), device_index}); + DeviceIds device_ids; + for (const auto& [device_id, device] : mesh_devices) { + device_ids.push_back(device_id); } + return device_ids; } -MeshDevice::~MeshDevice() { - if (not this->devices.empty()) { - this->close_devices(); - } +int MeshDevice::num_devices() const +{ + return mesh_devices.size(); } -Device* MeshDevice::get_device_index(int logical_device_id) const { - TT_FATAL(logical_device_id >= 0 and logical_device_id < num_devices(), "Invalid device index"); - return this->devices.at(logical_device_id); +CoreCoord MeshDevice::compute_with_storage_grid_size() const { + return mesh_devices.at(0).second->compute_with_storage_grid_size(); } -Device* MeshDevice::get_device(int physical_device_id) const { - return this->devices.at(this->physical_id_to_device_index.at(physical_device_id)); +CoreCoord MeshDevice::dram_grid_size() const { + return mesh_devices.at(0).second->dram_grid_size(); } -std::vector MeshDevice::get_devices() const { return this->devices; } - -Device* MeshDevice::get_device(int row_idx, int col_idx) const { - return this->get_device_index(row_idx * num_cols() + col_idx); +tt::ARCH MeshDevice::arch() const { + return mesh_devices.at(0).second->arch(); } -std::vector MeshDevice::get_devices_on_row(int row_idx) const { - return this->primary_view->get_devices_on_row(row_idx); +int MeshDevice::num_rows() const +{ + return this->mesh_shape.first; } -std::vector MeshDevice::get_devices_on_column(int col_idx) const { - return this->primary_view->get_devices_on_column(col_idx); +int MeshDevice::num_cols() const +{ + return this->mesh_shape.second; } -const DeviceIds MeshDevice::get_device_ids() const { - DeviceIds device_ids; - for (auto device : this->get_devices()) { - device_ids.push_back(device->id()); - } - return device_ids; +MeshShape MeshDevice::shape() const +{ + return this->mesh_shape; } -int MeshDevice::num_devices() const { return num_rows() * num_cols(); } - -CoreCoord MeshDevice::compute_with_storage_grid_size() const { return get_device_index(0)->compute_with_storage_grid_size(); } - -CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_grid_size(); } - -tt::ARCH MeshDevice::arch() const { return get_device_index(0)->arch(); } - -int MeshDevice::num_rows() const { return this->mesh_device_shape.first; } - -int MeshDevice::num_cols() const { return this->mesh_device_shape.second; } - -MeshShape MeshDevice::shape() const { return this->mesh_device_shape; } - void MeshDevice::close_devices() { - SystemMesh::instance().unmap_mesh_device(shared_from_this()); - this->devices.clear(); - this->physical_id_to_device_index.clear(); - this->primary_view.reset(); + tt::tt_metal::detail::CloseDevices(managed_devices); + mesh_devices.clear(); + managed_devices.clear(); } std::string MeshDevice::to_string() const { - return fmt::format("MeshDevice({}x{} grid, {} devices)", this->num_rows(), this->num_cols(), this->num_devices()); + return fmt::format("MeshDevice({}x{} grid, {} devices)", + this->num_rows(), + this->num_cols(), + this->num_devices()); } -std::shared_ptr MeshDevice::get_view() const { return this->primary_view; } - -std::shared_ptr MeshDevice::get_view() { return this->primary_view; } +std::shared_ptr MeshDevice::get_view() const { + return this->view; +} -MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } +std::shared_ptr MeshDevice::get_view() { + return this->view; +} -std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); } +std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { + return os << mesh_device.to_string(); +} bool validate_worker_modes(const std::vector& workers) { bool worker_modes_match = true; @@ -308,13 +192,4 @@ bool validate_worker_modes(const std::vector& workers) { return worker_modes_match; } -std::vector get_t3k_physical_device_ids_ring() { - auto& instance = SystemMesh::instance(); - auto num_devices = instance.get_num_devices(); - TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices"); - - auto physical_device_ids = instance.get_mapped_physical_device_ids(MeshDeviceConfig{instance.get_shape(), MeshOffset{0, 0}}); - return physical_device_ids; -} - -} // namespace tt::tt_metal +} // namespace tt::tt_metal diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 940110973cc..6ba0c336aaf 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -4,106 +4,28 @@ #pragma once -#include #include -#include #include +#include +#include -#include "mesh_device_view.hpp" #include "tt_metal/impl/device/device.hpp" #include "tt_metal/impl/device/mesh_device_view.hpp" namespace tt::tt_metal { using DeviceIds = std::vector; -using MeshDeviceID = std::size_t; -using MeshOffset = std::pair; class MeshDeviceView; -struct MeshDeviceConfig { +class MeshDevice +{ +public: MeshShape mesh_shape; - MeshOffset offset; -}; + std::map managed_devices; + std::vector> mesh_devices; + std::shared_ptr view; -// SystemMesh creates a virtualization over the physical devices in the system. -// It creates a logical 2D-mesh of devices and manages the mapping between logical and physical device coordinates. -// It is responsible for the assignment of devices in a MeshDevice to physical devices, and the creation and deletion of -// device resources. -class SystemMesh { - private: - using LogicalCoordinate = Coordinate; - using PhysicalCoordinate = eth_coord_t; - - // Keep track of the devices that were opened so we can close them later. We shouldn't - // to keep track of this but DevicePool seems to open all devices associated with an MMIO device id - std::unordered_map> opened_devices; - std::unordered_map> assigned_devices; - std::unordered_map> assigned_mesh_device_devices; - std::unordered_map assigned_physical_id_to_device; - - // Logical mesh shape and coordinates - MeshShape logical_mesh_shape; - std::map logical_to_physical_coordinates; - - // Handling of physical coordinates - std::unordered_map physical_coordinate_to_device_id; - std::unordered_map physical_device_id_to_coordinate; - - SystemMesh() = default; - SystemMesh(const SystemMesh &) = delete; - SystemMesh &operator=(const SystemMesh &) = delete; - SystemMesh(SystemMesh &&) = delete; - SystemMesh &operator=(SystemMesh &&) = delete; - - static MeshShape get_system_mesh_shape(std::size_t system_num_devices); - static std::map get_system_mesh_translation_map( - std::size_t system_num_devices); - - bool is_system_mesh_initialized() const; - - public: - static SystemMesh &instance(); - - void initialize(); - - // Return the shape of the logical mesh - const MeshShape &get_shape() const; - std::size_t get_num_devices() const; - - // Get the physical device IDs mapped to a MeshDevice - std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; - - // Map MeshDevice to physical devices - std::vector map_mesh_device( - std::shared_ptr mesh_device, - size_t num_command_queues, - size_t l1_small_size, - size_t trace_region_size, - DispatchCoreType dispatch_core_type, - const std::pair &offset = {0, 0}, - const std::vector &physical_device_ids = {}); - - // Unmap MeshDevice, releasing the associated physical devices. - void unmap_mesh_device(const std::shared_ptr &mesh_device); -}; - -class MeshDevice : public std::enable_shared_from_this { - MeshDeviceID mesh_id; - MeshShape mesh_device_shape; - std::shared_ptr primary_view; - std::vector devices; - std::unordered_map physical_id_to_device_index; - - void initialize( - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, - DispatchCoreType dispatch_core_type, - const std::pair &offset, - const std::vector &physical_device_ids); - - public: - MeshDevice(const MeshShape &mesh_device_shape); + MeshDevice(const MeshShape &mesh_shape, const DeviceIds &device_ids, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type); ~MeshDevice(); MeshDevice(const MeshDevice &) = delete; @@ -112,9 +34,8 @@ class MeshDevice : public std::enable_shared_from_this { MeshDevice(MeshDevice &&) = delete; MeshDevice &operator=(MeshDevice &&) = delete; - std::vector get_devices() const; - Device *get_device_index(int logical_device_id) const; - Device *get_device(int physical_device_id) const; + std::vector get_devices() const; + Device *get_device(int logical_device_id) const; Device *get_device(int row_idx, int col_idx) const; std::vector get_devices_on_row(int row_idx) const; std::vector get_devices_on_column(int col_idx) const; @@ -137,20 +58,12 @@ class MeshDevice : public std::enable_shared_from_this { std::shared_ptr get_view(); std::string to_string() const; - MeshDeviceID get_mesh_id() const; - - static std::shared_ptr create( - const MeshShape &mesh_device_shape, - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, - DispatchCoreType dispatch_core_type, - const std::pair &offset = {0, 0}, - const std::vector &physical_device_ids = {}); + + private: + bool is_galaxy_; }; -std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device); -bool validate_worker_modes(const std::vector &workers); -std::vector get_t3k_physical_device_ids_ring(); +std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device); +bool validate_worker_modes(const std::vector& workers); -} // namespace tt::tt_metal +} // namespace tt::tt_metal diff --git a/tt_metal/llrt/tt_cluster.cpp b/tt_metal/llrt/tt_cluster.cpp index 1483bd38bbd..f8139ce26f5 100644 --- a/tt_metal/llrt/tt_cluster.cpp +++ b/tt_metal/llrt/tt_cluster.cpp @@ -325,16 +325,6 @@ tt_device &Cluster::get_driver(chip_id_t device_id) const { return *(this->mmio_device_id_to_driver_.at(mmio_device_id)); } -std::unordered_map Cluster::get_user_chip_ethernet_coordinates() const { - auto user_chip_ethernet_coordinates = this->cluster_desc_->get_chip_locations(); - if (this->is_galaxy_cluster()) { - std::erase_if(user_chip_ethernet_coordinates, [this](const auto& entry) { - return this->cluster_desc_->get_board_type(entry.first) != BoardType::GALAXY; - }); - } - return user_chip_ethernet_coordinates; -} - const metal_SocDescriptor &Cluster::get_soc_desc(chip_id_t chip) const { if (this->sdesc_per_chip_.find(chip) == this->sdesc_per_chip_.end()) { TT_THROW( diff --git a/tt_metal/llrt/tt_cluster.hpp b/tt_metal/llrt/tt_cluster.hpp index 5738163f02d..a28aa2750f8 100644 --- a/tt_metal/llrt/tt_cluster.hpp +++ b/tt_metal/llrt/tt_cluster.hpp @@ -57,8 +57,6 @@ class Cluster { } } - std::unordered_map get_user_chip_ethernet_coordinates() const; - size_t number_of_devices() const { return this->cluster_desc_->get_number_of_chips(); } size_t number_of_pci_devices() const { return this->cluster_desc_->get_chips_with_mmio().size(); } diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index 70d9755d040..7b9fbff52a8 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -16,36 +16,20 @@ namespace ttnn { namespace multi_device { -void py_module_types(py::module& module) { py::class_>(module, "MeshDevice"); } +void py_module_types(py::module& module) { py::class_(module, "MeshDevice"); } void py_module(py::module& module) { - auto py_mesh_device = static_cast>>(module.attr("MeshDevice")); + auto py_mesh_device = static_cast>(module.attr("MeshDevice")); py_mesh_device .def( - py::init([](const MeshShape& mesh_device_shape, - size_t l1_small_size, - size_t trace_region_size, - size_t num_command_queues, - DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& physical_device_ids) { - return MeshDevice::create( - mesh_device_shape, - l1_small_size, - trace_region_size, - num_command_queues, - dispatch_core_type, - offset, - physical_device_ids); - }), + py::init, size_t, size_t, size_t, DispatchCoreType>(), py::kw_only(), py::arg("mesh_shape"), + py::arg("device_ids"), py::arg("l1_small_size"), py::arg("trace_region_size"), py::arg("num_command_queues"), - py::arg("dispatch_core_type"), - py::arg("offset"), - py::arg("physical_device_ids")) + py::arg("dispatch_core_type")) .def("get_num_devices", &MeshDevice::num_devices) .def("get_device_ids", &MeshDevice::get_device_ids) .def( @@ -122,11 +106,11 @@ void py_module(py::module& module) { &open_mesh_device, py::kw_only(), py::arg("mesh_shape"), + py::arg("device_ids"), py::arg("l1_small_size"), py::arg("trace_region_size"), py::arg("num_command_queues"), - py::arg("dispatch_core_type"), - py::arg("physical_device_ids")); + py::arg("dispatch_core_type")); module.def("close_mesh_device", &close_mesh_device, py::arg("mesh_device"), py::kw_only()); module.def( @@ -163,7 +147,6 @@ void py_module(py::module& module) { )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); module.def("aggregate_as_tensor", &aggregate_as_tensor, py::arg("tensors"), py::kw_only()); - module.def("get_t3k_physical_device_ids_ring", &tt::tt_metal::get_t3k_physical_device_ids_ring); } } // namespace multi_device diff --git a/ttnn/cpp/ttnn/events.cpp b/ttnn/cpp/ttnn/events.cpp index 5f207b1cfdf..f37efb94aa7 100644 --- a/ttnn/cpp/ttnn/events.cpp +++ b/ttnn/cpp/ttnn/events.cpp @@ -11,11 +11,11 @@ namespace ttnn::events { MultiDeviceEvent::MultiDeviceEvent(MeshDevice* mesh_device) { TT_ASSERT(mesh_device != nullptr, "Must provide a valid mesh_device when initializing an event on multiple devices."); - auto devices = mesh_device->get_devices(); + auto& devices = mesh_device->mesh_devices; this->events = std::vector>(devices.size()); for (int event_idx = 0; event_idx < devices.size(); event_idx++) { this->events[event_idx] = std::make_shared(); - this->events[event_idx]->device = devices[event_idx]; + this->events[event_idx]->device = devices[event_idx].second; } } diff --git a/ttnn/cpp/ttnn/multi_device.cpp b/ttnn/cpp/ttnn/multi_device.cpp index 7fa5f9e0d65..fe77a9f5fe4 100644 --- a/ttnn/cpp/ttnn/multi_device.cpp +++ b/ttnn/cpp/ttnn/multi_device.cpp @@ -12,12 +12,12 @@ namespace ttnn::multi_device { -std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset) { - return MeshDevice::create(mesh_shape, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset); +MeshDevice open_mesh_device(const MeshShape& mesh_shape, const DeviceIds& device_ids, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type) { + return MeshDevice(mesh_shape, device_ids, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type); } -void close_mesh_device(const std::shared_ptr& mesh_device) { - mesh_device->close_devices(); +void close_mesh_device(MeshDevice &multi_device) { + multi_device.close_devices(); } std::vector get_device_tensors(const ttnn::Tensor& tensor) { diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index d7db05721bc..b3821e2b5cb 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -15,15 +15,13 @@ using Device = ttnn::Device; namespace ttnn { namespace multi_device { -std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset = {0, 0}); -void close_mesh_device(const std::shared_ptr& mesh_device); +MeshDevice open_mesh_device(const MeshShape& mesh_shape, const DeviceIds& device_ids, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type); +void close_mesh_device(MeshDevice &multi_device); std::vector get_device_tensors(const ttnn::Tensor& tensor); Tensor aggregate_as_tensor(std::vector& tensor_shards); -std::vector get_t3k_physical_device_ids_ring(); - } // namespace multi_device using namespace multi_device; diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 0192b2b1673..e5423eb21ee 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -93,12 +93,7 @@ def manage_config(name, value): logger.debug(f"Restored ttnn.CONFIG.{name} to {original_value}") -from ttnn._ttnn.multi_device import ( - get_device_tensor, - get_device_tensors, - aggregate_as_tensor, - get_t3k_physical_device_ids_ring, -) +from ttnn._ttnn.multi_device import get_device_tensor, get_device_tensors, aggregate_as_tensor from ttnn._ttnn.events import create_event, record_event, wait_for_event diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/multi_device.py index 2f49af86422..263666f6398 100644 --- a/ttnn/ttnn/multi_device.py +++ b/ttnn/ttnn/multi_device.py @@ -5,7 +5,7 @@ import contextlib import functools -from typing import List, Dict, Optional, Callable, Tuple, Optional, Callable, Union, List +from typing import List, Dict, Optional, Callable, Tuple, Optional, Callable, Union import ttnn @@ -134,36 +134,26 @@ def get_device_ids() -> List[int]: def open_mesh_device( mesh_shape: ttnn.MeshShape, + device_ids: List[int], l1_small_size: int = ttnn._ttnn.device.DEFAULT_L1_SMALL_SIZE, trace_region_size: int = ttnn._ttnn.device.DEFAULT_TRACE_REGION_SIZE, num_command_queues: int = 1, dispatch_core_type: int = DispatchCoreType.WORKER, - offset: Tuple[int, int] = (0, 0), - physical_device_ids: List[int] = [], ): """ - Open a mesh device with the specified configuration. - - Args: - mesh_shape (ttnn.MeshShape): The shape of the mesh device. - l1_small_size (int, optional): Size of the L1 small memory. Defaults to ttnn._ttnn.device.DEFAULT_L1_SMALL_SIZE. - trace_region_size (int, optional): Size of the trace region. Defaults to ttnn._ttnn.device.DEFAULT_TRACE_REGION_SIZE. - num_command_queues (int, optional): Number of command queues. Defaults to 1. - dispatch_core_type (int, optional): Type of dispatch core. Defaults to DispatchCoreType.WORKER. - offset (Tuple[int, int], optional): Offset in logical mesh coordinates for the mesh device. Defaults to (0, 0). - - Returns: - ttnn._ttnn.multi_device.MeshDevice: The opened mesh device. + open_mesh_device(mesh_shape: ttnn.MeshShape, device_ids: int) -> ttnn.MeshDevice: + Open a device with the given device_id. If the device is already open, return the existing device. """ + assert len(device_ids) > 0 + return ttnn._ttnn.multi_device.MeshDevice( mesh_shape=mesh_shape.as_tuple(), + device_ids=device_ids, l1_small_size=l1_small_size, trace_region_size=trace_region_size, num_command_queues=num_command_queues, dispatch_core_type=dispatch_core_type, - offset=offset, - physical_device_ids=physical_device_ids, )