Skip to content

Commit

Permalink
Add process_before_every_sampling hook
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed May 28, 2024
1 parent 1c0a0c4 commit 4cd0225
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
24 changes: 24 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,15 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
# here we generate an image normally

x = self.rng.next()
if self.scripts is not None:
self.scripts.process_before_every_sampling(
p=self,
x=x,
noise=x,
c=conditioning,
uc=unconditional_conditioning
)

samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x

Expand Down Expand Up @@ -1425,6 +1434,13 @@ def save_intermediate(image, index):

if self.scripts is not None:
self.scripts.before_hr(self)
self.scripts.process_before_every_sampling(
p=self,
x=samples,
noise=noise,
c=self.hr_c,
uc=self.hr_uc,
)

samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)

Expand Down Expand Up @@ -1738,6 +1754,14 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
x *= self.initial_noise_multiplier

if self.scripts is not None:
self.scripts.process_before_every_sampling(
p=self,
x=self.init_latent,
noise=x,
c=conditioning,
uc=unconditional_conditioning
)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)

if self.mask is not None:
Expand Down
15 changes: 15 additions & 0 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def after_extra_networks_activate(self, p, *args, **kwargs):
"""
pass

def process_before_every_sampling(self, p, *args, **kwargs):
"""
Similar to process(), called before every sampling.
If you use high-res fix, this will be called two times.
"""
pass

def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
Expand Down Expand Up @@ -826,6 +833,14 @@ def process(self, p):
except Exception:
errors.report(f"Error running process: {script.filename}", exc_info=True)

def process_before_every_sampling(self, p, **kwargs):
for script in self.ordered_scripts('process_before_every_sampling'):
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process_before_every_sampling(p, *script_args, **kwargs)
except Exception:
errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True)

def before_process_batch(self, p, **kwargs):
for script in self.ordered_scripts('before_process_batch'):
try:
Expand Down

0 comments on commit 4cd0225

Please sign in to comment.