Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting bfloat16 for tensorflow + jax (was failing because of intermediary numpy). #382

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions bindings/python/py_src/safetensors/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import jax.numpy as jnp
from jax import Array
from safetensors import numpy
from safetensors import numpy, safe_open


def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes:
Expand Down Expand Up @@ -122,8 +122,11 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]:
loaded = load_file(file_path)
```
"""
flat = numpy.load_file(filename)
return _np2jnp(flat)
result = {}
with safe_open(filename, framework="flax") as f:
for k in f.keys():
result[k] = f.get_tensor(k)
return result


def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]:
Expand Down
9 changes: 6 additions & 3 deletions bindings/python/py_src/safetensors/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import tensorflow as tf

from safetensors import numpy
from safetensors import numpy, safe_open


def save(tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
Expand Down Expand Up @@ -121,8 +121,11 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]:
loaded = load_file(file_path)
```
"""
flat = numpy.load_file(filename)
return _np2tf(flat)
result = {}
with safe_open(filename, framework="tf") as f:
for k in f.keys():
result[k] = f.get_tensor(k)
return result


def _np2tf(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, tf.Tensor]:
Expand Down
11 changes: 10 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,16 @@ fn get_pydtype(module: &PyModule, dtype: Dtype, is_numpy: bool) -> PyResult<PyOb
let dtype: PyObject = match dtype {
Dtype::F64 => module.getattr(intern!(py, "float64"))?.into(),
Dtype::F32 => module.getattr(intern!(py, "float32"))?.into(),
Dtype::BF16 => module.getattr(intern!(py, "bfloat16"))?.into(),
Dtype::BF16 => {
if is_numpy {
module
.getattr(intern!(py, "dtype"))?
.call1(("bfloat16",))?
.into()
} else {
module.getattr(intern!(py, "bfloat16"))?.into()
}
}
Dtype::F16 => module.getattr(intern!(py, "float16"))?.into(),
Dtype::U64 => module.getattr(intern!(py, "uint64"))?.into(),
Dtype::I64 => module.getattr(intern!(py, "int64"))?.into(),
Expand Down
7 changes: 3 additions & 4 deletions bindings/python/tests/test_flax_comparison.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import platform
import unittest

import numpy as np


if platform.system() != "Windows":
# This platform is not supported, we don't want to crash on import
Expand All @@ -21,6 +19,7 @@ def setUp(self):
"test": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test2": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test3": jnp.zeros((1024, 1024), dtype=jnp.float32),
"test4": jnp.zeros((1024, 1024), dtype=jnp.bfloat16),
}
self.flax_filename = "./tests/data/flax_load.msgpack"
self.sf_filename = "./tests/data/flax_load.safetensors"
Expand Down Expand Up @@ -51,7 +50,7 @@ def test_deserialization_safe(self):

for k, v in weights.items():
tv = flax_weights[k]
self.assertTrue(np.allclose(v, tv))
self.assertTrue(jnp.allclose(v, tv))

def test_deserialization_safe_open(self):
weights = {}
Expand All @@ -65,4 +64,4 @@ def test_deserialization_safe_open(self):

for k, v in weights.items():
tv = flax_weights[k]
self.assertTrue(np.allclose(v, tv))
self.assertTrue(jnp.allclose(v, tv))
14 changes: 14 additions & 0 deletions bindings/python/tests/test_tf_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ def test_deserialization_safe(self):
tv = tf_weights[k]
self.assertTrue(np.allclose(v, tv))

def test_bfloat16(self):
data = {
"test": tf.zeros((1024, 1024), dtype=tf.bfloat16),
}
save_file(data, self.sf_filename)
weights = {}
with safe_open(self.sf_filename, framework="tf") as f:
for k in f.keys():
weights[k] = f.get_tensor(k)

for k, v in weights.items():
tv = data[k]
self.assertTrue(tf.experimental.numpy.allclose(v, tv))

def test_deserialization_safe_open(self):
weights = {}
with safe_open(self.sf_filename, framework="tf") as f:
Expand Down
Loading