Skip to content

Commit

Permalink
Add F4E2M1FN type: import mxfloat.h
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Dec 3, 2024
1 parent b86b937 commit fa539fb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions third_party/tsl/tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ cc_library(
deps = [
"@ml_dtypes//:float8",
"@ml_dtypes//:intn",
"@ml_dtypes//:mxfloat",
],
)

Expand Down
6 changes: 4 additions & 2 deletions third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_
#define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes

namespace tsl {
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
Expand Down

0 comments on commit fa539fb

Please sign in to comment.