diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 426e9cd1..c6828b15 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -396,7 +396,7 @@ def expression_range( ) -@dataclass +@dataclass(frozen=True) class InstancesTimeIndex: """ Defines a set of time indices on which a time operator operates. @@ -429,9 +429,9 @@ def __init__( ) if isinstance(expressions, (int, ExpressionNodeEfficient)): - self.expressions = [wrap_in_node(expressions)] + object.__setattr__(self, "expressions", [wrap_in_node(expressions)]) else: - self.expressions = expressions + object.__setattr__(self, "expressions", expressions) def is_simple(self) -> bool: if isinstance(self.expressions, list): diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index 11051dd5..102f4c45 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -89,20 +89,20 @@ def division(self, node: DivisionNode) -> IndexingStructure: def comparison(self, node: ComparisonNode) -> IndexingStructure: return visit(node.left, self) | visit(node.right, self) - def variable(self, node: VariableNode) -> IndexingStructure: - time = self.context.get_variable_structure(node.name).time == True - scenario = self.context.get_variable_structure(node.name).scenario == True - return IndexingStructure(time, scenario) + # def variable(self, node: VariableNode) -> IndexingStructure: + # time = self.context.get_variable_structure(node.name).time == True + # scenario = self.context.get_variable_structure(node.name).scenario == True + # return IndexingStructure(time, scenario) def parameter(self, node: ParameterNode) -> IndexingStructure: time = self.context.get_parameter_structure(node.name).time == True scenario = self.context.get_parameter_structure(node.name).scenario == True return IndexingStructure(time, scenario) - def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: - return self.context.get_component_variable_structure( - node.component_id, node.name - ) + # def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: + # return self.context.get_component_variable_structure( + # node.component_id, node.name + # ) def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure: return self.context.get_component_parameter_structure( diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index fcb0ee28..8f6f04ff 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -14,6 +14,7 @@ Specific modelling for "instantiated" linear expressions, with only variables and literal coefficients. """ +import dataclasses from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, TypeVar, Union @@ -22,17 +23,26 @@ from andromede.expression.expression_efficient import ( ComponentParameterNode, ExpressionNodeEfficient, + ExpressionRange, + Instances, + InstancesTimeIndex, LiteralNode, ParameterNode, + TimeOperatorNode, is_minus_one, is_one, is_zero, wrap_in_node, ) +from andromede.expression.indexing import ( + IndexingStructureProvider, + TimeScenarioIndexingVisitor, + compute_indexation, +) from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.print import print_expr from andromede.expression.scenario_operator import ScenarioOperator -from andromede.expression.time_operator import TimeAggregator, TimeOperator +from andromede.expression.time_operator import TimeAggregator, TimeOperator, TimeShift T = TypeVar("T") @@ -70,11 +80,23 @@ class TermEfficient: time_aggregator: Optional[TimeAggregator] = None scenario_operator: Optional[ScenarioOperator] = None - # TODO: It may be useful to define __add__, __sub__, etc on terms, which should return a linear expression ? + # TODO: Try to remove this + instances: Instances = field(init=False, default=Instances.SIMPLE) def __post_init__(self) -> None: object.__setattr__(self, "coefficient", wrap_in_node(self.coefficient)) + if self.time_operator is not None and self.time_aggregator is None: + + # TODO: Make a fuinction in time operator class + time_op_instances = Instances.SIMPLE if self.time_operator.time_ids.is_simple() else Instances.MULTIPLE + + if self.coefficient.instances != time_op_instances: + raise ValueError( + f"Cannot build term with coefficient {self.coefficient} and operator {self.time_operator} as they do not have the same number of instances." + ) + self.instances = self.coefficient.instances + def __eq__(self, other: "TermEfficient") -> bool: return ( isinstance(other, TermEfficient) @@ -131,6 +153,58 @@ def evaluate(self, context: ValueProvider) -> float: variable_value = context.get_variable_value(self.variable_name) return evaluate(self.coefficient, context) * variable_value + def compute_indexation( + self, provider: IndexingStructureProvider + ) -> IndexingStructure: + + # TODO: Improve this if/else structure + if self.component_id: + time = ( + provider.get_component_variable_structure(self.variable_name).time + == True + ) + scenario = ( + provider.get_component_variable_structure(self.variable_name).scenario + == True + ) + else: + time = provider.get_variable_structure(self.variable_name).time == True + scenario = ( + provider.get_variable_structure(self.variable_name).scenario == True + ) + return IndexingStructure(time, scenario) + + def shift( + self, + expressions: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + ], + ) -> "TermEfficient": + """ + Time shift of term + """ + # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression + + # Example : (param("p") * var("x")).shift(1) + # Previous behavior : p[t]x[t-1] + # New behavior : p[t-1]x[t-1] + + if self.time_operator is not None: + raise ValueError( + f"Composition of time operators {self.time_operator} and {TimeShift(InstancesTimeIndex(expressions))} is not allowed" + ) + + return dataclasses.replace( + self, + coefficient=TimeOperatorNode( + self.coefficient, "TimeShift", InstancesTimeIndex(expressions) + ), + time_operator=TimeShift(InstancesTimeIndex(expressions)), + ) + def generate_key(term: TermEfficient) -> TermKeyEfficient: return TermKeyEfficient( @@ -240,6 +314,8 @@ class LinearExpressionEfficient: terms: Dict[TermKeyEfficient, TermEfficient] constant: ExpressionNodeEfficient + # TODO: Probably not necessary, for now we replicate old implementation functioning + instances: Instances # TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break def __init__( @@ -272,6 +348,9 @@ def __init__( raise TypeError( f"Terms must be either of type Dict[str, Term] or List[Term], whereas {terms} is of type {type(terms)}" ) + + def _compute_instances(self): + def is_zero(self) -> bool: return len(self.terms) == 0 and is_zero(self.constant) @@ -484,6 +563,90 @@ def is_constant(self) -> bool: # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... return not self.terms + def compute_indexation( + self, provider: IndexingStructureProvider + ) -> IndexingStructure: + + indexing = compute_indexation(self.constant, provider) + for term in self.terms.values(): + indexing = indexing | term.compute_indexation(provider) + + return indexing + + # def sum(self) -> "ExpressionNode": + # if isinstance(self, TimeOperatorNode): + # return TimeAggregatorNode(self, "TimeSum", stay_roll=True) + # else: + # return _apply_if_node( + # self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) + # ) + + # def sum_connections(self) -> "ExpressionNode": + # if isinstance(self, PortFieldNode): + # return PortFieldAggregatorNode(self, aggregator="PortSum") + # raise ValueError( + # f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." + # ) + + def shift( + self, + expressions: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + ], + ) -> "LinearExpressionEfficient": + """ + Time shift of variables + + Examples: + >>> x.shift([1, 2, 4]) represents the vector of variables (x[t+1], x[t+2], x[t+4]) + + No variables allowed in shift argument, but parameter trees are ok + + It is assumed that the shift operator is linear and distributes to all terms and to the constant of the linear expression on which it is applied. + + Examples: + >>> (param("a") * var("x") + param("b")).shift([1, 2, 4]) represents the vector of variables (a[t+1]x[t+1] + b[t+1], a[t+2]x[t+2] + b[t+2], a[t+4]x[t+4] + b[t+4]) + """ + + # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression + + # Example : (param("p") * var("x")).shift(1) + # Previous behavior : p[t]x[t-1] + # New behavior : p[t-1]x[t-1] + + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.shift(expressions) + result_terms[generate_key(term_with_operator)] = term_with_operator + + result_constant = TimeOperatorNode( + self.constant, "TimeShift", InstancesTimeIndex(expressions) + ) + result_expr = LinearExpressionEfficient(result_terms, result_constant) + return result_expr + + # def eval( + # self, + # expressions: Union[ + # int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" + # ], + # ) -> "ExpressionNode": + # return _apply_if_node( + # self, + # lambda x: TimeOperatorNode( + # x, "TimeEvaluation", InstancesTimeIndex(expressions) + # ), + # ) + + # def expec(self) -> "ExpressionNode": + # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) + + # def variance(self) -> "ExpressionNode": + # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) + def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient @@ -542,13 +705,13 @@ def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: raise TypeError(f"Unable to wrap {obj} into a linear expression") -def _apply_if_node( - obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient] -) -> LinearExpressionEfficient: - if as_linear_expr := _wrap_in_linear_expr(obj): - return func(as_linear_expr) - else: - return NotImplemented +# def _apply_if_node( +# obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient] +# ) -> LinearExpressionEfficient: +# if as_linear_expr := _wrap_in_linear_expr(obj): +# return func(as_linear_expr) +# else: +# return NotImplemented def _copy_expression( diff --git a/src/andromede/expression/time_operator.py b/src/andromede/expression/time_operator.py index 63059528..4d4fc676 100644 --- a/src/andromede/expression/time_operator.py +++ b/src/andromede/expression/time_operator.py @@ -18,6 +18,8 @@ from dataclasses import dataclass from typing import Any, List, Tuple +from andromede.expression.expression_efficient import InstancesTimeIndex + @dataclass(frozen=True) class TimeOperator(ABC): @@ -27,21 +29,21 @@ class TimeOperator(ABC): - is_rolling: bool, if true, this means that the time_ids are to be understood relatively to the current timestep of the context AND that the represented expression will have to be instanciated for all timesteps. Otherwise, the time_ids are "absolute" times and the expression only has to be instantiated once. """ - time_ids: List[int] + time_ids: InstancesTimeIndex @classmethod @abstractmethod def rolling(cls) -> bool: raise NotImplementedError - def __post_init__(self) -> None: - if isinstance(self.time_ids, int): - object.__setattr__(self, "time_ids", [self.time_ids]) - elif isinstance(self.time_ids, range): - object.__setattr__(self, "time_ids", list(self.time_ids)) + # def __post_init__(self) -> None: + # if isinstance(self.time_ids, int): + # object.__setattr__(self, "time_ids", [self.time_ids]) + # elif isinstance(self.time_ids, range): + # object.__setattr__(self, "time_ids", list(self.time_ids)) def key(self) -> Tuple[int, ...]: - return tuple(self.time_ids) + return self.time_ids def size(self) -> int: return len(self.time_ids)