Skip to content

Commit

Permalink
Add jax/flax utilities.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682316546
  • Loading branch information
jan-matthis authored and copybara-github committed Oct 8, 2024
1 parent 241455e commit 7f4ae8f
Show file tree
Hide file tree
Showing 11 changed files with 1,817 additions and 0 deletions.
139 changes: 139 additions & 0 deletions connectomics/jax/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for model checkpointing."""

import dataclasses
import re
from typing import Any, Optional, Sequence, TypeVar

from clu import checkpoint as checkpoint_lib
from etils import epath
import flax
import grain.python as grain
import grain.tensorflow as tfgrain
from orbax import checkpoint as ocp
import tensorflow as tf


T = TypeVar('T')


class MixedMultihostCheckpoint(checkpoint_lib.MultihostCheckpoint):
"""Like MultihostCheckpoint, but with a single source of FLAX weights.
TF settings are restored per-host as in the base class.
This prevents the model from loading potentially inconsistent weights
saved by other hosts. Weights might be inconsistent when they are saved
based on wall-clock time instead of step count.
"""

def load_state(
self, state: Optional[T], checkpoint: Optional[str] = None
) -> T:
flax_path = self._flax_path(self._checkpoint_or_latest(checkpoint))
flax_path = re.sub('checkpoints-[0-9]*', 'checkpoints-0', flax_path)
if not tf.io.gfile.exists(flax_path):
raise FileNotFoundError(f'Checkpoint {checkpoint} does not exist')
with tf.io.gfile.GFile(flax_path, 'rb') as f:
return flax.serialization.from_bytes(state, f.read())


def get_checkpoint_manager(
workdir: epath.PathLike,
item_names: Sequence[str],
) -> ocp.CheckpointManager:
"""Returns a checkpoint manager."""
checkpoint_dir = epath.Path(workdir) / 'checkpoints'
return ocp.CheckpointManager(
checkpoint_dir,
item_names=item_names,
options=ocp.CheckpointManagerOptions(
create=True, cleanup_tmp_directories=True),
)


def save_checkpoint(
manager: ocp.CheckpointManager,
state: Any,
step: int,
pygrain_checkpointers: Sequence[str] = ('train_iter',),
wait_until_finished: bool = True,
):
"""Saves a checkpoint.
Args:
manager: Checkpoint manager to use.
state: Data to be saved.
step: Step at which to save the data.
pygrain_checkpointers: Names of items for which to use pygrain checkpointer.
wait_until_finished: If True, blocks until checkpoint is written.
"""
save_args_dict = {}
for k, v in state.items():
if k in pygrain_checkpointers:
save_args_dict[k] = grain.PyGrainCheckpointSave(v)
else:
save_args_dict[k] = ocp.args.StandardSave(v)
manager.save(step, args=ocp.args.Composite(**save_args_dict))
if wait_until_finished:
manager.wait_until_finished()


def restore_checkpoint(
manager: ocp.CheckpointManager,
state: Any,
step: int | None = None,
pygrain_checkpointers: Sequence[str] = ('train_iter',),
) -> Any:
"""Restores a checkpoint.
Args:
manager: Checkpoint manager to use.
state: Data to be restored.
step: Step at which to save the data. If None, uses latest step.
pygrain_checkpointers: Names of items for which to use pygrain checkpointer.
Returns:
Restored data.
"""
restore_args_dict = {}
for k, v in state.items():
if k in pygrain_checkpointers:
restore_args_dict[k] = grain.PyGrainCheckpointRestore(v)
else:
restore_args_dict[k] = ocp.args.StandardRestore(v)
return manager.restore(
manager.latest_step() if step is None else step,
args=ocp.args.Composite(**restore_args_dict))


class TfGrainCheckpointHandler(tfgrain.OrbaxCheckpointHandler):

def save(self, directory: epath.Path, args: 'TfGrainCheckpointArgs') -> None:
return super().save(directory, args.item)

def restore(
self, directory: epath.Path, args: 'TfGrainCheckpointArgs'
) -> tfgrain.TfGrainDatasetIterator:
return super().restore(directory, args.item)


@ocp.args.register_with_handler( # pytype:disable=wrong-arg-types
TfGrainCheckpointHandler, for_save=True, for_restore=True
)
@dataclasses.dataclass
class TfGrainCheckpointArgs(ocp.args.CheckpointArgs):
item: Any
138 changes: 138 additions & 0 deletions connectomics/jax/config_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper tools for config files.
While configs should remain simple and self-explanatory, it can also be very
useful to augment the configs with a bit of logic that helps organizing
complicated sweeps.
This module contains shared code that allows for powerful uncluttered configs.
"""

from typing import Any, Sequence

import ml_collections as mlc


def parse_arg(arg, lazy=False, **spec):
"""Makes ConfigDict's get_config single-string argument more usable.
Example use in the config file:
import big_vision.configs.common as bvcc
def get_config(arg):
arg = bvcc.parse_arg(arg,
res=(224, int),
runlocal=False,
schedule='short',
)
# ...
config.shuffle_buffer = 250_000 if not arg.runlocal else 50
Ways that values can be passed when launching:
--config amazing.py:runlocal,schedule=long,res=128
--config amazing.py:res=128
--config amazing.py:runlocal # A boolean needs no value for "true".
--config amazing.py:runlocal=False # Explicit false boolean.
--config amazing.py:128 # The first spec entry may be passed unnamed alone.
Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
'false', '' to False).
Args:
arg: the string argument that's passed to get_config.
lazy: allow lazy parsing of arguments, which are not in spec. For these,
the type is auto-extracted in dependence of most complex possible type.
**spec: the name and default values of the expected options.
If the value is a tuple, the value's first element is the default value,
and the second element is a function called to convert the string.
Otherwise the type is automatically extracted from the default value.
Returns:
ConfigDict object with extracted type-converted values.
"""
# Normalize arg and spec layout.
arg = arg or '' # Normalize None to empty string
spec = {k: (v if isinstance(v, tuple) else (v, _get_type(v)))
for k, v in spec.items()}

result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only.

# Expand convenience-cases for a single parameter without = sign.
if arg and ',' not in arg and '=' not in arg:
# (think :runlocal) If it's the name of sth in the spec (or there is no
# spec), it's that in bool.
if arg in spec or not spec:
arg = f'{arg}=True'
# Otherwise, it is the value for the first entry in the spec.
else:
arg = f'{list(spec.keys())[0]}={arg}'
# Yes, we rely on Py3.7 insertion order!

# Now, expand the `arg` string into a dict of keys and values:
raw_kv = {raw_arg.split('=')[0]:
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
for raw_arg in arg.split(',') if raw_arg}

# And go through the spec, using provided or default value for each:
for name, (default, type_fn) in spec.items():
val = raw_kv.pop(name, None)
result[name] = type_fn(val) if val is not None else default

if raw_kv:
if lazy: # Process args which are not in spec.
for k, v in raw_kv.items():
result[k] = _autotype(v)
else:
raise ValueError(f'Unhandled config args remain: {raw_kv}')

return result


def _get_type(v):
"""Returns type of v and for boolean returns a strict bool function."""
if isinstance(v, bool):
def strict_bool(x):
assert x.lower() in {'true', 'false', ''}
return x.lower() == 'true'
return strict_bool
return type(v)


def _autotype(x):
"""Auto-converts string to bool/int/float if possible."""
assert isinstance(x, str)
if x.lower() in {'true', 'false'}:
return x.lower() == 'true' # Returns as bool.
try:
return int(x) # Returns as int.
except ValueError:
try:
return float(x) # Returns as float.
except ValueError:
return x # Returns as str.


def sequence_to_string(x: Sequence[Any], separator: str = ',') -> str:
"""Converts sequence of str/bool/int/float to a concatenated string."""
return separator.join([str(i) for i in x])


def string_to_sequence(x: str, separator: str = ',') -> Sequence[Any]:
"""Converts string to sequence of str/bool/int/float with auto-conversion."""
return [_autotype(i) for i in x.split(separator)]
81 changes: 81 additions & 0 deletions connectomics/jax/config_util_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config_util."""

from absl.testing import absltest
from absl.testing import parameterized
from connectomics.jax import config_util as cutil


class ConfigUtilTest(parameterized.TestCase):

@parameterized.parameters(False, True)
def test_parse_arg_works(self, lazy):
spec = dict(
res=224,
lr=0.1,
runlocal=False,
schedule='short',
)

def check(result, runlocal, schedule, res, lr):
self.assertEqual(result.runlocal, runlocal)
self.assertEqual(result.schedule, schedule)
self.assertEqual(result.res, res)
self.assertEqual(result.lr, lr)
self.assertIsInstance(result.runlocal, bool)
self.assertIsInstance(result.schedule, str)
self.assertIsInstance(result.res, int)
self.assertIsInstance(result.lr, float)

check(cutil.parse_arg(None, lazy=lazy, **spec), False, 'short', 224, 0.1)
check(cutil.parse_arg('', lazy=lazy, **spec), False, 'short', 224, 0.1)
check(cutil.parse_arg('runlocal=True', lazy=lazy, **spec), True, 'short',
224, 0.1)
check(cutil.parse_arg('runlocal=False', lazy=lazy, **spec), False, 'short',
224, 0.1)
check(cutil.parse_arg('runlocal=', lazy=lazy, **spec), False, 'short', 224,
0.1)
check(cutil.parse_arg('runlocal', lazy=lazy, **spec), True, 'short', 224,
0.1)
check(cutil.parse_arg('res=128', lazy=lazy, **spec), False, 'short', 128,
0.1)
check(cutil.parse_arg('128', lazy=lazy, **spec), False, 'short', 128, 0.1)
check(cutil.parse_arg('schedule=long', lazy=lazy, **spec), False, 'long',
224, 0.1)
check(cutil.parse_arg('runlocal,schedule=long,res=128', lazy=lazy, **spec),
True, 'long', 128, 0.1)

@parameterized.parameters(
(None, {}, {}),
(None, {'res': 224}, {'res': 224}),
('640', {'res': 224}, {'res': 640}),
('runlocal', {}, {'runlocal': True}),
('res=640,lr=0.1,runlocal=false,schedule=long', {},
{'res': 640, 'lr': 0.1, 'runlocal': False, 'schedule': 'long'}),
)
def test_lazy_parse_arg_works(self, arg, spec, expected):
self.assertEqual(dict(cutil.parse_arg(arg, lazy=True, **spec)), expected)

def test_sequence_to_string(self):
seq = ['a', True, 1, 1.0]
self.assertEqual(cutil.sequence_to_string(seq), 'a,True,1,1.0')

def test_string_to_sequence(self):
self.assertEqual(
cutil.string_to_sequence('a,True,1,1.0'), ['a', True, 1, 1.0])

if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 7f4ae8f

Please sign in to comment.