-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generating samples with their log-probability #18
Comments
This feature would be particularly necessary for applications like importance sampling, where samples and their probability should be computed with the least number of steps possible. I'm not sure if I have understood enough the zuko code, but would it work just changing this line? https://github.com/francois-rozet/zuko/blob/master/zuko/distributions.py#L121 from
to
I would be happy to help adding this feature 👍 |
Hello @simonschnake and @valsdav, thank you for your interest in Zuko 🔥 This is a sensible request. I just have a few questions:
The former would require to modify all transformations while the latter might only require to modify the
This could dramatically change the way the feature is implemented. |
I would need the I have roughly written an example using the NSF Flow. I could provide a pull request. Thanks a lot for your help |
Before submitting a PR, we need to decide how the feature should be implemented. It is a major API change so I don't want to rush it. IMHO, I like the idea of @valsdav to use the def rsample_and_log_prob(self, shape: Size = ()) -> Tuple[Tensor, Tensor]:
x = self.rsample(shape)
log_p = self.log_prob(x)
return x, log_p My idea is to modify the |
My proposition is something like diff --git a/zuko/distributions.py b/zuko/distributions.py
index b975e76..82a02e1 100644
--- a/zuko/distributions.py
+++ b/zuko/distributions.py
@@ -120,6 +120,18 @@ class NormalizingFlow(Distribution):
return self.transform.inv(z)
+ def rsample_and_log_prob(self, shape: Size = ()) -> Tuple[Tensor, Tensor]:
+ if self.base.has_rsample:
+ z = self.base.rsample(shape)
+ else:
+ z = self.base.sample(shape)
+
+ log_p = self.base.log_prob(z)
+ x, ladj = self.transform.inv.call_and_ladj(z)
+ ladj = _sum_rightmost(ladj, self.reinterpreted)
+
+ return x, log_p - ladj
+
class Joint(Distribution):
r"""Creates a distribution for a multivariate random variable :math:`X` which
diff --git a/zuko/transforms.py b/zuko/transforms.py
index e926ee6..e4c087b 100644
--- a/zuko/transforms.py
+++ b/zuko/transforms.py
@@ -107,6 +107,17 @@ class ComposedTransform(Transform):
x = t(x)
return x
+ @property
+ def inv(self):
+ new = self.__new__(ComposedTransform)
+ new.transforms = [t.inv for t in reversed(self.transforms)]
+ new.domain_dim = self.codomain_dim
+ new.codomain_dim = self.domain_dim
+
+ Transform.__init__(new)
+
+ return new
+
def _inverse(self, y: Tensor) -> Tensor:
for t in reversed(self.transforms):
y = t.inv(y) It provides a slight boost in performance, and I think it would make back-propagation slightly more stable. |
That seems very elegant and nice, but I have the feeling that I can get very complicated. I shortly designed a fitting version of the Could be that I misunderstand the architecture, and it is a lot simpler. |
Sorry, I was not clear enough. Only transformations that can profit from a faster |
Thanks a lot @francois-rozet ! 💯 |
Description
There a some use-cases (at least I have one), where some also needs the
ladj
,while calculating the inverse operation. For my use-case I am not only using the normalizing flow
to generate samples, but also want to know the likelihood of the produced samples.
Implementation
Implementation could be somewhat hard, because in principle every transformation would need to include another
method
inverse_and_ladj
. I am willing to help and contribute pull requests for that. My main focus at the moment are neural spline flows.I'm mainly opening this issue to figure out if this is wanted, and one has to decide how to introduce the functionality in the consuming classes of the transformations.
The text was updated successfully, but these errors were encountered: