Skip to content

Commit

Permalink
fix for numpy>=2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ahalev committed Aug 8, 2024
1 parent fc2c64f commit e22b7ee
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/pymgrid/modules/base/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ def step(self, action, normalized=True):
except (IndexError, TypeError):
if not isinstance(denormalized_action, (float, int)):
try:
flat_dim = np.product(denormalized_action.shape)
assert flat_dim == 0
except (AttributeError, AssertionError):
flat_dim = np.prod(denormalized_action.shape)
if flat_dim != 0:
raise ValueError(f'Bad action {denormalized_action}')
except AttributeError:
raise ValueError(f'Bad action {denormalized_action}')
else:
denormalized_action = 0.0
Expand Down

0 comments on commit e22b7ee

Please sign in to comment.