Skip to content

Commit

Permalink
added yolo into the node
Browse files Browse the repository at this point in the history
  • Loading branch information
KalanaRatnayake committed Sep 12, 2024
1 parent 58b707b commit b1b1430
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 20 deletions.
157 changes: 138 additions & 19 deletions boxmot_ros/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,43 @@

from pathlib import Path

from message_filters import Subscriber
from message_filters import ApproximateTimeSynchronizer

from rclpy.node import Node
from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy

from sensor_msgs.msg import Image
from sensor_msgs.msg import Image, PointCloud2
from detection_msgs.msg import Detections

from cv_bridge import CvBridge

from ultralytics import YOLO
from boxmot import BoTSORT, DeepOCSORT, OCSORT, HybridSORT, BYTETracker, StrongSORT

class BoxmotROS(Node):

def __init__(self):
super().__init__('boxmot_ros')

self.declare_parameter("yolo_model", "yolov8n.pt")
self.declare_parameter("tracking_model", "deepocsort")
self.declare_parameter("reid_model", "osnet_x0_25_msmt17.pt")
self.declare_parameter("input_topic", "/input")
self.declare_parameter("input_rgb_topic", "/camera/color/image_raw")
self.declare_parameter("input_depth_topic", "/camera/depth/points")
self.declare_parameter("subscribe_depth", False)
self.declare_parameter("publish_annotated_image", False)
self.declare_parameter("annotated_topic", "/boxmot_ros/annotated_image")
self.declare_parameter("detailed_topic", "/boxmot_ros/tracking_result")
self.declare_parameter("threshold", 0.25)
self.declare_parameter("device", "cpu")

self.yolo_model = self.get_parameter("yolo_model").get_parameter_value().string_value
self.tracking_model = self.get_parameter("tracking_model").get_parameter_value().string_value
self.reid_model = self.get_parameter("reid_model").get_parameter_value().string_value
self.input_topic = self.get_parameter("input_topic").get_parameter_value().string_value
self.input_rgb_topic = self.get_parameter("input_rgb_topic").get_parameter_value().string_value
self.input_depth_topic = self.get_parameter("input_depth_topic").get_parameter_value().string_value
self.subscribe_depth = self.get_parameter("subscribe_depth").get_parameter_value().bool_value
self.publish_annotated_image = self.get_parameter("publish_annotated_image").get_parameter_value().bool_value
self.annotated_topic = self.get_parameter("annotated_topic").get_parameter_value().string_value
self.detailed_topic = self.get_parameter("detailed_topic").get_parameter_value().string_value
Expand All @@ -39,6 +49,9 @@ def __init__(self):

self.bridge = CvBridge()

self.model = YOLO(self.yolo_model)
self.model.fuse()

if self.tracking_model == "deepocsort":
self.tracker = DeepOCSORT( model_weights=Path(self.reid_model), # which ReID model to use
device=self.device,
Expand Down Expand Up @@ -73,7 +86,15 @@ def __init__(self):
history=QoSHistoryPolicy.KEEP_LAST,
depth=1 )

self.subscription = self.create_subscription(Detections, self.input_topic, self.image_callback, qos_profile=self.subscriber_qos_profile)
if self.subscribe_depth:
self.rgb_message_filter = Subscriber(self, Image, self.input_rgb_topic, qos_profile=self.subscriber_qos_profile)
self.depth_message_filter = Subscriber(self, PointCloud2, self.input_depth_topic, qos_profile=self.subscriber_qos_profile)

self.synchornizer = ApproximateTimeSynchronizer([self.rgb_message_filter, self.depth_message_filter], 10, 1)
self.synchornizer.registerCallback(self.sync_callback)

else:
self.subscription = self.create_subscription(Image, self.input_rgb_topic, self.image_callback, qos_profile=self.subscriber_qos_profile)

self.publisher_results = self.create_publisher(Detections, self.detailed_topic, 10)

Expand All @@ -85,32 +106,132 @@ def __init__(self):
self.time = 0

self.tracking_msg = Detections()
self.class_list_set = False

def image_callback(self, received_msg):

def sync_callback(self, rgb_msg, depth_msg):
start = time.time_ns()

self.input_image = self.bridge.imgmsg_to_cv2(received_msg.source_rgb, desired_encoding="bgr8")
self.input_image = self.bridge.imgmsg_to_cv2(rgb_msg, desired_encoding="bgr8")

self.result = self.model.predict(source = self.input_image,
conf=self.threshold,
device=self.device,
verbose=False)

if (not self.class_list_set) and (self.result is not None):
for i in range(len(self.result[0].names)):
self.tracking_msg.full_class_list.append(self.result[0].names.get(i))
self.class_list_set = True

if received_msg.detections:
if self.result is not None:
detection_list = []

for i in range(len(received_msg.class_id)):
for bbox, cls, conf in zip(self.result[0].boxes.xywh, self.result[0].boxes.cls, self.result[0].boxes.conf):
detection = []

clid = received_msg.class_id[i]
conf = received_msg.confidence[i]
cx = int(bbox[0])
cy = int(bbox[1])
sw = int(bbox[2])
sh = int(bbox[3])

x1 = cx - (sw/2)
y1 = cy - (sh/2)
x2 = cx + (sw/2)
y2 = cy + (sh/2)

detection = [x1, y1, x2, y2, float(conf), int(cls)]

detection_list.append(detection)

detection_numpy = np.array(detection_list)
else:
detection_numpy = np.empty((0, 5))

# input is of shape (x, y, x, y, conf, cls)
# output is of shape (x, y, x, y, id, conf, cls, ind)

self.result_tracks = self.tracker.update(detection_numpy, self.input_image)

if self.result_tracks is not None:

self.tracking_msg.header = rgb_msg.header
self.tracking_msg.source_rgb = rgb_msg
self.tracking_msg.source_depth = depth_msg

for track in self.result_tracks:

x1 = track[0].astype('int')
y1 = track[1].astype('int')
x2 = track[2].astype('int')
y2 = track[3].astype('int')

tracking_id = track[4].astype('int')
confidence = track[5].astype('float')
class_id = track[6].astype('int')

cx = (x2 + x1)/2
cy = (y2 + y1)/2
sw = x2 - x1
sh = y2 - y1

self.tracking_msg.bbx_center_x.append(int(cx))
self.tracking_msg.bbx_center_y.append(int(cy))
self.tracking_msg.bbx_size_w.append(int(cx))
self.tracking_msg.bbx_size_h.append(int(cx))
self.tracking_msg.class_id.append(class_id)
self.tracking_msg.tracking_id.append(tracking_id)
self.tracking_msg.confidence.append(confidence)

self.publisher_results.publish(self.tracking_msg)

if self.publish_annotated_image:
self.output_image = self.tracker.plot_results(self.input_image, show_trajectories=True)
result_msg = self.bridge.cv2_to_imgmsg(self.output_image, encoding="bgr8")

self.publisher_image.publish(result_msg)

self.counter += 1
self.time += time.time_ns() - start

if (self.counter == 100):
self.get_logger().info('Callback execution time for 100 loops: %d ms' % ((self.time/100)/1000000))
self.time = 0
self.counter = 0


def image_callback(self, rgb_image):
start = time.time_ns()

self.input_image = self.bridge.imgmsg_to_cv2(rgb_image, desired_encoding="bgr8")

self.result = self.model.predict(source = self.input_image,
conf=self.threshold,
device=self.device,
verbose=False)

if (not self.class_list_set) and (self.result is not None):
for i in range(len(self.result[0].names)):
self.tracking_msg.full_class_list.append(self.result[0].names.get(i))
self.class_list_set = True

if self.result is not None:
detection_list = []

for bbox, cls, conf in zip(self.result[0].boxes.xywh, self.result[0].boxes.cls, self.result[0].boxes.conf):
detection = []

cx = received_msg.bbx_center_x[i]
cy = received_msg.bbx_center_y[i]
sw = received_msg.bbx_size_w[i]
sh = received_msg.bbx_size_h[i]
cx = int(bbox[0])
cy = int(bbox[1])
sw = int(bbox[2])
sh = int(bbox[3])

x1 = cx - (sw/2)
y1 = cy - (sh/2)
x2 = cx + (sw/2)
y2 = cy + (sh/2)

detection = [x1, y1, x2, y2, conf, clid]
detection = [x1, y1, x2, y2, float(conf), int(cls)]

detection_list.append(detection)

Expand All @@ -125,10 +246,8 @@ def image_callback(self, received_msg):

if self.result_tracks is not None:

self.tracking_msg.header = received_msg.header
self.tracking_msg.source_rgb = received_msg.source_rgb
self.tracking_msg.source_depth = received_msg.source_depth
self.tracking_msg.full_class_list = received_msg.full_class_list
self.tracking_msg.header = rgb_image.header
self.tracking_msg.source_rgb = rgb_image

for track in self.result_tracks:

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
boxmot
boxmot
ultralytics

0 comments on commit b1b1430

Please sign in to comment.