-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
11 changed files
with
1,817 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Utilities for model checkpointing.""" | ||
|
||
import dataclasses | ||
import re | ||
from typing import Any, Optional, Sequence, TypeVar | ||
|
||
from clu import checkpoint as checkpoint_lib | ||
from etils import epath | ||
import flax | ||
import grain.python as grain | ||
import grain.tensorflow as tfgrain | ||
from orbax import checkpoint as ocp | ||
import tensorflow as tf | ||
|
||
|
||
T = TypeVar('T') | ||
|
||
|
||
class MixedMultihostCheckpoint(checkpoint_lib.MultihostCheckpoint): | ||
"""Like MultihostCheckpoint, but with a single source of FLAX weights. | ||
TF settings are restored per-host as in the base class. | ||
This prevents the model from loading potentially inconsistent weights | ||
saved by other hosts. Weights might be inconsistent when they are saved | ||
based on wall-clock time instead of step count. | ||
""" | ||
|
||
def load_state( | ||
self, state: Optional[T], checkpoint: Optional[str] = None | ||
) -> T: | ||
flax_path = self._flax_path(self._checkpoint_or_latest(checkpoint)) | ||
flax_path = re.sub('checkpoints-[0-9]*', 'checkpoints-0', flax_path) | ||
if not tf.io.gfile.exists(flax_path): | ||
raise FileNotFoundError(f'Checkpoint {checkpoint} does not exist') | ||
with tf.io.gfile.GFile(flax_path, 'rb') as f: | ||
return flax.serialization.from_bytes(state, f.read()) | ||
|
||
|
||
def get_checkpoint_manager( | ||
workdir: epath.PathLike, | ||
item_names: Sequence[str], | ||
) -> ocp.CheckpointManager: | ||
"""Returns a checkpoint manager.""" | ||
checkpoint_dir = epath.Path(workdir) / 'checkpoints' | ||
return ocp.CheckpointManager( | ||
checkpoint_dir, | ||
item_names=item_names, | ||
options=ocp.CheckpointManagerOptions( | ||
create=True, cleanup_tmp_directories=True), | ||
) | ||
|
||
|
||
def save_checkpoint( | ||
manager: ocp.CheckpointManager, | ||
state: Any, | ||
step: int, | ||
pygrain_checkpointers: Sequence[str] = ('train_iter',), | ||
wait_until_finished: bool = True, | ||
): | ||
"""Saves a checkpoint. | ||
Args: | ||
manager: Checkpoint manager to use. | ||
state: Data to be saved. | ||
step: Step at which to save the data. | ||
pygrain_checkpointers: Names of items for which to use pygrain checkpointer. | ||
wait_until_finished: If True, blocks until checkpoint is written. | ||
""" | ||
save_args_dict = {} | ||
for k, v in state.items(): | ||
if k in pygrain_checkpointers: | ||
save_args_dict[k] = grain.PyGrainCheckpointSave(v) | ||
else: | ||
save_args_dict[k] = ocp.args.StandardSave(v) | ||
manager.save(step, args=ocp.args.Composite(**save_args_dict)) | ||
if wait_until_finished: | ||
manager.wait_until_finished() | ||
|
||
|
||
def restore_checkpoint( | ||
manager: ocp.CheckpointManager, | ||
state: Any, | ||
step: int | None = None, | ||
pygrain_checkpointers: Sequence[str] = ('train_iter',), | ||
) -> Any: | ||
"""Restores a checkpoint. | ||
Args: | ||
manager: Checkpoint manager to use. | ||
state: Data to be restored. | ||
step: Step at which to save the data. If None, uses latest step. | ||
pygrain_checkpointers: Names of items for which to use pygrain checkpointer. | ||
Returns: | ||
Restored data. | ||
""" | ||
restore_args_dict = {} | ||
for k, v in state.items(): | ||
if k in pygrain_checkpointers: | ||
restore_args_dict[k] = grain.PyGrainCheckpointRestore(v) | ||
else: | ||
restore_args_dict[k] = ocp.args.StandardRestore(v) | ||
return manager.restore( | ||
manager.latest_step() if step is None else step, | ||
args=ocp.args.Composite(**restore_args_dict)) | ||
|
||
|
||
class TfGrainCheckpointHandler(tfgrain.OrbaxCheckpointHandler): | ||
|
||
def save(self, directory: epath.Path, args: 'TfGrainCheckpointArgs') -> None: | ||
return super().save(directory, args.item) | ||
|
||
def restore( | ||
self, directory: epath.Path, args: 'TfGrainCheckpointArgs' | ||
) -> tfgrain.TfGrainDatasetIterator: | ||
return super().restore(directory, args.item) | ||
|
||
|
||
@ocp.args.register_with_handler( # pytype:disable=wrong-arg-types | ||
TfGrainCheckpointHandler, for_save=True, for_restore=True | ||
) | ||
@dataclasses.dataclass | ||
class TfGrainCheckpointArgs(ocp.args.CheckpointArgs): | ||
item: Any |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Helper tools for config files. | ||
While configs should remain simple and self-explanatory, it can also be very | ||
useful to augment the configs with a bit of logic that helps organizing | ||
complicated sweeps. | ||
This module contains shared code that allows for powerful uncluttered configs. | ||
""" | ||
|
||
from typing import Any, Sequence | ||
|
||
import ml_collections as mlc | ||
|
||
|
||
def parse_arg(arg, lazy=False, **spec): | ||
"""Makes ConfigDict's get_config single-string argument more usable. | ||
Example use in the config file: | ||
import big_vision.configs.common as bvcc | ||
def get_config(arg): | ||
arg = bvcc.parse_arg(arg, | ||
res=(224, int), | ||
runlocal=False, | ||
schedule='short', | ||
) | ||
# ... | ||
config.shuffle_buffer = 250_000 if not arg.runlocal else 50 | ||
Ways that values can be passed when launching: | ||
--config amazing.py:runlocal,schedule=long,res=128 | ||
--config amazing.py:res=128 | ||
--config amazing.py:runlocal # A boolean needs no value for "true". | ||
--config amazing.py:runlocal=False # Explicit false boolean. | ||
--config amazing.py:128 # The first spec entry may be passed unnamed alone. | ||
Uses strict bool conversion (converting 'True', 'true' to True, and 'False', | ||
'false', '' to False). | ||
Args: | ||
arg: the string argument that's passed to get_config. | ||
lazy: allow lazy parsing of arguments, which are not in spec. For these, | ||
the type is auto-extracted in dependence of most complex possible type. | ||
**spec: the name and default values of the expected options. | ||
If the value is a tuple, the value's first element is the default value, | ||
and the second element is a function called to convert the string. | ||
Otherwise the type is automatically extracted from the default value. | ||
Returns: | ||
ConfigDict object with extracted type-converted values. | ||
""" | ||
# Normalize arg and spec layout. | ||
arg = arg or '' # Normalize None to empty string | ||
spec = {k: (v if isinstance(v, tuple) else (v, _get_type(v))) | ||
for k, v in spec.items()} | ||
|
||
result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only. | ||
|
||
# Expand convenience-cases for a single parameter without = sign. | ||
if arg and ',' not in arg and '=' not in arg: | ||
# (think :runlocal) If it's the name of sth in the spec (or there is no | ||
# spec), it's that in bool. | ||
if arg in spec or not spec: | ||
arg = f'{arg}=True' | ||
# Otherwise, it is the value for the first entry in the spec. | ||
else: | ||
arg = f'{list(spec.keys())[0]}={arg}' | ||
# Yes, we rely on Py3.7 insertion order! | ||
|
||
# Now, expand the `arg` string into a dict of keys and values: | ||
raw_kv = {raw_arg.split('=')[0]: | ||
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True' | ||
for raw_arg in arg.split(',') if raw_arg} | ||
|
||
# And go through the spec, using provided or default value for each: | ||
for name, (default, type_fn) in spec.items(): | ||
val = raw_kv.pop(name, None) | ||
result[name] = type_fn(val) if val is not None else default | ||
|
||
if raw_kv: | ||
if lazy: # Process args which are not in spec. | ||
for k, v in raw_kv.items(): | ||
result[k] = _autotype(v) | ||
else: | ||
raise ValueError(f'Unhandled config args remain: {raw_kv}') | ||
|
||
return result | ||
|
||
|
||
def _get_type(v): | ||
"""Returns type of v and for boolean returns a strict bool function.""" | ||
if isinstance(v, bool): | ||
def strict_bool(x): | ||
assert x.lower() in {'true', 'false', ''} | ||
return x.lower() == 'true' | ||
return strict_bool | ||
return type(v) | ||
|
||
|
||
def _autotype(x): | ||
"""Auto-converts string to bool/int/float if possible.""" | ||
assert isinstance(x, str) | ||
if x.lower() in {'true', 'false'}: | ||
return x.lower() == 'true' # Returns as bool. | ||
try: | ||
return int(x) # Returns as int. | ||
except ValueError: | ||
try: | ||
return float(x) # Returns as float. | ||
except ValueError: | ||
return x # Returns as str. | ||
|
||
|
||
def sequence_to_string(x: Sequence[Any], separator: str = ',') -> str: | ||
"""Converts sequence of str/bool/int/float to a concatenated string.""" | ||
return separator.join([str(i) for i in x]) | ||
|
||
|
||
def string_to_sequence(x: str, separator: str = ',') -> Sequence[Any]: | ||
"""Converts string to sequence of str/bool/int/float with auto-conversion.""" | ||
return [_autotype(i) for i in x.split(separator)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for config_util.""" | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
from connectomics.jax import config_util as cutil | ||
|
||
|
||
class ConfigUtilTest(parameterized.TestCase): | ||
|
||
@parameterized.parameters(False, True) | ||
def test_parse_arg_works(self, lazy): | ||
spec = dict( | ||
res=224, | ||
lr=0.1, | ||
runlocal=False, | ||
schedule='short', | ||
) | ||
|
||
def check(result, runlocal, schedule, res, lr): | ||
self.assertEqual(result.runlocal, runlocal) | ||
self.assertEqual(result.schedule, schedule) | ||
self.assertEqual(result.res, res) | ||
self.assertEqual(result.lr, lr) | ||
self.assertIsInstance(result.runlocal, bool) | ||
self.assertIsInstance(result.schedule, str) | ||
self.assertIsInstance(result.res, int) | ||
self.assertIsInstance(result.lr, float) | ||
|
||
check(cutil.parse_arg(None, lazy=lazy, **spec), False, 'short', 224, 0.1) | ||
check(cutil.parse_arg('', lazy=lazy, **spec), False, 'short', 224, 0.1) | ||
check(cutil.parse_arg('runlocal=True', lazy=lazy, **spec), True, 'short', | ||
224, 0.1) | ||
check(cutil.parse_arg('runlocal=False', lazy=lazy, **spec), False, 'short', | ||
224, 0.1) | ||
check(cutil.parse_arg('runlocal=', lazy=lazy, **spec), False, 'short', 224, | ||
0.1) | ||
check(cutil.parse_arg('runlocal', lazy=lazy, **spec), True, 'short', 224, | ||
0.1) | ||
check(cutil.parse_arg('res=128', lazy=lazy, **spec), False, 'short', 128, | ||
0.1) | ||
check(cutil.parse_arg('128', lazy=lazy, **spec), False, 'short', 128, 0.1) | ||
check(cutil.parse_arg('schedule=long', lazy=lazy, **spec), False, 'long', | ||
224, 0.1) | ||
check(cutil.parse_arg('runlocal,schedule=long,res=128', lazy=lazy, **spec), | ||
True, 'long', 128, 0.1) | ||
|
||
@parameterized.parameters( | ||
(None, {}, {}), | ||
(None, {'res': 224}, {'res': 224}), | ||
('640', {'res': 224}, {'res': 640}), | ||
('runlocal', {}, {'runlocal': True}), | ||
('res=640,lr=0.1,runlocal=false,schedule=long', {}, | ||
{'res': 640, 'lr': 0.1, 'runlocal': False, 'schedule': 'long'}), | ||
) | ||
def test_lazy_parse_arg_works(self, arg, spec, expected): | ||
self.assertEqual(dict(cutil.parse_arg(arg, lazy=True, **spec)), expected) | ||
|
||
def test_sequence_to_string(self): | ||
seq = ['a', True, 1, 1.0] | ||
self.assertEqual(cutil.sequence_to_string(seq), 'a,True,1,1.0') | ||
|
||
def test_string_to_sequence(self): | ||
self.assertEqual( | ||
cutil.string_to_sequence('a,True,1,1.0'), ['a', True, 1, 1.0]) | ||
|
||
if __name__ == '__main__': | ||
absltest.main() |
Oops, something went wrong.