Skip to content

Commit

Permalink
Allow ctrl c when using --from_file (#472)
Browse files Browse the repository at this point in the history
* added ansi escapes to highlight key parts of CLI session

* adjust exception handling so that ^C will abort when reading prompts from a file
  • Loading branch information
lstein authored Sep 9, 2022
1 parent 75f633c commit 723d074
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
6 changes: 5 additions & 1 deletion ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
seamless = False,
embedding_path = None,
device_type = 'cuda',
ignore_ctrl_c = False,
):
self.iterations = iterations
self.width = width
Expand All @@ -134,6 +135,7 @@ def __init__(
self.seamless = seamless
self.embedding_path = embedding_path
self.device_type = device_type
self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
self.model = None # empty for now
self.sampler = None
self.device = None
Expand Down Expand Up @@ -210,7 +212,7 @@ def prompt2image(
**args,
): # eat up additional cruft
"""
ldm.prompt2image() is the common entry point for txt2img() and img2img()
ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
It takes the following arguments:
prompt // prompt string (no default)
iterations // iterations (1); image count=iterations
Expand Down Expand Up @@ -341,6 +343,8 @@ def process_image(image,seed):

except KeyboardInterrupt:
print('*interrupted*')
if not self.ignore_ctrl_c:
raise KeyboardInterrupt
print(
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
)
Expand Down
21 changes: 16 additions & 5 deletions scripts/dream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from ldm.dream.image_util import make_grid
from omegaconf import OmegaConf

# Placeholder to be replaced with proper class that tracks the
# outputs and associates with the prompt that generated them.
# Just want to get the formatting look right for now.
output_cntr = 0

def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
Expand Down Expand Up @@ -63,7 +68,8 @@ def main():
# this is solely for recreating the prompt
seamless = opt.seamless,
embedding_path = opt.embedding_path,
device_type = opt.device
device_type = opt.device,
ignore_ctrl_c = opt.infile is None,
)

# make sure the output directory exists
Expand Down Expand Up @@ -292,16 +298,18 @@ def image_writer(image, seed, upscaled=False):
print(e)
continue

print('Outputs:')
print('\033[1mOutputs:\033[0m')
log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(results, log_path)

print('goodbye!')
print('goodbye!\033[0m')


def get_next_command(infile=None) -> str: #command string
if infile is None:
command = input('dream> ')
print('\033[1m') # add some boldface
command = input('dream> ')
print('\033[0m',end='')
else:
command = infile.readline()
if not command:
Expand Down Expand Up @@ -339,8 +347,11 @@ def dream_server_loop(t2i, host, port, outdir):

def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
global output_cntr
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
print(*log_lines, sep='')
for l in log_lines:
output_cntr += 1
print(f'\033[1m[{output_cntr}]\033[0m {l}',end='')

with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)
Expand Down

1 comment on commit 723d074

@blessedcoolant
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit is overriding terminal colors on the print statements by appending \033[1m .. This needs to be fixed.

Please sign in to comment.