Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Commit

Permalink
_cvxpy_{NegExpression,quad_over_lin,NonPos,NonNeg,Inequality}_sage_: New
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthias Koeppe committed Jun 15, 2021
1 parent a60e225 commit bb624fb
Showing 1 changed file with 217 additions and 4 deletions.
221 changes: 217 additions & 4 deletions src/sage/interfaces/cvxpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _cvxpy_Constant_sage_(self):
sage: x._sage_().parent()
Vector space of dimension 5 over Real Field with 53 bits of precision
A matrix variable with shape ``(3, 2)``::
A matrix constant with shape ``(3, 2)``::
sage: A = cp.Constant([[1, 2, 3], [4, 5, 6]]); A
Constant(CONSTANT, NONNEGATIVE, (3, 2))
Expand Down Expand Up @@ -494,7 +494,44 @@ def _cvxpy_AddExpression_sage_(self):
return self._sage_object


def _cvxpy_NegExpression_sage_(self):
r"""
Return an equivalent Sage object.
EXAMPLES::
sage: from sage.interfaces.cvxpy import cvxpy_init
sage: cvxpy_init()
sage: import cvxpy as cp
A scalar variable::
sage: a = cp.Variable(name='a'); a
Variable(())
sage: s_a = a._sage_(); s_a
a
sage: neg_a = -a; neg_a
Expression(AFFINE, UNKNOWN, ())
sage: s_neg_a = neg_a._sage_(); s_neg_a
-a
sage: s_neg_a is neg_a._sage_()
True
sage: s_neg_a == -s_a
True
"""
try:
return self._sage_object
except AttributeError:
pass

expr = self.args[0]._sage_()

self._sage_object = -expr
return self._sage_object


# Elementwise
# https://www.cvxpy.org/api_reference/cvxpy.atoms.elementwise.html

def _cvxpy_log1p_sage_(self):
r"""
Expand Down Expand Up @@ -561,9 +598,16 @@ def _cvxpy_power_sage_(self):
sage: square_x = cp.square(x); square_x
Expression(CONVEX, NONNEGATIVE, (5,))
sage: s_square_x = square_x._sage_(); s_square_x
Traceback (most recent call last):
...
Coordinate functions (x_0^2, x_1^2, x_2^2, x_3^2, x_4^2) on the Chart (dom_x, (x_0, x_1, x_2, x_3, x_4))
Coordinate functions (x_0^2, x_1^2, x_2^2, x_3^2, x_4^2)
on the Chart (dom_x, (x_0, x_1, x_2, x_3, x_4))
Testing ``cp.sqrt``::
sage: sqrt_x = cp.sqrt(x); sqrt_x
Expression(CONCAVE, NONNEGATIVE, (5,))
sage: s_sqrt_x = sqrt_x._sage_(); s_sqrt_x
Coordinate functions (sqrt(x_0), sqrt(x_1), sqrt(x_2), sqrt(x_3), sqrt(x_4))
on the Chart (dom_x, (x_0, x_1, x_2, x_3, x_4))
"""
try:
return self._sage_object
Expand Down Expand Up @@ -623,6 +667,164 @@ def _cvxpy_log_det_sage_(self):
return self._sage_object


def _cvxpy_quad_over_lin_sage_(self):
"""
Return an equivalent Sage object.
EXAMPLES::
sage: from sage.interfaces.cvxpy import cvxpy_init
sage: cvxpy_init()
sage: import cvxpy as cp
A scalar variable::
sage: a = cp.Variable(name='a'); a
Variable(())
A vector variable with shape ``(3,)``::
sage: x = cp.Variable(3, name='x'); x
Variable((3,))
sage: s_x = x._sage_(); s_x
Coordinate functions (x_0, x_1, x_2) on the Chart (dom_x, (x_0, x_1, x_2))
A matrix variable with shape ``(3, 2)``::
sage: A = cp.Variable((3, 2), name='A'); A
Variable((3, 2))
Test ``sum_squares``::
sage: sqr = cp.sum_squares(a); sqr
Expression(CONVEX, NONNEGATIVE, ())
sage: s_sqr = sqr._sage_(); s_sqr
a^2
sage: sos = cp.sum_squares(x); sos
Expression(CONVEX, NONNEGATIVE, ())
sage: s_sos = sos._sage_(); s_sos
x_0^2 + x_1^2 + x_2^2
sage: matsos = cp.sum_squares(A); matsos
Expression(CONVEX, NONNEGATIVE, ())
sage: s_matsos = matsos._sage_(); s_matsos
A_0_0^2 + A_0_1^2 + A_1_0^2 + A_1_1^2 + A_2_0^2 + A_2_1^2
"""
try:
return self._sage_object
except AttributeError:
pass

x = self.args[0]._sage_()
y = self.args[1]._sage_()

if self.args[0].shape:
sos = sum(x_i ** 2 for x_i in x.list())
else:
sos = x ** 2

if y == 1:
# WISHLIST: cvxpy creates a fresh Constant for the denominator 1
# every time that sum_squares is used. It would be nice if that
# was a unique object -- that we could then provide with a _sage_
# method that gives an exact 1.
self._sage_object = sos
else:
self._sage_object = sos / y
return self._sage_object


# Constraints

def _cvxpy_Inequality_sage_(self):
"""
Return an equivalent Sage object (relation expression).
EXAMPLES::
sage: from sage.interfaces.cvxpy import cvxpy_init
sage: cvxpy_init()
sage: import cvxpy as cp
A scalar variable::
sage: t = cp.Variable(name='t', nonneg=True); t
Variable((), nonneg=True)
A vector variable with shape ``(2,)``::
sage: x = cp.Variable(2, name='x'); x
Variable((2,))
We map them to a common space::
sage: E.<s_t, s_x_1, s_x_2> = EuclideanSpace()
sage: cart = E.default_chart()
sage: t._sage_object = cart.function(s_t)
sage: t._sage_()
s_t
sage: x._sage_object = cart.multifunction(s_x_1, s_x_2)
sage: x._sage_()
Coordinate functions (s_x_1, s_x_2) on the Chart (E^3, (s_t, s_x_1, s_x_2))
An interval::
sage: t_le_3 = (t <= 3); t_le_3
Inequality(Variable((), nonneg=True))
sage: s_t_le_3 = t_le_3._sage_(); s_t_le_3
s_t <= 3.00000000000000
The second-order cone::
sage: iscreamuscream = (cp.square(t) >= cp.sum_squares(x)); iscreamuscream
Inequality(Expression(CONVEX, NONNEGATIVE, ()))
sage: s_iscreamuscream = iscreamuscream._sage_(); s_iscreamuscream
s_x_1^2 + s_x_2^2 <= s_t^2
"""
try:
return self._sage_object
except AttributeError:
pass

lhs = self.args[0]._sage_()
rhs = self.args[1]._sage_()

if self.shape:
raise NotImplementedError

def to_SR(x):
try:
return x.expr(method='SR')
except AttributeError:
return x

self._sage_object = (to_SR(lhs) <= to_SR(rhs))
return self._sage_object


def _cvxpy_NonPos_sage_(self):
try:
return self._sage_object
except AttributeError:
pass

expr = self.args[0]._sage_()
self._sage_object = (expr.expr(method='SR') <= 0)
return self._sage_object


def _cvxpy_NonNeg_sage_(self):
try:
return self._sage_object
except AttributeError:
pass

expr = self.args[0]._sage_()
self._sage_object = (0 <= expr.expr(method='SR'))
return self._sage_object


# Monkey patching

from sage.repl.ipython_extension import run_once
Expand Down Expand Up @@ -651,6 +853,9 @@ def cvxpy_init():
from cvxpy.atoms.affine.add_expr import AddExpression
AddExpression._sage_ = _cvxpy_AddExpression_sage_

from cvxpy.atoms.affine.unary_operators import NegExpression
NegExpression._sage_ = _cvxpy_NegExpression_sage_

from cvxpy.atoms.elementwise.log1p import log1p
log1p._sage_ = _cvxpy_log1p_sage_

Expand All @@ -659,3 +864,11 @@ def cvxpy_init():

from cvxpy.atoms.log_det import log_det
log_det._sage_ = _cvxpy_log_det_sage_

from cvxpy.atoms.quad_over_lin import quad_over_lin
quad_over_lin._sage_ = _cvxpy_quad_over_lin_sage_

from cvxpy.constraints.nonpos import NonPos, NonNeg, Inequality
NonPos._sage_ = _cvxpy_NonPos_sage_
NonNeg._sage_ = _cvxpy_NonNeg_sage_
Inequality._sage_ = _cvxpy_Inequality_sage_

0 comments on commit bb624fb

Please sign in to comment.