Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Large Tensor] Fixed Spatial Transformer op (#17617)
Browse files Browse the repository at this point in the history
* Added CPU fix

* Added fix for backward on CPU

* Fixed lint error

* index_t for lower bound instead of hardcoded long int

* Fixed remaining lint errors

* Removed trailing whitespace

* Reverting to DType for vertices

* Added nightly test for SpatialTransformer
  • Loading branch information
connorgoggins authored Feb 25, 2020
1 parent 5098dbe commit 13b3893
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
47 changes: 29 additions & 18 deletions src/operator/spatial_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
DType *out = output.dptr_;
const DType *data = input.dptr_;
const DType *grid = grid_src.dptr_;
const int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3);
const int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
const index_t o_n = output.size(0), o_c = output.size(1),
o_h = output.size(2), o_w = output.size(3);
const index_t i_c = input.size(1), i_h = input.size(2), i_w = input.size(3);
for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
Expand All @@ -51,23 +52,28 @@ inline void BilinearSamplingForward(const Tensor<cpu, 4, DType> &output,
const index_t grid_index = n * o_h * o_w * 2 + h * o_w + w;
const DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2;
const DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2;
const auto top_left_y = static_cast<int>(std::floor(y_real));
const auto top_left_x = static_cast<int>(std::floor(x_real));
const auto top_left_y = static_cast<index_t>(std::floor(y_real));
const auto top_left_x = static_cast<index_t>(std::floor(x_real));
const DType top_left_y_w = 1.0 - (y_real - top_left_y);
const DType top_left_x_w = 1.0 - (x_real - top_left_x);
const int data_index = n * i_c * i_h * i_w + c * i_h * i_w +
const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
top_left_y * i_w + top_left_x;
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1))
index_t lower_bound = 0;
if (between(top_left_x, lower_bound, i_w-1) &&
between(top_left_y, lower_bound, i_h-1))
top_left_v = *(data + data_index);
if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1))
if (between(top_left_x + 1, lower_bound, i_w-1) &&
between(top_left_y, lower_bound, i_h-1))
top_right_v = *(data + data_index + 1);
if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
if (between(top_left_x, lower_bound, i_w-1) &&
between(top_left_y + 1, lower_bound, i_h-1))
bottom_left_v = *(data + data_index + i_w);
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1))
if (between(top_left_x+1, lower_bound, i_w-1) &&
between(top_left_y + 1, lower_bound, i_h-1))
bottom_right_v = *(data + data_index + i_w + 1);
*(out+out_index) = top_left_v * top_left_y_w * top_left_x_w +
top_right_v * top_left_y_w * (1.0 - top_left_x_w) +
Expand All @@ -88,9 +94,9 @@ inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
DType *grid_src = grid_src_data.dptr_;
const DType *grad = output_grad.dptr_;
const DType *data = input_data.dptr_;
const int o_n = output_grad.size(0), o_c = output_grad.size(1),
const index_t o_n = output_grad.size(0), o_c = output_grad.size(1),
o_h = output_grad.size(2), o_w = output_grad.size(3);
const int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
const index_t i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3);
for (index_t n = 0; n < static_cast<index_t>(o_n); ++n) {
for (index_t h = 0; h < static_cast<index_t>(o_h); ++h) {
for (index_t w = 0; w < static_cast<index_t>(o_w); ++w) {
Expand All @@ -99,34 +105,39 @@ inline void BilinearSamplingBackward(const Tensor<cpu, 4, DType> &input_grad,
const index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w;
const DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2;
const DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2;
const auto top_left_y = static_cast<int>(std::floor(y_real));
const auto top_left_x = static_cast<int>(std::floor(x_real));
const auto top_left_y = static_cast<index_t>(std::floor(y_real));
const auto top_left_x = static_cast<index_t>(std::floor(x_real));
const DType top_left_y_w = 1.0 - (y_real - top_left_y);
const DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
const int data_index = n * i_c * i_h * i_w + c * i_h * i_w +
const index_t data_index = n * i_c * i_h * i_w + c * i_h * i_w +
top_left_y * i_w + top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
index_t lower_bound = 0;
if (between(top_left_x, lower_bound, i_w-1) &&
between(top_left_y, lower_bound, i_h-1)) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
if (between(top_left_x+1, lower_bound, i_w-1) &&
between(top_left_y, lower_bound, i_h-1)) {
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w);
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
if (between(top_left_x, lower_bound, i_w-1) &&
between(top_left_y+1, lower_bound, i_h-1)) {
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w;
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
if (between(top_left_x+1, lower_bound, i_w-1) &&
between(top_left_y+1, lower_bound, i_h-1)) {
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w);
bottom_right_v = *(data + data_index + i_w + 1);
Expand Down
17 changes: 17 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def check_col2im():
assert res.shape[2] == 2
assert res.shape[3] == 2
assert res.shape[4] == 1

def check_embedding():
data = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
weight = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
Expand All @@ -479,6 +480,21 @@ def check_embedding():
assert out.shape[0] == LARGE_TENSOR_SHAPE
assert out.shape[1] == 1
assert out.shape[2] == 1

def check_spatial_transformer():
data = nd.random_normal(shape=(2, 2**29, 1, 6))
loc = nd.random_normal(shape=(2, 6))
transform_type = 'affine'
sampler_type = 'bilinear'
target_shape = (2, 6)

res = nd.SpatialTransformer(data=data, loc=loc, transform_type=transform_type,
sampler_type=sampler_type, target_shape=target_shape)

assert res.shape[0] == 2
assert res.shape[1] == 536870912
assert res.shape[2] == 2
assert res.shape[3] == 6

check_gluon_embedding()
check_fully_connected()
Expand All @@ -501,6 +517,7 @@ def check_embedding():
check_instance_norm()
check_col2im()
check_embedding()
check_spatial_transformer()


def test_tensor():
Expand Down

0 comments on commit 13b3893

Please sign in to comment.