Skip to content

Commit

Permalink
#12729: add row-major split for BERT
Browse files Browse the repository at this point in the history
  • Loading branch information
jaykru-tt authored Sep 19, 2024
1 parent ea8522c commit bcbf802
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
import sys
from loguru import logger
import random
import numpy as np
import ttnn

from models.utility_functions import (
comp_pcc,
)
import torch
import sys
import numpy
import pytest
import os
import itertools

debug = False

# TODO: test other dims/num_splits
two_chunk_dim_two_tests = list(
zip(
[[1, 2**k, 2] for k in range(10)], # shapes
itertools.repeat(2), # chunks
itertools.repeat(2), # dim
)
)

two_chunk_dim_two_ids = [
"x".join(map(str, shape)) + f"@{dim}" + f"->{chunks}" for (shape, chunks, dim) in two_chunk_dim_two_tests
]


@pytest.mark.parametrize(
"in_mem_config",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
ids=["in_DRAM", "in_L1"],
)
@pytest.mark.parametrize(
"out_mem_config",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize(
"refshape_chunks_dim",
tuple(two_chunk_dim_two_tests),
ids=two_chunk_dim_two_ids,
)
def test_split_rm(refshape_chunks_dim, in_mem_config, out_mem_config, device, dtype=ttnn.bfloat16):
(refshape, num_splits, dim) = refshape_chunks_dim
profile_location = "split_rm/"
os.system(f"rm -rf {profile_location}")

torch.manual_seed(1234)

Z = refshape[0]
Y = refshape[1]
X = refshape[2]

assert dim in [0, 1, 2]

if dim == 2:
chunk_shape = [Z, Y, X // num_splits]
elif dim == 1:
chunk_shape = [Z, Y // num_splits, X]
elif dim == 0:
chunk_shape = [Z // num_splits, Y, X]

logger.info(f"Split tensor of size: {str(refshape)}")

dtype_torch = torch.bfloat16

A = torch.arange(Z * Y * X, dtype=dtype_torch).reshape(refshape)
assert list(A.size()) == refshape

a_t = ttnn.from_torch(A, layout=ttnn.Layout.ROW_MAJOR, dtype=dtype, memory_config=in_mem_config, device=device)

# Check memory of inputs and outputs
# logger.debug(f"input to rm split is on: {a_t.memory_config().buffer_type}")

dev_buffers = ttnn.split(a_t, num_splits, dim, memory_config=out_mem_config)
# dev_buffers come out tilized
pyt_buff_list = []

assert len(dev_buffers) == num_splits
for index, buff in enumerate(dev_buffers):
logger.debug(f"buff{index} is on: {buff.memory_config().buffer_type}")
assert list(buff.shape) == chunk_shape
tt_host_rm_buff = (
buff.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(buff.get_legacy_shape().without_padding())
)
pyt_got_back_rm_buff = tt_host_rm_buff.to_torch()
pyt_buff_list.append(pyt_got_back_rm_buff)

golden_buffers = torch.chunk(A, num_splits, dim=dim)
assert len(pyt_buff_list) == len(golden_buffers)
if debug:
for i in range(0, num_splits):
print(f"torch result [{i+1}]: ", golden_buffers[i][0, 0, 0])
print(f"our result [{i+1}]: ", pyt_buff_list[i][0, 0, 0])
print()

for index, pyt_buff in enumerate(pyt_buff_list):
golden_buff = golden_buffers[index]
passing_pcc, output_pcc = comp_pcc(pyt_buff, golden_buff, 1.0)
logger.debug(f"Out passing={passing_pcc}")
logger.debug(f"Output pcc={output_pcc}")
assert passing_pcc
83 changes: 78 additions & 5 deletions ttnn/cpp/ttnn/operations/data_movement/split/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,89 @@


#include "ttnn/common/constants.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/run_operation.hpp"
#include "device/split_op.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/reshape/reshape.hpp"
#include "ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/tensor/types.hpp"
#include "ttnn/operations/data_movement/split/split.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"


namespace ttnn::operations::data_movement {


namespace detail {

std::vector<Tensor> impl_split_last_dim_two_chunks_tiled(const Tensor &input_tensor, const MemoryConfig &mem_config) {
std::vector<Tensor> split_dim_n_chunks_rm(const Tensor &input_tensor, int dim, int num_splits, const MemoryConfig &mem_config) {
TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR, "ttnn.split only supports row major tensors.");
TT_FATAL(input_tensor.get_shape()[dim] % num_splits == 0, "Split dimension {} must be divisible by num_splits {}.", input_tensor.get_shape()[dim], num_splits);
auto input_shape = input_tensor.get_shape();
auto input_rank = input_shape.size();

const bool on_host = input_tensor.storage_type() == StorageType::OWNED || input_tensor.storage_type() == StorageType::BORROWED;
std::optional<Device *> device = on_host ? std::nullopt : std::make_optional(input_tensor.device());

Tensor preprocessed = ttnn::unsqueeze_to_4D(input_tensor); // ensure we're 4D before slicing
dim += 4 - input_rank; // convert to 4D index

if (!on_host && input_tensor.get_dtype() == DataType::BFLOAT16) {
preprocessed = preprocessed.cpu(); // bf16 tensors must be handled on host due to limitations in slice
}

auto preproc_shape = preprocessed.get_shape();

auto chunk_len = preproc_shape[dim] / num_splits;

std::vector<Tensor> output_tensors;
output_tensors.reserve(num_splits);

for (int i = 0; i < num_splits; i++) {
auto start = i*chunk_len;
auto end = start + chunk_len - 1;

std::vector<uint32_t> start_shape(preproc_shape.size(), 0);
start_shape[dim] = start;

std::vector<uint32_t> end_shape(preproc_shape.size());
for (int j = 0; j < end_shape.size(); j++) {
if (j == dim) {
end_shape[j] = end;
} else {
end_shape[j] = preproc_shape[j] - 1;
}
}

Tensor output_chunk = ttnn::slice(preprocessed,
tt::tt_metal::LegacyShape(start_shape),
tt::tt_metal::LegacyShape(end_shape),
std::nullopt,
mem_config);
if (input_rank < 4) {
output_chunk = ttnn::squeeze_from_4D(output_chunk, input_rank);
}

tt::tt_metal::Layout layout = input_tensor.get_layout();
if (device && (input_tensor.dtype() == DataType::BFLOAT16 || input_tensor.dtype() == DataType::UINT16)
&& chunk_len % 2 != 0) {
layout = Layout::TILE; // bf16 and uint16 tensors must be tiled if the chunk length is odd due to packing constraints
output_chunk = output_chunk.pad_to_tile(0.0);
}

output_chunk = output_chunk.to(layout);

if (device) {
output_chunk = output_chunk.to(*device);
}

output_tensors.push_back(output_chunk);
}

return output_tensors;
}

std::vector<Tensor> impl_split_last_dim_two_chunks_tiled(const Tensor &input_tensor, const MemoryConfig &mem_config) {
auto input_shape = input_tensor.get_legacy_shape();
auto padded_input_shape = ttnn::operations::experimental::auto_format::AutoFormat::pad_to_tile_shape(input_shape);
ttnn::operations::experimental::auto_format::FormatParams input_format_params = {.pad_shape = padded_input_shape, .pad_value = 0.0, .target_layout = Layout::TILE};
Expand Down Expand Up @@ -45,8 +114,9 @@ namespace detail {
}


std::vector<Tensor> split_dim_two_chunks_tiled(
const Tensor &input_tensor, int dim /* = 3 */, const MemoryConfig &mem_config /* = default */) {
std::vector<Tensor> split_dim_n_chunks_tiled(
const Tensor &input_tensor, int dim /* = 3 */, int num_splits, const MemoryConfig &mem_config /* = default */) {
TT_FATAL(num_splits == 2, "ttnn.split currently only supports split in 2 in tiled layout, but {} is passed", num_splits);
if (dim == 3) {
return split_last_dim_two_chunks_tiled(input_tensor, mem_config);
}
Expand All @@ -71,9 +141,12 @@ std::vector<ttnn::Tensor> SplitOperation::invoke(
const std::optional<MemoryConfig>& memory_config_arg) {

auto memory_config = memory_config_arg.value_or(input_tensor.memory_config());
TT_FATAL(num_splits == 2, "Currently only supporting split in 2");
return detail::split_dim_two_chunks_tiled(input_tensor, dim, memory_config);

if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
return detail::split_dim_n_chunks_rm(input_tensor, dim, num_splits, memory_config);
} else {
return detail::split_dim_n_chunks_tiled(input_tensor, dim, num_splits, memory_config);
}
}

std::vector<ttnn::Tensor> SplitOperation::invoke(
Expand Down

0 comments on commit bcbf802

Please sign in to comment.