Skip to content

Commit

Permalink
Fix syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
tbittar committed Jul 26, 2024
1 parent 13ea7de commit d092b1b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 37 deletions.
26 changes: 13 additions & 13 deletions src/andromede/expression/context_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
from dataclasses import dataclass

from . import CopyVisitor
from .expression import (
from .expression_efficient import (
ComponentParameterNode,
ComponentVariableNode,
ExpressionNode,
ExpressionNodeEfficient,
ParameterNode,
VariableNode,
)
from .visitor import visit

Expand All @@ -32,22 +30,24 @@ class ContextAdder(CopyVisitor):

component_id: str

def variable(self, node: VariableNode) -> ExpressionNode:
return ComponentVariableNode(self.component_id, node.name)
# def variable(self, node: VariableNode) -> ExpressionNodeEfficient:
# return ComponentVariableNode(self.component_id, node.name)

def parameter(self, node: ParameterNode) -> ExpressionNode:
def parameter(self, node: ParameterNode) -> ExpressionNodeEfficient:
return ComponentParameterNode(self.component_id, node.name)

def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
raise ValueError(
"This expression has already been associated to another component."
)
# def comp_variable(self, node: ComponentVariableNode) -> ExpressionNodeEfficient:
# raise ValueError(
# "This expression has already been associated to another component."
# )

def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNodeEfficient:
raise ValueError(
"This expression has already been associated to another component."
)


def add_component_context(id: str, expression: ExpressionNode) -> ExpressionNode:
def add_component_context(
id: str, expression: ExpressionNodeEfficient
) -> ExpressionNodeEfficient:
return visit(expression, ContextAdder(id))
25 changes: 21 additions & 4 deletions src/andromede/expression/linear_expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
overload,
)

from andromede.expression.context_adder import add_component_context
from andromede.expression.equality import expressions_equal
from andromede.expression.evaluate import ValueProvider, evaluate
from andromede.expression.expression_efficient import (
Expand Down Expand Up @@ -350,8 +351,7 @@ def _merge_dicts(
rhs: Dict[TermKeyEfficient, TermEfficient],
merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient],
neutral: float,
) -> Dict[TermKeyEfficient, TermEfficient]:
...
) -> Dict[TermKeyEfficient, TermEfficient]: ...


@overload
Expand All @@ -360,8 +360,7 @@ def _merge_dicts(
rhs: Dict[PortFieldId, PortFieldTerm],
merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm],
neutral: float,
) -> Dict[PortFieldId, PortFieldTerm]:
...
) -> Dict[PortFieldId, PortFieldTerm]: ...


def _get_neutral_term(term: T_val, neutral: float) -> T_val:
Expand Down Expand Up @@ -919,6 +918,24 @@ def resolve_port(
)
return self + port_expr

def add_component_context(self, component_id: str) -> "LinearExpressionEfficient":
result_terms = {}
for term in self.terms.values():
if term.component_id:
raise ValueError(
"This expression has already been associated to another component."
)
result_term = dataclasses.replace(
term,
component_id=component_id,
coefficient=add_component_context(component_id, term.coefficient),
)
result_terms[generate_key(result_term)] = result_term
result_constant = add_component_context(component_id, self.constant)
return LinearExpressionEfficient(
result_terms, result_constant, self.port_field_terms
)


def linear_expressions_equal(
lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient
Expand Down
39 changes: 19 additions & 20 deletions src/andromede/simulation/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,21 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Iterable, List, Optional, Type
from typing import Dict, Iterable, List, Optional

import ortools.linear_solver.pywraplp as lp

from andromede.expression import ( # ExpressionNode,
EvaluationVisitor,
ParameterValueProvider,
ValueProvider,
resolve_parameters,
visit,
)
from andromede.expression.context_adder import add_component_context
from andromede.expression.indexing import IndexingStructureProvider, compute_indexation
from andromede.expression.indexing import IndexingStructureProvider
from andromede.expression.indexing_structure import IndexingStructure
from andromede.expression.linear_expression_efficient import LinearExpressionEfficient
from andromede.expression.port_resolver import PortFieldKey, resolve_port
from andromede.expression.linear_expression_efficient import (
LinearExpressionEfficient,
PortFieldKey,
)
from andromede.expression.scenario_operator import Expectation
from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum
from andromede.model.common import ValueType
Expand Down Expand Up @@ -149,8 +148,7 @@ def get_value(self, block_timestep: int, scenario: int) -> float:
param_value_provider = _make_value_provider(
self.context, block_timestep, scenario, self.component
)
visitor = EvaluationVisitor(param_value_provider)
return visit(self.expression, visitor)
return self.expression.evaluate(param_value_provider)


def _make_parameter_value_provider(
Expand Down Expand Up @@ -421,9 +419,9 @@ def _get_indexing(
constraint: Constraint, provider: IndexingStructureProvider
) -> IndexingStructure:
return (
compute_indexation(constraint.expression, provider)
or compute_indexation(constraint.lower_bound, provider)
or compute_indexation(constraint.upper_bound, provider)
constraint.expression.compute_indexation(provider)
or constraint.lower_bound.compute_indexation(provider)
or constraint.upper_bound.compute_indexation(provider)
)


Expand All @@ -447,9 +445,9 @@ def _instantiate_model_expression(
1. add component ID for variables and parameters of THIS component
2. replace port fields by their definition
"""
with_component = add_component_context(component_id, model_expression)
with_component_and_ports = resolve_port(
with_component, component_id, optimization_context.connection_fields_expressions
with_component = model_expression.add_component_context(component_id)
with_component_and_ports = with_component.resolve_port(
component_id, optimization_context.connection_fields_expressions
)
return with_component_and_ports

Expand All @@ -465,9 +463,10 @@ def _create_constraint(
constraint_indexing = _compute_indexing_structure(context, constraint)

# Perf: Perform linearization (tree traversing) without timesteps so that we can get the number of instances for the expression (from the time_ids of operators)
linear_expr = context.linearize_expression(0, 0, constraint.expression)
# Will there be cases where instances > 1 ? If not, maybe just a check that get_number_of_instances == 1 is sufficient ? Anyway, the function should be implemented
instances_per_time_step = linear_expr.number_of_instances()
# linear_expr = context.linearize_expression(0, 0, constraint.expression)
# # Will there be cases where instances > 1 ? If not, maybe just a check that get_number_of_instances == 1 is sufficient ? Anyway, the function should be implemented
# instances_per_time_step = linear_expr.number_of_instances()
instances_per_time_step = 1

for block_timestep in context.opt_context.get_time_indices(constraint_indexing):
for scenario in context.opt_context.get_scenario_indices(constraint_indexing):
Expand Down Expand Up @@ -703,8 +702,8 @@ def _register_connection_fields_definitions(self) -> None:
)
)
expression_node = port_definition.definition # type: ignore
instantiated_expression = add_component_context(
master_port.component.id, expression_node
instantiated_expression = expression_node.add_component_context(
master_port.component.id
)
self.context.register_connection_fields_expressions(
component_id=cnx.port1.component.id,
Expand Down

0 comments on commit d092b1b

Please sign in to comment.