-
Notifications
You must be signed in to change notification settings - Fork 650
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(lidar_centerpoint): add IoU-based NMS
Signed-off-by: yukke42 <yusuke.muramatsu@tier4.jp>
- Loading branch information
yukke42
committed
Sep 22, 2022
1 parent
66e010c
commit 52dc529
Showing
9 changed files
with
237 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
...ption/lidar_centerpoint/include/lidar_centerpoint/postprocess/non_maximum_suppression.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// Copyright 2022 TIER IV, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ | ||
#define LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ | ||
|
||
#include "lidar_centerpoint/ros_utils.hpp" | ||
|
||
#include <Eigen/Eigen> | ||
|
||
#include "autoware_auto_perception_msgs/msg/detected_object.hpp" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
namespace centerpoint | ||
{ | ||
using autoware_auto_perception_msgs::msg::DetectedObject; | ||
|
||
// TODO(yukke42): now only support IoU_BEV | ||
enum class NMS_TYPE { | ||
IoU_BEV | ||
// IoU_3D | ||
// Distance_2D | ||
// Distance_3D | ||
}; | ||
|
||
struct NMSParams | ||
{ | ||
NMS_TYPE nms_type_{}; | ||
std::vector<std::string> target_class_names_{}; | ||
double search_distance_2d_{}; | ||
double iou_threshold_{}; | ||
// double distance_threshold_{}; | ||
}; | ||
|
||
std::vector<bool> classNamesToBooleanMask(const std::vector<std::string> & class_names) | ||
{ | ||
std::vector<bool> mask; | ||
mask.resize(8); | ||
for (const auto & class_name : class_names) { | ||
const auto semantic_type = getSemanticType(class_name); | ||
mask.at(semantic_type) = true; | ||
} | ||
|
||
return mask; | ||
} | ||
|
||
class NonMaximumSuppression | ||
{ | ||
public: | ||
void setParameters(const NMSParams &); | ||
|
||
std::vector<DetectedObject> apply(const std::vector<DetectedObject> &); | ||
|
||
private: | ||
bool isTargetLabel(const std::uint8_t); | ||
|
||
bool isTargetPairObject(const DetectedObject &, const DetectedObject &); | ||
|
||
Eigen::MatrixXd generateIoUMatrix(const std::vector<DetectedObject> &); | ||
|
||
NMSParams params_{}; | ||
std::vector<bool> target_class_mask_{}; | ||
}; | ||
|
||
} // namespace centerpoint | ||
|
||
#endif // LIDAR_CENTERPOINT__POSTPROCESS__NON_MAXIMUM_SUPPRESSION_HPP_ |
105 changes: 105 additions & 0 deletions
105
perception/lidar_centerpoint/lib/postprocess/non_maximum_suppression.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Copyright 2022 TIER IV, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "lidar_centerpoint/postprocess/non_maximum_suppression.hpp" | ||
|
||
#include "perception_utils/geometry.hpp" | ||
#include "perception_utils/perception_utils.hpp" | ||
#include "tier4_autoware_utils/tier4_autoware_utils.hpp" | ||
|
||
namespace centerpoint | ||
{ | ||
|
||
void NonMaximumSuppression::setParameters(const NMSParams & params) | ||
{ | ||
assert(params.target_class_names_.size() == 8); | ||
assert(params.search_distance_2d_ >= 0.0); | ||
assert(params.iou_threshold_ >= 0.0 && params.iou_threshold_ <= 1.0); | ||
|
||
params_ = params; | ||
target_class_mask_ = classNamesToBooleanMask(params.target_class_names_); | ||
} | ||
|
||
bool NonMaximumSuppression::isTargetLabel(const uint8_t label) | ||
{ | ||
if (label >= target_class_mask_.size()) { | ||
return false; | ||
} | ||
return target_class_mask_.at(label); | ||
} | ||
|
||
bool NonMaximumSuppression::isTargetPairObject( | ||
const DetectedObject & object1, const DetectedObject & object2) | ||
{ | ||
const auto label1 = perception_utils::getHighestProbLabel(object1.classification); | ||
const auto label2 = perception_utils::getHighestProbLabel(object2.classification); | ||
|
||
if (isTargetLabel(label1) && isTargetLabel(label2)) { | ||
return true; | ||
} | ||
|
||
const auto search_sqr_dist_2d = params_.search_distance_2d_ * params_.search_distance_2d_; | ||
const auto sqrt_dist_2d = tier4_autoware_utils::calcSquaredDistance2d( | ||
perception_utils::getPose(object1), perception_utils::getPose(object2)); | ||
return sqrt_dist_2d <= search_sqr_dist_2d; | ||
} | ||
|
||
Eigen::MatrixXd NonMaximumSuppression::generateIoUMatrix( | ||
const std::vector<DetectedObject> & input_objects) | ||
{ | ||
// NOTE(yukke42): row = target objects to be suppressed, col = source objects to be compared | ||
Eigen::MatrixXd triangular_matrix = | ||
Eigen::MatrixXd::Zero(input_objects.size(), input_objects.size()); | ||
for (std::size_t target_i = 0; target_i < input_objects.size(); ++target_i) { | ||
for (std::size_t source_i = 0; source_i < target_i; ++source_i) { | ||
const auto & target_obj = input_objects.at(target_i); | ||
const auto & source_obj = input_objects.at(source_i); | ||
if (!isTargetPairObject(target_obj, source_obj)) { | ||
continue; | ||
} | ||
|
||
if (params_.nms_type_ == NMS_TYPE::IoU_BEV) { | ||
const double iou = perception_utils::get2dIoU(target_obj, source_obj); | ||
triangular_matrix(target_i, source_i) = iou; | ||
// NOTE(yukke42): If the target object has any objects with iou > iou_threshold, it | ||
// will be suppressed regardless of later results. | ||
if (iou > params_.iou_threshold_) { | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
|
||
return triangular_matrix; | ||
} | ||
|
||
std::vector<DetectedObject> NonMaximumSuppression::apply( | ||
const std::vector<DetectedObject> & input_objects) | ||
{ | ||
Eigen::MatrixXd iou_matrix = generateIoUMatrix(input_objects); | ||
|
||
std::vector<DetectedObject> output_objects; | ||
output_objects.reserve(input_objects.size()); | ||
for (std::size_t i = 0; i < input_objects.size(); ++i) { | ||
const auto value = iou_matrix.row(i).maxCoeff(); | ||
if (params_.nms_type_ == NMS_TYPE::IoU_BEV) { | ||
if (value <= params_.iou_threshold_) { | ||
output_objects.emplace_back(input_objects.at(i)); | ||
} | ||
} | ||
} | ||
|
||
return output_objects; | ||
} | ||
} // namespace centerpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters