Skip to content

Commit

Permalink
#0: add host fallback when datatype is 4 or 8 bit which is not compat…
Browse files Browse the repository at this point in the history
…ible with row major
  • Loading branch information
jvegaTT committed Nov 19, 2024
1 parent 842d3d7 commit b7370c3
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,15 @@ ttnn::Tensor convert_tensor_to_rm_reshape_convert_back_to_orig_layout(const ttnn
auto tensor_shape_with_padding = tensor_shape.padded_shape();

//Constraint in device kernel

uint32_t ROW_MAJOR_WIDTH = 32/tensor.element_size();
ttnn::Tensor reshaped_rm_tensor;
if((tensor_shape[-1] % ROW_MAJOR_WIDTH == 0 && shape[-1] % ROW_MAJOR_WIDTH == 0)) {
if(tensor.element_size()<2 && layout == ttnn::TILE_LAYOUT)
{
//Can't call to_layout on 4b and 8b datatypes
reshaped_rm_tensor = host_reshape(tensor, shape);
}
else if((tensor_shape[-1] % ROW_MAJOR_WIDTH == 0 && shape[-1] % ROW_MAJOR_WIDTH == 0)) {
auto rm_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr);
if (rm_tensor.is_contiguous()) {
// Page size depends on the width, so only modify the shape if the width is the same
Expand Down Expand Up @@ -105,11 +111,13 @@ ttnn::Shape tiling_reshape_corrector(const ttnn::Shape& shape) {
auto padded = shape.with_tile_padding();
auto rank = shape.rank();
const int8_t correction_1 =(ttnn::types::TILE_SIZE - (int)padded[-1] % ttnn::types::TILE_SIZE) % ttnn::types::TILE_SIZE;
if(rank == 1)
{
return ttnn::Shape({shape[0]},{padded[0]+correction_1});
}
const int8_t correction_2 =(ttnn::types::TILE_SIZE - (int)padded[-2] % ttnn::types::TILE_SIZE) % ttnn::types::TILE_SIZE;
switch(rank)
{
case 1:
return ttnn::Shape({shape[0]},{padded[0]+correction_1});
case 2:
return ttnn::Shape({shape[0],shape[1]},{padded[0]+correction_2,padded[1]+correction_1});
break;
Expand Down

0 comments on commit b7370c3

Please sign in to comment.