From c710150168c502e4d6e7a45dab1207a88818c024 Mon Sep 17 00:00:00 2001 From: Moin Nadeem Date: Thu, 4 Jul 2024 23:05:11 +0000 Subject: [PATCH] exposing eps as a parameter --- torchcfm/conditional_flow_matching.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index b71b6bd..9605081 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -155,7 +155,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): def sample_noise_like(self, x): return torch.randn_like(x) - def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): + def sample_location_and_conditional_flow(self, x0, x1, t=None, eps=None, return_noise=False): """ Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. @@ -189,7 +189,8 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals t = torch.rand(x0.shape[0]).type_as(x0) assert len(t) == x0.shape[0], "t has to have batch size dimension" - eps = self.sample_noise_like(x0) + if eps is None: + eps = self.sample_noise_like(x0) xt = self.sample_xt(x0, x1, t, eps) ut = self.compute_conditional_flow(x0, x1, t, xt) if return_noise: