Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add x86 SIMD optimizations to crypto datatypes #507

Merged
merged 1 commit into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions config_in_cmake.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,14 @@
#define inline
#endif
#endif

/* Define gcc/clang-style SSE macros on compilers that don't define them (primarilly, MSVC). */
#if !defined(__SSE2__) && (defined(_M_X64) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2))
#define __SSE2__
#endif
#if !defined(__SSSE3__) && defined(__AVX__)
#define __SSSE3__
#endif
#if !defined(__SSE4_1__) && defined(__AVX__)
#define __SSE4_1__
#endif
36 changes: 31 additions & 5 deletions crypto/include/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
#error "Platform not recognized"
#endif

#if defined(__SSE2__)
#include <emmintrin.h>
#endif

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -90,6 +94,26 @@ void v128_left_shift(v128_t *x, int shift_index);
*
*/

#if defined(__SSE2__)

#define v128_set_to_zero(x) \
(_mm_storeu_si128((__m128i *)(x), _mm_setzero_si128()))

#define v128_copy(x, y) \
(_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(y))))

#define v128_xor(z, x, y) \
(_mm_storeu_si128((__m128i *)(z), \
_mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \
_mm_loadu_si128((const __m128i *)(y)))))

#define v128_xor_eq(z, x) \
(_mm_storeu_si128((__m128i *)(z), \
_mm_xor_si128(_mm_loadu_si128((const __m128i *)(x)), \
_mm_loadu_si128((const __m128i *)(z)))))

#else /* defined(__SSE2__) */

#define v128_set_to_zero(x) \
((x)->v32[0] = 0, (x)->v32[1] = 0, (x)->v32[2] = 0, (x)->v32[3] = 0)

Expand All @@ -113,6 +137,8 @@ void v128_left_shift(v128_t *x, int shift_index);
((z)->v64[0] ^= (x)->v64[0], (z)->v64[1] ^= (x)->v64[1])
#endif

#endif /* defined(__SSE2__) */

/* NOTE! This assumes an odd ordering! */
/* This will not be compatible directly with math on some processors */
/* bit 0 is first 32-bit word, low order bit. in little-endian, that's
Expand Down Expand Up @@ -173,13 +199,11 @@ void octet_string_set_to_zero(void *s, size_t len);
#define be64_to_cpu(x) OSSwapInt64(x)
#else /* WORDS_BIGENDIAN */

#if defined(__GNUC__) && (defined(HAVE_X86) || defined(__x86_64__))
#if defined(__GNUC__)
/* Fall back. */
static inline uint32_t be32_to_cpu(uint32_t v)
{
/* optimized for x86. */
asm("bswap %0" : "=r"(v) : "0"(v));
return v;
return __builtin_bswap32(v);
}
#else /* HAVE_X86 */
#ifdef HAVE_NETINET_IN_H
Expand All @@ -192,7 +216,9 @@ static inline uint32_t be32_to_cpu(uint32_t v)

static inline uint64_t be64_to_cpu(uint64_t v)
{
#ifdef NO_64BIT_MATH
#if defined(__GNUC__)
v = __builtin_bswap64(v);
#elif defined(NO_64BIT_MATH)
/* use the make64 functions to do 64-bit math */
v = make64(htonl(low32(v)), htonl(high32(v)));
#else /* NO_64BIT_MATH */
Expand Down
219 changes: 216 additions & 3 deletions crypto/math/datatypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@

#include "datatypes.h"

#if defined(__SSE2__)
#include <tmmintrin.h>
#endif

#if defined(_MSC_VER)
#define ALIGNMENT(N) __declspec(align(N))
#else
#define ALIGNMENT(N) __attribute__((aligned(N)))
#endif

/*
* bit_string is a buffer that is used to hold output strings, e.g.
* for printing.
Expand Down Expand Up @@ -123,6 +133,9 @@ char *v128_bit_string(v128_t *x)

void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
{
#if defined(__SSE2__)
_mm_storeu_si128((__m128i *)(x), _mm_loadu_si128((const __m128i *)(s)));
#else
#ifdef ALIGNMENT_32BIT_REQUIRED
if ((((uint32_t)&s[0]) & 0x3) != 0)
#endif
Expand Down Expand Up @@ -151,8 +164,67 @@ void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
v128_copy(x, v);
}
#endif
#endif /* defined(__SSE2__) */
}

#if defined(__SSSE3__)

/* clang-format off */

ALIGNMENT(16)
static const uint8_t right_shift_masks[5][16] = {
{ 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u,
8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u },
{ 0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u,
4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u },
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u },
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0u, 1u, 2u, 3u },
/* needed for bitvector_left_shift */
{ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }
};

ALIGNMENT(16)
static const uint8_t left_shift_masks[4][16] = {
{ 0u, 1u, 2u, 3u, 4u, 5u, 6u, 7u,
8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u },
{ 4u, 5u, 6u, 7u, 8u, 9u, 10u, 11u,
12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80 },
{ 8u, 9u, 10u, 11u, 12u, 13u, 14u, 15u,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 },
{ 12u, 13u, 14u, 15u, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80 }
};

/* clang-format on */

void v128_left_shift(v128_t *x, int shift)
{
if (shift > 127) {
v128_set_to_zero(x);
return;
}

const int base_index = shift >> 5;
const int bit_index = shift & 31;

__m128i mm = _mm_loadu_si128((const __m128i *)x);
__m128i mm_shift_right = _mm_cvtsi32_si128(bit_index);
__m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index);
mm = _mm_shuffle_epi8(mm, ((const __m128i *)left_shift_masks)[base_index]);

__m128i mm1 = _mm_srl_epi32(mm, mm_shift_right);
__m128i mm2 = _mm_sll_epi32(mm, mm_shift_left);
mm2 = _mm_srli_si128(mm2, 4);
mm1 = _mm_or_si128(mm1, mm2);

_mm_storeu_si128((__m128i *)x, mm1);
}

#else /* defined(__SSSE3__) */

void v128_left_shift(v128_t *x, int shift)
{
int i;
Expand All @@ -179,6 +251,8 @@ void v128_left_shift(v128_t *x, int shift)
x->v32[i] = 0;
}

#endif /* defined(__SSSE3__) */

/* functions manipulating bitvector_t */

int bitvector_alloc(bitvector_t *v, unsigned long length)
Expand All @@ -190,6 +264,7 @@ int bitvector_alloc(bitvector_t *v, unsigned long length)
(length + bits_per_word - 1) & ~(unsigned long)((bits_per_word - 1));

l = length / bits_per_word * bytes_per_word;
l = (l + 15ul) & ~15ul;

/* allocate memory, then set parameters */
if (l == 0) {
Expand Down Expand Up @@ -225,6 +300,73 @@ void bitvector_set_to_zero(bitvector_t *x)
memset(x->word, 0, x->length >> 3);
}

#if defined(__SSSE3__)

void bitvector_left_shift(bitvector_t *x, int shift)
{
if ((uint32_t)shift >= x->length) {
bitvector_set_to_zero(x);
return;
}

const int base_index = shift >> 5;
const int bit_index = shift & 31;
const int vec_length = (x->length + 127u) >> 7;
const __m128i *from = ((const __m128i *)x->word) + (base_index >> 2);
__m128i *to = (__m128i *)x->word;
__m128i *const end = to + vec_length;

__m128i mm_right_shift_mask =
((const __m128i *)right_shift_masks)[4u - (base_index & 3u)];
__m128i mm_left_shift_mask =
((const __m128i *)left_shift_masks)[base_index & 3u];
__m128i mm_shift_right = _mm_cvtsi32_si128(bit_index);
__m128i mm_shift_left = _mm_cvtsi32_si128(32 - bit_index);

__m128i mm_current = _mm_loadu_si128(from);
__m128i mm_current_r = _mm_srl_epi32(mm_current, mm_shift_right);
__m128i mm_current_l = _mm_sll_epi32(mm_current, mm_shift_left);

while ((end - from) >= 2) {
++from;
__m128i mm_next = _mm_loadu_si128(from);

__m128i mm_next_r = _mm_srl_epi32(mm_next, mm_shift_right);
__m128i mm_next_l = _mm_sll_epi32(mm_next, mm_shift_left);
mm_current_l = _mm_alignr_epi8(mm_next_l, mm_current_l, 4);
mm_current = _mm_or_si128(mm_current_r, mm_current_l);

mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask);

__m128i mm_temp_next = _mm_srli_si128(mm_next_l, 4);
mm_temp_next = _mm_or_si128(mm_next_r, mm_temp_next);

mm_temp_next = _mm_shuffle_epi8(mm_temp_next, mm_right_shift_mask);
mm_current = _mm_or_si128(mm_temp_next, mm_current);

_mm_storeu_si128(to, mm_current);
++to;

mm_current_r = mm_next_r;
mm_current_l = mm_next_l;
}

mm_current_l = _mm_srli_si128(mm_current_l, 4);
mm_current = _mm_or_si128(mm_current_r, mm_current_l);

mm_current = _mm_shuffle_epi8(mm_current, mm_left_shift_mask);

_mm_storeu_si128(to, mm_current);
++to;

while (to < end) {
_mm_storeu_si128(to, _mm_setzero_si128());
++to;
}
}

#else /* defined(__SSSE3__) */

void bitvector_left_shift(bitvector_t *x, int shift)
{
int i;
Expand Down Expand Up @@ -253,16 +395,82 @@ void bitvector_left_shift(bitvector_t *x, int shift)
x->word[i] = 0;
}

#endif /* defined(__SSSE3__) */

int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)
{
uint8_t *end = b + len;
uint8_t accumulator = 0;

/*
* We use this somewhat obscure implementation to try to ensure the running
* time only depends on len, even accounting for compiler optimizations.
* The accumulator ends up zero iff the strings are equal.
*/
uint8_t *end = b + len;
uint32_t accumulator = 0;

#if defined(__SSE2__)
__m128i mm_accumulator1 = _mm_setzero_si128();
__m128i mm_accumulator2 = _mm_setzero_si128();
for (int i = 0, n = len >> 5; i < n; ++i, a += 32, b += 32) {
__m128i mm_a1 = _mm_loadu_si128((const __m128i *)a);
__m128i mm_b1 = _mm_loadu_si128((const __m128i *)b);
__m128i mm_a2 = _mm_loadu_si128((const __m128i *)(a + 16));
__m128i mm_b2 = _mm_loadu_si128((const __m128i *)(b + 16));
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
mm_a2 = _mm_xor_si128(mm_a2, mm_b2);
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
mm_accumulator2 = _mm_or_si128(mm_accumulator2, mm_a2);
}

mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_accumulator2);

if ((end - b) >= 16) {
__m128i mm_a1 = _mm_loadu_si128((const __m128i *)a);
__m128i mm_b1 = _mm_loadu_si128((const __m128i *)b);
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
a += 16;
b += 16;
}

if ((end - b) >= 8) {
__m128i mm_a1 = _mm_loadl_epi64((const __m128i *)a);
__m128i mm_b1 = _mm_loadl_epi64((const __m128i *)b);
mm_a1 = _mm_xor_si128(mm_a1, mm_b1);
mm_accumulator1 = _mm_or_si128(mm_accumulator1, mm_a1);
a += 8;
b += 8;
}

mm_accumulator1 = _mm_or_si128(
mm_accumulator1, _mm_unpackhi_epi64(mm_accumulator1, mm_accumulator1));
mm_accumulator1 =
_mm_or_si128(mm_accumulator1, _mm_srli_si128(mm_accumulator1, 4));
accumulator = _mm_cvtsi128_si32(mm_accumulator1);
#else
uint32_t accumulator2 = 0;
for (int i = 0, n = len >> 3; i < n; ++i, a += 8, b += 8) {
uint32_t a_val1, b_val1;
uint32_t a_val2, b_val2;
memcpy(&a_val1, a, sizeof(a_val1));
memcpy(&b_val1, b, sizeof(b_val1));
memcpy(&a_val2, a + 4, sizeof(a_val2));
memcpy(&b_val2, b + 4, sizeof(b_val2));
accumulator |= a_val1 ^ b_val1;
accumulator2 |= a_val2 ^ b_val2;
}

accumulator |= accumulator2;

if ((end - b) >= 4) {
uint32_t a_val, b_val;
memcpy(&a_val, a, sizeof(a_val));
memcpy(&b_val, b, sizeof(b_val));
accumulator |= a_val ^ b_val;
a += 4;
b += 4;
}
#endif

while (b < end)
accumulator |= (*a++ ^ *b++);

Expand All @@ -272,9 +480,14 @@ int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)

void srtp_cleanse(void *s, size_t len)
{
#if defined(__GNUC__)
memset(s, 0, len);
__asm__ __volatile__("" : : "r"(s) : "memory");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a little unrelated to the rest ... unless it was performance problem.

#else
volatile unsigned char *p = (volatile unsigned char *)s;
while (len--)
*p++ = 0;
#endif
}

void octet_string_set_to_zero(void *s, size_t len)
Expand Down