Skip to content

Commit

Permalink
Reverts 2533c35
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709153611
  • Loading branch information
jaeyoo authored and Google-ML-Automation committed Dec 23, 2024
1 parent 7738c74 commit e19d979
Show file tree
Hide file tree
Showing 83 changed files with 367 additions and 1,856 deletions.
1 change: 0 additions & 1 deletion third_party/tsl/tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,6 @@ cc_library(
deps = [
"@ml_dtypes//:float8",
"@ml_dtypes//:intn",
"@ml_dtypes//:mxfloat",
],
)

Expand Down
3 changes: 0 additions & 3 deletions third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@ limitations under the License.

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes

namespace tsl {
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;
using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu;

using int1 = ::ml_dtypes::int1;
using uint1 = ::ml_dtypes::uint1;
Expand Down
28 changes: 0 additions & 28 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,34 +219,6 @@ TEST(Array2dTest, LinspaceF8E3M4) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF4E2M1FN) {
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, LinspaceF8E8M0FNU) {
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
Loading

0 comments on commit e19d979

Please sign in to comment.