Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avx vnni int8, avx vnni int16, avx ne convert infrastructure #5749

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,15 @@ else()
set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

Expand All @@ -532,6 +541,15 @@ else()
set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxneconvert")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

Expand All @@ -556,6 +574,15 @@ else()
set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint8")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxneconvert")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

Expand Down Expand Up @@ -599,9 +626,30 @@ else()
if(NCNN_AVX2)
option(NCNN_AVXVNNI "optimize x86 platform with avx vnni extension" ON)
endif()
if(NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)
if(NCNN_AVXVNNI)
option(NCNN_AVXVNNIINT8 "optimize x86 platform with avx vnni int8 extension" ON)
endif()
else()
message(WARNING "The compiler does not support avx vnni int8 extension. NCNN_AVXVNNIINT8 will be OFF.")
endif()
if(NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)
if(NCNN_AVXVNNI)
option(NCNN_AVXVNNIINT16 "optimize x86 platform with avx vnni int16 extension" ON)
endif()
else()
message(WARNING "The compiler does not support avx vnni int16 extension. NCNN_AVXVNNIINT16 will be OFF.")
endif()
else()
message(WARNING "The compiler does not support avx vnni extension. NCNN_AVXVNNI will be OFF.")
endif()
if(NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)
if(NCNN_AVX2)
option(NCNN_AVXNECONVERT "optimize x86 platform with avx ne convert extension" ON)
endif()
else()
message(WARNING "The compiler does not support avx ne convert extension. NCNN_AVXNECONVERT will be OFF.")
endif()
if(NCNN_COMPILER_SUPPORT_X86_AVX512)
if(NCNN_AVX2)
option(NCNN_AVX512 "optimize x86 platform with avx512 extension" ON)
Expand Down
27 changes: 27 additions & 0 deletions cmake/ncnn_add_layer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8)
ncnn_add_arch_opt_source(${class} avxvnniint8 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT8__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16)
ncnn_add_arch_opt_source(${class} avxvnniint16 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT16__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT)
ncnn_add_arch_opt_source(${class} avxneconvert "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXNECONVERT__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
Expand Down Expand Up @@ -187,6 +196,15 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8)
ncnn_add_arch_opt_source(${class} avxvnniint8 "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT8__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16)
ncnn_add_arch_opt_source(${class} avxvnniint16 "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__ /D__AVXVNNIINT16__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT)
ncnn_add_arch_opt_source(${class} avxneconvert "/arch:AVX2 -mfma -mf16c -mavxneconvert /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXNECONVERT__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
Expand Down Expand Up @@ -218,6 +236,15 @@ macro(ncnn_add_layer class)
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT8)
ncnn_add_arch_opt_source(${class} avxvnniint8 "-mavx2 -mfma -mf16c -mavxvnni -mavxvnniint8")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNIINT16)
ncnn_add_arch_opt_source(${class} avxvnniint16 "-mavx2 -mfma -mf16c -mavxvnni -mavxvnniint16")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVXNECONVERT)
ncnn_add_arch_opt_source(${class} avxneconvert "-mavx2 -mfma -mf16c -mavxneconvert")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c")
endif()
Expand Down
27 changes: 27 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")
else()
target_compile_options(ncnn PRIVATE /arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__FMA__)
endif()
if(NCNN_AVXVNNIINT8)
target_compile_options(ncnn PRIVATE /D__AVXVNNIINT8__)
endif()
if(NCNN_AVXVNNIINT16)
target_compile_options(ncnn PRIVATE /D__AVXVNNIINT16__)
endif()
if(NCNN_AVXNECONVERT)
target_compile_options(ncnn PRIVATE /D__AVXNECONVERT__)
endif()
if(NCNN_AVXVNNI)
target_compile_options(ncnn PRIVATE /D__AVXVNNI__)
elseif(NCNN_XOP)
Expand All @@ -455,6 +464,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")
else()
target_compile_options(ncnn PRIVATE /arch:AVX -mfma /D__SSSE3__ /D__SSE4_1__ /D__FMA__)
endif()
if(NCNN_AVXVNNIINT8)
target_compile_options(ncnn PRIVATE -mavxvnniint8 /D__AVXVNNIINT8__)
endif()
if(NCNN_AVXVNNIINT16)
target_compile_options(ncnn PRIVATE -mavxvnniint16 /D__AVXVNNIINT16__)
endif()
if(NCNN_AVXNECONVERT)
target_compile_options(ncnn PRIVATE -mavxneconvert /D__AVXNECONVERT__)
endif()
if(NCNN_AVXVNNI)
target_compile_options(ncnn PRIVATE -mavxvnni /D__AVXVNNI__)
elseif(NCNN_XOP)
Expand All @@ -469,6 +487,15 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")
else()
target_compile_options(ncnn PRIVATE -mavx -mfma)
endif()
if(NCNN_AVXVNNIINT8)
target_compile_options(ncnn PRIVATE -mavxvnniint8)
endif()
if(NCNN_AVXVNNIINT16)
target_compile_options(ncnn PRIVATE -mavxvnniint16)
endif()
if(NCNN_AVXNECONVERT)
target_compile_options(ncnn PRIVATE -mavxneconvert)
endif()
if(NCNN_AVXVNNI)
target_compile_options(ncnn PRIVATE -mavxvnni)
elseif(NCNN_XOP)
Expand Down
102 changes: 102 additions & 0 deletions src/cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ static int g_cpu_support_x86_xop;
static int g_cpu_support_x86_f16c;
static int g_cpu_support_x86_avx2;
static int g_cpu_support_x86_avx_vnni;
static int g_cpu_support_x86_avx_vnni_int8;
static int g_cpu_support_x86_avx_vnni_int16;
static int g_cpu_support_x86_avx_ne_convert;
static int g_cpu_support_x86_avx512;
static int g_cpu_support_x86_avx512_vnni;
static int g_cpu_support_x86_avx512_bf16;
Expand Down Expand Up @@ -617,6 +620,72 @@ static int get_cpu_support_x86_avx_vnni()
return cpu_info[0] & (1u << 4);
}

static int get_cpu_support_x86_avx_vnni_int8()
{
unsigned int cpu_info[4] = {0};
x86_cpuid(0, cpu_info);

int nIds = cpu_info[0];
if (nIds < 7)
return 0;

x86_cpuid(1, cpu_info);
// check AVX XSAVE OSXSAVE
if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27)))
return 0;

// check XSAVE enabled by kernel
if ((x86_get_xcr0() & 6) != 6)
return 0;

x86_cpuid_sublevel(7, 1, cpu_info);
return cpu_info[3] & (1u << 4);
}

static int get_cpu_support_x86_avx_vnni_int16()
{
unsigned int cpu_info[4] = {0};
x86_cpuid(0, cpu_info);

int nIds = cpu_info[0];
if (nIds < 7)
return 0;

x86_cpuid(1, cpu_info);
// check AVX XSAVE OSXSAVE
if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27)))
return 0;

// check XSAVE enabled by kernel
if ((x86_get_xcr0() & 6) != 6)
return 0;

x86_cpuid_sublevel(7, 1, cpu_info);
return cpu_info[3] & (1u << 10);
}

static int get_cpu_support_x86_avx_ne_convert()
{
unsigned int cpu_info[4] = {0};
x86_cpuid(0, cpu_info);

int nIds = cpu_info[0];
if (nIds < 7)
return 0;

x86_cpuid(1, cpu_info);
// check AVX XSAVE OSXSAVE
if (!(cpu_info[2] & (1u << 28)) || !(cpu_info[2] & (1u << 26)) || !(cpu_info[2] & (1u << 27)))
return 0;

// check XSAVE enabled by kernel
if ((x86_get_xcr0() & 6) != 6)
return 0;

x86_cpuid_sublevel(7, 1, cpu_info);
return cpu_info[3] & (1u << 5);
}

static int get_cpu_support_x86_avx512()
{
#if __APPLE__
Expand Down Expand Up @@ -1967,6 +2036,9 @@ static void initialize_global_cpu_info()
g_cpu_support_x86_f16c = get_cpu_support_x86_f16c();
g_cpu_support_x86_avx2 = get_cpu_support_x86_avx2();
g_cpu_support_x86_avx_vnni = get_cpu_support_x86_avx_vnni();
g_cpu_support_x86_avx_vnni_int8 = get_cpu_support_x86_avx_vnni_int8();
g_cpu_support_x86_avx_vnni_int16 = get_cpu_support_x86_avx_vnni_int16();
g_cpu_support_x86_avx_ne_convert = get_cpu_support_x86_avx_ne_convert();
g_cpu_support_x86_avx512 = get_cpu_support_x86_avx512();
g_cpu_support_x86_avx512_vnni = get_cpu_support_x86_avx512_vnni();
g_cpu_support_x86_avx512_bf16 = get_cpu_support_x86_avx512_bf16();
Expand Down Expand Up @@ -2489,6 +2561,36 @@ int cpu_support_x86_avx_vnni()
#endif
}

int cpu_support_x86_avx_vnni_int8()
{
try_initialize_global_cpu_info();
#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)
return g_cpu_support_x86_avx_vnni_int8;
#else
return 0;
#endif
}

int cpu_support_x86_avx_vnni_int16()
{
try_initialize_global_cpu_info();
#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)
return g_cpu_support_x86_avx_vnni_int16;
#else
return 0;
#endif
}

int cpu_support_x86_avx_ne_convert()
{
try_initialize_global_cpu_info();
#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)
return g_cpu_support_x86_avx_ne_convert;
#else
return 0;
#endif
}

int cpu_support_x86_avx512()
{
try_initialize_global_cpu_info();
Expand Down
6 changes: 6 additions & 0 deletions src/cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ NCNN_EXPORT int cpu_support_x86_f16c();
NCNN_EXPORT int cpu_support_x86_avx2();
// avx_vnni = x86 avx vnni
NCNN_EXPORT int cpu_support_x86_avx_vnni();
// avx_vnni_int8 = x86 avx vnni int8
NCNN_EXPORT int cpu_support_x86_avx_vnni_int8();
// avx_vnni_int16 = x86 avx vnni int16
NCNN_EXPORT int cpu_support_x86_avx_vnni_int16();
// avx_ne_convert = x86 avx ne convert
NCNN_EXPORT int cpu_support_x86_avx_ne_convert();
// avx512 = x86 avx512f + avx512cd + avx512bw + avx512dq + avx512vl
NCNN_EXPORT int cpu_support_x86_avx512();
// avx512_vnni = x86 avx512 vnni
Expand Down
3 changes: 3 additions & 0 deletions src/platform.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
#cmakedefine01 NCNN_F16C
#cmakedefine01 NCNN_AVX2
#cmakedefine01 NCNN_AVXVNNI
#cmakedefine01 NCNN_AVXVNNIINT8
#cmakedefine01 NCNN_AVXVNNIINT16
#cmakedefine01 NCNN_AVXNECONVERT
#cmakedefine01 NCNN_AVX512
#cmakedefine01 NCNN_AVX512VNNI
#cmakedefine01 NCNN_AVX512BF16
Expand Down
Loading