Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce scatter perf sweep #12391

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull, get_devices_for_t3000
from tests.ttnn.unit_tests.operations.test_reduce_scatter_post_commit import run_reduce_scatter_sharded_test


@pytest.mark.timeout(120)
@pytest.mark.parametrize(
"num_devices, num_links",
[
# (4, 1),
(8, 1),
],
)
@pytest.mark.parametrize("dim", [3])
@pytest.mark.parametrize(
"tensor_mem_layout",
[
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
# ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
# ttnn.TensorMemoryLayout.BLOCK_SHARDED,
],
)
@pytest.mark.parametrize("tensor_layout", [ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("orientation", [ttnn.ShardOrientation.ROW_MAJOR])
@pytest.mark.parametrize(
"input_dtype",
[
# ttnn.bfloat16,
ttnn.bfloat8_b,
],
)
@pytest.mark.parametrize(
"per_chip_output_shape,output_shard_shape,shard_grid",
(
# LLama
(
(1, 1, 32, 1024),
(32, 32),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}),
),
),
ids=["llama70b_t3k_decode"],
)
@pytest.mark.parametrize(
"n_buffer",
(
# LLama
1,
2,
3,
4,
6,
8,
),
)
@pytest.mark.parametrize(
"n_worker",
(
# LLama
2,
4,
8,
10,
12,
),
)
@pytest.mark.parametrize("num_iters", [1000])
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 17068032}], indirect=True)
def test_width_sharded_reduce_scatter_post_commit(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
output_shard_shape,
dim,
num_links,
math_op,
shard_grid,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
use_program_cache,
function_level_defaults,
enable_async,
num_iters,
n_worker,
n_buffer,
):
logger.info(f"Running for n_worker={n_worker}, n_buffer={n_buffer}:")
run_reduce_scatter_sharded_test(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
output_shard_shape,
dim,
num_links,
math_op,
shard_grid,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
use_program_cache=use_program_cache,
function_level_defaults=function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
n_worker=n_worker,
n_buffer=n_buffer,
trace_mode=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,58 @@ def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devic
return False, ""


def run_with_trace(
t3k_mesh_device,
input_tensor_mesh,
scatter_dim,
num_links,
math_op,
output_mem_config,
n_worker,
n_buffer,
num_iters,
):
# Compile Run
logger.info("Compiling model")
output_tensor_mesh = ttnn.reduce_scatter(
input_tensor_mesh,
scatter_dim=scatter_dim,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
num_workers=n_worker,
num_buffers_per_channel=n_buffer,
)
for device_id in t3k_mesh_device.get_device_ids():
ttnn.synchronize_device(t3k_mesh_device.get_device(device_id))

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(t3k_mesh_device, cq_id=0)
for i in range(num_iters):
output_tensor_mesh = ttnn.reduce_scatter(
input_tensor_mesh,
scatter_dim=scatter_dim,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
num_workers=n_worker,
num_buffers_per_channel=n_buffer,
)
ttnn.end_trace_capture(t3k_mesh_device, trace_id, cq_id=0)
for device_id in t3k_mesh_device.get_device_ids():
ttnn.synchronize_device(t3k_mesh_device.get_device(device_id))

# Run the op
logger.info("Starting Trace perf test...")
ttnn.execute_trace(t3k_mesh_device, trace_id, blocking=False)
ttnn.release_trace(t3k_mesh_device, trace_id)
for device_id in t3k_mesh_device.get_device_ids():
ttnn.synchronize_device(t3k_mesh_device.get_device(device_id))

return output_tensor_mesh


def run_reduce_scatter_test(
t3k_mesh_device,
num_devices,
Expand Down Expand Up @@ -215,6 +267,9 @@ def run_reduce_scatter_sharded_test(
function_level_defaults,
enable_async=True,
num_iters=1,
n_worker=None,
n_buffer=None,
trace_mode=False,
):
if len(t3k_mesh_device.get_device_ids()) != 8:
pytest.skip("Not T3000!")
Expand Down Expand Up @@ -269,19 +324,33 @@ def run_reduce_scatter_sharded_test(
)

input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)

# Run the op
for i in range(num_iters):
output_tensor_mesh = ttnn.reduce_scatter(
if trace_mode:
output_tensor_mesh = run_with_trace(
t3k_mesh_device,
input_tensor_mesh,
scatter_dim=scatter_dim,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
scatter_dim,
num_links,
math_op,
output_mem_config,
n_worker,
n_buffer,
num_iters,
)

for device_id in t3k_mesh_device.get_device_ids():
ttnn.synchronize_device(t3k_mesh_device.get_device(device_id))
logger.info(f"Done iteration {i}")
else:
for i in range(num_iters):
output_tensor_mesh = ttnn.reduce_scatter(
input_tensor_mesh,
scatter_dim=scatter_dim,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
)

for device_id in t3k_mesh_device.get_device_ids():
ttnn.synchronize_device(t3k_mesh_device.get_device(device_id))
logger.info(f"Done iteration {i}")

# Compute golden
# TODO: Make it model how reduce scatter actually works for numerical correctness/ordering
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,9 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers(
const uint32_t ring_index,
const std::optional<chip_id_t> receiver_device_id,
const std::optional<chip_id_t> sender_device_id,
ttnn::ccl::Topology topology) {
ttnn::ccl::Topology topology,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel) {
log_trace(tt::LogOp, "reduce_scatter_with_workers entry");
TT_ASSERT(
input_tensor.get_legacy_shape()[scatter_split_dim] ==
Expand All @@ -750,13 +752,17 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers(
input_tensor_n_elems_per_slice / (tt::constants::TILE_WIDTH * tt::constants::TILE_HEIGHT);

TT_ASSERT(input_tensor_num_units_per_tensor_slice > 0);
uint32_t max_num_workers = std::min<std::size_t>(8, input_tensor_num_units_per_tensor_slice);
uint32_t max_num_workers = std::min<std::size_t>(user_defined_num_workers.value_or(8), input_tensor_num_units_per_tensor_slice);
bool enable_bidirectional = true;
auto num_edm_channels = decide_number_of_edm_channels(op_config, max_num_workers, enable_bidirectional);
std::size_t num_edm_channels = decide_number_of_edm_channels(op_config, max_num_workers, enable_bidirectional);
log_trace(tt::LogOp, "num_edm_channels: {}", num_edm_channels);
auto edm_termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED;

constexpr std::size_t num_buffers_per_channel = 2; // enable double buffering later
std::size_t num_buffers_per_channel = 2;
if (user_defined_num_buffers_per_channel.has_value()) {
// Override with user defined value
num_buffers_per_channel = user_defined_num_buffers_per_channel.value();
}
auto const& edm_builder = create_erisc_datamover_builder(
num_edm_channels, op_config.get_page_size(), num_buffers_per_channel, buffer_sharing_mode, edm_termination_mode);
TT_ASSERT(num_edm_channels > 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ operation::ProgramWithCallbacks ReduceScatter::create_program(
this->ring_index,
this->receiver_device_id,
this->sender_device_id,
this->topology);
this->topology,
this->user_defined_num_workers,
this->user_defined_num_buffers_per_channel);
}

static ttnn::operations::binary::BinaryOpType convert_reduce_type_to_eltwise_type(ttnn::operations::reduction::ReduceType reduce_op) {
Expand All @@ -72,15 +74,17 @@ Tensor reduce_scatter(
const uint32_t scatter_dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links,
const MemoryConfig& output_mem_config) {
const MemoryConfig& output_mem_config,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel) {
ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op);
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring;
TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch");

auto devices = input_tensor.get_workers();
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
operation::launch_op(
[binary_op_type, scatter_dim, num_links, output_mem_config, topology, devices](
[binary_op_type, scatter_dim, num_links, output_mem_config, topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
Expand Down Expand Up @@ -114,7 +118,9 @@ Tensor reduce_scatter(
receiver_device_id,
sender_device_id,
output_mem_config,
topology},
topology,
user_defined_num_workers,
user_defined_num_buffers_per_channel},
{input_tensor});
},
{input_tensor},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ struct ReduceScatter {
const std::optional<chip_id_t> sender_device_id;
const MemoryConfig output_mem_config;
const ttnn::ccl::Topology topology;
const std::optional<size_t> user_defined_num_workers;
const std::optional<size_t> user_defined_num_buffers_per_channel;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
Expand All @@ -41,7 +43,9 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers(
const uint32_t ring_index,
const std::optional<chip_id_t> receiver_device_id,
const std::optional<chip_id_t> sender_device_id,
ttnn::ccl::Topology topology);
ttnn::ccl::Topology topology,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel);
}
}; // namespace ccl

Expand All @@ -52,7 +56,9 @@ namespace ccl{
const uint32_t scatter_split_dim,
ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum,
const uint32_t num_links = 1,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
const std::optional<size_t> user_defined_num_workers = std::nullopt,
const std::optional<size_t> user_defined_num_buffers_per_channel = std::nullopt);
} // namespace ccl
} // namespace operations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ ttnn::Tensor ExecuteReduceScatter::invoke(
const uint32_t scatter_dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config) {
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel) {

MemoryConfig out_memory_config = memory_config.value_or(input_tensor.memory_config());
return ttnn::operations::ccl::reduce_scatter(input_tensor, scatter_dim, math_op, num_links, out_memory_config);
return ttnn::operations::ccl::reduce_scatter(input_tensor, scatter_dim, math_op, num_links, out_memory_config, num_workers, num_buffers_per_channel);
}

} // namespace ttnn::operations::ccl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ struct ExecuteReduceScatter {
const uint32_t scatter_dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links = 1,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt);
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> num_workers = std::nullopt,
const std::optional<size_t> num_buffers_per_channel = std::nullopt);
};

} // namespace ccl
Expand Down
Loading
Loading