Skip to content

Commit

Permalink
Add a grace period between detecting a change and triggering generati…
Browse files Browse the repository at this point in the history
…on in live preview

* This will prevent some of the "useless" generations, e.g. from the very start of
  the brush stroke
* Period is configurable in settings; setting the default to 0 to
  preserve the existing behaviour

This at least partially addresses/follows the discussions in Acly#628 and Acly#1248
  • Loading branch information
modelflat committed Nov 26, 2024
1 parent e138bd5 commit 374b961
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
28 changes: 20 additions & 8 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def estimate_cost(self, kind=JobKind.diffusion):
def generate_live(self):
eventloop.run(_report_errors(self, self._generate_live()))

async def _generate_live(self, last_input: WorkflowInput | None = None):
def _prepare_live_job_params(self):
strength = self.live.strength
workflow_kind = WorkflowKind.generate if strength == 1.0 else WorkflowKind.refine
client = self._connection.client
Expand Down Expand Up @@ -346,12 +346,15 @@ async def _generate_live(self, last_input: WorkflowInput | None = None):
inpaint=inpaint if mask else None,
is_live=True,
)
params = JobParams(bounds, conditioning.positive, regions=job_regions)
return input, params

async def _generate_live(self, last_input: WorkflowInput | None = None):
input, job_params = self._prepare_live_job_params()
if input != last_input:
self.clear_error()
params = JobParams(bounds, conditioning.positive, regions=job_regions)
await self.enqueue_jobs(input, JobKind.live_preview, params)
await self.enqueue_jobs(input, JobKind.live_preview, job_params)
return input

return None

async def _generate_custom(self, previous_input: WorkflowInput | None):
Expand Down Expand Up @@ -872,12 +875,21 @@ def handle_job_finished(self, job: Job):
eventloop.run(_report_errors(self._model, self._continue_generating()))

async def _continue_generating(self):
just_got_here = True
while self.is_active and self._model.document.is_active:
new_input = await self._model._generate_live(self._last_input)
if new_input is not None: # frame was scheduled
self._last_input = new_input
return
new_input, _ = self._model._prepare_live_job_params()
if self._last_input != new_input:
if settings.live_redraw_grace_period > 0 and not just_got_here:
# only use grace period if this isn't our first frame of polling
# if it is, and there are changes in the input, it's likely that we have some changes we ignored
# previously due to the generation process running, and we need to update the preview asap
await asyncio.sleep(settings.live_redraw_grace_period)
new_input = await self._model._generate_live(self._last_input)
if new_input is not None:
self._last_input = new_input
return
# no changes in input data
just_got_here = False
await asyncio.sleep(self._poll_rate)

def apply_result(self, layer_only=False):
Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ class Settings(QObject):
_("Pick a new seed after copying the result to the canvas in Live mode"),
)

live_redraw_grace_period: float
_live_redraw_grace_period = Setting(
_("Live: Redraw grace period"),
0.0,
_("How long to delay scheduling the live preview job for after a change is made"),
)

prompt_translation: str
_prompt_translation = Setting(
_("Prompt Translation"),
Expand Down
4 changes: 4 additions & 0 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ def __init__(self):
self.add("auto_preview", SwitchSetting(S._auto_preview, parent=self))
self.add("show_steps", SwitchSetting(S._show_steps, parent=self))
self.add("new_seed_after_apply", SwitchSetting(S._new_seed_after_apply, parent=self))
self.add(
"live_redraw_grace_period",
SliderSetting(S._live_redraw_grace_period, self, 0.0, 3.0, "{} s"),
)
self.add("debug_dump_workflow", SwitchSetting(S._debug_dump_workflow, parent=self))

languages = [(lang.name, lang.id) for lang in Localization.available]
Expand Down

0 comments on commit 374b961

Please sign in to comment.