Skip to content

Commit

Permalink
Support for complex parameters expressions in time operators (#54)
Browse files Browse the repository at this point in the history
We now have an "operators expansion" step, which transforms aggregation
operators into actual sums, easier to translate later in linear expressions.

Linear expressions now only contain simple terms related to one timestep and
one scenario (or none for constants).

---------

Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>
Co-authored-by: Thomas Bittar <thomas.bittar@rte-france.com>
  • Loading branch information
sylvlecl and tbittar authored Sep 27, 2024
1 parent a447724 commit b1649dc
Show file tree
Hide file tree
Showing 23 changed files with 1,408 additions and 988 deletions.
12 changes: 12 additions & 0 deletions src/andromede/expression/context_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
ComponentVariableNode,
ExpressionNode,
ParameterNode,
ProblemParameterNode,
ProblemVariableNode,
VariableNode,
)
from .visitor import visit
Expand Down Expand Up @@ -48,6 +50,16 @@ def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
"This expression has already been associated to another component."
)

def pb_variable(self, node: ProblemVariableNode) -> ExpressionNode:
raise ValueError(
"This expression has already been associated to another component."
)

def pb_parameter(self, node: ProblemParameterNode) -> ExpressionNode:
raise ValueError(
"This expression has already been associated to another component."
)


def add_component_context(id: str, expression: ExpressionNode) -> ExpressionNode:
return visit(expression, ContextAdder(id))
12 changes: 12 additions & 0 deletions src/andromede/expression/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
ParameterNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
ProblemVariableNode,
ScenarioOperatorNode,
TimeEvalNode,
TimeShiftNode,
Expand Down Expand Up @@ -58,6 +60,16 @@ def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
return ComponentParameterNode(node.component_id, node.name)

def pb_variable(self, node: ProblemVariableNode) -> ExpressionNode:
return ProblemVariableNode(
node.component_id, node.name, node.time_index, node.scenario_index
)

def pb_parameter(self, node: ProblemParameterNode) -> ExpressionNode:
return ProblemParameterNode(
node.component_id, node.name, node.time_index, node.scenario_index
)

def time_shift(self, node: TimeShiftNode) -> ExpressionNode:
return TimeShiftNode(visit(node.operand, self), visit(node.time_shift, self))

Expand Down
8 changes: 8 additions & 0 deletions src/andromede/expression/degree.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
ComponentVariableNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
ProblemVariableNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
Expand Down Expand Up @@ -77,6 +79,12 @@ def comp_variable(self, node: ComponentVariableNode) -> int:
def comp_parameter(self, node: ComponentParameterNode) -> int:
return 0

def pb_variable(self, node: ProblemVariableNode) -> int:
return 1

def pb_parameter(self, node: ProblemParameterNode) -> int:
return 0

def time_shift(self, node: TimeShiftNode) -> int:
return visit(node.operand, self)

Expand Down
50 changes: 50 additions & 0 deletions src/andromede/expression/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
from andromede.expression.expression import (
AllTimeSumNode,
BinaryOperatorNode,
ComponentParameterNode,
ComponentVariableNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
ProblemVariableNode,
ScenarioOperatorNode,
TimeEvalNode,
TimeShiftNode,
Expand Down Expand Up @@ -73,6 +77,22 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool:
return self.variable(left, right)
if isinstance(left, ParameterNode) and isinstance(right, ParameterNode):
return self.parameter(left, right)
if isinstance(left, ComponentVariableNode) and isinstance(
right, ComponentVariableNode
):
return self.comp_variable(left, right)
if isinstance(left, ComponentParameterNode) and isinstance(
right, ComponentParameterNode
):
return self.comp_parameter(left, right)
if isinstance(left, ProblemVariableNode) and isinstance(
right, ProblemVariableNode
):
return self.problem_variable(left, right)
if isinstance(left, ProblemParameterNode) and isinstance(
right, ProblemParameterNode
):
return self.problem_parameter(left, right)
if isinstance(left, TimeShiftNode) and isinstance(right, TimeShiftNode):
return self.time_shift(left, right)
if isinstance(left, TimeEvalNode) and isinstance(right, TimeEvalNode):
Expand Down Expand Up @@ -130,6 +150,36 @@ def variable(self, left: VariableNode, right: VariableNode) -> bool:
def parameter(self, left: ParameterNode, right: ParameterNode) -> bool:
return left.name == right.name

def comp_variable(
self, left: ComponentVariableNode, right: ComponentVariableNode
) -> bool:
return left.name == right.name and left.component_id == right.component_id

def comp_parameter(
self, left: ComponentParameterNode, right: ComponentParameterNode
) -> bool:
return left.name == right.name and left.component_id == right.component_id

def problem_variable(
self, left: ProblemVariableNode, right: ProblemVariableNode
) -> bool:
return (
left.name == right.name
and left.component_id == right.component_id
and left.time_index == right.time_index
and left.scenario_index == right.scenario_index
)

def problem_parameter(
self, left: ProblemParameterNode, right: ProblemParameterNode
) -> bool:
return (
left.name == right.name
and left.component_id == right.component_id
and left.time_index == right.time_index
and left.scenario_index == right.scenario_index
)

def time_shift(self, left: TimeShiftNode, right: TimeShiftNode) -> bool:
return self.visit(left.time_shift, right.time_shift) and self.visit(
left.operand, right.operand
Expand Down
46 changes: 9 additions & 37 deletions src/andromede/expression/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
ComponentVariableNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
ProblemVariableNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
Expand All @@ -33,6 +35,7 @@
ScenarioOperatorNode,
VariableNode,
)
from .indexing import IndexingStructureProvider
from .visitor import ExpressionVisitorOperations, visit


Expand All @@ -58,11 +61,6 @@ def get_component_variable_value(self, component_id: str, name: str) -> float:
def get_component_parameter_value(self, component_id: str, name: str) -> float:
...

# TODO: Should this really be an abstract method ? Or maybe, only the Provider in _make_value_provider should implement it. And the context attribute in the InstancesIndexVisitor is a ValueProvider that implements the parameter_is_constant_over_time method. Maybe create a child class of ValueProvider like TimeValueProvider ?
@abstractmethod
def parameter_is_constant_over_time(self, name: str) -> bool:
...


@dataclass(frozen=True)
class EvaluationContext(ValueProvider):
Expand All @@ -86,9 +84,6 @@ def get_component_variable_value(self, component_id: str, name: str) -> float:
def get_component_parameter_value(self, component_id: str, name: str) -> float:
raise NotImplementedError()

def parameter_is_constant_over_time(self, name: str) -> bool:
raise NotImplementedError()


@dataclass(frozen=True)
class EvaluationVisitor(ExpressionVisitorOperations[float]):
Expand Down Expand Up @@ -117,6 +112,12 @@ def comp_parameter(self, node: ComponentParameterNode) -> float:
def comp_variable(self, node: ComponentVariableNode) -> float:
return self.context.get_component_variable_value(node.component_id, node.name)

def pb_parameter(self, node: ProblemParameterNode) -> float:
raise ValueError("Should not reach here.")

def pb_variable(self, node: ProblemVariableNode) -> float:
raise ValueError("Should not reach here.")

def time_shift(self, node: TimeShiftNode) -> float:
raise NotImplementedError()

Expand All @@ -141,32 +142,3 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float:

def evaluate(expression: ExpressionNode, value_provider: ValueProvider) -> float:
return visit(expression, EvaluationVisitor(value_provider))


@dataclass(frozen=True)
class InstancesIndexVisitor(EvaluationVisitor):
"""
Evaluates an expression given as instances index which should have no variable and constant parameter values.
"""

def variable(self, node: VariableNode) -> float:
raise ValueError("An instance index expression cannot contain variable")

def parameter(self, node: ParameterNode) -> float:
if not self.context.parameter_is_constant_over_time(node.name):
raise ValueError(
"Parameter given in an instance index expression must be constant over time"
)
return self.context.get_parameter_value(node.name)

def time_shift(self, node: TimeShiftNode) -> float:
raise ValueError("An instance index expression cannot contain time shift")

def time_eval(self, node: TimeEvalNode) -> float:
raise ValueError("An instance index expression cannot contain time eval")

def time_sum(self, node: TimeSumNode) -> float:
raise ValueError("An instance index expression cannot contain time sum")

def all_time_sum(self, node: AllTimeSumNode) -> float:
raise ValueError("An instance index expression cannot contain time sum")
39 changes: 1 addition & 38 deletions src/andromede/expression/evaluate_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List

from andromede.expression.evaluate import InstancesIndexVisitor, ValueProvider
from andromede.expression.evaluate import ValueProvider

from .copy import CopyVisitor
from .expression import (
Expand Down Expand Up @@ -59,39 +58,3 @@ def resolve_parameters(
expression: ExpressionNode, parameter_provider: ParameterValueProvider
) -> ExpressionNode:
return visit(expression, ParameterResolver(parameter_provider))


def float_to_int(value: float) -> int:
if isinstance(value, int) or value.is_integer():
return int(value)
else:
raise ValueError(f"{value} is not an integer.")


def evaluate_time_id(expr: ExpressionNode, value_provider: ValueProvider) -> int:
float_time_id = visit(expr, InstancesIndexVisitor(value_provider))
try:
time_id = float_to_int(float_time_id)
except ValueError:
print(f"{expr} does not represent an integer time index.")
return time_id


# def get_time_ids_from_instances_index(
# instances_index: InstancesTimeIndex, value_provider: ValueProvider
# ) -> List[int]:
# time_ids = []
# if isinstance(instances_index.expressions, list): # List[ExpressionNode]
# for expr in instances_index.expressions:
# time_ids.append(evaluate_time_id(expr, value_provider))
#
# elif isinstance(instances_index.expressions, ExpressionRange): # ExpressionRange
# start_id = evaluate_time_id(instances_index.expressions.start, value_provider)
# stop_id = evaluate_time_id(instances_index.expressions.stop, value_provider)
# step_id = 1
# if instances_index.expressions.step is not None:
# step_id = evaluate_time_id(instances_index.expressions.step, value_provider)
# # ExpressionRange includes stop_id whereas range excludes it
# time_ids = list(range(start_id, stop_id + 1, step_id))
#
# return time_ids
Loading

0 comments on commit b1649dc

Please sign in to comment.