Skip to content

Commit

Permalink
Added x86 SIMD optimizations to crypto datatypes.
Browse files Browse the repository at this point in the history
- The v128 operations are optimized for SSE2/SSSE3.
- srtp_octet_string_is_eq is optimized for SSE2. When SSE2 is not
  available, use a pair of 32-bit accumulators to speed up the
  bulk of the operation. We use two accumulators to leverage
  instruction-level parallelism supported by most modern CPUs.
- In srtp_cleanse, use memset and ensure it is not optimized away
  with a dummy asm statement, which can potentially consume the
  contents of the memory.
- Endian conversion functions use gcc-style intrinsics, when possible.

The SIMD code uses intrinsics, which are available on all modern compilers.
For MSVC, config_in_cmake.h is modified to define gcc/clang-style SSE macros
based on MSVC predefined macros. We enable all SSE versions when it
indicates that AVX is enabled. SSE2 is always enabled for x86-64 or for x86
when SSE2 FP math is enabled.
  • Loading branch information
Lastique committed Feb 28, 2023
1 parent 536a367 commit 52d61dd
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 8 deletions.
11 changes: 11 additions & 0 deletions config_in_cmake.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,14 @@
#define inline
#endif
#endif

/* Define gcc/clang-style SSE macros on compilers that don't define them (primarilly, MSVC). */
#if !defined(__SSE2__) && (defined(_M_X64) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2))
#define __SSE2__
#endif
#if !defined(__SSSE3__) && defined(__AVX__)
#define __SSSE3__
#endif
#if !defined(__SSE4_1__) && defined(__AVX__)
#define __SSE4_1__
#endif
36 changes: 31 additions & 5 deletions crypto/include/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
#error "Platform not recognized"
#endif

#if defined(__SSE2__)
#include <emmintrin.h>
#endif

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -90,6 +94,26 @@ void v128_left_shift(v128_t *x, int shift_index);
*
*/

#if defined(__SSE2__)

#define v128_set_to_zero(x) \
(_mm_storeu_si128((__m128i *)(x), _mm_setzero_si128()))

#define v128_copy(x, y) \
(_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(y))))

#define v128_xor(z, x, y) \
(_mm_storeu_si128((__m128i *)(z), \
_mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \
_mm_loadu_si128((const __m128i *)(y)))))

#define v128_xor_eq(z, x) \
(_mm_storeu_si128((__m128i *)(z), \
_mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \
_mm_loadu_si128((const __m128i *)(z)))))

#else /* defined(__SSE2__) */

#define v128_set_to_zero(x) \
((x)->v32[0] = 0, (x)->v32[1] = 0, (x)->v32[2] = 0, (x)->v32[3] = 0)

Expand All @@ -113,6 +137,8 @@ void v128_left_shift(v128_t *x, int shift_index);
((z)->v64[0] ^= (x)->v64[0], (z)->v64[1] ^= (x)->v64[1])
#endif

#endif /* defined(__SSE2__) */

/* NOTE! This assumes an odd ordering! */
/* This will not be compatible directly with math on some processors */
/* bit 0 is first 32-bit word, low order bit. in little-endian, that's
Expand Down Expand Up @@ -173,13 +199,11 @@ void octet_string_set_to_zero(void *s, size_t len);
#define be64_to_cpu(x) OSSwapInt64(x)
#else /* WORDS_BIGENDIAN */

#if defined(__GNUC__) && (defined(HAVE_X86) || defined(__x86_64__))
#if defined(__GNUC__)
/* Fall back. */
static inline uint32_t be32_to_cpu(uint32_t v)
{
/* optimized for x86. */
asm("bswap %0" : "=r"(v) : "0"(v));
return v;
return __builtin_bswap32(v);
}
#else /* HAVE_X86 */
#ifdef HAVE_NETINET_IN_H
Expand All @@ -192,7 +216,9 @@ static inline uint32_t be32_to_cpu(uint32_t v)

static inline uint64_t be64_to_cpu(uint64_t v)
{
#ifdef NO_64BIT_MATH
#if defined(__GNUC__)
v = __builtin_bswap64(v);
#elif defined(NO_64BIT_MATH)
/* use the make64 functions to do 64-bit math */
v = make64(htonl(low32(v)), htonl(high32(v)));
#else /* NO_64BIT_MATH */
Expand Down
210 changes: 207 additions & 3 deletions crypto/math/datatypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@

#include "datatypes.h"

#if defined(__SSE2__)
#include <tmmintrin.h>
#endif

#if defined(_MSC_VER)
#define ALIGNMENT(N) __declspec(align(N))
#else
#define ALIGNMENT(N) __attribute__((aligned(N)))
#endif

/*
* bit_string is a buffer that is used to hold output strings, e.g.
* for printing.
Expand Down Expand Up @@ -123,6 +133,9 @@ char *v128_bit_string(v128_t *x)

void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
{
#if defined(__SSE2__)
_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(s)));
#else
#ifdef ALIGNMENT_32BIT_REQUIRED
if ((((uint32_t)&s[0]) & 0x3) != 0)
#endif
Expand Down Expand Up @@ -151,8 +164,67 @@ void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
v128_copy(x, v);
}
#endif
#endif /* defined(__SSE2__) */
}

#if defined(__SSSE3__)

/* clang-format off */

ALIGNMENT(16)
static const uint8_t right_shift_masks[5][16] = {
{ 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u,
8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u },
{ 0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u,
4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u },
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u },
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u },
/* needed for bitvector_left_shift */
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }
};

ALIGNMENT(16)
static const uint8_t left_shift_masks[4][16] = {
{ 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u,
8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u },
{ 4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u,
12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80 },
{ 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 },
{ 12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }
};

/* clang-format on */

void v128_left_shift(v128_t *x, int shift)
{
if (shift > 127) {
v128_set_to_zero(x);
return;
}

const int base_index = shift >> 5;
const int bit_index = shift & 31;

__m128i mm = _mm_loadu_si128((const __m128i *)x);
__m128i mm_shift_right = _mm_cvtsi32_si128(bit_index);
__m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index);
mm = _mm_shuffle_epi8(mm, ((const __m128i *)left_shift_masks)[base_index]);

__m128i mm1 = _mm_srl_epi32(mm, mm_shift_right);
__m128i mm2 = _mm_sll_epi32(mm, mm_shift_left);
mm2 = _mm_srli_si128(mm2, 4);
mm1 = _mm_or_si128(mm1, mm2);

_mm_storeu_si128((__m128i *)x, mm1);
}

#else /* defined(__SSSE3__) */

void v128_left_shift(v128_t *x, int shift)
{
int i;
Expand All @@ -179,6 +251,8 @@ void v128_left_shift(v128_t *x, int shift)
x->v32[i] = 0;
}

#endif /* defined(__SSSE3__) */

/* functions manipulating bitvector_t */

int bitvector_alloc(bitvector_t *v, unsigned long length)
Expand All @@ -190,6 +264,7 @@ int bitvector_alloc(bitvector_t *v, unsigned long length)
(length + bits_per_word - 1) & ~(unsigned long)((bits_per_word - 1));

l = length / bits_per_word * bytes_per_word;
l = (l + 15ul) & ~15ul;

/* allocate memory, then set parameters */
if (l == 0) {
Expand Down Expand Up @@ -225,6 +300,73 @@ void bitvector_set_to_zero(bitvector_t *x)
memset(x->word, 0, x->length >> 3);
}

#if defined(__SSSE3__)

void bitvector_left_shift(bitvector_t *x, int shift)
{
if ((uint32_t)shift >= x->length) {
bitvector_set_to_zero(x);
return;
}

const int base_index = shift >> 5;
const int bit_index = shift & 31;
const int vec_length = (x->length + 127u) >> 7;
const __m128i *from = ((const __m128i *)x->word) + (base_index >> 2);
__m128i *to = (__m128i *)x->word;
__m128i *const end = to + vec_length;

__m128i mm_right_shift_mask =
((const __m128i *)right_shift_masks)[4u - (base_index & 3u)];
__m128i mm_left_shift_mask =
((const __m128i *)left_shift_masks)[base_index & 3u];
__m128i mm_shift_right = _mm_cvtsi32_si128(bit_index);
__m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index);

__m128i mm_current = _mm_loadu_si128(from);
__m128i mm_current_r = _mm_srl_epi32(mm_current, mm_shift_right);
__m128i mm_current_l = _mm_sll_epi32(mm_current, mm_shift_left);

while ((end - from) >= 2) {
++from;
__m128i mm_next = _mm_loadu_si128(from);

__m128i mm_next_r = _mm_srl_epi32(mm_next, mm_shift_right);
__m128i mm_next_l = _mm_sll_epi32(mm_next, mm_shift_left);
mm_current_l = _mm_alignr_epi8(mm_next_l, mm_current_l, 4);
mm_current = _mm_or_si128(mm_current_r, mm_current_l);

mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask);

__m128i mm_temp_next = _mm_srli_si128(mm_next_l, 4);
mm_temp_next = _mm_or_si128(mm_next_r, mm_temp_next);

mm_temp_next = _mm_shuffle_epi8(mm_temp_next, mm_right_shift_mask);
mm_current = _mm_or_si128(mm_temp_next, mm_current);

_mm_storeu_si128(to, mm_current);
++to;

mm_current_r = mm_next_r;
mm_current_l = mm_next_l;
}

mm_current_l = _mm_srli_si128(mm_current_l, 4);
mm_current = _mm_or_si128(mm_current_r, mm_current_l);

mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask);

_mm_storeu_si128(to, mm_current);
++to;

while (to < end) {
_mm_storeu_si128(to, _mm_setzero_si128());
++to;
}
}

#else /* defined(__SSSE3__) */

void bitvector_left_shift(bitvector_t *x, int shift)
{
int i;
Expand Down Expand Up @@ -253,16 +395,73 @@ void bitvector_left_shift(bitvector_t *x, int shift)
x->word[i] = 0;
}

#endif /* defined(__SSSE3__) */

int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)
{
uint8_t *end = b + len;
uint8_t accumulator = 0;

/*
* We use this somewhat obscure implementation to try to ensure the running
* time only depends on len, even accounting for compiler optimizations.
* The accumulator ends up zero iff the strings are equal.
*/
uint8_t *end = b + len;
uint32_t accumulator = 0;

#if defined(__SSE2__)
__m128i mm_accumulator1 = _mm_setzero_si128();
__m128i mm_accumulator2 = _mm_setzero_si128();
for (int i = 0, n = len >> 5; i < n; ++i, a += 32, b += 32) {
__m128i mm_a1 = _mm_loadu_si128((const __m128i *)a);
__m128i mm_b1 = _mm_loadu_si128((const __m128i *)b);
__m128i mm_a2 = _mm_loadu_si128((const __m128i *)(a + 16));
__m128i mm_b2 = _mm_loadu_si128((const __m128i *)(b + 16));
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
mm_a2 = _mm_xor_si128(mm_a2, mm_b2);
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
mm_accumulator2 = _mm_or_si128(mm_accumulator2, mm_a2);
}

mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_accumulator2);

if ((end - b) >= 16) {
__m128i mm_a1 = _mm_loadu_si128((const __m128i *)a);
__m128i mm_b1 = _mm_loadu_si128((const __m128i *)b);
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
a += 16;
b += 16;
}

mm_accumulator1 = _mm_or_si128(
mm_accumulator1, _mm_unpackhi_epi64(mm_accumulator1, mm_accumulator1));
mm_accumulator1 =
_mm_or_si128(mm_accumulator1, _mm_srli_si128(mm_accumulator1, 4));
accumulator = _mm_cvtsi128_si32(mm_accumulator1);
#else
uint32_t accumulator2 = 0;
for (int i = 0, n = len >> 3; i < n; ++i, a += 8, b += 8) {
uint32_t a_val1, b_val1;
uint32_t a_val2, b_val2;
memcpy(&a_val1, a, sizeof(a_val1));
memcpy(&b_val1, b, sizeof(b_val1));
memcpy(&a_val2, a + 4, sizeof(a_val2));
memcpy(&b_val2, b + 4, sizeof(b_val2));
accumulator |= a_val1 ^ b_val1;
accumulator2 |= a_val2 ^ b_val2;
}

accumulator |= accumulator2;

if ((end - b) >= 4) {
uint32_t a_val, b_val;
memcpy(&a_val, a, sizeof(a_val));
memcpy(&b_val, b, sizeof(b_val));
accumulator |= a_val ^ b_val;
a += 4;
b += 4;
}
#endif

while (b < end)
accumulator |= (*a++ ^ *b++);

Expand All @@ -272,9 +471,14 @@ int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)

void srtp_cleanse(void *s, size_t len)
{
#if defined(__GNUC__)
memset(s, 0, len);
__asm__ __volatile__("" : : "r"(s) : "memory");
#else
volatile unsigned char *p = (volatile unsigned char *)s;
while (len--)
*p++ = 0;
#endif
}

void octet_string_set_to_zero(void *s, size_t len)
Expand Down

0 comments on commit 52d61dd

Please sign in to comment.