Skip to content

Commit

Permalink
also allow for returning attention maps from tab transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 6, 2023
1 parent 582ebc8 commit e2e8b58
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'tab-transformer-pytorch',
packages = find_packages(),
version = '0.2.2',
version = '0.2.5',
license='MIT',
description = 'Tab Transformer - Pytorch',
author = 'Phil Wang',
Expand Down
37 changes: 25 additions & 12 deletions tab_transformer_pytorch/tab_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def forward(self, x):
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
dropped_attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out)
return self.to_out(out), attn

# transformer

Expand All @@ -92,18 +92,26 @@ def __init__(self, num_tokens, dim, depth, heads, dim_head, attn_dropout, ff_dro

for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
Residual(PreNorm(dim, FeedForward(dim, dropout = ff_dropout))),
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, dropout = ff_dropout)),
]))

def forward(self, x):
def forward(self, x, return_attn = False):
x = self.embeds(x)

post_softmax_attns = []

for attn, ff in self.layers:
x = attn(x)
x = ff(x)
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = x + attn_out
x = ff(x) + x

if not return_attn:
return x

return x
return x, torch.stack(post_softmax_attns)
# mlp

class MLP(nn.Module):
Expand Down Expand Up @@ -201,15 +209,15 @@ def __init__(

self.mlp = MLP(all_dimensions, act = mlp_act)

def forward(self, x_categ, x_cont):
def forward(self, x_categ, x_cont, return_attn = False):
xs = []

assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'

if self.num_unique_categories > 0:
x_categ += self.categories_offset

x = self.transformer(x_categ)
x, attns = self.transformer(x_categ, return_attn = True)

flat_categ = x.flatten(1)
xs.append(flat_categ)
Expand All @@ -225,4 +233,9 @@ def forward(self, x_categ, x_cont):
xs.append(normed_cont)

x = torch.cat(xs, dim = -1)
return self.mlp(x)
logits =self.mlp(x)

if not return_attn:
return logits

return logits, attns

0 comments on commit e2e8b58

Please sign in to comment.