You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been loving this library – it makes quickly spinning up a training loop so much easier and really reduces the amount of boilerplate code I'm writing.
Perhaps I'm not using it correctly, but I think that the current implementation of inner_loop implicitly assumes that the metrics being used are stateful. In particular, this code
logs = log_history[-1] if len(log_history) > 0 else Logs()
only makes sense if the metrics are stateful, and the last entry added to the history is a value that represents the final version of the metric after accumulating over the whole test dataset and then computing the result. If metrics are not stateful then this results in taking only the last batch-wise estimates. (This behavior is especially problematic when the last batch is smaller than the rest!)
Have I missed an obvious way around this issue? Happy to provide some minimum working examples if that would help.
If you think that this is an actual issue, I'd be happy to take a stab at a fix. Perhaps adding an aggregate option to inner_loop.
The text was updated successfully, but these errors were encountered:
Hey @JamesAllingham , thanks for reporting this. You are right, this assumes metrics (and all other collections in Logs) are stateful. Metrics from both clu.metrics and jax_metrics (are supported) and are stateful which is mostly what I've used.
Feel free to send a proposal / PR, happy to review it.
Hi Christian,
I've been loving this library – it makes quickly spinning up a training loop so much easier and really reduces the amount of boilerplate code I'm writing.
Perhaps I'm not using it correctly, but I think that the current implementation of
inner_loop
implicitly assumes that the metrics being used are stateful. In particular, this codeonly makes sense if the metrics are stateful, and the last entry added to the history is a value that represents the final version of the metric after accumulating over the whole test dataset and then computing the result. If metrics are not stateful then this results in taking only the last batch-wise estimates. (This behavior is especially problematic when the last batch is smaller than the rest!)
Have I missed an obvious way around this issue? Happy to provide some minimum working examples if that would help.
If you think that this is an actual issue, I'd be happy to take a stab at a fix. Perhaps adding an
aggregate
option toinner_loop
.The text was updated successfully, but these errors were encountered: