Skip to content

Commit

Permalink
feat(dream/server): add CORS headers
Browse files Browse the repository at this point in the history
  • Loading branch information
mgcrea committed Sep 10, 2022
1 parent bfb2781 commit b792915
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
23 changes: 22 additions & 1 deletion ldm/dream/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,22 @@ class CanceledException(Exception):
class DreamServer(BaseHTTPRequestHandler):
model = None
outdir = None
cors = None
canceled = Event()

# CORS support
def send_cors_headers(self):
if (self.cors):
self.send_header("Access-Control-Allow-Origin", self.cors)
request_headers = self.headers.get('Access-Control-Request-Headers');
if (request_headers):
self.send_header("Access-Control-Allow-Headers", request_headers)

def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.send_cors_headers();
self.end_headers()
with open("./static/dream_web/index.html", "rb") as content:
self.wfile.write(content.read())
Expand All @@ -74,6 +84,7 @@ def do_GET(self):
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
self.send_response(200)
self.send_header("Content-type", "application/javascript")
self.send_cors_headers();
self.end_headers()
config = {
'gfpgan_model_exists': gfpgan_model_exists
Expand All @@ -84,7 +95,7 @@ def do_GET(self):
self.send_header("Content-type", "application/json")
self.end_headers()
output = []

log_file = os.path.join(self.outdir, "dream_web_log.txt")
if os.path.exists(log_file):
with open(log_file, "r") as log:
Expand All @@ -100,6 +111,7 @@ def do_GET(self):
self.canceled.set()
self.send_response(200)
self.send_header("Content-type", "application/json")
self.send_cors_headers();
self.end_headers()
self.wfile.write(bytes('{}', 'utf8'))
else:
Expand All @@ -113,6 +125,7 @@ def do_GET(self):
if mime_type is not None:
self.send_response(200)
self.send_header("Content-type", mime_type)
self.send_cors_headers();
self.end_headers()
with open("." + self.path, "rb") as content:
self.wfile.write(content.read())
Expand All @@ -122,6 +135,7 @@ def do_GET(self):
def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.send_cors_headers();
self.end_headers()

# unfortunately this import can't be at the top level, since that would cause a circular import
Expand Down Expand Up @@ -237,6 +251,13 @@ def image_progress(sample, step):
print(f"Canceled.")
return

# CORS support
def do_OPTIONS(self):
self.send_response(200)
self.send_cors_headers();
if (self.cors):
self.send_header("Access-Control-Allow-Methods", "GET,HEAD,PUT,PATCH,POST,DELETE")
self.end_headers()

class ThreadingDreamServer(ThreadingHTTPServer):
def __init__(self, server_address):
Expand Down
22 changes: 17 additions & 5 deletions scripts/dream.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()

if opt.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if opt.weights != 'model':
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
sys.exit(-1)

try:
models = OmegaConf.load(opt.config)
width = models[opt.model].width
Expand Down Expand Up @@ -106,7 +106,7 @@ def main():

cmd_parser = create_cmd_parser()
if opt.web:
dream_server_loop(t2i, opt.host, opt.port, opt.outdir)
dream_server_loop(t2i, opt.host, opt.port, opt.outdir, opt.cors)
else:
main_loop(t2i, opt.outdir, opt.prompt_as_dir, cmd_parser, infile)

Expand Down Expand Up @@ -319,8 +319,13 @@ def get_next_command(infile=None) -> str: #command string
print(f'#{command}')
return command

def dream_server_loop(t2i, host, port, outdir):
def dream_server_loop(t2i, host, port, outdir, cors):
print('\n* --web was specified, starting web server...')
if cors:
print(f'* --cors was specified, enabling CORS support with "Access-Control-Allow-Origin: {cors}"...')
if cors == '*':
print(f'* UNSAFE CORS POLICY DETECTED! YOU ARE ALLOWING ANY ORIGIN TO CONNECT TO THIS SERVER!')
print(f'* PLEASE SPECIFY AN ORIGIN & SEE THIS STACKOVERFLOW POST FOR MORE INFORMATION: https://bit.ly/3RpIy6I! \n')
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
Expand All @@ -329,6 +334,7 @@ def dream_server_loop(t2i, host, port, outdir):
# Start server
DreamServer.model = t2i
DreamServer.outdir = outdir
DreamServer.cors = cors
dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0':
Expand Down Expand Up @@ -371,7 +377,7 @@ def write_log_message(results, log_path):
def create_argv_parser():
parser = argparse.ArgumentParser(
description="""Generate images using Stable Diffusion.
Use --web to launch the web interface.
Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-").
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
Other command-line arguments are defaults that can usually be overridden
Expand Down Expand Up @@ -489,6 +495,12 @@ def create_argv_parser():
default='9090',
help='Web server: Port to listen on'
)
parser.add_argument(
'--cors',
dest='cors',
type=str,
help='Web server: Access-Control-Allow-Origin value'
)
parser.add_argument(
'--weights',
default='model',
Expand Down

0 comments on commit b792915

Please sign in to comment.