Skip to content

Commit

Permalink
[WIP] Fixed a bug and using .compute() in lzyexpr tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescAlted committed Oct 15, 2024
1 parent bbeaa90 commit 437e9ec
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 57 deletions.
9 changes: 4 additions & 5 deletions examples/ndarray/reduce_expr_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

# This shows how to evaluate expressions with NDArray instances as operands.

import numpy as np

import blosc2
import numpy as np

shape = (2, 1, 2)
shape = (2, 2)

# Create a NDArray from a NumPy array
npa = np.linspace(0, 1, np.prod(shape), dtype=np.float32).reshape(shape)
Expand All @@ -25,7 +24,7 @@
# Get a LazyExpr instance
c = a**2 + b**2 + 2 * a * b + 1
# Evaluate: output is a NDArray
# d = c.sum(axis=1)
# d = c.sum()
# d = blosc2.sum(c, axis=1)
# d = blosc2.sum(c) + blosc2.mean(a)
# d = blosc2.sum(c, axis=1) + blosc2.mean(a, axis=0)
Expand All @@ -42,7 +41,7 @@
assert isinstance(e, blosc2.NDArray)
sum = e[()]
print("Reduction with Blosc2:\n", sum)
# npsum = npc.sum(axis=1)
# npsum = npc.sum()
# npsum = np.sum(npc, axis=1)
# npsum = np.sum(npc) + np.mean(npa)
# npsum = np.sum(npc, axis=1) + np.mean(npa, axis=0)
Expand Down
15 changes: 6 additions & 9 deletions src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,7 @@ def update_expr(self, new_op):
op_name = f"o{len(self.operands)}"
new_operands = {op_name: value2}
expression = f"({self.expression} {op} {op_name})"
self.operands = value1.operands
else:
if np.isscalar(value1):
expression = f"({value1} {op} {self.expression})"
Expand All @@ -1621,6 +1622,7 @@ def update_expr(self, new_op):
expression = f"({op_name}[{self.expression}])"
else:
expression = f"({op_name} {op} {self.expression})"
self.operands = value2.operands
blosc2._disable_overloaded_equal = False
# Return a new expression
operands = self.operands | new_operands
Expand All @@ -1629,22 +1631,17 @@ def update_expr(self, new_op):
@property
def dtype(self):
if hasattr(self, "_dtype"):
# This comes from string expressions, so it is always the same
# This comes from string expressions (probably saved on disk),
# so it is always the same
return self._dtype
# Updating the expression can change the dtype
return guess_dtype(self)
# Infer the dtype by evaluating the scalar version of the expression
# scalar_inputs = {}
# for key, value in self.operands.items():
# single_item = (0,) * len(value.shape)
# scalar_inputs[key] = value[single_item]
# # Evaluate the expression with scalar inputs (it is cheap)
# return ne.evaluate(self.expression, scalar_inputs).dtype

@property
def shape(self):
if hasattr(self, "_shape"):
# Contrarily to dtype, shape cannot change after creation of the expression
# This comes from string expressions (probably saved on disk),
# so it is always the same
return self._shape
self._shape, chunks, blocks, fast_path = validate_inputs(self.operands)
if fast_path:
Expand Down
Loading

0 comments on commit 437e9ec

Please sign in to comment.