Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inner loop aggregation options #9

Merged
merged 6 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions ciclo/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from dataclasses import dataclass, replace
from datetime import datetime
from enum import Enum, auto
from typing import Any, Callable, Dict, Optional, Tuple, Union, overload
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Tuple, Union, overload

from pkbar import Kbar
from tqdm import tqdm

from ciclo.logging import Logs
from ciclo.logging import Collection, Entry, History, Logs
from ciclo.loops.loop import (
CallbackOutput,
LoopCallbackBase,
Expand All @@ -23,6 +23,9 @@
from ciclo.utils import get_batch_size, is_scalar


InnerLoopAggregation = Literal["last", "mean", "sum", "min", "max", "first"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!



def unavailable_dependency(msg: str) -> Any:
class DependencyNotAvailable(LoopCallbackBase[S]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -39,6 +42,20 @@
max = auto()


def _transpose_history(log_history: History) -> Mapping[Collection, Mapping[Entry, List[Any]]]:
"""Convert a list of (nested) log dictionaries into a (nested) dictionary of lists."""
result = {}
for log_dict in log_history:
for collection, entries in log_dict.items():
if collection not in result:
result[collection] = {}
for entry, value in entries.items():
if entry not in result[collection]:
result[collection][entry] = []
result[collection][entry].append(value)
return result


class inner_loop(LoopCallbackBase[S]):
@overload
def __init__(
Expand All @@ -65,6 +82,9 @@
maybe_loop_fn: Optional[Callable[[S], LoopOutput[S]]] = None,
*,
output_state: bool = False,
aggregation: Union[
InnerLoopAggregation, Mapping[Collection, InnerLoopAggregation]
] = "last",
):
if isinstance(name_or_loop_fn, str):
assert maybe_loop_fn is not None
Expand All @@ -75,17 +95,20 @@
self.name = None
self.loop_fn = name_or_loop_fn
self.output_state = output_state
self.aggregation = aggregation

def __call__(self, state: S) -> Tuple[Logs, S]:
inner_state, log_history, _ = self.loop_fn(state)
logs = log_history[-1] if len(log_history) > 0 else Logs()
logs = _transpose_history(log_history)
logs = Logs(
{
collection: {
k + f"_{self.name}" if self.name else k: v
for k, v in values.items()
entry + f"_{self.name}"
if self.name
else entry: self.__get_aggregation_fn(collection)(values)
for entry, values in entries.items()
}
for collection, values in logs.items()
for collection, entries in logs.items()
if collection != "elapsed"
}
)
Expand All @@ -94,6 +117,27 @@
def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
return self(loop_state.state)

def __get_aggregation_fn(self, collection: Collection) -> Callable[[List[Any]], Any]:
if isinstance(self.aggregation, str):
aggregation = self.aggregation
else:
aggregation = self.aggregation.get(collection, "last")

if aggregation == "last":
return lambda x: x[-1]
elif aggregation == "mean":
return lambda x: sum(x) / len(x)
elif aggregation == "sum":
return sum
elif aggregation == "min":
return min
elif aggregation == "max":
return max

Check warning on line 135 in ciclo/callbacks.py

View check run for this annotation

Codecov / codecov/patch

ciclo/callbacks.py#L135

Added line #L135 was not covered by tests
elif aggregation == "first":
return lambda x: x[0]
else:
raise ValueError(f"Invalid aggregation: {aggregation}")

Check warning on line 139 in ciclo/callbacks.py

View check run for this annotation

Codecov / codecov/patch

ciclo/callbacks.py#L139

Added line #L139 was not covered by tests


if importlib.util.find_spec("tensorflow") is not None:
from flax.training import checkpoints as flax_checkpoints
Expand Down
5 changes: 5 additions & 0 deletions ciclo/loops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def train_loop(
catch_keyboard_interrupt: bool = True,
metadata: Optional[Any] = None,
batch_size_fn: Optional[Callable[[List[Tuple[int, ...]]], int]] = None,
inner_loop_kwargs: Optional[Dict[str, Any]] = None,
) -> LoopOutput[S]:
if tasks is None:
tasks = {}
Expand All @@ -83,6 +84,9 @@ def train_loop(
if isinstance(test_duration, int):
test_duration = Period.create(steps=test_duration)

if inner_loop_kwargs is None:
inner_loop_kwargs = {}

additionl_tasks: Dict[ScheduleLike, CallbackOrList] = {}
named_tasks: Dict[str, CallbackOrList] = {}
for schedule in list(tasks.keys()):
Expand Down Expand Up @@ -145,6 +149,7 @@ def train_loop(
stop=test_duration,
batch_size_fn=batch_size_fn,
),
**inner_loop_kwargs,
)
)
test_tasks += named_tasks.pop(ON_EPOCH_END, [])
Expand Down
115 changes: 115 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import jax.numpy as jnp

import ciclo


def dummy_inner_loop_fn(_):
log_history = [
{
"stateful_metrics": {
"A": jnp.array(1.0, dtype=jnp.float32),
"B": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C": jnp.array(0.0, dtype=jnp.float32),
"D": jnp.array(0.0, dtype=jnp.float32),
},
"elapsed": {
"steps": 1,
"samples": 1,
},
},
{
"stateful_metrics": {
"A": jnp.array(0.0, dtype=jnp.float32),
},
"metrics": {
"C": jnp.array(1.0, dtype=jnp.float32),
},
"elapsed": {
"steps": 2,
"samples": 2,
},
},
]
return None, log_history, None


class TestCallbacks:
def test_inner_loop_default_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(0.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(1.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

def test_inner_loop_mean_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation="mean",
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(0.5, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(0.5, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

def test_inner_loop_aggregation_dict(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation={"stateful_metrics": "sum", "metrics": "min"},
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(1.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(0.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation={"stateful_metrics": "first"},
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(1.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(1.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}
38 changes: 38 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,41 @@ def data():

assert a_list == list(range(1, 4))
assert b_list == list(range(-1, -4, -1))

def test_inner_loop_kwargs(self):
def increment(state, key):
state[key] += 1
logs = ciclo.logs()
logs.add_metric(key, state[key])
return logs, state

state = {"a": 0}

state, history, _ = ciclo.train_loop(
state,
ciclo.elapse(range(1)),
{
ciclo.on_test_step: lambda state: increment(state, "a"),
},
test_dataset=lambda: ciclo.elapse(range(4)),
epoch_duration=1,
stop=1,
)

assert history[0]["metrics"]["a_test"] == 4

state = {"a": 0}

state, history, _ = ciclo.train_loop(
state,
ciclo.elapse(range(1)),
{
ciclo.on_test_step: lambda state: increment(state, "a"),
},
test_dataset=lambda: ciclo.elapse(range(4)),
epoch_duration=1,
stop=1,
inner_loop_kwargs={"aggregation": "sum"},
)

assert history[0]["metrics"]["a_test"] == 10
Loading