diff --git a/src/andromede/simulation/linearize.py b/src/andromede/simulation/linearize.py index f1be143..13b5372 100644 --- a/src/andromede/simulation/linearize.py +++ b/src/andromede/simulation/linearize.py @@ -11,7 +11,7 @@ # This file is part of the Antares project. from abc import ABC, abstractmethod -from dataclasses import Field, dataclass +from dataclasses import dataclass from typing import List, Optional from andromede.expression import ( @@ -20,8 +20,6 @@ ExpressionVisitor, MultiplicationNode, NegationNode, - ValueProvider, - resolve_parameters, ) from andromede.expression.expression import ( AllTimeSumNode, @@ -47,7 +45,7 @@ TimeSumNode, VariableNode, ) -from andromede.expression.visitor import ExpressionVisitorOperations, T, visit +from andromede.expression.visitor import visit from andromede.simulation.linear_expression import LinearExpression, Term, TermKey @@ -85,15 +83,6 @@ def to_term(self) -> Term: ) -def generate_key(term: MutableTerm) -> TermKey: - return TermKey( - term.component_id, - term.variable_name, - term.time_index, - term.scenario_index, - ) - - @dataclass class LinearExpressionData: terms: List[MutableTerm] @@ -102,7 +91,7 @@ class LinearExpressionData: def build(self) -> LinearExpression: res_terms = {} for t in self.terms: - k = generate_key(t) + k = t.to_key() if k in res_terms: current_t = res_terms[k] current_t.coefficient += t.coefficient @@ -143,9 +132,13 @@ def multiplication(self, node: MultiplicationNode) -> LinearExpressionData: if not lhs.terms: multiplier = lhs.constant actual_expr = rhs - else: + elif not rhs.terms: multiplier = rhs.constant actual_expr = lhs + else: + raise ValueError( + "At least one operand of a multiplication must be a constant expression." + ) actual_expr.constant *= multiplier for t in actual_expr.terms: t.coefficient *= multiplier @@ -154,6 +147,10 @@ def multiplication(self, node: MultiplicationNode) -> LinearExpressionData: def division(self, node: DivisionNode) -> LinearExpressionData: lhs = visit(node.left, self) rhs = visit(node.right, self) + if rhs.terms: + raise ValueError( + "The second operand of a division must be a constant expression." + ) divider = rhs.constant actual_expr = lhs actual_expr.constant /= divider diff --git a/tests/unittests/expressions/test_linearization.py b/tests/unittests/expressions/test_linearization.py index 2895b27..7b14eb3 100644 --- a/tests/unittests/expressions/test_linearization.py +++ b/tests/unittests/expressions/test_linearization.py @@ -1,26 +1,26 @@ from unittest.mock import Mock import pytest +from pydantic._internal._validators import pattern_bytes_validator +from unittests.expressions.test_expressions import StructureProvider -from andromede.expression import ExpressionNode, var, LiteralNode +from andromede.expression import ExpressionNode, LiteralNode, literal, var from andromede.expression.expression import ( - comp_var, - comp_param, ComponentVariableNode, + NoScenarioIndex, + ProblemVariableNode, + TimeShift, + comp_param, + comp_var, + problem_var, ) from andromede.expression.operators_expansion import ( - expand_operators, ProblemDimensions, ProblemIndex, + expand_operators, ) -from andromede.simulation.linear_expression import ( - LinearExpression, - Term, -) -from andromede.simulation.linearize import linearize_expression, ParameterGetter -from unittests.expressions.test_expressions import ( - StructureProvider, -) +from andromede.simulation.linear_expression import LinearExpression, Term +from andromede.simulation.linearize import ParameterGetter, linearize_expression P = comp_param("c", "p") X = comp_var("c", "x") @@ -111,7 +111,19 @@ def test_linearization_of_nested_time_operations( assert _expand_and_linearize(expr, dimensions, index, params) == expected -# def test_expansion_and_linearization(): -# param_provider = ComponentEvaluationContext() -# with pytest.raises(ValueError): -# linearize_expression(expr, structure_provider, value_provider) +def test_invalid_multiplication() -> None: + params = Mock(spec=ParameterGetter) + + x = problem_var("c", "x", time_index=TimeShift(0), scenario_index=NoScenarioIndex()) + expression = x * x + with pytest.raises(ValueError, match="constant"): + linearize_expression(expression, 0, 0, params) + + +def test_invalid_division() -> None: + params = Mock(spec=ParameterGetter) + + x = problem_var("c", "x", time_index=TimeShift(0), scenario_index=NoScenarioIndex()) + expression = literal(1) / x + with pytest.raises(ValueError, match="constant"): + linearize_expression(expression, 0, 0, params)