-
Notifications
You must be signed in to change notification settings - Fork 562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add combine_terms option to exact MLL #1863
base: main
Are you sure you want to change the base?
add combine_terms option to exact MLL #1863
Conversation
@jacobrgardner @gpleiss any thoughts? |
Yeah, this would be awesome to add! |
ba187db
to
328ebd0
Compare
@gpleiss how does everything look? |
@@ -203,12 +203,12 @@ def get_base_samples(self, sample_shape=torch.Size()): | |||
return base_samples.view(new_shape).transpose(-1, -2).contiguous() | |||
return base_samples.view(*sample_shape, *self._output_shape) | |||
|
|||
def log_prob(self, value): | |||
def log_prob(self, value, combine_terms=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I don't think we want to be adding flags to the standard log_prob call here to maintain compatibility with the MVN api in pytorch. let's have this be a _log_prob
method with the log_prob
just calling _log_prob(value=value, combine_terms=True)
?
@@ -142,7 +146,7 @@ def lazy_covariance_matrix(self): | |||
else: | |||
return lazify(super().covariance_matrix) | |||
|
|||
def log_prob(self, value): | |||
def log_prob(self, value, combine_terms=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, change to _log_prob
?
Looks like the failing unit test was flaky. |
@@ -59,9 +62,17 @@ def forward(self, function_dist, target, *params): | |||
|
|||
# Get the log prob of the marginal distribution | |||
output = self.likelihood(function_dist, *params) | |||
res = output.log_prob(target) | |||
res = self._add_other_terms(res, params) | |||
res = output.log_prob(target, combine_terms=self.combine_terms) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stylistically, the proposed change from log_prob
to _log_prob
is problematic here, because you would essentially be calling a "private" method publicly. More generally I think the combine_terms
option is broadly useful and burying it inside the class makes it harder to use.
Personally I don't see why the GPyTorch log_prob
API can't allow optional keyword arguments like combine_terms
, as long as the default behavior is consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gpleiss @jacobrgardner care to weigh in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a compromise would be to just call it log_prob_terms
instead of _log_prob
split_terms = [inv_quad, logdet, norm_const] | ||
split_terms = [-0.5 * term for term in split_terms] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split_terms = [inv_quad, logdet, norm_const] | |
split_terms = [-0.5 * term for term in split_terms] | |
split_terms = [-0.5 * inv_quad, logdet, -0.5 * norm_const] |
@@ -17,6 +19,7 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): | |||
|
|||
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model | |||
:param ~gpytorch.models.ExactGP model: The exact GP model | |||
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably also describe what happens if there are "other terms" (i.e. that they are added to the return elements)
actual = TMultivariateNormal(mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values) | ||
self.assertLess((res - actual).div(res).abs().item(), 1e-2) | ||
|
||
res2 = mvn.log_prob_terms(values) | ||
assert len(res2) == 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert len(res2) == 3 | |
self.assertEqual(len(res2), 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also in other places in the tests below
I've found that logging the inv_quad terms and logdet terms separately (rather than just the train loss) to be very helpful for debugging. Right now classes like
VariationalELBO
have acombine_terms
option that allow the user to sum the terms after the MLL call. This is a nice feature, since otherwise you essentially have to pay for an extra training step just to log the terms separately.In this PR I've demonstrated how we could go about adding this option to the subclasses of
MarginalLogLikelihood
, starting with the Gaussian likelihood case. There are a few unit tests that aren't passing yet, but I wanted to check and see if this feature would be approved before fixing it up.