Skip to content

Commit

Permalink
Add numerically safe implementation of unit_vector
Browse files Browse the repository at this point in the history
Fixes #4.
  • Loading branch information
schmrlng committed Dec 3, 2023
1 parent ac4db1d commit 6683248
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 1 deletion.
4 changes: 3 additions & 1 deletion hj_reachability/sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from flax import struct
import jax.numpy as jnp

from hj_reachability import utils

from typing import Any

Array = Any
Expand Down Expand Up @@ -61,7 +63,7 @@ class Ball(BoundedSet):

def extreme_point(self, direction: Array) -> Array:
"""Computes the point `x` in the set such that the dot product `x @ direction` is greatest."""
return self.center + self.radius * direction / jnp.linalg.norm(direction)
return self.center + self.radius * utils.unit_vector(direction)

@property
def bounding_box(self) -> "Box":
Expand Down
32 changes: 32 additions & 0 deletions hj_reachability/sets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from absl.testing import absltest
import jax
import numpy as np

from hj_reachability import sets


class SetsTest(absltest.TestCase):

def setUp(self):
np.random.seed(0)

def test_box(self):
box = sets.Box(np.ones(3), 2 * np.ones(3))
np.testing.assert_allclose(box.extreme_point(np.array([1, -1, 1])), np.array([2, 1, 2]))
self.assertTrue(np.all(np.isfinite(box.extreme_point(np.zeros(3)))))
self.assertEqual(box.bounding_box, box)
np.testing.assert_allclose(box.max_magnitudes, 2 * np.ones(3))
self.assertEqual(box.ndim, 3)

def test_ball(self):
ball = sets.Ball(np.ones(3), np.sqrt(3))
np.testing.assert_allclose(ball.extreme_point(np.array([1, -1, 1])), np.array([2, 0, 2]), atol=1e-6)
self.assertTrue(np.all(np.isfinite(ball.extreme_point(np.zeros(3)))))
jax.tree_map(np.testing.assert_allclose, ball.bounding_box,
sets.Box((1 - np.sqrt(3)) * np.ones(3), (1 + np.sqrt(3)) * np.ones(3)))
np.testing.assert_allclose(ball.max_magnitudes, (1 + np.sqrt(3)) * np.ones(3))
self.assertEqual(ball.ndim, 3)


if __name__ == "__main__":
absltest.main()
8 changes: 8 additions & 0 deletions hj_reachability/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools

import jax
import jax.numpy as jnp
import numpy as np

from typing import Any, Callable, Iterable, List, Mapping, Optional, TypeVar, Union
Expand Down Expand Up @@ -56,3 +57,10 @@ def get_axis_sequence(axis_array: np.ndarray) -> List:
vmap_kwargs = jax.tree_util.tree_transpose(jax.tree_util.tree_structure(multivmap_kwargs), axis_sequence_structure,
jax.tree_map(get_axis_sequence, multivmap_kwargs))
return functools.reduce(lambda f, kwargs: jax.vmap(f, **kwargs), vmap_kwargs, fun)


def unit_vector(x):
"""Normalizes a vector `x`, returning a unit vector in the same direction, or a zero vector if `x` is zero."""
norm2 = jnp.sum(jnp.square(x))
iszero = norm2 < jnp.finfo(jnp.zeros(()).dtype).eps**2
return jnp.where(iszero, jnp.zeros_like(x), x / jnp.sqrt(jnp.where(iszero, 1, norm2)))
40 changes: 40 additions & 0 deletions hj_reachability/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from absl.testing import absltest
import jax
import jax.numpy as jnp
import numpy as np

from hj_reachability import utils


class UtilsTest(absltest.TestCase):

def setUp(self):
np.random.seed(0)

def test_multivmap(self):
a = np.random.random((3, 4, 5, 6))
np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1]))(a), np.max(a, (2, 3)))
np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 2]))(a), np.max(a, -1))
np.testing.assert_allclose(utils.multivmap(jnp.max, np.array([0, 1, 3]), np.array([0, 1, 2]))(a), np.max(a, 2))
np.testing.assert_allclose(
utils.multivmap(jnp.max, np.array([1, 0, 2]), np.array([0, 1, 2]))(a),
np.max(a, 3).swapaxes(0, 1))
np.testing.assert_allclose(
utils.multivmap(jnp.max, np.array([3, 2]), np.array([0, 1]))(a),
np.max(a, (0, 1)).swapaxes(0, 1))

def test_unit_vector(self):
unsafe_unit_vector = lambda x: x / jnp.linalg.norm(x, axis=-1, keepdims=True)
for d in range(1, 4):
np.testing.assert_array_equal(utils.unit_vector(np.zeros(d)), np.zeros(d))
self.assertTrue(np.all(np.isfinite(jax.jacobian(utils.unit_vector)(np.zeros(d)))))
self.assertTrue(np.all(np.isnan(jax.jacobian(unsafe_unit_vector)(np.zeros(d)))))
a = np.random.random((100, d))
np.testing.assert_allclose(jax.vmap(utils.unit_vector)(a), unsafe_unit_vector(a), atol=1e-6)
np.testing.assert_allclose(jax.vmap(jax.jacobian(utils.unit_vector))(a),
jax.vmap(jax.jacobian(unsafe_unit_vector))(a),
atol=1e-6)


if __name__ == "__main__":
absltest.main()

0 comments on commit 6683248

Please sign in to comment.