Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jun 28, 2023
1 parent 2e662f3 commit 0ba59e3
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 6 deletions.
23 changes: 22 additions & 1 deletion src/gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class GpuInfoPrivate
bool support_cooperative_matrix;
bool support_cooperative_matrix_16_8_8;
bool support_cooperative_matrix_16_8_16;
bool support_cooperative_matrix_16_16_16;

// extension capability
int support_VK_KHR_8bit_storage;
Expand Down Expand Up @@ -537,6 +538,11 @@ bool GpuInfo::support_cooperative_matrix_16_8_16() const
return d->support_cooperative_matrix_16_8_16;
}

bool GpuInfo::support_cooperative_matrix_16_16_16() const
{
return d->support_cooperative_matrix_16_16_16;
}

int GpuInfo::support_VK_KHR_8bit_storage() const
{
return d->support_VK_KHR_8bit_storage;
Expand Down Expand Up @@ -1535,6 +1541,7 @@ int create_gpu_instance()
gpu_info.support_cooperative_matrix = false;
gpu_info.support_cooperative_matrix_16_8_8 = false;
gpu_info.support_cooperative_matrix_16_8_16 = false;
gpu_info.support_cooperative_matrix_16_16_16 = false;
if (support_VK_KHR_get_physical_device_properties2)
{
void* queryExtensionFeatures = 0;
Expand Down Expand Up @@ -1699,6 +1706,13 @@ int create_gpu_instance()
{
gpu_info.support_cooperative_matrix_16_8_16 = true;
}
if (cmp.MSize == 16 && cmp.NSize == 16 && cmp.KSize == 16
&& cmp.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && cmp.BType == VK_COMPONENT_TYPE_FLOAT16_KHR
&& cmp.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && cmp.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR
&& cmp.scope == VK_SCOPE_SUBGROUP_KHR)
{
gpu_info.support_cooperative_matrix_16_16_16 = true;
}
}
}
else
Expand All @@ -1725,7 +1739,7 @@ int create_gpu_instance()
for (uint32_t j = 0; j < properties.size(); j++)
{
const VkCooperativeMatrixPropertiesNV& cmp = properties[j];
// NCNN_LOGE("cpm %2d %2d %2d %d %d %d %d %d", cmp.MSize, cmp.NSize, cmp.KSize, cmp.AType, cmp.BType, cmp.CType, cmp.DType, cmp.scope);
NCNN_LOGE("cpm %2d %2d %2d %d %d %d %d %d", cmp.MSize, cmp.NSize, cmp.KSize, cmp.AType, cmp.BType, cmp.CType, cmp.DType, cmp.scope);

if (cmp.MSize == 16 && cmp.NSize == 8 && cmp.KSize == 8
&& cmp.AType == VK_COMPONENT_TYPE_FLOAT16_NV && cmp.BType == VK_COMPONENT_TYPE_FLOAT16_NV
Expand All @@ -1741,6 +1755,13 @@ int create_gpu_instance()
{
gpu_info.support_cooperative_matrix_16_8_16 = true;
}
if (cmp.MSize == 16 && cmp.NSize == 16 && cmp.KSize == 16
&& cmp.AType == VK_COMPONENT_TYPE_FLOAT16_NV && cmp.BType == VK_COMPONENT_TYPE_FLOAT16_NV
&& cmp.CType == VK_COMPONENT_TYPE_FLOAT32_NV && cmp.DType == VK_COMPONENT_TYPE_FLOAT32_NV
&& cmp.scope == VK_SCOPE_SUBGROUP_NV)
{
gpu_info.support_cooperative_matrix_16_16_16 = true;
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class NCNN_EXPORT GpuInfo
bool support_cooperative_matrix() const;
bool support_cooperative_matrix_16_8_8() const;
bool support_cooperative_matrix_16_8_16() const;
bool support_cooperative_matrix_16_16_16() const;

// extension capability
int support_VK_KHR_8bit_storage() const;
Expand Down
61 changes: 56 additions & 5 deletions src/layer/vulkan/convolution_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,13 +728,42 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
}
else
{
bool use_cooperative_matrix_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && is_conv1x1s1d1 && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0;
bool use_cooperative_matrix = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && is_conv1x1s1d1 && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0;
if (use_cooperative_matrix)
// NCNN_LOGE("use_cooperative_matrix_16_16 = %d %d %d %d %d %d %d", use_cooperative_matrix_16_16, vkdev->info.support_cooperative_matrix_16_16_16(), opt.use_cooperative_matrix, is_conv1x1s1d1, !opt.use_image_storage, !opt.use_shader_pack8, opt.use_fp16_storage);
if (use_cooperative_matrix_16_16)
{
// dst = 8b-8a-inch/8a-outch/8b
// dst = 16b-16a-inch/16a-outch/16b
Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output);

weight_data_packed.create(maxk, num_input / 16, num_output / 16, (size_t)4 * 16 * 16, 16 * 16);

for (int q = 0; q + 15 < num_output; q += 16)
{
float* g00 = weight_data_packed.channel(q / 16);

for (int p = 0; p + 15 < num_input; p += 16)
{
for (int k = 0; k < maxk; k++)
{
for (int i = 0; i < 16; i++)
{
for (int j = 0; j < 16; j++)
{
const float* k00 = weight_data_r2.channel(q + j).row(p + i);
g00[0] = k00[k];
g00++;
}
}
}
}
}
}
else if (use_cooperative_matrix)
{
// dst = 8b-8a-inch/8a-outch/8b
Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output);

weight_data_packed.create(maxk, num_input / 8, num_output / 8, (size_t)4 * 8 * 8, 8 * 8);

for (int q = 0; q + 7 < num_output; q += 8)
Expand Down Expand Up @@ -879,6 +908,8 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
}
else if (is_conv1x1s1d1)
{
bool use_cooperative_matrix_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 16 == 0 && num_output % 16 == 0;

bool use_cooperative_matrix = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && num_input % 8 == 0 && num_output % 8 == 0;

std::vector<vk_specialization_type> specializations(4 + 8);
Expand Down Expand Up @@ -906,13 +937,24 @@ int Convolution_vulkan::create_pipeline(const Option& _opt)
if (elempack == 4 && out_elempack == 8) shader_type_index = LayerShaderType::convolution_pack4to8_1x1s1d1;
if (elempack == 8 && out_elempack == 4) shader_type_index = LayerShaderType::convolution_pack8to4_1x1s1d1;

if (use_cooperative_matrix)
if (use_cooperative_matrix_16_16)
{
shader_type_index = LayerShaderType::convolution_pack4_1x1s1d1_cm_16_16_16;
}
else if (use_cooperative_matrix)
{
shader_type_index = LayerShaderType::convolution_pack4_1x1s1d1_cm_16_8_8;
}

pipeline_convolution_1x1s1d1 = new Pipeline(vkdev);
if (use_cooperative_matrix)
if (use_cooperative_matrix_16_16)
{
// TODO proper unroll y
// pipeline_convolution_1x1s1d1->set_local_size_xyz(32, 4, 1); // 16_16_16 ly*4
// pipeline_convolution_1x1s1d1->set_local_size_xyz(32, 2, 1); // 16_16_16 ly*2
pipeline_convolution_1x1s1d1->set_local_size_xyz(32, 1, 1); // 16_16_16 ly*1
}
else if (use_cooperative_matrix)
{
// TODO proper unroll y
pipeline_convolution_1x1s1d1->set_local_size_xyz(32, 4, 1); // 16_8_8 ly*4
Expand Down Expand Up @@ -1475,6 +1517,8 @@ int Convolution_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCom
}
if (is_conv1x1s1d1)
{
bool use_cooperative_matrix_16_16 = vkdev->info.support_cooperative_matrix_16_16_16() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && channels * elempack % 16 == 0 && num_output % 16 == 0;

bool use_cooperative_matrix = vkdev->info.support_cooperative_matrix_16_8_8() && opt.use_cooperative_matrix && !opt.use_image_storage && !opt.use_shader_pack8 && opt.use_fp16_storage && channels * elempack % 8 == 0 && num_output % 8 == 0;

top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
Expand Down Expand Up @@ -1502,7 +1546,14 @@ int Convolution_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCom
dispatcher.h = top_blob.c;
dispatcher.c = 1;

if (use_cooperative_matrix)
if (use_cooperative_matrix_16_16)
{
dispatcher.w = ((top_blob.w * top_blob.h + 15) / 16 + 1) / 2 * 32;
// dispatcher.w = (top_blob.w * top_blob.h + 15) / 16 * 32;
dispatcher.h = (top_blob.c + 3) / 4;
dispatcher.c = 1;
}
else if (use_cooperative_matrix)
{
dispatcher.w = ((top_blob.w * top_blob.h + 15) / 16 + 3) / 4 * 32;
dispatcher.h = (top_blob.c + 1) / 2;
Expand Down
Loading

0 comments on commit 0ba59e3

Please sign in to comment.