Skip to content

Commit

Permalink
#12938: squeeze fix for dim == 0 with padding
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar committed Sep 20, 2024
1 parent e23e894 commit a87fbec
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
11 changes: 10 additions & 1 deletion tests/ttnn/unit_tests/test_squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,18 @@
((1, 1, 1, 256), -1),
((1, 1, 1, 30), 2),
((1, 1, 1, 30), -1),
((1, 32, 16), 0),
((1, 1, 24576), 0),
((1, 19), 0),
((1, 1, 480, 640), 1),
((3, 1370, 1, 1, 1280), -2),
((3, 197, 1, 1, 1024), -2),
((3, 197, 1, 1, 768), -2),
((3, 50, 1, 1, 1024), -2),
((3, 50, 1, 1, 768), -2),
],
)
def test_squeeze_as_reshape(device, input_shape, dim):
def test_squeeze(device, input_shape, dim):
torch_input_tensor = torch.rand(input_shape, dtype=torch.float32)
torch_squeeze_tensor = torch.squeeze(torch_input_tensor, dim)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
Expand Down
27 changes: 15 additions & 12 deletions ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,33 @@ ttnn::Tensor SqueezeOperation::invoke(
const int dim
) {

const auto tensor_shape = input_tensor.get_shape();
const auto rank = tensor_shape.rank();
std::vector<uint32_t> output_shape_vector;
const auto original_logical_shape = input_tensor.get_shape();
const auto padded_shape = input_tensor.get_shape().with_tile_padding();
const auto input_tensor_rank = original_logical_shape.rank();

int normal_dim = dim;
if (dim < 0) {
// Handle negative dimension by converting it to positive
normal_dim += rank;
normal_dim += input_tensor_rank;
}

// Remove the dimension if it is of size 1
for (size_t i = 0; i < tensor_shape.size(); ++i) {
if (static_cast<int>(i) != normal_dim || tensor_shape[i] != 1) {
output_shape_vector.push_back(tensor_shape[i]);
std::vector<uint32_t> original_logical_shape_vector(input_tensor_rank - 1);
std::vector<uint32_t> padded_shape_vector(input_tensor_rank - 1);
uint32_t vector_id = 0;
for(int i=0; i< input_tensor_rank; i++) {
if(i != normal_dim or original_logical_shape[i] != 1) {
original_logical_shape_vector[vector_id] = original_logical_shape[i];
padded_shape_vector[vector_id] = padded_shape[i];
vector_id++;
}
}

// If dim is out of range or original dimension was not of size 1, include all dimensions
if (dim >= static_cast<int>(tensor_shape.size()) || tensor_shape[dim] != 1) {
//// If dim is out of range or original dimension was not of size 1, include all dimensions
if (normal_dim >= static_cast<int>(original_logical_shape.size()) || original_logical_shape[normal_dim] != 1) {
return input_tensor;
}

ttnn::Shape output_shape(output_shape_vector);
return ttnn::reshape(input_tensor, output_shape);
return ttnn::reshape(input_tensor, ttnn::Shape(original_logical_shape_vector, padded_shape_vector));

}

Expand Down

0 comments on commit a87fbec

Please sign in to comment.