Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #236

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions learned_optimization/circular_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Generic, Tuple, TypeVar

import jax
from jax import tree_util
import jax.numpy as jnp

CircularBufferState = collections.namedtuple("CircularBufferState",
Expand Down Expand Up @@ -54,7 +55,7 @@ def build_one(x):
tiled = jnp.tile(expanded, [self.size] + [1] * len(x.shape))
return jnp.asarray(tiled, dtype=x.dtype)

empty_buffer = jax.tree_map(build_one, self.abstract_value)
empty_buffer = tree_util.tree_map(build_one, self.abstract_value)
return CircularBufferState(
idx=jnp.asarray(0, jnp.int64),
values=(empty_buffer,
Expand All @@ -71,7 +72,8 @@ def do_update(src, to_set):
else:
return src.at[idx, :].set(to_set)

new_jax_array = jax.tree_map(do_update, state.values, (value, state.idx))
new_jax_array = tree_util.tree_map(do_update, state.values,
(value, state.idx))
return CircularBufferState(idx=state.idx + 1, values=new_jax_array)

def _reorder(self, vals, idx):
Expand Down Expand Up @@ -100,13 +102,13 @@ def stack_reorder(self, state: CircularBufferState) -> Tuple[T, jnp.ndarray]: #
candidate = jnp.clip((state.values[1] - state.idx + self.size), -1,
self.size)
mask = self._reorder(jnp.where(candidate == -1, 0, 1), state.idx)
return jax.tree_map(lambda x: self._reorder(x, state.idx),
state.values[0]), mask
return tree_util.tree_map(lambda x: self._reorder(x, state.idx),
state.values[0]), mask

@functools.partial(jax.jit, static_argnums=(0,))
def gather_from_present(
self, state: CircularBufferState, idxs: jnp.ndarray) -> T: # pytype: disable=invalid-annotation
"""Get the values from for each idx in the past."""
offset = (idxs % self.size)
idx = (state.idx + offset) % self.size
return jax.tree_map(lambda x: x[idx], state.values[0])
return tree_util.tree_map(lambda x: x[idx], state.values[0])