Skip to content

Commit

Permalink
Add inner loop aggregation options (#9)
Browse files Browse the repository at this point in the history
* Add aggregation option to inner_loop callback

Also added a inner_loop_kwargs argument to train_loop.

Added tests for inner_loop aggregation and inner_loop_kwargs.

* Fix inconsistent quotes

* Make "last" the default aggregation with dict

* Allow Callable aggregations

* Fix formatting

* fix pre-commit

---------

Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com>
  • Loading branch information
JamesAllingham and cgarciae authored Aug 24, 2023
1 parent 2d5b0e6 commit f39f265
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 6 deletions.
80 changes: 74 additions & 6 deletions ciclo/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@
from dataclasses import dataclass, replace
from datetime import datetime
from enum import Enum, auto
from typing import Any, Callable, Dict, Optional, Tuple, Union, overload
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
overload,
)

from pkbar import Kbar
from tqdm import tqdm

from ciclo.logging import Logs
from ciclo.logging import Collection, Entry, History, Logs
from ciclo.loops.loop import (
CallbackOutput,
LoopCallbackBase,
Expand All @@ -22,6 +33,11 @@
from ciclo.types import Batch, S
from ciclo.utils import get_batch_size, is_scalar

AggregationFn = Callable[[List[Any]], Any]
InnerLoopAggregation = Union[
Literal["last", "mean", "sum", "min", "max", "first"], AggregationFn
]


def unavailable_dependency(msg: str) -> Any:
class DependencyNotAvailable(LoopCallbackBase[S]):
Expand All @@ -39,6 +55,22 @@ class OptimizationMode(str, Enum):
max = auto()


def _transpose_history(
log_history: History,
) -> Mapping[Collection, Mapping[Entry, List[Any]]]:
"""Convert a list of (nested) log dictionaries into a (nested) dictionary of lists."""
result = {}
for log_dict in log_history:
for collection, entries in log_dict.items():
if collection not in result:
result[collection] = {}
for entry, value in entries.items():
if entry not in result[collection]:
result[collection][entry] = []
result[collection][entry].append(value)
return result


class inner_loop(LoopCallbackBase[S]):
@overload
def __init__(
Expand All @@ -65,6 +97,9 @@ def __init__(
maybe_loop_fn: Optional[Callable[[S], LoopOutput[S]]] = None,
*,
output_state: bool = False,
aggregation: Union[
InnerLoopAggregation, Mapping[Collection, InnerLoopAggregation]
] = "last",
):
if isinstance(name_or_loop_fn, str):
assert maybe_loop_fn is not None
Expand All @@ -75,17 +110,20 @@ def __init__(
self.name = None
self.loop_fn = name_or_loop_fn
self.output_state = output_state
self.aggregation = aggregation

def __call__(self, state: S) -> Tuple[Logs, S]:
inner_state, log_history, _ = self.loop_fn(state)
logs = log_history[-1] if len(log_history) > 0 else Logs()
logs = _transpose_history(log_history)
logs = Logs(
{
collection: {
k + f"_{self.name}" if self.name else k: v
for k, v in values.items()
entry + f"_{self.name}"
if self.name
else entry: self.__get_aggregation_fn(collection)(values)
for entry, values in entries.items()
}
for collection, values in logs.items()
for collection, entries in logs.items()
if collection != "elapsed"
}
)
Expand All @@ -94,6 +132,36 @@ def __call__(self, state: S) -> Tuple[Logs, S]:
def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
return self(loop_state.state)

def __get_aggregation_fn(self, collection: Collection) -> AggregationFn:
if isinstance(self.aggregation, Mapping):
aggregation = self.aggregation.get(collection, "last")
error_message = f"The aggregation ({aggregation}) for collection {collection} must be a str or Callable."
else:
aggregation = self.aggregation
error_message = (
f"The aggregation ({aggregation}) must be a str or Callable."
)

if not (isinstance(aggregation, str) or isinstance(aggregation, Callable)):
raise ValueError(error_message)

if aggregation == "last":
return lambda x: x[-1]
elif aggregation == "mean":
return lambda x: sum(x) / len(x)
elif aggregation == "sum":
return sum
elif aggregation == "min":
return min
elif aggregation == "max":
return max
elif aggregation == "first":
return lambda x: x[0]
elif isinstance(aggregation, Callable):
return aggregation
else:
raise ValueError(f"Invalid aggregation: {aggregation}")


if importlib.util.find_spec("tensorflow") is not None:
from flax.training import checkpoints as flax_checkpoints
Expand Down
5 changes: 5 additions & 0 deletions ciclo/loops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def train_loop(
catch_keyboard_interrupt: bool = True,
metadata: Optional[Any] = None,
batch_size_fn: Optional[Callable[[List[Tuple[int, ...]]], int]] = None,
inner_loop_kwargs: Optional[Dict[str, Any]] = None,
) -> LoopOutput[S]:
if tasks is None:
tasks = {}
Expand All @@ -83,6 +84,9 @@ def train_loop(
if isinstance(test_duration, int):
test_duration = Period.create(steps=test_duration)

if inner_loop_kwargs is None:
inner_loop_kwargs = {}

additionl_tasks: Dict[ScheduleLike, CallbackOrList] = {}
named_tasks: Dict[str, CallbackOrList] = {}
for schedule in list(tasks.keys()):
Expand Down Expand Up @@ -145,6 +149,7 @@ def train_loop(
stop=test_duration,
batch_size_fn=batch_size_fn,
),
**inner_loop_kwargs,
)
)
test_tasks += named_tasks.pop(ON_EPOCH_END, [])
Expand Down
135 changes: 135 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import jax.numpy as jnp

import ciclo


def dummy_inner_loop_fn(_):
log_history = [
{
"stateful_metrics": {
"A": jnp.array(1.0, dtype=jnp.float32),
"B": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C": jnp.array(0.0, dtype=jnp.float32),
"D": jnp.array(0.0, dtype=jnp.float32),
},
"elapsed": {
"steps": 1,
"samples": 1,
},
},
{
"stateful_metrics": {
"A": jnp.array(0.0, dtype=jnp.float32),
},
"metrics": {
"C": jnp.array(1.0, dtype=jnp.float32),
},
"elapsed": {
"steps": 2,
"samples": 2,
},
},
]
return None, log_history, None


class TestCallbacks:
def test_inner_loop_default_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(0.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(1.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

def test_inner_loop_callable_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation=sum,
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(1.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(1.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

def test_inner_loop_mean_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation="mean",
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(0.5, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(0.5, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

def test_inner_loop_aggregation_dict(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation={"stateful_metrics": "sum", "metrics": "min"},
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(1.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(0.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation={"stateful_metrics": "first"},
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(1.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(1.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}
38 changes: 38 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,41 @@ def data():

assert a_list == list(range(1, 4))
assert b_list == list(range(-1, -4, -1))

def test_inner_loop_kwargs(self):
def increment(state, key):
state[key] += 1
logs = ciclo.logs()
logs.add_metric(key, state[key])
return logs, state

state = {"a": 0}

state, history, _ = ciclo.train_loop(
state,
ciclo.elapse(range(1)),
{
ciclo.on_test_step: lambda state: increment(state, "a"),
},
test_dataset=lambda: ciclo.elapse(range(4)),
epoch_duration=1,
stop=1,
)

assert history[0]["metrics"]["a_test"] == 4

state = {"a": 0}

state, history, _ = ciclo.train_loop(
state,
ciclo.elapse(range(1)),
{
ciclo.on_test_step: lambda state: increment(state, "a"),
},
test_dataset=lambda: ciclo.elapse(range(4)),
epoch_duration=1,
stop=1,
inner_loop_kwargs={"aggregation": "sum"},
)

assert history[0]["metrics"]["a_test"] == 10

0 comments on commit f39f265

Please sign in to comment.