Skip to content

Commit

Permalink
Handle constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
tbittar committed Jun 25, 2024
1 parent ef9c8b6 commit 4f0542c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 33 deletions.
58 changes: 56 additions & 2 deletions src/andromede/expression/linear_expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,28 @@ def __str__(self) -> str:

return result

def __le__(self, rhs: Any) -> "StandaloneConstraint":
return StandaloneConstraint(
expression=self - rhs,
lower_bound=literal(-float("inf")),
upper_bound=literal(0),
)

def __ge__(self, rhs: Any) -> "ExpressionNodeEfficient":
return StandaloneConstraint(
expression=self - rhs,
lower_bound=literal(0),
upper_bound=literal(float("inf")),
)

# def __eq__(self, rhs: Any) -> "ExpressionNodeEfficient": # type: ignore
# return _apply_if_node(rhs, lambda x: ComparisonNode(self, x, Comparator.EQUAL))

def __eq__(self, rhs: object) -> bool:
return (
isinstance(rhs, LinearExpressionEfficient)
and expressions_equal(self.constant, rhs.constant)
and self.terms
== rhs.terms
and self.terms == rhs.terms
)

def __iadd__(
Expand Down Expand Up @@ -466,6 +482,44 @@ def is_constant(self) -> bool:
return not self.terms


@dataclass
class StandaloneConstraint:
"""
A standalone constraint, with rugid initialization.
"""

expression: LinearExpressionEfficient
lower_bound: LinearExpressionEfficient
upper_bound: LinearExpressionEfficient

def __init__(
self,
expression: LinearExpressionEfficient,
lower_bound: LinearExpressionEfficient,
upper_bound: LinearExpressionEfficient,
) -> None:

for bound in [lower_bound, upper_bound]:
if bound is not None and not bound.is_constant():
raise ValueError(
f"The bounds of a constraint should not contain variables, {print_expr(bound)} was given."
)

self.expression = expression
if lower_bound is not None:
self.lower_bound = lower_bound
else:
self.lower_bound = literal(-float("inf"))

if upper_bound is not None:
self.upper_bound = upper_bound
else:
self.upper_bound = literal(float("inf"))

def __str__(self) -> str:
return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}"


def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient:
if isinstance(obj, LinearExpressionEfficient):
return obj
Expand Down
51 changes: 25 additions & 26 deletions src/andromede/model/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@
#
# This file is part of the Antares project.
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Optional, Union

from andromede.expression.degree import is_constant
from andromede.expression.equality import (
expressions_equal,
expressions_equal_if_present,
)
from andromede.expression.expression import (
Comparator,
ComparisonNode,
ExpressionNode,

# from andromede.expression.expression import (
# Comparator,
# ComparisonNode,
# ExpressionNode,
# literal,
# )
from andromede.expression.expression_efficient import Comparator, ComparisonNode
from andromede.expression.linear_expression_efficient import (
LinearExpressionEfficient,
StandaloneConstraint,
literal,
)
from andromede.expression.print import print_expr
Expand All @@ -36,42 +43,31 @@ class Constraint:
"""

name: str
expression: ExpressionNode
lower_bound: ExpressionNode
upper_bound: ExpressionNode
expression: LinearExpressionEfficient
lower_bound: LinearExpressionEfficient
upper_bound: LinearExpressionEfficient
context: ProblemContext

def __init__(
self,
name: str,
expression: ExpressionNode,
lower_bound: Optional[ExpressionNode] = None,
upper_bound: Optional[ExpressionNode] = None,
expression: Union[LinearExpressionEfficient, StandaloneConstraint],
lower_bound: Optional[LinearExpressionEfficient] = None,
upper_bound: Optional[LinearExpressionEfficient] = None,
context: ProblemContext = ProblemContext.OPERATIONAL,
) -> None:
self.name = name
self.context = context

if isinstance(expression, ComparisonNode):
if isinstance(expression, StandaloneConstraint):
if lower_bound is not None or upper_bound is not None:
raise ValueError(
"Both comparison between two expressions and a bound are specfied, set either only a comparison between expressions or a single linear expression with bounds."
)

merged_expr = expression.left - expression.right
self.expression = merged_expr

if expression.comparator == Comparator.LESS_THAN:
# lhs - rhs <= 0
self.upper_bound = literal(0)
self.lower_bound = literal(-float("inf"))
elif expression.comparator == Comparator.GREATER_THAN:
# lhs - rhs >= 0
self.lower_bound = literal(0)
self.upper_bound = literal(float("inf"))
else: # lhs - rhs == 0
self.lower_bound = literal(0)
self.upper_bound = literal(0)
self.expression = expression.expression
self.lower_bound = expression.lower_bound
self.upper_bound = expression.upper_bound
else:
for bound in [lower_bound, upper_bound]:
if bound is not None and not is_constant(bound):
Expand Down Expand Up @@ -99,3 +95,6 @@ def __eq__(self, other: Any) -> bool:
and expressions_equal_if_present(self.lower_bound, other.lower_bound)
and expressions_equal_if_present(self.upper_bound, other.upper_bound)
)

def __str__(self) -> str:
return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}"
16 changes: 11 additions & 5 deletions tests/unittests/expressions/test_expressions_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from andromede.expression.indexing_structure import IndexingStructure
from andromede.expression.linear_expression_efficient import (
LinearExpressionEfficient,
StandaloneConstraint,
TermEfficient,
comp_param,
comp_var,
literal,
param,
var,
)
Expand Down Expand Up @@ -214,6 +216,7 @@ def test_addition(
) -> None:
assert e1 + e2 == expected


@pytest.mark.parametrize(
"e1, e2, expected",
[
Expand Down Expand Up @@ -318,14 +321,17 @@ def test_linear_expression_equality(
# )


def test_standalone_constraint() -> None:
cst = StandaloneConstraint(var("x"), literal(0), literal(10))

assert str(cst) == "0 <= +x <= + 10"


def test_comparison() -> None:
x = var("x")
p = param("p")
expr: Constraint = (
5 * x + 3
) >= p - 2 ## Overloading operator to return a constraint object !

assert str(expr) == "((5.0 * x) + 3.0) >= (p - 2.0)"
expr = (5 * x + 3) >= p - 2
assert str(expr) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= + inf"


class StructureProvider(IndexingStructureProvider):
Expand Down

0 comments on commit 4f0542c

Please sign in to comment.