Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tbittar committed Jun 26, 2024
1 parent 946c366 commit e1ebb29
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 27 deletions.
6 changes: 3 additions & 3 deletions src/andromede/expression/expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def expression_range(
)


@dataclass
@dataclass(frozen=True)
class InstancesTimeIndex:
"""
Defines a set of time indices on which a time operator operates.
Expand Down Expand Up @@ -429,9 +429,9 @@ def __init__(
)

if isinstance(expressions, (int, ExpressionNodeEfficient)):
self.expressions = [wrap_in_node(expressions)]
object.__setattr__(self, "expressions", [wrap_in_node(expressions)])
else:
self.expressions = expressions
object.__setattr__(self, "expressions", expressions)

def is_simple(self) -> bool:
if isinstance(self.expressions, list):
Expand Down
16 changes: 8 additions & 8 deletions src/andromede/expression/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,20 @@ def division(self, node: DivisionNode) -> IndexingStructure:
def comparison(self, node: ComparisonNode) -> IndexingStructure:
return visit(node.left, self) | visit(node.right, self)

def variable(self, node: VariableNode) -> IndexingStructure:
time = self.context.get_variable_structure(node.name).time == True
scenario = self.context.get_variable_structure(node.name).scenario == True
return IndexingStructure(time, scenario)
# def variable(self, node: VariableNode) -> IndexingStructure:
# time = self.context.get_variable_structure(node.name).time == True
# scenario = self.context.get_variable_structure(node.name).scenario == True
# return IndexingStructure(time, scenario)

def parameter(self, node: ParameterNode) -> IndexingStructure:
time = self.context.get_parameter_structure(node.name).time == True
scenario = self.context.get_parameter_structure(node.name).scenario == True
return IndexingStructure(time, scenario)

def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure:
return self.context.get_component_variable_structure(
node.component_id, node.name
)
# def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure:
# return self.context.get_component_variable_structure(
# node.component_id, node.name
# )

def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure:
return self.context.get_component_parameter_structure(
Expand Down
181 changes: 172 additions & 9 deletions src/andromede/expression/linear_expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Specific modelling for "instantiated" linear expressions,
with only variables and literal coefficients.
"""
import dataclasses
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

Expand All @@ -22,17 +23,26 @@
from andromede.expression.expression_efficient import (
ComponentParameterNode,
ExpressionNodeEfficient,
ExpressionRange,
Instances,
InstancesTimeIndex,
LiteralNode,
ParameterNode,
TimeOperatorNode,
is_minus_one,
is_one,
is_zero,
wrap_in_node,
)
from andromede.expression.indexing import (
IndexingStructureProvider,
TimeScenarioIndexingVisitor,
compute_indexation,
)
from andromede.expression.indexing_structure import IndexingStructure
from andromede.expression.print import print_expr
from andromede.expression.scenario_operator import ScenarioOperator
from andromede.expression.time_operator import TimeAggregator, TimeOperator
from andromede.expression.time_operator import TimeAggregator, TimeOperator, TimeShift

T = TypeVar("T")

Expand Down Expand Up @@ -70,11 +80,23 @@ class TermEfficient:
time_aggregator: Optional[TimeAggregator] = None
scenario_operator: Optional[ScenarioOperator] = None

# TODO: It may be useful to define __add__, __sub__, etc on terms, which should return a linear expression ?
# TODO: Try to remove this
instances: Instances = field(init=False, default=Instances.SIMPLE)

def __post_init__(self) -> None:
object.__setattr__(self, "coefficient", wrap_in_node(self.coefficient))

if self.time_operator is not None and self.time_aggregator is None:

# TODO: Make a fuinction in time operator class
time_op_instances = Instances.SIMPLE if self.time_operator.time_ids.is_simple() else Instances.MULTIPLE

if self.coefficient.instances != time_op_instances:
raise ValueError(
f"Cannot build term with coefficient {self.coefficient} and operator {self.time_operator} as they do not have the same number of instances."
)
self.instances = self.coefficient.instances

def __eq__(self, other: "TermEfficient") -> bool:
return (
isinstance(other, TermEfficient)
Expand Down Expand Up @@ -131,6 +153,58 @@ def evaluate(self, context: ValueProvider) -> float:
variable_value = context.get_variable_value(self.variable_name)
return evaluate(self.coefficient, context) * variable_value

def compute_indexation(
self, provider: IndexingStructureProvider
) -> IndexingStructure:

# TODO: Improve this if/else structure
if self.component_id:
time = (
provider.get_component_variable_structure(self.variable_name).time
== True
)
scenario = (
provider.get_component_variable_structure(self.variable_name).scenario
== True
)
else:
time = provider.get_variable_structure(self.variable_name).time == True
scenario = (
provider.get_variable_structure(self.variable_name).scenario == True
)
return IndexingStructure(time, scenario)

def shift(
self,
expressions: Union[
int,
"ExpressionNodeEfficient",
List["ExpressionNodeEfficient"],
"ExpressionRange",
],
) -> "TermEfficient":
"""
Time shift of term
"""
# The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression

# Example : (param("p") * var("x")).shift(1)
# Previous behavior : p[t]x[t-1]
# New behavior : p[t-1]x[t-1]

if self.time_operator is not None:
raise ValueError(
f"Composition of time operators {self.time_operator} and {TimeShift(InstancesTimeIndex(expressions))} is not allowed"
)

return dataclasses.replace(
self,
coefficient=TimeOperatorNode(
self.coefficient, "TimeShift", InstancesTimeIndex(expressions)
),
time_operator=TimeShift(InstancesTimeIndex(expressions)),
)


def generate_key(term: TermEfficient) -> TermKeyEfficient:
return TermKeyEfficient(
Expand Down Expand Up @@ -240,6 +314,8 @@ class LinearExpressionEfficient:

terms: Dict[TermKeyEfficient, TermEfficient]
constant: ExpressionNodeEfficient
# TODO: Probably not necessary, for now we replicate old implementation functioning
instances: Instances

# TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break
def __init__(
Expand Down Expand Up @@ -272,6 +348,9 @@ def __init__(
raise TypeError(
f"Terms must be either of type Dict[str, Term] or List[Term], whereas {terms} is of type {type(terms)}"
)

def _compute_instances(self):


def is_zero(self) -> bool:
return len(self.terms) == 0 and is_zero(self.constant)
Expand Down Expand Up @@ -484,6 +563,90 @@ def is_constant(self) -> bool:
# Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree...
return not self.terms

def compute_indexation(
self, provider: IndexingStructureProvider
) -> IndexingStructure:

indexing = compute_indexation(self.constant, provider)
for term in self.terms.values():
indexing = indexing | term.compute_indexation(provider)

return indexing

# def sum(self) -> "ExpressionNode":
# if isinstance(self, TimeOperatorNode):
# return TimeAggregatorNode(self, "TimeSum", stay_roll=True)
# else:
# return _apply_if_node(
# self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False)
# )

# def sum_connections(self) -> "ExpressionNode":
# if isinstance(self, PortFieldNode):
# return PortFieldAggregatorNode(self, aggregator="PortSum")
# raise ValueError(
# f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}."
# )

def shift(
self,
expressions: Union[
int,
"ExpressionNodeEfficient",
List["ExpressionNodeEfficient"],
"ExpressionRange",
],
) -> "LinearExpressionEfficient":
"""
Time shift of variables
Examples:
>>> x.shift([1, 2, 4]) represents the vector of variables (x[t+1], x[t+2], x[t+4])
No variables allowed in shift argument, but parameter trees are ok
It is assumed that the shift operator is linear and distributes to all terms and to the constant of the linear expression on which it is applied.
Examples:
>>> (param("a") * var("x") + param("b")).shift([1, 2, 4]) represents the vector of variables (a[t+1]x[t+1] + b[t+1], a[t+2]x[t+2] + b[t+2], a[t+4]x[t+4] + b[t+4])
"""

# The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression

# Example : (param("p") * var("x")).shift(1)
# Previous behavior : p[t]x[t-1]
# New behavior : p[t-1]x[t-1]

result_terms = {}
for term in self.terms.values():
term_with_operator = term.shift(expressions)
result_terms[generate_key(term_with_operator)] = term_with_operator

result_constant = TimeOperatorNode(
self.constant, "TimeShift", InstancesTimeIndex(expressions)
)
result_expr = LinearExpressionEfficient(result_terms, result_constant)
return result_expr

# def eval(
# self,
# expressions: Union[
# int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange"
# ],
# ) -> "ExpressionNode":
# return _apply_if_node(
# self,
# lambda x: TimeOperatorNode(
# x, "TimeEvaluation", InstancesTimeIndex(expressions)
# ),
# )

# def expec(self) -> "ExpressionNode":
# return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation"))

# def variance(self) -> "ExpressionNode":
# return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance"))


def linear_expressions_equal(
lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient
Expand Down Expand Up @@ -542,13 +705,13 @@ def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient:
raise TypeError(f"Unable to wrap {obj} into a linear expression")


def _apply_if_node(
obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient]
) -> LinearExpressionEfficient:
if as_linear_expr := _wrap_in_linear_expr(obj):
return func(as_linear_expr)
else:
return NotImplemented
# def _apply_if_node(
# obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient]
# ) -> LinearExpressionEfficient:
# if as_linear_expr := _wrap_in_linear_expr(obj):
# return func(as_linear_expr)
# else:
# return NotImplemented


def _copy_expression(
Expand Down
16 changes: 9 additions & 7 deletions src/andromede/expression/time_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from dataclasses import dataclass
from typing import Any, List, Tuple

from andromede.expression.expression_efficient import InstancesTimeIndex


@dataclass(frozen=True)
class TimeOperator(ABC):
Expand All @@ -27,21 +29,21 @@ class TimeOperator(ABC):
- is_rolling: bool, if true, this means that the time_ids are to be understood relatively to the current timestep of the context AND that the represented expression will have to be instanciated for all timesteps. Otherwise, the time_ids are "absolute" times and the expression only has to be instantiated once.
"""

time_ids: List[int]
time_ids: InstancesTimeIndex

@classmethod
@abstractmethod
def rolling(cls) -> bool:
raise NotImplementedError

def __post_init__(self) -> None:
if isinstance(self.time_ids, int):
object.__setattr__(self, "time_ids", [self.time_ids])
elif isinstance(self.time_ids, range):
object.__setattr__(self, "time_ids", list(self.time_ids))
# def __post_init__(self) -> None:
# if isinstance(self.time_ids, int):
# object.__setattr__(self, "time_ids", [self.time_ids])
# elif isinstance(self.time_ids, range):
# object.__setattr__(self, "time_ids", list(self.time_ids))

def key(self) -> Tuple[int, ...]:
return tuple(self.time_ids)
return self.time_ids

def size(self) -> int:
return len(self.time_ids)
Expand Down

0 comments on commit e1ebb29

Please sign in to comment.