Skip to content

Commit

Permalink
better coding style
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 1, 2023
1 parent 3dbf76a commit 242817f
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions src/layer/vulkan/pooling_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,35 +514,36 @@ int Pooling_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute
// reduce more
while (reduced_blob.w > 32)
{
VkMat prev_reduced_blob = reduced_blob;

int reduced_size = (reduced_blob.w + 7) / 8;
size_t reduced_elemsize = pooling_type == 0 ? elemsize : 4u * elempack;
reduced_blob.create(reduced_size, 1, channels, reduced_elemsize, elempack, opt.workspace_vkallocator);
if (reduced_blob.empty())
VkMat reduced_blob2;
reduced_blob2.create(reduced_size, 1, channels, reduced_elemsize, elempack, opt.workspace_vkallocator);
if (reduced_blob2.empty())
return -100;

std::vector<VkMat> bindings(2);
bindings[0] = prev_reduced_blob;
bindings[1] = reduced_blob;
bindings[0] = reduced_blob;
bindings[1] = reduced_blob2;

std::vector<vk_constant_type> constants(5);
constants[0].i = prev_reduced_blob.w;
constants[1].i = prev_reduced_blob.c;
constants[2].i = prev_reduced_blob.cstep;
constants[3].i = reduced_blob.w;
constants[4].i = reduced_blob.cstep;
constants[0].i = reduced_blob.w;
constants[1].i = reduced_blob.c;
constants[2].i = reduced_blob.cstep;
constants[3].i = reduced_blob2.w;
constants[4].i = reduced_blob2.cstep;

const Pipeline* pipeline = elempack == 8 ? pipeline_pooling_global_reduce_pack8
: elempack == 4 ? pipeline_pooling_global_reduce_pack4
: pipeline_pooling_global_reduce;

VkMat dispatcher;
dispatcher.w = reduced_blob.w;
dispatcher.w = reduced_blob2.w;
dispatcher.h = 1;
dispatcher.c = reduced_blob.c;
dispatcher.c = reduced_blob2.c;

cmd.record_pipeline(pipeline, bindings, constants, dispatcher);

reduced_blob = reduced_blob2;
}

// reduce last
Expand Down Expand Up @@ -799,35 +800,36 @@ int Pooling_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_blob,
// reduce more
while (reduced_blob.w > 32)
{
VkImageMat prev_reduced_blob = reduced_blob;

int reduced_size = (reduced_blob.w + 7) / 8;
size_t reduced_elemsize = pooling_type == 0 ? elemsize : 4u * elempack;
reduced_blob.create(reduced_size, 1, channels, reduced_elemsize, elempack, opt.workspace_vkallocator);
if (reduced_blob.empty())
VkImageMat reduced_blob2;
reduced_blob2.create(reduced_size, 1, channels, reduced_elemsize, elempack, opt.workspace_vkallocator);
if (reduced_blob2.empty())
return -100;

std::vector<VkImageMat> bindings(2);
bindings[0] = prev_reduced_blob;
bindings[1] = reduced_blob;
bindings[0] = reduced_blob;
bindings[1] = reduced_blob2;

std::vector<vk_constant_type> constants(5);
constants[0].i = prev_reduced_blob.w;
constants[1].i = prev_reduced_blob.c;
constants[2].i = 0; //prev_reduced_blob.cstep;
constants[3].i = reduced_blob.w;
constants[4].i = 0; //reduced_blob.cstep;
constants[0].i = reduced_blob.w;
constants[1].i = reduced_blob.c;
constants[2].i = 0;//reduced_blob.cstep;
constants[3].i = reduced_blob2.w;
constants[4].i = 0;//reduced_blob2.cstep;

const Pipeline* pipeline = elempack == 8 ? pipeline_pooling_global_reduce_pack8
: elempack == 4 ? pipeline_pooling_global_reduce_pack4
: pipeline_pooling_global_reduce;

VkImageMat dispatcher;
dispatcher.w = reduced_blob.w;
dispatcher.w = reduced_blob2.w;
dispatcher.h = 1;
dispatcher.c = reduced_blob.c;
dispatcher.c = reduced_blob2.c;

cmd.record_pipeline(pipeline, bindings, constants, dispatcher);

reduced_blob = reduced_blob2;
}

// reduce last
Expand Down

0 comments on commit 242817f

Please sign in to comment.