Skip to content

Commit

Permalink
offer way for extractor to return latents without detaching them
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 16, 2022
1 parent 2fa2b62 commit f86e052
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.7',
version = '0.35.8',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
13 changes: 11 additions & 2 deletions vit_pytorch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
def exists(val):
return val is not None

def identity(t):
return t

def clone_and_detach(t):
return t.clone().detach()

def apply_tuple_or_single(fn, val):
if isinstance(val, tuple):
return tuple(map(fn, val))
Expand All @@ -17,7 +23,8 @@ def __init__(
layer = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False
return_embeddings_only = False,
detach = True
):
super().__init__()
self.vit = vit
Expand All @@ -34,9 +41,11 @@ def __init__(
self.layer_save_input = layer_save_input # whether to save input or output of layer
self.return_embeddings_only = return_embeddings_only

self.detach_fn = clone_and_detach if detach else identity

def _hook(self, _, inputs, output):
layer_output = inputs if self.layer_save_input else output
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)

def _register_hook(self):
if not exists(self.layer):
Expand Down

0 comments on commit f86e052

Please sign in to comment.