diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index b71b6bd..9605081 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -155,7 +155,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): def sample_noise_like(self, x): return torch.randn_like(x) - def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): + def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_noise=False): """ Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. @@ -189,7 +189,8 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals t = torch.rand(x0.shape[0]).type_as(x0) assert len(t) == x0.shape[0], "t has to have batch size dimension" - eps = self.sample_noise_like(x0) + if eps is None: + eps = self.sample_noise_like(x0) xt = self.sample_xt(x0, x1, t, eps) ut = self.compute_conditional_flow(x0, x1, t, xt) if return_noise: