Skip to content

Commit

Permalink
Expose save/load
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 27, 2024
1 parent e543bb8 commit 7a772b0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
2 changes: 1 addition & 1 deletion phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
copy_with, replace, find_differences
)

from ._tensors import Tensor, wrap, tensor, layout, native, numpy_ as numpy, reshaped_numpy, Dict, to_dict, from_dict, is_scalar, BROADCAST_FORMATTER as f
from ._tensors import Tensor, wrap, tensor, layout, native, numpy_ as numpy, reshaped_numpy, Dict, to_dict, from_dict, is_scalar, BROADCAST_FORMATTER as f, save, load

from ._sparse import dense, get_sparsity, get_format, to_format, is_sparse, sparse_tensor, stored_indices, stored_values, tensor_like, matrix_rank

Expand Down
29 changes: 27 additions & 2 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,7 +2874,20 @@ def specs_equal(spec1, spec2):
return spec1 == spec2


def save_tree(file: str, obj):
def save(file: str, obj):
"""
Saves a `Tensor` or tree using NumPy.
This function converts all tensors contained in `obj` to NumPy tensors before storing.
Each tensor is given a name corresponding to its path within `obj`, allowing reading only specific arrays from the file later on.
Pickle is used for structures, but no reference to `Tensor` or its sub-classes is included.
See Also:
`load()`.
Args:
file: Target file, will be stored as `.npz`.
obj: `Tensor` or tree to store.
"""
tree, tensors = disassemble_tree(obj, False, all_attributes)
paths = attr_paths(obj, all_attributes, 'root')
assert len(paths) == len(tensors)
Expand All @@ -2887,7 +2900,19 @@ def save_tree(file: str, obj):
np.savez(file, tree=tree, specs=specs, paths=paths, **{p: n for p, n in zip(all_paths, all_np)})


def load_tree(file: str):
def load(file: str):
"""
Loads a `Tensor` or tree from a file previously written using `save`.
All tensors are restored as NumPy arrays, not the backend-specific tensors they may have been written as.
Use `convert()` to convert all or some of the tensors to a different backend.
Args:
file: File to read.
Returns:
Same type as what was written.
"""
data = np.load(file, allow_pickle=True)
all_np = {k: data[k] for k in data if k not in ['tree', 'specs', 'paths']}
specs = [unserialize_spec(spec) for spec in data['specs'].tolist()]
Expand Down
7 changes: 7 additions & 0 deletions tests/commit/math/test__tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,3 +691,10 @@ def test_reshaped_tensor(self):
s = spatial(x=4, y=3)
t = math.reshaped_tensor(a, [s, s.as_dual()])
self.assertEqual(set(t.shape), set(s & dual))

def test_save_load(self):
files = math.layout(["A", "test/filename.png", math.ones(spatial(x='a,b,c'))], 'example:b')
math.save("files.npz", files)
loaded = math.load("files.npz")
self.assertEqual(loaded.example[0].native(), "A")
self.assertTrue((loaded.example[2] == 1).all)

0 comments on commit 7a772b0

Please sign in to comment.