Skip to content

Commit

Permalink
multi gpu update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 6, 2019
1 parent 46fe125 commit 2862a81
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 2 deletions.
2 changes: 2 additions & 0 deletions cuda/basis_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ template <typename scalar_t> struct BasisForward {

#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, KERNEL_NAME) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
cudaSetDevice(PSEUDO.get_device()); \
auto E = PSEUDO.size(0); \
auto S = (int64_t)(powf(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.options()); \
Expand Down Expand Up @@ -163,6 +164,7 @@ template <typename scalar_t> struct BasisBackward {
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
KERNEL_NAME) \
[&]() -> at::Tensor { \
cudaSetDevice(GRAD_BASIS.get_device()); \
auto E = PSEUDO.size(0); \
auto D = PSEUDO.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.options()); \
Expand Down
4 changes: 4 additions & 0 deletions cuda/weighting_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,

at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
cudaSetDevice(x.get_device());
auto E = x.size(0), M_out = weight.size(2);
auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] {
Expand Down Expand Up @@ -86,6 +87,7 @@ __global__ void weighting_bw_x_kernel(

at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0), M_in = weight.size(1);
auto grad_x = at::empty({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous();
Expand Down Expand Up @@ -131,6 +133,7 @@ __global__ void weighting_bw_w_kernel(
at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor basis, at::Tensor weight_index,
int64_t K) {
cudaSetDevice(grad_out.get_device());
auto M_in = x.size(1), M_out = grad_out.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] {
Expand Down Expand Up @@ -175,6 +178,7 @@ __global__ void weighting_bw_b_kernel(

at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor weight, at::Tensor weight_index) {
cudaSetDevice(grad_out.get_device());
auto E = x.size(0), S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu']),
]

__version__ = '1.0.5'
__version__ = '1.0.6'
url = 'https://github.com/rusty1s/pytorch_spline_conv'

install_requires = []
Expand Down
2 changes: 1 addition & 1 deletion torch_spline_conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .weighting import SplineWeighting
from .conv import SplineConv

__version__ = '1.0.5'
__version__ = '1.0.6'

__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__']

0 comments on commit 2862a81

Please sign in to comment.