Skip to content

Commit

Permalink
Add a way to patch blocks in SD3.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Oct 29, 2024
1 parent 13b0ff8 commit 30c0c81
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions comfy/ldm/modules/diffusionmodules/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,9 @@ def forward_core_with_concat(
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if self.register_length > 0:
context = torch.cat(
(
Expand All @@ -961,14 +963,25 @@ def forward_core_with_concat(

# context is B, L', D
# x is B, L, D
blocks_replace = patches_replace.get("dit", {})
blocks = len(self.joint_blocks)
for i in range(blocks):
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
return out

out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
context, x = self.joint_blocks[i](
context,
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
)
if control is not None:
control_o = control.get("output")
if i < len(control_o):
Expand All @@ -986,6 +999,7 @@ def forward(
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
) -> torch.Tensor:
"""
Forward pass of DiT.
Expand All @@ -1007,7 +1021,7 @@ def forward(
if context is not None:
context = self.context_embedder(context)

x = self.forward_core_with_concat(x, c, context, control)
x = self.forward_core_with_concat(x, c, context, control, transformer_options)

x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]]
Expand All @@ -1021,7 +1035,8 @@ def forward(
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
control = None,
transformer_options = {},
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control)
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)

0 comments on commit 30c0c81

Please sign in to comment.