Skip to content

Commit

Permalink
Add checks on multiplications and divisions
Browse files Browse the repository at this point in the history
Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>
  • Loading branch information
sylvlecl committed Sep 16, 2024
1 parent d29f77d commit 3e57ae9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 31 deletions.
27 changes: 12 additions & 15 deletions src/andromede/simulation/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# This file is part of the Antares project.

from abc import ABC, abstractmethod
from dataclasses import Field, dataclass
from dataclasses import dataclass
from typing import List, Optional

from andromede.expression import (
Expand All @@ -20,8 +20,6 @@
ExpressionVisitor,
MultiplicationNode,
NegationNode,
ValueProvider,
resolve_parameters,
)
from andromede.expression.expression import (
AllTimeSumNode,
Expand All @@ -47,7 +45,7 @@
TimeSumNode,
VariableNode,
)
from andromede.expression.visitor import ExpressionVisitorOperations, T, visit
from andromede.expression.visitor import visit
from andromede.simulation.linear_expression import LinearExpression, Term, TermKey


Expand Down Expand Up @@ -85,15 +83,6 @@ def to_term(self) -> Term:
)


def generate_key(term: MutableTerm) -> TermKey:
return TermKey(
term.component_id,
term.variable_name,
term.time_index,
term.scenario_index,
)


@dataclass
class LinearExpressionData:
terms: List[MutableTerm]
Expand All @@ -102,7 +91,7 @@ class LinearExpressionData:
def build(self) -> LinearExpression:
res_terms = {}
for t in self.terms:
k = generate_key(t)
k = t.to_key()
if k in res_terms:
current_t = res_terms[k]
current_t.coefficient += t.coefficient
Expand Down Expand Up @@ -143,9 +132,13 @@ def multiplication(self, node: MultiplicationNode) -> LinearExpressionData:
if not lhs.terms:
multiplier = lhs.constant
actual_expr = rhs
else:
elif not rhs.terms:
multiplier = rhs.constant
actual_expr = lhs
else:
raise ValueError(
"At least one operand of a multiplication must be a constant expression."
)
actual_expr.constant *= multiplier
for t in actual_expr.terms:
t.coefficient *= multiplier
Expand All @@ -154,6 +147,10 @@ def multiplication(self, node: MultiplicationNode) -> LinearExpressionData:
def division(self, node: DivisionNode) -> LinearExpressionData:
lhs = visit(node.left, self)
rhs = visit(node.right, self)
if rhs.terms:
raise ValueError(
"The second operand of a division must be a constant expression."
)
divider = rhs.constant
actual_expr = lhs
actual_expr.constant /= divider
Expand Down
44 changes: 28 additions & 16 deletions tests/unittests/expressions/test_linearization.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from unittest.mock import Mock

import pytest
from pydantic._internal._validators import pattern_bytes_validator
from unittests.expressions.test_expressions import StructureProvider

from andromede.expression import ExpressionNode, var, LiteralNode
from andromede.expression import ExpressionNode, LiteralNode, literal, var
from andromede.expression.expression import (
comp_var,
comp_param,
ComponentVariableNode,
NoScenarioIndex,
ProblemVariableNode,
TimeShift,
comp_param,
comp_var,
problem_var,
)
from andromede.expression.operators_expansion import (
expand_operators,
ProblemDimensions,
ProblemIndex,
expand_operators,
)
from andromede.simulation.linear_expression import (
LinearExpression,
Term,
)
from andromede.simulation.linearize import linearize_expression, ParameterGetter
from unittests.expressions.test_expressions import (
StructureProvider,
)
from andromede.simulation.linear_expression import LinearExpression, Term
from andromede.simulation.linearize import ParameterGetter, linearize_expression

P = comp_param("c", "p")
X = comp_var("c", "x")
Expand Down Expand Up @@ -111,7 +111,19 @@ def test_linearization_of_nested_time_operations(
assert _expand_and_linearize(expr, dimensions, index, params) == expected


# def test_expansion_and_linearization():
# param_provider = ComponentEvaluationContext()
# with pytest.raises(ValueError):
# linearize_expression(expr, structure_provider, value_provider)
def test_invalid_multiplication() -> None:
params = Mock(spec=ParameterGetter)

x = problem_var("c", "x", time_index=TimeShift(0), scenario_index=NoScenarioIndex())
expression = x * x
with pytest.raises(ValueError, match="constant"):
linearize_expression(expression, 0, 0, params)


def test_invalid_division() -> None:
params = Mock(spec=ParameterGetter)

x = problem_var("c", "x", time_index=TimeShift(0), scenario_index=NoScenarioIndex())
expression = literal(1) / x
with pytest.raises(ValueError, match="constant"):
linearize_expression(expression, 0, 0, params)

0 comments on commit 3e57ae9

Please sign in to comment.