Skip to content

Commit

Permalink
Add a few failing tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sylvain Leclerc <sylvain.leclerc@rte-france.com>
  • Loading branch information
sylvlecl committed Mar 15, 2024
1 parent 7bd061b commit 73f8301
Showing 1 changed file with 71 additions and 50 deletions.
121 changes: 71 additions & 50 deletions tests/andromede/expressions/parsing/test_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
from typing import Set

import pytest
from antlr4 import CommonTokenStream, InputStream

from andromede.expression import ExpressionNode, literal, param, print_expr, var
from andromede.expression.equality import expressions_equal
from andromede.expression.expression import ExpressionRange, port_field
from andromede.expression.parsing.antlr.ExprLexer import ExprLexer
from andromede.expression.parsing.antlr.ExprParser import ExprParser
from andromede.expression.parsing.parse_expression import (
AntaresParseException,
ModelIdentifiers,
Expand All @@ -26,43 +24,75 @@


@pytest.mark.parametrize(
"expression_str, expected",
"variables, parameters, expression_str, expected",
[
("1 + 2", literal(1) + 2),
("1 - 2", literal(1) - 2),
("1 - 3 + 4 - 2", literal(1) - 3 + 4 - 2),
({}, {}, "1 + 2", literal(1) + 2),
({}, {}, "1 - 2", literal(1) - 2),
({}, {}, "1 - 3 + 4 - 2", literal(1) - 3 + 4 - 2),
(
{"x"},
{"p"},
"1 + 2 * x = p",
literal(1) + 2 * var("x") == param("p"),
),
(
{},
{},
"port.f <= 0",
port_field("port", "f") <= 0,
),
("sum(x)", var("x").sum()),
("x[-1]", var("x").eval(-literal(1))),
("x[-1..5]", var("x").eval(ExpressionRange(-literal(1), literal(5)))),
("x[1]", var("x").eval(1)),
("x[t-1]", var("x").shift(-literal(1))),
({"x"}, {}, "sum(x)", var("x").sum()),
({"x"}, {}, "x[-1]", var("x").eval(-literal(1))),
(
{"x"},
{},
"x[-1..5]",
var("x").eval(ExpressionRange(-literal(1), literal(5))),
),
({"x"}, {}, "x[1]", var("x").eval(1)),
({"x"}, {}, "x[t-1]", var("x").shift(-literal(1))),
(
{"x"},
{},
"x[t-1, t+4]",
var("x").shift([-literal(1), literal(4)]),
),
(
{"x"},
{},
"x[t-1, t, t+4]",
var("x").shift([-literal(1), literal(0), literal(4)]),
),
("x[t-1..t+5]", var("x").shift(ExpressionRange(-literal(1), literal(5)))),
("x[t-1..t]", var("x").shift(ExpressionRange(-literal(1), literal(0)))),
("x[t..t+5]", var("x").shift(ExpressionRange(literal(0), literal(5)))),
("x[t]", var("x")),
("x[t+p]", var("x").shift(param("p"))),
(
{"x"},
{},
"x[t-1..t+5]",
var("x").shift(ExpressionRange(-literal(1), literal(5))),
),
(
{"x"},
{},
"x[t-1..t]",
var("x").shift(ExpressionRange(-literal(1), literal(0))),
),
(
{"x"},
{},
"x[t..t+5]",
var("x").shift(ExpressionRange(literal(0), literal(5))),
),
({"x"}, {}, "x[t]", var("x")),
({"x"}, {"p"}, "x[t+p]", var("x").shift(param("p"))),
(
{"x"},
{},
"sum(x[-1..5])",
var("x").eval(ExpressionRange(-literal(1), literal(5))).sum(),
),
("sum_connections(port.f)", port_field("port", "f").sum_connections()),
({}, {}, "sum_connections(port.f)", port_field("port", "f").sum_connections()),
(
{"level", "injection", "withdrawal"},
{"inflows", "efficiency"},
"level - level[-1] - efficiency * injection + withdrawal = inflows",
var("level")
- var("level").eval(-literal(1))
Expand All @@ -71,59 +101,50 @@
== param("inflows"),
),
(
{"nb_start", "nb_on"},
{"d_min_up"},
"sum(nb_start[-d_min_up + 1 .. 0]) <= nb_on",
var("nb_start")
.eval(ExpressionRange(-param("d_min_up") + 1, literal(0)))
.sum()
<= var("nb_on"),
),
(
{"generation"},
{"cost"},
"expec(sum(cost * generation))",
(param("cost") * var("generation")).sum().expec(),
),
],
)
def test_parsing_visitor(expression_str: str, expected: ExpressionNode):
identifiers = ModelIdentifiers(
variables={
"x",
"level",
"injection",
"withdrawal",
"nb_start",
"nb_on",
"generation",
},
parameters={"p", "inflows", "efficiency", "d_min_up", "cost"},
)

def test_parsing_visitor(
variables: Set[str],
parameters: Set[str],
expression_str: str,
expected: ExpressionNode,
):
identifiers = ModelIdentifiers(variables, parameters)
expr = parse_expression(expression_str, identifiers)
print()
print(print_expr(expr))
assert expressions_equal(expr, expected)


def test_parse_cancellation_err():
@pytest.mark.parametrize(
"expression_str",
[
"1**3",
"1 6",
"x[t+1-t]",
"x[2*t]",
],
)
def test_parse_cancellation_should_throw(expression_str: str):
# Console log error is displayed !
identifiers = ModelIdentifiers(
variables={
"x",
"level",
"injection",
"withdrawal",
"nb_start",
"nb_on",
"generation",
},
parameters={
"p",
"inflows",
"efficiency",
"d_min_up",
"cost",
},
variables={"x"},
parameters=set(),
)
expression_str = "x[t+1-t]"

with pytest.raises(
AntaresParseException,
Expand Down

0 comments on commit 73f8301

Please sign in to comment.