From 73f8301a6b11930556187b2bf91a249da6bf8b06 Mon Sep 17 00:00:00 2001 From: Sylvain Leclerc Date: Fri, 15 Mar 2024 16:49:20 +0100 Subject: [PATCH] Add a few failing tests Signed-off-by: Sylvain Leclerc --- .../parsing/test_expression_parsing.py | 121 ++++++++++-------- 1 file changed, 71 insertions(+), 50 deletions(-) diff --git a/tests/andromede/expressions/parsing/test_expression_parsing.py b/tests/andromede/expressions/parsing/test_expression_parsing.py index 54af539e..5aab206d 100644 --- a/tests/andromede/expressions/parsing/test_expression_parsing.py +++ b/tests/andromede/expressions/parsing/test_expression_parsing.py @@ -9,15 +9,13 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. +from typing import Set import pytest -from antlr4 import CommonTokenStream, InputStream from andromede.expression import ExpressionNode, literal, param, print_expr, var from andromede.expression.equality import expressions_equal from andromede.expression.expression import ExpressionRange, port_field -from andromede.expression.parsing.antlr.ExprLexer import ExprLexer -from andromede.expression.parsing.antlr.ExprParser import ExprParser from andromede.expression.parsing.parse_expression import ( AntaresParseException, ModelIdentifiers, @@ -26,43 +24,75 @@ @pytest.mark.parametrize( - "expression_str, expected", + "variables, parameters, expression_str, expected", [ - ("1 + 2", literal(1) + 2), - ("1 - 2", literal(1) - 2), - ("1 - 3 + 4 - 2", literal(1) - 3 + 4 - 2), + ({}, {}, "1 + 2", literal(1) + 2), + ({}, {}, "1 - 2", literal(1) - 2), + ({}, {}, "1 - 3 + 4 - 2", literal(1) - 3 + 4 - 2), ( + {"x"}, + {"p"}, "1 + 2 * x = p", literal(1) + 2 * var("x") == param("p"), ), ( + {}, + {}, "port.f <= 0", port_field("port", "f") <= 0, ), - ("sum(x)", var("x").sum()), - ("x[-1]", var("x").eval(-literal(1))), - ("x[-1..5]", var("x").eval(ExpressionRange(-literal(1), literal(5)))), - ("x[1]", var("x").eval(1)), - ("x[t-1]", var("x").shift(-literal(1))), + ({"x"}, {}, "sum(x)", var("x").sum()), + ({"x"}, {}, "x[-1]", var("x").eval(-literal(1))), ( + {"x"}, + {}, + "x[-1..5]", + var("x").eval(ExpressionRange(-literal(1), literal(5))), + ), + ({"x"}, {}, "x[1]", var("x").eval(1)), + ({"x"}, {}, "x[t-1]", var("x").shift(-literal(1))), + ( + {"x"}, + {}, "x[t-1, t+4]", var("x").shift([-literal(1), literal(4)]), ), ( + {"x"}, + {}, "x[t-1, t, t+4]", var("x").shift([-literal(1), literal(0), literal(4)]), ), - ("x[t-1..t+5]", var("x").shift(ExpressionRange(-literal(1), literal(5)))), - ("x[t-1..t]", var("x").shift(ExpressionRange(-literal(1), literal(0)))), - ("x[t..t+5]", var("x").shift(ExpressionRange(literal(0), literal(5)))), - ("x[t]", var("x")), - ("x[t+p]", var("x").shift(param("p"))), ( + {"x"}, + {}, + "x[t-1..t+5]", + var("x").shift(ExpressionRange(-literal(1), literal(5))), + ), + ( + {"x"}, + {}, + "x[t-1..t]", + var("x").shift(ExpressionRange(-literal(1), literal(0))), + ), + ( + {"x"}, + {}, + "x[t..t+5]", + var("x").shift(ExpressionRange(literal(0), literal(5))), + ), + ({"x"}, {}, "x[t]", var("x")), + ({"x"}, {"p"}, "x[t+p]", var("x").shift(param("p"))), + ( + {"x"}, + {}, "sum(x[-1..5])", var("x").eval(ExpressionRange(-literal(1), literal(5))).sum(), ), - ("sum_connections(port.f)", port_field("port", "f").sum_connections()), + ({}, {}, "sum_connections(port.f)", port_field("port", "f").sum_connections()), ( + {"level", "injection", "withdrawal"}, + {"inflows", "efficiency"}, "level - level[-1] - efficiency * injection + withdrawal = inflows", var("level") - var("level").eval(-literal(1)) @@ -71,6 +101,8 @@ == param("inflows"), ), ( + {"nb_start", "nb_on"}, + {"d_min_up"}, "sum(nb_start[-d_min_up + 1 .. 0]) <= nb_on", var("nb_start") .eval(ExpressionRange(-param("d_min_up") + 1, literal(0))) @@ -78,52 +110,41 @@ <= var("nb_on"), ), ( + {"generation"}, + {"cost"}, "expec(sum(cost * generation))", (param("cost") * var("generation")).sum().expec(), ), ], ) -def test_parsing_visitor(expression_str: str, expected: ExpressionNode): - identifiers = ModelIdentifiers( - variables={ - "x", - "level", - "injection", - "withdrawal", - "nb_start", - "nb_on", - "generation", - }, - parameters={"p", "inflows", "efficiency", "d_min_up", "cost"}, - ) - +def test_parsing_visitor( + variables: Set[str], + parameters: Set[str], + expression_str: str, + expected: ExpressionNode, +): + identifiers = ModelIdentifiers(variables, parameters) expr = parse_expression(expression_str, identifiers) print() print(print_expr(expr)) assert expressions_equal(expr, expected) -def test_parse_cancellation_err(): +@pytest.mark.parametrize( + "expression_str", + [ + "1**3", + "1 6", + "x[t+1-t]", + "x[2*t]", + ], +) +def test_parse_cancellation_should_throw(expression_str: str): # Console log error is displayed ! identifiers = ModelIdentifiers( - variables={ - "x", - "level", - "injection", - "withdrawal", - "nb_start", - "nb_on", - "generation", - }, - parameters={ - "p", - "inflows", - "efficiency", - "d_min_up", - "cost", - }, + variables={"x"}, + parameters=set(), ) - expression_str = "x[t+1-t]" with pytest.raises( AntaresParseException,