Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#12729: add row-major split for BERT #12804

Merged
merged 7 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved

# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
import sys
from loguru import logger
import random
import numpy as np
import ttnn

from models.utility_functions import (
comp_pcc,
)
import torch
import sys
import numpy
import pytest
import os
import itertools

debug = False

# TODO: test other dims/num_splits
two_chunk_dim_two_tests = list(
zip(
[[1, 2**k, 2] for k in range(10)], # shapes
itertools.repeat(2), # chunks
itertools.repeat(2), # dim
)
)

two_chunk_dim_two_ids = [
"x".join(map(str, shape)) + f"@{dim}" + f"->{chunks}" for (shape, chunks, dim) in two_chunk_dim_two_tests
]


@pytest.mark.parametrize(
"in_mem_config",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
ids=["in_DRAM", "in_L1"],
)
@pytest.mark.parametrize(
"out_mem_config",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize(
"refshape_chunks_dim",
tuple(two_chunk_dim_two_tests),
ids=two_chunk_dim_two_ids,
)
def test_split_rm(refshape_chunks_dim, in_mem_config, out_mem_config, device, dtype=ttnn.bfloat16):
(refshape, num_splits, dim) = refshape_chunks_dim
profile_location = "split_rm/"
os.system(f"rm -rf {profile_location}")

torch.manual_seed(1234)

Z = refshape[0]
Y = refshape[1]
X = refshape[2]

assert dim in [0, 1, 2]

if dim == 2:
chunk_shape = [Z, Y, X // num_splits]
elif dim == 1:
chunk_shape = [Z, Y // num_splits, X]
elif dim == 0:
chunk_shape = [Z // num_splits, Y, X]

logger.info(f"Split tensor of size: {str(refshape)}")

dtype_torch = torch.bfloat16

A = torch.arange(Z * Y * X, dtype=dtype_torch).reshape(refshape)
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
assert list(A.size()) == refshape

a_t = ttnn.from_torch(A, layout=ttnn.Layout.ROW_MAJOR, dtype=dtype, memory_config=in_mem_config, device=device)
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved

# Check memory of inputs and outputs
# logger.debug(f"input to rm split is on: {a_t.memory_config().buffer_type}")

dev_buffers = ttnn.split(a_t, num_splits, dim, memory_config=out_mem_config)
# dev_buffers come out tilized
pyt_buff_list = []

assert len(dev_buffers) == num_splits
for index, buff in enumerate(dev_buffers):
logger.debug(f"buff{index} is on: {buff.memory_config().buffer_type}")
assert list(buff.shape) == chunk_shape
tt_host_rm_buff = (
buff.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(buff.get_legacy_shape().without_padding())
)
pyt_got_back_rm_buff = tt_host_rm_buff.to_torch()
pyt_buff_list.append(pyt_got_back_rm_buff)

golden_buffers = torch.chunk(A, num_splits, dim=dim)
assert len(pyt_buff_list) == len(golden_buffers)
if debug:
for i in range(0, num_splits):
print(f"torch result [{i+1}]: ", golden_buffers[i][0, 0, 0])
print(f"our result [{i+1}]: ", pyt_buff_list[i][0, 0, 0])
print()

for index, pyt_buff in enumerate(pyt_buff_list):
golden_buff = golden_buffers[index]
passing_pcc, output_pcc = comp_pcc(pyt_buff, golden_buff, 1.0)
logger.debug(f"Out passing={passing_pcc}")
logger.debug(f"Output pcc={output_pcc}")
assert passing_pcc
80 changes: 77 additions & 3 deletions ttnn/cpp/ttnn/operations/data_movement/split/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,90 @@


#include "ttnn/common/constants.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/run_operation.hpp"
#include "device/split_op.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/reshape/reshape.hpp"
#include "ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/tensor/types.hpp"
#include "ttnn/operations/data_movement/split/split.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"


namespace ttnn::operations::data_movement {


namespace detail {

std::vector<Tensor> impl_split_last_dim_two_chunks_tiled(const Tensor &input_tensor, const MemoryConfig &mem_config) {
std::vector<Tensor> split_dim_n_chunks_rm(const Tensor &input_tensor, int dim, int num_splits, const MemoryConfig &mem_config) {
TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR, "This op only supports row major tensors.");
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
TT_FATAL(input_tensor.get_shape()[dim] % num_splits == 0, "Split dimension must be divisible by num_splits.");
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
auto input_shape = input_tensor.get_shape();
auto input_rank = input_shape.size();

const bool on_host = input_tensor.storage_type() == StorageType::OWNED || input_tensor.storage_type() == StorageType::BORROWED;
std::optional<Device *> dev = on_host ? std::nullopt : std::make_optional(input_tensor.device());
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved

Tensor preprocessed = Tensor(input_tensor);
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
preprocessed = ttnn::unsqueeze_to_4D(preprocessed); // ensure we're 4D before slicing
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
dim += 4 - input_rank; // convert to 4D index

if (!on_host && input_tensor.get_dtype() == DataType::BFLOAT16) {
preprocessed = preprocessed.cpu(); // bf16 tensors must be handled on host due to limitations in slice
}

auto preproc_shape = preprocessed.get_shape();

auto chunk_len = preproc_shape[dim] / num_splits;

std::vector<Tensor> output_tensors;
output_tensors.reserve(num_splits);

for (int i = 0; i < num_splits; i++) {
auto start = i*chunk_len;
auto end = start + chunk_len - 1;
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved

std::vector<uint32_t> start_shape(preproc_shape.size(), 0);
start_shape[dim] = start;

std::vector<uint32_t> end_shape(preproc_shape.size());
for (int j = 0; j < end_shape.size(); j++) {
if (j == dim) {
end_shape[j] = end;
} else {
end_shape[j] = preproc_shape[j] - 1;
}
}

Tensor output_chunk = ttnn::slice(preprocessed,
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
tt::tt_metal::Shape(start_shape),
tt::tt_metal::Shape(end_shape),
std::nullopt,
mem_config);
if (input_rank < 4) {
output_chunk = ttnn::squeeze_from_4D(output_chunk, input_rank);
}

tt::tt_metal::Layout layout = input_tensor.get_layout();
if (dev && (input_tensor.dtype() == DataType::BFLOAT16 || input_tensor.dtype() == DataType::UINT16)
&& chunk_len % 2 != 0) {
layout = Layout::TILE; // bf16 and uint16 tensors must be tiled if the chunk length is odd due to packing constraints
output_chunk = output_chunk.pad_to_tile(0.0);
}

output_chunk = output_chunk.to(layout);

if (dev) {
output_chunk = output_chunk.to(*dev);
}

output_tensors.push_back(output_chunk);
}

return output_tensors;
}

std::vector<Tensor> impl_split_last_dim_two_chunks_tiled(const Tensor &input_tensor, const MemoryConfig &mem_config) {
auto input_shape = input_tensor.get_legacy_shape();
auto padded_input_shape = ttnn::operations::experimental::auto_format::AutoFormat::pad_to_tile_shape(input_shape);
ttnn::operations::experimental::auto_format::FormatParams input_format_params = {.pad_shape = padded_input_shape, .pad_value = 0.0, .target_layout = Layout::TILE};
Expand Down Expand Up @@ -71,9 +141,13 @@ std::vector<ttnn::Tensor> SplitOperation::invoke(
const std::optional<MemoryConfig>& memory_config_arg) {

auto memory_config = memory_config_arg.value_or(input_tensor.memory_config());
TT_FATAL(num_splits == 2, "Currently only supporting split in 2");
return detail::split_dim_two_chunks_tiled(input_tensor, dim, memory_config);

if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
return detail::split_dim_n_chunks_rm(input_tensor, dim, num_splits, memory_config);
} else {
TT_FATAL(num_splits == 2, "Currently only supporting split in 2");
jaykru-tt marked this conversation as resolved.
Show resolved Hide resolved
return detail::split_dim_two_chunks_tiled(input_tensor, dim, memory_config);
}
}

std::vector<ttnn::Tensor> SplitOperation::invoke(
Expand Down
Loading