Skip to content

Commit

Permalink
add check for valid dimension name before creating new variable (#360)
Browse files Browse the repository at this point in the history
* fix: unsupported dimension names (#359)

Add function to ensure that added variables do not use unvalid dimension names.

* Update linopy/model.py

More condensed check of dim_names

Co-authored-by: Lukas Trippe <lkstrp@pm.me>

* Update linopy/model.py

remove unnecessary else block

Co-authored-by: Lukas Trippe <lkstrp@pm.me>

* Update linopy/model.py

make check_valid_dim_names private

Co-authored-by: Lukas Trippe <lkstrp@pm.me>

* Update linopy/model.py

make check_valid_dim_names private

Co-authored-by: Lukas Trippe <lkstrp@pm.me>

---------

Co-authored-by: Lukas Trippe <lkstrp@pm.me>
  • Loading branch information
tom-welfonder and lkstrp authored Dec 12, 2024
1 parent 30ef9ea commit ede48e3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
27 changes: 27 additions & 0 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,32 @@ def check_force_dim_names(self, ds: DataArray | Dataset) -> None:
else:
return

def _check_valid_dim_names(self, ds: DataArray | Dataset) -> None:
"""
Ensure that the added data does not lead to a naming conflict.
Parameters
----------
model : linopy.Model
ds : xr.DataArray/Variable/LinearExpression
Data that should be added to the model.
Raises
------
ValueError
If broadcasted data leads to unsupported dimension names.
Returns
-------
None.
"""
unsupported_dim_names = ["labels", "coeffs", "vars", "sign", "rhs"]
if any(dim in unsupported_dim_names for dim in ds.dims):
raise ValueError(
"Added data contains unsupported dimension names. "
"Dimensions cannot be named 'labels', 'coeffs', 'vars', 'sign' or 'rhs'."
)

def add_variables(
self,
lower: Any = -inf,
Expand Down Expand Up @@ -474,6 +500,7 @@ def add_variables(
)
(data,) = xr.broadcast(data)
self.check_force_dim_names(data)
self._check_valid_dim_names(data)

if mask is not None:
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)
Expand Down
18 changes: 18 additions & 0 deletions test/test_variable_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ def test_variable_assignment_without_coords_and_dims_names():
assert x.dims == ("i", "j")


def test_variable_assignment_without_coords_and_invalid_dims_names():
# setting bounds without explicit coords
m = Model()
lower = np.zeros((10, 10))
upper = np.ones((10, 10))
with pytest.raises(ValueError):
m.add_variables(lower, upper, name="x", dims=["sign", "j"])


def test_variable_assignment_without_coords_in_bounds():
# setting bounds without explicit coords
m = Model()
Expand All @@ -141,6 +150,15 @@ def test_variable_assignment_without_coords_in_bounds():
assert x.dims == ("i", "j")


def test_variable_assignment_without_coords_in_bounds_invalid_dims_names():
# setting bounds without explicit coords
m = Model()
lower = xr.DataArray(np.zeros((10, 10)), dims=["i", "sign"])
upper = xr.DataArray(np.ones((10, 10)), dims=["i", "sign"])
with pytest.raises(ValueError):
m.add_variables(lower, upper, name="x")


def test_variable_assignment_without_coords_pandas_types():
# setting bounds without explicit coords
m = Model()
Expand Down

0 comments on commit ede48e3

Please sign in to comment.