Skip to content

Commit

Permalink
Processor for Stable Diffusion image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
stijn-uva committed Nov 10, 2023
1 parent b815c54 commit 1e1b136
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 1 deletion.
5 changes: 4 additions & 1 deletion common/lib/dmi_service_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def check_progress(self):
self.processor.dataset.update_progress(current_completed / self.num_files_to_process)
self.processed_files = current_completed

def send_request_and_wait_for_results(self, service_endpoint, data, wait_period=60, check_process=True):
def send_request_and_wait_for_results(self, service_endpoint, data, wait_period=60, check_process=True, callback=None):
"""
Send request and wait for results to be ready.
Expand Down Expand Up @@ -164,6 +164,9 @@ def send_request_and_wait_for_results(self, service_endpoint, data, wait_period=
if check_process:
self.check_progress()

if callback:
callback(self)

result = requests.get(results_url, timeout=30)
if 'status' in result.json().keys() and result.json()['status'] == 'running':
# Still running
Expand Down
262 changes: 262 additions & 0 deletions processors/visualisation/generate_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
"""
Generate images with Stable Diffusion
Why? Because we can.
"""
import shutil
import json
import re
import os

from backend.lib.processor import BasicProcessor
from common.lib.dmi_service_manager import DmiServiceManager, DmiServiceManagerException, DsmOutOfMemory
from common.lib.exceptions import ProcessorInterruptedException
from common.lib.user_input import UserInput
from common.config_manager import config

__author__ = "Stijn Peeters"
__credits__ = ["Stijn Peeters"]
__maintainer__ = "Stijn Peeters"
__email__ = "4cat@oilab.eu"


class CategorizeImagesCLIP(BasicProcessor):
"""
Generate images with Stable Diffusion
"""
type = "text-to-images" # job type ID
category = "Visual" # category
title = "Generate images from text prompts" # title displayed in UI
description = "Given a list of prompts, generates images for those prompts using the Stable Diffusion XL image model." # description displayed in UI
extension = "zip" # extension of result file, used internally and in UI

references = [
"[Stable Diffusion XL 1.0 model card](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)"
]

config = {
"dmi-service-manager.sd_intro-1": {
"type": UserInput.OPTION_INFO,
"help": "StabilityAI's Stable Diffusion model can generate images for a given text prompt. Ensure the DMI "
"Service Manager is running and has a [prebuilt SD XL "
"image](https://github.com/digitalmethodsinitiative/dmi_dockerized_services/tree/main/stable_diffusion).",
},
"dmi-service-manager.sd_enabled": {
"type": UserInput.OPTION_TOGGLE,
"default": False,
"help": "Enable Stable Diffusion image generation",
},
"dmi-service-manager.sd_num_files": {
"type": UserInput.OPTION_TEXT,
"coerce_type": int,
"default": 0,
"help": "Max images to generate with Stable Diffusion",
"tooltip": "Use '0' to allow unlimited number"
},
}

options = {
"prompt-column": {
"type": UserInput.OPTION_TEXT,
"default": False,
"help": "Dataset field containing prompt",
"tooltip": "Prompts will be truncated to 70 characters"
},
"negative-prompt-column": {
"type": UserInput.OPTION_TEXT,
"default": False,
"help": "Dataset field containing negative prompt",
"tooltip": "The model will try to avoid generating an image that fits the negative prompt. Prompts will be "
"truncated to 70 characters"
}
}

@classmethod
def get_options(cls, parent_dataset=None, user=None):
"""
Get processor options
These are dynamic for this processor: the 'column names' option is
populated with the column names from the parent dataset, if available.
:param DataSet parent_dataset: Parent dataset
:param user: Flask User to which the options are shown, if applicable
:return dict: Processor options
"""
options = cls.options
if parent_dataset is None:
return options

parent_columns = parent_dataset.get_columns()

if parent_columns:
parent_columns = {c: c for c in sorted(parent_columns)}
options["prompt-column"].update({
"type": UserInput.OPTION_CHOICE,
"options": parent_columns,
})

# negative prompt optional, so add empty option
options["negative-prompt-column"].update({
"type": UserInput.OPTION_CHOICE,
"options": {"": "", **parent_columns},
"default": ""
})

return options

@classmethod
def is_compatible_with(cls, module=None, user=None):
"""
Allow on datasets with columns (from which a prompt can be retrieved)
"""
return config.get("dmi-service-manager.sd_enabled", False, user=user) and \
config.get("dmi-service-manager.ab_server_address", False, user=user) and \
module.get_columns()

@staticmethod
def get_progress_callback(num_prompts):
"""
Callback for updating dataset progress
This generates a function that is periodically called by the service
manager to update the progress of the query. It sets the progress and
updates the dataset status.
:param int num_prompts: Total number of prompts that will be processed.
:return: Callback
"""

def callback(manager):
if manager.local_or_remote == "local":
current_completed = manager.count_local_files(
manager.processor.config.get("PATH_DATA").joinpath(manager.path_to_results))
elif manager.local_or_remote == "remote":
existing_files = manager.request_folder_files(manager.server_file_collection_name)
current_completed = len(existing_files.get(manager.server_results_folder_name, []))

manager.processor.dataset.update_status(
f"Generated images for {current_completed:,} of {num_prompts:,} prompt(s)")
manager.processor.dataset.update_progress(current_completed / num_prompts)

return callback

def process(self):
"""
This takes a dataset and generates images for prompts retrieved from that dataset
"""
prompts = {}
max_prompts = self.config.get("dmi-service-manager.sd_num_files")

prompt_c = self.parameters["prompt-column"]
neg_c = self.parameters.get("negative-prompt-column")
for item in self.source_dataset.iterate_items():
if max_prompts and len(prompts) >= max_prompts:
break

prompts[item.get("id", len(prompts) + 1)] = {
"prompt": item.get(prompt_c, ""),
"negative": item.get(neg_c) if item.get(neg_c) is not None else ""
}

if not any([p["prompt"] for p in prompts.values()]):
return self.dataset.finish_with_error(
f"No prompts found in dataset's '{prompt_c}' field. Use a different field and try again.")

# Make output dir
staging_area = self.dataset.get_staging_area()
output_dir = self.dataset.get_staging_area()

# Initialize DMI Service Manager
dmi_service_manager = DmiServiceManager(processor=self)

# Check GPU memory available
try:
gpu_memory, info = dmi_service_manager.check_gpu_memory_available("stable_diffusion")
except DmiServiceManagerException as e:
return self.dataset.finish_with_error(str(e))
staging_area.unlink()
output_dir.unlink()

if not gpu_memory:
if info.get("reason") == "GPU not enabled on this instance of DMI Service Manager":
self.dataset.update_status("DMI Service Manager GPU not enabled; using CPU")
elif int(info.get("memory", {}).get("gpu_free_mem", 0)) < 1000000:
return self.dataset.finish_with_error(
"DMI Service Manager currently too busy; no GPU memory available. Please try again later.")
shutil.rmtree(staging_area)
shutil.rmtree(output_dir)

# Results should be unique to this dataset
results_folder_name = f"images_{self.dataset.key}"
file_collection_name = dmi_service_manager.get_folder_name(self.dataset)

# the prompts may be lengthy, so dump them in a file rather than
# relying on passing them as a command line argument
prompts_file = staging_area.joinpath("prompts.temp.json")
with prompts_file.open("w") as outfile:
json.dump(prompts, outfile)

path_to_files, path_to_results = dmi_service_manager.process_files(staging_area, [prompts_file.name],
output_dir, file_collection_name,
results_folder_name)

# interface.py args
data = {"args": ['--output-dir', f"data/{path_to_results}",
"--prompts-file",
f"data/{path_to_files.joinpath(dmi_service_manager.sanitize_filenames(prompts_file.name))}"]
}

# Send request to DMI Service Manager
self.dataset.update_status(f"Requesting service from DMI Service Manager...")
api_endpoint = "stable_diffusion"

try:
dmi_service_manager.send_request_and_wait_for_results(api_endpoint, data, wait_period=5, check_process=None,
callback=self.get_progress_callback(len(prompts)))
except DsmOutOfMemory:
shutil.rmtree(staging_area)
shutil.rmtree(output_dir)
return self.dataset.finish_with_error(
"DMI Service Manager ran out of memory; Try decreasing the number of prompts or try again or try again later.")
except DmiServiceManagerException as e:
shutil.rmtree(staging_area)
shutil.rmtree(output_dir)
return self.dataset.finish_with_error(str(e))

# Download the result files
self.dataset.update_status("Processing generated images...")
dmi_service_manager.process_results(output_dir)

# Output folder is basically already ready for archiving
# Add a metadata JSON file before that though
def make_filename(id, prompt):
"""
Generate filename for generated image
Should mirror the make_filename method in interface.py in the SD Docker.
:param prompt_id: Unique identifier, eg `54`
:param str prompt: Text prompt, will be sanitised, e.g. `Rasta Bill Gates`
:return str: For example, `54-rasta-bill-gates.jpeg`
"""
safe_prompt = re.sub(r"[^a-zA-Z0-9 _-]", "", prompt).replace(" ", "-").lower()[:90]
return f"{id}-{safe_prompt}.jpeg"

self.dataset.update_status("Verifying results")
with output_dir.joinpath(".metadata.json").open("w") as outfile:
results = [{
"id": prompt_id,
"prompt": data["prompt"],
"negative-prompt": data["negative"],
"filename": make_filename(prompt_id, data["prompt"]),
"success": output_dir.joinpath(make_filename(prompt_id, data["prompt"])).exists()
} for prompt_id, data in prompts.items()]
json.dump(results, outfile)

shutil.rmtree(staging_area)

self.dataset.update_status(
f"Generated {len([r for r in results if r['success']]):,} image(s) for {len(prompts):,} prompt(s)",
is_final=True)
self.write_archive_and_finish(output_dir, num_items=len([r for r in results if r['success']]))

0 comments on commit 1e1b136

Please sign in to comment.