Skip to content

Commit

Permalink
Revert "Mesh Virtualization (#12719)"
Browse files Browse the repository at this point in the history
This reverts commit 7738225.
  • Loading branch information
tt-aho committed Sep 23, 2024
1 parent 34675ee commit 742a37d
Show file tree
Hide file tree
Showing 31 changed files with 234 additions and 595 deletions.
16 changes: 13 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device_par

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

mesh_device = ttnn.open_mesh_device(mesh_shape, dispatch_core_type=get_dispatch_core_type(), **device_params)
mesh_device = ttnn.open_mesh_device(
mesh_shape, device_ids[:num_devices_requested], dispatch_core_type=get_dispatch_core_type(), **device_params
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
yield mesh_device
Expand All @@ -233,9 +235,9 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_pcie_devices_requested),
device_ids[:num_pcie_devices_requested],
dispatch_core_type=get_dispatch_core_type(),
**device_params,
physical_device_ids=device_ids[:num_pcie_devices_requested],
)

logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created")
Expand All @@ -254,9 +256,17 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device

if ttnn.get_num_devices() < 8:
pytest.skip()
device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
try:
num_devices_requested = min(request.param, len(device_ids))
except (ValueError, AttributeError):
num_devices_requested = len(device_ids)

request.node.pci_ids = [ttnn.GetPCIeDeviceID(i) for i in device_ids[:num_devices_requested]]

mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(2, 4),
ttnn.MeshShape(1, num_devices_requested),
device_ids[:num_devices_requested],
dispatch_core_type=get_dispatch_core_type(),
**device_params,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_LlamaModel_inference(
if t3k_mesh_device.get_num_devices() < n_devices:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def test_Llama_perf_host(
if t3k_mesh_device.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_Llama_stress_test(
if t3k_mesh_device.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def check_mesh_device(t3k_mesh_device, model_config):
model_config["NUM_DEVICES"],
)

compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size()
compute_grid_size = t3k_mesh_device.get_device(0).compute_with_storage_grid_size()
assert not (
compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]
), ("Requires grid size of at least %d to run", model_config["MAX_GRID_SIZE"])
Expand Down
2 changes: 1 addition & 1 deletion models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def attn_mqa(
value_layer.deallocate(True)

program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=self.mesh_device.compute_with_storage_grid_size(),
compute_with_storage_grid_size=self.mesh_device.get_device(0).compute_with_storage_grid_size(),
q_chunk_size=0, # unused
k_chunk_size=0, # unused
)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def get_mlp_model_config(self, mode):
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(
self.mesh_device.dram_grid_size().x - 1,
self.mesh_device.dram_grid_size().y - 1,
self.mesh_device.get_device(0).dram_grid_size().x - 1,
self.mesh_device.get_device(0).dram_grid_size().y - 1,
),
)
}
Expand Down
7 changes: 3 additions & 4 deletions models/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,10 +971,9 @@ def get_devices_for_t3000(all_devices, num_devices):
if num_devices <= 4:
return all_devices[:num_devices]
elif num_devices == 8:
# Temporary until we move request for ring order to CCL operations directly.
# This is better because we no longer need to manually manage the ring order.
ring_indices = ttnn.get_t3k_physical_device_ids_ring()
return [all_devices[i] for i in ring_indices]
# TODO: Generalize this for different arch
hamiltonian_ring_indices = [0, 4, 5, 1, 2, 6, 7, 3]
return [all_devices[i] for i in hamiltonian_ring_indices]
else:
raise NotImplementedError("Only supports 1, 2, 3, 4, and 8 chip configurations!")

Expand Down
1 change: 0 additions & 1 deletion tests/scripts/tgg/run_tgg_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ run_tgg_tests() {

TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*"
./build/test/tt_metal/unit_tests_galaxy --gtest_filter="GalaxyFixture.*:TGGFixture.*"
pytest -s tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py
}

main() {
Expand Down
4 changes: 2 additions & 2 deletions tests/sweep_framework/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def mesh_device_fixture():
assert ttnn.get_num_devices() >= 8, "Not T3000!"
device_ids = ttnn.get_t3k_physical_device_ids_ring()
device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
num_devices_requested = len(device_ids)
mesh_device = ttnn.open_mesh_device(
ttnn.MeshShape(1, num_devices_requested),
ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested]
)
print("ADD: Opened device mesh")
Expand Down
3 changes: 2 additions & 1 deletion tests/sweep_framework/sweeps/line_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:


def mesh_device_fixture():

assert ttnn.get_num_devices() >= 8, "Not T3000!"
device_ids = ttnn.get_t3k_physical_device_ids_ring()
device_ids = [0, 4, 5, 1, 2, 6, 7, 3]
num_devices_requested = len(device_ids)
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(1, num_devices_requested), device_ids[:num_devices_requested])
print("ALL GATHER: Opened device mesh")
Expand Down
12 changes: 0 additions & 12 deletions tests/ttnn/multichip_unit_tests/test_mesh_device_TGG.py

This file was deleted.

10 changes: 6 additions & 4 deletions tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def test_galaxy_matmul_2d_fracture_dram_sharded(M, K, N, weights_dtype, mesh_sha
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(mesh_device.dram_grid_size().x - 1, mesh_device.dram_grid_size().y - 1),
ttnn.CoreCoord(
mesh_device.get_device(0).dram_grid_size().x - 1, mesh_device.get_device(0).dram_grid_size().y - 1
),
)
}
)
Expand Down Expand Up @@ -735,7 +737,7 @@ def test_fill_cache(

xt = x

compute_grid_size = mesh_device.compute_with_storage_grid_size()
compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
num_cores = min(seq_len // 32 * num_heads, 32) # Always use max 32 cores for testing
mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True))
input_shard_spec = ttnn.ShardSpec(
Expand Down Expand Up @@ -809,7 +811,7 @@ def test_update_cache_decode(
x_new = torch.cat((x_new, torch.zeros(32 - num_users - batch_offset, num_heads, 1, head_dim)), dim=0)
assert x_new.shape[0] == 32, f"Expected x.shape[0] to be 32, got {x_new.shape[0]}"
xt = x_new.permute(2, 1, 0, 3)
compute_grid_size = mesh_device.compute_with_storage_grid_size()
compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
num_cores = min(max(num_users, 32) // 32 * num_heads, compute_grid_size.x * compute_grid_size.y)
mesh_shape = ttnn.CoreRangeSet(ttnn.num_cores_to_corerange_set(num_cores, compute_grid_size, True))
input_shard_spec = ttnn.ShardSpec(
Expand Down Expand Up @@ -891,7 +893,7 @@ def run_test_sdpa_decode_single_iter(
sharded_in=False,
sharded_out=False,
):
compute_grid_size = mesh_device.compute_with_storage_grid_size()
compute_grid_size = mesh_device.get_device(0).compute_with_storage_grid_size()
if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y:
pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}")

Expand Down
32 changes: 25 additions & 7 deletions tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,17 @@ TEST(TGTests, TestAllGatherDeadlock) {
TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4");
TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy");

std::vector<chip_id_t> all_device_ids = {};
for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) {
auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx);
for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) {
auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx);
all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end());
}
}

// Create the device mesh: Grid size is <num_tunnels, tunnel_depth>.
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER);
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER);

// Setup input data and output data containers
MemoryConfig mem_cfg = MemoryConfig{
Expand All @@ -78,7 +87,7 @@ TEST(TGTests, TestAllGatherDeadlock) {
// Iterate over each tunnel and run line all-gather multiple times.
// For each tunnel, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged.
for (uint32_t row = 0; row < 8; row++) {
auto devs = mesh->get_devices_on_row(row);
auto devs = mesh.get_devices_on_row(row);
std::vector<uint32_t> device_ids = {};
for (auto dev : devs) {
device_ids.push_back(dev->id());
Expand Down Expand Up @@ -137,17 +146,26 @@ TEST(TGTests, TestReduceScatterDeadlock) {
TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4");
TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy");

std::vector<chip_id_t> all_device_ids = {};
for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) {
auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx);
for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) {
auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx);
all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end());
}
}

// Create the device mesh: Grid size is <num_tunnels, tunnel_depth>.
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, 0, 0, 1, DispatchCoreType::WORKER);
auto mesh = ttnn::multi_device::open_mesh_device({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1, DispatchCoreType::WORKER);
// Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the
// first tunnel (forward path).
std::vector<Device*> ring_devices = mesh->get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh->get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices = mesh.get_devices_on_row(0); // Tunnel 0
std::vector<Device*> ring_devices_1 = mesh.get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks
ring_devices_1 = std::vector<Device*>(ring_devices_1.begin() + 1, ring_devices_1.end());
std::vector<Device*> ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::vector<Device*> ring_devices_2 = mesh.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering
std::reverse(ring_devices_2.begin(), ring_devices_2.end());
ring_devices_2 = std::vector<Device*>(ring_devices_2.begin() + 1, ring_devices_2.end());
std::vector<Device*> ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::vector<Device*> ring_devices_3 = mesh.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks
std::reverse(ring_devices_3.begin(), ring_devices_3.end());
ring_devices_3 = std::vector<Device*>(ring_devices_3.begin() + 1, ring_devices_3.end() - 1);

Expand Down
13 changes: 6 additions & 7 deletions tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,19 @@ class T3kMultiDeviceFixture : public ::testing::Test {
if (num_devices < 8 or arch != tt::ARCH::WORMHOLE_B0) {
GTEST_SKIP() << "Skipping T3K Multi-Device test suite on non T3K machine.";
}
const auto T3K_DEVICE_IDS = DeviceIds{0, 4, 5, 1, 2, 6, 7, 3};
constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1;
mesh_device_ = MeshDevice::create(
MeshShape{2, 4},
mesh_device_ = std::make_unique<MeshDevice>(
MeshShape{1, num_devices},
T3K_DEVICE_IDS,
DEFAULT_L1_SMALL_SIZE,
DEFAULT_TRACE_REGION_SIZE,
DEFAULT_NUM_COMMAND_QUEUES,
DispatchCoreType::WORKER);
}

void TearDown() override {
mesh_device_->close_devices();
mesh_device_.reset();
}
std::shared_ptr<MeshDevice> mesh_device_;
void TearDown() override { mesh_device_.reset(); }
std::unique_ptr<MeshDevice> mesh_device_;
};

} // namespace ttnn::multi_device::test
9 changes: 5 additions & 4 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def test_mesh_device_open_close_explicit(silicon_arch_name, silicon_arch_wormhol
if num_pcie_devices <= 1:
pytest.skip("Requires multiple devices to run")

multi_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 2))
mesh_shape, device_ids = ttnn.MeshShape(2, 2), ttnn.get_pcie_device_ids()
multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
ttnn.close_mesh_device(multi_device)


Expand All @@ -33,12 +34,12 @@ def test_multi_device_subset_mesh(silicon_arch_name, silicon_arch_wormhole_b0):
if num_pcie_devices <= 1:
pytest.skip("Requires multiple devices to run")

mesh_shape = ttnn.MeshShape(1, 2)
multi_device = ttnn.open_mesh_device(mesh_shape)
mesh_shape, device_ids = ttnn.MeshShape(1, 2), ttnn.get_pcie_device_ids()
multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
assert multi_device.get_num_devices() == 2
ttnn.close_mesh_device(multi_device)

multi_device = ttnn.open_mesh_device(mesh_shape)
multi_device = ttnn.open_mesh_device(mesh_shape, device_ids)
assert multi_device.get_num_devices() == 2
ttnn.close_mesh_device(multi_device)

Expand Down
5 changes: 0 additions & 5 deletions tt_metal/impl/device/mesh_configurations/N300.json

This file was deleted.

6 changes: 0 additions & 6 deletions tt_metal/impl/device/mesh_configurations/T3000.json

This file was deleted.

36 changes: 0 additions & 36 deletions tt_metal/impl/device/mesh_configurations/TG.json

This file was deleted.

Loading

0 comments on commit 742a37d

Please sign in to comment.