Skip to content

Commit

Permalink
#3188: Refactoring moreh_matmul op
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjin-na committed Nov 4, 2023
1 parent d612a23 commit 38ba158
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 98 deletions.
9 changes: 4 additions & 5 deletions tt_eager/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
#include "tt_metal/host_api.hpp"
#include "tt_metal/tools/profiler/op_profiler.hpp"

using namespace tt::constants;

namespace tt {

using namespace constants;
namespace operations {
namespace primary {

Expand All @@ -35,8 +35,7 @@ void MorehDot::validate(const std::vector<Tensor>& input_tensors) const {
TT_ASSERT(a_shape_wo_padding[3] == b_shape_wo_padding[3]);

TT_ASSERT(
input_tensor_a.dtype() == tt::tt_metal::DataType::BFLOAT16 ||
input_tensor_a.dtype() == tt::tt_metal::DataType::BFLOAT8_B,
input_tensor_a.dtype() == DataType::BFLOAT16 || input_tensor_a.dtype() == DataType::BFLOAT8_B,
"Unsupported data format");
TT_ASSERT(
input_tensor_a.storage_type() == StorageType::DEVICE and input_tensor_b.storage_type() == StorageType::DEVICE,
Expand Down Expand Up @@ -70,7 +69,7 @@ operation::ProgramWithCallbacks MorehDot::create_program(
return moreh_dot_single_core(input_tensor_a, input_tensor_b, output_tensor);
}

tt::stl::reflection::Attributes MorehDot::attributes() const {
stl::reflection::Attributes MorehDot::attributes() const {
return {
{"output_mem_config", this->output_mem_config},
{"output_dtype", this->output_dtype},
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct MorehDot {
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
tt::stl::reflection::Attributes attributes() const;
stl::reflection::Attributes attributes() const;
};

inline Tensor moreh_dot(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp"
#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/detail/util.hpp"
#include "tt_metal/host_api.hpp"
Expand All @@ -17,14 +18,11 @@ namespace primary {

operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Tensor &b, Tensor &output) {
Program program{};
CoreRange core = {.start = {0, 0}, .end = {0, 0}};
CoreCoord core = {0, 0};
const uint32_t core_num = 1;

tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype());
uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format);
tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.dtype());
uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format);
tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype());
uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format);
DataFormat cb_data_format = datatype_to_dataformat_converter(output.dtype());
uint32_t single_tile_size = detail::TileSize(cb_data_format);

tt_metal::Buffer *src0_buffer = a.buffer();
tt_metal::Buffer *src1_buffer = b.buffer();
Expand All @@ -42,91 +40,69 @@ operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Ten
tt_metal::Buffer *dst_buffer = output.buffer();
TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!");

uint32_t src0_cb_index = CB::c_in0;
uint32_t num_input_tiles = 2;
tt_metal::CircularBufferConfig cb_src0_config =
tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}})
.set_page_size(src0_cb_index, src0_single_tile_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, core, cb_src0_config);

uint32_t src1_cb_index = CB::c_in1;
tt_metal::CircularBufferConfig cb_src1_config =
tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}})
.set_page_size(src1_cb_index, src1_single_tile_size);
auto cb_src1 = tt_metal::CreateCircularBuffer(program, core, cb_src1_config);

uint32_t output_cb_index = CB::c_out0; // output operands start at index 16
uint32_t num_output_tiles = 2;
tt_metal::CircularBufferConfig cb_output_config =
tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}})
.set_page_size(output_cb_index, dst_single_tile_size);
auto cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config);

tt_metal::CircularBufferConfig cb_scaler_config =
tt_metal::CircularBufferConfig(dst_single_tile_size, {{CB::c_in2, dst_cb_data_format}})
.set_page_size(CB::c_in2, dst_single_tile_size);
auto cb_src2 = tt_metal::CreateCircularBuffer(program, core, cb_scaler_config);

uint32_t interm0_cb_index = CB::c_intermed0;
tt_metal::CircularBufferConfig interm0_cb_config =
tt_metal::CircularBufferConfig(dst_single_tile_size, {{interm0_cb_index, dst_cb_data_format}})
.set_page_size(interm0_cb_index, dst_single_tile_size);
auto cb_interm0 = tt_metal::CreateCircularBuffer(program, core, interm0_cb_config);

uint32_t interm1_cb_index = CB::c_intermed1;
tt_metal::CircularBufferConfig interm1_cb_config =
tt_metal::CircularBufferConfig(dst_single_tile_size, {{interm1_cb_index, dst_cb_data_format}})
.set_page_size(interm1_cb_index, dst_single_tile_size);
auto cb_interm1 = tt_metal::CreateCircularBuffer(program, core, interm1_cb_config);

bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
////////////////////////////////////////////////////////////////////////////
// CircularBuffer Setup
////////////////////////////////////////////////////////////////////////////
const uint32_t in0_t = 2; // a
const uint32_t in1_t = 2; // b
const uint32_t in2_t = 1; // scaler
const uint32_t out0_t = 2; // out
const uint32_t im0_t = 1;
const uint32_t im1_t = 1;

CreateCircularBuffer(
program,
std::set<CoreRange>{CoreRange{.start = core, .end = core}},
cb_data_format,
{
{CB::c_in0, in0_t},
{CB::c_in1, in1_t},
{CB::c_in2, in2_t},
{CB::c_out0, out0_t},
{CB::c_intermed0, im0_t},
{CB::c_intermed1, im1_t},
});

////////////////////////////////////////////////////////////////////////////
// DataMovementKernel SetUp
////////////////////////////////////////////////////////////////////////////
std::vector<uint32_t> reader_compile_time_args = {
(std::uint32_t)src0_is_dram, (std::uint32_t)src1_is_dram, *reinterpret_cast<uint32_t *>(&scaler)};
(std::uint32_t)is_dram(src0_buffer),
(std::uint32_t)is_dram(src1_buffer),
*reinterpret_cast<uint32_t *>(&scaler)};

bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram};
std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)CB::c_out0, (std::uint32_t)is_dram(dst_buffer)};

KernelID binary_reader_kernel_id = tt_metal::CreateDataMovementKernel(
program,
"tt_eager/tt_dnn/op_library/moreh_dot/single_core/kernels/reader_moreh_dot.cpp",
core,
tt_metal::DataMovementConfig{
.processor = tt_metal::DataMovementProcessor::RISCV_1,
.noc = tt_metal::NOC::RISCV_1_default,
.compile_args = reader_compile_time_args});
const auto reader_kernel_file = "tt_eager/tt_dnn/op_library/moreh_dot/single_core/kernels/reader_moreh_dot.cpp";
const auto writer_kernel_file = "tt_eager/tt_dnn/op_library/moreh_dot/single_core/kernels/writer_moreh_dot.cpp";

KernelID unary_writer_kernel_id = tt_metal::CreateDataMovementKernel(
program,
"tt_eager/tt_dnn/op_library/moreh_dot/single_core/kernels/writer_moreh_dot.cpp",
core,
tt_metal::DataMovementConfig{
.processor = tt_metal::DataMovementProcessor::RISCV_0,
.noc = tt_metal::NOC::RISCV_0_default,
.compile_args = writer_compile_time_args});
const auto reader_kernel_id = CreateReadKernel(program, reader_kernel_file, core, reader_compile_time_args);
const auto writer_kernel_id = CreateWriteKernel(program, writer_kernel_file, core, writer_compile_time_args);

////////////////////////////////////////////////////////////////////////////
// ComputeKernel SetUp
////////////////////////////////////////////////////////////////////////////
vector<uint32_t> compute_kernel_args = {};
std::map<string, string> defines;
defines["REDUCE_OP"] = "PoolType::SUM";
defines["REDUCE_DIM"] = "ReduceDim::REDUCE_ROW";

auto dot_kernel = tt_metal::CreateComputeKernel(
program,
"tt_eager/tt_dnn/op_library/moreh_dot/single_core/kernels/moreh_dot.cpp",
core,
tt_metal::ComputeConfig{.compile_args = compute_kernel_args, .defines = defines});

tt_metal::SetRuntimeArgs(
std::map<string, string> compute_defines;
compute_defines["REDUCE_OP"] = "PoolType::SUM";
compute_defines["REDUCE_DIM"] = "ReduceDim::REDUCE_ROW";

const auto compute_kernel_file = "tt_eager/tt_dnn/op_library/moreh_dot/single_core/kernels/moreh_dot.cpp";
const auto compute_kernel_id =
CreateComputeKernel(program, compute_kernel_file, {core, core_num, compute_kernel_args}, compute_defines);

////////////////////////////////////////////////////////////////////////////
// RuntimeArgs SetUp
////////////////////////////////////////////////////////////////////////////
SetRuntimeArgs(
program,
binary_reader_kernel_id,
reader_kernel_id,
core,
{src0_buffer->address(), src1_buffer->address(), num_tiles, 0, mask_h, mask_w});
SetRuntimeArgs(program, compute_kernel_id, core, {num_tiles, 1});
SetRuntimeArgs(program, writer_kernel_id, core, {output.buffer()->address(), 1, 0});

tt_metal::SetRuntimeArgs(program, dot_kernel, core, {num_tiles, 1});

tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, {output.buffer()->address(), 1, 0});

auto override_runtime_arguments_callback = [binary_reader_kernel_id, unary_writer_kernel_id, dot_kernel](
auto override_runtime_arguments_callback = [reader_kernel_id, writer_kernel_id, compute_kernel_id](
const void *operation,
const Program &program,
const std::vector<Tensor> &input_tensors,
Expand All @@ -142,24 +118,24 @@ operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Ten
uint32_t num_tiles = input_tensors.at(0).volume() / TILE_HW;

{
auto runtime_args = GetRuntimeArgs(program, binary_reader_kernel_id, core);
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = src_buffer_a->address();
runtime_args[1] = src_buffer_b->address();
runtime_args[2] = num_tiles;
SetRuntimeArgs(program, binary_reader_kernel_id, core, runtime_args);
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, dot_kernel, core);
auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
runtime_args[0] = num_tiles;
SetRuntimeArgs(program, dot_kernel, core, runtime_args);
SetRuntimeArgs(program, compute_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core);
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = dst_buffer->address();
runtime_args[1] = 1;
SetRuntimeArgs(program, unary_writer_kernel_id, core, runtime_args);
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback};
Expand Down
7 changes: 3 additions & 4 deletions tt_eager/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace tt {
namespace tt_metal {

Tensor moreh_matmul(const Tensor& input_tensor, const Tensor& other_tensor, const MemoryConfig& mem_config) {
return tt::operations::primary::moreh_matmul(input_tensor, other_tensor, std::nullopt, false, false, mem_config);
return operations::primary::moreh_matmul(input_tensor, other_tensor, std::nullopt, false, false, mem_config);
}

} // namespace tt_metal
Expand All @@ -38,8 +38,7 @@ void MorehMatmul::validate(const std::vector<Tensor>& input_tensors) const {
"Inputs to matmul must be tilized");

TT_ASSERT(
input_tensor.dtype() == tt::tt_metal::DataType::BFLOAT16 ||
input_tensor.dtype() == tt::tt_metal::DataType::BFLOAT8_B,
input_tensor.dtype() == DataType::BFLOAT16 || input_tensor.dtype() == DataType::BFLOAT8_B,
"Unsupported data format");
TT_ASSERT(
input_tensor.storage_type() == StorageType::DEVICE and other_tensor.storage_type() == StorageType::DEVICE,
Expand Down Expand Up @@ -78,7 +77,7 @@ operation::ProgramWithCallbacks MorehMatmul::create_program(
this->output_start_tile_id);
}

tt::stl::reflection::Attributes MorehMatmul::attributes() const {
stl::reflection::Attributes MorehMatmul::attributes() const {
return {
{"transpose_input", this->transpose_input},
{"transpose_other", this->transpose_other},
Expand Down

0 comments on commit 38ba158

Please sign in to comment.