-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SAM 2 Update 12/11/2024 -- full model compilation for a major VOS spe…
…edup and a new SAM2VideoPredictor to better handle multi-object tracking (#486) This PR provides new features and updates for SAM 2: - We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. - We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
- Loading branch information
1 parent
c2ec8e1
commit 393ae33
Showing
27 changed files
with
1,794 additions
and
443 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ build/* | |
_C.* | ||
outputs/* | ||
checkpoints/*.pt | ||
demo/backend/checkpoints/*.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
## SAM 2 release notes | ||
|
||
### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking | ||
|
||
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). | ||
* Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. | ||
* In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. | ||
* Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. | ||
* **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. | ||
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: | ||
* Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. | ||
* This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). | ||
* We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class. | ||
|
||
### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released | ||
|
||
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details. | ||
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below. | ||
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started. | ||
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details. | ||
|
||
### 07/29/2024 -- SAM 2 is released | ||
|
||
- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos. | ||
* SAM 2 code: https://github.com/facebookresearch/sam2 | ||
* SAM 2 demo: https://sam2.metademolab.com/ | ||
* SAM 2 paper: https://arxiv.org/abs/2408.00714 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[build-system] | ||
requires = [ | ||
"setuptools>=61.0", | ||
"torch>=2.3.1", | ||
"torch>=2.5.1", | ||
] | ||
build-backend = "setuptools.build_meta" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
import time | ||
|
||
import numpy as np | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from sam2.build_sam import build_sam2_video_predictor | ||
|
||
# Only cuda supported | ||
assert torch.cuda.is_available() | ||
device = torch.device("cuda") | ||
|
||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | ||
if torch.cuda.get_device_properties(0).major >= 8: | ||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
|
||
# Config and checkpoint | ||
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" | ||
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" | ||
|
||
# Build video predictor with vos_optimized=True setting | ||
predictor = build_sam2_video_predictor( | ||
model_cfg, sam2_checkpoint, device=device, vos_optimized=True | ||
) | ||
|
||
|
||
# Initialize with video | ||
video_dir = "notebooks/videos/bedroom" | ||
# scan all the JPEG frame names in this directory | ||
frame_names = [ | ||
p | ||
for p in os.listdir(video_dir) | ||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] | ||
] | ||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) | ||
inference_state = predictor.init_state(video_path=video_dir) | ||
|
||
|
||
# Number of runs, warmup etc | ||
warm_up, runs = 5, 25 | ||
verbose = True | ||
num_frames = len(frame_names) | ||
total, count = 0, 0 | ||
torch.cuda.empty_cache() | ||
|
||
# We will select an object with a click. | ||
# See video_predictor_example.ipynb for more detailed explanation | ||
ann_frame_idx, ann_obj_id = 0, 1 | ||
# Add a positive click at (x, y) = (210, 350) | ||
# For labels, `1` means positive click | ||
points = np.array([[210, 350]], dtype=np.float32) | ||
labels = np.array([1], np.int32) | ||
|
||
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( | ||
inference_state=inference_state, | ||
frame_idx=ann_frame_idx, | ||
obj_id=ann_obj_id, | ||
points=points, | ||
labels=labels, | ||
) | ||
|
||
# Warmup and then average FPS over several runs | ||
with torch.autocast("cuda", torch.bfloat16): | ||
with torch.inference_mode(): | ||
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): | ||
start = time.time() | ||
# Start tracking | ||
for ( | ||
out_frame_idx, | ||
out_obj_ids, | ||
out_mask_logits, | ||
) in predictor.propagate_in_video(inference_state): | ||
pass | ||
|
||
end = time.time() | ||
total += end - start | ||
count += 1 | ||
if i == warm_up - 1: | ||
print("Warmup FPS: ", count * num_frames / total) | ||
total = 0 | ||
count = 0 | ||
|
||
print("FPS: ", count * num_frames / total) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.