Skip to content

Commit

Permalink
add metadata to scn2img intermediate image output
Browse files Browse the repository at this point in the history
Following metadata is written:
- "prompt" contains the representation of the SceneObject corresponding to the intermediate image
- "seed" contains the seed at the start of the function that generated this intermediate image
- "width" and "height" contain the size of the image.

To get the seed at the start of the render function without using it, a class SeedGenerator is added and
used instead of the python generator functions.

Fixes warning thrown in console: "> Couldn't find metadata on image", originally reported by @codedealer in Sygil-Dev#1179 (review)
  • Loading branch information
xaedes committed Oct 2, 2022
1 parent 33b896d commit 849a569
Showing 1 changed file with 51 additions and 29 deletions.
80 changes: 51 additions & 29 deletions scripts/scn2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import yaml
import math
import copy
import random
from typing import List, Union, Dict, Callable, Any, Optional, Type, Tuple

Expand All @@ -15,6 +16,7 @@
import torch

from frontend.job_manager import JobInfo
from frontend.image_metadata import ImageMetadata

scn2img_cache = {
"seed": None,
Expand Down Expand Up @@ -279,8 +281,17 @@ def get_scn2img(MemUsageMonitor:Type, save_sample:Callable, get_next_sequence_nu
opt = opt or argparse.Namespace()

def next_seed(s):
s = seed_to_int(s)
return random.Random(s).randint(0, 2**32 - 1)
return random.Random(seed_to_int(s)).randint(0, 2**32 - 1)

class SeedGenerator:
def __init__(self, seed):
self._seed = seed_to_int(seed)
def next_seed(self):
seed = self._seed
self._seed = next_seed(self._seed)
return seed
def peek_seed(self):
return self._seed

def scn2img(prompt: str, toggles: List[int], seed: Union[int, str, None], fp = None, job_info: JobInfo = None):
global scn2img_cache
Expand Down Expand Up @@ -336,11 +347,6 @@ def log_exception(*args, **kwargs):
log_info("scn2img_cache")
log_info(list(scn2img_cache["cache"].keys()))

def gen_seeds(seed):
while True:
yield seed
seed = next_seed(seed)

def is_seed_invalid(s):
result = (
(type(s) != int)
Expand Down Expand Up @@ -631,20 +637,28 @@ def parse_scene_args(scene):

return scene

def save_sample_scn2img(img, obj):
def save_sample_scn2img(img, obj, name, seed):
if img is None:
return
base_count = get_next_sequence_number(outpath)
filename = "[SEED]_result"
filename = f"{base_count:05}-" + filename
filename = filename.replace("[SEED]", str(seed))
save_sample(img, outpath, filename, jpg_sample, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False)
if write_info_files or write_sample_info_to_log_file:
info_dict = {
"prompt": prompt,
"scene_object": str(obj),
"seed": seed
}
wrapped = SceneObject(
func=name,
title=obj.title,
args={"seed":seed},
depth=obj.depth-1,
children=[obj]
)
info_dict = {
"prompt": prompt,
"scene_object": str(wrapped),
"seed": seed
}
metadata = ImageMetadata(prompt=info_dict["scene_object"], seed=seed, width=img.size[0], height=img.size[0])
ImageMetadata.set_on_image(img, metadata)
save_sample(img, outpath, filename, jpg_sample, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False, False)
if write_info_files:
filename_i = os.path.join(outpath, filename)
with open(f"{filename_i}.yaml", "w", encoding="utf8") as f:
Expand Down Expand Up @@ -931,11 +945,11 @@ def output_img(img):
output_image_set.add(img_id)
output_images.append(img)

def render_intermediate(img, obj):
def render_intermediate(img, obj, name, seed):
if output_intermediates:
output_img(img)
if not skip_save:
save_sample_scn2img(img, obj)
save_sample_scn2img(img, obj, name, seed)
return img

def render_3d(img, obj):
Expand All @@ -944,7 +958,7 @@ def render_3d(img, obj):
if obj["transform3d"] == True:
d2r = math.pi / 180.0
depth_model = obj["transform3d_depth_model"] if "transform3d_depth_model" in obj else 1
depth_near = obj["transform3d_depth_near"] if "transform3d_depth_near" in obj else 0.1
depth_near = obj["transform3d_depth_near"] if "transform3d_depth_near" in obj else 0.1
depth_scale = obj["transform3d_depth_scale"] if "transform3d_depth_scale" in obj else 1.0
from_hfov = obj["transform3d_from_hfov"] if "transform3d_from_hfov" in obj else (45*d2r)
from_pose = obj["transform3d_from_pose"] if "transform3d_from_pose" in obj else (0,0,0, 0,0,0)
Expand Down Expand Up @@ -983,6 +997,7 @@ def render_3d(img, obj):
return img

def render_image(seeds, obj):
start_seed = seeds.peek_seed()
img = create_image(obj["size"], obj["color"])
img = blend_objects(
seeds,
Expand All @@ -993,7 +1008,7 @@ def render_image(seeds, obj):
img = resize_image(img, obj["resize"], obj["crop"])
# if img is None: log_warn(f"result of render_image({obj}) is None")
img = render_3d(img, obj)
img = render_intermediate(img, obj)
img = render_intermediate(img, obj, "render_image", start_seed)
return img

def prepare_img2img_kwargs(seeds, obj, img):
Expand Down Expand Up @@ -1025,7 +1040,7 @@ def prepare_img2img_kwargs(seeds, obj, img):
if is_seed_valid(s):
img2img_kwargs["seed"] = int(s)
else:
img2img_kwargs["seed"] = next(seeds)
img2img_kwargs["seed"] = seeds.next_seed()

log_info('img2img_kwargs["seed"]', img2img_kwargs["seed"])

Expand All @@ -1047,7 +1062,7 @@ def prepare_img2img_kwargs(seeds, obj, img):
"image": img.convert("RGB").convert("RGBA"),
"mask": img.getchannel("A")
}
# render_intermediate(img2img_kwargs["init_info_mask"]["mask"].convert("RGBA"), obj)
# render_intermediate(img2img_kwargs["init_info_mask"]["mask"].convert("RGBA"), obj, "img2img_init_info_mask", start_seed)
log_info("img2img_kwargs")
log_info(img2img_kwargs)

Expand Down Expand Up @@ -1079,7 +1094,7 @@ def prepare_txt2img_kwargs(seeds, obj):
if is_seed_valid(s):
txt2img_kwargs["seed"] = int(s)
else:
txt2img_kwargs["seed"] = next(seeds)
txt2img_kwargs["seed"] = seeds.next_seed()

log_info('txt2img_kwargs["seed"]', txt2img_kwargs["seed"])

Expand All @@ -1102,6 +1117,7 @@ def prepare_txt2img_kwargs(seeds, obj):
return txt2img_kwargs

def render_img2img(seeds, obj):
start_seed = seeds.peek_seed()
global scn2img_cache
if obj["size"] is None:
obj["size"] = (img2img_defaults["width"], img2img_defaults["height"])
Expand All @@ -1112,7 +1128,7 @@ def render_img2img(seeds, obj):
obj.children
)
img = render_mask(seeds, obj, img)
img = render_intermediate(img, obj)
img = render_intermediate(img, obj, "render_img2img_input", start_seed)

img2img_kwargs = prepare_img2img_kwargs(seeds, obj, img)

Expand Down Expand Up @@ -1161,10 +1177,11 @@ def render_img2img(seeds, obj):
img = resize_image(img, obj["resize"], obj["crop"])
if img is None: log_warn(f"result of render_img2img({obj}) is None")
img = render_3d(img, obj)
img = render_intermediate(img, obj)
img = render_intermediate(img, obj, "render_img2img", start_seed)
return img

def render_txt2img(seeds, obj):
start_seed = seeds.peek_seed()
global scn2img_cache

txt2img_kwargs = prepare_txt2img_kwargs(seeds, obj)
Expand Down Expand Up @@ -1213,14 +1230,16 @@ def render_txt2img(seeds, obj):
img = resize_image(img, obj["resize"], obj["crop"])
if img is None: log_warn(f"result of render_txt2img({obj}) is None")
img = render_3d(img, obj)
img = render_intermediate(img, obj)
img = render_intermediate(img, obj, "render_txt2img", start_seed)
return img

def render_object(seeds, obj):
# log_trace(f"render_object({str(obj)})")

if "initial_seed" in obj:
seeds = gen_seeds(obj["initial_seed"])
# create new generator rather than resetting current generator,
# so that seeds generator from function argument is not changed.
seeds = SeedGenerator(obj["initial_seed"])

if obj.func == "scene":
assert(len(obj.children) == 1)
Expand All @@ -1240,7 +1259,9 @@ def render_scn2img(seeds, obj):
result = []

if "initial_seed" in obj:
seeds = gen_seeds(obj["initial_seed"])
# create new generator rather than resetting current generator,
# so that seeds generator from function argument is not changed.
seeds = SeedGenerator(obj["initial_seed"])

if obj.func == "scn2img":
# Note on seed generation and for-loops instead of
Expand All @@ -1257,6 +1278,7 @@ def render_scn2img(seeds, obj):
result.append(render_object(seeds, obj))
return result

start_seed = seeds.peek_seed()
for img in render_scn2img(seeds, scene):
if output_intermediates:
# img already in output, do nothing here
Expand All @@ -1267,7 +1289,7 @@ def render_scn2img(seeds, obj):
if skip_save:
# individual image save was skipped,
# we need to save them now
save_sample_scn2img(img, scene)
save_sample_scn2img(img, scene, "render_scene", start_seed)


return output_images
Expand All @@ -1285,7 +1307,7 @@ def render_scn2img(seeds, obj):
log_info(scene)
# log_info("comments", comments)

render_scene(output_images, scene, gen_seeds(seed))
render_scene(output_images, scene, SeedGenerator(seed))
log_info("output_images", output_images)
# log_info("comments", comments)

Expand Down

0 comments on commit 849a569

Please sign in to comment.