Skip to content

Commit

Permalink
SAM2Ultra and ObjectDetector nodes support image batch
Browse files Browse the repository at this point in the history
  • Loading branch information
chflame163 committed Oct 29, 2024
1 parent c49a40f commit 1afff89
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 185 deletions.
1 change: 1 addition & 0 deletions README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Please try downgrading the ```protobuf``` dependency package to 3.20.3, or set e

<font size="4">**If the dependency package error after updating, please double clicking ```repair_dependency.bat``` (for Official ComfyUI Protable) or ```repair_dependency_aki.bat``` (for ComfyUI-aki-v1.x) in the plugin folder to reinstall the dependency packages. </font><br />

* [SAM2Ultra](#SAM2Ultra) and ObjectDetector nodes support image batch.
* [SAM2Ultra](#SAM2Ultra) and [SAM2VideoUltra](#SAM2VideoUltra) nodes add support for SAM2.1 model, including [kijai](https://github.com/kijai)'s FP16 model. Download model files from [BaiduNetdisk](https://pan.baidu.com/s/1xaQYBA6ktxvAxm310HXweQ?pwd=auki) or [huggingface.co/Kijai/sam2-safetensors](https://huggingface.co/Kijai/sam2-safetensors/tree/main) and copy to ```ComfyUI/models/sam2``` folder.
* Commit [JoyCaption2Split](#JoyCaption2Split) and [LoadJoyCaption2Model](#LoadJoyCaption2Model) nodes, Sharing the model across multiple JoyCaption2 nodes improves efficiency.
* [SegmentAnythingUltra](#SegmentAnythingUltra) and [SegmentAnythingUltraV2](#SegmentAnythingUltraV2) add the ```cache_model``` option, Easy to flexibly manage VRAM usage.
Expand Down
1 change: 1 addition & 0 deletions README_CN.MD
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ If this call came from a _pb2.py file, your generated code is out of date and mu
## 更新说明
<font size="4">**如果本插件更新后出现依赖包错误,请双击运行插件目录下的```install_requirements.bat```(官方便携包),或 ```install_requirements_aki.bat```(秋叶整合包) 重新安装依赖包。

* [SAM2Ultra](#SAM2Ultra) 及 ObjectDetector 节点支持图像批次。
* [SAM2Ultra](#SAM2Ultra)[SAM2VideoUltra](#SAM2VideoUltra) 节点增加支持SAM2.1模型,包括[kijai](https://github.com/kijai)量化版fp16模型。请从请从[百度网盘](https://pan.baidu.com/s/1xaQYBA6ktxvAxm310HXweQ?pwd=auki) 或者 [huggingface.co/Kijai/sam2-safetensors](https://huggingface.co/Kijai/sam2-safetensors/tree/main)下载模型文件并复制到```ComfyUI/models/sam2```文件夹。
* 添加 [JoyCaption2Split](#JoyCaption2Split)[LoadJoyCaption2Model](#LoadJoyCaption2Model) 节点,在多个JoyCaption2节点时共用模型提高效率。
* [SegmentAnythingUltra](#SegmentAnythingUltra)[SegmentAnythingUltraV2](#SegmentAnythingUltraV2) 增加 ```cache_model``` 参数,便于灵活管理显存。
Expand Down
195 changes: 108 additions & 87 deletions py/object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def INPUT_TYPES(cls):

def object_detector_fl2(self, image, prompt, florence2_model, sort_method, bbox_select, select_index):

ret_bboxes = []
bboxes = []
ret_previews = []
max_new_tokens = 512
Expand All @@ -115,27 +116,30 @@ def object_detector_fl2(self, image, prompt, florence2_model, sort_method, bbox_
model = florence2_model['model']
processor = florence2_model['processor']

img = tensor2pil(image[0]).convert("RGB")
task = 'caption to phrase grounding'
from .florence2_ultra import process_image
results, _ = process_image(model, processor, img, task,
max_new_tokens, num_beams, do_sample,
fill_mask, prompt)

if isinstance(results, dict):
results["width"] = img.width
results["height"] = img.height

bboxes = self.fbboxes_to_list(results)
bboxes = sort_bboxes(bboxes, sort_method)
bboxes = select_bboxes(bboxes, bbox_select, select_index)
preview = draw_bounding_boxes(img, bboxes, color="random", line_width=-1)
ret_previews.append(pil2tensor(preview))
if len(bboxes) == 0:
log(f"{self.NODE_NAME} no object found", message_type='warning')
else:
log(f"{self.NODE_NAME} found {len(bboxes)} object(s)", message_type='info')
return (standardize_bbox(bboxes), torch.cat(ret_previews, dim=0))
for img in image:
img = tensor2pil(img.unsqueeze(0)).convert("RGB")
task = 'caption to phrase grounding'
from .florence2_ultra import process_image
results, _ = process_image(model, processor, img, task,
max_new_tokens, num_beams, do_sample,
fill_mask, prompt)

if isinstance(results, dict):
results["width"] = img.width
results["height"] = img.height

bboxes = self.fbboxes_to_list(results)
bboxes = sort_bboxes(bboxes, sort_method)
bboxes = select_bboxes(bboxes, bbox_select, select_index)
preview = draw_bounding_boxes(img, bboxes, color="random", line_width=-1)
ret_previews.append(pil2tensor(preview))
ret_bboxes.append(standardize_bbox(bboxes))
if len(bboxes) == 0:
log(f"{self.NODE_NAME} no object found", message_type='warning')
else:
log(f"{self.NODE_NAME} found {len(bboxes)} object(s)", message_type='info')

return (ret_bboxes, torch.cat(ret_previews, dim=0))

def fbboxes_to_list(self, F_BBOXES) -> list:
if isinstance(F_BBOXES, str):
Expand Down Expand Up @@ -210,31 +214,34 @@ def INPUT_TYPES(cls):

def object_detector_mask(self, object_mask, sort_method, bbox_select, select_index):

ret_bboxes = []
ret_previews = []
bboxes = []
if object_mask.dim() == 2:
object_mask = torch.unsqueeze(object_mask, 0)

cv_mask = tensor2cv2(object_mask[0])
cv_mask = cv2.cvtColor(cv_mask, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(cv_mask, 127, 255, cv2.THRESH_BINARY)
# invert mask
# binary = cv2.bitwise_not(binary)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
bboxes.append([x, y, x + w, y + h])
bboxes = sort_bboxes(bboxes, sort_method)
bboxes = select_bboxes(bboxes, bbox_select, select_index)
ret_previews = []
preview = draw_bounding_boxes(tensor2pil(object_mask[0]).convert("RGB"), bboxes, color="random", line_width=-1)
ret_previews.append(pil2tensor(preview))

if len(bboxes) == 0:
log(f"{self.NODE_NAME} no object found", message_type='warning')
else:
log(f"{self.NODE_NAME} found {len(bboxes)} object(s)", message_type='info')

return (standardize_bbox(bboxes), torch.cat(ret_previews, dim=0))
for msk in object_mask:
cv_mask = tensor2cv2(msk)
cv_mask = cv2.cvtColor(cv_mask, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(cv_mask, 127, 255, cv2.THRESH_BINARY)
# invert mask
# binary = cv2.bitwise_not(binary)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
bboxes.append([x, y, x + w, y + h])
bboxes = sort_bboxes(bboxes, sort_method)
bboxes = select_bboxes(bboxes, bbox_select, select_index)
preview = draw_bounding_boxes(tensor2pil(msk).convert("RGB"), bboxes, color="random", line_width=-1)
ret_previews.append(pil2tensor(preview))

if len(bboxes) == 0:
log(f"{self.NODE_NAME} no object found", message_type='warning')
else:
log(f"{self.NODE_NAME} found {len(bboxes)} object(s)", message_type='info')
ret_bboxes.append(standardize_bbox(bboxes))

return (ret_bboxes, torch.cat(ret_previews, dim=0))


class LS_OBJECT_DETECTOR_YOLO8:
Expand Down Expand Up @@ -271,31 +278,34 @@ def object_detector_yolo8(self, image, yolo_model, sort_method, bbox_select, sel
model_path = os.path.join(folder_paths.models_dir, 'yolo')
yolo_model = YOLO(os.path.join(model_path, yolo_model))

ret_bboxes = []
bboxes = []
ret_previews = []

img = torch.unsqueeze(image[0], 0)
_image = tensor2pil(img)
results = yolo_model(_image, retina_masks=True)
for result in results:
yolo_plot_image = cv2.cvtColor(result.plot(), cv2.COLOR_BGR2RGB)

# no mask, if have box, draw box
if result.boxes is not None and len(result.boxes.xyxy) > 0:
for box in result.boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
bboxes.append([x1, y1, x2, y2])
bboxes = sort_bboxes(bboxes, sort_method)
bboxes = select_bboxes(bboxes, bbox_select, select_index)
preview = draw_bounding_boxes(_image.convert("RGB"), bboxes, color="random", line_width=-1)
ret_previews.append(pil2tensor(preview))

if len(bboxes) == 0:
log(f"{self.NODE_NAME} no object found", message_type='warning')
else:
log(f"{self.NODE_NAME} found {len(bboxes)} object(s)", message_type='info')

return (standardize_bbox(bboxes), torch.cat(ret_previews, dim=0),)
for img in image:
img = torch.unsqueeze(img.unsqueeze(0), 0)
_image = tensor2pil(img)
results = yolo_model(_image, retina_masks=True)
for result in results:
yolo_plot_image = cv2.cvtColor(result.plot(), cv2.COLOR_BGR2RGB)

# no mask, if have box, draw box
if result.boxes is not None and len(result.boxes.xyxy) > 0:
for box in result.boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
bboxes.append([x1, y1, x2, y2])
bboxes = sort_bboxes(bboxes, sort_method)
bboxes = select_bboxes(bboxes, bbox_select, select_index)
preview = draw_bounding_boxes(_image.convert("RGB"), bboxes, color="random", line_width=-1)
ret_previews.append(pil2tensor(preview))

if len(bboxes) == 0:
log(f"{self.NODE_NAME} no object found", message_type='warning')
else:
log(f"{self.NODE_NAME} found {len(bboxes)} object(s)", message_type='info')
ret_bboxes.append(standardize_bbox(bboxes))

return (ret_bboxes, torch.cat(ret_previews, dim=0),)

class LS_OBJECT_DETECTOR_YOLOWORLD:

Expand Down Expand Up @@ -332,33 +342,44 @@ def INPUT_TYPES(cls):
def object_detector_yoloworld(self, image, yolo_world_model,
confidence_threshold, nms_iou_threshold, prompt,
sort_method, bbox_select, select_index):
ret_previews = []
ret_bboxes = []

import supervision as sv

model=self.load_yolo_world_model(yolo_world_model, prompt)
infer_outputs = []
img = (255 * image[0].cpu().numpy()).astype(np.uint8)
results = model.infer(
img, confidence=confidence_threshold)
detections = sv.Detections.from_inference(results)
detections = detections.with_nms(
class_agnostic=False,
threshold=nms_iou_threshold
)
infer_outputs.append(detections)
bboxes = infer_outputs[0].xyxy.tolist()
bboxes = [[int(value) for value in sublist] for sublist in bboxes]
bboxes = sort_bboxes(bboxes, sort_method)
bboxes = select_bboxes(bboxes, bbox_select, select_index)
ret_previews = []
preview = draw_bounding_boxes(tensor2pil(image[0]).convert('RGB'), bboxes, color="random", line_width=-1)
ret_previews.append(pil2tensor(preview))

if len(bboxes) == 0:
log(f"{self.NODE_NAME} no object found", message_type='warning')
else:
log(f"{self.NODE_NAME} found {len(bboxes)} object(s)", message_type='info')

return (standardize_bbox(bboxes), torch.cat(ret_previews, dim=0))
for i in image:
infer_outputs = []
# img = (255 * img.unsqueeze(0).cpu().numpy()).astype(np.uint8)
img = tensor2np(i)
results = model.infer(
img, confidence=confidence_threshold)
detections = sv.Detections.from_inference(results)
detections = detections.with_nms(
class_agnostic=False,
threshold=nms_iou_threshold
)
infer_outputs.append(detections)

if len(infer_outputs[0].xyxy) > 0:
bboxes = infer_outputs[0].xyxy.tolist()
bboxes = [[int(value) for value in sublist] for sublist in bboxes]
bboxes = sort_bboxes(bboxes, sort_method)
bboxes = select_bboxes(bboxes, bbox_select, select_index)
else:
bboxes = [[0, 0, i.shape[1], i.shape[0]]]

preview = draw_bounding_boxes(tensor2pil(i.unsqueeze(0)).convert('RGB'), bboxes, color="random", line_width=-1)
ret_previews.append(pil2tensor(preview))

if len(bboxes) == 0:
log(f"{self.NODE_NAME} no object found", message_type='warning')
else:
log(f"{self.NODE_NAME} found {len(bboxes)} object(s)", message_type='info')
ret_bboxes.append(standardize_bbox(bboxes))

return (ret_bboxes, torch.cat(ret_previews, dim=0))

def process_categories(self, categories: str) -> List[str]:
return [category.strip().lower() for category in categories.split(',')]
Expand Down
Loading

0 comments on commit 1afff89

Please sign in to comment.