Skip to content

Commit

Permalink
initial self attention implementation
Browse files Browse the repository at this point in the history
#52

Although this is still work in progress:

- The test is not really testing anything.
- Not using consistent dim tags (#17)
- Missing is normal attention
- Missing is auto-regressive self-attention
  • Loading branch information
albertz committed Nov 6, 2021
1 parent ed75042 commit a711320
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
1 change: 1 addition & 0 deletions nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .utils import *
from .rec import *
from .container import *
from .attention import *
40 changes: 40 additions & 0 deletions nn/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Attention, self-attention, auto-regressive self-attention
"""

from .. import nn


class SelfAttention(nn.Module):
"""
Classic self attention
"""
def __init__(self, *, axis: nn.DimensionTag, key_dim_total, value_dim_total, num_heads: int):
super().__init__()
self.axis = axis
self.key_dim_total = key_dim_total
self.key_dim_per_head = key_dim_total // num_heads
self.value_dim_total = value_dim_total
self.value_dim_per_head = value_dim_total // num_heads
self.num_heads = num_heads
self.qkv = nn.Linear(key_dim_total * 2 + value_dim_total)
self.expand_dim = nn.DimensionTag(kind=nn.DimensionTag.Types.Spatial, description="self_att_expand_dim")

def forward(self, source: nn.LayerRef) -> nn.Layer:
"""forward"""
qkv = self.qkv(source)
qkv = nn.split_dims(
qkv, axis="F", dims=(self.num_heads, self.key_dim_per_head * 2 + self.value_dim_per_head),
name="qkv_split_dims")
q, k, v = nn.split(
qkv, axis="F", size_splits=(self.key_dim_per_head, self.key_dim_per_head, self.value_dim_per_head),
name="qkv_split")
q *= self.key_dim_per_head ** -0.5
k = nn.reinterpret_data(k, set_dim_tags={self.axis: self.expand_dim}, name="k_new_dim")
v = nn.reinterpret_data(v, set_dim_tags={self.axis: self.expand_dim}, name="v_new_dim")
energy = nn.dot([q, k], red1="static:-1", red2="static:-1", var1=self.axis, var2=self.expand_dim, name="energy")
att_weights = nn.softmax(energy, axis=self.expand_dim, name="att_weights")
att = nn.dot(
[att_weights, v], red1=self.expand_dim, red2=self.expand_dim, var1=self.axis, var2="static:-1", name="att")
output = nn.merge_dims(att, axes="static", name="output")
return output
17 changes: 17 additions & 0 deletions tests/test_models_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,20 @@ def forward(self, x: nn.LayerRef) -> nn.Layer:
'output': {'class': 'copy', 'from': 'mul'},
'sigmoid': {'activation': 'sigmoid', 'class': 'activation', 'from': 'split/1'},
'split': {'axis': 'F', 'class': 'split', 'from': 'data:data', 'num_splits': 2}})


def test_self_attention():
time_dim = nn.DimensionTag(nn.DimensionTag.Types.Spatial, description="time")

class _Net(nn.Module):
def __init__(self):
super().__init__()
self.self_att = nn.SelfAttention(axis=time_dim, key_dim_total=21, value_dim_total=33, num_heads=3)

def forward(self, x: nn.LayerRef) -> nn.Layer:
"""forward"""
return self.self_att(x)

net = _Net()
net_dict = make_root_net_dict(net, "data")
pprint(net_dict)

0 comments on commit a711320

Please sign in to comment.