Skip to content

Commit

Permalink
make sure batch size is set correctly, if only inpainting images are …
Browse files Browse the repository at this point in the history
…passed in, for unconditional image synthesis. also allow for forcing unconditional image synthesis on network trained conditionally, although not recommended
  • Loading branch information
lucidrains committed Sep 9, 2022
1 parent 77cc6a3 commit 038f300
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
15 changes: 15 additions & 0 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ def __init__(

self.to(next(self.unets.parameters()).device)

def force_unconditional_(self):
self.condition_on_text = False
self.unconditional = True

for unet in self.unets:
unet.cond_on_text = False

@property
def device(self):
return self._temp.device
Expand Down Expand Up @@ -550,6 +557,14 @@ def sample(
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
batch_size = text_embeds.shape[0]

if exists(inpaint_images):
if self.unconditional:
if batch_size == 1: # assume researcher wants to broadcast along inpainted images
batch_size = inpaint_images.shape[0]

assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'

assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
Expand Down
15 changes: 15 additions & 0 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,6 +1908,13 @@ def __init__(

self.to(next(self.unets.parameters()).device)

def force_unconditional_(self):
self.condition_on_text = False
self.unconditional = True

for unet in self.unets:
unet.cond_on_text = False

@property
def device(self):
return self._temp.device
Expand Down Expand Up @@ -2184,6 +2191,14 @@ def sample(
text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
batch_size = text_embeds.shape[0]

if exists(inpaint_images):
if self.unconditional:
if batch_size == 1: # assume researcher wants to broadcast along inpainted images
batch_size = inpaint_images.shape[0]

assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'

assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.11.9'
__version__ = '1.11.10'

0 comments on commit 038f300

Please sign in to comment.