From ede48e3c97d613a4d660024f57c3073e102d3f37 Mon Sep 17 00:00:00 2001 From: Tom Welfonder <13498192+tom-welfonder@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:50:55 +0100 Subject: [PATCH] add check for valid dimension name before creating new variable (#360) * fix: unsupported dimension names (#359) Add function to ensure that added variables do not use unvalid dimension names. * Update linopy/model.py More condensed check of dim_names Co-authored-by: Lukas Trippe * Update linopy/model.py remove unnecessary else block Co-authored-by: Lukas Trippe * Update linopy/model.py make check_valid_dim_names private Co-authored-by: Lukas Trippe * Update linopy/model.py make check_valid_dim_names private Co-authored-by: Lukas Trippe --------- Co-authored-by: Lukas Trippe --- linopy/model.py | 27 +++++++++++++++++++++++++++ test/test_variable_assignment.py | 18 ++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/linopy/model.py b/linopy/model.py index 62f8b990..ab6c86da 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -367,6 +367,32 @@ def check_force_dim_names(self, ds: DataArray | Dataset) -> None: else: return + def _check_valid_dim_names(self, ds: DataArray | Dataset) -> None: + """ + Ensure that the added data does not lead to a naming conflict. + + Parameters + ---------- + model : linopy.Model + ds : xr.DataArray/Variable/LinearExpression + Data that should be added to the model. + + Raises + ------ + ValueError + If broadcasted data leads to unsupported dimension names. + + Returns + ------- + None. + """ + unsupported_dim_names = ["labels", "coeffs", "vars", "sign", "rhs"] + if any(dim in unsupported_dim_names for dim in ds.dims): + raise ValueError( + "Added data contains unsupported dimension names. " + "Dimensions cannot be named 'labels', 'coeffs', 'vars', 'sign' or 'rhs'." + ) + def add_variables( self, lower: Any = -inf, @@ -474,6 +500,7 @@ def add_variables( ) (data,) = xr.broadcast(data) self.check_force_dim_names(data) + self._check_valid_dim_names(data) if mask is not None: mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool) diff --git a/test/test_variable_assignment.py b/test/test_variable_assignment.py index 4dd7ca08..6c2e735a 100644 --- a/test/test_variable_assignment.py +++ b/test/test_variable_assignment.py @@ -131,6 +131,15 @@ def test_variable_assignment_without_coords_and_dims_names(): assert x.dims == ("i", "j") +def test_variable_assignment_without_coords_and_invalid_dims_names(): + # setting bounds without explicit coords + m = Model() + lower = np.zeros((10, 10)) + upper = np.ones((10, 10)) + with pytest.raises(ValueError): + m.add_variables(lower, upper, name="x", dims=["sign", "j"]) + + def test_variable_assignment_without_coords_in_bounds(): # setting bounds without explicit coords m = Model() @@ -141,6 +150,15 @@ def test_variable_assignment_without_coords_in_bounds(): assert x.dims == ("i", "j") +def test_variable_assignment_without_coords_in_bounds_invalid_dims_names(): + # setting bounds without explicit coords + m = Model() + lower = xr.DataArray(np.zeros((10, 10)), dims=["i", "sign"]) + upper = xr.DataArray(np.ones((10, 10)), dims=["i", "sign"]) + with pytest.raises(ValueError): + m.add_variables(lower, upper, name="x") + + def test_variable_assignment_without_coords_pandas_types(): # setting bounds without explicit coords m = Model()