From b640b22f8f092d2fd4119b990e653b0b02b8be01 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Mon, 15 Jul 2024 18:46:14 +0200 Subject: [PATCH] Implement shift, eval and time sum of linear expressions --- .../expression/expression_efficient.py | 76 +++--- .../expression/linear_expression_efficient.py | 128 ++++++--- .../expressions/test_expressions_efficient.py | 242 ++++++++++++++---- 3 files changed, 325 insertions(+), 121 deletions(-) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 8356c225..a8f76ac9 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -26,9 +26,9 @@ EPS = 10 ** (-16) -class Instances(enum.Enum): - SIMPLE = "SIMPLE" - MULTIPLE = "MULTIPLE" +# class Instances(enum.Enum): +# SIMPLE = "SIMPLE" +# MULTIPLE = "MULTIPLE" @dataclass(frozen=True) @@ -43,7 +43,7 @@ class ExpressionNodeEfficient: >>> expr = -var('x') + 5 / param('p') """ - instances: Instances = field(init=False, default=Instances.SIMPLE) + # instances: Instances = field(init=False, default=Instances.SIMPLE) def __neg__(self) -> "ExpressionNodeEfficient": return _negate_node(self) @@ -286,8 +286,8 @@ class LiteralNode(ExpressionNodeEfficient): class UnaryOperatorNode(ExpressionNodeEfficient): operand: ExpressionNodeEfficient - def __post_init__(self) -> None: - object.__setattr__(self, "instances", self.operand.instances) + # def __post_init__(self) -> None: + # object.__setattr__(self, "instances", self.operand.instances) @dataclass(frozen=True, eq=False) @@ -318,17 +318,17 @@ class BinaryOperatorNode(ExpressionNodeEfficient): left: ExpressionNodeEfficient right: ExpressionNodeEfficient - def __post_init__(self) -> None: - binary_operator_post_init(self, "apply binary operation with") + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "apply binary operation with") -def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: - if node.left.instances != node.right.instances: - raise ValueError( - f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." - ) - else: - object.__setattr__(node, "instances", node.left.instances) +# def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: +# if node.left.instances != node.right.instances: +# raise ValueError( +# f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." +# ) +# else: +# object.__setattr__(node, "instances", node.left.instances) class Comparator(enum.Enum): @@ -341,32 +341,36 @@ class Comparator(enum.Enum): class ComparisonNode(BinaryOperatorNode): comparator: Comparator - def __post_init__(self) -> None: - binary_operator_post_init(self, "compare") + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "compare") @dataclass(frozen=True, eq=False) class AdditionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "add") + pass + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "add") @dataclass(frozen=True, eq=False) class SubstractionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "substract") + pass + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "substract") @dataclass(frozen=True, eq=False) class MultiplicationNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "multiply") + pass + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "multiply") @dataclass(frozen=True, eq=False) class DivisionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "divide") + pass + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "divide") @dataclass(frozen=True, eq=False) @@ -465,15 +469,15 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" ) - if self.operand.instances == Instances.SIMPLE: - if self.instances_index.is_simple(): - object.__setattr__(self, "instances", Instances.SIMPLE) - else: - object.__setattr__(self, "instances", Instances.MULTIPLE) - else: - raise ValueError( - "Cannot apply time operator on an expression that already represents multiple instances" - ) + # if self.operand.instances == Instances.SIMPLE: + # if self.instances_index.is_simple(): + # object.__setattr__(self, "instances", Instances.SIMPLE) + # else: + # object.__setattr__(self, "instances", Instances.MULTIPLE) + # else: + # raise ValueError( + # "Cannot apply time operator on an expression that already represents multiple instances" + # ) @dataclass(frozen=True, eq=False) @@ -493,7 +497,7 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" ) - object.__setattr__(self, "instances", Instances.SIMPLE) + # object.__setattr__(self, "instances", Instances.SIMPLE) @dataclass(frozen=True, eq=False) @@ -512,7 +516,7 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}" ) - object.__setattr__(self, "instances", Instances.SIMPLE) + # object.__setattr__(self, "instances", Instances.SIMPLE) def sum_expressions( diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 8deff23a..f7e10267 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -27,6 +27,7 @@ InstancesTimeIndex, LiteralNode, ParameterNode, + ScenarioOperatorNode, TimeAggregatorNode, TimeOperatorNode, is_minus_one, @@ -37,7 +38,7 @@ from andromede.expression.indexing import IndexingStructureProvider, compute_indexation from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.print import print_expr -from andromede.expression.scenario_operator import ScenarioOperator +from andromede.expression.scenario_operator import Expectation, ScenarioOperator from andromede.expression.time_operator import ( TimeAggregator, TimeEvaluation, @@ -135,22 +136,39 @@ def evaluate(self, context: ValueProvider) -> float: def compute_indexation( self, provider: IndexingStructureProvider ) -> IndexingStructure: - # TODO: Improve this if/else structure - if self.component_id: - time = ( - provider.get_component_variable_structure(self.variable_name).time - == True - ) - scenario = ( - provider.get_component_variable_structure(self.variable_name).scenario - == True - ) + + return IndexingStructure( + self._compute_time_indexing(provider), + self._compute_scenario_indexing(provider), + ) + + def _compute_time_indexing(self, provider: IndexingStructureProvider) -> bool: + if (self.time_aggregator and not self.time_aggregator.stay_roll) or ( + self.time_operator and not self.time_operator.rolling() + ): + time = False else: - time = provider.get_variable_structure(self.variable_name).time == True - scenario = ( - provider.get_variable_structure(self.variable_name).scenario == True - ) - return IndexingStructure(time, scenario) + if self.component_id: + time = provider.get_component_variable_structure( + self.component_id, self.variable_name + ).time + else: + time = provider.get_variable_structure(self.variable_name).time + return time + + def _compute_scenario_indexing(self, provider: IndexingStructureProvider) -> bool: + if self.scenario_operator: + scenario = False + else: + # TODO: Improve this if/else structure, probably simplify IndexingStructureProvider + if self.component_id: + scenario = provider.get_component_variable_structure( + self.component_id, self.variable_name + ).scenario + + else: + scenario = provider.get_variable_structure(self.variable_name).scenario + return scenario def sum( self, @@ -204,7 +222,7 @@ def shift( List["ExpressionNodeEfficient"], "ExpressionRange", ], - ) -> "LinearExpressionEfficient": + ) -> "TermEfficient": """ Shorthand for shift on a single time step @@ -236,7 +254,7 @@ def eval( List["ExpressionNodeEfficient"], "ExpressionRange", ], - ) -> "LinearExpressionEfficient": + ) -> "TermEfficient": """ Shorthand for eval on a single time step @@ -260,6 +278,10 @@ def eval( else: return self.sum(eval=expressions) + def expec(self) -> "TermEfficient": + # TODO: Do we need checks, in case a scenario operator is already specified ? + return dataclasses.replace(self, scenario_operator=Expectation()) + def generate_key(term: TermEfficient) -> TermKeyEfficient: return TermKeyEfficient( @@ -595,7 +617,12 @@ def is_constant(self) -> bool: def compute_indexation( self, provider: IndexingStructureProvider ) -> IndexingStructure: - indexing = compute_indexation(self.constant, provider) + """ + Computes the (time, scenario) indexing of a linear expression. + + Time and scenario indexation is driven by the indexation of variables in the expression. If a single term is indexed by time (resp. scenario), then the linear expression is indexed by time (resp. scenario). + """ + indexing = IndexingStructure(False, False) for term in self.terms.values(): indexing = indexing | term.compute_indexation(provider) @@ -635,15 +662,40 @@ def sum( if shift is not None: sum_args = {"shift": shift} - stay_roll = True + + result_constant = TimeAggregatorNode( + TimeOperatorNode( + self.constant, + "TimeShift", + InstancesTimeIndex(shift), + ), + "TimeSum", + stay_roll=True, + ) elif eval is not None: sum_args = {"eval": eval} - stay_roll = True + + result_constant = TimeAggregatorNode( + TimeOperatorNode( + self.constant, + "TimeEvaluation", + InstancesTimeIndex(eval), + ), + "TimeSum", + stay_roll=True, + ) else: # x.sum() -> Sum over all time block sum_args = {} - stay_roll = False - return self._apply_operator(sum_args, stay_roll) + result_constant = TimeAggregatorNode( + self.constant, + "TimeSum", + stay_roll=False, + ) + + return LinearExpressionEfficient( + self._apply_operator(sum_args), result_constant + ) def _apply_operator( self, @@ -657,26 +709,13 @@ def _apply_operator( None, ], ], - stay_roll: bool, ): result_terms = {} for term in self.terms.values(): term_with_operator = term.sum(**sum_args) result_terms[generate_key(term_with_operator)] = term_with_operator - result_constant = TimeAggregatorNode( - TimeOperatorNode( - self.constant, - "TimeShift", - InstancesTimeIndex( - sum_args.popitem()[1] - ), # Dangerous as it modifies sum_args ? - ), - "TimeSum", - stay_roll=stay_roll, - ) - result_expr = LinearExpressionEfficient(result_terms, result_constant) - return result_expr + return result_terms # def sum_connections(self) -> "ExpressionNode": # if isinstance(self, PortFieldNode): @@ -737,8 +776,19 @@ def eval( else: return self.sum(eval=expressions) - # def expec(self) -> "ExpressionNode": - # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) + def expec(self) -> "LinearExpressionEfficient": + """ + Expectation of linear expression. As the operator is linear, it distributes over all terms and the constant + """ + + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.expec() + result_terms[generate_key(term_with_operator)] = term_with_operator + + result_constant = ScenarioOperatorNode(self.constant, "Expectation") + result_expr = LinearExpressionEfficient(result_terms, result_constant) + return result_expr # def variance(self) -> "ExpressionNode": # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index c4cdaa91..8705d342 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -16,12 +16,18 @@ import pytest +from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import EvaluationContext, ValueProvider from andromede.expression.evaluate_parameters import ParameterValueProvider from andromede.expression.expression_efficient import ( ComponentParameterNode, + ExpressionNodeEfficient, ExpressionRange, + InstancesTimeIndex, + LiteralNode, ParameterNode, + TimeAggregatorNode, + TimeOperatorNode, ) from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure @@ -29,12 +35,14 @@ LinearExpressionEfficient, StandaloneConstraint, TermEfficient, + TermKeyEfficient, comp_param, comp_var, literal, param, var, ) +from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum from andromede.model.constraint import Constraint from andromede.simulation.linearize import linearize_expression @@ -341,6 +349,149 @@ def test_comparison() -> None: assert str(expr_eq) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= 0" +# TODO: Maybe imagine other use cases, that should be forbidden (composition of operators...) +@pytest.mark.parametrize( + "expr, expec_terms, expec_constant", + [ + ( + (var("x") + var("y") + literal(1)).shift(1), + { + TermKeyEfficient( + "", + "x", + TimeShift(InstancesTimeIndex(1)), + time_aggregator=TimeSum( + stay_roll=True + ), # The internal representation of shift(1) is sum(shift=1) + scenario_operator=None, + ): TermEfficient( + TimeOperatorNode( + LiteralNode(1), "TimeShift", InstancesTimeIndex(1) + ), + "", + "x", + time_operator=TimeShift( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + ), + TermKeyEfficient( + "", + "y", + TimeShift( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + scenario_operator=None, + ): TermEfficient( + TimeOperatorNode( + LiteralNode(1), "TimeShift", InstancesTimeIndex(1) + ), + "", + "y", + time_operator=TimeShift(InstancesTimeIndex(1)), + time_aggregator=TimeSum(stay_roll=True), + ), + }, + TimeAggregatorNode( + TimeOperatorNode(LiteralNode(1), "TimeShift", InstancesTimeIndex(1)), + "TimeSum", + stay_roll=True, + ), # TODO: Could it be simplified online ? + ), + ( + (var("x") + var("y") + literal(1)).eval(1), + { + TermKeyEfficient( + "", + "x", + TimeEvaluation(InstancesTimeIndex(1)), + time_aggregator=TimeSum( + stay_roll=True + ), # The internal representation of eval(1) is sum(eval=1) + scenario_operator=None, + ): TermEfficient( + TimeOperatorNode( + LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + ), + "", + "x", + time_operator=TimeEvaluation( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + ), + TermKeyEfficient( + "", + "y", + TimeEvaluation( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + scenario_operator=None, + ): TermEfficient( + TimeOperatorNode( + LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + ), + "", + "y", + time_operator=TimeEvaluation(InstancesTimeIndex(1)), + time_aggregator=TimeSum(stay_roll=True), + ), + }, + TimeAggregatorNode( + TimeOperatorNode( + LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + ), + "TimeSum", + stay_roll=True, + ), # TODO: Could it be simplified online ? + ), + ( + (var("x") + var("y") + literal(1)).sum(), + { + TermKeyEfficient( + "", + "x", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + scenario_operator=None, + ): TermEfficient( + LiteralNode(1), # Sum is not distributed to coeff + "", + "x", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + ), + TermKeyEfficient( + "", + "y", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + scenario_operator=None, + ): TermEfficient( + LiteralNode(1), # Sum is not distributed to coeff + "", + "y", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + ), + }, + TimeAggregatorNode( + LiteralNode(1), "TimeSum", stay_roll=False + ), # TODO: Could it be simplified online ? + ), + ], +) +def test_operators_are_correctly_distributed_over_terms( + expr: LinearExpressionEfficient, + expec_terms: Dict[TermKeyEfficient, TermEfficient], + expec_constant: ExpressionNodeEfficient, +) -> None: + assert expr.terms == expec_terms + assert expressions_equal(expr.constant, expec_constant) + + class StructureProvider(IndexingStructureProvider): def get_component_variable_structure( self, component_id: str, name: str @@ -381,44 +532,52 @@ def test_eval_on_time_step_list_raises_value_error() -> None: _ = x.eval(ExpressionRange(1, 4)) -def test_shift_on_single_time_step() -> None: - x = var("x") - expr = x.shift(1) - - provider = StructureProvider() - assert expr.compute_indexation(provider) == IndexingStructure(True, True) - - -def test_shifting_sum() -> None: - x = var("x") - expr = x.sum(shift=ExpressionRange(1, 4)) - - provider = StructureProvider() - assert expr.compute_indexation(provider) == IndexingStructure(True, True) - - -def test_eval() -> None: - x = var("x") - expr = x.eval(1) - provider = StructureProvider() - - assert expr.compute_indexation(provider) == IndexingStructure(False, True) - - -def test_eval_sum() -> None: - x = var("x") - expr = x.eval(ExpressionRange(1, 4)).sum() - provider = StructureProvider() - - assert expr.compute_indexation(provider) == IndexingStructure(False, True) - - -def test_sum_over_whole_block() -> None: - x = var("x") - expr = x.sum() +@pytest.mark.parametrize( + "linear_expr, expected_indexation", + [ + ( + var("x").shift(1), + IndexingStructure(True, True), + ), + ( + var("x").sum(shift=ExpressionRange(1, 4)), + IndexingStructure(True, True), + ), + ( + var("x").eval(1), + IndexingStructure(False, True), + ), + ( + var("x").sum(eval=ExpressionRange(1, 4)), + IndexingStructure(False, True), + ), + ( + var("x").sum(), + IndexingStructure(False, True), + ), + ( + var("x").expec(), + IndexingStructure(True, False), + ), + ( + var("x").sum().expec(), + IndexingStructure(False, False), + ), + ( + var("x").shift(1).expec(), + IndexingStructure(True, False), + ), + ( + var("x").eval(1).expec(), + IndexingStructure(False, False), + ), + ], +) +def test_compute_indexation( + linear_expr: LinearExpressionEfficient, expected_indexation: IndexingStructure +) -> None: provider = StructureProvider() - - assert expr.compute_indexation(provider) == IndexingStructure(False, True) + assert linear_expr.compute_indexation(provider) == expected_indexation def test_forbidden_composition_should_raise_value_error() -> None: @@ -427,15 +586,6 @@ def test_forbidden_composition_should_raise_value_error() -> None: _ = x.shift(ExpressionRange(1, 4)) + var("y") -def test_expectation() -> None: - x = var("x") - expr = x.expec() - provider = StructureProvider() - - assert expr.compute_indexation(provider) == IndexingStructure(True, False) - assert expr.instances == Instances.SIMPLE - - def test_indexing_structure_comparison() -> None: free = IndexingStructure(True, True) constant = IndexingStructure(False, False)