Skip to content

Commit

Permalink
[XLA] Start adding S1 and U1 as data types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706706372
  • Loading branch information
amitsabne1 authored and The ml_dtypes Authors committed Dec 16, 2024
1 parent 401ed6a commit 010c646
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
8 changes: 8 additions & 0 deletions ml_dtypes/include/intn.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ struct intN {
}
};

using int1 = intN<1, int8_t>;
using int2 = intN<2, int8_t>;
using uint1 = intN<1, uint8_t>;
using uint2 = intN<2, uint8_t>;
using int4 = intN<4, int8_t>;
using uint4 = intN<4, uint8_t>;
Expand Down Expand Up @@ -295,6 +297,12 @@ struct intN_numeric_limits_base {

namespace std {

template <>
struct numeric_limits<ml_dtypes::int1>
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::int1> {};
template <>
struct numeric_limits<ml_dtypes::uint1>
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::uint1> {};
template <>
struct numeric_limits<ml_dtypes::int2>
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::int2> {};
Expand Down
39 changes: 27 additions & 12 deletions ml_dtypes/tests/intn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct IntNTestParamNames {
}
};

using IntNTypes = ::testing::Types<int2, uint2, int4, uint4>;
using IntNTypes = ::testing::Types<int1, uint1, int2, uint2, int4, uint4>;
TYPED_TEST_SUITE(IntNTest, IntNTypes, IntNTestParamNames);

TEST(IntNTest, NumericLimits) {
Expand All @@ -69,6 +69,13 @@ TEST(IntNTest, NumericLimits) {
EXPECT_EQ(static_cast<int>(std::numeric_limits<int4>::lowest()), -8);
EXPECT_EQ(std::numeric_limits<int4>::digits, 3);
EXPECT_EQ(std::numeric_limits<int4>::digits10, 0);
EXPECT_EQ(std::numeric_limits<int1>::is_signed, true);
EXPECT_EQ(std::numeric_limits<int1>::is_modulo, false);
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::min()), -1);
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::max()), 0);
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::lowest()), -1);
EXPECT_EQ(std::numeric_limits<int1>::digits, 0);
EXPECT_EQ(std::numeric_limits<int1>::digits10, 0);
}

TEST(UIntNTest, NumericLimits) {
Expand All @@ -79,6 +86,13 @@ TEST(UIntNTest, NumericLimits) {
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint4>::lowest()), 0);
EXPECT_EQ(std::numeric_limits<uint4>::digits, 4);
EXPECT_EQ(std::numeric_limits<uint4>::digits10, 1);
EXPECT_EQ(std::numeric_limits<uint1>::is_signed, false);
EXPECT_EQ(std::numeric_limits<uint1>::is_modulo, true);
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint1>::min()), 0);
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint1>::max()), 1);
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint4>::lowest()), 0);
EXPECT_EQ(std::numeric_limits<uint1>::digits, 1);
EXPECT_EQ(std::numeric_limits<uint1>::digits10, 0);
}

TYPED_TEST(IntNTest, NumericLimitsBase) {
Expand Down Expand Up @@ -141,7 +155,7 @@ struct ConstexprEvaluator {
};

// To avoid warnings about unused left-side of comma expressions,
// we additionally pass the expression through a contexpr function.
// we additionally pass the expression through a constexpr function.
template <typename T>
constexpr void ConstexprEvaluatorFunc(T&&) {}

Expand Down Expand Up @@ -211,8 +225,8 @@ TYPED_TEST(IntNTest, Casts) {
}

// Implicit conversion to optional.
std::optional<int64_t> c = IntN(1);
EXPECT_EQ(c, 1);
std::optional<int64_t> c = IntN(0);
EXPECT_EQ(c, 0);

// Loop through all valid values.
for (int i = static_cast<int>(std::numeric_limits<IntN>::min());
Expand Down Expand Up @@ -329,17 +343,18 @@ struct CustomInt {
int x;
};

#define GEN_DEST_TYPES(Type) \
std::pair<Type, bool>, std::pair<Type, uint2>, std::pair<Type, uint4>, \
std::pair<Type, uint8_t>, std::pair<Type, uint16_t>, \
std::pair<Type, uint32_t>, std::pair<Type, uint64_t>, \
std::pair<Type, int2>, std::pair<Type, int4>, std::pair<Type, int8_t>, \
std::pair<Type, int16_t>, std::pair<Type, int32_t>, \
#define GEN_DEST_TYPES(Type) \
std::pair<Type, bool>, std::pair<Type, uint1>, std::pair<Type, uint2>, \
std::pair<Type, uint4>, std::pair<Type, uint8_t>, \
std::pair<Type, uint16_t>, std::pair<Type, uint32_t>, \
std::pair<Type, uint64_t>, std::pair<Type, int1>, std::pair<Type, int2>, \
std::pair<Type, int4>, std::pair<Type, int8_t>, \
std::pair<Type, int16_t>, std::pair<Type, int32_t>, \
std::pair<Type, int64_t>, std::pair<Type, CustomInt>

#define GEN_TYPE_PAIRS() \
GEN_DEST_TYPES(int2), GEN_DEST_TYPES(uint2), GEN_DEST_TYPES(int4), \
GEN_DEST_TYPES(uint4)
GEN_DEST_TYPES(int1), GEN_DEST_TYPES(uint1), GEN_DEST_TYPES(int2), \
GEN_DEST_TYPES(uint2), GEN_DEST_TYPES(int4), GEN_DEST_TYPES(uint4)

using IntNCastTypePairs = ::testing::Types<GEN_TYPE_PAIRS()>;
template <typename CastPair>
Expand Down

0 comments on commit 010c646

Please sign in to comment.