From faa79984ad0e7c918a34e512448e90892535711c Mon Sep 17 00:00:00 2001 From: Andrey Semashev Date: Mon, 21 Sep 2020 14:40:49 +0300 Subject: [PATCH] Added x86 SIMD optimizations to crypto datatypes. - The v128 operations are optimized for SSE2/SSSE3. - srtp_octet_string_is_eq is optimized for SSE2. When SSE2 is not available, use a pair of 32-bit accumulators to speed up the bulk of the operation. We use two accumulators to leverage instruction-level parallelism supported by most modern CPUs. - In srtp_cleanse, use memset and ensure it is not optimized away with a dummy asm statement, which can potentially consume the contents of the memory. - Endian conversion functions use gcc-style intrinsics, when possible. The SIMD code uses intrinsics, which are available on all modern compilers. For MSVC, config_in_cmake.h is modified to define gcc/clang-style SSE macros based on MSVC predefined macros. We enable all SSE versions when it indicates that AVX is enabled. SSE2 is always enabled for x86-64 or for x86 when SSE2 FP math is enabled. --- config_in_cmake.h | 11 ++ crypto/include/datatypes.h | 36 +++++- crypto/math/datatypes.c | 219 ++++++++++++++++++++++++++++++++++++- 3 files changed, 258 insertions(+), 8 deletions(-) diff --git a/config_in_cmake.h b/config_in_cmake.h index b884cb377..4a198dcfd 100644 --- a/config_in_cmake.h +++ b/config_in_cmake.h @@ -122,3 +122,14 @@ #define inline #endif #endif + +/* Define gcc/clang-style SSE macros on compilers that don't define them (primarilly, MSVC). */ +#if !defined(__SSE2__) && (defined(_M_X64) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)) +#define __SSE2__ +#endif +#if !defined(__SSSE3__) && defined(__AVX__) +#define __SSSE3__ +#endif +#if !defined(__SSE4_1__) && defined(__AVX__) +#define __SSE4_1__ +#endif diff --git a/crypto/include/datatypes.h b/crypto/include/datatypes.h index 98042a379..4f25c990d 100644 --- a/crypto/include/datatypes.h +++ b/crypto/include/datatypes.h @@ -62,6 +62,10 @@ #error "Platform not recognized" #endif +#if defined(__SSE2__) +#include +#endif + #ifdef __cplusplus extern "C" { #endif @@ -90,6 +94,26 @@ void v128_left_shift(v128_t *x, int shift_index); * */ +#if defined(__SSE2__) + +#define v128_set_to_zero(x) \ + (_mm_storeu_si128((__m128i *)(x), _mm_setzero_si128())) + +#define v128_copy(x, y) \ + (_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(y)))) + +#define v128_xor(z, x, y) \ + (_mm_storeu_si128((__m128i *)(z), \ + _mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \ + _mm_loadu_si128((const __m128i *)(y))))) + +#define v128_xor_eq(z, x) \ + (_mm_storeu_si128((__m128i *)(z), \ + _mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \ + _mm_loadu_si128((const __m128i *)(z))))) + +#else /* defined(__SSE2__) */ + #define v128_set_to_zero(x) \ ((x)->v32[0] = 0, (x)->v32[1] = 0, (x)->v32[2] = 0, (x)->v32[3] = 0) @@ -113,6 +137,8 @@ void v128_left_shift(v128_t *x, int shift_index); ((z)->v64[0] ^= (x)->v64[0], (z)->v64[1] ^= (x)->v64[1]) #endif +#endif /* defined(__SSE2__) */ + /* NOTE! This assumes an odd ordering! */ /* This will not be compatible directly with math on some processors */ /* bit 0 is first 32-bit word, low order bit. in little-endian, that's @@ -173,13 +199,11 @@ void octet_string_set_to_zero(void *s, size_t len); #define be64_to_cpu(x) OSSwapInt64(x) #else /* WORDS_BIGENDIAN */ -#if defined(__GNUC__) && (defined(HAVE_X86) || defined(__x86_64__)) +#if defined(__GNUC__) /* Fall back. */ static inline uint32_t be32_to_cpu(uint32_t v) { - /* optimized for x86. */ - asm("bswap %0" : "=r"(v) : "0"(v)); - return v; + return __builtin_bswap32(v); } #else /* HAVE_X86 */ #ifdef HAVE_NETINET_IN_H @@ -192,7 +216,9 @@ static inline uint32_t be32_to_cpu(uint32_t v) static inline uint64_t be64_to_cpu(uint64_t v) { -#ifdef NO_64BIT_MATH +#if defined(__GNUC__) + v = __builtin_bswap64(v); +#elif defined(NO_64BIT_MATH) /* use the make64 functions to do 64-bit math */ v = make64(htonl(low32(v)), htonl(high32(v))); #else /* NO_64BIT_MATH */ diff --git a/crypto/math/datatypes.c b/crypto/math/datatypes.c index 372f8d188..cfadc47c8 100644 --- a/crypto/math/datatypes.c +++ b/crypto/math/datatypes.c @@ -53,6 +53,16 @@ #include "datatypes.h" +#if defined(__SSE2__) +#include +#endif + +#if defined(_MSC_VER) +#define ALIGNMENT(N) __declspec(align(N)) +#else +#define ALIGNMENT(N) __attribute__((aligned(N))) +#endif + /* * bit_string is a buffer that is used to hold output strings, e.g. * for printing. @@ -123,6 +133,9 @@ char *v128_bit_string(v128_t *x) void v128_copy_octet_string(v128_t *x, const uint8_t s[16]) { +#if defined(__SSE2__) + _mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(s))); +#else #ifdef ALIGNMENT_32BIT_REQUIRED if ((((uint32_t)&s[0]) & 0x3) != 0) #endif @@ -151,8 +164,67 @@ void v128_copy_octet_string(v128_t *x, const uint8_t s[16]) v128_copy(x, v); } #endif +#endif /* defined(__SSE2__) */ } +#if defined(__SSSE3__) + +/* clang-format off */ + +ALIGNMENT(16) +static const uint8_t right_shift_masks[5][16] = { + { 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u, + 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u }, + { 0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u, + 4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u }, + { 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u }, + { 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u }, + /* needed for bitvector_left_shift */ + { 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 } +}; + +ALIGNMENT(16) +static const uint8_t left_shift_masks[4][16] = { + { 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u, + 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u }, + { 4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u, + 12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80 }, + { 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }, + { 12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80, + 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 } +}; + +/* clang-format on */ + +void v128_left_shift(v128_t *x, int shift) +{ + if (shift > 127) { + v128_set_to_zero(x); + return; + } + + const int base_index = shift >> 5; + const int bit_index = shift & 31; + + __m128i mm = _mm_loadu_si128((const __m128i *)x); + __m128i mm_shift_right = _mm_cvtsi32_si128(bit_index); + __m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index); + mm = _mm_shuffle_epi8(mm, ((const __m128i *)left_shift_masks)[base_index]); + + __m128i mm1 = _mm_srl_epi32(mm, mm_shift_right); + __m128i mm2 = _mm_sll_epi32(mm, mm_shift_left); + mm2 = _mm_srli_si128(mm2, 4); + mm1 = _mm_or_si128(mm1, mm2); + + _mm_storeu_si128((__m128i *)x, mm1); +} + +#else /* defined(__SSSE3__) */ + void v128_left_shift(v128_t *x, int shift) { int i; @@ -179,6 +251,8 @@ void v128_left_shift(v128_t *x, int shift) x->v32[i] = 0; } +#endif /* defined(__SSSE3__) */ + /* functions manipulating bitvector_t */ int bitvector_alloc(bitvector_t *v, unsigned long length) @@ -190,6 +264,7 @@ int bitvector_alloc(bitvector_t *v, unsigned long length) (length + bits_per_word - 1) & ~(unsigned long)((bits_per_word - 1)); l = length / bits_per_word * bytes_per_word; + l = (l + 15ul) & ~15ul; /* allocate memory, then set parameters */ if (l == 0) { @@ -225,6 +300,73 @@ void bitvector_set_to_zero(bitvector_t *x) memset(x->word, 0, x->length >> 3); } +#if defined(__SSSE3__) + +void bitvector_left_shift(bitvector_t *x, int shift) +{ + if ((uint32_t)shift >= x->length) { + bitvector_set_to_zero(x); + return; + } + + const int base_index = shift >> 5; + const int bit_index = shift & 31; + const int vec_length = (x->length + 127u) >> 7; + const __m128i *from = ((const __m128i *)x->word) + (base_index >> 2); + __m128i *to = (__m128i *)x->word; + __m128i *const end = to + vec_length; + + __m128i mm_right_shift_mask = + ((const __m128i *)right_shift_masks)[4u - (base_index & 3u)]; + __m128i mm_left_shift_mask = + ((const __m128i *)left_shift_masks)[base_index & 3u]; + __m128i mm_shift_right = _mm_cvtsi32_si128(bit_index); + __m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index); + + __m128i mm_current = _mm_loadu_si128(from); + __m128i mm_current_r = _mm_srl_epi32(mm_current, mm_shift_right); + __m128i mm_current_l = _mm_sll_epi32(mm_current, mm_shift_left); + + while ((end - from) >= 2) { + ++from; + __m128i mm_next = _mm_loadu_si128(from); + + __m128i mm_next_r = _mm_srl_epi32(mm_next, mm_shift_right); + __m128i mm_next_l = _mm_sll_epi32(mm_next, mm_shift_left); + mm_current_l = _mm_alignr_epi8(mm_next_l, mm_current_l, 4); + mm_current = _mm_or_si128(mm_current_r, mm_current_l); + + mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask); + + __m128i mm_temp_next = _mm_srli_si128(mm_next_l, 4); + mm_temp_next = _mm_or_si128(mm_next_r, mm_temp_next); + + mm_temp_next = _mm_shuffle_epi8(mm_temp_next, mm_right_shift_mask); + mm_current = _mm_or_si128(mm_temp_next, mm_current); + + _mm_storeu_si128(to, mm_current); + ++to; + + mm_current_r = mm_next_r; + mm_current_l = mm_next_l; + } + + mm_current_l = _mm_srli_si128(mm_current_l, 4); + mm_current = _mm_or_si128(mm_current_r, mm_current_l); + + mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask); + + _mm_storeu_si128(to, mm_current); + ++to; + + while (to < end) { + _mm_storeu_si128(to, _mm_setzero_si128()); + ++to; + } +} + +#else /* defined(__SSSE3__) */ + void bitvector_left_shift(bitvector_t *x, int shift) { int i; @@ -253,16 +395,82 @@ void bitvector_left_shift(bitvector_t *x, int shift) x->word[i] = 0; } +#endif /* defined(__SSSE3__) */ + int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len) { - uint8_t *end = b + len; - uint8_t accumulator = 0; - /* * We use this somewhat obscure implementation to try to ensure the running * time only depends on len, even accounting for compiler optimizations. * The accumulator ends up zero iff the strings are equal. */ + uint8_t *end = b + len; + uint32_t accumulator = 0; + +#if defined(__SSE2__) + __m128i mm_accumulator1 = _mm_setzero_si128(); + __m128i mm_accumulator2 = _mm_setzero_si128(); + for (int i = 0, n = len >> 5; i < n; ++i, a += 32, b += 32) { + __m128i mm_a1 = _mm_loadu_si128((const __m128i *)a); + __m128i mm_b1 = _mm_loadu_si128((const __m128i *)b); + __m128i mm_a2 = _mm_loadu_si128((const __m128i *)(a + 16)); + __m128i mm_b2 = _mm_loadu_si128((const __m128i *)(b + 16)); + mm_a1 = _mm_xor_si128(mm_a1, mm_b1); + mm_a2 = _mm_xor_si128(mm_a2, mm_b2); + mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1); + mm_accumulator2 = _mm_or_si128(mm_accumulator2, mm_a2); + } + + mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_accumulator2); + + if ((end - b) >= 16) { + __m128i mm_a1 = _mm_loadu_si128((const __m128i *)a); + __m128i mm_b1 = _mm_loadu_si128((const __m128i *)b); + mm_a1 = _mm_xor_si128(mm_a1, mm_b1); + mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1); + a += 16; + b += 16; + } + + if ((end - b) >= 8) { + __m128i mm_a1 = _mm_loadl_epi64((const __m128i *)a); + __m128i mm_b1 = _mm_loadl_epi64((const __m128i *)b); + mm_a1 = _mm_xor_si128(mm_a1, mm_b1); + mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1); + a += 8; + b += 8; + } + + mm_accumulator1 = _mm_or_si128( + mm_accumulator1, _mm_unpackhi_epi64(mm_accumulator1, mm_accumulator1)); + mm_accumulator1 = + _mm_or_si128(mm_accumulator1, _mm_srli_si128(mm_accumulator1, 4)); + accumulator = _mm_cvtsi128_si32(mm_accumulator1); +#else + uint32_t accumulator2 = 0; + for (int i = 0, n = len >> 3; i < n; ++i, a += 8, b += 8) { + uint32_t a_val1, b_val1; + uint32_t a_val2, b_val2; + memcpy(&a_val1, a, sizeof(a_val1)); + memcpy(&b_val1, b, sizeof(b_val1)); + memcpy(&a_val2, a + 4, sizeof(a_val2)); + memcpy(&b_val2, b + 4, sizeof(b_val2)); + accumulator |= a_val1 ^ b_val1; + accumulator2 |= a_val2 ^ b_val2; + } + + accumulator |= accumulator2; + + if ((end - b) >= 4) { + uint32_t a_val, b_val; + memcpy(&a_val, a, sizeof(a_val)); + memcpy(&b_val, b, sizeof(b_val)); + accumulator |= a_val ^ b_val; + a += 4; + b += 4; + } +#endif + while (b < end) accumulator |= (*a++ ^ *b++); @@ -272,9 +480,14 @@ int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len) void srtp_cleanse(void *s, size_t len) { +#if defined(__GNUC__) + memset(s, 0, len); + __asm__ __volatile__("" : : "r"(s) : "memory"); +#else volatile unsigned char *p = (volatile unsigned char *)s; while (len--) *p++ = 0; +#endif } void octet_string_set_to_zero(void *s, size_t len)