From 74215f8497769b464648320243c9125905e2706e Mon Sep 17 00:00:00 2001 From: Sylvain Leclerc Date: Tue, 10 Sep 2024 09:59:03 +0200 Subject: [PATCH] WIP: convert addition operands to list Signed-off-by: Sylvain Leclerc --- src/andromede/expression/__init__.py | 1 - src/andromede/expression/degree.py | 7 ++-- src/andromede/expression/equality.py | 10 ++---- src/andromede/expression/expression.py | 32 ++++++++++++------- src/andromede/expression/indexing.py | 10 +++--- src/andromede/expression/print.py | 21 ++++++------ src/andromede/expression/visitor.py | 20 +++--------- src/andromede/model/model.py | 7 ++-- .../unittests/expressions/test_expressions.py | 4 +-- 9 files changed, 50 insertions(+), 62 deletions(-) diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 2fe9b94d..3d949b6e 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -28,7 +28,6 @@ MultiplicationNode, NegationNode, ParameterNode, - SubstractionNode, VariableNode, literal, param, diff --git a/src/andromede/expression/degree.py b/src/andromede/expression/degree.py index 34f53156..8d21235a 100644 --- a/src/andromede/expression/degree.py +++ b/src/andromede/expression/degree.py @@ -32,7 +32,6 @@ NegationNode, ParameterNode, ScenarioOperatorNode, - SubstractionNode, VariableNode, ) from .visitor import ExpressionVisitor, T, visit @@ -51,10 +50,8 @@ def negation(self, node: NegationNode) -> int: # TODO: Take into account simplification that can occur with literal coefficient for add, sub, mult, div def addition(self, node: AdditionNode) -> int: - return max(visit(node.left, self), visit(node.right, self)) - - def substraction(self, node: SubstractionNode) -> int: - return max(visit(node.left, self), visit(node.right, self)) + degrees = [visit(o, self) for o in node.operands] + return max(degrees) def multiplication(self, node: MultiplicationNode) -> int: return visit(node.left, self) + visit(node.right, self) diff --git a/src/andromede/expression/equality.py b/src/andromede/expression/equality.py index a5264935..55c91e6d 100644 --- a/src/andromede/expression/equality.py +++ b/src/andromede/expression/equality.py @@ -23,7 +23,6 @@ MultiplicationNode, NegationNode, ParameterNode, - SubstractionNode, VariableNode, ) from andromede.expression.expression import ( @@ -66,8 +65,6 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: return self.negation(left, right) if isinstance(left, AdditionNode) and isinstance(right, AdditionNode): return self.addition(left, right) - if isinstance(left, SubstractionNode) and isinstance(right, SubstractionNode): - return self.substraction(left, right) if isinstance(left, DivisionNode) and isinstance(right, DivisionNode): return self.division(left, right) if isinstance(left, MultiplicationNode) and isinstance( @@ -130,10 +127,9 @@ def negation(self, left: NegationNode, right: NegationNode) -> bool: return self.visit(left.operand, right.operand) def addition(self, left: AdditionNode, right: AdditionNode) -> bool: - return self._visit_operands(left, right) - - def substraction(self, left: SubstractionNode, right: SubstractionNode) -> bool: - return self._visit_operands(left, right) + left_ops = left.operands + right_ops = right.operands + return len(left_ops) == len(right_ops) and all(self.visit(l, r) for l, r in zip(left_ops, right_ops)) def multiplication( self, left: MultiplicationNode, right: MultiplicationNode diff --git a/src/andromede/expression/expression.py b/src/andromede/expression/expression.py index 4112fce3..03e6aa90 100644 --- a/src/andromede/expression/expression.py +++ b/src/andromede/expression/expression.py @@ -16,7 +16,7 @@ import enum import inspect from dataclasses import dataclass -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union, List import andromede.expression.port_operator import andromede.expression.scenario_operator @@ -40,16 +40,29 @@ def __neg__(self) -> "ExpressionNode": return NegationNode(self) def __add__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: AdditionNode(self, x)) + lhs = self + operands = [] + rhs = _wrap_in_node(rhs) + operands.extend(lhs.operands if isinstance(lhs, AdditionNode) else [lhs]) + operands.extend(rhs.operands if isinstance(rhs, AdditionNode) else [rhs]) + return AdditionNode(operands) def __radd__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: AdditionNode(x, self)) + lhs = _wrap_in_node(lhs) + return lhs + self def __sub__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: SubstractionNode(self, x)) + lhs = self + operands = [] + rhs = _wrap_in_node(rhs) + operands.extend(lhs.operands if isinstance(lhs, AdditionNode) else [lhs]) + right_operands = rhs.operands if isinstance(rhs, AdditionNode) else [rhs] + operands.extend([-o for o in right_operands]) + return AdditionNode(operands) def __rsub__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: SubstractionNode(x, self)) + lhs = _wrap_in_node(lhs) + return lhs + self def __mul__(self, rhs: Any) -> "ExpressionNode": return _apply_if_node(rhs, lambda x: MultiplicationNode(self, x)) @@ -322,13 +335,8 @@ class ComparisonNode(BinaryOperatorNode): @dataclass(frozen=True, eq=False) -class AdditionNode(BinaryOperatorNode): - pass - - -@dataclass(frozen=True, eq=False) -class SubstractionNode(BinaryOperatorNode): - pass +class AdditionNode(ExpressionNode): + operands: List[ExpressionNode] @dataclass(frozen=True, eq=False) diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index 49dc7466..c4cfda76 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -30,7 +30,6 @@ PortFieldAggregatorNode, PortFieldNode, ScenarioOperatorNode, - SubstractionNode, TimeEvalNode, TimeShiftNode, TimeSumNode, @@ -76,10 +75,11 @@ def negation(self, node: NegationNode) -> IndexingStructure: return visit(node.operand, self) def addition(self, node: AdditionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def substraction(self, node: SubstractionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) + operands = [visit(o, self) for o in node.operands] + res = operands[0] + for o in node.operands[1:]: + res = res | visit(o, self) + return res def multiplication(self, node: MultiplicationNode) -> IndexingStructure: return visit(node.left, self) | visit(node.right, self) diff --git a/src/andromede/expression/print.py b/src/andromede/expression/print.py index b031bee8..a682ed6d 100644 --- a/src/andromede/expression/print.py +++ b/src/andromede/expression/print.py @@ -22,7 +22,7 @@ PortFieldNode, TimeEvalNode, TimeShiftNode, - TimeSumNode, + TimeSumNode, ProblemVariableNode, ProblemParameterNode, ) from andromede.expression.visitor import T @@ -36,7 +36,6 @@ NegationNode, ParameterNode, ScenarioOperatorNode, - SubstractionNode, VariableNode, ) from .visitor import ExpressionVisitor, visit @@ -63,14 +62,8 @@ def negation(self, node: NegationNode) -> str: return f"-({visit(node.operand, self)})" def addition(self, node: AdditionNode) -> str: - left_value = visit(node.left, self) - right_value = visit(node.right, self) - return f"({left_value} + {right_value})" - - def substraction(self, node: SubstractionNode) -> str: - left_value = visit(node.left, self) - right_value = visit(node.right, self) - return f"({left_value} - {right_value})" + values = [visit(o, self) for o in node.operands] + return f"({' + '.join(values)})" def multiplication(self, node: MultiplicationNode) -> str: left_value = visit(node.left, self) @@ -100,6 +93,14 @@ def comp_variable(self, node: ComponentVariableNode) -> str: def comp_parameter(self, node: ComponentParameterNode) -> str: return f"{node.component_id}.{node.name}" + def pb_variable(self, node: ProblemVariableNode) -> str: + # TODO + return f"{node.component_id}.{node.name}" + + def pb_parameter(self, node: ProblemParameterNode) -> str: + # TODO + return f"{node.component_id}.{node.name}" + def time_shift(self, node: TimeShiftNode) -> str: return f"({visit(node.operand, self)}.shift({visit(node.time_shift, self)}))" diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index d1e790f4..351be9e5 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -34,7 +34,6 @@ ProblemParameterNode, ProblemVariableNode, ScenarioOperatorNode, - SubstractionNode, TimeEvalNode, TimeShiftNode, TimeSumNode, @@ -65,10 +64,6 @@ def negation(self, node: NegationNode) -> T: def addition(self, node: AdditionNode) -> T: ... - @abstractmethod - def substraction(self, node: SubstractionNode) -> T: - ... - @abstractmethod def multiplication(self, node: MultiplicationNode) -> T: ... @@ -160,8 +155,6 @@ def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: return visitor.multiplication(root) elif isinstance(root, DivisionNode): return visitor.division(root) - elif isinstance(root, SubstractionNode): - return visitor.substraction(root) elif isinstance(root, ComparisonNode): return visitor.comparison(root) elif isinstance(root, TimeShiftNode): @@ -220,14 +213,11 @@ def negation(self, node: NegationNode) -> T_op: return -visit(node.operand, self) def addition(self, node: AdditionNode) -> T_op: - left_value = visit(node.left, self) - right_value = visit(node.right, self) - return left_value + right_value - - def substraction(self, node: SubstractionNode) -> T_op: - left_value = visit(node.left, self) - right_value = visit(node.right, self) - return left_value - right_value + operands = [visit(o, self) for o in node.operands] + res = operands[0] + for o in operands[1:]: + res = res + o + return res def multiplication(self, node: MultiplicationNode) -> T_op: left_value = visit(node.left, self) diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 1e13b703..e061c9fc 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -29,7 +29,6 @@ MultiplicationNode, NegationNode, ParameterNode, - SubstractionNode, VariableNode, ) from andromede.expression.degree import is_linear @@ -243,10 +242,8 @@ def _visit_binary_op(self, node: BinaryOperatorNode) -> None: visit(node.right, self) def addition(self, node: AdditionNode) -> None: - self._visit_binary_op(node) - - def substraction(self, node: SubstractionNode) -> None: - self._visit_binary_op(node) + for n in node.operands: + visit(n, self) def multiplication(self, node: MultiplicationNode) -> None: self._visit_binary_op(node) diff --git a/tests/unittests/expressions/test_expressions.py b/tests/unittests/expressions/test_expressions.py index d1fdaed2..b8521852 100644 --- a/tests/unittests/expressions/test_expressions.py +++ b/tests/unittests/expressions/test_expressions.py @@ -81,7 +81,7 @@ def parameter_is_constant_over_time(self, name: str) -> bool: def test_comp_parameter() -> None: - add_node = AdditionNode(LiteralNode(1), ComponentVariableNode("comp1", "x")) + add_node = AdditionNode([LiteralNode(1), ComponentVariableNode("comp1", "x")]) expr = DivisionNode(add_node, ComponentParameterNode("comp1", "p")) assert visit(expr, PrinterVisitor()) == "((1 + comp1.x) / comp1.p)" @@ -93,7 +93,7 @@ def test_comp_parameter() -> None: def test_ast() -> None: - add_node = AdditionNode(LiteralNode(1), VariableNode("x")) + add_node = AdditionNode([LiteralNode(1), VariableNode("x")]) expr = DivisionNode(add_node, ParameterNode("p")) assert visit(expr, PrinterVisitor()) == "((1 + x) / p)"