diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 918a1241f996..d49f55aa921d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -340,8 +340,8 @@ def __call__( masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - mask = mask.repeat(num_images_per_prompt, 1, 1, 1) - masked_image_latents = masked_image_latents.repeat(num_images_per_prompt, 1, 1, 1) + mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = (