Skip to content

Commit

Permalink
#12250: un-deprecate moreh_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
o2buzzle committed Sep 5, 2024
1 parent 28964ee commit c128fe5
Show file tree
Hide file tree
Showing 20 changed files with 1,829 additions and 33 deletions.
256 changes: 256 additions & 0 deletions tests/ttnn/unit_tests/operations/test_moreh_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from loguru import logger

import ttnn
from models.utility_functions import comp_allclose_and_pcc
from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
)


def create_tt_tensor(tensor: torch.Tensor, device):
return ttnn.from_torch(tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)


def get_tensors(
input_shape, other_shape, output_shape, require_input_grad, require_other_grad, is_1d, device, use_randint=True
):
npu_dtype = ttnn.bfloat16
cpu_dtype = torch.bfloat16
npu_layout = ttnn.TILE_LAYOUT
cpu_layout = ttnn.ROW_MAJOR_LAYOUT

# create tensors for forward
if use_randint:
input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype)
other = torch.randint(-2, 3, other_shape, dtype=cpu_dtype)
output = torch.randint(-2, 3, output_shape, dtype=cpu_dtype)
else:
input = torch.rand(input_shape, dtype=cpu_dtype)
other = torch.rand(other_shape, dtype=cpu_dtype)
output = torch.rand(output_shape, dtype=cpu_dtype)

# tt_input = ttnn.Tensor(input, npu_dtype).pad_to_tile(float(1)).to(npu_layout).to(device)
# tt_other = ttnn.Tensor(other, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
# tt_output = ttnn.Tensor(output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)

tt_input = create_tt_tensor(input, device)
tt_other = create_tt_tensor(other, device)
tt_output = create_tt_tensor(output, device)

torch_input = input.reshape(-1) if is_1d else input
torch_other = other.reshape(-1) if is_1d else other

# tensors for backward
output_grad = tt_output_grad = torch_output_grad = tt_input_grad = tt_other_grad = None
if require_input_grad or require_other_grad:
output_grad = torch.randint(-2, 3, output_shape, dtype=cpu_dtype)
# tt_output_grad = ttnn.Tensor(output_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
tt_output_grad = ttnn.Tensor(output_grad, npu_dtype).pad_to_tile(float(-1)).to(npu_layout).to(device)
torch_output_grad = output_grad[0][0][0][0] if is_1d else output_grad

if require_input_grad:
input_grad = torch.full(input_shape, float("nan"), dtype=cpu_dtype)
tt_input_grad = ttnn.Tensor(input_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)

if require_other_grad:
other_grad = torch.full(other_shape, float("nan"), dtype=cpu_dtype)
tt_other_grad = (
ttnn.Tensor(
other_grad,
npu_dtype,
)
.pad_to_tile(float("nan"))
.to(npu_layout)
.to(device)
)

return (
tt_input,
tt_other,
tt_output,
tt_output_grad,
tt_input_grad,
tt_other_grad,
torch_input,
torch_other,
torch_output_grad,
)


def moreh_matmul(params, has_output, compute_kernel_config, device):
torch.manual_seed(3072)
input_shape, other_shape, output_shape, transpose_input, transpose_other = params
tt_input, tt_other, tt_output, _, _, _, torch_input, torch_other, _ = get_tensors(
input_shape, other_shape, output_shape, False, False, False, device
)
if not has_output:
tt_output = None

torch_input = torch_input.transpose(-1, -2) if transpose_input else torch_input
torch_other = torch_other.transpose(-1, -2) if transpose_other else torch_other

# tt matmul
cpu_layout = ttnn.ROW_MAJOR_LAYOUT
tt_output = ttnn.operations.moreh.matmul(
tt_input,
tt_other,
transpose_input=transpose_input,
transpose_other=transpose_other,
output=tt_output,
compute_kernel_config=compute_kernel_config,
)
tt_output_cpu = tt_output.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch()

# torch matmul
torch_out = torch.matmul(torch_input, torch_other)

# test for equivalance
rtol = atol = 0.1
passing, output_pcc = comp_allclose_and_pcc(torch_out, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"Out passing={passing}")
logger.debug(f"Output pcc={output_pcc}")

return passing


@pytest.mark.parametrize(
"params",
(
# input, other, output shape, transpose input, other
([32, 32], [32, 32], [32, 32], False, False), # single-core
([1024, 128], [128, 1024], [1024, 1024], False, False), # multi-core
([128, 1024], [128, 1024], [1024, 1024], True, False), # input transpose
([1024, 128], [1024, 128], [1024, 1024], False, True), # other transpose
([128, 1024], [1024, 128], [1024, 1024], True, True), # input, other transpose
([1020, 128], [128, 1024], [1020, 1024], False, False), # input mask
([1024, 128], [128, 1020], [1024, 1020], False, False), # other mask
([1020, 310], [310, 1020], [1020, 1020], False, False), # input, other mask
([128, 1020], [128, 1024], [1020, 1024], True, False), # input mask, transpose
([1024, 128], [1020, 128], [1024, 1020], False, True), # other mask, transpose
([310, 1020], [1020, 310], [1020, 1020], True, True), # input, other mask, transpose
([3, 1, 2, 1, 4, 1, 319, 95], [4, 2, 95, 470], [3, 1, 2, 1, 4, 2, 319, 470], False, False), # batched matmul
([2, 319, 95], [2, 1, 3, 4, 1, 95, 470], [2, 1, 3, 4, 2, 319, 470], False, False), # batched matmul
([3, 1, 2, 1, 4, 1, 95, 319], [4, 2, 95, 470], [3, 1, 2, 1, 4, 2, 319, 470], True, False), # batched matmul
([2, 319, 95], [2, 1, 3, 4, 1, 470, 95], [2, 1, 3, 4, 2, 319, 470], False, True), # batched matmul
(
[2, 3, 1, 2, 3, 2, 64, 64],
[2, 1, 4, 2, 1, 2, 64, 64],
[2, 3, 4, 2, 3, 2, 64, 64],
False,
False,
), # batched matmul
),
)
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_matmul(params, compute_kernel_options, device):
compute_kernel_config = get_compute_kernel_options(compute_kernel_options)
passing = moreh_matmul(params, True, compute_kernel_config, device)
assert passing


@pytest.mark.parametrize(
"params",
(
# input, other, output shape, transpose input, other
([32, 32], [32, 32], [32, 32], False, False), # single-core
([3, 1, 2, 1, 4, 1, 95, 319], [4, 2, 95, 470], [3, 1, 2, 1, 4, 2, 319, 470], True, False), # batched matmul
([2, 319, 95], [2, 1, 3, 4, 1, 470, 95], [2, 1, 3, 4, 2, 319, 470], False, True), # batched matmul
(
[2, 3, 1, 2, 3, 2, 64, 64],
[2, 1, 4, 2, 1, 2, 64, 64],
[2, 3, 4, 2, 3, 2, 64, 64],
False,
False,
), # batched matmul
),
)
def test_moreh_matmul_wo_output(params, device):
passing = moreh_matmul(params, False, None, device)
assert passing


@pytest.mark.parametrize(
"params",
(
# input, weight, bias(1d or scalar), output
([32, 32], [32, 32], [32, 32], False, False), # single-core
(
[2, 3, 1, 2, 3, 2, 64, 64],
[2, 1, 4, 2, 1, 2, 64, 64],
[2, 3, 4, 2, 3, 2, 64, 64],
False,
False,
), # batched matmul
),
)
def test_moreh_matmul_enable_cache(params, device, use_program_cache):
device.enable_program_cache()
torch.manual_seed(3072)
for i in range(4):
# change input's transpose option
if i % 2 == 1:
param_list = list(params)
param_list[3] = False if param_list[3] else True
params = tuple(param_list)
passing = moreh_matmul(params, False, None, device)
assert passing
assert device.num_program_cache_entries() == 2
device.disable_and_clear_program_cache()


@pytest.mark.parametrize(
"params",
(
# input, other, output shape, transpose input, other
([32, 3200], [3200, 32], [32, 32], False, False),
([3100, 31], [3100, 31], [31, 31], True, False),
),
)
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_matmul_fp32_dest_acc(params, compute_kernel_options, device):
torch.manual_seed(3072)
input_shape, other_shape, output_shape, transpose_input, transpose_other = params
tt_input, tt_other, tt_output, _, _, _, torch_input, torch_other, _ = get_tensors(
input_shape, other_shape, output_shape, False, False, False, device, use_randint=False
)

compute_kernel_config = get_compute_kernel_options(compute_kernel_options)

torch_input = torch_input.transpose(-1, -2) if transpose_input else torch_input
torch_other = torch_other.transpose(-1, -2) if transpose_other else torch_other

# tt matmul
cpu_layout = ttnn.ROW_MAJOR_LAYOUT
tt_output = ttnn.operations.moreh.matmul(
tt_input,
tt_other,
transpose_input=transpose_input,
transpose_other=transpose_other,
output=tt_output,
compute_kernel_config=compute_kernel_config,
)
tt_output_cpu = tt_output.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch()

# torch matmul (float)
torch_out = torch.matmul(torch_input.float(), torch_other.float())

# test for equivalance
rtol = atol = 0.1
passing, output_pcc = comp_allclose_and_pcc(torch_out, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"Out passing={passing}")
logger.debug(f"Output pcc={output_pcc}")
diff = torch.abs(torch_out - tt_output_cpu)
logger.debug(f"std={torch.std(diff)}")
logger.debug(f"mean={diff.mean()}")
logger.debug(f"topk(5) {torch.topk(diff.reshape(-1), 5)}")

# TODO
# assert passing
6 changes: 6 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/reshard/reshard_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp

${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp
)

# Split src and python bindings
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/pybind11/__init__.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ PYBIND11_MODULE(_ttnn, module) {
/*
We have to make sure every class and enum is bound before any function that uses it as an argument or a return type.
So we split the binding calls into two parts: one for classes and enums, and one for functions.
Another issue to be aware of is that we have to define each shared submodule only once. Therefore, all def_submodule calls
have to be put in here.
Another issue to be aware of is that we have to define each shared submodule only once. Therefore, all def_submodule
calls have to be put in here.
*/

// MODULES
Expand Down
4 changes: 4 additions & 0 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "ttnn/operations/kv_cache/kv_cache_pybind.hpp"
#include "ttnn/operations/loss/loss_pybind.hpp"
#include "ttnn/operations/matmul/matmul_pybind.hpp"
#include "ttnn/operations/moreh/moreh_pybind.hpp"
#include "ttnn/operations/normalization/normalization_pybind.hpp"
#include "ttnn/operations/pool/avgpool/avg_pool_pybind.hpp"
#include "ttnn/operations/pool/downsample/downsample_pybind.hpp"
Expand Down Expand Up @@ -130,6 +131,9 @@ void py_module(py::module& module) {

auto m_experimental = module.def_submodule("experimental", "experimental operations");
experimental::py_module(m_experimental);

auto m_moreh = module.def_submodule("moreh", "ttnn moreh");
moreh::py_module(m_moreh);
}

} // namespace operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp"

#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp"

using namespace tt::constants;

Expand Down Expand Up @@ -208,7 +209,12 @@ void MorehMatmul::validate_with_output_tensors(
get_tensor_dim(other_dim, other_shape);
for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) {
if (input_dim[i] != other_dim[i]) {
TT_FATAL(input_dim[i] == 1 || other_dim[i] ==1, "one of dim must be one. {}th dim input_dim {}, other_dim {}", i, input_dim[i], other_dim[i]);
TT_FATAL(
input_dim[i] == 1 || other_dim[i] == 1,
"one of dim must be one. {}th dim input_dim {}, other_dim {}",
i,
input_dim[i],
other_dim[i]);
}
}

Expand All @@ -225,7 +231,12 @@ void MorehMatmul::validate_with_output_tensors(
get_tensor_dim(output_dim, output_shape);

for (auto i = 2; i < tt::tt_metal::MAX_NUM_DIMENSIONS; ++i) {
TT_FATAL(std::max(input_dim[i], other_dim[i]) == output_dim[i], "{}th max(input_dim[i], other_dim[i]) {} must be the same as output_dim[i] {}", i, std::max(input_dim[i], other_dim[i]), output_dim[i]);
TT_FATAL(
std::max(input_dim[i], other_dim[i]) == output_dim[i],
"{}th max(input_dim[i], other_dim[i]) {} must be the same as output_dim[i] {}",
i,
std::max(input_dim[i], other_dim[i]),
output_dim[i]);
}
}

Expand All @@ -235,7 +246,11 @@ void MorehMatmul::validate_with_output_tensors(
uint32_t bias_rank = bias_wo_shape.rank();
uint32_t bias_w = bias_wo_shape[-1];
TT_FATAL(bias_rank == 2, "bias rank {} must be 2 (tilized).", bias_rank);
TT_FATAL(bias_w == 1 || bias_w == other_n, "bias_w must be one or the same as other_n. bias_w {}, other_n {}", bias_w, other_n);
TT_FATAL(
bias_w == 1 || bias_w == other_n,
"bias_w must be one or the same as other_n. bias_w {}, other_n {}",
bias_w,
other_n);
}
}

Expand All @@ -251,7 +266,8 @@ Tensor moreh_matmul_(
log_debug(LogOp, "{}:{} run matmul {} {}", __func__, __LINE__, transpose_input, transpose_other);

TT_FATAL(input.storage_type() == StorageType::DEVICE);
auto kernel_config_val = init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4);
auto kernel_config_val =
init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input, other}, {bias}))};

Expand Down Expand Up @@ -287,13 +303,13 @@ Tensor moreh_matmul(
const std::optional<const Tensor> bias,
const MemoryConfig& output_mem_config,
std::optional<const ttnn::DeviceComputeKernelConfig> compute_kernel_config) {

// TODO(seunghwan100): Add the argument "output_tensor" to moreh_dot.
if (is_dot_forward(input, other, transpose_input, transpose_other)) {
TT_ASSERT(!bias.has_value());
return moreh_dot(input, other, output_mem_config);
}
return moreh_matmul_(input, other, transpose_input, transpose_other, output, bias, output_mem_config, compute_kernel_config);
return moreh_matmul_(
input, other, transpose_input, transpose_other, output, bias, output_mem_config, compute_kernel_config);
}

} // namespace primary
Expand Down
Loading

0 comments on commit c128fe5

Please sign in to comment.