Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting bfloat16 for tensorflow + jax (was failing because of intermediary numpy). #382

Merged
merged 1 commit into from
Nov 17, 2023

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Nov 17, 2023

What does this PR do?

Adds support for loading bfloat16 directly on numpy (custom dtype), tensorflow, flax.
This PR is the minimal change that was found. Further improvements might be possible,
but not the goal here.

Fixes # (issue) or description of the problem this PR solves.

@Narsil Narsil requested a review from McPatate November 17, 2023 13:15
@Narsil Narsil merged commit 9e0bc08 into main Nov 17, 2023
11 checks passed
@Narsil Narsil deleted the bfloat16_tf_flax branch November 17, 2023 13:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants