Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mesh Virtualization #12719

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,7 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

mesh_device = ttnn.open_mesh_device(
mesh_shape, device_ids[:num_devices_requested], dispatch_core_type=get_dispatch_core_type(), **device_params
)
mesh_device = ttnn.open_mesh_device(mesh_shape, dispatch_core_type=get_dispatch_core_type(), **device_params)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
yield mesh_device
Expand All @@ -235,9 +233,9 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_pcie_devices_requested),
device_ids[:num_pcie_devices_requested],
dispatch_core_type=get_dispatch_core_type(),
**device_params,
physical_device_ids=device_ids[:num_pcie_devices_requested],
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand All @@ -256,17 +254,9 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device

if ttnn.get_num_devices() < 8:
pytest.skip()
device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
try:
num_devices_requested = min(request.param, len(device_ids))
except (ValueError, AttributeError):
num_devices_requested = len(device_ids)

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_devices_requested),
device_ids[:num_devices_requested],
ttnn.MeshShape(2, 4),
dispatch_core_type=get_dispatch_core_type(),
**device_params,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_LlamaModel_inference(
if t3k_mesh_device.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_Llama_perf_host(
if t3k_mesh_device.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_Llama_stress_test(
if t3k_mesh_device.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def check_mesh_device(t3k_mesh_device, model_config):
model_config["NUM_DEVICES"],
)

compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
assert not (
compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]
), ("Requires grid size of at least %d to run", model_config["MAX_GRID_SIZE"])
Expand Down
2 changes: 1 addition & 1 deletion models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def attn_mqa(
value_layer.deallocate(True)

program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=self.mesh_device.get_device(0).compute_with_storage_grid_size(),
compute_with_storage_grid_size=self.mesh_device.compute_with_storage_grid_size(),
q_chunk_size=0, # unused
k_chunk_size=0, # unused
)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def get_mlp_model_config(self):
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(
self.mesh_device.get_device(0).dram_grid_size().x - 1,
self.mesh_device.get_device(0).dram_grid_size().y - 1,
self.mesh_device.dram_grid_size().x - 1,
self.mesh_device.dram_grid_size().y - 1,
),
)
}
Expand Down
7 changes: 4 additions & 3 deletions models/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,9 +971,10 @@ def get_devices_for_t3000(all_devices, num_devices):
if num_devices <= 4:
return all_devices[:num_devices]
elif num_devices == 8:
# TODO: Generalize this for different arch
hamiltonian_ring_indices = [0, 4, 5, 1, 2, 6, 7, 3]
return [all_devices[i] for i in hamiltonian_ring_indices]
# Temporary until we move request for ring order to CCL operations directly.
# This is better because we no longer need to manually manage the ring order.
ring_indices = ttnn.get_t3k_physical_device_ids_ring()
return [all_devices[i] for i in ring_indices]
else:
raise NotImplementedError("Only supports 1, 2, 3, 4, and 8 chip configurations!")

Expand Down
1 change: 1 addition & 0 deletions tests/scripts/tgg/run_tgg_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ run_tgg_tests() {

TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*"
./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*"
pytest -s tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py
}

main() {
Expand Down
4 changes: 2 additions & 2 deletions tests/sweep_framework/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def mesh_device_fixture():

assert ttnn.get_num_devices() >= 8, "Not T3000!"

device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
device_ids = ttnn.get_t3k_physical_device_ids_ring()
num_devices_requested = len(device_ids)

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested]
ttnn.MeshShape(1, num_devices_requested),
)

print("ADD: Opened device mesh")
Expand Down
3 changes: 1 addition & 2 deletions tests/sweep_framework/sweeps/line_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:


def mesh_device_fixture():

assert ttnn.get_num_devices() >= 8, "Not T3000!"
device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
device_ids = ttnn.get_t3k_physical_device_ids_ring()
num_devices_requested = len(device_ids)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested])
print("ALL GATHER: Opened device mesh")
Expand Down
12 changes: 12 additions & 0 deletions tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn


@pytest.mark.parametrize("mesh_device", [pytest.param((8, 8), id="8x8_grid")], indirect=True)
def test_visualize_mesh_device(mesh_device):
ttnn.visualize_mesh_device(mesh_device)
10 changes: 4 additions & 6 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@ def test_galaxy_matmul_2d_fracture_dram_sharded(M, K, N, weights_dtype, mesh_sha
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(
mesh_device.get_device(0).dram_grid_size().x - 1, mesh_device.get_device(0).dram_grid_size().y - 1
),
ttnn.CoreCoord(mesh_device.dram_grid_size().x - 1, mesh_device.dram_grid_size().y - 1),
)
}
)
Expand Down Expand Up @@ -711,7 +709,7 @@ def test_fill_cache(

xt = x

compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = mesh_device.compute_with_storage_grid_size()
num_cores = min(seq_len // 32 * num_heads, 32) # Always use max 32 cores for testing
mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True))
input_shard_spec = ttnn.ShardSpec(
Expand Down Expand Up @@ -783,7 +781,7 @@ def test_update_cache_decode(
x_new = torch.cat((x_new, torch.zeros(32 - num_users - batch_offset, num_heads, 1, head_dim)), dim=0)
assert x_new.shape[0] == 32, f"Expected x.shape[0] to be 32, got {x_new.shape[0]}"
xt = x_new.permute(2, 1, 0, 3)
compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = mesh_device.compute_with_storage_grid_size()
num_cores = min(max(num_users, 32) // 32 * num_heads, compute_grid_size.x * compute_grid_size.y)
mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True))
input_shard_spec = ttnn.ShardSpec(
Expand Down Expand Up @@ -865,7 +863,7 @@ def run_test_sdpa_decode_single_iter(
sharded_in=False,
sharded_out=False,
):
compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
compute_grid_size = mesh_device.compute_with_storage_grid_size()
if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y:
pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}")

Expand Down
32 changes: 7 additions & 25 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,8 @@ TEST(TGTests, TestAllGatherDeadlock) {
TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4");
TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy");

std::vector<chip_id_t> all_device_ids = {};
for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) {
auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx);
for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) {
auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx);
all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end());
}
}

// Create the device mesh: Grid size is <num_tunnels, tunnel_depth>.
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER);
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER);

// Setup input data and output data containers
MemoryConfig mem_cfg = MemoryConfig{
Expand All @@ -87,7 +78,7 @@ TEST(TGTests, TestAllGatherDeadlock) {
// Iterate over each tunnel and run line all-gather multiple times.
// For each tunnel, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
for (uint32_t row = 0; row < 8; row++) {
auto devs = mesh.get_devices_on_row(row);
auto devs = mesh->get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
for (auto dev : devs) {
device_ids.push_back(dev->id());
Expand Down Expand Up @@ -146,26 +137,17 @@ TEST(TGTests, TestReduceScatterDeadlock) {
TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4");
TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy");

std::vector<chip_id_t> all_device_ids = {};
for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) {
auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx);
for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) {
auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx);
all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end());
}
}

// Create the device mesh: Grid size is <num_tunnels, tunnel_depth>.
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER);
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
std::vector<Device*> ring_devices = mesh.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh.get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices = mesh->get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh->get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
std::vector<Device*> ring_devices_2 = mesh.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::vector<Device*> ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::reverse(ring_devices_2.begin(), ring_devices_2.end());
ring_devices_2 = std::vector<Device*>(ring_devices_2.begin() + 1, ring_devices_2.end());
std::vector<Device*> ring_devices_3 = mesh.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::reverse(ring_devices_3.begin(), ring_devices_3.end());
ring_devices_3 = std::vector<Device*>(ring_devices_3.begin() + 1, ring_devices_3.end() - 1);

Expand Down
13 changes: 7 additions & 6 deletions tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,20 @@ class T3kMultiDeviceFixture : public ::testing::Test {
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
const auto T3K_DEVICE_IDS = DeviceIds{0, 4, 5, 1, 2, 6, 7, 3};
constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1;
mesh_device_ = std::make_unique<MeshDevice>(
MeshShape{1, num_devices},
T3K_DEVICE_IDS,
mesh_device_ = MeshDevice::create(
MeshShape{2, 4},
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
DEFAULT_NUM_COMMAND_QUEUES,
DispatchCoreType::WORKER);
}

void TearDown() override { mesh_device_.reset(); }
std::unique_ptr<MeshDevice> mesh_device_;
void TearDown() override {
mesh_device_->close_devices();
mesh_device_.reset();
}
std::shared_ptr<MeshDevice> mesh_device_;
};

} // namespace ttnn::multi_device::test
9 changes: 4 additions & 5 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def test_mesh_device_open_close_explicit(silicon_arch_name, silicon_arch_wormhol
if num_pcie_devices <= 1:
pytest.skip("Requires multiple devices to run")

mesh_shape, device_ids = ttnn.MeshShape(2, 2), ttnn.get_pcie_device_ids()
multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
multi_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 2))
ttnn.close_mesh_device(multi_device)


Expand All @@ -34,12 +33,12 @@ def test_multi_device_subset_mesh(silicon_arch_name, silicon_arch_wormhole_b0):
if num_pcie_devices <= 1:
pytest.skip("Requires multiple devices to run")

mesh_shape, device_ids = ttnn.MeshShape(1, 2), ttnn.get_pcie_device_ids()
multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
mesh_shape = ttnn.MeshShape(1, 2)
multi_device = ttnn.open_mesh_device(mesh_shape)
assert multi_device.get_num_devices() == 2
ttnn.close_mesh_device(multi_device)

multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
multi_device = ttnn.open_mesh_device(mesh_shape)
assert multi_device.get_num_devices() == 2
ttnn.close_mesh_device(multi_device)

Expand Down
5 changes: 5 additions & 0 deletions tt_metal/impl/device/mesh_configurations/N300.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]]
]
}
6 changes: 6 additions & 0 deletions tt_metal/impl/device/mesh_configurations/T3000.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]],
[[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]]
]
}
36 changes: 36 additions & 0 deletions tt_metal/impl/device/mesh_configurations/TG.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"logical_to_physical_coordinates": [
[[0, 0], [7, 3, 0, 1]],
[[0, 1], [7, 2, 0, 1]],
[[0, 2], [7, 1, 0, 1]],
[[0, 3], [7, 0, 0, 1]],
[[1, 0], [6, 3, 0, 1]],
[[1, 1], [6, 2, 0, 1]],
[[1, 2], [6, 1, 0, 1]],
[[1, 3], [6, 0, 0, 1]],
[[2, 0], [5, 3, 0, 1]],
[[2, 1], [5, 2, 0, 1]],
[[2, 2], [5, 1, 0, 1]],
[[2, 3], [5, 0, 0, 1]],
[[3, 0], [4, 3, 0, 1]],
[[3, 1], [4, 2, 0, 1]],
[[3, 2], [4, 1, 0, 1]],
[[3, 3], [4, 0, 0, 1]],
[[4, 0], [3, 3, 0, 1]],
[[4, 1], [3, 2, 0, 1]],
[[4, 2], [3, 1, 0, 1]],
[[4, 3], [3, 0, 0, 1]],
[[5, 0], [2, 3, 0, 1]],
[[5, 1], [2, 2, 0, 1]],
[[5, 2], [2, 1, 0, 1]],
[[5, 3], [2, 0, 0, 1]],
[[6, 0], [1, 3, 0, 1]],
[[6, 1], [1, 2, 0, 1]],
[[6, 2], [1, 1, 0, 1]],
[[6, 3], [1, 0, 0, 1]],
[[7, 0], [0, 3, 0, 1]],
[[7, 1], [0, 2, 0, 1]],
[[7, 2], [0, 1, 0, 1]],
[[7, 3], [0, 0, 0, 1]]
]
}
Loading
Loading