Skip to content

Commit

Permalink
[Unity][Frontend][NN] Add diffusers style Attention layer (#15609)
Browse files Browse the repository at this point in the history
This PR adds support for the `Attention` layer in the nn module API. This layer mimics the behavior of the [Attention layer used in huggingface Diffusers](https://github.com/huggingface/diffusers/blob/80871ac5971fe7e708befa3b553463c4e61b22ab/src/diffusers/models/attention_processor.py#L36). Under the hood it uses scaled dot product attention. Notably, there are still some missing features. For example I didnt add support for attention masks yet. I also am assuming 3 dimensional inputs so this likely wouldnt be useful for the VAE of stable diffusion. However, it does allow us to represent the all attention layers of the UNET. I'm hoping we can expand on the functionality more if needed in future PRs.
  • Loading branch information
Josh Fromm authored Aug 26, 2023
1 parent d28613f commit d5dcabf
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 18 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ def __setitem__(self, idx, module):
def __len__(self):
return len(self.modules)

def append(self, module):
self.modules.append(module)

def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-name
for module in self.modules:
module.to(dtype=dtype)
Expand Down
129 changes: 125 additions & 4 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-arguments,invalid-name,protected-access
# pylint: disable=too-many-arguments,invalid-name,protected-access,unused-argument
"""Builtin Modules."""
from typing import List, Optional, Sequence, Union

Expand All @@ -24,7 +24,7 @@
from tvm.runtime import NDArray

from . import op
from .core import Effect, Module, Parameter, Tensor, get_default_dtype
from .core import Effect, Module, Parameter, Tensor, get_default_dtype, ModuleList


class IOEffect(Effect):
Expand Down Expand Up @@ -344,21 +344,28 @@ def __init__(
self.weight = None
self.bias = None

def forward(self, x: Tensor):
def forward(self, x: Tensor, channel_axis: int = 1, axes: Optional[List[int]] = None):
"""
Forward method for group norm layer.
Parameters
----------
x : Tensor
The input tensor.
channel_axis : int
Channel axis of the input data.
axes : Optional[List[int]]
Optional list of axes to compute norm over, if not specified,
assumes that the first two axes should be left alone.
Returns
-------
ret : Tensor
The output tensor for the group norm layer.
"""
return op.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
return op.group_norm(
x, self.num_groups, self.weight, self.bias, self.eps, channel_axis, axes
)


class KVCache(Effect):
Expand Down Expand Up @@ -621,3 +628,117 @@ def forward(self, x: Tensor):
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)


class Attention(Module):
"""
A cross attention layer.
Parameters
----------
query_dim : int
The number of channels in the query.
cross_attention_dim : Optional[int]
The number of channels in the encoder_hidden_states.
If not given, defaults to `query_dim`.
heads : int
The number of heads to use for multi-head attention.
dim_head : int
The number of channels in each head.
bias : bool
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
norm_num_groups : Optional[int]
When set, group norm is applied to the input using this number of groups.
out_bias : bool
Set to `True` to apply a bias to the output linear layer.
scale_qk : bool
Whether to apply scaling to query and key tensors.
"""

def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
bias: bool = False,
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
):
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim else query_dim
self.heads = heads
self.dim_head = dim_head
self.bias = bias
self.norm_num_groups = norm_num_groups
self.out_bias = out_bias
self.scale_qk = scale_qk

self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.inner_dim = dim_head * heads

self.to_q = Linear(self.query_dim, self.inner_dim, bias=self.bias)
self.to_k = Linear(self.cross_attention_dim, self.inner_dim, bias=self.bias)
self.to_v = Linear(self.cross_attention_dim, self.inner_dim, bias=self.bias)

if self.norm_num_groups is not None:
self.group_norm = GroupNorm(
num_channels=self.query_dim, num_groups=self.norm_num_groups, affine=True
)
else:
self.group_norm = None

self.to_out = ModuleList([Linear(self.inner_dim, self.query_dim, bias=self.out_bias)])

def forward(
self,
hidden_states: Tensor,
encoder_hidden_states: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
**cross_attention_kwargs,
):
"""
Forward method for Attention layer.
Parameters
----------
hidden_states : Tensor
The input sample tensor.
encoder_hidden_states : Optional[Tensor]
Previous hidden step hidden states.
attention_mask : Optional[Tensor]
Mask tensor for attention, currently not supported.
Returns
-------
ret : Tensor
The output tensor for the embedding layer.
"""
# This implementation assumes use of torch 2.0 scaled_dot_product attention.
assert attention_mask is None, "Attention mask not yet supported."

if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states, channel_axis=2, axes=[1])

query = self.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
head_dim = int(self.inner_dim // self.heads)

query = op.reshape(query, [0, -1, self.heads, head_dim])
key = op.reshape(key, [0, -1, self.heads, head_dim])
value = op.reshape(value, [0, -1, self.heads, head_dim])

hidden_states = op.scaled_dot_product_attention(query, key, value, is_causal=False)

# Return to proper shape.
hidden_states = op.reshape(hidden_states, (0, -1, self.heads * head_dim))

# Linear projection
hidden_states = self.to_out[0](hidden_states)

return hidden_states
52 changes: 51 additions & 1 deletion python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,8 @@ def group_norm(
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float = 1e-5,
channel_axis: int = 1,
axes: Optional[List[int]] = None,
name: str = "group_norm",
) -> Tensor:
r"""
Expand All @@ -785,6 +787,13 @@ def group_norm(
epsilon : float
Small float added to square mean to avoid dividing by zero.
channel_axis: int
The channel axis of the data.
axes : Optional[int]
Which axes to compute the groupnorm over. If None, assumes first
two channels should be ignored.
name : str
Name hint.
Expand All @@ -798,9 +807,11 @@ def group_norm(
if bias is not None:
bias = bias._expr
dim = len(x._expr.struct_info.shape)
if axes is None:
axes = list(range(2, dim))
return _wrap_nested(
_op.nn.group_norm(
x._expr, weight, bias, num_groups, channel_axis=1, axes=list(range(2, dim)), epsilon=eps
x._expr, weight, bias, num_groups, channel_axis=channel_axis, axes=axes, epsilon=eps
),
name,
)
Expand Down Expand Up @@ -955,6 +966,45 @@ def get_timestep_embedding(
return _wrap_nested(emb, name)


def scaled_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
is_causal: Optional[bool] = False,
scale: Optional[float] = None,
name: str = "scaled_dot_product_attention",
):
"""
Computes a scaled dot product attention on provided attention
query, key, and values. Compliant with the functional torch implementation.
Parameters
----------
query : Tensor
Tensor representing current attention lookup.
key : Tensor
Tensor representing cross attention mapping.
value : Tensor
Tensor representing embedded attention values.
attn_mask : Optional[Tensor]
Optional mask for attention, not yet supported.
is_causal : Optional[bool]
If set, uses a causal attention mask.
scale : Optional[float]
Optional extra scaling argument applied to attention.
name : str
Name hint for this function.
"""
assert attn_mask is None, "attn_mask not yet supported."
causal_mask = "TopLeft" if is_causal else None

attn = _op.nn.attention(
query._expr, key._expr, value._expr, causal_mask=causal_mask, scale=scale
)
return _wrap_nested(attn, name)


def tensor_expr_op(
tensor_expr_func: Callable,
name_hint: str,
Expand Down
21 changes: 11 additions & 10 deletions python/tvm/relax/frontend/nn/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,17 @@ def from_raw(spec: MethodSpecType, method: Callable) -> "MethodSpec":
arg_names = list(method_signature.parameters.keys())
arg_specs = []
for arg_name in arg_names:
arg_spec = spec[arg_name]
if arg_spec is Int or arg_spec is int:
arg_spec = Int()
elif isinstance(arg_spec, str) and arg_spec == "int":
arg_spec = Int()
elif isinstance(arg_spec, (Int, Tensor)):
pass
else:
raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")
arg_specs.append(arg_spec)
if arg_name in spec:
arg_spec = spec[arg_name]
if arg_spec is Int or arg_spec is int:
arg_spec = Int()
elif isinstance(arg_spec, str) and arg_spec == "int":
arg_spec = Int()
elif isinstance(arg_spec, (Int, Tensor)):
pass
else:
raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}")
arg_specs.append(arg_spec)
return MethodSpec(method, arg_names, arg_specs)

@staticmethod
Expand Down
82 changes: 82 additions & 0 deletions tests/python/relax/test_frontend_nn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,5 +335,87 @@ def forward(self, x: core.Tensor) -> core.Tensor:
assert_structural_equal(tvm_mod, Module, True)


def test_attention():
@R.function
def forward(
hidden_states: R.Tensor((2, 4096, 640), dtype="float32"),
encoder_hidden_states: R.Tensor((2, 77, 2048), dtype="float32"),
to_q_weight: R.Tensor((640, 640), dtype="float32"),
to_k_weight: R.Tensor((640, 2048), dtype="float32"),
to_v_weight: R.Tensor((640, 2048), dtype="float32"),
group_norm_weight: R.Tensor((640,), dtype="float32"),
group_norm_bias: R.Tensor((640,), dtype="float32"),
to_out_0_weight: R.Tensor((640, 640), dtype="float32"),
to_out_0_bias: R.Tensor((640,), dtype="float32"),
_io: R.Object,
) -> R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"), R.Tuple(R.Object)):
with R.dataflow():
group_norm: R.Tensor((2, 4096, 640), dtype="float32") = R.nn.group_norm(
hidden_states,
group_norm_weight,
group_norm_bias,
num_groups=8,
channel_axis=2,
axes=[1],
epsilon=1.0000000000000001e-05,
center=True,
scale=True,
)
permute_dims: R.Tensor((640, 640), dtype="float32") = R.permute_dims(
to_q_weight, axes=None
)
matmul: R.Tensor((2, 4096, 640), dtype="float32") = R.matmul(
group_norm, permute_dims, out_dtype="void"
)
permute_dims1: R.Tensor((2048, 640), dtype="float32") = R.permute_dims(
to_k_weight, axes=None
)
matmul1: R.Tensor((2, 77, 640), dtype="float32") = R.matmul(
encoder_hidden_states, permute_dims1, out_dtype="void"
)
permute_dims2: R.Tensor((2048, 640), dtype="float32") = R.permute_dims(
to_v_weight, axes=None
)
matmul2: R.Tensor((2, 77, 640), dtype="float32") = R.matmul(
encoder_hidden_states, permute_dims2, out_dtype="void"
)
reshape: R.Tensor((2, 4096, 10, 64), dtype="float32") = R.reshape(
matmul, R.shape([2, 4096, 10, 64])
)
reshape1: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape(
matmul1, R.shape([2, 77, 10, 64])
)
reshape2: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape(
matmul2, R.shape([2, 77, 10, 64])
)
scaled_dot_product_attention: R.Tensor(
(2, 4096, 10, 64), dtype="float32"
) = R.nn.attention(reshape, reshape1, reshape2, scale=None, causal_mask=None)
reshape3: R.Tensor((2, 4096, 640), dtype="float32") = R.reshape(
scaled_dot_product_attention, R.shape([2, 4096, 640])
)
permute_dims3: R.Tensor((640, 640), dtype="float32") = R.permute_dims(
to_out_0_weight, axes=None
)
matmul3: R.Tensor((2, 4096, 640), dtype="float32") = R.matmul(
reshape3, permute_dims3, out_dtype="void"
)
add: R.Tensor((2, 4096, 640), dtype="float32") = R.add(matmul3, to_out_0_bias)
gv1: R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"), R.Tuple(R.Object)) = add, (_io,)
R.output(gv1)
return gv1

mod = modules.Attention(query_dim=640, cross_attention_dim=2048, heads=10, norm_num_groups=8)
tvm_mod, _ = mod.export_tvm(
spec={
"forward": {
"hidden_states": spec.Tensor((2, 4096, 640), "float32"),
"encoder_hidden_states": spec.Tensor((2, 77, 2048), "float32"),
}
}
)
assert_structural_equal(tvm_mod["forward"], forward, True)


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit d5dcabf

Please sign in to comment.