Skip to content

Commit

Permalink
Mesh Virtualization (#12719)
Browse files Browse the repository at this point in the history
* #10608: add method to fetch ethernet coordinates from cluster

* #10608: add device virtualization btw MeshDevice and physical system

- This fixes #10608, #10419 by adding a logical 2D mesh to remap the
physical mesh onto a flattened 2d grid.
- A logical to physical coordinate translation map is now supplied
per supported system.
  • Loading branch information
cfjchu authored Sep 19, 2024
1 parent a19c57a commit 7738225
Show file tree
Hide file tree
Showing 31 changed files with 595 additions and 234 deletions.
16 changes: 3 additions & 13 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,7 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

mesh_device = ttnn.open_mesh_device(
mesh_shape, device_ids[:num_devices_requested], dispatch_core_type=get_dispatch_core_type(), **device_params
)
mesh_device = ttnn.open_mesh_device(mesh_shape, dispatch_core_type=get_dispatch_core_type(), **device_params)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
yield mesh_device
Expand All @@ -235,9 +233,9 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_pcie_devices_requested),
device_ids[:num_pcie_devices_requested],
dispatch_core_type=get_dispatch_core_type(),
**device_params,
physical_device_ids=device_ids[:num_pcie_devices_requested],
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand All @@ -256,17 +254,9 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device

if ttnn.get_num_devices() < 8:
pytest.skip()
device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
try:
num_devices_requested = min(request.param, len(device_ids))
except (ValueError, AttributeError):
num_devices_requested = len(device_ids)

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_devices_requested),
device_ids[:num_devices_requested],
ttnn.MeshShape(2, 4),
dispatch_core_type=get_dispatch_core_type(),
**device_params,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_LlamaModel_inference(
if t3k_mesh_device.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_Llama_perf_host(
if t3k_mesh_device.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_Llama_stress_test(
if t3k_mesh_device.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def check_mesh_device(t3k_mesh_device, model_config):
model_config["NUM_DEVICES"],
)

compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
assert not (
compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]
), ("Requires grid size of at least %d to run", model_config["MAX_GRID_SIZE"])
Expand Down
2 changes: 1 addition & 1 deletion models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def attn_mqa(
value_layer.deallocate(True)

program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=self.mesh_device.get_device(0).compute_with_storage_grid_size(),
compute_with_storage_grid_size=self.mesh_device.compute_with_storage_grid_size(),
q_chunk_size=0, # unused
k_chunk_size=0, # unused
)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def get_mlp_model_config(self):
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(
self.mesh_device.get_device(0).dram_grid_size().x - 1,
self.mesh_device.get_device(0).dram_grid_size().y - 1,
self.mesh_device.dram_grid_size().x - 1,
self.mesh_device.dram_grid_size().y - 1,
),
)
}
Expand Down
7 changes: 4 additions & 3 deletions models/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,9 +971,10 @@ def get_devices_for_t3000(all_devices, num_devices):
if num_devices <= 4:
return all_devices[:num_devices]
elif num_devices == 8:
# TODO: Generalize this for different arch
hamiltonian_ring_indices = [0, 4, 5, 1, 2, 6, 7, 3]
return [all_devices[i] for i in hamiltonian_ring_indices]
# Temporary until we move request for ring order to CCL operations directly.
# This is better because we no longer need to manually manage the ring order.
ring_indices = ttnn.get_t3k_physical_device_ids_ring()
return [all_devices[i] for i in ring_indices]
else:
raise NotImplementedError("Only supports 1, 2, 3, 4, and 8 chip configurations!")

Expand Down
1 change: 1 addition & 0 deletions tests/scripts/tgg/run_tgg_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ run_tgg_tests() {

TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*"
./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*"
pytest -s tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py
}

main() {
Expand Down
4 changes: 2 additions & 2 deletions tests/sweep_framework/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def mesh_device_fixture():
assert ttnn.get_num_devices() >= 8, "Not T3000!"
device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
device_ids = ttnn.get_t3k_physical_device_ids_ring()
num_devices_requested = len(device_ids)
mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested]
ttnn.MeshShape(1, num_devices_requested),
)
print("ADD: Opened device mesh")
Expand Down
3 changes: 1 addition & 2 deletions tests/sweep_framework/sweeps/line_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:


def mesh_device_fixture():

assert ttnn.get_num_devices() >= 8, "Not T3000!"
device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
device_ids = ttnn.get_t3k_physical_device_ids_ring()
num_devices_requested = len(device_ids)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested])
print("ALL GATHER: Opened device mesh")
Expand Down
12 changes: 12 additions & 0 deletions tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn


@pytest.mark.parametrize("mesh_device", [pytest.param((8, 8), id="8x8_grid")], indirect=True)
def test_visualize_mesh_device(mesh_device):
ttnn.visualize_mesh_device(mesh_device)
10 changes: 4 additions & 6 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@ def test_galaxy_matmul_2d_fracture_dram_sharded(M, K, N, weights_dtype, mesh_sha
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(
mesh_device.get_device(0).dram_grid_size().x - 1, mesh_device.get_device(0).dram_grid_size().y - 1
),
ttnn.CoreCoord(mesh_device.dram_grid_size().x - 1, mesh_device.dram_grid_size().y - 1),
)
}
)
Expand Down Expand Up @@ -711,7 +709,7 @@ def test_fill_cache(

xt = x

compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = mesh_device.compute_with_storage_grid_size()
num_cores = min(seq_len // 32 * num_heads, 32) # Always use max 32 cores for testing
mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True))
input_shard_spec = ttnn.ShardSpec(
Expand Down Expand Up @@ -783,7 +781,7 @@ def test_update_cache_decode(
x_new = torch.cat((x_new, torch.zeros(32 - num_users - batch_offset, num_heads, 1, head_dim)), dim=0)
assert x_new.shape[0] == 32, f"Expected x.shape[0] to be 32, got {x_new.shape[0]}"
xt = x_new.permute(2, 1, 0, 3)
compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = mesh_device.compute_with_storage_grid_size()
num_cores = min(max(num_users, 32) // 32 * num_heads, compute_grid_size.x * compute_grid_size.y)
mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True))
input_shard_spec = ttnn.ShardSpec(
Expand Down Expand Up @@ -865,7 +863,7 @@ def run_test_sdpa_decode_single_iter(
sharded_in=False,
sharded_out=False,
):
compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = mesh_device.compute_with_storage_grid_size()
if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y:
pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}")

Expand Down
32 changes: 7 additions & 25 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,8 @@ TEST(TGTests, TestAllGatherDeadlock) {
TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4");
TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy");

std::vector<chip_id_t> all_device_ids = {};
for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) {
auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx);
for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) {
auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx);
all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end());
}
}

// Create the device mesh: Grid size is <num_tunnels, tunnel_depth>.
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER);
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER);

// Setup input data and output data containers
MemoryConfig mem_cfg = MemoryConfig{
Expand All @@ -87,7 +78,7 @@ TEST(TGTests, TestAllGatherDeadlock) {
// Iterate over each tunnel and run line all-gather multiple times.
// For each tunnel, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
for (uint32_t row = 0; row < 8; row++) {
auto devs = mesh.get_devices_on_row(row);
auto devs = mesh->get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
for (auto dev : devs) {
device_ids.push_back(dev->id());
Expand Down Expand Up @@ -146,26 +137,17 @@ TEST(TGTests, TestReduceScatterDeadlock) {
TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4");
TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy");

std::vector<chip_id_t> all_device_ids = {};
for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) {
auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx);
for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) {
auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx);
all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end());
}
}

// Create the device mesh: Grid size is <num_tunnels, tunnel_depth>.
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER);
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
std::vector<Device*> ring_devices = mesh.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh.get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices = mesh->get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh->get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
std::vector<Device*> ring_devices_2 = mesh.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::vector<Device*> ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::reverse(ring_devices_2.begin(), ring_devices_2.end());
ring_devices_2 = std::vector<Device*>(ring_devices_2.begin() + 1, ring_devices_2.end());
std::vector<Device*> ring_devices_3 = mesh.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::reverse(ring_devices_3.begin(), ring_devices_3.end());
ring_devices_3 = std::vector<Device*>(ring_devices_3.begin() + 1, ring_devices_3.end() - 1);

Expand Down
13 changes: 7 additions & 6 deletions tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,20 @@ class T3kMultiDeviceFixture : public ::testing::Test {
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
const auto T3K_DEVICE_IDS = DeviceIds{0, 4, 5, 1, 2, 6, 7, 3};
constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1;
mesh_device_ = std::make_unique<MeshDevice>(
MeshShape{1, num_devices},
T3K_DEVICE_IDS,
mesh_device_ = MeshDevice::create(
MeshShape{2, 4},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
DEFAULT_NUM_COMMAND_QUEUES,
DispatchCoreType::WORKER);
}

void TearDown() override { mesh_device_.reset(); }
std::unique_ptr<MeshDevice> mesh_device_;
void TearDown() override {
mesh_device_->close_devices();
mesh_device_.reset();
}
std::shared_ptr<MeshDevice> mesh_device_;
};

} // namespace ttnn::multi_device::test
9 changes: 4 additions & 5 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def test_mesh_device_open_close_explicit(silicon_arch_name, silicon_arch_wormhol
if num_pcie_devices <= 1:
pytest.skip("Requires multiple devices to run")

mesh_shape, device_ids = ttnn.MeshShape(2, 2), ttnn.get_pcie_device_ids()
multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
multi_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 2))
ttnn.close_mesh_device(multi_device)


Expand All @@ -34,12 +33,12 @@ def test_multi_device_subset_mesh(silicon_arch_name, silicon_arch_wormhole_b0):
if num_pcie_devices <= 1:
pytest.skip("Requires multiple devices to run")

mesh_shape, device_ids = ttnn.MeshShape(1, 2), ttnn.get_pcie_device_ids()
multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
mesh_shape = ttnn.MeshShape(1, 2)
multi_device = ttnn.open_mesh_device(mesh_shape)
assert multi_device.get_num_devices() == 2
ttnn.close_mesh_device(multi_device)

multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
multi_device = ttnn.open_mesh_device(mesh_shape)
assert multi_device.get_num_devices() == 2
ttnn.close_mesh_device(multi_device)

Expand Down
5 changes: 5 additions & 0 deletions tt_metal/impl/device/mesh_configurations/N300.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]]
]
}
6 changes: 6 additions & 0 deletions tt_metal/impl/device/mesh_configurations/T3000.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]],
[[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]]
]
}
36 changes: 36 additions & 0 deletions tt_metal/impl/device/mesh_configurations/TG.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [7, 3, 0, 1]],
[[0, 1], [7, 2, 0, 1]],
[[0, 2], [7, 1, 0, 1]],
[[0, 3], [7, 0, 0, 1]],
[[1, 0], [6, 3, 0, 1]],
[[1, 1], [6, 2, 0, 1]],
[[1, 2], [6, 1, 0, 1]],
[[1, 3], [6, 0, 0, 1]],
[[2, 0], [5, 3, 0, 1]],
[[2, 1], [5, 2, 0, 1]],
[[2, 2], [5, 1, 0, 1]],
[[2, 3], [5, 0, 0, 1]],
[[3, 0], [4, 3, 0, 1]],
[[3, 1], [4, 2, 0, 1]],
[[3, 2], [4, 1, 0, 1]],
[[3, 3], [4, 0, 0, 1]],
[[4, 0], [3, 3, 0, 1]],
[[4, 1], [3, 2, 0, 1]],
[[4, 2], [3, 1, 0, 1]],
[[4, 3], [3, 0, 0, 1]],
[[5, 0], [2, 3, 0, 1]],
[[5, 1], [2, 2, 0, 1]],
[[5, 2], [2, 1, 0, 1]],
[[5, 3], [2, 0, 0, 1]],
[[6, 0], [1, 3, 0, 1]],
[[6, 1], [1, 2, 0, 1]],
[[6, 2], [1, 1, 0, 1]],
[[6, 3], [1, 0, 0, 1]],
[[7, 0], [0, 3, 0, 1]],
[[7, 1], [0, 2, 0, 1]],
[[7, 2], [0, 1, 0, 1]],
[[7, 3], [0, 0, 0, 1]]
]
}
Loading

0 comments on commit 7738225

Please sign in to comment.