diff --git a/point_transformer_pytorch/multihead_point_transformer_pytorch.py b/point_transformer_pytorch/multihead_point_transformer_pytorch.py index 926c519..ec6a740 100644 --- a/point_transformer_pytorch/multihead_point_transformer_pytorch.py +++ b/point_transformer_pytorch/multihead_point_transformer_pytorch.py @@ -57,9 +57,9 @@ def __init__( attn_inner_dim = inner_dim * attn_mlp_hidden_mult self.attn_mlp = nn.Sequential( - nn.Conv2d(inner_dim, attn_inner_dim, 1), + nn.Conv2d(inner_dim, attn_inner_dim, 1, groups = heads), nn.ReLU(), - nn.Conv2d(attn_inner_dim, heads, 1), + nn.Conv2d(attn_inner_dim, heads, 1, groups = heads), ) def forward(self, x, pos, mask = None): diff --git a/setup.py b/setup.py index 4f68cc6..359f2d2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'point-transformer-pytorch', packages = find_packages(), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'Point Transformer - Pytorch', author = 'Phil Wang',