Skip to content

Commit

Permalink
optimize indexing visitor
Browse files Browse the repository at this point in the history
Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>
  • Loading branch information
sylvlecl committed Sep 16, 2024
1 parent 3e57ae9 commit 231d0e4
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions src/andromede/expression/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List

from andromede.expression.indexing_structure import IndexingStructure

Expand All @@ -29,11 +30,13 @@
ParameterNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
ProblemVariableNode,
ScenarioOperatorNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
VariableNode, ProblemVariableNode, ProblemParameterNode,
VariableNode,
)
from .visitor import ExpressionVisitor, T, visit

Expand Down Expand Up @@ -74,21 +77,32 @@ def literal(self, node: LiteralNode) -> IndexingStructure:
def negation(self, node: NegationNode) -> IndexingStructure:
return visit(node.operand, self)

def addition(self, node: AdditionNode) -> IndexingStructure:
operands = [visit(o, self) for o in node.operands]
res = operands[0]
for o in node.operands[1:]:
def _combine(self, operands: List[ExpressionNode]) -> IndexingStructure:
if not operands:
return IndexingStructure(False, False)
res = visit(operands[0], self)
if res.is_time_scenario_varying():
return res
for o in operands[1:]:
res = res | visit(o, self)
if res.is_time_scenario_varying():
return res
return res

def addition(self, node: AdditionNode) -> IndexingStructure:
# performance note:
# here we don't need to visit all nodes, we can stop as soon as
# index is true/true
return self._combine(node.operands)

def multiplication(self, node: MultiplicationNode) -> IndexingStructure:
return visit(node.left, self) | visit(node.right, self)
return self._combine([node.left, node.right])

def division(self, node: DivisionNode) -> IndexingStructure:
return visit(node.left, self) | visit(node.right, self)
return self._combine([node.left, node.right])

def comparison(self, node: ComparisonNode) -> IndexingStructure:
return visit(node.left, self) | visit(node.right, self)
return self._combine([node.left, node.right])

def variable(self, node: VariableNode) -> IndexingStructure:
time = self.context.get_variable_structure(node.name).time == True
Expand All @@ -111,13 +125,13 @@ def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure:
)

def pb_variable(self, node: ProblemVariableNode) -> IndexingStructure:
return self.context.get_component_variable_structure(
node.component_id, node.name
raise ValueError(
"Not relevant to compute indexation on already instantiated problem variables."
)

def pb_parameter(self, node: ProblemParameterNode) -> IndexingStructure:
return self.context.get_component_parameter_structure(
node.component_id, node.name
raise ValueError(
"Not relevant to compute indexation on already instantiated problem parameters."
)

def time_shift(self, node: TimeShiftNode) -> IndexingStructure:
Expand Down

0 comments on commit 231d0e4

Please sign in to comment.