Skip to content

Commit

Permalink
#12883: Commonize functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Sep 20, 2024
1 parent 0171c95 commit 3690932
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 278 deletions.
50 changes: 35 additions & 15 deletions tests/ttnn/unit_tests/operations/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from ttnn import ShardTensorToMesh


def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout):
def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout, isN300=False):
if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b:
return True, "Invalid combination"

if num_devices < 2:
if num_devices < 2 and not isN300:
return True, "Requires multiple devices to run"
elif num_devices == 2 and num_links <= 2:
elif num_devices == 2 and num_links <= 2 and not isN300:
return True, "Not enough links to run"

if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0):
Expand Down Expand Up @@ -110,7 +110,7 @@ def run_with_trace(
return tt_out_tensor


def run_all_gather_on_t3000_impl(
def run_all_gather_impl(
all_devices,
num_devices,
input_shape,
Expand All @@ -124,10 +124,14 @@ def run_all_gather_on_t3000_impl(
all_gather_operation,
num_iters=1,
enable_async=False,
isN300=False,
):
if len(all_devices) != 8:
if len(all_devices) != 8 and not isN300:
pytest.skip("Not T3000!")

if len(all_devices) != 2 and isN300:
pytest.skip("Not N300!")

# Use Async mode based on test input config
for device in all_devices:
device.enable_async(enable_async)
Expand All @@ -137,12 +141,14 @@ def run_all_gather_on_t3000_impl(
logger.info(f"dim: {dim}")

(is_known_failure, message) = is_unsupported_case(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout, isN300
)
if is_known_failure:
pytest.skip(f"Skipping unsupported case {message}.")

devices = get_devices_for_t3000(all_devices, num_devices)
devices = all_devices
if not isN300:
devices = get_devices_for_t3000(all_devices, num_devices)
# for device in devices:
# device.disable_and_clear_program_cache()

Expand Down Expand Up @@ -190,7 +196,7 @@ def run_all_gather_on_t3000_impl_tight_loop(
num_iters,
enable_async=False,
):
run_all_gather_on_t3000_impl(
run_all_gather_impl(
all_devices,
num_devices,
input_shape,
Expand Down Expand Up @@ -402,7 +408,7 @@ def test_all_gather_on_t3000_post_commit(
use_program_cache,
function_level_defaults,
):
run_all_gather_on_t3000_impl(
run_all_gather_impl(
all_devices,
num_devices,
input_shape,
Expand Down Expand Up @@ -898,7 +904,7 @@ def test_all_gather_on_t3000_nightly(
):
pytest.xfail(reason="Known failure")

run_all_gather_on_t3000_impl(
run_all_gather_impl(
all_devices,
num_devices,
input_shape,
Expand Down Expand Up @@ -934,12 +940,22 @@ def run_all_gather_sharded(
n_buffer=None,
num_iter=1,
trace_mode=False,
isN300=False,
):
if len(t3k_mesh_device.get_device_ids()) != 8:
pytest.skip("Not T3000!")
if not isN300:
if len(t3k_mesh_device.get_device_ids()) != 8:
pytest.skip("Not T3000!")

all_devices = t3k_mesh_device
if len(all_devices) != 2 and isN300:
pytest.skip("Not N300!")

for device_id in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device_id).enable_async(enable_async)
if not isN300:
for device_id in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device_id).enable_async(enable_async)
if isN300:
for device in all_devices:
device.enable_async(enable_async)

numel = input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3] * num_devices
unchunked_input_shape = list(input_shape)
Expand All @@ -962,7 +978,11 @@ def run_all_gather_sharded(
unchunked_input_tensor = unchunked_input_tensor.bfloat16()

input_tensors = torch.chunk(unchunked_input_tensor, num_devices, dim)
devices = [t3k_mesh_device.get_device(t3k_mesh_device.get_device_ids()[i]) for i in range(num_devices)]

devices = all_devices

if not isN300:
devices = [t3k_mesh_device.get_device(t3k_mesh_device.get_device_ids()[i]) for i in range(num_devices)]

# num_cores =
# compute_grid_size = devices[0].compute_with_storage_grid_size()
Expand Down
Loading

0 comments on commit 3690932

Please sign in to comment.