Skip to content

Commit

Permalink
Fix shift distribution and add component context for time operators
Browse files Browse the repository at this point in the history
  • Loading branch information
tbittar committed Aug 20, 2024
1 parent b429782 commit bda088c
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 41 deletions.
4 changes: 4 additions & 0 deletions src/andromede/expression/evaluate_parameters_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def time_aggregator(
for k in operand_dict.keys()
if k.scenario == scenario
)
# As the sum aggregates on time, time indices on which to evaluate parent expression collapses on row_id.time
self.time_scenario_indices.time_indices = [self.row_id.time]
return result
else:
return NotImplemented
Expand All @@ -185,6 +187,8 @@ def scenario_operator(
operand_dict[k] for k in operand_dict.keys() if k.time == time
)
)
# As the expectation aggregates on scenario, scenario indices on which to evaluate parent expression collapses on row_id.scenario
self.time_scenario_indices.scenario_indices = [self.row_id.scenario]
return result

else:
Expand Down
66 changes: 52 additions & 14 deletions src/andromede/expression/linear_expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Sequence,
TypeVar,
Union,
cast,
overload,
)

Expand Down Expand Up @@ -231,26 +232,25 @@ def sum(
if shift is not None and eval is not None:
raise ValueError("Only shift or eval arguments should specified, not both.")

# The shift or eval operators distribute over the coefficients whereas the sum only applies to the whole as (param("a") * var("x")).shift([1,5]) represents: a[t+1]x[t+1] + ... + a[t+5]x[t+5]
# And (param("a") * var("x")).eval([1,5]) represents: a[1]x[1] + ... + a[5]x[5]
# The shift or eval operators applies on the variable, then it will define at which time step the term coefficient * variable will be evaluated

if shift is not None:
return dataclasses.replace(
self,
coefficient=TimeOperatorNode(
self.coefficient, TimeOperatorName.SHIFT, InstancesTimeIndex(shift)
),
# coefficient=TimeOperatorNode(
# self.coefficient, TimeOperatorName.SHIFT, InstancesTimeIndex(shift)
# ),
time_operator=TimeShift(InstancesTimeIndex(shift)),
time_aggregator=TimeSum(stay_roll=True),
)
elif eval is not None:
return dataclasses.replace(
self,
coefficient=TimeOperatorNode(
self.coefficient,
TimeOperatorName.EVALUATION,
InstancesTimeIndex(eval),
),
# coefficient=TimeOperatorNode(
# self.coefficient,
# TimeOperatorName.EVALUATION,
# InstancesTimeIndex(eval),
# ),
time_operator=TimeEvaluation(InstancesTimeIndex(eval)),
time_aggregator=TimeSum(stay_roll=True),
)
Expand Down Expand Up @@ -380,8 +380,7 @@ def _merge_dicts(
rhs: Dict[TermKeyEfficient, TermEfficient],
merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient],
neutral: float,
) -> Dict[TermKeyEfficient, TermEfficient]:
...
) -> Dict[TermKeyEfficient, TermEfficient]: ...


@overload
Expand All @@ -390,8 +389,7 @@ def _merge_dicts(
rhs: Dict[PortFieldId, PortFieldTerm],
merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm],
neutral: float,
) -> Dict[PortFieldId, PortFieldTerm]:
...
) -> Dict[PortFieldId, PortFieldTerm]: ...


def _get_neutral_term(term: T_val, neutral: float) -> T_val:
Expand Down Expand Up @@ -959,10 +957,21 @@ def add_component_context(self, component_id: str) -> "LinearExpressionEfficient
raise ValueError(
"This expression has already been associated to another component."
)

result_term = dataclasses.replace(
term,
component_id=component_id,
coefficient=add_component_context(component_id, term.coefficient),
time_operator=(
dataclasses.replace(
term.time_operator,
time_ids=_add_component_context_to_instances_index(
component_id, term.time_operator.time_ids
),
)
if term.time_operator
else None
),
)
result_terms[generate_key(result_term)] = result_term
result_constant = add_component_context(component_id, self.constant)
Expand All @@ -971,6 +980,35 @@ def add_component_context(self, component_id: str) -> "LinearExpressionEfficient
)


def _add_component_context_to_expression_range(
component_id: str, expression_range: ExpressionRange
) -> ExpressionRange:
return ExpressionRange(
start=add_component_context(component_id, expression_range.start),
stop=add_component_context(component_id, expression_range.stop),
step=(
add_component_context(component_id, expression_range.step)
if expression_range.step is not None
else None
),
)


def _add_component_context_to_instances_index(
component_id: str, instances_index: InstancesTimeIndex
) -> InstancesTimeIndex:
expressions = instances_index.expressions
if isinstance(expressions, ExpressionRange):
return InstancesTimeIndex(
_add_component_context_to_expression_range(component_id, expressions)
)
if isinstance(expressions, list):
expressions_list = cast(List[ExpressionNodeEfficient], expressions)
copy = [add_component_context(component_id, e) for e in expressions_list]
return InstancesTimeIndex(copy)
raise ValueError("Unexpected type in instances index")


def linear_expressions_equal(
lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient
) -> bool:
Expand Down
26 changes: 15 additions & 11 deletions src/andromede/simulation/linear_expression_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,14 @@ def resolve(
# a_t * sum(x_t')
# a_t * x_t
# TODO: Next line is to be moved inside the for loop once we have figured out how to represent sum(a_t * x_t)
resolved_coeff = resolve_coefficient(
term.coefficient, self.value_provider, row_id
)

for ts_id, lp_variable in resolved_variables.items():
# TODO: Could we check in which case coeff resolution leads to the same result for each element in the for loop ? When there is only a literal, etc, etc ?
resolved_coeff = resolve_coefficient(
term.coefficient,
self.value_provider,
RowIndex(ts_id.time, ts_id.scenario),
)
resolved_terms.append(ResolvedTerm(resolved_coeff, lp_variable))

resolved_constant = resolve_coefficient(
Expand All @@ -80,14 +84,14 @@ def resolve_variables(
operator_ts_ids = self._row_id_to_term_time_scenario_id(term, row_id)
for time in operator_ts_ids.time_indices:
for scenario in operator_ts_ids.scenario_indices:
solver_vars[
TimeScenarioIndex(time, scenario)
] = self.context.get_component_variable(
time,
scenario,
term.component_id,
term.variable_name,
term.structure,
solver_vars[TimeScenarioIndex(time, scenario)] = (
self.context.get_component_variable(
time,
scenario,
term.component_id,
term.variable_name,
term.structure,
)
)
return solver_vars

Expand Down
20 changes: 4 additions & 16 deletions tests/unittests/expressions/test_expressions_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,7 @@ def test_comparison() -> None:
), # The internal representation of shift(1) is sum(shift=1)
scenario_aggregator=None,
): TermEfficient(
TimeOperatorNode(
LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1)
),
LiteralNode(1),
"",
"x",
time_operator=TimeShift(
Expand All @@ -432,9 +430,7 @@ def test_comparison() -> None:
time_aggregator=TimeSum(stay_roll=True),
scenario_aggregator=None,
): TermEfficient(
TimeOperatorNode(
LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1)
),
LiteralNode(1),
"",
"y",
time_operator=TimeShift(InstancesTimeIndex(1)),
Expand All @@ -461,11 +457,7 @@ def test_comparison() -> None:
), # The internal representation of eval(1) is sum(eval=1)
scenario_aggregator=None,
): TermEfficient(
TimeOperatorNode(
LiteralNode(1),
TimeOperatorName.EVALUATION,
InstancesTimeIndex(1),
),
LiteralNode(1),
"",
"x",
time_operator=TimeEvaluation(
Expand All @@ -482,11 +474,7 @@ def test_comparison() -> None:
time_aggregator=TimeSum(stay_roll=True),
scenario_aggregator=None,
): TermEfficient(
TimeOperatorNode(
LiteralNode(1),
TimeOperatorName.EVALUATION,
InstancesTimeIndex(1),
),
LiteralNode(1),
"",
"y",
time_operator=TimeEvaluation(InstancesTimeIndex(1)),
Expand Down
21 changes: 21 additions & 0 deletions tests/unittests/expressions/test_resolve_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TimeOperatorName,
TimeOperatorNode,
comp_param,
literal,
param,
)
from andromede.expression.indexing_structure import IndexingStructure, RowIndex
Expand Down Expand Up @@ -275,6 +276,7 @@ def test_resolve_coefficient_on_elementary_operations(
[
(param("p").shift(2).sum(), RowIndex(0, 0), 3.0),
(param("p").shift(-1).sum(), RowIndex(2, 1), 5.0),
(literal(0).shift(-1).sum(), RowIndex(0, 0), 0.0),
(param("p").eval(2).sum(), RowIndex(0, 0), 3.0),
(param("p").eval(2).sum(), RowIndex(2, 0), 3.0),
(param("p").shift(ExpressionRange(0, 3)).sum(), RowIndex(0, 0), 13.0),
Expand Down Expand Up @@ -311,3 +313,22 @@ def test_resolve_coefficient_on_expectation(
provider: CustomValueProvider,
) -> None:
assert math.isclose(resolve_coefficient(expr, provider, row_id), expected)


@pytest.mark.parametrize(
"expr, row_id, expected",
[
(param("p").expec().sum(), RowIndex(0, 0), 18.0),
(param("p").sum().expec(), RowIndex(0, 0), 18.0),
(param("p").shift(comp_param("c", "q")).sum().expec(), RowIndex(1, 0), 6.5),
(param("p").expec().shift(comp_param("c", "q")).sum(), RowIndex(1, 0), 7.5),
(param("p").shift(comp_param("c", "q")).expec().sum(), RowIndex(1, 0), 6.5),
],
)
def test_resolve_coefficient_on_sum_and_expectation(
expr: ExpressionNodeEfficient,
row_id: RowIndex,
expected: float,
provider: CustomValueProvider,
) -> None:
assert math.isclose(resolve_coefficient(expr, provider, row_id), expected)

0 comments on commit bda088c

Please sign in to comment.