Skip to content

Commit

Permalink
Test against complex inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Oct 21, 2024
1 parent 6c7f1ba commit 1a33bdc
Showing 1 changed file with 38 additions and 16 deletions.
54 changes: 38 additions & 16 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,15 +537,22 @@ def test_solve_discrete_lyapunov(
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"

A1, A2 = rng.normal(size=(2, *shape)).astype(dtype)
Q1, Q2 = rng.normal(size=(2, *shape)).astype(dtype)

if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1

a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)

x = solve_discrete_lyapunov(a, q, method=method)
f = function([a, q], x)

A = rng.normal(size=shape).astype(dtype)
Q = rng.normal(size=shape).astype(dtype)

X = f(A, Q)
Q_recovered = vec_recover_Q(A, X, continuous=False)

Expand All @@ -561,15 +568,12 @@ def test_solve_discrete_lyapunov_gradient(
):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")

rng = np.random.default_rng(utt.fetch_seed())
dtype = config.floatX
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"
pytest.skip(reason="Complex numbers are not supported in the gradient test")

A = rng.normal(size=shape).astype(dtype)
Q = rng.normal(size=shape).astype(dtype)
rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)

utt.verify_grad(
functools.partial(solve_discrete_lyapunov, method=method),
Expand All @@ -579,14 +583,29 @@ def test_solve_discrete_lyapunov_gradient(


@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
def test_solve_continuous_lyapunov(shape: tuple[int]):
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov(shape: tuple[int], use_complex: bool):
rng = np.random.default_rng(utt.fetch_seed())
a = pt.tensor(name="a", shape=shape)
q = pt.tensor(name="q", shape=shape)

dtype = config.floatX
if use_complex:
precision = int(dtype[-2:]) # 64 or 32
dtype = f"complex{int(2 * precision)}"

A1, A2 = rng.normal(size=(2, *shape)).astype(dtype)
Q1, Q2 = rng.normal(size=(2, *shape)).astype(dtype)

if use_complex:
A = A1 + 1j * A2
Q = Q1 + 1j * Q2
else:
A = A1
Q = Q1

a = pt.tensor(name="a", shape=shape, dtype=dtype)
q = pt.tensor(name="q", shape=shape, dtype=dtype)
f = function([a, q], [solve_continuous_lyapunov(a, q)])

A = rng.normal(size=shape).astype(config.floatX)
Q = rng.normal(size=shape).astype(config.floatX)
X = f(A, Q)

Q_recovered = vec_recover_Q(A, X, continuous=True)
Expand All @@ -596,9 +615,12 @@ def test_solve_continuous_lyapunov(shape: tuple[int]):


@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int]):
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
def test_solve_continuous_lyapunov_grad(shape: tuple[int], use_complex):
if config.floatX == "float32":
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
if use_complex:
pytest.skip(reason="Complex numbers are not supported in the gradient test")

rng = np.random.default_rng(utt.fetch_seed())
A = rng.normal(size=shape).astype(config.floatX)
Expand Down

0 comments on commit 1a33bdc

Please sign in to comment.