Skip to content

Commit

Permalink
WIP: convert addition operands to list
Browse files Browse the repository at this point in the history
Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>
  • Loading branch information
sylvlecl committed Sep 10, 2024
1 parent 7f3bb36 commit 74215f8
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 62 deletions.
1 change: 0 additions & 1 deletion src/andromede/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
MultiplicationNode,
NegationNode,
ParameterNode,
SubstractionNode,
VariableNode,
literal,
param,
Expand Down
7 changes: 2 additions & 5 deletions src/andromede/expression/degree.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
NegationNode,
ParameterNode,
ScenarioOperatorNode,
SubstractionNode,
VariableNode,
)
from .visitor import ExpressionVisitor, T, visit
Expand All @@ -51,10 +50,8 @@ def negation(self, node: NegationNode) -> int:

# TODO: Take into account simplification that can occur with literal coefficient for add, sub, mult, div
def addition(self, node: AdditionNode) -> int:
return max(visit(node.left, self), visit(node.right, self))

def substraction(self, node: SubstractionNode) -> int:
return max(visit(node.left, self), visit(node.right, self))
degrees = [visit(o, self) for o in node.operands]
return max(degrees)

def multiplication(self, node: MultiplicationNode) -> int:
return visit(node.left, self) + visit(node.right, self)
Expand Down
10 changes: 3 additions & 7 deletions src/andromede/expression/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
MultiplicationNode,
NegationNode,
ParameterNode,
SubstractionNode,
VariableNode,
)
from andromede.expression.expression import (
Expand Down Expand Up @@ -66,8 +65,6 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool:
return self.negation(left, right)
if isinstance(left, AdditionNode) and isinstance(right, AdditionNode):
return self.addition(left, right)
if isinstance(left, SubstractionNode) and isinstance(right, SubstractionNode):
return self.substraction(left, right)
if isinstance(left, DivisionNode) and isinstance(right, DivisionNode):
return self.division(left, right)
if isinstance(left, MultiplicationNode) and isinstance(
Expand Down Expand Up @@ -130,10 +127,9 @@ def negation(self, left: NegationNode, right: NegationNode) -> bool:
return self.visit(left.operand, right.operand)

def addition(self, left: AdditionNode, right: AdditionNode) -> bool:
return self._visit_operands(left, right)

def substraction(self, left: SubstractionNode, right: SubstractionNode) -> bool:
return self._visit_operands(left, right)
left_ops = left.operands
right_ops = right.operands
return len(left_ops) == len(right_ops) and all(self.visit(l, r) for l, r in zip(left_ops, right_ops))

def multiplication(
self, left: MultiplicationNode, right: MultiplicationNode
Expand Down
32 changes: 20 additions & 12 deletions src/andromede/expression/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import enum
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union, List

import andromede.expression.port_operator
import andromede.expression.scenario_operator
Expand All @@ -40,16 +40,29 @@ def __neg__(self) -> "ExpressionNode":
return NegationNode(self)

def __add__(self, rhs: Any) -> "ExpressionNode":
return _apply_if_node(rhs, lambda x: AdditionNode(self, x))
lhs = self
operands = []
rhs = _wrap_in_node(rhs)
operands.extend(lhs.operands if isinstance(lhs, AdditionNode) else [lhs])
operands.extend(rhs.operands if isinstance(rhs, AdditionNode) else [rhs])
return AdditionNode(operands)

def __radd__(self, lhs: Any) -> "ExpressionNode":
return _apply_if_node(lhs, lambda x: AdditionNode(x, self))
lhs = _wrap_in_node(lhs)
return lhs + self

def __sub__(self, rhs: Any) -> "ExpressionNode":
return _apply_if_node(rhs, lambda x: SubstractionNode(self, x))
lhs = self
operands = []
rhs = _wrap_in_node(rhs)
operands.extend(lhs.operands if isinstance(lhs, AdditionNode) else [lhs])
right_operands = rhs.operands if isinstance(rhs, AdditionNode) else [rhs]
operands.extend([-o for o in right_operands])
return AdditionNode(operands)

def __rsub__(self, lhs: Any) -> "ExpressionNode":
return _apply_if_node(lhs, lambda x: SubstractionNode(x, self))
lhs = _wrap_in_node(lhs)
return lhs + self

def __mul__(self, rhs: Any) -> "ExpressionNode":
return _apply_if_node(rhs, lambda x: MultiplicationNode(self, x))
Expand Down Expand Up @@ -322,13 +335,8 @@ class ComparisonNode(BinaryOperatorNode):


@dataclass(frozen=True, eq=False)
class AdditionNode(BinaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
class SubstractionNode(BinaryOperatorNode):
pass
class AdditionNode(ExpressionNode):
operands: List[ExpressionNode]


@dataclass(frozen=True, eq=False)
Expand Down
10 changes: 5 additions & 5 deletions src/andromede/expression/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
PortFieldAggregatorNode,
PortFieldNode,
ScenarioOperatorNode,
SubstractionNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
Expand Down Expand Up @@ -76,10 +75,11 @@ def negation(self, node: NegationNode) -> IndexingStructure:
return visit(node.operand, self)

def addition(self, node: AdditionNode) -> IndexingStructure:
return visit(node.left, self) | visit(node.right, self)

def substraction(self, node: SubstractionNode) -> IndexingStructure:
return visit(node.left, self) | visit(node.right, self)
operands = [visit(o, self) for o in node.operands]
res = operands[0]
for o in node.operands[1:]:
res = res | visit(o, self)
return res

def multiplication(self, node: MultiplicationNode) -> IndexingStructure:
return visit(node.left, self) | visit(node.right, self)
Expand Down
21 changes: 11 additions & 10 deletions src/andromede/expression/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PortFieldNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
TimeSumNode, ProblemVariableNode, ProblemParameterNode,
)
from andromede.expression.visitor import T

Expand All @@ -36,7 +36,6 @@
NegationNode,
ParameterNode,
ScenarioOperatorNode,
SubstractionNode,
VariableNode,
)
from .visitor import ExpressionVisitor, visit
Expand All @@ -63,14 +62,8 @@ def negation(self, node: NegationNode) -> str:
return f"-({visit(node.operand, self)})"

def addition(self, node: AdditionNode) -> str:
left_value = visit(node.left, self)
right_value = visit(node.right, self)
return f"({left_value} + {right_value})"

def substraction(self, node: SubstractionNode) -> str:
left_value = visit(node.left, self)
right_value = visit(node.right, self)
return f"({left_value} - {right_value})"
values = [visit(o, self) for o in node.operands]
return f"({' + '.join(values)})"

def multiplication(self, node: MultiplicationNode) -> str:
left_value = visit(node.left, self)
Expand Down Expand Up @@ -100,6 +93,14 @@ def comp_variable(self, node: ComponentVariableNode) -> str:
def comp_parameter(self, node: ComponentParameterNode) -> str:
return f"{node.component_id}.{node.name}"

def pb_variable(self, node: ProblemVariableNode) -> str:
# TODO
return f"{node.component_id}.{node.name}"

def pb_parameter(self, node: ProblemParameterNode) -> str:
# TODO
return f"{node.component_id}.{node.name}"

def time_shift(self, node: TimeShiftNode) -> str:
return f"({visit(node.operand, self)}.shift({visit(node.time_shift, self)}))"

Expand Down
20 changes: 5 additions & 15 deletions src/andromede/expression/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
ProblemParameterNode,
ProblemVariableNode,
ScenarioOperatorNode,
SubstractionNode,
TimeEvalNode,
TimeShiftNode,
TimeSumNode,
Expand Down Expand Up @@ -65,10 +64,6 @@ def negation(self, node: NegationNode) -> T:
def addition(self, node: AdditionNode) -> T:
...

@abstractmethod
def substraction(self, node: SubstractionNode) -> T:
...

@abstractmethod
def multiplication(self, node: MultiplicationNode) -> T:
...
Expand Down Expand Up @@ -160,8 +155,6 @@ def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T:
return visitor.multiplication(root)
elif isinstance(root, DivisionNode):
return visitor.division(root)
elif isinstance(root, SubstractionNode):
return visitor.substraction(root)
elif isinstance(root, ComparisonNode):
return visitor.comparison(root)
elif isinstance(root, TimeShiftNode):
Expand Down Expand Up @@ -220,14 +213,11 @@ def negation(self, node: NegationNode) -> T_op:
return -visit(node.operand, self)

def addition(self, node: AdditionNode) -> T_op:
left_value = visit(node.left, self)
right_value = visit(node.right, self)
return left_value + right_value

def substraction(self, node: SubstractionNode) -> T_op:
left_value = visit(node.left, self)
right_value = visit(node.right, self)
return left_value - right_value
operands = [visit(o, self) for o in node.operands]
res = operands[0]
for o in operands[1:]:
res = res + o
return res

def multiplication(self, node: MultiplicationNode) -> T_op:
left_value = visit(node.left, self)
Expand Down
7 changes: 2 additions & 5 deletions src/andromede/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
MultiplicationNode,
NegationNode,
ParameterNode,
SubstractionNode,
VariableNode,
)
from andromede.expression.degree import is_linear
Expand Down Expand Up @@ -243,10 +242,8 @@ def _visit_binary_op(self, node: BinaryOperatorNode) -> None:
visit(node.right, self)

def addition(self, node: AdditionNode) -> None:
self._visit_binary_op(node)

def substraction(self, node: SubstractionNode) -> None:
self._visit_binary_op(node)
for n in node.operands:
visit(n, self)

def multiplication(self, node: MultiplicationNode) -> None:
self._visit_binary_op(node)
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def parameter_is_constant_over_time(self, name: str) -> bool:


def test_comp_parameter() -> None:
add_node = AdditionNode(LiteralNode(1), ComponentVariableNode("comp1", "x"))
add_node = AdditionNode([LiteralNode(1), ComponentVariableNode("comp1", "x")])
expr = DivisionNode(add_node, ComponentParameterNode("comp1", "p"))

assert visit(expr, PrinterVisitor()) == "((1 + comp1.x) / comp1.p)"
Expand All @@ -93,7 +93,7 @@ def test_comp_parameter() -> None:


def test_ast() -> None:
add_node = AdditionNode(LiteralNode(1), VariableNode("x"))
add_node = AdditionNode([LiteralNode(1), VariableNode("x")])
expr = DivisionNode(add_node, ParameterNode("p"))

assert visit(expr, PrinterVisitor()) == "((1 + x) / p)"
Expand Down

0 comments on commit 74215f8

Please sign in to comment.