From 02bde94f14327e2328267d2bd2d5c7339b393d51 Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Wed, 11 Sep 2024 20:14:26 -0700 Subject: [PATCH 1/5] #12729: add row-major split for BERT --- .../misc/test_split_any_dim_rm.py | 118 ++++++++++++++++++ .../operations/data_movement/split/split.cpp | 80 +++++++++++- 2 files changed, 195 insertions(+), 3 deletions(-) create mode 100644 tests/tt_eager/python_api_testing/unit_testing/misc/test_split_any_dim_rm.py diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_split_any_dim_rm.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_split_any_dim_rm.py new file mode 100644 index 00000000000..ce839117501 --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_split_any_dim_rm.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +import sys +from loguru import logger +import random +import numpy as np +import ttnn + +from models.utility_functions import ( + comp_pcc, +) +import torch +import sys +import numpy +import pytest +import os +import itertools + +debug = False + +# TODO: test other dims/num_splits +two_chunk_dim_two_tests = list( + zip( + [[1, 2**k, 2] for k in range(10)], # shapes + itertools.repeat(2), # chunks + itertools.repeat(2), # dim + ) +) + +two_chunk_dim_two_ids = [ + "x".join(map(str, shape)) + f"@{dim}" + f"->{chunks}" for (shape, chunks, dim) in two_chunk_dim_two_tests +] + + +@pytest.mark.parametrize( + "in_mem_config", + ( + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), + ), + ids=["in_DRAM", "in_L1"], +) +@pytest.mark.parametrize( + "out_mem_config", + ( + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), + ), + ids=["out_DRAM", "out_L1"], +) +@pytest.mark.parametrize( + "refshape_chunks_dim", + tuple(two_chunk_dim_two_tests), + ids=two_chunk_dim_two_ids, +) +def test_split_rm(refshape_chunks_dim, in_mem_config, out_mem_config, device, dtype=ttnn.bfloat16): + (refshape, num_splits, dim) = refshape_chunks_dim + profile_location = "split_rm/" + os.system(f"rm -rf {profile_location}") + + torch.manual_seed(1234) + + Z = refshape[0] + Y = refshape[1] + X = refshape[2] + + assert dim in [0, 1, 2] + + if dim == 2: + chunk_shape = [Z, Y, X // num_splits] + elif dim == 1: + chunk_shape = [Z, Y // num_splits, X] + elif dim == 0: + chunk_shape = [Z // num_splits, Y, X] + + logger.info(f"Split tensor of size: {str(refshape)}") + + dtype_torch = torch.bfloat16 + + A = torch.arange(Z * Y * X, dtype=dtype_torch).reshape(refshape) + assert list(A.size()) == refshape + + a_t = ttnn.from_torch(A, layout=ttnn.Layout.ROW_MAJOR, dtype=dtype, memory_config=in_mem_config, device=device) + + # Check memory of inputs and outputs + # logger.debug(f"input to rm split is on: {a_t.memory_config().buffer_type}") + + dev_buffers = ttnn.split(a_t, num_splits, dim, memory_config=out_mem_config) + # dev_buffers come out tilized + pyt_buff_list = [] + + assert len(dev_buffers) == num_splits + for index, buff in enumerate(dev_buffers): + logger.debug(f"buff{index} is on: {buff.memory_config().buffer_type}") + assert list(buff.shape) == chunk_shape + tt_host_rm_buff = ( + buff.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(buff.get_legacy_shape().without_padding()) + ) + pyt_got_back_rm_buff = tt_host_rm_buff.to_torch() + pyt_buff_list.append(pyt_got_back_rm_buff) + + golden_buffers = torch.chunk(A, num_splits, dim=dim) + assert len(pyt_buff_list) == len(golden_buffers) + if debug: + for i in range(0, num_splits): + print(f"torch result [{i+1}]: ", golden_buffers[i][0, 0, 0]) + print(f"our result [{i+1}]: ", pyt_buff_list[i][0, 0, 0]) + print() + + for index, pyt_buff in enumerate(pyt_buff_list): + golden_buff = golden_buffers[index] + passing_pcc, output_pcc = comp_pcc(pyt_buff, golden_buff, 1.0) + logger.debug(f"Out passing={passing_pcc}") + logger.debug(f"Output pcc={output_pcc}") + assert passing_pcc diff --git a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp index eea04300d56..49be1a38c92 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp @@ -4,11 +4,14 @@ #include "ttnn/common/constants.hpp" +#include "ttnn/operations/core/core.hpp" #include "ttnn/run_operation.hpp" #include "device/split_op.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/reshape/reshape.hpp" #include "ttnn/operations/data_movement/transpose/transpose.hpp" +#include "ttnn/tensor/types.hpp" #include "ttnn/operations/data_movement/split/split.hpp" +#include "ttnn/operations/data_movement/slice/slice.hpp" namespace ttnn::operations::data_movement { @@ -16,8 +19,75 @@ namespace ttnn::operations::data_movement { namespace detail { - std::vector impl_split_last_dim_two_chunks_tiled(const Tensor &input_tensor, const MemoryConfig &mem_config) { + std::vector split_dim_n_chunks_rm(const Tensor &input_tensor, int dim, int num_splits, const MemoryConfig &mem_config) { + TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR, "This op only supports row major tensors."); + TT_FATAL(input_tensor.get_shape()[dim] % num_splits == 0, "Split dimension must be divisible by num_splits."); + auto input_shape = input_tensor.get_shape(); + auto input_rank = input_shape.size(); + + const bool on_host = input_tensor.storage_type() == StorageType::OWNED || input_tensor.storage_type() == StorageType::BORROWED; + std::optional dev = on_host ? std::nullopt : std::make_optional(input_tensor.device()); + + Tensor preprocessed = Tensor(input_tensor); + preprocessed = ttnn::unsqueeze_to_4D(preprocessed); // ensure we're 4D before slicing + dim += 4 - input_rank; // convert to 4D index + + if (!on_host && input_tensor.get_dtype() == DataType::BFLOAT16) { + preprocessed = preprocessed.cpu(); // bf16 tensors must be handled on host due to limitations in slice + } + + auto preproc_shape = preprocessed.get_shape(); + + auto chunk_len = preproc_shape[dim] / num_splits; + + std::vector output_tensors; + output_tensors.reserve(num_splits); + + for (int i = 0; i < num_splits; i++) { + auto start = i*chunk_len; + auto end = start + chunk_len - 1; + + std::vector start_shape(preproc_shape.size(), 0); + start_shape[dim] = start; + + std::vector end_shape(preproc_shape.size()); + for (int j = 0; j < end_shape.size(); j++) { + if (j == dim) { + end_shape[j] = end; + } else { + end_shape[j] = preproc_shape[j] - 1; + } + } + Tensor output_chunk = ttnn::slice(preprocessed, + tt::tt_metal::Shape(start_shape), + tt::tt_metal::Shape(end_shape), + std::nullopt, + mem_config); + if (input_rank < 4) { + output_chunk = ttnn::squeeze_from_4D(output_chunk, input_rank); + } + + tt::tt_metal::Layout layout = input_tensor.get_layout(); + if (dev && (input_tensor.dtype() == DataType::BFLOAT16 || input_tensor.dtype() == DataType::UINT16) + && chunk_len % 2 != 0) { + layout = Layout::TILE; // bf16 and uint16 tensors must be tiled if the chunk length is odd due to packing constraints + output_chunk = output_chunk.pad_to_tile(0.0); + } + + output_chunk = output_chunk.to(layout); + + if (dev) { + output_chunk = output_chunk.to(*dev); + } + + output_tensors.push_back(output_chunk); + } + + return output_tensors; + } + + std::vector impl_split_last_dim_two_chunks_tiled(const Tensor &input_tensor, const MemoryConfig &mem_config) { auto input_shape = input_tensor.get_legacy_shape(); auto padded_input_shape = ttnn::operations::experimental::auto_format::AutoFormat::pad_to_tile_shape(input_shape); ttnn::operations::experimental::auto_format::FormatParams input_format_params = {.pad_shape = padded_input_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; @@ -71,9 +141,13 @@ std::vector SplitOperation::invoke( const std::optional& memory_config_arg) { auto memory_config = memory_config_arg.value_or(input_tensor.memory_config()); - TT_FATAL(num_splits == 2, "Currently only supporting split in 2"); - return detail::split_dim_two_chunks_tiled(input_tensor, dim, memory_config); + if (input_tensor.get_layout() == Layout::ROW_MAJOR) { + return detail::split_dim_n_chunks_rm(input_tensor, dim, num_splits, memory_config); + } else { + TT_FATAL(num_splits == 2, "Currently only supporting split in 2"); + return detail::split_dim_two_chunks_tiled(input_tensor, dim, memory_config); + } } std::vector SplitOperation::invoke( From ae9905119281715271a88975f2008fc9a4581505 Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Wed, 18 Sep 2024 17:34:17 +0000 Subject: [PATCH 2/5] #0: update year on RM split pytest --- .../unit_testing/misc/test_split_any_dim_rm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_split_any_dim_rm.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_split_any_dim_rm.py index ce839117501..1259e2f9616 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_split_any_dim_rm.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_split_any_dim_rm.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 From d914debef976bc33c3e368e7b2bb4b4f799e1d7b Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Wed, 18 Sep 2024 17:35:19 +0000 Subject: [PATCH 3/5] #0: make 2-way split assert more specific and relocate --- ttnn/cpp/ttnn/operations/data_movement/split/split.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp index 49be1a38c92..26d280cc242 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp @@ -115,8 +115,9 @@ namespace detail { } -std::vector split_dim_two_chunks_tiled( - const Tensor &input_tensor, int dim /* = 3 */, const MemoryConfig &mem_config /* = default */) { +std::vector split_dim_n_chunks_tiled( + const Tensor &input_tensor, int dim /* = 3 */, int num_splits, const MemoryConfig &mem_config /* = default */) { + TT_FATAL(num_splits == 2, "ttnn.split currently only supports split in 2 in tiled layout, but {} is passed", num_splits); if (dim == 3) { return split_last_dim_two_chunks_tiled(input_tensor, mem_config); } @@ -145,8 +146,7 @@ std::vector SplitOperation::invoke( if (input_tensor.get_layout() == Layout::ROW_MAJOR) { return detail::split_dim_n_chunks_rm(input_tensor, dim, num_splits, memory_config); } else { - TT_FATAL(num_splits == 2, "Currently only supporting split in 2"); - return detail::split_dim_two_chunks_tiled(input_tensor, dim, memory_config); + return detail::split_dim_n_chunks_tiled(input_tensor, dim, num_splits, memory_config); } } From bf5eb55ff3b8bb3c1fbff00cfdb3f7f79a767e99 Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Thu, 19 Sep 2024 01:38:36 +0000 Subject: [PATCH 4/5] #0: Use LegacyShape --- ttnn/cpp/ttnn/operations/data_movement/split/split.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp index 26d280cc242..e1caeb0fbe1 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp @@ -60,8 +60,8 @@ namespace detail { } Tensor output_chunk = ttnn::slice(preprocessed, - tt::tt_metal::Shape(start_shape), - tt::tt_metal::Shape(end_shape), + tt::tt_metal::LegacyShape(start_shape), + tt::tt_metal::LegacyShape(end_shape), std::nullopt, mem_config); if (input_rank < 4) { From 0c7fc4959e4e65f0f24c5a2f86303d0f637faeea Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Thu, 19 Sep 2024 02:14:03 +0000 Subject: [PATCH 5/5] #0: Tweaks proposed by Artem --- .../ttnn/operations/data_movement/split/split.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp index e1caeb0fbe1..81a84ffb6db 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp @@ -20,16 +20,15 @@ namespace ttnn::operations::data_movement { namespace detail { std::vector split_dim_n_chunks_rm(const Tensor &input_tensor, int dim, int num_splits, const MemoryConfig &mem_config) { - TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR, "This op only supports row major tensors."); - TT_FATAL(input_tensor.get_shape()[dim] % num_splits == 0, "Split dimension must be divisible by num_splits."); + TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR, "ttnn.split only supports row major tensors."); + TT_FATAL(input_tensor.get_shape()[dim] % num_splits == 0, "Split dimension {} must be divisible by num_splits {}.", input_tensor.get_shape()[dim], num_splits); auto input_shape = input_tensor.get_shape(); auto input_rank = input_shape.size(); const bool on_host = input_tensor.storage_type() == StorageType::OWNED || input_tensor.storage_type() == StorageType::BORROWED; - std::optional dev = on_host ? std::nullopt : std::make_optional(input_tensor.device()); + std::optional device = on_host ? std::nullopt : std::make_optional(input_tensor.device()); - Tensor preprocessed = Tensor(input_tensor); - preprocessed = ttnn::unsqueeze_to_4D(preprocessed); // ensure we're 4D before slicing + Tensor preprocessed = ttnn::unsqueeze_to_4D(input_tensor); // ensure we're 4D before slicing dim += 4 - input_rank; // convert to 4D index if (!on_host && input_tensor.get_dtype() == DataType::BFLOAT16) { @@ -69,7 +68,7 @@ namespace detail { } tt::tt_metal::Layout layout = input_tensor.get_layout(); - if (dev && (input_tensor.dtype() == DataType::BFLOAT16 || input_tensor.dtype() == DataType::UINT16) + if (device && (input_tensor.dtype() == DataType::BFLOAT16 || input_tensor.dtype() == DataType::UINT16) && chunk_len % 2 != 0) { layout = Layout::TILE; // bf16 and uint16 tensors must be tiled if the chunk length is odd due to packing constraints output_chunk = output_chunk.pad_to_tile(0.0); @@ -77,8 +76,8 @@ namespace detail { output_chunk = output_chunk.to(layout); - if (dev) { - output_chunk = output_chunk.to(*dev); + if (device) { + output_chunk = output_chunk.to(*device); } output_tensors.push_back(output_chunk);