From 40722c2016cecb2276eafccd7eb22ac7b85c9b30 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 18 Oct 2024 22:18:15 +0800 Subject: [PATCH] avx vnni int8, avx vnni int16, avx ne convert infrastructure --- CMakeLists.txt | 48 +++++++++++++++++ cmake/ncnn_add_layer.cmake | 27 ++++++++++ src/CMakeLists.txt | 27 ++++++++++ src/cpu.cpp | 102 +++++++++++++++++++++++++++++++++++++ src/cpu.h | 6 +++ src/platform.h.in | 3 ++ 6 files changed, 213 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 875a8d06598f..c60f087a1363 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -506,6 +506,15 @@ else() set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") + check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX512") check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) @@ -532,6 +541,15 @@ else() set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxneconvert") + check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni") check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) @@ -556,6 +574,15 @@ else() set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni") check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint8") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint16") + check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxneconvert") + check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni") check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) @@ -599,9 +626,30 @@ else() if(NCNN_AVX2) option(NCNN_AVXVNNI "optimize x86 platform with avx vnni extension" ON) endif() + if(NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + if(NCNN_AVXVNNI) + option(NCNN_AVXVNNIINT8 "optimize x86 platform with avx vnni int8 extension" ON) + endif() + else() + message(WARNING "The compiler does not support avx vnni int8 extension. NCNN_AVXVNNIINT8 will be OFF.") + endif() + if(NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + if(NCNN_AVXVNNI) + option(NCNN_AVXVNNIINT16 "optimize x86 platform with avx vnni int16 extension" ON) + endif() + else() + message(WARNING "The compiler does not support avx vnni int16 extension. NCNN_AVXVNNIINT16 will be OFF.") + endif() else() message(WARNING "The compiler does not support avx vnni extension. NCNN_AVXVNNI will be OFF.") endif() + if(NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + if(NCNN_AVX2) + option(NCNN_AVXNECONVERT "optimize x86 platform with avx ne convert extension" ON) + endif() + else() + message(WARNING "The compiler does not support avx ne convert extension. NCNN_AVXNECONVERT will be OFF.") + endif() if(NCNN_COMPILER_SUPPORT_X86_AVX512) if(NCNN_AVX2) option(NCNN_AVX512 "optimize x86 platform with avx512 extension" ON) diff --git a/cmake/ncnn_add_layer.cmake b/cmake/ncnn_add_layer.cmake index 7f334fb0b68d..d9f898f62c95 100644 --- a/cmake/ncnn_add_layer.cmake +++ b/cmake/ncnn_add_layer.cmake @@ -156,6 +156,15 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8) + ncnn_add_arch_opt_source(${class} avxvnniint8 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT8__") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16) + ncnn_add_arch_opt_source(${class} avxvnniint16 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT16__") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT) + ncnn_add_arch_opt_source(${class} avxneconvert "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXNECONVERT__") + endif() if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() @@ -187,6 +196,15 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8) + ncnn_add_arch_opt_source(${class} avxvnniint8 "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT8__") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16) + ncnn_add_arch_opt_source(${class} avxvnniint16 "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT16__") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT) + ncnn_add_arch_opt_source(${class} avxneconvert "/arch:AVX2 -mfma -mf16c -mavxneconvert /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXNECONVERT__") + endif() if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() @@ -218,6 +236,15 @@ macro(ncnn_add_layer class) if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni") endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8) + ncnn_add_arch_opt_source(${class} avxvnniint8 "-mavx2 -mfma -mf16c -mavxvnni -mavxvnniint8") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16) + ncnn_add_arch_opt_source(${class} avxvnniint16 "-mavx2 -mfma -mf16c -mavxvnni -mavxvnniint16") + endif() + if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT) + ncnn_add_arch_opt_source(${class} avxneconvert "-mavx2 -mfma -mf16c -mavxneconvert") + endif() if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c") endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 803c34a780d4..a9c45fd4645b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -441,6 +441,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") else() target_compile_options(ncnn PRIVATE /arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__FMA__) endif() + if(NCNN_AVXVNNIINT8) + target_compile_options(ncnn PRIVATE /D__AVXVNNIINT8__) + endif() + if(NCNN_AVXVNNIINT16) + target_compile_options(ncnn PRIVATE /D__AVXVNNIINT16__) + endif() + if(NCNN_AVXNECONVERT) + target_compile_options(ncnn PRIVATE /D__AVXNECONVERT__) + endif() if(NCNN_AVXVNNI) target_compile_options(ncnn PRIVATE /D__AVXVNNI__) elseif(NCNN_XOP) @@ -455,6 +464,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") else() target_compile_options(ncnn PRIVATE /arch:AVX -mfma /D__SSSE3__ /D__SSE4_1__ /D__FMA__) endif() + if(NCNN_AVXVNNIINT8) + target_compile_options(ncnn PRIVATE -mavxvnniint8 /D__AVXVNNIINT8__) + endif() + if(NCNN_AVXVNNIINT16) + target_compile_options(ncnn PRIVATE -mavxvnniint16 /D__AVXVNNIINT16__) + endif() + if(NCNN_AVXNECONVERT) + target_compile_options(ncnn PRIVATE -mavxneconvert /D__AVXNECONVERT__) + endif() if(NCNN_AVXVNNI) target_compile_options(ncnn PRIVATE -mavxvnni /D__AVXVNNI__) elseif(NCNN_XOP) @@ -469,6 +487,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") else() target_compile_options(ncnn PRIVATE -mavx -mfma) endif() + if(NCNN_AVXVNNIINT8) + target_compile_options(ncnn PRIVATE -mavxvnniint8) + endif() + if(NCNN_AVXVNNIINT16) + target_compile_options(ncnn PRIVATE -mavxvnniint16) + endif() + if(NCNN_AVXNECONVERT) + target_compile_options(ncnn PRIVATE -mavxneconvert) + endif() if(NCNN_AVXVNNI) target_compile_options(ncnn PRIVATE -mavxvnni) elseif(NCNN_XOP) diff --git a/src/cpu.cpp b/src/cpu.cpp index 9ab0ebb31e99..c9307619ce91 100644 --- a/src/cpu.cpp +++ b/src/cpu.cpp @@ -183,6 +183,9 @@ static int g_cpu_support_x86_xop; static int g_cpu_support_x86_f16c; static int g_cpu_support_x86_avx2; static int g_cpu_support_x86_avx_vnni; +static int g_cpu_support_x86_avx_vnni_int8; +static int g_cpu_support_x86_avx_vnni_int16; +static int g_cpu_support_x86_avx_ne_convert; static int g_cpu_support_x86_avx512; static int g_cpu_support_x86_avx512_vnni; static int g_cpu_support_x86_avx512_bf16; @@ -617,6 +620,72 @@ static int get_cpu_support_x86_avx_vnni() return cpu_info[0] & (1u << 4); } +static int get_cpu_support_x86_avx_vnni_int8() +{ + unsigned int cpu_info[4] = {0}; + x86_cpuid(0, cpu_info); + + int nIds = cpu_info[0]; + if (nIds < 7) + return 0; + + x86_cpuid(1, cpu_info); + // check AVX XSAVE OSXSAVE + if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27))) + return 0; + + // check XSAVE enabled by kernel + if ((x86_get_xcr0() & 6) != 6) + return 0; + + x86_cpuid_sublevel(7, 1, cpu_info); + return cpu_info[3] & (1u << 4); +} + +static int get_cpu_support_x86_avx_vnni_int16() +{ + unsigned int cpu_info[4] = {0}; + x86_cpuid(0, cpu_info); + + int nIds = cpu_info[0]; + if (nIds < 7) + return 0; + + x86_cpuid(1, cpu_info); + // check AVX XSAVE OSXSAVE + if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27))) + return 0; + + // check XSAVE enabled by kernel + if ((x86_get_xcr0() & 6) != 6) + return 0; + + x86_cpuid_sublevel(7, 1, cpu_info); + return cpu_info[3] & (1u << 10); +} + +static int get_cpu_support_x86_avx_ne_convert() +{ + unsigned int cpu_info[4] = {0}; + x86_cpuid(0, cpu_info); + + int nIds = cpu_info[0]; + if (nIds < 7) + return 0; + + x86_cpuid(1, cpu_info); + // check AVX XSAVE OSXSAVE + if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27))) + return 0; + + // check XSAVE enabled by kernel + if ((x86_get_xcr0() & 6) != 6) + return 0; + + x86_cpuid_sublevel(7, 1, cpu_info); + return cpu_info[3] & (1u << 5); +} + static int get_cpu_support_x86_avx512() { #if __APPLE__ @@ -1967,6 +2036,9 @@ static void initialize_global_cpu_info() g_cpu_support_x86_f16c = get_cpu_support_x86_f16c(); g_cpu_support_x86_avx2 = get_cpu_support_x86_avx2(); g_cpu_support_x86_avx_vnni = get_cpu_support_x86_avx_vnni(); + g_cpu_support_x86_avx_vnni_int8 = get_cpu_support_x86_avx_vnni_int8(); + g_cpu_support_x86_avx_vnni_int16 = get_cpu_support_x86_avx_vnni_int16(); + g_cpu_support_x86_avx_ne_convert = get_cpu_support_x86_avx_ne_convert(); g_cpu_support_x86_avx512 = get_cpu_support_x86_avx512(); g_cpu_support_x86_avx512_vnni = get_cpu_support_x86_avx512_vnni(); g_cpu_support_x86_avx512_bf16 = get_cpu_support_x86_avx512_bf16(); @@ -2489,6 +2561,36 @@ int cpu_support_x86_avx_vnni() #endif } +int cpu_support_x86_avx_vnni_int8() +{ + try_initialize_global_cpu_info(); +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + return g_cpu_support_x86_avx_vnni_int8; +#else + return 0; +#endif +} + +int cpu_support_x86_avx_vnni_int16() +{ + try_initialize_global_cpu_info(); +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + return g_cpu_support_x86_avx_vnni_int16; +#else + return 0; +#endif +} + +int cpu_support_x86_avx_ne_convert() +{ + try_initialize_global_cpu_info(); +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + return g_cpu_support_x86_avx_ne_convert; +#else + return 0; +#endif +} + int cpu_support_x86_avx512() { try_initialize_global_cpu_info(); diff --git a/src/cpu.h b/src/cpu.h index 2ae6b8c3ffe9..f0e4728633fe 100644 --- a/src/cpu.h +++ b/src/cpu.h @@ -93,6 +93,12 @@ NCNN_EXPORT int cpu_support_x86_f16c(); NCNN_EXPORT int cpu_support_x86_avx2(); // avx_vnni = x86 avx vnni NCNN_EXPORT int cpu_support_x86_avx_vnni(); +// avx_vnni_int8 = x86 avx vnni int8 +NCNN_EXPORT int cpu_support_x86_avx_vnni_int8(); +// avx_vnni_int16 = x86 avx vnni int16 +NCNN_EXPORT int cpu_support_x86_avx_vnni_int16(); +// avx_ne_convert = x86 avx ne convert +NCNN_EXPORT int cpu_support_x86_avx_ne_convert(); // avx512 = x86 avx512f + avx512cd + avx512bw + avx512dq + avx512vl NCNN_EXPORT int cpu_support_x86_avx512(); // avx512_vnni = x86 avx512 vnni diff --git a/src/platform.h.in b/src/platform.h.in index 50a9454b7da0..a0b372d8296b 100644 --- a/src/platform.h.in +++ b/src/platform.h.in @@ -40,6 +40,9 @@ #cmakedefine01 NCNN_F16C #cmakedefine01 NCNN_AVX2 #cmakedefine01 NCNN_AVXVNNI +#cmakedefine01 NCNN_AVXVNNIINT8 +#cmakedefine01 NCNN_AVXVNNIINT16 +#cmakedefine01 NCNN_AVXNECONVERT #cmakedefine01 NCNN_AVX512 #cmakedefine01 NCNN_AVX512VNNI #cmakedefine01 NCNN_AVX512BF16