diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index b052c48e06a..e70f4431f5a 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -71,45 +71,33 @@ def __init__( strict_img_size: bool = True, dynamic_img_pad: bool = True, padding_mode='circular', + conv3d=False, dtype=None, device=None, operations=None, ): super().__init__() - self.patch_size = (patch_size, patch_size) + try: + len(patch_size) + self.patch_size = patch_size + except: + if conv3d: + self.patch_size = (patch_size, patch_size, patch_size) + else: + self.patch_size = (patch_size, patch_size) self.padding_mode = padding_mode - if img_size is not None: - self.img_size = (img_size, img_size) - self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - else: - self.img_size = None - self.grid_size = None - self.num_patches = None # flatten spatial dim and transpose to channels last, kept for bwd compat self.flatten = flatten self.strict_img_size = strict_img_size self.dynamic_img_pad = dynamic_img_pad - - self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device) + if conv3d: + self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device) + else: + self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): - # B, C, H, W = x.shape - # if self.img_size is not None: - # if self.strict_img_size: - # _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") - # _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") - # elif not self.dynamic_img_pad: - # _assert( - # H % self.patch_size[0] == 0, - # f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." - # ) - # _assert( - # W % self.patch_size[1] == 0, - # f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." - # ) if self.dynamic_img_pad: x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode) x = self.proj(x)