-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#12544: support wide channels (> 256) in maxpool #12625
Conversation
2630487
to
7b96435
Compare
822148f
to
86d42a0
Compare
@@ -63,10 +63,13 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ | |||
uint32_t in_ntiles_hw = (uint32_t)std::ceil((float)kernel_size_hw_padded / tt::constants::TILE_HEIGHT); | |||
uint32_t in_ntiles_c = (uint32_t)std::ceil((float)input_shape[3] / tt::constants::TILE_WIDTH); | |||
uint32_t out_ntiles_c = (uint32_t)std::ceil((float)output_shape[3] / tt::constants::TILE_WIDTH); | |||
uint32_t MAX_SMALL_KERNEL_SIZE_HW = 16; | |||
|
|||
const uint32_t MAX_SMALL_KERNEL_SIZE_HW = 16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I prefer constexpr for this type of consts
constexpr bool is_partial_tile = in_c < 32; | ||
static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16"); | ||
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2; | ||
constexpr uint32_t num_out_rows = 1; | ||
|
||
tilizeA_B_reduce_init<true>(in_cb_id, in_scalar_cb_id, in_ntiles_hwc, out_cb_id, num_faces_in_tile, window_size_hw); | ||
uint32_t in_ntiles_hwc_block = in_ntiles_hwc / in_nblocks_c; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add constexpr here
#include "compute_kernel_api/tilize.h" | ||
#include "compute_kernel_api/reduce.h" | ||
#include "compute_kernel_api/pack_untilize.h" | ||
// #include "tools/profiler/kernel_profiler.hpp" | ||
|
||
#define DEBUG_PRINT 0 | ||
|
||
#if DEBUG_PRINT == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feels like these should be in some common place for kernel debug utils
#endif | ||
|
||
template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, uint32_t nblocks, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim> | ||
template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, bool is_partial_tile, uint32_t split_reader, uint32_t unpA_face_r_dim, uint32_t in_nblocks_c> | ||
inline void reduce_h_fused( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"fused" in reduce_h_fused means tilize/untilize on the fly right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
@@ -113,7 +113,7 @@ def run_max_pool( | |||
# interleaved_mem_config = ttnn.L1_MEMORY_CONFIG | |||
# output = ttnn.to_memory_config(output, interleaved_mem_config) | |||
output_host = output.cpu() | |||
output_pytorch_padded = ttnn.to_torch(output_host) | |||
output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed? ttnn.to_torch returns torch.Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't really return pure torch tensor -- needs ttnn to subsequently read/manipulate it. This was the recommended way to convert it into pure torch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah can confirm. If you don't do it torch function will crash with segmentation fault.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🫠 okay, good to know
## wide for vgg | ||
[1, 256, 56, 56], | ||
[1, 512, 28, 28], | ||
[1, 512, 14, 14], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to test it on multiple batches too?
also what if W or H is not equal to each other.
Also the case wher W or H is equal to 1 is interesting too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are specific to what are used in models -- the generic testing will be all in sweep tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if it works on other cases why not demonstrate it here?
As I see in description: it mentioned that added support for wide channels.
Nothing said that addeed support for vgg only or some other CNN.
In this case I expect to see other wide channels examples.
bde6978
to
ef019dc
Compare
Ticket
#12544
Problem description
Maxpool had limitation on max tiles for reduction to 8 (DST reg max).
What's changed
Added support for wide channels.
Checklist