Skip to content

Commit

Permalink
Remove concept of "instances" (#53)
Browse files Browse the repository at this point in the history
* start removing "instances"

The only place ranges are allowed is now in time sum,
not any more in the too large time shift operator.
For now we remove the possibility to have multiple
explicit indices in a sum, it's not actually used.

Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>

* finalized removal of instances

Only printing remains to update

Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>

* add detection of nested time operations

Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>

* more tests and printing

Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>

* Fix import

Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>

* Remove duplicate tests

Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>

---------

Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>
  • Loading branch information
sylvlecl authored Sep 18, 2024
1 parent ce4f60f commit fb635a9
Show file tree
Hide file tree
Showing 43 changed files with 1,376 additions and 1,695 deletions.
8 changes: 4 additions & 4 deletions grammar/Expr.g4
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ expr
| expr op=('/' | '*') expr # muldiv
| expr op=('+' | '-') expr # addsub
| expr COMPARISON expr # comparison
| 'sum' '(' expr ')' # allTimeSum
| 'sum' '(' from=shift '..' to=shift ',' expr ')' # timeSum
| IDENTIFIER '(' expr ')' # function
| IDENTIFIER '[' shift (',' shift)* ']' # timeShift
| IDENTIFIER '[' expr (',' expr )* ']' # timeIndex
| IDENTIFIER '[' shift1=shift '..' shift2=shift ']' # timeShiftRange
| IDENTIFIER '[' expr '..' expr ']' # timeRange
| IDENTIFIER '[' shift ']' # timeShift
| IDENTIFIER '[' expr ']' # timeIndex
;

atom
Expand Down
56 changes: 17 additions & 39 deletions src/andromede/expression/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,25 @@
# This file is part of the Antares project.

from dataclasses import dataclass
from typing import List, Union, cast
from typing import List, cast

from .expression import (
AdditionNode,
AllTimeSumNode,
ComparisonNode,
ComponentParameterNode,
ComponentVariableNode,
DivisionNode,
ExpressionNode,
ExpressionRange,
InstancesTimeIndex,
LiteralNode,
MultiplicationNode,
NegationNode,
ParameterNode,
PortFieldAggregatorNode,
PortFieldNode,
ScenarioOperatorNode,
SubstractionNode,
TimeAggregatorNode,
TimeOperatorNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
VariableNode,
)
from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit
from .visitor import ExpressionVisitorOperations, visit


@dataclass(frozen=True)
Expand Down Expand Up @@ -63,38 +58,21 @@ def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
return ComponentParameterNode(node.component_id, node.name)

def copy_expression_range(
self, expression_range: ExpressionRange
) -> ExpressionRange:
return ExpressionRange(
start=visit(expression_range.start, self),
stop=visit(expression_range.stop, self),
step=visit(expression_range.step, self)
if expression_range.step is not None
else None,
)
def time_shift(self, node: TimeShiftNode) -> ExpressionNode:
return TimeShiftNode(visit(node.operand, self), visit(node.time_shift, self))

def time_eval(self, node: TimeEvalNode) -> ExpressionNode:
return TimeShiftNode(visit(node.operand, self), visit(node.eval_time, self))

def copy_instances_index(
self, instances_index: InstancesTimeIndex
) -> InstancesTimeIndex:
expressions = instances_index.expressions
if isinstance(expressions, ExpressionRange):
return InstancesTimeIndex(self.copy_expression_range(expressions))
if isinstance(expressions, list):
expressions_list = cast(List[ExpressionNode], expressions)
copy = [visit(e, self) for e in expressions_list]
return InstancesTimeIndex(copy)
raise ValueError("Unexpected type in instances index")

def time_operator(self, node: TimeOperatorNode) -> ExpressionNode:
return TimeOperatorNode(
def time_sum(self, node: TimeSumNode) -> ExpressionNode:
return TimeSumNode(
visit(node.operand, self),
node.name,
self.copy_instances_index(node.instances_index),
visit(node.from_time, self),
visit(node.to_time, self),
)

def time_aggregator(self, node: TimeAggregatorNode) -> ExpressionNode:
return TimeAggregatorNode(visit(node.operand, self), node.name, node.stay_roll)
def all_time_sum(self, node: AllTimeSumNode) -> ExpressionNode:
return AllTimeSumNode(visit(node.operand, self))

def scenario_operator(self, node: ScenarioOperatorNode) -> ExpressionNode:
return ScenarioOperatorNode(visit(node.operand, self), node.name)
Expand Down
28 changes: 15 additions & 13 deletions src/andromede/expression/degree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@

import andromede.expression.scenario_operator
from andromede.expression.expression import (
AllTimeSumNode,
ComponentParameterNode,
ComponentVariableNode,
PortFieldAggregatorNode,
PortFieldNode,
TimeOperatorNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
)

from .expression import (
Expand All @@ -30,7 +33,6 @@
ParameterNode,
ScenarioOperatorNode,
SubstractionNode,
TimeAggregatorNode,
VariableNode,
)
from .visitor import ExpressionVisitor, T, visit
Expand Down Expand Up @@ -78,17 +80,17 @@ def comp_variable(self, node: ComponentVariableNode) -> int:
def comp_parameter(self, node: ComponentParameterNode) -> int:
return 0

def time_operator(self, node: TimeOperatorNode) -> int:
if node.name in ["TimeShift", "TimeEvaluation"]:
return visit(node.operand, self)
else:
return NotImplemented

def time_aggregator(self, node: TimeAggregatorNode) -> int:
if node.name in ["TimeSum"]:
return visit(node.operand, self)
else:
return NotImplemented
def time_shift(self, node: TimeShiftNode) -> int:
return visit(node.operand, self)

def time_eval(self, node: TimeEvalNode) -> int:
return visit(node.operand, self)

def time_sum(self, node: TimeSumNode) -> int:
return visit(node.operand, self)

def all_time_sum(self, node: AllTimeSumNode) -> int:
return visit(node.operand, self)

def scenario_operator(self, node: ScenarioOperatorNode) -> int:
scenario_operator_cls = getattr(
Expand Down
64 changes: 25 additions & 39 deletions src/andromede/expression/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
VariableNode,
)
from andromede.expression.expression import (
AllTimeSumNode,
BinaryOperatorNode,
ExpressionRange,
InstancesTimeIndex,
PortFieldAggregatorNode,
PortFieldNode,
ScenarioOperatorNode,
TimeAggregatorNode,
TimeOperatorNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
)


Expand Down Expand Up @@ -76,12 +76,14 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool:
return self.variable(left, right)
if isinstance(left, ParameterNode) and isinstance(right, ParameterNode):
return self.parameter(left, right)
if isinstance(left, TimeOperatorNode) and isinstance(right, TimeOperatorNode):
return self.time_operator(left, right)
if isinstance(left, TimeAggregatorNode) and isinstance(
right, TimeAggregatorNode
):
return self.time_aggregator(left, right)
if isinstance(left, TimeShiftNode) and isinstance(right, TimeShiftNode):
return self.time_shift(left, right)
if isinstance(left, TimeEvalNode) and isinstance(right, TimeEvalNode):
return self.time_eval(left, right)
if isinstance(left, TimeSumNode) and isinstance(right, TimeSumNode):
return self.time_sum(left, right)
if isinstance(left, AllTimeSumNode) and isinstance(right, AllTimeSumNode):
return self.all_time_sum(left, right)
if isinstance(left, ScenarioOperatorNode) and isinstance(
right, ScenarioOperatorNode
):
Expand Down Expand Up @@ -130,42 +132,26 @@ def variable(self, left: VariableNode, right: VariableNode) -> bool:
def parameter(self, left: ParameterNode, right: ParameterNode) -> bool:
return left.name == right.name

def expression_range(self, left: ExpressionRange, right: ExpressionRange) -> bool:
if not self.visit(left.start, right.start):
return False
if not self.visit(left.stop, right.stop):
return False
if left.step is not None and right.step is not None:
return self.visit(left.step, right.step)
return left.step is None and right.step is None

def instances_index(self, lhs: InstancesTimeIndex, rhs: InstancesTimeIndex) -> bool:
if isinstance(lhs.expressions, ExpressionRange) and isinstance(
rhs.expressions, ExpressionRange
):
return self.expression_range(lhs.expressions, rhs.expressions)
if isinstance(lhs.expressions, list) and isinstance(rhs.expressions, list):
return len(lhs.expressions) == len(rhs.expressions) and all(
self.visit(l, r) for l, r in zip(lhs.expressions, rhs.expressions)
)
return False
def time_shift(self, left: TimeShiftNode, right: TimeShiftNode) -> bool:
return self.visit(left.time_shift, right.time_shift) and self.visit(
left.operand, right.operand
)

def time_operator(self, left: TimeOperatorNode, right: TimeOperatorNode) -> bool:
return (
left.name == right.name
and self.instances_index(left.instances_index, right.instances_index)
and self.visit(left.operand, right.operand)
def time_eval(self, left: TimeEvalNode, right: TimeEvalNode) -> bool:
return self.visit(left.eval_time, right.eval_time) and self.visit(
left.operand, right.operand
)

def time_aggregator(
self, left: TimeAggregatorNode, right: TimeAggregatorNode
) -> bool:
def time_sum(self, left: TimeSumNode, right: TimeSumNode) -> bool:
return (
left.name == right.name
and left.stay_roll == right.stay_roll
self.visit(left.from_time, right.from_time)
and self.visit(left.to_time, right.to_time)
and self.visit(left.operand, right.operand)
)

def all_time_sum(self, left: AllTimeSumNode, right: AllTimeSumNode) -> bool:
return self.visit(left.operand, right.operand)

def scenario_operator(
self, left: ScenarioOperatorNode, right: ScenarioOperatorNode
) -> bool:
Expand Down
37 changes: 23 additions & 14 deletions src/andromede/expression/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,25 @@
from typing import Dict

from andromede.expression.expression import (
AllTimeSumNode,
ComponentParameterNode,
ComponentVariableNode,
PortFieldAggregatorNode,
PortFieldNode,
TimeOperatorNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
)

from .expression import (
AdditionNode,
ComparisonNode,
DivisionNode,
ExpressionNode,
LiteralNode,
MultiplicationNode,
NegationNode,
ParameterNode,
ScenarioOperatorNode,
SubstractionNode,
TimeAggregatorNode,
VariableNode,
)
from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit
from .visitor import ExpressionVisitorOperations, visit


class ValueProvider(ABC):
Expand Down Expand Up @@ -120,10 +117,16 @@ def comp_parameter(self, node: ComponentParameterNode) -> float:
def comp_variable(self, node: ComponentVariableNode) -> float:
return self.context.get_component_variable_value(node.component_id, node.name)

def time_operator(self, node: TimeOperatorNode) -> float:
def time_shift(self, node: TimeShiftNode) -> float:
raise NotImplementedError()

def time_aggregator(self, node: TimeAggregatorNode) -> float:
def time_eval(self, node: TimeEvalNode) -> float:
raise NotImplementedError()

def time_sum(self, node: TimeSumNode) -> float:
raise NotImplementedError()

def all_time_sum(self, node: AllTimeSumNode) -> float:
raise NotImplementedError()

def scenario_operator(self, node: ScenarioOperatorNode) -> float:
Expand Down Expand Up @@ -156,8 +159,14 @@ def parameter(self, node: ParameterNode) -> float:
)
return self.context.get_parameter_value(node.name)

def time_operator(self, node: TimeOperatorNode) -> float:
raise ValueError("An instance index expression cannot contain time operator")
def time_shift(self, node: TimeShiftNode) -> float:
raise ValueError("An instance index expression cannot contain time shift")

def time_eval(self, node: TimeEvalNode) -> float:
raise ValueError("An instance index expression cannot contain time eval")

def time_sum(self, node: TimeSumNode) -> float:
raise ValueError("An instance index expression cannot contain time sum")

def time_aggregator(self, node: TimeAggregatorNode) -> float:
raise ValueError("An instance index expression cannot contain time aggregator")
def all_time_sum(self, node: AllTimeSumNode) -> float:
raise ValueError("An instance index expression cannot contain time sum")
38 changes: 18 additions & 20 deletions src/andromede/expression/evaluate_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from .expression import (
ComponentParameterNode,
ExpressionNode,
ExpressionRange,
InstancesTimeIndex,
LiteralNode,
ParameterNode,
)
Expand Down Expand Up @@ -79,21 +77,21 @@ def evaluate_time_id(expr: ExpressionNode, value_provider: ValueProvider) -> int
return time_id


def get_time_ids_from_instances_index(
instances_index: InstancesTimeIndex, value_provider: ValueProvider
) -> List[int]:
time_ids = []
if isinstance(instances_index.expressions, list): # List[ExpressionNode]
for expr in instances_index.expressions:
time_ids.append(evaluate_time_id(expr, value_provider))

elif isinstance(instances_index.expressions, ExpressionRange): # ExpressionRange
start_id = evaluate_time_id(instances_index.expressions.start, value_provider)
stop_id = evaluate_time_id(instances_index.expressions.stop, value_provider)
step_id = 1
if instances_index.expressions.step is not None:
step_id = evaluate_time_id(instances_index.expressions.step, value_provider)
# ExpressionRange includes stop_id whereas range excludes it
time_ids = list(range(start_id, stop_id + 1, step_id))

return time_ids
# def get_time_ids_from_instances_index(
# instances_index: InstancesTimeIndex, value_provider: ValueProvider
# ) -> List[int]:
# time_ids = []
# if isinstance(instances_index.expressions, list): # List[ExpressionNode]
# for expr in instances_index.expressions:
# time_ids.append(evaluate_time_id(expr, value_provider))
#
# elif isinstance(instances_index.expressions, ExpressionRange): # ExpressionRange
# start_id = evaluate_time_id(instances_index.expressions.start, value_provider)
# stop_id = evaluate_time_id(instances_index.expressions.stop, value_provider)
# step_id = 1
# if instances_index.expressions.step is not None:
# step_id = evaluate_time_id(instances_index.expressions.step, value_provider)
# # ExpressionRange includes stop_id whereas range excludes it
# time_ids = list(range(start_id, stop_id + 1, step_id))
#
# return time_ids
Loading

0 comments on commit fb635a9

Please sign in to comment.