Skip to content

Commit

Permalink
Fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
tbittar committed Sep 18, 2024
1 parent 21c920c commit 9c0904b
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 37 deletions.
4 changes: 2 additions & 2 deletions src/andromede/expression/operators_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def scenario_operator(self, node: ScenarioOperatorNode) -> ExpressionNode:
nodes.append(apply_scenario(operand, t))
return sum_expressions(nodes) / self.scenarios_count

def pb_parameter(self, node: ProblemParameterNode) -> T:
def pb_parameter(self, node: ProblemParameterNode) -> ExpressionNode:
raise ValueError("Should not reach")

def pb_variable(self, node: ProblemVariableNode) -> T:
def pb_variable(self, node: ProblemVariableNode) -> ExpressionNode:
raise ValueError("Should not reach")


Expand Down
9 changes: 0 additions & 9 deletions src/andromede/simulation/linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,6 @@ def _scenario_index_to_str(scenario_index: ScenarioIndex) -> str:
return ""


def _str_for_coeff(coeff: float) -> str:
if is_one(coeff):
return "+"
elif is_minus_one(coeff):
return "-"
else:
return "{:+g}".format(coeff)


def _str_for_time_expansion(exp: TimeExpansion) -> str:
if isinstance(exp, TimeShiftExpansion):
return f".shift({exp.shift})"
Expand Down
15 changes: 11 additions & 4 deletions src/andromede/simulation/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union

from andromede.expression import (
AdditionNode,
Expand All @@ -34,6 +34,7 @@
ParameterNode,
PortFieldAggregatorNode,
PortFieldNode,
ProblemParameterNode,
ProblemVariableNode,
ScenarioIndex,
ScenarioOperatorNode,
Expand Down Expand Up @@ -89,7 +90,7 @@ class LinearExpressionData:
constant: float

def build(self) -> LinearExpression:
res_terms = {}
res_terms: Dict[TermKey, Any] = {}
for t in self.terms:
k = t.to_key()
if k in res_terms:
Expand Down Expand Up @@ -120,7 +121,7 @@ def negation(self, node: NegationNode) -> LinearExpressionData:
def addition(self, node: AdditionNode) -> LinearExpressionData:
operands = [visit(o, self) for o in node.operands]
terms = []
constant = 0
constant: float = 0
for o in operands:
constant += o.constant
terms.extend(o.terms)
Expand Down Expand Up @@ -170,12 +171,18 @@ def _get_timestep(self, time_index: TimeIndex) -> int:
return time_index.timestep
if isinstance(time_index, NoTimeIndex):
return self.timestep
else:
raise TypeError(f"Type {type(time_index)} is not a valid TimeIndex type.")

def _get_scenario(self, scenario_index: ScenarioIndex) -> int:
if isinstance(scenario_index, OneScenarioIndex):
return scenario_index.scenario
if isinstance(scenario_index, NoScenarioIndex):
return self.scenario
else:
raise TypeError(
f"Type {type(scenario_index)} is not a valid TimeIndex type."
)

def literal(self, node: LiteralNode) -> LinearExpressionData:
return LinearExpressionData([], node.value)
Expand Down Expand Up @@ -215,7 +222,7 @@ def comp_parameter(self, node: ComponentParameterNode) -> LinearExpressionData:
"Parameters need to be associated with their timestep/scenario before linearization."
)

def pb_parameter(self, node: ProblemVariableNode) -> LinearExpressionData:
def pb_parameter(self, node: ProblemParameterNode) -> LinearExpressionData:
# TODO SL: not the best place to do this.
# in the future, we should evaluate coefficients of variables as time vectors once for all timesteps
time_index = self._get_timestep(node.time_index)
Expand Down
23 changes: 1 addition & 22 deletions src/andromede/simulation/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,29 +244,8 @@ def get_variable(
scenario,
self.component.id,
variable_name,
self.component.model.variables[variable_name].structure,
)

def linearize_expression(
self,
block_timestep: int,
scenario: int,
expression: ExpressionNode,
) -> LinearExpression:
parameters_valued_provider = _make_parameter_value_provider(
self.opt_context, block_timestep, scenario
)
evaluated_expr = resolve_parameters(expression, parameters_valued_provider)

value_provider = _make_value_provider(
self.opt_context, block_timestep, scenario, self.component
)
structure_provider = _make_data_structure_provider(
self.opt_context.network, self.component
)

return linearize_expression(evaluated_expr, structure_provider, value_provider)


class BlockBorderManagement(Enum):
"""
Expand Down Expand Up @@ -338,7 +317,7 @@ def connection_fields_expressions(self) -> Dict[PortFieldKey, List[ExpressionNod
def block_timestep_to_absolute_timestep(self, block_timestep: int) -> int:
return self._block.timesteps[self.get_actual_block_timestep(block_timestep)]

def get_actual_block_timestep(self, block_timestep):
def get_actual_block_timestep(self, block_timestep: int) -> int:
if self._border_management == BlockBorderManagement.CYCLE:
return block_timestep % self.block_length()
else:
Expand Down

0 comments on commit 9c0904b

Please sign in to comment.