-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add inner loop aggregation options #9
Merged
Merged
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
eca3a93
Add aggregation option to inner_loop callback
JamesAllingham 2a9abb0
Fix inconsistent quotes
JamesAllingham ff3ad05
Make "last" the default aggregation with dict
JamesAllingham ccaa6ed
Allow Callable aggregations
JamesAllingham 1d1f5b6
Fix formatting
JamesAllingham c441c0f
fix pre-commit
cgarciae File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import jax.numpy as jnp | ||
|
||
import ciclo | ||
|
||
|
||
def dummy_inner_loop_fn(_): | ||
log_history = [ | ||
{ | ||
"stateful_metrics": { | ||
"A": jnp.array(1.0, dtype=jnp.float32), | ||
"B": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C": jnp.array(0.0, dtype=jnp.float32), | ||
"D": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
"elapsed": { | ||
"steps": 1, | ||
"samples": 1, | ||
}, | ||
}, | ||
{ | ||
"stateful_metrics": { | ||
"A": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"elapsed": { | ||
"steps": 2, | ||
"samples": 2, | ||
}, | ||
}, | ||
] | ||
return None, log_history, None | ||
|
||
|
||
class TestCallbacks: | ||
def test_inner_loop_default_aggregation(self): | ||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(0.0, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(1.0, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} | ||
|
||
def test_inner_loop_mean_aggregation(self): | ||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
aggregation="mean", | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(0.5, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(0.5, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} | ||
|
||
def test_inner_loop_aggregation_dict(self): | ||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
aggregation={"stateful_metrics": "sum", "metrics": "min"}, | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(1.0, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(0.0, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} | ||
|
||
inner_loop = ciclo.callbacks.inner_loop( | ||
"test", | ||
dummy_inner_loop_fn, | ||
aggregation={"stateful_metrics": "first"}, | ||
) | ||
|
||
log_history, _ = inner_loop(None) | ||
|
||
assert log_history == { | ||
"stateful_metrics": { | ||
"A_test": jnp.array(1.0, dtype=jnp.float32), | ||
"B_test": jnp.array(1.0, dtype=jnp.float32), | ||
}, | ||
"metrics": { | ||
"C_test": jnp.array(1.0, dtype=jnp.float32), | ||
"D_test": jnp.array(0.0, dtype=jnp.float32), | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!