Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#12544: support wide channels (> 256) in maxpool #12625

Merged
merged 1 commit into from
Sep 13, 2024
Merged

Conversation

mywoodstock
Copy link
Contributor

Ticket

#12544

Problem description

Maxpool had limitation on max tiles for reduction to 8 (DST reg max).

What's changed

Added support for wide channels.

Checklist

  • Post commit CI passes
  • Blackhole Post commit (if applicable)
  • Model regression CI testing passes (if applicable)
  • Device performance regression CI testing passes (if applicable)
  • New/Existing tests provide coverage for changes

@@ -63,10 +63,13 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
uint32_t in_ntiles_hw = (uint32_t)std::ceil((float)kernel_size_hw_padded / tt::constants::TILE_HEIGHT);
uint32_t in_ntiles_c = (uint32_t)std::ceil((float)input_shape[3] / tt::constants::TILE_WIDTH);
uint32_t out_ntiles_c = (uint32_t)std::ceil((float)output_shape[3] / tt::constants::TILE_WIDTH);
uint32_t MAX_SMALL_KERNEL_SIZE_HW = 16;

const uint32_t MAX_SMALL_KERNEL_SIZE_HW = 16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer constexpr for this type of consts

constexpr bool is_partial_tile = in_c < 32;
static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16");
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;

tilizeA_B_reduce_init<true>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc, out_cb_id, num_faces_in_tile, window_size_hw);
uint32_t in_ntiles_hwc_block = in_ntiles_hwc / in_nblocks_c;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add constexpr here

#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/pack_untilize.h"
// #include "tools/profiler/kernel_profiler.hpp"

#define DEBUG_PRINT 0

#if DEBUG_PRINT == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like these should be in some common place for kernel debug utils

#endif

template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, uint32_t nblocks, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim>
template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim, uint32_t in_nblocks_c>
inline void reduce_h_fused(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"fused" in reduce_h_fused means tilize/untilize on the fly right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

@@ -113,7 +113,7 @@ def run_max_pool(
# interleaved_mem_config = ttnn.L1_MEMORY_CONFIG
# output = ttnn.to_memory_config(output, interleaved_mem_config)
output_host = output.cpu()
output_pytorch_padded = ttnn.to_torch(output_host)
output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed? ttnn.to_torch returns torch.Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really return pure torch tensor -- needs ttnn to subsequently read/manipulate it. This was the recommended way to convert it into pure torch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah can confirm. If you don't do it torch function will crash with segmentation fault.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫠 okay, good to know

## wide for vgg
[1, 256, 56, 56],
[1, 512, 28, 28],
[1, 512, 14, 14],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to test it on multiple batches too?
also what if W or H is not equal to each other.
Also the case wher W or H is equal to 1 is interesting too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are specific to what are used in models -- the generic testing will be all in sweep tests.

Copy link
Contributor

@dmakoviichuk-tt dmakoviichuk-tt Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if it works on other cases why not demonstrate it here?
As I see in description: it mentioned that added support for wide channels.
Nothing said that addeed support for vgg only or some other CNN.
In this case I expect to see other wide channels examples.

@mywoodstock mywoodstock merged commit f1f1d37 into main Sep 13, 2024
6 checks passed
@mywoodstock mywoodstock deleted the asarje/wide-maxpool branch September 13, 2024 17:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants