Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Start adding S1 and U1 as data types #244

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ml_dtypes/include/intn.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ struct intN {
}
};

using int1 = intN<1, int8_t>;
using int2 = intN<2, int8_t>;
using uint1 = intN<1, uint8_t>;
using uint2 = intN<2, uint8_t>;
using int4 = intN<4, int8_t>;
using uint4 = intN<4, uint8_t>;
Expand Down Expand Up @@ -295,6 +297,12 @@ struct intN_numeric_limits_base {

namespace std {

template <>
struct numeric_limits<ml_dtypes::int1>
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::int1> {};
template <>
struct numeric_limits<ml_dtypes::uint1>
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::uint1> {};
template <>
struct numeric_limits<ml_dtypes::int2>
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::int2> {};
Expand Down
39 changes: 27 additions & 12 deletions ml_dtypes/tests/intn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct IntNTestParamNames {
}
};

using IntNTypes = ::testing::Types<int2, uint2, int4, uint4>;
using IntNTypes = ::testing::Types<int1, uint1, int2, uint2, int4, uint4>;
TYPED_TEST_SUITE(IntNTest, IntNTypes, IntNTestParamNames);

TEST(IntNTest, NumericLimits) {
Expand All @@ -69,6 +69,13 @@ TEST(IntNTest, NumericLimits) {
EXPECT_EQ(static_cast<int>(std::numeric_limits<int4>::lowest()), -8);
EXPECT_EQ(std::numeric_limits<int4>::digits, 3);
EXPECT_EQ(std::numeric_limits<int4>::digits10, 0);
EXPECT_EQ(std::numeric_limits<int1>::is_signed, true);
EXPECT_EQ(std::numeric_limits<int1>::is_modulo, false);
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::min()), -1);
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::max()), 0);
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::lowest()), -1);
EXPECT_EQ(std::numeric_limits<int1>::digits, 0);
EXPECT_EQ(std::numeric_limits<int1>::digits10, 0);
}

TEST(UIntNTest, NumericLimits) {
Expand All @@ -79,6 +86,13 @@ TEST(UIntNTest, NumericLimits) {
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint4>::lowest()), 0);
EXPECT_EQ(std::numeric_limits<uint4>::digits, 4);
EXPECT_EQ(std::numeric_limits<uint4>::digits10, 1);
EXPECT_EQ(std::numeric_limits<uint1>::is_signed, false);
EXPECT_EQ(std::numeric_limits<uint1>::is_modulo, true);
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint1>::min()), 0);
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint1>::max()), 1);
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint4>::lowest()), 0);
EXPECT_EQ(std::numeric_limits<uint1>::digits, 1);
EXPECT_EQ(std::numeric_limits<uint1>::digits10, 0);
}

TYPED_TEST(IntNTest, NumericLimitsBase) {
Expand Down Expand Up @@ -141,7 +155,7 @@ struct ConstexprEvaluator {
};

// To avoid warnings about unused left-side of comma expressions,
// we additionally pass the expression through a contexpr function.
// we additionally pass the expression through a constexpr function.
template <typename T>
constexpr void ConstexprEvaluatorFunc(T&&) {}

Expand Down Expand Up @@ -211,8 +225,8 @@ TYPED_TEST(IntNTest, Casts) {
}

// Implicit conversion to optional.
std::optional<int64_t> c = IntN(1);
EXPECT_EQ(c, 1);
std::optional<int64_t> c = IntN(0);
EXPECT_EQ(c, 0);

// Loop through all valid values.
for (int i = static_cast<int>(std::numeric_limits<IntN>::min());
Expand Down Expand Up @@ -329,17 +343,18 @@ struct CustomInt {
int x;
};

#define GEN_DEST_TYPES(Type) \
std::pair<Type, bool>, std::pair<Type, uint2>, std::pair<Type, uint4>, \
std::pair<Type, uint8_t>, std::pair<Type, uint16_t>, \
std::pair<Type, uint32_t>, std::pair<Type, uint64_t>, \
std::pair<Type, int2>, std::pair<Type, int4>, std::pair<Type, int8_t>, \
std::pair<Type, int16_t>, std::pair<Type, int32_t>, \
#define GEN_DEST_TYPES(Type) \
std::pair<Type, bool>, std::pair<Type, uint1>, std::pair<Type, uint2>, \
std::pair<Type, uint4>, std::pair<Type, uint8_t>, \
std::pair<Type, uint16_t>, std::pair<Type, uint32_t>, \
std::pair<Type, uint64_t>, std::pair<Type, int1>, std::pair<Type, int2>, \
std::pair<Type, int4>, std::pair<Type, int8_t>, \
std::pair<Type, int16_t>, std::pair<Type, int32_t>, \
std::pair<Type, int64_t>, std::pair<Type, CustomInt>

#define GEN_TYPE_PAIRS() \
GEN_DEST_TYPES(int2), GEN_DEST_TYPES(uint2), GEN_DEST_TYPES(int4), \
GEN_DEST_TYPES(uint4)
GEN_DEST_TYPES(int1), GEN_DEST_TYPES(uint1), GEN_DEST_TYPES(int2), \
GEN_DEST_TYPES(uint2), GEN_DEST_TYPES(int4), GEN_DEST_TYPES(uint4)

using IntNCastTypePairs = ::testing::Types<GEN_TYPE_PAIRS()>;
template <typename CastPair>
Expand Down
Loading