Skip to content

Commit

Permalink
Add slogdet implementation (#28)
Browse files Browse the repository at this point in the history
* Add slogdet implementation

* add compare_allclose decorator to test_slogdet; compare_allclose now manage tuple results.

* add compare_allclose decorator to test_slogdet; compare_allclose now manage tuple results.

* make black

* do not check dimension in slogdet (since raw implementation accept multi-dimensional arrays)

* merge compare_allclose_tuple and compare_allclose

* restore compare_allclose

* test sign and logdet separatly

* add test on stack_matrices

* correct typing

* remove manual comparaison with reference values since compare_all already compare all implementations with numpy implementation

* factorize tests with parameterize fixture

* factorize slogdet test using parameterize fixture and itertools.product

* add test on initial 10x10 matrix

* add ids to slogdet tests

* reverse order of array and output in test

* swap compare_allclose and pytest.mark.parametrize
  • Loading branch information
eserie authored Mar 6, 2021
1 parent 6ef7313 commit 850a905
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 1 deletion.
1 change: 0 additions & 1 deletion docs/.vuepress/eagerpy

This file was deleted.

4 changes: 4 additions & 0 deletions eagerpy/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def crossentropy(logits: TensorType, labels: TensorType) -> TensorType:
return logits.crossentropy(labels)


def slogdet(matrix: TensorType) -> Tuple[TensorType, TensorType]:
return matrix.slogdet()


@overload
def value_and_grad_fn(
t: TensorType, f: Callable[..., TensorType]
Expand Down
4 changes: 4 additions & 0 deletions eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
).squeeze(axis=1)
return type(self)(ces)

def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]:
sign, logabsdet = np.linalg.slogdet(self.raw)
return type(self)(sign), type(self)(logabsdet)

@overload
def _value_and_grad_fn(
self: TensorType, f: Callable[..., TensorType]
Expand Down
4 changes: 4 additions & 0 deletions eagerpy/tensor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
).squeeze(axis=1)
return type(self)(ces)

def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]:
sign, logabsdet = np.linalg.slogdet(self.raw)
return type(self)(sign), type(self)(logabsdet)

@overload
def _value_and_grad_fn(
self: TensorType, f: Callable[..., TensorType]
Expand Down
4 changes: 4 additions & 0 deletions eagerpy/tensor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
torch.nn.functional.cross_entropy(self.raw, labels.raw, reduction="none")
)

def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]:
sign, logabsdet = torch.slogdet(self.raw)
return type(self)(sign), type(self)(logabsdet)

@overload
def _value_and_grad_fn(
self: TensorType, f: Callable[..., TensorType]
Expand Down
4 changes: 4 additions & 0 deletions eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ def isinf(self: TensorType) -> TensorType:
def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
...

@abstractmethod
def slogdet(matrix: TensorType) -> Tuple[TensorType, TensorType]:
...

@overload
def _value_and_grad_fn(
self: TensorType, f: Callable[..., TensorType]
Expand Down
4 changes: 4 additions & 0 deletions eagerpy/tensor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,10 @@ def crossentropy(self: TensorType, labels: TensorType) -> TensorType:
tf.nn.sparse_softmax_cross_entropy_with_logits(labels.raw, self.raw)
)

def slogdet(self: TensorType) -> Tuple[TensorType, TensorType]:
sign, logabsdet = tf.linalg.slogdet(self.raw)
return type(self)(sign), type(self)(logabsdet)

@overload
def _value_and_grad_fn(
self: TensorType, f: Callable[..., TensorType]
Expand Down
27 changes: 27 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, Dict, Any, Tuple, Union, Optional, cast
import pytest
import functools
import itertools
import numpy as np
import eagerpy as ep
from eagerpy import Tensor
Expand Down Expand Up @@ -1280,6 +1281,32 @@ def test_crossentropy(dummy: Tensor) -> Tensor:
return ep.crossentropy(t, t.argmax(axis=-1))


@pytest.mark.parametrize(
"array, output",
itertools.product(
[
np.array([[1, 2], [3, 4]]),
np.array([[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]),
np.arange(100).reshape((10, 10)),
],
["sign", "logdet"],
),
ids=map(
lambda *l: "_".join(*l),
itertools.product(
["matrix_finite", "stack_of_matrices", "matrix_infinite"],
["sign", "logdet"],
),
),
)
@compare_allclose
def test_slogdet(dummy: Tensor, array: Tensor, output: str) -> Tensor:
a = ep.from_numpy(dummy, array).float32()
outputs = dict()
outputs["sign"], outputs["logdet"] = ep.slogdet(a)
return outputs[output]


@pytest.mark.parametrize("axis", [0, 1, -1])
@compare_all
def test_stack(t1: Tensor, t2: Tensor, axis: int) -> Tensor:
Expand Down

0 comments on commit 850a905

Please sign in to comment.