Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#12729: add row-major split for BERT #12804

Merged
merged 7 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
import sys
from loguru import logger
import random
import numpy as np
import ttnn

from models.utility_functions import (
comp_pcc,
)
import torch
import sys
import numpy
import pytest
import os
import itertools

debug = False

# TODO: test other dims/num_splits
two_chunk_dim_two_tests = list(
zip(
[[1, 2**k, 2] for k in range(10)], # shapes
itertools.repeat(2), # chunks
itertools.repeat(2), # dim
)
)

two_chunk_dim_two_ids = [
"x".join(map(str, shape)) + f"@{dim}" + f"->{chunks}" for (shape, chunks, dim) in two_chunk_dim_two_tests
]


@pytest.mark.parametrize(
"in_mem_config",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
ids=["in_DRAM", "in_L1"],
)
@pytest.mark.parametrize(
"out_mem_config",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize(
"refshape_chunks_dim",
tuple(two_chunk_dim_two_tests),
ids=two_chunk_dim_two_ids,
)
def test_split_rm(refshape_chunks_dim, in_mem_config, out_mem_config, device, dtype=ttnn.bfloat16):
(refshape, num_splits, dim) = refshape_chunks_dim
profile_location = "split_rm/"
os.system(f"rm -rf {profile_location}")

torch.manual_seed(1234)

Z = refshape[0]
Y = refshape[1]
X = refshape[2]

assert dim in [0, 1, 2]

if dim == 2:
chunk_shape = [Z, Y, X // num_splits]
elif dim == 1:
chunk_shape = [Z, Y // num_splits, X]
elif dim == 0:
chunk_shape = [Z // num_splits, Y, X]

logger.info(f"Split tensor of size: {str(refshape)}")

dtype_torch = torch.bfloat16

A = torch.arange(Z * Y * X, dtype=dtype_torch).reshape(refshape)
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
assert list(A.size()) == refshape

a_t = ttnn.from_torch(A, layout=ttnn.Layout.ROW_MAJOR, dtype=dtype, memory_config=in_mem_config, device=device)
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved

# Check memory of inputs and outputs
# logger.debug(f"input to rm split is on: {a_t.memory_config().buffer_type}")

dev_buffers = ttnn.split(a_t, num_splits, dim, memory_config=out_mem_config)
# dev_buffers come out tilized
pyt_buff_list = []

assert len(dev_buffers) == num_splits
for index, buff in enumerate(dev_buffers):
logger.debug(f"buff{index} is on: {buff.memory_config().buffer_type}")
assert list(buff.shape) == chunk_shape
tt_host_rm_buff = (
buff.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(buff.get_legacy_shape().without_padding())
)
pyt_got_back_rm_buff = tt_host_rm_buff.to_torch()
pyt_buff_list.append(pyt_got_back_rm_buff)

golden_buffers = torch.chunk(A, num_splits, dim=dim)
assert len(pyt_buff_list) == len(golden_buffers)
if debug:
for i in range(0, num_splits):
print(f"torch result [{i+1}]: ", golden_buffers[i][0, 0, 0])
print(f"our result [{i+1}]: ", pyt_buff_list[i][0, 0, 0])
print()

for index, pyt_buff in enumerate(pyt_buff_list):
golden_buff = golden_buffers[index]
passing_pcc, output_pcc = comp_pcc(pyt_buff, golden_buff, 1.0)
logger.debug(f"Out passing={passing_pcc}")
logger.debug(f"Output pcc={output_pcc}")
assert passing_pcc
83 changes: 78 additions & 5 deletions ttnn/cpp/ttnn/operations/data_movement/split/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,89 @@


#include "ttnn/common/constants.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/run_operation.hpp"
#include "device/split_op.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/reshape/reshape.hpp"
#include "ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/tensor/types.hpp"
#include "ttnn/operations/data_movement/split/split.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"


namespace ttnn::operations::data_movement {


namespace detail {

std::vector<Tensor> impl_split_last_dim_two_chunks_tiled(const Tensor &input_tensor, const MemoryConfig &mem_config) {
std::vector<Tensor> split_dim_n_chunks_rm(const Tensor &input_tensor, int dim, int num_splits, const MemoryConfig &mem_config) {
TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR, "ttnn.split only supports row major tensors.");
TT_FATAL(input_tensor.get_shape()[dim] % num_splits == 0, "Split dimension {} must be divisible by num_splits {}.", input_tensor.get_shape()[dim], num_splits);
auto input_shape = input_tensor.get_shape();
auto input_rank = input_shape.size();

const bool on_host = input_tensor.storage_type() == StorageType::OWNED || input_tensor.storage_type() == StorageType::BORROWED;
std::optional<Device *> device = on_host ? std::nullopt : std::make_optional(input_tensor.device());

Tensor preprocessed = ttnn::unsqueeze_to_4D(input_tensor); // ensure we're 4D before slicing
dim += 4 - input_rank; // convert to 4D index

if (!on_host && input_tensor.get_dtype() == DataType::BFLOAT16) {
preprocessed = preprocessed.cpu(); // bf16 tensors must be handled on host due to limitations in slice
}

auto preproc_shape = preprocessed.get_shape();

auto chunk_len = preproc_shape[dim] / num_splits;

std::vector<Tensor> output_tensors;
output_tensors.reserve(num_splits);

for (int i = 0; i < num_splits; i++) {
auto start = i*chunk_len;
auto end = start + chunk_len - 1;
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved

std::vector<uint32_t> start_shape(preproc_shape.size(), 0);
start_shape[dim] = start;

std::vector<uint32_t> end_shape(preproc_shape.size());
for (int j = 0; j < end_shape.size(); j++) {
if (j == dim) {
end_shape[j] = end;
} else {
end_shape[j] = preproc_shape[j] - 1;
}
}

Tensor output_chunk = ttnn::slice(preprocessed,
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
tt::tt_metal::LegacyShape(start_shape),
tt::tt_metal::LegacyShape(end_shape),
std::nullopt,
mem_config);
if (input_rank < 4) {
output_chunk = ttnn::squeeze_from_4D(output_chunk, input_rank);
}

tt::tt_metal::Layout layout = input_tensor.get_layout();
if (device && (input_tensor.dtype() == DataType::BFLOAT16 || input_tensor.dtype() == DataType::UINT16)
&& chunk_len % 2 != 0) {
layout = Layout::TILE; // bf16 and uint16 tensors must be tiled if the chunk length is odd due to packing constraints
output_chunk = output_chunk.pad_to_tile(0.0);
}

output_chunk = output_chunk.to(layout);

if (device) {
output_chunk = output_chunk.to(*device);
}

output_tensors.push_back(output_chunk);
}

return output_tensors;
}

std::vector<Tensor> impl_split_last_dim_two_chunks_tiled(const Tensor &input_tensor, const MemoryConfig &mem_config) {
auto input_shape = input_tensor.get_legacy_shape();
auto padded_input_shape = ttnn::operations::experimental::auto_format::AutoFormat::pad_to_tile_shape(input_shape);
ttnn::operations::experimental::auto_format::FormatParams input_format_params = {.pad_shape = padded_input_shape, .pad_value = 0.0, .target_layout = Layout::TILE};
Expand Down Expand Up @@ -45,8 +114,9 @@ namespace detail {
}


std::vector<Tensor> split_dim_two_chunks_tiled(
const Tensor &input_tensor, int dim /* = 3 */, const MemoryConfig &mem_config /* = default */) {
std::vector<Tensor> split_dim_n_chunks_tiled(
const Tensor &input_tensor, int dim /* = 3 */, int num_splits, const MemoryConfig &mem_config /* = default */) {
TT_FATAL(num_splits == 2, "ttnn.split currently only supports split in 2 in tiled layout, but {} is passed", num_splits);
if (dim == 3) {
return split_last_dim_two_chunks_tiled(input_tensor, mem_config);
}
Expand All @@ -71,9 +141,12 @@ std::vector<ttnn::Tensor> SplitOperation::invoke(
const std::optional<MemoryConfig>& memory_config_arg) {

auto memory_config = memory_config_arg.value_or(input_tensor.memory_config());
TT_FATAL(num_splits == 2, "Currently only supporting split in 2");
return detail::split_dim_two_chunks_tiled(input_tensor, dim, memory_config);

if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
return detail::split_dim_n_chunks_rm(input_tensor, dim, num_splits, memory_config);
} else {
return detail::split_dim_n_chunks_tiled(input_tensor, dim, num_splits, memory_config);
}
}

std::vector<ttnn::Tensor> SplitOperation::invoke(
Expand Down
Loading