diff --git a/pyro/infer/mcmc/api.py b/pyro/infer/mcmc/api.py index 5a2a3f1535..2666188fbc 100644 --- a/pyro/infer/mcmc/api.py +++ b/pyro/infer/mcmc/api.py @@ -385,13 +385,15 @@ def run(self, *args, **kwargs): # If transforms is not explicitly provided, infer automatically using # model args, kwargs. if self.transforms is None: - if hasattr(self.kernel, 'transforms'): - if self.kernel.transforms is not None: - self.transforms = self.kernel.transforms + # Use `kernel.transforms` when available + if hasattr(self.kernel, 'transforms') and self.kernel.transforms is not None: + self.transforms = self.kernel.transforms + # Else, get transforms from model (e.g. in multiprocessing). elif self.kernel.model: _, _, self.transforms, _ = initialize_model(self.kernel.model, model_args=args, model_kwargs=kwargs) + # Assign default value else: self.transforms = {}