Skip to content

Commit

Permalink
Appease ViPy (Vieira-py type checking)
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Oct 22, 2024
1 parent 1a33bdc commit a43886e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return [eye_input * (non_eye_input**0.5)]


@node_rewriter([_solve_bilinear_discrete_lyapunov])
@node_rewriter([_solve_bilinear_discrete_lyapunov]) # type: ignore
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"""
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
Expand Down
68 changes: 51 additions & 17 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,14 @@ def perform(self, node, inputs, outputs):


class SolveContinuousLyapunov(Op):
"""
Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
scipy.linalg.solve_continuous_lyapunov
"""

__props__ = ()
gufunc_signature = "(m,m),(m,m)->(m,m)"

Expand Down Expand Up @@ -815,6 +823,14 @@ def grad(self, inputs, output_grads):


class BilinearSolveDiscreteLyapunov(Op):
"""
Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
docstring for scipy.linalg.solve_discrete_lyapunov
"""

gufunc_signature = "(m,m),(m,m)->(m,m)"

def make_node(self, A, B):
Expand Down Expand Up @@ -861,7 +877,17 @@ def grad(self, inputs, output_grads):
)


def _direct_solve_discrete_lyapunov(A, Q) -> TensorVariable:
def _direct_solve_discrete_lyapunov(
A: TensorVariable, Q: TensorVariable
) -> TensorVariable:
r"""
Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
Neudecker.
This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
"""

if A.type.dtype.startswith("complex"):
AxA = kron(A, A.conj())
else:
Expand All @@ -876,17 +902,17 @@ def _direct_solve_discrete_lyapunov(A, Q) -> TensorVariable:


def solve_discrete_lyapunov(
A: TensorVariable,
Q: TensorVariable,
A: TensorLike,
Q: TensorLike,
method: Literal["direct", "bilinear"] = "bilinear",
) -> TensorVariable:
"""Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
Parameters
----------
A: TensorVariable
A: TensorLike
Square matrix of shape N x N
Q: TensorVariable
Q: TensorLike
Square matrix of shape N x N
method: str, one of ``"direct"`` or ``"bilinear"``
Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
Expand All @@ -910,7 +936,8 @@ def solve_discrete_lyapunov(

if method == "direct":
signature = BilinearSolveDiscreteLyapunov.gufunc_signature
return pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
X = pt.vectorize(_direct_solve_discrete_lyapunov, signature=signature)(A, Q)
return cast(TensorVariable, X)

elif method == "bilinear":
return cast(TensorVariable, _solve_bilinear_discrete_lyapunov(A, Q))
Expand All @@ -919,15 +946,15 @@ def solve_discrete_lyapunov(
raise ValueError(f"Unknown method {method}")


def solve_continuous_lyapunov(A: TensorVariable, Q: TensorVariable) -> TensorVariable:
def solve_continuous_lyapunov(A: TensorLike, Q: TensorLike) -> TensorVariable:
"""
Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
Parameters
----------
A: TensorVariable
A: TensorLike
Square matrix of shape ``N x N``.
Q: TensorVariable
Q: TensorLike
Square matrix of shape ``N x N``.
Returns
Expand Down Expand Up @@ -1002,24 +1029,31 @@ def grad(self, inputs, output_grads):


def solve_discrete_are(
A: TensorVariable,
B: TensorVariable,
Q: TensorVariable,
R: TensorVariable,
A: TensorLike,
B: TensorLike,
Q: TensorLike,
R: TensorLike,
enforce_Q_symmetric: bool = False,
) -> TensorVariable:
"""
Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
steady-state covariance of the Kalman Filter.
Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
solution. This stable solution, if it exists, will be returned by this function.
Parameters
----------
A: TensorVariable
A: TensorLike
Square matrix of shape M x M
B: TensorVariable
B: TensorLike
Square matrix of shape M x M
Q: TensorVariable
Q: TensorLike
Symmetric square matrix of shape M x M
R: TensorVariable
R: TensorLike
Square matrix of shape N x N
enforce_Q_symmetric: bool
If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
Expand Down

0 comments on commit a43886e

Please sign in to comment.