Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#12883: Add initial unit tests for N300 #12922

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 217 additions & 30 deletions tests/ttnn/unit_tests/operations/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, in
if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b:
return True, "Invalid combination"

if num_devices < 2:
return True, "Requires multiple devices to run"
elif num_devices == 2 and num_links <= 2:
return True, "Not enough links to run"

if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0):
return True, "Unsupported test case"

Expand Down Expand Up @@ -59,6 +54,19 @@ def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, in
return False, ""


def is_unsupported_case_t3k(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout):
if num_devices < 2:
return True, "Requires multiple devices to run"
elif num_devices == 2 and num_links <= 2:
return True, "Not enough links to run"

return is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout)


def is_unsupported_case_n300(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout):
return is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout)


def run_with_trace(
t3k_mesh_device,
devices,
Expand Down Expand Up @@ -110,7 +118,7 @@ def run_with_trace(
return tt_out_tensor


def run_all_gather_on_t3000_impl(
def run_all_gather_impl(
all_devices,
num_devices,
input_shape,
Expand All @@ -122,12 +130,10 @@ def run_all_gather_on_t3000_impl(
use_program_cache,
function_level_defaults,
all_gather_operation,
devices,
num_iters=1,
enable_async=False,
):
if len(all_devices) != 8:
pytest.skip("Not T3000!")

# Use Async mode based on test input config
for device in all_devices:
device.enable_async(enable_async)
Expand All @@ -136,13 +142,6 @@ def run_all_gather_on_t3000_impl(
logger.info(f"Input shape: {input_shape}")
logger.info(f"dim: {dim}")

(is_known_failure, message) = is_unsupported_case(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
pytest.skip(f"Skipping unsupported case {message}.")

devices = get_devices_for_t3000(all_devices, num_devices)
# for device in devices:
# device.disable_and_clear_program_cache()

Expand Down Expand Up @@ -175,6 +174,92 @@ def run_all_gather_on_t3000_impl(
assert eq, f"{i} FAILED: {output}"


def run_all_gather_on_n300_impl(
all_devices,
num_devices,
input_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation,
num_iters=1,
enable_async=False,
):
if len(all_devices) != 2:
pytest.skip("Not N300!")

(is_known_failure, message) = is_unsupported_case_n300(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
pytest.skip(f"Skipping unsupported case {message}.")

return run_all_gather_impl(
all_devices,
num_devices,
input_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation,
all_devices,
num_iters,
enable_async,
)


def run_all_gather_on_t3000_impl(
all_devices,
num_devices,
input_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation,
num_iters=1,
enable_async=False,
):
if len(all_devices) != 8:
pytest.skip("Not T3000!")

(is_known_failure, message) = is_unsupported_case_t3k(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
pytest.skip(f"Skipping unsupported case {message}.")

devices = get_devices_for_t3000(all_devices, num_devices)

return run_all_gather_impl(
all_devices,
num_devices,
input_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_operation,
devices,
num_iters,
enable_async,
)


def run_all_gather_on_t3000_impl_tight_loop(
all_devices,
num_devices,
Expand Down Expand Up @@ -440,7 +525,7 @@ def run_line_all_gather(
logger.info(f"Input shape: {input_shape}")
logger.info(f"dim: {dim}")

(is_known_failure, message) = is_unsupported_case(
(is_known_failure, message) = is_unsupported_case_t3k(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
Expand Down Expand Up @@ -494,7 +579,7 @@ def run_line_all_gather_deprecated(
logger.info(f"Input shape: {input_shape}")
logger.info(f"dim: {dim}")

(is_known_failure, message) = is_unsupported_case(
(is_known_failure, message) = is_unsupported_case_t3k(
input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout
)
if is_known_failure:
Expand Down Expand Up @@ -929,18 +1014,13 @@ def run_all_gather_sharded(
use_program_cache,
function_level_defaults,
all_gather_operation,
devices,
enable_async,
n_worker=None,
n_buffer=None,
num_iter=1,
trace_mode=False,
):
if len(t3k_mesh_device.get_device_ids()) != 8:
pytest.skip("Not T3000!")

for device_id in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device_id).enable_async(enable_async)

numel = input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3] * num_devices
unchunked_input_shape = list(input_shape)
unchunked_input_shape[dim] *= num_devices
Expand All @@ -962,7 +1042,6 @@ def run_all_gather_sharded(
unchunked_input_tensor = unchunked_input_tensor.bfloat16()

input_tensors = torch.chunk(unchunked_input_tensor, num_devices, dim)
devices = [t3k_mesh_device.get_device(t3k_mesh_device.get_device_ids()[i]) for i in range(num_devices)]

# num_cores =
# compute_grid_size = devices[0].compute_with_storage_grid_size()
Expand Down Expand Up @@ -1079,6 +1158,114 @@ def run_all_gather_sharded(
assert all_eq, f"{i} FAILED: {output}"


def run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
input_shard_shape,
shard_grid,
dim,
num_links,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
# num_cores,
use_program_cache,
function_level_defaults,
all_gather_operation,
enable_async,
n_worker=None,
n_buffer=None,
num_iter=1,
trace_mode=False,
):
if len(t3k_mesh_device.get_device_ids()) != 8:
pytest.skip("Not T3000!")

for device_id in t3k_mesh_device.get_device_ids():
t3k_mesh_device.get_device(device_id).enable_async(enable_async)

devices = [t3k_mesh_device.get_device(t3k_mesh_device.get_device_ids()[i]) for i in range(num_devices)]

return run_all_gather_sharded(
t3k_mesh_device,
num_devices,
input_shape,
input_shard_shape,
shard_grid,
dim,
num_links,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
# num_cores,
use_program_cache,
function_level_defaults,
all_gather_operation,
devices,
enable_async,
n_worker,
n_buffer,
num_iter,
trace_mode,
)


def run_all_gather_sharded_n300(
all_devices,
num_devices,
input_shape,
input_shard_shape,
shard_grid,
dim,
num_links,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
# num_cores,
use_program_cache,
function_level_defaults,
all_gather_operation,
enable_async,
n_worker=None,
n_buffer=None,
num_iter=1,
trace_mode=False,
):
if len(all_devices) != 2:
pytest.skip("Not N300!")

for device in all_devices:
device.enable_async(enable_async)

return run_all_gather_sharded(
all_devices,
num_devices,
input_shape,
input_shard_shape,
shard_grid,
dim,
num_links,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
# num_cores,
use_program_cache,
function_level_defaults,
all_gather_operation,
all_devices,
enable_async,
n_worker,
n_buffer,
num_iter,
trace_mode,
)


# @pytest.mark.parametrize("num_devices", [4, 8])
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize("num_devices", [8])
Expand Down Expand Up @@ -1146,7 +1333,7 @@ def test_all_gather_sharded_post_commit(
function_level_defaults,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down Expand Up @@ -1236,7 +1423,7 @@ def test_all_gather_height_sharded_post_commit(
function_level_defaults,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down Expand Up @@ -1320,7 +1507,7 @@ def test_all_gather_block_sharded_post_commit(
function_level_defaults,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down Expand Up @@ -1412,7 +1599,7 @@ def test_line_all_gather_sharded_post_commit(
function_level_defaults,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down Expand Up @@ -1576,7 +1763,7 @@ def test_sharded_all_gather_nightly(
all_gather_operation,
enable_async,
):
run_all_gather_sharded(
run_all_gather_sharded_t3k(
t3k_mesh_device,
num_devices,
input_shape,
Expand Down
Loading
Loading