Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched calculation option to energy_score_empirical in order to reduce memory consumption #3402

Merged
merged 2 commits into from
Sep 26, 2024

Conversation

BenZickel
Copy link
Contributor

The Problem

When calculating the energy score of large samples with pyro.ops.stats.energy_score_empirical the memory consumption can be very high which is prohibitive for some applications.

The Solution

Add an option to do the energy score calculation in batches by providing an additional optional parameter pred_batch_size which specifies the sample batch to use in the calculation. The calculation time is longer but the result stays the same up to numerical accuracy.

fritzo
fritzo previously approved these changes Sep 25, 2024
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after explaining or replacing the cat-then-add

pyro/ops/stats.py Outdated Show resolved Hide resolved
@fritzo fritzo merged commit 04c371f into pyro-ppl:dev Sep 26, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants