Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor generation backend #4201

Merged
merged 70 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
9aaf67c
wip
StAlKeR7779 Aug 6, 2023
b0738b7
Fixes, zero tensor for empty negative prompt, remove raw prompt node
StAlKeR7779 Aug 7, 2023
2539e26
Apply denoising_start/end, add torch-sdp to memory effictiend attenti…
StAlKeR7779 Aug 7, 2023
1db2c93
Fix preview, inpaint
StAlKeR7779 Aug 7, 2023
492bfe0
Remove sdxl t2l/l2l nodes
StAlKeR7779 Aug 8, 2023
5f29526
Add seed to latents field
StAlKeR7779 Aug 8, 2023
96b7248
Add mask to l2l
StAlKeR7779 Aug 8, 2023
da0184a
Invert mask, fix l2l on no mask conntected, remove zeroing latents on…
StAlKeR7779 Aug 8, 2023
a7e4467
Remove legacy/unused code
StAlKeR7779 Aug 8, 2023
f7aec3b
Move conditioning class to backend
StAlKeR7779 Aug 8, 2023
b4a74f6
Add MaskEdge and ColorCorrect nodes
StAlKeR7779 Aug 8, 2023
e98f7ed
Fix total_steps in generation event, order field added
StAlKeR7779 Aug 9, 2023
ade78b9
Merge branch 'main' into feat/refactor_generation_backend
StAlKeR7779 Aug 10, 2023
17fed1c
Fix merge conflict errors
StAlKeR7779 Aug 10, 2023
e9ec5ab
Apply requested changes
StAlKeR7779 Aug 10, 2023
231e665
Merge branch 'main' into feat/refactor_generation_backend
blessedcoolant Aug 11, 2023
7c0023a
feat: Remove TextToLatents / Rename Latents To Latents -> DenoiseLatents
blessedcoolant Aug 11, 2023
87ce4ab
fix: Update default_graph to use new DenoiseLatents
blessedcoolant Aug 11, 2023
7479f9c
feat: Update LinearUI to use new backend (except Inpaint)
blessedcoolant Aug 11, 2023
f3ae52f
Fix error at high denoising_start, fix unipc(cpu_only)
StAlKeR7779 Aug 11, 2023
69a9dc7
wip: Add initial Inpaint Graph
blessedcoolant Aug 11, 2023
1affb7f
feat: Add Paste / Mask Blur / Color Correction to Inpainting
blessedcoolant Aug 11, 2023
5629d8f
fix; Key issue in Lora List
blessedcoolant Aug 11, 2023
58a48bf
fix: LoRA list name sorting
blessedcoolant Aug 11, 2023
d7d6298
feat: Add Infill Method support
blessedcoolant Aug 11, 2023
f343ab0
wip: Port Outpainting to new backend
blessedcoolant Aug 11, 2023
7293a60
feat(wip): Add SDXL To Canvas
blessedcoolant Aug 11, 2023
8acd7ee
feat: Disable clip skip for SDXL Canvas
blessedcoolant Aug 11, 2023
ce3675f
Apply denoising_start/end according on timestep value
StAlKeR7779 Aug 12, 2023
6034fa1
feat: Add Mask Blur node
blessedcoolant Aug 12, 2023
7254ffc
chore: Split Inpaint and Outpaint Graphs
blessedcoolant Aug 12, 2023
7587b54
chore: Cleanup, comment and organize Node Graphs
blessedcoolant Aug 12, 2023
9f6221f
Merge branch 'main' into feat/refactor_generation_backend
blessedcoolant Aug 12, 2023
f296e5c
wip: Remove MaskBlur / Adjust color correction
blessedcoolant Aug 12, 2023
27bd127
fix: Do not add anything but final output to staging area
blessedcoolant Aug 12, 2023
ad96c41
feat: Add Canvas Output node to all Canvas Graphs
blessedcoolant Aug 12, 2023
746c7c5
fix: remove extra node for canvas output catch
blessedcoolant Aug 12, 2023
55d27f7
feat: Give each graph its own unique id
blessedcoolant Aug 12, 2023
500cd55
feat: Make SDXL work across the board + Custom VAE Support
blessedcoolant Aug 12, 2023
c33acf9
feat: Make Refiner work with Canvas
blessedcoolant Aug 12, 2023
28208e6
fix: Fix VAE Precision not working for SDXL Canvas Modes
blessedcoolant Aug 12, 2023
29f1c6d
fix: Image To Image FP32 Fix for Canvas SDXL
blessedcoolant Aug 12, 2023
fcf7f4a
feat: Add SDXL ControlNet To Linear UI
blessedcoolant Aug 12, 2023
c8864e4
fix: SDXL Lora's not working on Canvas Image To Image
blessedcoolant Aug 12, 2023
b35cdc0
feat: Scaled Processing to Inpainting & Outpainting / 1.x & SDXL
blessedcoolant Aug 13, 2023
33779b6
chore: Remove shouldFitToWidthHeight from Inpaint Graphs
blessedcoolant Aug 13, 2023
3ff9961
fix: Circular dependency in Mask Blur Method
blessedcoolant Aug 13, 2023
561951a
chore: Black linting
blessedcoolant Aug 13, 2023
90fa3ee
feat: Make SDXL Style Prompt not take spaces
blessedcoolant Aug 13, 2023
499e89d
feat: Add SDXL Negative Aesthetic Score
blessedcoolant Aug 13, 2023
746e099
fix: Do not do step math for refinerSteps
blessedcoolant Aug 13, 2023
94636dd
Fix empty prompt handling
StAlKeR7779 Aug 13, 2023
6e0beb1
Fixes for second order scheduler timesteps
StAlKeR7779 Aug 13, 2023
59ba9fc
Flip bits in seed for sde/ancestral schedulers to have different nois…
StAlKeR7779 Aug 13, 2023
7a8f14d
Clean-up code a bit
StAlKeR7779 Aug 13, 2023
096333b
Fix error on zero timesteps
StAlKeR7779 Aug 13, 2023
d63bb39
Make dpmpp_sde(_k) use not random seed
StAlKeR7779 Aug 13, 2023
75fb3f4
re: Readd Refiner Step Math but cap max steps to 1000
blessedcoolant Aug 13, 2023
cc85c98
feat: Upgrade Diffusers to 0.19.3
blessedcoolant Aug 13, 2023
550e6ef
re: Set the image denoise str back to 0
blessedcoolant Aug 13, 2023
fecad2c
fix: SDXL Denoising Strength not plugged in correctly
blessedcoolant Aug 13, 2023
957ee6d
fix: SDXL Canvas Inpaint & Outpaint not respecting SDXL Refiner start…
blessedcoolant Aug 14, 2023
3d8da67
Remove callback-generator wrapper
StAlKeR7779 Aug 14, 2023
58d5c61
fix: SDXL Inpaint & Outpaint using regular Img2Img strength
blessedcoolant Aug 14, 2023
409e5d0
Fix cpu_only schedulers(unipc)
StAlKeR7779 Aug 14, 2023
511da59
Add magic to debug
StAlKeR7779 Aug 14, 2023
9217a21
fix(ui): refiner uses steps directly, no math
psychedelicious Aug 14, 2023
9fee3f7
Revert "Add magic to debug"
psychedelicious Aug 14, 2023
46a8eed
Merge branch 'main' into feat/refactor_generation_backend
psychedelicious Aug 14, 2023
9d3cd85
chore: black
psychedelicious Aug 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 32 additions & 209 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ...backend.model_management import ModelType
from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .model import ClipField
from dataclasses import dataclass
Expand All @@ -29,28 +29,9 @@ class Config:
schema_extra = {"required": ["conditioning_name"]}


@dataclass
class BasicConditioningInfo:
# type: Literal["basic_conditioning"] = "basic_conditioning"
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo


@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor


ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]


@dataclass
class ConditioningFieldData:
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
conditionings: List[BasicConditioningInfo]
# unconditioned: Optional[torch.Tensor]


Expand Down Expand Up @@ -176,7 +157,15 @@ def _lora_loader():


class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
def run_clip_compel(
self,
context: InvocationContext,
clip_field: ClipField,
prompt: str,
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
Expand All @@ -186,82 +175,21 @@ def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
context=context,
)

def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight)
del lora_info
return

# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]

ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model,
)
)
except ModelNotFoundException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')

with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.context.model
c = torch.zeros(
(1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size),
dtype=text_encoder_info.context.cache.precision,
)
if get_pooled:
c_pooled = prompt_embeds[0]
c_pooled = torch.zeros(
(1, cpu_text_encoder.config.hidden_size),
dtype=c.dtype,
)
else:
c_pooled = None
c = prompt_embeds.hidden_states[-2]

del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info

c = c.detach().to("cpu")
if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu")

return c, c_pooled, None

def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
context=context,
)
return c, c_pooled, None

def _lora_loader():
for lora in clip_field.loras:
Expand Down Expand Up @@ -366,11 +294,17 @@ class Config(InvocationConfig):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
)

original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
Expand Down Expand Up @@ -425,118 +359,7 @@ class Config(InvocationConfig):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")

original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)

add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])

conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
)
]
)

conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)

return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)


class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Pass unmodified prompt to conditioning without compel processing."""

type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"

prompt: str = Field(default="", description="Prompt")
style: str = Field(default="", description="Style prompt")
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
target_width: int = Field(1024, description="")
target_height: int = Field(1024, description="")
clip: ClipField = Field(None, description="Clip to use")
clip2: ClipField = Field(None, description="Clip2 to use")

# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
}

@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
else:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")

original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width)

add_time_ids = torch.tensor([original_size + crop_coords + target_size])

conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1),
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec1,
)
]
)

conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)

return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)


class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""

type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"

style: str = Field(default="", description="Style prompt") # TODO: ?
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
crop_left: int = Field(0, description="")
aesthetic_score: float = Field(6.0, description="")
clip2: ClipField = Field(None, description="Clip to use")

# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Refiner Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {"model": "model"},
},
}

@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)

original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
Expand Down
Loading
Loading