-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add inner loop aggregation options (#9)
* Add aggregation option to inner_loop callback Also added a inner_loop_kwargs argument to train_loop. Added tests for inner_loop aggregation and inner_loop_kwargs. * Fix inconsistent quotes * Make "last" the default aggregation with dict * Allow Callable aggregations * Fix formatting * fix pre-commit --------- Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com>
- Loading branch information
1 parent
2d5b0e6
commit f39f265
Showing
4 changed files
with
252 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import jax.numpy as jnp | ||
|
||
import ciclo | ||
|
||
|
||
def dummy_inner_loop_fn(_): | ||
log_history = [ | ||
{ | ||
"stateful_metrics": { | ||
"A": jnp.array(1.0, dtype=jnp.float32), | ||
"B": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C": jnp.array(0.0, dtype=jnp.float32), | ||
"D": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
"elapsed": { | ||
"steps": 1, | ||
"samples": 1, | ||
}, | ||
}, | ||
{ | ||
"stateful_metrics": { | ||
"A": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"elapsed": { | ||
"steps": 2, | ||
"samples": 2, | ||
}, | ||
}, | ||
] | ||
return None, log_history, None | ||
|
||
|
||
class TestCallbacks: | ||
def test_inner_loop_default_aggregation(self): | ||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(0.0, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(1.0, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} | ||
|
||
def test_inner_loop_callable_aggregation(self): | ||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
aggregation=sum, | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(1.0, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(1.0, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} | ||
|
||
def test_inner_loop_mean_aggregation(self): | ||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
aggregation="mean", | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(0.5, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(0.5, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} | ||
|
||
def test_inner_loop_aggregation_dict(self): | ||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
aggregation={"stateful_metrics": "sum", "metrics": "min"}, | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(1.0, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(0.0, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} | ||
|
||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
aggregation={"stateful_metrics": "first"}, | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(1.0, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(1.0, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters