Skip to content

Commit

Permalink
Fix stride indexing bugs in reorg and reorg_gradient functions (C…
Browse files Browse the repository at this point in the history
…PU & CUDA) (#3012)

* Fix Stride Indexing Bugs in `reorg` and `reorg_gradient` Functions (CPU & CUDA) and Add `add_to` Parameter

* 'add_to' parameter missing in cuda call reorg_gradient.launch_kernel()

* Cleanup: remove using namespace std; (#3016)

* remove using namespace std from headers

* more std::

* more std::

* more std:: on windows stuff

* remove uses of using namespace std::chrono

* do not use C++17 features

* Add Davis suggestion

* revert some more stuff

* revert removing include

* more std::chrono stuff

* fix build error

* Adjust comment formatting to be like other dlib comments

---------

Co-authored-by: Adrià <1671644+arrufat@users.noreply.github.com>
Co-authored-by: Davis King <davis@dlib.net>
  • Loading branch information
3 people authored Sep 23, 2024
1 parent 90c8d78 commit 72822fe
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 129 deletions.
91 changes: 51 additions & 40 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2333,58 +2333,67 @@ namespace dlib

// ----------------------------------------------------------------------------------------

void reorg (
void reorg(
bool add_to,
tensor& dest,
const int row_stride,
const int col_stride,
const tensor& src
)
{
DLIB_CASSERT(is_same_object(dest, src)==false);
DLIB_CASSERT(src.nr() % row_stride == 0);
DLIB_CASSERT(src.nc() % col_stride == 0);
DLIB_CASSERT(dest.num_samples() == src.num_samples());
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride);
DLIB_CASSERT(dest.nr() == src.nr() / row_stride);
DLIB_CASSERT(dest.nc() == src.nc() / col_stride);
DLIB_CASSERT(!is_same_object(dest, src), "Destination and source must be distinct objects.");
DLIB_CASSERT(src.nr() % row_stride == 0, "The number of rows in src must be divisible by row_stride.");
DLIB_CASSERT(src.nc() % col_stride == 0, "The number of columns in src must be divisible by col_stride.");
DLIB_CASSERT(dest.num_samples() == src.num_samples(), "The number of samples must match.");
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride, "The number of channels must match.");
DLIB_CASSERT(dest.nr() == src.nr() / row_stride, "The number of rows must match.");
DLIB_CASSERT(dest.nc() == src.nc() / col_stride, "The number of columns must match.");

const float* s = src.host();
float* d = dest.host();

parallel_for(0, dest.num_samples(), [&](long n)
const size_t sk = src.k(), snr = src.nr(), snc = src.nc();
const size_t dk = dest.k(), dnr = dest.nr(), dnc = dest.nc(), dsize = dest.size();

dlib::parallel_for(0, dsize, [&](long i)
{
for (long k = 0; k < dest.k(); ++k)
{
for (long r = 0; r < dest.nr(); ++r)
{
for (long c = 0; c < dest.nc(); ++c)
{
const auto out_idx = tensor_index(dest, n, k, r, c);
const auto in_idx = tensor_index(src,
n,
k % src.k(),
r * row_stride + (k / src.k()) / row_stride,
c * col_stride + (k / src.k()) % col_stride);
d[out_idx] = s[in_idx];
}
}
}
const size_t out_plane_size = dnr * dnc;
const size_t out_sample_size = dk * out_plane_size;

const size_t n = i / out_sample_size;
const size_t out_idx = i % out_sample_size;
const size_t out_k = out_idx / out_plane_size;
const size_t out_rc = out_idx % out_plane_size;
const size_t out_r = out_rc / dnc;
const size_t out_c = out_rc % dnc;

const size_t in_k = out_k % sk;
const size_t in_r = out_r * row_stride + (out_k / sk) / col_stride;
const size_t in_c = out_c * col_stride + (out_k / sk) % col_stride;

const size_t in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c;

if (add_to) d[i] += s[in_idx];
else d[i] = s[in_idx];
});
}

void reorg_gradient (
void reorg_gradient(
bool add_to,
tensor& grad,
const int row_stride,
const int col_stride,
const tensor& gradient_input
)
{
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
DLIB_CASSERT(grad.nr() % row_stride == 0);
DLIB_CASSERT(grad.nc() % col_stride == 0);
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples());
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride);
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride);
DLIB_CASSERT(grad.nc() == gradient_input.nc() * row_stride);
DLIB_CASSERT(!is_same_object(grad, gradient_input), "Grad and gradient_input must be distinct objects.");
DLIB_CASSERT(grad.nr() % row_stride == 0, "The number of rows in grad must be divisible by row_stride.");
DLIB_CASSERT(grad.nc() % col_stride == 0, "The number of columns in grad must be divisible by col_stride.");
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples(), "The number of samples in grad and gradient_input must match.");
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride, "The number of channels in grad must be gradient_input.k() divided by row_stride and col_stride.");
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride, "The number of rows in grad must be gradient_input.nr() multiplied by row_stride.");
DLIB_CASSERT(grad.nc() == gradient_input.nc() * col_stride, "The number of columns in grad must be gradient_input.nc() multiplied by col_stride.");

const float* gi = gradient_input.host();
float* g = grad.host();

Expand All @@ -2396,13 +2405,15 @@ namespace dlib
{
for (long c = 0; c < gradient_input.nc(); ++c)
{
const auto in_idx = tensor_index(gradient_input, n, k, r, c);
const auto out_idx = tensor_index(grad,
n,
k % grad.k(),
r * row_stride + (k / grad.k()) / row_stride,
c * col_stride + (k / grad.k()) % col_stride);
g[out_idx] += gi[in_idx];
const auto in_idx = tensor_index(gradient_input, n, k, r, c);
const auto out_idx = tensor_index(grad,
n,
k % grad.k(),
r * row_stride + (k / grad.k()) / col_stride,
c * col_stride + (k / grad.k()) % col_stride);

if (add_to) g[out_idx] += gi[in_idx];
else g[out_idx] = gi[in_idx];
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,13 +502,15 @@ namespace dlib
// -----------------------------------------------------------------------------------

void reorg (
bool add_to,
tensor& dest,
const int row_stride,
const int col_stride,
const tensor& src
);

void reorg_gradient (
bool add_to,
tensor& grad,
const int row_stride,
const int col_stride,
Expand Down
87 changes: 46 additions & 41 deletions dlib/cuda/cuda_dlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2001,86 +2001,91 @@ namespace dlib

__global__ void _cuda_reorg(size_t dsize, size_t dk, size_t dnr, size_t dnc, float* d,
size_t sk, size_t snr, int snc, const float* s,
const size_t row_stride, const size_t col_stride)
const size_t row_stride, const size_t col_stride, const bool add_to)
{
const auto out_plane_size = dnr * dnc;
const auto sample_size = dk * out_plane_size;
for(auto i : grid_stride_range(0, dsize))
const auto out_sample_size = dk * out_plane_size;
for (auto i : grid_stride_range(0, dsize))
{
const auto n = i / sample_size;
const auto idx = i % out_plane_size;
const auto out_k = (i / out_plane_size) % dk;
const auto out_r = idx / dnc;
const auto out_c = idx % dnc;
const auto n = i / out_sample_size;
const auto out_idx = i % out_sample_size;
const auto out_k = out_idx / out_plane_size;
const auto out_rc = out_idx % out_plane_size;
const auto out_r = out_rc / dnc;
const auto out_c = out_rc % dnc;

const auto in_k = out_k % sk;
const auto in_r = out_r * row_stride + (out_k / sk) / row_stride;
const auto in_r = out_r * row_stride + (out_k / sk) / col_stride;
const auto in_c = out_c * col_stride + (out_k / sk) % col_stride;

const auto in_idx = ((n * sk + in_k) * snr + in_r) * snc + in_c;
d[i] = s[in_idx];
if (add_to) d[i] += s[in_idx];
else d[i] = s[in_idx];
}
}

__global__ void _cuda_reorg_gradient(size_t ssize, size_t dk, size_t dnr, size_t dnc, float* d,
size_t sk, size_t snr, int snc, const float* s,
const size_t row_stride, const size_t col_stride)
size_t sk, size_t snr, int snc, const float* s, const size_t row_stride,
const size_t col_stride, const bool add_to
)
{
const auto in_plane_size = snr * snc;
const auto sample_size = sk * in_plane_size;
for(auto i : grid_stride_range(0, ssize))
{
const auto n = i / sample_size;
const auto idx = i % in_plane_size;
const auto in_k = (i / in_plane_size) % sk;
const auto in_r = idx / snc;
const auto in_c = idx % snc;
const auto n = i / (sk * snr * snc);
const auto sample_idx = i % (sk * snr * snc);
const auto in_k = (sample_idx / (snr * snc)) % sk;
const auto in_r = (sample_idx / snc) % snr;
const auto in_c = sample_idx % snc;

const auto out_k = in_k % dk;
const auto out_r = in_r * row_stride + (in_k / dk) / row_stride;
const auto out_r = in_r * row_stride + (in_k / dk) / col_stride;
const auto out_c = in_c * col_stride + (in_k / dk) % col_stride;

const auto out_idx = ((n * dk + out_k) * dnr + out_r) * dnc + out_c;
d[out_idx] += s[i];

if (add_to) d[out_idx] += s[i];
else d[out_idx] = s[i];
}
}

void reorg (
void reorg(
bool add_to,
tensor& dest,
const int row_stride,
const int col_stride,
const tensor& src
)
{
DLIB_CASSERT(is_same_object(dest, src)==false);
DLIB_CASSERT(src.nr() % row_stride == 0);
DLIB_CASSERT(src.nc() % col_stride == 0);
DLIB_CASSERT(dest.num_samples() == src.num_samples());
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride);
DLIB_CASSERT(dest.nr() == src.nr() / row_stride);
DLIB_CASSERT(dest.nc() == src.nc() / col_stride);
DLIB_CASSERT(!is_same_object(dest, src), "Destination and source must be distinct objects.");
DLIB_CASSERT(src.nr() % row_stride == 0, "The number of rows in src must be divisible by row_stride.");
DLIB_CASSERT(src.nc() % col_stride == 0, "The number of columns in src must be divisible by col_stride.");
DLIB_CASSERT(dest.num_samples() == src.num_samples(), "The number of samples must match.");
DLIB_CASSERT(dest.k() == src.k() * row_stride * col_stride, "The number of channels must match.");
DLIB_CASSERT(dest.nr() == src.nr() / row_stride, "The number of rows must match.");
DLIB_CASSERT(dest.nc() == src.nc() / col_stride, "The number of columns must match.");

launch_kernel(_cuda_reorg, dest.size(), dest.k(), dest.nr(), dest.nc(), dest.device(),
src.k(), src.nr(), src.nc(), src.device(), row_stride, col_stride);
src.k(), src.nr(), src.nc(), src.device(), row_stride, col_stride, add_to);
}

void reorg_gradient (
void reorg_gradient(
bool add_to,
tensor& grad,
const int row_stride,
const int col_stride,
const tensor& gradient_input
)
{
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
DLIB_CASSERT(grad.nr() % row_stride == 0);
DLIB_CASSERT(grad.nc() % col_stride == 0);
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples());
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride);
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride);
DLIB_CASSERT(grad.nc() == gradient_input.nc() * row_stride);
DLIB_CASSERT(!is_same_object(grad, gradient_input), "Grad and gradient_input must be distinct objects.");
DLIB_CASSERT(grad.nr() % row_stride == 0, "The number of rows in grad must be divisible by row_stride.");
DLIB_CASSERT(grad.nc() % col_stride == 0, "The number of columns in grad must be divisible by col_stride.");
DLIB_CASSERT(grad.num_samples() == gradient_input.num_samples(), "The number of samples in grad and gradient_input must match.");
DLIB_CASSERT(grad.k() == gradient_input.k() / row_stride / col_stride, "The number of channels in grad must be gradient_input.k() divided by row_stride and col_stride.");
DLIB_CASSERT(grad.nr() == gradient_input.nr() * row_stride, "The number of rows in grad must be gradient_input.nr() multiplied by row_stride.");
DLIB_CASSERT(grad.nc() == gradient_input.nc() * col_stride, "The number of columns in grad must be gradient_input.nc() multiplied by col_stride.");

launch_kernel(_cuda_reorg_gradient, gradient_input.size(), grad.k(), grad.nr(), grad.nc(), grad.device(),
gradient_input.k(), gradient_input.nr(), gradient_input.nc(), gradient_input.device(),
row_stride, col_stride);
gradient_input.k(), gradient_input.nr(), gradient_input.nc(), gradient_input.device(),
row_stride, col_stride, add_to);
}

// ----------------------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions dlib/cuda/cuda_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,13 +546,15 @@ namespace dlib
// ----------------------------------------------------------------------------------------

void reorg (
bool add_to,
tensor& dest,
const int row_stride,
const int col_stride,
const tensor& src
);

void reorg_gradient (
bool add_to,
tensor& grad,
const int row_stride,
const int col_stride,
Expand Down
10 changes: 6 additions & 4 deletions dlib/cuda/tensor_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1219,30 +1219,32 @@ namespace dlib { namespace tt
// ------------------------------------------------------------------------------------

void reorg (
bool add_to,
tensor& dest,
const int row_stride,
const int col_stride,
const tensor& src
)
{
#ifdef DLIB_USE_CUDA
cuda::reorg(dest, row_stride, col_stride, src);
cuda::reorg(add_to, dest, row_stride, col_stride, src);
#else
cpu::reorg(dest, row_stride, col_stride, src);
cpu::reorg(add_to, dest, row_stride, col_stride, src);
#endif
}

void reorg_gradient (
bool add_to,
tensor& grad,
const int row_stride,
const int col_stride,
const tensor& gradient_input
)
{
#ifdef DLIB_USE_CUDA
cuda::reorg_gradient(grad, row_stride, col_stride, gradient_input);
cuda::reorg_gradient(add_to, grad, row_stride, col_stride, gradient_input);
#else
cpu::reorg_gradient(grad, row_stride, col_stride, gradient_input);
cpu::reorg_gradient(add_to, grad, row_stride, col_stride, gradient_input);
#endif
}

Expand Down
Loading

0 comments on commit 72822fe

Please sign in to comment.