Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inner loop aggregation options #9

Merged
merged 6 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 74 additions & 6 deletions ciclo/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@
from dataclasses import dataclass, replace
from datetime import datetime
from enum import Enum, auto
from typing import Any, Callable, Dict, Optional, Tuple, Union, overload
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
overload,
)

from pkbar import Kbar
from tqdm import tqdm

from ciclo.logging import Logs
from ciclo.logging import Collection, Entry, History, Logs
from ciclo.loops.loop import (
CallbackOutput,
LoopCallbackBase,
Expand All @@ -22,6 +33,11 @@
from ciclo.types import Batch, S
from ciclo.utils import get_batch_size, is_scalar

AggregationFn = Callable[[List[Any]], Any]
InnerLoopAggregation = Union[
Literal["last", "mean", "sum", "min", "max", "first"], AggregationFn
]


def unavailable_dependency(msg: str) -> Any:
class DependencyNotAvailable(LoopCallbackBase[S]):
Expand All @@ -39,6 +55,22 @@
max = auto()


def _transpose_history(
log_history: History,
) -> Mapping[Collection, Mapping[Entry, List[Any]]]:
"""Convert a list of (nested) log dictionaries into a (nested) dictionary of lists."""
result = {}
for log_dict in log_history:
for collection, entries in log_dict.items():
if collection not in result:
result[collection] = {}
for entry, value in entries.items():
if entry not in result[collection]:
result[collection][entry] = []
result[collection][entry].append(value)
return result


class inner_loop(LoopCallbackBase[S]):
@overload
def __init__(
Expand All @@ -65,6 +97,9 @@
maybe_loop_fn: Optional[Callable[[S], LoopOutput[S]]] = None,
*,
output_state: bool = False,
aggregation: Union[
InnerLoopAggregation, Mapping[Collection, InnerLoopAggregation]
] = "last",
):
if isinstance(name_or_loop_fn, str):
assert maybe_loop_fn is not None
Expand All @@ -75,17 +110,20 @@
self.name = None
self.loop_fn = name_or_loop_fn
self.output_state = output_state
self.aggregation = aggregation

def __call__(self, state: S) -> Tuple[Logs, S]:
inner_state, log_history, _ = self.loop_fn(state)
logs = log_history[-1] if len(log_history) > 0 else Logs()
logs = _transpose_history(log_history)
logs = Logs(
{
collection: {
k + f"_{self.name}" if self.name else k: v
for k, v in values.items()
entry + f"_{self.name}"
if self.name
else entry: self.__get_aggregation_fn(collection)(values)
for entry, values in entries.items()
}
for collection, values in logs.items()
for collection, entries in logs.items()
if collection != "elapsed"
}
)
Expand All @@ -94,6 +132,36 @@
def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
return self(loop_state.state)

def __get_aggregation_fn(self, collection: Collection) -> AggregationFn:
if isinstance(self.aggregation, Mapping):
aggregation = self.aggregation.get(collection, "last")
error_message = f"The aggregation ({aggregation}) for collection {collection} must be a str or Callable."
else:
aggregation = self.aggregation
error_message = (
f"The aggregation ({aggregation}) must be a str or Callable."
)

if not (isinstance(aggregation, str) or isinstance(aggregation, Callable)):
raise ValueError(error_message)

Check warning on line 146 in ciclo/callbacks.py

View check run for this annotation

Codecov / codecov/patch

ciclo/callbacks.py#L146

Added line #L146 was not covered by tests

if aggregation == "last":
return lambda x: x[-1]
elif aggregation == "mean":
return lambda x: sum(x) / len(x)
elif aggregation == "sum":
return sum
elif aggregation == "min":
return min
elif aggregation == "max":
return max

Check warning on line 157 in ciclo/callbacks.py

View check run for this annotation

Codecov / codecov/patch

ciclo/callbacks.py#L157

Added line #L157 was not covered by tests
elif aggregation == "first":
return lambda x: x[0]
elif isinstance(aggregation, Callable):
return aggregation
else:
raise ValueError(f"Invalid aggregation: {aggregation}")

Check warning on line 163 in ciclo/callbacks.py

View check run for this annotation

Codecov / codecov/patch

ciclo/callbacks.py#L163

Added line #L163 was not covered by tests


if importlib.util.find_spec("tensorflow") is not None:
from flax.training import checkpoints as flax_checkpoints
Expand Down
5 changes: 5 additions & 0 deletions ciclo/loops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def train_loop(
catch_keyboard_interrupt: bool = True,
metadata: Optional[Any] = None,
batch_size_fn: Optional[Callable[[List[Tuple[int, ...]]], int]] = None,
inner_loop_kwargs: Optional[Dict[str, Any]] = None,
) -> LoopOutput[S]:
if tasks is None:
tasks = {}
Expand All @@ -83,6 +84,9 @@ def train_loop(
if isinstance(test_duration, int):
test_duration = Period.create(steps=test_duration)

if inner_loop_kwargs is None:
inner_loop_kwargs = {}

additionl_tasks: Dict[ScheduleLike, CallbackOrList] = {}
named_tasks: Dict[str, CallbackOrList] = {}
for schedule in list(tasks.keys()):
Expand Down Expand Up @@ -145,6 +149,7 @@ def train_loop(
stop=test_duration,
batch_size_fn=batch_size_fn,
),
**inner_loop_kwargs,
)
)
test_tasks += named_tasks.pop(ON_EPOCH_END, [])
Expand Down
135 changes: 135 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import jax.numpy as jnp

import ciclo


def dummy_inner_loop_fn(_):
log_history = [
{
"stateful_metrics": {
"A": jnp.array(1.0, dtype=jnp.float32),
"B": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C": jnp.array(0.0, dtype=jnp.float32),
"D": jnp.array(0.0, dtype=jnp.float32),
},
"elapsed": {
"steps": 1,
"samples": 1,
},
},
{
"stateful_metrics": {
"A": jnp.array(0.0, dtype=jnp.float32),
},
"metrics": {
"C": jnp.array(1.0, dtype=jnp.float32),
},
"elapsed": {
"steps": 2,
"samples": 2,
},
},
]
return None, log_history, None


class TestCallbacks:
def test_inner_loop_default_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(0.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(1.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

def test_inner_loop_callable_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation=sum,
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(1.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(1.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

def test_inner_loop_mean_aggregation(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation="mean",
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(0.5, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(0.5, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

def test_inner_loop_aggregation_dict(self):
inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation={"stateful_metrics": "sum", "metrics": "min"},
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(1.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(0.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}

inner_loop = ciclo.callbacks.inner_loop(
"test",
dummy_inner_loop_fn,
aggregation={"stateful_metrics": "first"},
)

log_history, _ = inner_loop(None)

assert log_history == {
"stateful_metrics": {
"A_test": jnp.array(1.0, dtype=jnp.float32),
"B_test": jnp.array(1.0, dtype=jnp.float32),
},
"metrics": {
"C_test": jnp.array(1.0, dtype=jnp.float32),
"D_test": jnp.array(0.0, dtype=jnp.float32),
},
}
38 changes: 38 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,41 @@ def data():

assert a_list == list(range(1, 4))
assert b_list == list(range(-1, -4, -1))

def test_inner_loop_kwargs(self):
def increment(state, key):
state[key] += 1
logs = ciclo.logs()
logs.add_metric(key, state[key])
return logs, state

state = {"a": 0}

state, history, _ = ciclo.train_loop(
state,
ciclo.elapse(range(1)),
{
ciclo.on_test_step: lambda state: increment(state, "a"),
},
test_dataset=lambda: ciclo.elapse(range(4)),
epoch_duration=1,
stop=1,
)

assert history[0]["metrics"]["a_test"] == 4

state = {"a": 0}

state, history, _ = ciclo.train_loop(
state,
ciclo.elapse(range(1)),
{
ciclo.on_test_step: lambda state: increment(state, "a"),
},
test_dataset=lambda: ciclo.elapse(range(4)),
epoch_duration=1,
stop=1,
inner_loop_kwargs={"aggregation": "sum"},
)

assert history[0]["metrics"]["a_test"] == 10
Loading