From 3aec165ea8c0bd53931ce11bdf69d9b827a334e3 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sat, 23 Sep 2023 18:03:49 +0330 Subject: [PATCH] range fixes and checks --- src/layer/range.cpp | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/layer/range.cpp b/src/layer/range.cpp index 7ce121b17fc0..a1833c772699 100644 --- a/src/layer/range.cpp +++ b/src/layer/range.cpp @@ -26,15 +26,15 @@ Range::Range() int Range::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - if (bottom_blobs.size() < 2 || bottom_blobs.size() > 3) + if (bottom_blobs.size() < 2 || bottom_blobs.size() > 3 || top_blobs.size() != 1) return -100; const Mat& start = bottom_blobs[0]; - if(start.empty()) + if (start.empty()) return -100; const Mat& limit = bottom_blobs[1]; - if(limit.empty()) + if (limit.empty()) return -100; const Mat& delta = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); @@ -44,29 +44,41 @@ int Range::forward(const std::vector& bottom_blobs, std::vector& top_b if (start.w * start.h * start.d * start.c != 1 || limit.w * limit.h * limit.d * limit.c != 1 || (!delta.empty() && delta.w * delta.h * delta.d * delta.c != 1)) return -100; - const int* start_ptr = start; - const int* limit_ptr = limit; - const int* delta_ptr = delta; + if (start.elemsize != limit.elemsize || (!delta.empty() && start.elemsize != delta.elemsize)) + return -100; - int start_val = *start_ptr; - int limit_val = *limit_ptr; - int delta_val = delta.empty() ? 1 : *delta_ptr; + const float* start_ptr = start; + const float* limit_ptr = limit; - if (delta_val == 0) + float start_val = start_ptr[0]; + float limit_val = limit_ptr[0]; + float delta_val = 1.0f; + if (!delta.empty()) + { + const float* delta_ptr = delta; + delta_val = delta_ptr[0]; + } + + if (delta_val == 0.0f || (limit_val - start_val) * delta_val <= 0.0f) return -100; - int number_of_elements = (int) std::max((int) ceilf((limit_val - start_val) / delta_val), 0); + if (limit_val < start_val && delta_val > 0.0f) + delta_val = -delta_val; + + int number_of_elements = static_cast(ceil((limit_val - start_val) / delta_val)); + if (number_of_elements < 0) + number_of_elements = 0; - output.create(number_of_elements, start.elemsize, opt.blob_allocator); + output.create(number_of_elements, start.elemsize, start.elempack, opt.blob_allocator); if (output.empty()) return -100; - int* outptr = output; + float* outptr = output; #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < number_of_elements; i++) { - ((int*)outptr)[i] = start_val + (i * delta_val); + ((float*)outptr)[i] = start_val + (i * delta_val); } return 0;