Skip to content

Commit

Permalink
exposing eps as a parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
moinnadeem committed Jul 4, 2024
1 parent ec4da08 commit c710150
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
def sample_noise_like(self, x):
return torch.randn_like(x)

def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_noise=False):
"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Expand Down Expand Up @@ -189,7 +189,8 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals
t = torch.rand(x0.shape[0]).type_as(x0)
assert len(t) == x0.shape[0], "t has to have batch size dimension"

eps = self.sample_noise_like(x0)
if eps is None:
eps = self.sample_noise_like(x0)
xt = self.sample_xt(x0, x1, t, eps)
ut = self.compute_conditional_flow(x0, x1, t, xt)
if return_noise:
Expand Down

0 comments on commit c710150

Please sign in to comment.