From 66fe64c86c01a15169d6b7fbd86f0f5fd0e75004 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 17 Nov 2023 14:13:46 +0100 Subject: [PATCH] Supporting bfloat16 for tensorflow + jax (was failing because of intermediary numpy). --- bindings/python/py_src/safetensors/flax.py | 9 ++++++--- bindings/python/py_src/safetensors/tensorflow.py | 9 ++++++--- bindings/python/src/lib.rs | 11 ++++++++++- bindings/python/tests/test_flax_comparison.py | 7 +++---- bindings/python/tests/test_tf_comparison.py | 14 ++++++++++++++ 5 files changed, 39 insertions(+), 11 deletions(-) diff --git a/bindings/python/py_src/safetensors/flax.py b/bindings/python/py_src/safetensors/flax.py index 906ae02c..208264ab 100644 --- a/bindings/python/py_src/safetensors/flax.py +++ b/bindings/python/py_src/safetensors/flax.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from jax import Array -from safetensors import numpy +from safetensors import numpy, safe_open def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes: @@ -122,8 +122,11 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]: loaded = load_file(file_path) ``` """ - flat = numpy.load_file(filename) - return _np2jnp(flat) + result = {} + with safe_open(filename, framework="flax") as f: + for k in f.keys(): + result[k] = f.get_tensor(k) + return result def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]: diff --git a/bindings/python/py_src/safetensors/tensorflow.py b/bindings/python/py_src/safetensors/tensorflow.py index 96b704f1..6b0cbf28 100644 --- a/bindings/python/py_src/safetensors/tensorflow.py +++ b/bindings/python/py_src/safetensors/tensorflow.py @@ -4,7 +4,7 @@ import numpy as np import tensorflow as tf -from safetensors import numpy +from safetensors import numpy, safe_open def save(tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes: @@ -121,8 +121,11 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]: loaded = load_file(file_path) ``` """ - flat = numpy.load_file(filename) - return _np2tf(flat) + result = {} + with safe_open(filename, framework="tf") as f: + for k in f.keys(): + result[k] = f.get_tensor(k) + return result def _np2tf(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, tf.Tensor]: diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index d088c277..f44915cc 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -962,7 +962,16 @@ fn get_pydtype(module: &PyModule, dtype: Dtype, is_numpy: bool) -> PyResult module.getattr(intern!(py, "float64"))?.into(), Dtype::F32 => module.getattr(intern!(py, "float32"))?.into(), - Dtype::BF16 => module.getattr(intern!(py, "bfloat16"))?.into(), + Dtype::BF16 => { + if is_numpy { + module + .getattr(intern!(py, "dtype"))? + .call1(("bfloat16",))? + .into() + } else { + module.getattr(intern!(py, "bfloat16"))?.into() + } + } Dtype::F16 => module.getattr(intern!(py, "float16"))?.into(), Dtype::U64 => module.getattr(intern!(py, "uint64"))?.into(), Dtype::I64 => module.getattr(intern!(py, "int64"))?.into(), diff --git a/bindings/python/tests/test_flax_comparison.py b/bindings/python/tests/test_flax_comparison.py index fcb0db5d..23d6c4bb 100644 --- a/bindings/python/tests/test_flax_comparison.py +++ b/bindings/python/tests/test_flax_comparison.py @@ -1,8 +1,6 @@ import platform import unittest -import numpy as np - if platform.system() != "Windows": # This platform is not supported, we don't want to crash on import @@ -21,6 +19,7 @@ def setUp(self): "test": jnp.zeros((1024, 1024), dtype=jnp.float32), "test2": jnp.zeros((1024, 1024), dtype=jnp.float32), "test3": jnp.zeros((1024, 1024), dtype=jnp.float32), + "test4": jnp.zeros((1024, 1024), dtype=jnp.bfloat16), } self.flax_filename = "./tests/data/flax_load.msgpack" self.sf_filename = "./tests/data/flax_load.safetensors" @@ -51,7 +50,7 @@ def test_deserialization_safe(self): for k, v in weights.items(): tv = flax_weights[k] - self.assertTrue(np.allclose(v, tv)) + self.assertTrue(jnp.allclose(v, tv)) def test_deserialization_safe_open(self): weights = {} @@ -65,4 +64,4 @@ def test_deserialization_safe_open(self): for k, v in weights.items(): tv = flax_weights[k] - self.assertTrue(np.allclose(v, tv)) + self.assertTrue(jnp.allclose(v, tv)) diff --git a/bindings/python/tests/test_tf_comparison.py b/bindings/python/tests/test_tf_comparison.py index bdaae061..ac41e6f6 100644 --- a/bindings/python/tests/test_tf_comparison.py +++ b/bindings/python/tests/test_tf_comparison.py @@ -62,6 +62,20 @@ def test_deserialization_safe(self): tv = tf_weights[k] self.assertTrue(np.allclose(v, tv)) + def test_bfloat16(self): + data = { + "test": tf.zeros((1024, 1024), dtype=tf.bfloat16), + } + save_file(data, self.sf_filename) + weights = {} + with safe_open(self.sf_filename, framework="tf") as f: + for k in f.keys(): + weights[k] = f.get_tensor(k) + + for k, v in weights.items(): + tv = data[k] + self.assertTrue(tf.experimental.numpy.allclose(v, tv)) + def test_deserialization_safe_open(self): weights = {} with safe_open(self.sf_filename, framework="tf") as f: