diff --git a/tests/ttnn/unit_tests/operations/test_all_gather.py b/tests/ttnn/unit_tests/operations/test_all_gather.py index c2adad43aec..bc58fef10eb 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather.py @@ -16,11 +16,6 @@ def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, in if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b: return True, "Invalid combination" - if num_devices < 2: - return True, "Requires multiple devices to run" - elif num_devices == 2 and num_links <= 2: - return True, "Not enough links to run" - if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0): return True, "Unsupported test case" @@ -59,6 +54,19 @@ def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, in return False, "" +def is_unsupported_case_t3k(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout): + if num_devices < 2: + return True, "Requires multiple devices to run" + elif num_devices == 2 and num_links <= 2: + return True, "Not enough links to run" + + return is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout) + + +def is_unsupported_case_n300(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout): + return is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout) + + def run_with_trace( t3k_mesh_device, devices, @@ -110,7 +118,7 @@ def run_with_trace( return tt_out_tensor -def run_all_gather_on_t3000_impl( +def run_all_gather_impl( all_devices, num_devices, input_shape, @@ -122,12 +130,10 @@ def run_all_gather_on_t3000_impl( use_program_cache, function_level_defaults, all_gather_operation, + devices, num_iters=1, enable_async=False, ): - if len(all_devices) != 8: - pytest.skip("Not T3000!") - # Use Async mode based on test input config for device in all_devices: device.enable_async(enable_async) @@ -136,13 +142,6 @@ def run_all_gather_on_t3000_impl( logger.info(f"Input shape: {input_shape}") logger.info(f"dim: {dim}") - (is_known_failure, message) = is_unsupported_case( - input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout - ) - if is_known_failure: - pytest.skip(f"Skipping unsupported case {message}.") - - devices = get_devices_for_t3000(all_devices, num_devices) # for device in devices: # device.disable_and_clear_program_cache() @@ -175,6 +174,92 @@ def run_all_gather_on_t3000_impl( assert eq, f"{i} FAILED: {output}" +def run_all_gather_on_n300_impl( + all_devices, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_operation, + num_iters=1, + enable_async=False, +): + if len(all_devices) != 2: + pytest.skip("Not N300!") + + (is_known_failure, message) = is_unsupported_case_n300( + input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout + ) + if is_known_failure: + pytest.skip(f"Skipping unsupported case {message}.") + + return run_all_gather_impl( + all_devices, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_operation, + all_devices, + num_iters, + enable_async, + ) + + +def run_all_gather_on_t3000_impl( + all_devices, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_operation, + num_iters=1, + enable_async=False, +): + if len(all_devices) != 8: + pytest.skip("Not T3000!") + + (is_known_failure, message) = is_unsupported_case_t3k( + input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout + ) + if is_known_failure: + pytest.skip(f"Skipping unsupported case {message}.") + + devices = get_devices_for_t3000(all_devices, num_devices) + + return run_all_gather_impl( + all_devices, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_operation, + devices, + num_iters, + enable_async, + ) + + def run_all_gather_on_t3000_impl_tight_loop( all_devices, num_devices, @@ -440,7 +525,7 @@ def run_line_all_gather( logger.info(f"Input shape: {input_shape}") logger.info(f"dim: {dim}") - (is_known_failure, message) = is_unsupported_case( + (is_known_failure, message) = is_unsupported_case_t3k( input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout ) if is_known_failure: @@ -494,7 +579,7 @@ def run_line_all_gather_deprecated( logger.info(f"Input shape: {input_shape}") logger.info(f"dim: {dim}") - (is_known_failure, message) = is_unsupported_case( + (is_known_failure, message) = is_unsupported_case_t3k( input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout ) if is_known_failure: @@ -929,18 +1014,13 @@ def run_all_gather_sharded( use_program_cache, function_level_defaults, all_gather_operation, + devices, enable_async, n_worker=None, n_buffer=None, num_iter=1, trace_mode=False, ): - if len(t3k_mesh_device.get_device_ids()) != 8: - pytest.skip("Not T3000!") - - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_async(enable_async) - numel = input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3] * num_devices unchunked_input_shape = list(input_shape) unchunked_input_shape[dim] *= num_devices @@ -962,7 +1042,6 @@ def run_all_gather_sharded( unchunked_input_tensor = unchunked_input_tensor.bfloat16() input_tensors = torch.chunk(unchunked_input_tensor, num_devices, dim) - devices = [t3k_mesh_device.get_device(t3k_mesh_device.get_device_ids()[i]) for i in range(num_devices)] # num_cores = # compute_grid_size = devices[0].compute_with_storage_grid_size() @@ -1079,6 +1158,114 @@ def run_all_gather_sharded( assert all_eq, f"{i} FAILED: {output}" +def run_all_gather_sharded_t3k( + t3k_mesh_device, + num_devices, + input_shape, + input_shard_shape, + shard_grid, + dim, + num_links, + orientation, + input_dtype, + tensor_layout, + tensor_mem_layout, + # num_cores, + use_program_cache, + function_level_defaults, + all_gather_operation, + enable_async, + n_worker=None, + n_buffer=None, + num_iter=1, + trace_mode=False, +): + if len(t3k_mesh_device.get_device_ids()) != 8: + pytest.skip("Not T3000!") + + for device_id in t3k_mesh_device.get_device_ids(): + t3k_mesh_device.get_device(device_id).enable_async(enable_async) + + devices = [t3k_mesh_device.get_device(t3k_mesh_device.get_device_ids()[i]) for i in range(num_devices)] + + return run_all_gather_sharded( + t3k_mesh_device, + num_devices, + input_shape, + input_shard_shape, + shard_grid, + dim, + num_links, + orientation, + input_dtype, + tensor_layout, + tensor_mem_layout, + # num_cores, + use_program_cache, + function_level_defaults, + all_gather_operation, + devices, + enable_async, + n_worker, + n_buffer, + num_iter, + trace_mode, + ) + + +def run_all_gather_sharded_n300( + all_devices, + num_devices, + input_shape, + input_shard_shape, + shard_grid, + dim, + num_links, + orientation, + input_dtype, + tensor_layout, + tensor_mem_layout, + # num_cores, + use_program_cache, + function_level_defaults, + all_gather_operation, + enable_async, + n_worker=None, + n_buffer=None, + num_iter=1, + trace_mode=False, +): + if len(all_devices) != 2: + pytest.skip("Not N300!") + + for device in all_devices: + device.enable_async(enable_async) + + return run_all_gather_sharded( + all_devices, + num_devices, + input_shape, + input_shard_shape, + shard_grid, + dim, + num_links, + orientation, + input_dtype, + tensor_layout, + tensor_mem_layout, + # num_cores, + use_program_cache, + function_level_defaults, + all_gather_operation, + all_devices, + enable_async, + n_worker, + n_buffer, + num_iter, + trace_mode, + ) + + # @pytest.mark.parametrize("num_devices", [4, 8]) @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize("num_devices", [8]) @@ -1146,7 +1333,7 @@ def test_all_gather_sharded_post_commit( function_level_defaults, enable_async, ): - run_all_gather_sharded( + run_all_gather_sharded_t3k( t3k_mesh_device, num_devices, input_shape, @@ -1236,7 +1423,7 @@ def test_all_gather_height_sharded_post_commit( function_level_defaults, enable_async, ): - run_all_gather_sharded( + run_all_gather_sharded_t3k( t3k_mesh_device, num_devices, input_shape, @@ -1320,7 +1507,7 @@ def test_all_gather_block_sharded_post_commit( function_level_defaults, enable_async, ): - run_all_gather_sharded( + run_all_gather_sharded_t3k( t3k_mesh_device, num_devices, input_shape, @@ -1412,7 +1599,7 @@ def test_line_all_gather_sharded_post_commit( function_level_defaults, enable_async, ): - run_all_gather_sharded( + run_all_gather_sharded_t3k( t3k_mesh_device, num_devices, input_shape, @@ -1576,7 +1763,7 @@ def test_sharded_all_gather_nightly( all_gather_operation, enable_async, ): - run_all_gather_sharded( + run_all_gather_sharded_t3k( t3k_mesh_device, num_devices, input_shape, diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_N300_post_commit.py b/tests/ttnn/unit_tests/operations/test_all_gather_N300_post_commit.py new file mode 100644 index 00000000000..ba6105f9222 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_all_gather_N300_post_commit.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc +from models.utility_functions import skip_for_grayskull +from tests.ttnn.unit_tests.operations.test_all_gather import ( + run_all_gather_on_n300_impl, + run_all_gather_sharded_n300, +) + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, input_shape, dim, layout", + [ + (2, 1, [1, 1, 64, 16384], 3, ttnn.TILE_LAYOUT), + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize( + "mem_config", + [ + ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), + ], +) +@pytest.mark.parametrize("num_iters", [1]) +@pytest.mark.parametrize("enable_async", [True, False]) +def test_all_gather_on_n300_post_commit( + all_devices, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + num_iters, + use_program_cache, + function_level_defaults, + enable_async, +): + run_all_gather_on_n300_impl( + all_devices, + num_devices, + input_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_operation=ttnn.all_gather, + num_iters=num_iters, + enable_async=enable_async, + ) + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize("num_devices", [2]) +@pytest.mark.parametrize("dim", [3]) +@pytest.mark.parametrize("tensor_layout", [ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize( + "tensor_mem_layout", + [ + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ], +) +@pytest.mark.parametrize("orientation", [ttnn.ShardOrientation.ROW_MAJOR]) +@pytest.mark.parametrize("num_links", [1]) +@pytest.mark.parametrize( + "input_shape, input_shard_shape,shard_grid", + ( + ( + (1, 1, 512, 2048), + (128, 256), + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), + ), + ), +) +@pytest.mark.parametrize("enable_async", [True]) +def test_all_gather_sharded_n300_post_commit( + all_devices, + num_devices, + input_shape, + input_shard_shape, + shard_grid, + dim, + num_links, + orientation, + input_dtype, + tensor_layout, + tensor_mem_layout, + # num_cores, + use_program_cache, + function_level_defaults, + enable_async, +): + run_all_gather_sharded_n300( + all_devices, + num_devices, + input_shape, + input_shard_shape, + shard_grid, + dim, + num_links, + orientation, + input_dtype, + tensor_layout, + tensor_mem_layout, + # num_cores, + use_program_cache, + function_level_defaults, + all_gather_operation=ttnn.all_gather, + enable_async=enable_async, + ) diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 65a9b6af40e..5e2cc7a5d74 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -138,12 +138,6 @@ void AllGather::validate(const std::vector &input_tensors) const { TT_FATAL(this->num_links > 0, "Error"); TT_FATAL(this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, "Worker cores used by links are parallelizaed over rows"); TT_FATAL(this->receiver_device_id.has_value() || this->sender_device_id.has_value(), "Error"); - if (this->receiver_device_id == this->sender_device_id) { - TT_FATAL(input_tensor.device()->get_ethernet_sockets(this->receiver_device_id.value()).size() >= 2 * this->num_links, "2 Device all gather requires at least 2 eth connections per link"); - } else { - TT_FATAL(this->topology == all_gather_op::Topology::Linear || (this->receiver_device_id.has_value() && input_tensor.device()->get_ethernet_sockets(this->receiver_device_id.value()).size() >= this->num_links), "All gather requires at least 1 eth connection per link between sender device {} and receiver device {}", this->sender_device_id, this->receiver_device_id); - TT_FATAL(this->topology == all_gather_op::Topology::Linear || (this->sender_device_id.has_value() &&input_tensor.device()->get_ethernet_sockets(this->sender_device_id.value()).size() >= this->num_links), "All gather requires at least 1 eth connection per link between sender device {} and receiver device {}", this->sender_device_id, this->receiver_device_id); - } TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED || input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED || @@ -193,7 +187,7 @@ Tensor all_gather( all_gather_op::Topology topology = all_gather_op::Topology::Ring; auto devices = input_tensor.get_workers(); uint32_t num_devices = devices.size(); - if (num_devices == 1){ + if (num_devices == 2){ topology = all_gather_op::Topology::Linear; } std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};