diff --git a/src/andromede/expression/operators_expansion.py b/src/andromede/expression/operators_expansion.py index f0f530b..03fa215 100644 --- a/src/andromede/expression/operators_expansion.py +++ b/src/andromede/expression/operators_expansion.py @@ -106,10 +106,10 @@ def scenario_operator(self, node: ScenarioOperatorNode) -> ExpressionNode: nodes.append(apply_scenario(operand, t)) return sum_expressions(nodes) / self.scenarios_count - def pb_parameter(self, node: ProblemParameterNode) -> T: + def pb_parameter(self, node: ProblemParameterNode) -> ExpressionNode: raise ValueError("Should not reach") - def pb_variable(self, node: ProblemVariableNode) -> T: + def pb_variable(self, node: ProblemVariableNode) -> ExpressionNode: raise ValueError("Should not reach") diff --git a/src/andromede/simulation/linear_expression.py b/src/andromede/simulation/linear_expression.py index ebe0fbf..fd187aa 100644 --- a/src/andromede/simulation/linear_expression.py +++ b/src/andromede/simulation/linear_expression.py @@ -156,15 +156,6 @@ def _scenario_index_to_str(scenario_index: ScenarioIndex) -> str: return "" -def _str_for_coeff(coeff: float) -> str: - if is_one(coeff): - return "+" - elif is_minus_one(coeff): - return "-" - else: - return "{:+g}".format(coeff) - - def _str_for_time_expansion(exp: TimeExpansion) -> str: if isinstance(exp, TimeShiftExpansion): return f".shift({exp.shift})" diff --git a/src/andromede/simulation/linearize.py b/src/andromede/simulation/linearize.py index 13b5372..454ce67 100644 --- a/src/andromede/simulation/linearize.py +++ b/src/andromede/simulation/linearize.py @@ -12,7 +12,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional +from typing import Any, Dict, List, Optional, Union from andromede.expression import ( AdditionNode, @@ -34,6 +34,7 @@ ParameterNode, PortFieldAggregatorNode, PortFieldNode, + ProblemParameterNode, ProblemVariableNode, ScenarioIndex, ScenarioOperatorNode, @@ -89,7 +90,7 @@ class LinearExpressionData: constant: float def build(self) -> LinearExpression: - res_terms = {} + res_terms: Dict[TermKey, Any] = {} for t in self.terms: k = t.to_key() if k in res_terms: @@ -120,7 +121,7 @@ def negation(self, node: NegationNode) -> LinearExpressionData: def addition(self, node: AdditionNode) -> LinearExpressionData: operands = [visit(o, self) for o in node.operands] terms = [] - constant = 0 + constant: float = 0 for o in operands: constant += o.constant terms.extend(o.terms) @@ -170,12 +171,18 @@ def _get_timestep(self, time_index: TimeIndex) -> int: return time_index.timestep if isinstance(time_index, NoTimeIndex): return self.timestep + else: + raise TypeError(f"Type {type(time_index)} is not a valid TimeIndex type.") def _get_scenario(self, scenario_index: ScenarioIndex) -> int: if isinstance(scenario_index, OneScenarioIndex): return scenario_index.scenario if isinstance(scenario_index, NoScenarioIndex): return self.scenario + else: + raise TypeError( + f"Type {type(scenario_index)} is not a valid TimeIndex type." + ) def literal(self, node: LiteralNode) -> LinearExpressionData: return LinearExpressionData([], node.value) @@ -215,7 +222,7 @@ def comp_parameter(self, node: ComponentParameterNode) -> LinearExpressionData: "Parameters need to be associated with their timestep/scenario before linearization." ) - def pb_parameter(self, node: ProblemVariableNode) -> LinearExpressionData: + def pb_parameter(self, node: ProblemParameterNode) -> LinearExpressionData: # TODO SL: not the best place to do this. # in the future, we should evaluate coefficients of variables as time vectors once for all timesteps time_index = self._get_timestep(node.time_index) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 42562fd..eec666c 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -244,29 +244,8 @@ def get_variable( scenario, self.component.id, variable_name, - self.component.model.variables[variable_name].structure, ) - def linearize_expression( - self, - block_timestep: int, - scenario: int, - expression: ExpressionNode, - ) -> LinearExpression: - parameters_valued_provider = _make_parameter_value_provider( - self.opt_context, block_timestep, scenario - ) - evaluated_expr = resolve_parameters(expression, parameters_valued_provider) - - value_provider = _make_value_provider( - self.opt_context, block_timestep, scenario, self.component - ) - structure_provider = _make_data_structure_provider( - self.opt_context.network, self.component - ) - - return linearize_expression(evaluated_expr, structure_provider, value_provider) - class BlockBorderManagement(Enum): """ @@ -338,7 +317,7 @@ def connection_fields_expressions(self) -> Dict[PortFieldKey, List[ExpressionNod def block_timestep_to_absolute_timestep(self, block_timestep: int) -> int: return self._block.timesteps[self.get_actual_block_timestep(block_timestep)] - def get_actual_block_timestep(self, block_timestep): + def get_actual_block_timestep(self, block_timestep: int) -> int: if self._border_management == BlockBorderManagement.CYCLE: return block_timestep % self.block_length() else: