Skip to content

Commit

Permalink
add torch==2.4.0, fix warning
Browse files Browse the repository at this point in the history
  • Loading branch information
mjun0812 committed Aug 16, 2024
1 parent d71af3f commit d3ace38
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
fail-fast: false
matrix:
python: ["3.11", "3.10"]
torch: ["2.1.2", "2.2.1", "2.3.1"]
torch: ["2.1.2", "2.2.1", "2.3.1", "2.4.0"]
cuda: ["11.8.0", "12.1.1", "12.2.2"]

steps:
Expand Down
2 changes: 1 addition & 1 deletion torch_cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .src import * # noqa: F401, F403

__version__ = "1.0.5"
__version__ = "1.1.0"
26 changes: 20 additions & 6 deletions torch_cpp/src/DCNv3/dcn_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.init import constant_, xavier_uniform_

try:
from torch.amp import custom_bwd, custom_fwd
except ImportError:
from torch.cuda.amp import custom_bwd, custom_fwd

from torch_cpp import _C


Expand Down Expand Up @@ -69,13 +73,17 @@ def build_act_layer(act_layer):

def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
raise ValueError(
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
)

return (n & (n - 1) == 0) and n != 0


class CenterFeatureScaleModule(nn.Module):
def forward(self, query, center_feature_scale_proj_weight, center_feature_scale_proj_bias):
def forward(
self, query, center_feature_scale_proj_weight, center_feature_scale_proj_bias
):
center_feature_scale = F.linear(
query,
weight=center_feature_scale_proj_weight,
Expand Down Expand Up @@ -154,8 +162,12 @@ def __init__(
build_norm_layer(channels, norm_layer, "channels_first", "channels_last"),
build_act_layer(act_layer),
)
self.offset = nn.Linear(channels, group * (kernel_size * kernel_size - remove_center) * 2)
self.mask = nn.Linear(channels, group * (kernel_size * kernel_size - remove_center))
self.offset = nn.Linear(
channels, group * (kernel_size * kernel_size - remove_center) * 2
)
self.mask = nn.Linear(
channels, group * (kernel_size * kernel_size - remove_center)
)
self.input_proj = nn.Linear(channels, channels)
self.output_proj = nn.Linear(channels, channels)
self._reset_parameters()
Expand Down Expand Up @@ -221,7 +233,9 @@ def forward(self, input):

if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias
x1,
self.center_feature_scale_proj_weight,
self.center_feature_scale_proj_bias,
)
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
center_feature_scale = (
Expand Down
25 changes: 20 additions & 5 deletions torch_cpp/src/DCNv3/dcn_v3_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def dcnv3_core_pytorch(
# for debug and test only,
# need to use cuda version instead

if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h):
raise ValueError("remove_center is only compatible with square odd kernel size.")
if remove_center and (
kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h
):
raise ValueError(
"remove_center is only compatible with square odd kernel size."
)

input = F.pad(input, [0, 0, pad_h, pad_h, pad_w, pad_w])
N_, H_in, W_in, _ = input.shape
Expand Down Expand Up @@ -127,7 +131,9 @@ def dcnv3_core_pytorch(
)
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
sampling_grid_ = (
sampling_grids.view(N_, H_out * W_out, group, P_, 2).transpose(1, 2).flatten(0, 1)
sampling_grids.view(N_, H_out * W_out, group, P_, 2)
.transpose(1, 2)
.flatten(0, 1)
)
# N_*group, group_channels, H_out*W_out, P_
sampling_input_ = F.grid_sample(
Expand All @@ -144,7 +150,9 @@ def dcnv3_core_pytorch(
.transpose(1, 2)
.reshape(N_ * group, 1, H_out * W_out, P_)
)
output = (sampling_input_ * mask).sum(-1).view(N_, group * group_channels, H_out * W_out)
output = (
(sampling_input_ * mask).sum(-1).view(N_, group * group_channels, H_out * W_out)
)

return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()

Expand Down Expand Up @@ -184,6 +192,7 @@ def _get_reference_points(
dtype=torch.float32,
device=device,
),
indexing="ij",
)
ref_y = ref_y.reshape(-1)[None] / H_
ref_x = ref_x.reshape(-1)[None] / W_
Expand Down Expand Up @@ -213,10 +222,16 @@ def _generate_dilation_grids(
dtype=torch.float32,
device=device,
),
indexing="ij",
)

points_list.extend([x / W_, y / H_])
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).repeat(1, group, 1).permute(1, 0, 2)
grid = (
torch.stack(points_list, -1)
.reshape(-1, 1, 2)
.repeat(1, group, 1)
.permute(1, 0, 2)
)
grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)

return grid
Expand Down
6 changes: 5 additions & 1 deletion torch_cpp/src/DCNv4/functions/dcnv4_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd

try:
from torch.amp import custom_bwd, custom_fwd
except ImportError:
from torch.cuda.amp import custom_bwd, custom_fwd

from torch_cpp import _C

Expand Down

0 comments on commit d3ace38

Please sign in to comment.