Skip to content

Commit

Permalink
Implement shift, eval and time sum of linear expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
tbittar committed Jul 15, 2024
1 parent 2860242 commit b640b22
Show file tree
Hide file tree
Showing 3 changed files with 325 additions and 121 deletions.
76 changes: 40 additions & 36 deletions src/andromede/expression/expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
EPS = 10 ** (-16)


class Instances(enum.Enum):
SIMPLE = "SIMPLE"
MULTIPLE = "MULTIPLE"
# class Instances(enum.Enum):
# SIMPLE = "SIMPLE"
# MULTIPLE = "MULTIPLE"


@dataclass(frozen=True)
Expand All @@ -43,7 +43,7 @@ class ExpressionNodeEfficient:
>>> expr = -var('x') + 5 / param('p')
"""

instances: Instances = field(init=False, default=Instances.SIMPLE)
# instances: Instances = field(init=False, default=Instances.SIMPLE)

def __neg__(self) -> "ExpressionNodeEfficient":
return _negate_node(self)
Expand Down Expand Up @@ -286,8 +286,8 @@ class LiteralNode(ExpressionNodeEfficient):
class UnaryOperatorNode(ExpressionNodeEfficient):
operand: ExpressionNodeEfficient

def __post_init__(self) -> None:
object.__setattr__(self, "instances", self.operand.instances)
# def __post_init__(self) -> None:
# object.__setattr__(self, "instances", self.operand.instances)


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -318,17 +318,17 @@ class BinaryOperatorNode(ExpressionNodeEfficient):
left: ExpressionNodeEfficient
right: ExpressionNodeEfficient

def __post_init__(self) -> None:
binary_operator_post_init(self, "apply binary operation with")
# def __post_init__(self) -> None:
# binary_operator_post_init(self, "apply binary operation with")


def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None:
if node.left.instances != node.right.instances:
raise ValueError(
f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances."
)
else:
object.__setattr__(node, "instances", node.left.instances)
# def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None:
# if node.left.instances != node.right.instances:
# raise ValueError(
# f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances."
# )
# else:
# object.__setattr__(node, "instances", node.left.instances)


class Comparator(enum.Enum):
Expand All @@ -341,32 +341,36 @@ class Comparator(enum.Enum):
class ComparisonNode(BinaryOperatorNode):
comparator: Comparator

def __post_init__(self) -> None:
binary_operator_post_init(self, "compare")
# def __post_init__(self) -> None:
# binary_operator_post_init(self, "compare")


@dataclass(frozen=True, eq=False)
class AdditionNode(BinaryOperatorNode):
def __post_init__(self) -> None:
binary_operator_post_init(self, "add")
pass
# def __post_init__(self) -> None:
# binary_operator_post_init(self, "add")


@dataclass(frozen=True, eq=False)
class SubstractionNode(BinaryOperatorNode):
def __post_init__(self) -> None:
binary_operator_post_init(self, "substract")
pass
# def __post_init__(self) -> None:
# binary_operator_post_init(self, "substract")


@dataclass(frozen=True, eq=False)
class MultiplicationNode(BinaryOperatorNode):
def __post_init__(self) -> None:
binary_operator_post_init(self, "multiply")
pass
# def __post_init__(self) -> None:
# binary_operator_post_init(self, "multiply")


@dataclass(frozen=True, eq=False)
class DivisionNode(BinaryOperatorNode):
def __post_init__(self) -> None:
binary_operator_post_init(self, "divide")
pass
# def __post_init__(self) -> None:
# binary_operator_post_init(self, "divide")


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -465,15 +469,15 @@ def __post_init__(self) -> None:
raise ValueError(
f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}"
)
if self.operand.instances == Instances.SIMPLE:
if self.instances_index.is_simple():
object.__setattr__(self, "instances", Instances.SIMPLE)
else:
object.__setattr__(self, "instances", Instances.MULTIPLE)
else:
raise ValueError(
"Cannot apply time operator on an expression that already represents multiple instances"
)
# if self.operand.instances == Instances.SIMPLE:
# if self.instances_index.is_simple():
# object.__setattr__(self, "instances", Instances.SIMPLE)
# else:
# object.__setattr__(self, "instances", Instances.MULTIPLE)
# else:
# raise ValueError(
# "Cannot apply time operator on an expression that already represents multiple instances"
# )


@dataclass(frozen=True, eq=False)
Expand All @@ -493,7 +497,7 @@ def __post_init__(self) -> None:
raise ValueError(
f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}"
)
object.__setattr__(self, "instances", Instances.SIMPLE)
# object.__setattr__(self, "instances", Instances.SIMPLE)


@dataclass(frozen=True, eq=False)
Expand All @@ -512,7 +516,7 @@ def __post_init__(self) -> None:
raise ValueError(
f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}"
)
object.__setattr__(self, "instances", Instances.SIMPLE)
# object.__setattr__(self, "instances", Instances.SIMPLE)


def sum_expressions(
Expand Down
128 changes: 89 additions & 39 deletions src/andromede/expression/linear_expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
InstancesTimeIndex,
LiteralNode,
ParameterNode,
ScenarioOperatorNode,
TimeAggregatorNode,
TimeOperatorNode,
is_minus_one,
Expand All @@ -37,7 +38,7 @@
from andromede.expression.indexing import IndexingStructureProvider, compute_indexation
from andromede.expression.indexing_structure import IndexingStructure
from andromede.expression.print import print_expr
from andromede.expression.scenario_operator import ScenarioOperator
from andromede.expression.scenario_operator import Expectation, ScenarioOperator
from andromede.expression.time_operator import (
TimeAggregator,
TimeEvaluation,
Expand Down Expand Up @@ -135,22 +136,39 @@ def evaluate(self, context: ValueProvider) -> float:
def compute_indexation(
self, provider: IndexingStructureProvider
) -> IndexingStructure:
# TODO: Improve this if/else structure
if self.component_id:
time = (
provider.get_component_variable_structure(self.variable_name).time
== True
)
scenario = (
provider.get_component_variable_structure(self.variable_name).scenario
== True
)

return IndexingStructure(
self._compute_time_indexing(provider),
self._compute_scenario_indexing(provider),
)

def _compute_time_indexing(self, provider: IndexingStructureProvider) -> bool:
if (self.time_aggregator and not self.time_aggregator.stay_roll) or (
self.time_operator and not self.time_operator.rolling()
):
time = False
else:
time = provider.get_variable_structure(self.variable_name).time == True
scenario = (
provider.get_variable_structure(self.variable_name).scenario == True
)
return IndexingStructure(time, scenario)
if self.component_id:
time = provider.get_component_variable_structure(
self.component_id, self.variable_name
).time
else:
time = provider.get_variable_structure(self.variable_name).time
return time

def _compute_scenario_indexing(self, provider: IndexingStructureProvider) -> bool:
if self.scenario_operator:
scenario = False
else:
# TODO: Improve this if/else structure, probably simplify IndexingStructureProvider
if self.component_id:
scenario = provider.get_component_variable_structure(
self.component_id, self.variable_name
).scenario

else:
scenario = provider.get_variable_structure(self.variable_name).scenario
return scenario

def sum(
self,
Expand Down Expand Up @@ -204,7 +222,7 @@ def shift(
List["ExpressionNodeEfficient"],
"ExpressionRange",
],
) -> "LinearExpressionEfficient":
) -> "TermEfficient":
"""
Shorthand for shift on a single time step
Expand Down Expand Up @@ -236,7 +254,7 @@ def eval(
List["ExpressionNodeEfficient"],
"ExpressionRange",
],
) -> "LinearExpressionEfficient":
) -> "TermEfficient":
"""
Shorthand for eval on a single time step
Expand All @@ -260,6 +278,10 @@ def eval(
else:
return self.sum(eval=expressions)

def expec(self) -> "TermEfficient":
# TODO: Do we need checks, in case a scenario operator is already specified ?
return dataclasses.replace(self, scenario_operator=Expectation())


def generate_key(term: TermEfficient) -> TermKeyEfficient:
return TermKeyEfficient(
Expand Down Expand Up @@ -595,7 +617,12 @@ def is_constant(self) -> bool:
def compute_indexation(
self, provider: IndexingStructureProvider
) -> IndexingStructure:
indexing = compute_indexation(self.constant, provider)
"""
Computes the (time, scenario) indexing of a linear expression.
Time and scenario indexation is driven by the indexation of variables in the expression. If a single term is indexed by time (resp. scenario), then the linear expression is indexed by time (resp. scenario).
"""
indexing = IndexingStructure(False, False)
for term in self.terms.values():
indexing = indexing | term.compute_indexation(provider)

Expand Down Expand Up @@ -635,15 +662,40 @@ def sum(

if shift is not None:
sum_args = {"shift": shift}
stay_roll = True

result_constant = TimeAggregatorNode(
TimeOperatorNode(
self.constant,
"TimeShift",
InstancesTimeIndex(shift),
),
"TimeSum",
stay_roll=True,
)
elif eval is not None:
sum_args = {"eval": eval}
stay_roll = True

result_constant = TimeAggregatorNode(
TimeOperatorNode(
self.constant,
"TimeEvaluation",
InstancesTimeIndex(eval),
),
"TimeSum",
stay_roll=True,
)
else: # x.sum() -> Sum over all time block
sum_args = {}
stay_roll = False

return self._apply_operator(sum_args, stay_roll)
result_constant = TimeAggregatorNode(
self.constant,
"TimeSum",
stay_roll=False,
)

return LinearExpressionEfficient(
self._apply_operator(sum_args), result_constant
)

def _apply_operator(
self,
Expand All @@ -657,26 +709,13 @@ def _apply_operator(
None,
],
],
stay_roll: bool,
):
result_terms = {}
for term in self.terms.values():
term_with_operator = term.sum(**sum_args)
result_terms[generate_key(term_with_operator)] = term_with_operator

result_constant = TimeAggregatorNode(
TimeOperatorNode(
self.constant,
"TimeShift",
InstancesTimeIndex(
sum_args.popitem()[1]
), # Dangerous as it modifies sum_args ?
),
"TimeSum",
stay_roll=stay_roll,
)
result_expr = LinearExpressionEfficient(result_terms, result_constant)
return result_expr
return result_terms

# def sum_connections(self) -> "ExpressionNode":
# if isinstance(self, PortFieldNode):
Expand Down Expand Up @@ -737,8 +776,19 @@ def eval(
else:
return self.sum(eval=expressions)

# def expec(self) -> "ExpressionNode":
# return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation"))
def expec(self) -> "LinearExpressionEfficient":
"""
Expectation of linear expression. As the operator is linear, it distributes over all terms and the constant
"""

result_terms = {}
for term in self.terms.values():
term_with_operator = term.expec()
result_terms[generate_key(term_with_operator)] = term_with_operator

result_constant = ScenarioOperatorNode(self.constant, "Expectation")
result_expr = LinearExpressionEfficient(result_terms, result_constant)
return result_expr

# def variance(self) -> "ExpressionNode":
# return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance"))
Expand Down
Loading

0 comments on commit b640b22

Please sign in to comment.