Skip to content

Commit

Permalink
update sam2 code
Browse files Browse the repository at this point in the history
  • Loading branch information
chflame163 committed Oct 22, 2024
1 parent 0e59ff0 commit cb87dd3
Show file tree
Hide file tree
Showing 16 changed files with 1,199 additions and 78 deletions.
29 changes: 25 additions & 4 deletions py/sam2/modeling/backbones/hieradet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from iopath.common.file_io import g_pathmgr

from ....sam2.modeling.backbones.utils import (
PatchEmbed,
Expand Down Expand Up @@ -46,11 +47,7 @@ def __init__(

self.dim = dim
self.dim_out = dim_out

self.num_heads = num_heads
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5

self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)
Expand Down Expand Up @@ -197,6 +194,7 @@ def __init__(
16,
20,
),
weights_path=None,
return_interm_layers=True, # return feats from every stage
):
super().__init__()
Expand Down Expand Up @@ -266,6 +264,11 @@ def __init__(
else [self.blocks[-1].dim_out]
)

if weights_path is not None:
with g_pathmgr.open(weights_path, "rb") as f:
chkpt = torch.load(f, map_location="cpu")
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))

def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw
window_embed = self.pos_embed_window
Expand Down Expand Up @@ -293,3 +296,21 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outputs.append(feats)

return outputs

def get_layer_id(self, layer_name):
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
num_layers = self.get_num_layers()

if layer_name.find("rel_pos") != -1:
return num_layers + 1
elif layer_name.find("pos_embed") != -1:
return 0
elif layer_name.find("patch_embed") != -1:
return 0
elif layer_name.find("blocks") != -1:
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
else:
return num_layers + 1

def get_num_layers(self) -> int:
return len(self.blocks)
1 change: 1 addition & 0 deletions py/sam2/modeling/backbones/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
self.position_encoding = position_encoding
self.convs = nn.ModuleList()
self.backbone_channel_list = backbone_channel_list
self.d_model = d_model
for dim in backbone_channel_list:
current = nn.Sequential()
current.add_module(
Expand Down
2 changes: 1 addition & 1 deletion py/sam2/modeling/sam/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def predict_masks(
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
Expand Down
166 changes: 122 additions & 44 deletions py/sam2/modeling/sam2_base.py

Large diffs are not rendered by default.

174 changes: 174 additions & 0 deletions py/sam2/modeling/sam2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@


import copy
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..utils.misc import mask_to_box


def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
"""
Expand Down Expand Up @@ -147,3 +151,173 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x


def sample_box_points(
masks: torch.Tensor,
noise: float = 0.1, # SAM default
noise_bound: int = 20, # SAM default
top_left_label: int = 2,
bottom_right_label: int = 3,
) -> Tuple[np.array, np.array]:
"""
Sample a noised version of the top left and bottom right corners of a given `bbox`
Inputs:
- masks: [B, 1, H,W] boxes, dtype=torch.Tensor
- noise: noise as a fraction of box width and height, dtype=float
- noise_bound: maximum amount of noise (in pure pixesl), dtype=int
Returns:
- box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
- box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
"""
device = masks.device
box_coords = mask_to_box(masks)
B, _, H, W = masks.shape
box_labels = torch.tensor(
[top_left_label, bottom_right_label], dtype=torch.int, device=device
).repeat(B)
if noise > 0.0:
if not isinstance(noise_bound, torch.Tensor):
noise_bound = torch.tensor(noise_bound, device=device)
bbox_w = box_coords[..., 2] - box_coords[..., 0]
bbox_h = box_coords[..., 3] - box_coords[..., 1]
max_dx = torch.min(bbox_w * noise, noise_bound)
max_dy = torch.min(bbox_h * noise, noise_bound)
box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)

box_coords = box_coords + box_noise
img_bounds = (
torch.tensor([W, H, W, H], device=device) - 1
) # uncentered pixel coords
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping

box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
box_labels = box_labels.reshape(-1, 2)
return box_coords, box_labels


def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
"""
Sample `num_pt` random points (along with their labels) independently from the error regions.
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- num_pt: int, number of points to sample independently for each of the B error maps
Outputs:
- points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
negative clicks
"""
if pred_masks is None: # if pred_masks is not provided, treat it as empty
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
assert num_pt >= 0

B, _, H_im, W_im = gt_masks.shape
device = gt_masks.device

# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
# whether the prediction completely match the ground-truth on each mask
all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
all_correct = all_correct[..., None, None]

# channel 0 is FP map, while channel 1 is FN map
pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
# sample a negative new click from FP region or a positive new click
# from FN region, depend on where the maximum falls,
# and in case the predictions are all correct (no FP or FN), we just
# sample a negative click from the background region
pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
pts_noise[..., 1] *= fn_masks
pts_idx = pts_noise.flatten(2).argmax(dim=2)
labels = (pts_idx % 2).to(torch.int32)
pts_idx = pts_idx // 2
pts_x = pts_idx % W_im
pts_y = pts_idx // W_im
points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
return points, labels


def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
"""
Sample 1 random point (along with its label) from the center of each error region,
that is, the point with the largest distance to the boundary of each error region.
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- padding: if True, pad with boundary of 1 px for distance transform
Outputs:
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
"""
import cv2

if pred_masks is None:
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape

B, _, _, W_im = gt_masks.shape
device = gt_masks.device

# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks

fp_masks = fp_masks.cpu().numpy()
fn_masks = fn_masks.cpu().numpy()
points = torch.zeros(B, 1, 2, dtype=torch.float)
labels = torch.ones(B, 1, dtype=torch.int32)
for b in range(B):
fn_mask = fn_masks[b, 0]
fp_mask = fp_masks[b, 0]
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
# compute the distance of each point in FN/FP region to its boundary
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]

# take the point in FN/FP region with the largest distance to its boundary
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
fn_argmax = np.argmax(fn_mask_dt_flat)
fp_argmax = np.argmax(fp_mask_dt_flat)
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
pt_idx = fn_argmax if is_positive else fp_argmax
points[b, 0, 0] = pt_idx % W_im # x
points[b, 0, 1] = pt_idx // W_im # y
labels[b, 0] = int(is_positive)

points = points.to(device)
labels = labels.to(device)
return points, labels


def get_next_point(gt_masks, pred_masks, method):
if method == "uniform":
return sample_random_points_from_errors(gt_masks, pred_masks)
elif method == "center":
return sample_one_point_from_error_center(gt_masks, pred_masks)
else:
raise ValueError(f"unknown sampling method {method}")
116 changes: 116 additions & 0 deletions py/sam2/sam2_configs/sam2.1_hiera_b+.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# @package _global_

# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [896, 448, 224, 112]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest

memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2

num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
Loading

0 comments on commit cb87dd3

Please sign in to comment.