Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load/Store for SIMD type wrappers #4288

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 75 additions & 6 deletions src/lib/utils/loadstor.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,34 @@ constexpr bool native_endianness_is_unknown() {
#endif
}

/**
* Models a custom type that provides factory methods to be loaded in big- or
* little-endian byte order.
*/
template <typename T>
concept custom_loadable = requires(std::span<const uint8_t, sizeof(T)> data) {
{ T::load_be(data) } -> std::same_as<T>;
{ T::load_le(data) } -> std::same_as<T>;
};

/**
* Models a custom type that provides store methods to be stored in big- or
* little-endian byte order.
*/
template <typename T>
concept custom_storable = requires(std::span<uint8_t, sizeof(T)> data, const T value) {
{ value.store_be(data) };
{ value.store_le(data) };
};

/**
* Models a type that can be loaded/stored from/to a byte range.
*/
template <typename T>
concept unsigned_integralish = std::unsigned_integral<strong_type_wrapped_type<T>> ||
(std::is_enum_v<T> && std::unsigned_integral<std::underlying_type_t<T>>);
concept unsigned_integralish =
std::unsigned_integral<strong_type_wrapped_type<T>> ||
(std::is_enum_v<T> && std::unsigned_integral<std::underlying_type_t<T>>) ||
(custom_loadable<strong_type_wrapped_type<T>> || custom_storable<strong_type_wrapped_type<T>>);

template <typename T>
struct wrapped_type_helper_with_enum {
Expand Down Expand Up @@ -276,6 +298,7 @@ inline constexpr void fallback_store_any(InT in, OutR&& out_range) {
* @return T loaded from @p in_range, as a big-endian value
*/
template <Endianness endianness, unsigned_integralish WrappedOutT, ranges::contiguous_range<uint8_t> InR>
requires(!custom_loadable<strong_type_wrapped_type<WrappedOutT>>)
inline constexpr WrappedOutT load_any(InR&& in_range) {
using OutT = detail::wrapped_type<WrappedOutT>;
ranges::assert_exact_byte_length<sizeof(OutT)>(in_range);
Expand All @@ -302,6 +325,28 @@ inline constexpr WrappedOutT load_any(InR&& in_range) {
}());
}

/**
* Load a custom object from a range in either big or little endian byte order
*
* This is the base implementation for custom objects (e.g. SIMD type wrappres),
* all other overloads are just convenience overloads.
*
* @param in_range a fixed-length byte range
* @return T loaded from @p in_range, as a big-endian value
*/
template <Endianness endianness, unsigned_integralish WrappedOutT, ranges::contiguous_range<uint8_t> InR>
requires(custom_loadable<strong_type_wrapped_type<WrappedOutT>>)
inline constexpr WrappedOutT load_any(InR&& in_range) {
using OutT = detail::wrapped_type<WrappedOutT>;
ranges::assert_exact_byte_length<sizeof(OutT)>(in_range);
std::span<const uint8_t, sizeof(OutT)> ins{in_range};
if constexpr(endianness == Endianness::Big) {
return wrap_strong_type<WrappedOutT>(OutT::load_be(ins));
} else {
return wrap_strong_type<WrappedOutT>(OutT::load_le(ins));
}
}

/**
* Load many unsigned integers
* @param in a fixed-length span to some bytes
Expand Down Expand Up @@ -335,9 +380,9 @@ template <Endianness endianness,
(std::same_as<AutoDetect, OutT> || std::same_as<OutT, std::ranges::range_value_t<OutR>>))
inline constexpr void load_any(OutR&& out, InR&& in) {
ranges::assert_equal_byte_lengths(out, in);
using element_type = std::ranges::range_value_t<OutR>;

auto load_elementwise = [&] {
using element_type = std::ranges::range_value_t<OutR>;
constexpr size_t bytes_per_element = sizeof(element_type);
std::span<const uint8_t> in_s(in);
for(auto& out_elem : out) {
Expand All @@ -352,7 +397,7 @@ inline constexpr void load_any(OutR&& out, InR&& in) {
if(std::is_constant_evaluated()) /* TODO: C++23: if consteval {} */ {
load_elementwise();
} else {
if constexpr(is_native(endianness)) {
if constexpr(is_native(endianness) && !custom_loadable<element_type>) {
typecast_copy(out, in);
} else {
load_elementwise();
Expand Down Expand Up @@ -502,6 +547,7 @@ namespace detail {
* @param out_range a byte range to store the word into
*/
template <Endianness endianness, unsigned_integralish WrappedInT, ranges::contiguous_output_range<uint8_t> OutR>
requires(!custom_storable<strong_type_wrapped_type<WrappedInT>>)
inline constexpr void store_any(WrappedInT wrapped_in, OutR&& out_range) {
const auto in = detail::unwrap_strong_type_or_enum(wrapped_in);
using InT = decltype(in);
Expand All @@ -527,6 +573,29 @@ inline constexpr void store_any(WrappedInT wrapped_in, OutR&& out_range) {
}
}

/**
* Store a custom word in either big or little endian byte order into a range
*
* This is the base implementation for storing custom objects, all other
* overloads are just convenience overloads.
*
* @param wrapped_in a custom object to be stored
* @param out_range a byte range to store the word into
*/
template <Endianness endianness, unsigned_integralish WrappedInT, ranges::contiguous_output_range<uint8_t> OutR>
requires(custom_storable<strong_type_wrapped_type<WrappedInT>>)
inline constexpr void store_any(WrappedInT wrapped_in, OutR&& out_range) {
const auto in = detail::unwrap_strong_type_or_enum(wrapped_in);
using InT = decltype(in);
ranges::assert_exact_byte_length<sizeof(in)>(out_range);
std::span<uint8_t, sizeof(InT)> outs{out_range};
if constexpr(endianness == Endianness::Big) {
in.store_be(outs);
} else {
in.store_le(outs);
}
}

/**
* Store many unsigned integers words into a byte range
* @param out a sized range of some bytes
Expand Down Expand Up @@ -561,9 +630,9 @@ template <Endianness endianness,
requires(std::same_as<AutoDetect, InT> || std::same_as<InT, std::ranges::range_value_t<InR>>)
inline constexpr void store_any(OutR&& out, InR&& in) {
ranges::assert_equal_byte_lengths(out, in);
using element_type = std::ranges::range_value_t<InR>;

auto store_elementwise = [&] {
using element_type = std::ranges::range_value_t<InR>;
constexpr size_t bytes_per_element = sizeof(element_type);
std::span<uint8_t> out_s(out);
for(auto in_elem : in) {
Expand All @@ -578,7 +647,7 @@ inline constexpr void store_any(OutR&& out, InR&& in) {
if(std::is_constant_evaluated()) /* TODO: C++23: if consteval {} */ {
store_elementwise();
} else {
if constexpr(is_native(endianness)) {
if constexpr(is_native(endianness) && !custom_storable<element_type>) {
typecast_copy(out, in);
} else {
store_elementwise();
Expand Down
10 changes: 10 additions & 0 deletions src/lib/utils/simd/simd_32.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <botan/types.h>

#include <span>

#if defined(BOTAN_TARGET_SUPPORTS_SSE2)
#include <emmintrin.h>
#define BOTAN_SIMD_USE_SSE2
Expand Down Expand Up @@ -186,6 +188,10 @@ class SIMD_4x32 final {
#endif
}

static SIMD_4x32 load_le(std::span<const uint8_t, 16> in) { return SIMD_4x32::load_le(in.data()); }

static SIMD_4x32 load_be(std::span<const uint8_t, 16> in) { return SIMD_4x32::load_be(in.data()); }

void store_le(uint32_t out[4]) const noexcept { this->store_le(reinterpret_cast<uint8_t*>(out)); }

void store_be(uint32_t out[4]) const noexcept { this->store_be(reinterpret_cast<uint8_t*>(out)); }
Expand Down Expand Up @@ -246,6 +252,10 @@ class SIMD_4x32 final {
#endif
}

void store_be(std::span<uint8_t, 16> out) const { this->store_be(out.data()); }

void store_le(std::span<uint8_t, 16> out) const { this->store_le(out.data()); }

/*
* This is used for SHA-2/SHACAL2
*/
Expand Down
34 changes: 34 additions & 0 deletions src/tests/test_simd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <botan/internal/loadstor.h>
#include <botan/internal/rotate.h>
#include <botan/internal/simd_32.h>
#include <botan/internal/stl_util.h>
#endif

namespace Botan_Tests {
Expand Down Expand Up @@ -134,6 +135,39 @@ class SIMD_32_Tests final : public Test {
test_eq(result, "shift right 2", input.shift_elems_right<2>(), pat3, pat4, 0, 0);
test_eq(result, "shift right 3", input.shift_elems_right<3>(), pat4, 0, 0, 0);

// Test load/stores SIMD wrapper types
const auto simd_le_in = Botan::hex_decode("ABCDEF01234567890123456789ABCDEF");
const auto simd_be_in = Botan::hex_decode("0123456789ABCDEFABCDEF0123456789");
const auto simd_le_array_in = Botan::concat(simd_le_in, simd_be_in);
const auto simd_be_array_in = Botan::concat(simd_be_in, simd_le_in);

auto simd_le = Botan::load_le<Botan::SIMD_4x32>(simd_le_in);
auto simd_be = Botan::load_be<Botan::SIMD_4x32>(simd_be_in);
auto simd_le_array = Botan::load_le<std::array<Botan::SIMD_4x32, 2>>(simd_le_array_in);
auto simd_be_array = Botan::load_be<std::array<Botan::SIMD_4x32, 2>>(simd_be_array_in);

auto simd_le_vec = Botan::store_le<std::vector<uint8_t>>(simd_le);
auto simd_be_vec = Botan::store_be(simd_be);
auto simd_le_array_vec = Botan::store_le<std::vector<uint8_t>>(simd_le_array);
auto simd_be_array_vec = Botan::store_be(simd_be_array);

result.test_is_eq("roundtrip SIMD little-endian", simd_le_vec, simd_le_in);
result.test_is_eq(
"roundtrip SIMD big-endian", std::vector(simd_be_vec.begin(), simd_be_vec.end()), simd_be_in);
result.test_is_eq("roundtrip SIMD array little-endian", simd_le_array_vec, simd_le_array_in);
result.test_is_eq("roundtrip SIMD array big-endian",
std::vector(simd_be_array_vec.begin(), simd_be_array_vec.end()),
simd_be_array_in);

using StrongSIMD = Botan::Strong<Botan::SIMD_4x32, struct StrongSIMD_>;
const auto simd_le_strong = Botan::load_le<StrongSIMD>(simd_le_in);
const auto simd_be_strong = Botan::load_be<StrongSIMD>(simd_be_in);

result.test_is_eq(
"roundtrip SIMD strong little-endian", Botan::store_le<std::vector<uint8_t>>(simd_le_strong), simd_le_in);
result.test_is_eq(
"roundtrip SIMD strong big-endian", Botan::store_be<std::vector<uint8_t>>(simd_be_strong), simd_be_in);

return {result};
}

Expand Down
Loading