diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index c4cfda7..4bc8471 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -12,6 +12,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import List from andromede.expression.indexing_structure import IndexingStructure @@ -29,11 +30,13 @@ ParameterNode, PortFieldAggregatorNode, PortFieldNode, + ProblemParameterNode, + ProblemVariableNode, ScenarioOperatorNode, TimeEvalNode, TimeShiftNode, TimeSumNode, - VariableNode, ProblemVariableNode, ProblemParameterNode, + VariableNode, ) from .visitor import ExpressionVisitor, T, visit @@ -74,21 +77,32 @@ def literal(self, node: LiteralNode) -> IndexingStructure: def negation(self, node: NegationNode) -> IndexingStructure: return visit(node.operand, self) - def addition(self, node: AdditionNode) -> IndexingStructure: - operands = [visit(o, self) for o in node.operands] - res = operands[0] - for o in node.operands[1:]: + def _combine(self, operands: List[ExpressionNode]) -> IndexingStructure: + if not operands: + return IndexingStructure(False, False) + res = visit(operands[0], self) + if res.is_time_scenario_varying(): + return res + for o in operands[1:]: res = res | visit(o, self) + if res.is_time_scenario_varying(): + return res return res + def addition(self, node: AdditionNode) -> IndexingStructure: + # performance note: + # here we don't need to visit all nodes, we can stop as soon as + # index is true/true + return self._combine(node.operands) + def multiplication(self, node: MultiplicationNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) + return self._combine([node.left, node.right]) def division(self, node: DivisionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) + return self._combine([node.left, node.right]) def comparison(self, node: ComparisonNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) + return self._combine([node.left, node.right]) def variable(self, node: VariableNode) -> IndexingStructure: time = self.context.get_variable_structure(node.name).time == True @@ -111,13 +125,13 @@ def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure: ) def pb_variable(self, node: ProblemVariableNode) -> IndexingStructure: - return self.context.get_component_variable_structure( - node.component_id, node.name + raise ValueError( + "Not relevant to compute indexation on already instantiated problem variables." ) def pb_parameter(self, node: ProblemParameterNode) -> IndexingStructure: - return self.context.get_component_parameter_structure( - node.component_id, node.name + raise ValueError( + "Not relevant to compute indexation on already instantiated problem parameters." ) def time_shift(self, node: TimeShiftNode) -> IndexingStructure: