Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hust17yixuan committed Dec 2, 2024
1 parent db6da71 commit 007ba98
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 88 deletions.
13 changes: 7 additions & 6 deletions mmcv/ops/csrc/pytorch/npu/assign_score_withk_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ using namespace std;

void assign_score_withk_forward_npu(int B, int N0, int N1, int M, int K, int O,
int aggregate, const Tensor& points,
const Tensor& centers,
const Tensor& scores,
const Tensor& centers, const Tensor& scores,
const Tensor& knn_idx, Tensor& output) {
at::Tensor points_trans = points.permute({0, 3, 1, 2});
at::Tensor centers_trans = centers.permute({0, 3, 1, 2});
EXEC_NPU_CMD(aclnnAssignScoreWithk, points_trans, centers_trans, scores, knn_idx, B, N0, N1, M, K, O, aggregate, output);
at::Tensor points_trans = points.permute({0, 3, 1, 2});
at::Tensor centers_trans = centers.permute({0, 3, 1, 2});
EXEC_NPU_CMD(aclnnAssignScoreWithk, points_trans, centers_trans, scores,
knn_idx, B, N0, N1, M, K, O, aggregate, output);
}

void assign_score_withk_forward_impl(int B, int N0, int N1, int M, int K, int O,
Expand All @@ -19,4 +19,5 @@ void assign_score_withk_forward_impl(int B, int N0, int N1, int M, int K, int O,
const Tensor& scores,
const Tensor& knn_idx, Tensor& output);

REGISTER_NPU_IMPL(assign_score_withk_forward_impl, assign_score_withk_forward_npu);
REGISTER_NPU_IMPL(assign_score_withk_forward_impl,
assign_score_withk_forward_npu);
167 changes: 85 additions & 82 deletions tests/test_ops/test_assign_score_withk.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,88 +17,91 @@
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_paconv_assign_scores(device):
scores = torch.tensor([[[[0.06947571, 0.6065746], [0.28462553, 0.8378516],
[0.7595994, 0.97220325], [0.519155, 0.766185]],
[[0.15348864, 0.6051019], [0.21510637, 0.31916398],
[0.00236845, 0.5842595], [0.6783676, 0.5216348]]],
[[[0.23089725, 0.5568468], [0.7405102, 0.06438422],
[0.6887394, 0.22089851], [0.0502342, 0.79228795]],
[[0.44883424, 0.15427643],
[0.13817799, 0.34856772], [0.7989621, 0.33788306],
[0.15699774, 0.7693662]]]]).float().cuda()
scores.requires_grad_()
points = torch.tensor([[[[0.06001121, 0.92963666, 0.5753327, 0.7251477],
[0.53563064, 0.23129565, 0.92366195, 0.44261628]],
[[0.5770022, 0.56625944, 0.23560429, 0.11178821],
[0.7735967, 0.95678777, 0.25468266, 0.02895975]],
[[0.0589869, 0.09017515, 0.5977862, 0.02797985],
[0.603862, 0.35991007, 0.85761684, 0.3096559]],
[[0.22359002, 0.13983732, 0.5544243, 0.68863827],
[0.85646236, 0.75651926, 0.8638947, 0.83600986]],
[[0.45424145, 0.27458847, 0.6456112, 0.47162914],
[0.15773582, 0.47645122, 0.79964715, 0.3323908]],
[[0.8351399, 0.84696376, 0.9431732, 0.29418713],
[0.77168906, 0.6996871, 0.19354361, 0.03392768]],
[[0.30976456, 0.7074133, 0.581795, 0.976677],
[0.69656056, 0.07199162, 0.4708506, 0.29117996]],
[[0.5829035, 0.30201727, 0.76556486, 0.0935446],
[0.88030535, 0.16129416, 0.9242525, 0.49545723]]],
[[[0.50899494, 0.06482804, 0.44939405, 0.37704808],
[0.47028124, 0.11969638, 0.62823206, 0.28560323]],
[[0.40690207, 0.689753, 0.51636654, 0.23040164],
[0.06935787, 0.00488842, 0.22462702, 0.09182382]],
[[0.26611632, 0.00184339, 0.7730655, 0.5228131],
[0.87776035, 0.77895886, 0.2787183, 0.16620636]],
[[0.502574, 0.04039001, 0.5368497, 0.98379374],
[0.40973026, 0.3238272, 0.9733018, 0.13988364]],
[[0.04586202, 0.20983845, 0.20662665, 0.22270602],
[0.60387236, 0.5155574, 0.51237285, 0.6528438]],
[[0.45735973, 0.86821306, 0.61054605, 0.8370336],
[0.45193362, 0.3734138, 0.7825672, 0.5699416]],
[[0.44591594, 0.12447512, 0.09282011, 0.7055254],
[0.25223452, 0.46696228, 0.7051136, 0.892151]],
[[0.49615085, 0.47321403, 0.93138885, 0.7652197],
[0.38766378, 0.30332977, 0.23131835,
0.02863514]]]]).float().cuda()
points.requires_grad_()
centers = torch.tensor([[[[0.83878064, 0.96658987, 0.8033424, 0.9598312],
[0.45035273, 0.8768925, 0.977736, 0.54547966]],
[[0.01041394, 0.597893, 0.36212963, 0.4410367],
[0.94879234, 0.8372817, 0.21237361, 0.67945415]],
[[0.5096087, 0.26401454, 0.60034937, 0.5417416],
[0.87591463, 0.546456, 0.4096033, 0.16373193]],
[[0.79547447, 0.1482386, 0.12840575, 0.45384115],
[0.5640288, 0.944541, 0.5745328, 0.73229736]],
[[0.93011934, 0.7406011, 0.62621707, 0.8677915],
[0.91563636, 0.3595413, 0.6678378, 0.6085383]],
[[0.22431666, 0.65617776, 0.7483924, 0.6263364],
[0.30968404, 0.78204364, 0.14899081,
0.09628749]],
[[0.73675203, 0.72104895, 0.4648038, 0.6101647],
[0.7817645, 0.16572917, 0.3311919, 0.43407398]],
[[0.8193154, 0.09559608, 0.05978829, 0.90262103],
[0.4256065, 0.8165596, 0.8206446, 0.6604721]]],
[[[0.7159653, 0.18600845, 0.21433902, 0.3159626],
[0.3921569, 0.33221376, 0.5061177, 0.7961841]],
[[0.95338356, 0.04785997, 0.67185795, 0.6538394],
[0.4729132, 0.33404195, 0.17750603, 0.8445621]],
[[0.6755793, 0.16193843, 0.75943846, 0.92123103],
[0.2781859, 0.03114432, 0.710638, 0.52729136]],
[[0.8376105, 0.10858494, 0.13208169, 0.365772],
[0.5930795, 0.27390373, 0.14036089, 0.170403]],
[[0.3479789, 0.89855295, 0.04844379, 0.9871029],
[0.29781651, 0.0244137, 0.9179047, 0.8081611]],
[[0.12460887, 0.44991326, 0.19382608, 0.35037738],
[0.2773472, 0.4362057, 0.36757517, 0.5993509]],
[[0.29630446, 0.90046406, 0.5417113, 0.13510644],
[0.09623539, 0.04226565, 0.32001644,
0.44358212]],
[[0.5274848, 0.82096446, 0.9415489, 0.7123748],
[0.7537517, 0.8086482, 0.85345286,
0.7472754]]]]).float().cuda()
centers.requires_grad_()
knn_idx = torch.tensor([[[6, 7, 4, 6], [2, 4, 2, 4]],
[[7, 1, 3, 2], [6, 0, 2, 6]]]).long().cuda()
scores = torch.tensor(
[[[[0.06947571, 0.6065746], [0.28462553, 0.8378516],
[0.7595994, 0.97220325], [0.519155, 0.766185]],
[[0.15348864, 0.6051019], [0.21510637, 0.31916398],
[0.00236845, 0.5842595], [0.6783676, 0.5216348]]],
[[[0.23089725, 0.5568468], [0.7405102, 0.06438422],
[0.6887394, 0.22089851], [0.0502342, 0.79228795]],
[[0.44883424, 0.15427643], [0.13817799, 0.34856772],
[0.7989621, 0.33788306], [0.15699774, 0.7693662]]]],
device=device).float()
points = torch.tensor(
[[[[0.06001121, 0.92963666, 0.5753327, 0.7251477],
[0.53563064, 0.23129565, 0.92366195, 0.44261628]],
[[0.5770022, 0.56625944, 0.23560429, 0.11178821],
[0.7735967, 0.95678777, 0.25468266, 0.02895975]],
[[0.0589869, 0.09017515, 0.5977862, 0.02797985],
[0.603862, 0.35991007, 0.85761684, 0.3096559]],
[[0.22359002, 0.13983732, 0.5544243, 0.68863827],
[0.85646236, 0.75651926, 0.8638947, 0.83600986]],
[[0.45424145, 0.27458847, 0.6456112, 0.47162914],
[0.15773582, 0.47645122, 0.79964715, 0.3323908]],
[[0.8351399, 0.84696376, 0.9431732, 0.29418713],
[0.77168906, 0.6996871, 0.19354361, 0.03392768]],
[[0.30976456, 0.7074133, 0.581795, 0.976677],
[0.69656056, 0.07199162, 0.4708506, 0.29117996]],
[[0.5829035, 0.30201727, 0.76556486, 0.0935446],
[0.88030535, 0.16129416, 0.9242525, 0.49545723]]],
[[[0.50899494, 0.06482804, 0.44939405, 0.37704808],
[0.47028124, 0.11969638, 0.62823206, 0.28560323]],
[[0.40690207, 0.689753, 0.51636654, 0.23040164],
[0.06935787, 0.00488842, 0.22462702, 0.09182382]],
[[0.26611632, 0.00184339, 0.7730655, 0.5228131],
[0.87776035, 0.77895886, 0.2787183, 0.16620636]],
[[0.502574, 0.04039001, 0.5368497, 0.98379374],
[0.40973026, 0.3238272, 0.9733018, 0.13988364]],
[[0.04586202, 0.20983845, 0.20662665, 0.22270602],
[0.60387236, 0.5155574, 0.51237285, 0.6528438]],
[[0.45735973, 0.86821306, 0.61054605, 0.8370336],
[0.45193362, 0.3734138, 0.7825672, 0.5699416]],
[[0.44591594, 0.12447512, 0.09282011, 0.7055254],
[0.25223452, 0.46696228, 0.7051136, 0.892151]],
[[0.49615085, 0.47321403, 0.93138885, 0.7652197],
[0.38766378, 0.30332977, 0.23131835, 0.02863514]]]],
device=device).float()
centers = torch.tensor(
[[[[0.83878064, 0.96658987, 0.8033424, 0.9598312],
[0.45035273, 0.8768925, 0.977736, 0.54547966]],
[[0.01041394, 0.597893, 0.36212963, 0.4410367],
[0.94879234, 0.8372817, 0.21237361, 0.67945415]],
[[0.5096087, 0.26401454, 0.60034937, 0.5417416],
[0.87591463, 0.546456, 0.4096033, 0.16373193]],
[[0.79547447, 0.1482386, 0.12840575, 0.45384115],
[0.5640288, 0.944541, 0.5745328, 0.73229736]],
[[0.93011934, 0.7406011, 0.62621707, 0.8677915],
[0.91563636, 0.3595413, 0.6678378, 0.6085383]],
[[0.22431666, 0.65617776, 0.7483924, 0.6263364],
[0.30968404, 0.78204364, 0.14899081, 0.09628749]],
[[0.73675203, 0.72104895, 0.4648038, 0.6101647],
[0.7817645, 0.16572917, 0.3311919, 0.43407398]],
[[0.8193154, 0.09559608, 0.05978829, 0.90262103],
[0.4256065, 0.8165596, 0.8206446, 0.6604721]]],
[[[0.7159653, 0.18600845, 0.21433902, 0.3159626],
[0.3921569, 0.33221376, 0.5061177, 0.7961841]],
[[0.95338356, 0.04785997, 0.67185795, 0.6538394],
[0.4729132, 0.33404195, 0.17750603, 0.8445621]],
[[0.6755793, 0.16193843, 0.75943846, 0.92123103],
[0.2781859, 0.03114432, 0.710638, 0.52729136]],
[[0.8376105, 0.10858494, 0.13208169, 0.365772],
[0.5930795, 0.27390373, 0.14036089, 0.170403]],
[[0.3479789, 0.89855295, 0.04844379, 0.9871029],
[0.29781651, 0.0244137, 0.9179047, 0.8081611]],
[[0.12460887, 0.44991326, 0.19382608, 0.35037738],
[0.2773472, 0.4362057, 0.36757517, 0.5993509]],
[[0.29630446, 0.90046406, 0.5417113, 0.13510644],
[0.09623539, 0.04226565, 0.32001644, 0.44358212]],
[[0.5274848, 0.82096446, 0.9415489, 0.7123748],
[0.7537517, 0.8086482, 0.85345286, 0.7472754]]]],
device=device).float()
if device == 'cuda':
points.requires_grad_()
scores.requires_grad_()
centers.requires_grad_()
knn_idx = torch.tensor(
[[[6, 7, 4, 6], [2, 4, 2, 4]], [[7, 1, 3, 2], [6, 0, 2, 6]]],
device=device).long()
aggregate = 'sum'
expected_output = torch.tensor(
[[[[-0.08134781, 0.03877336, -0.8212776, -0.2869547],
Expand Down

0 comments on commit 007ba98

Please sign in to comment.