Skip to content

Latest commit

 

History

History
72 lines (47 loc) · 2.56 KB

README_cn.md

File metadata and controls

72 lines (47 loc) · 2.56 KB

mge-segment-anything

这是 MegEngine 版本SAM 模型, 相关代码是 torch 版本 SAM 在 MegEngine 中的实现。

SAM 是一个图片分割的基础模型。它可以根据用户输入的 prompts 来为图片生成 mask。用户也可以使用 SAM 为一张图片中的所有物体生成 mask。这篇论文中有着关于 SAM 模型的更多信息。

环境准备

pip install megengine opencv-python pycocotools matplotlib

权值下载

有两个方法可以得到 MegEngine-SAM 的权值:

方法一:直接下载

可以从这里直接下载 MegEngine-SAM 的权值,下载完成后请存储为 checkpoints/*.pkl`。

vit_b: VIT-B Model

vit_l: VIT-L Model

vit_h: VIT-H Model

方法二:从 Torch 转换

可以下载 torch weights 存储为 checkpoints/*.pth`。

然后执行以下代码进行权值转换:

export PYTHONPATH=/path/to/megengine-sam:$PYTHONPATH
python convert_weights.py

转换完成后,被转换好的 MegEngine-SAM 权值会被存为 ./checkpoints/*.pkl

例子

export PYTHONPATH=/path/to/megengine-sam:$PYTHONPATH
python example.py

这个例子会为 images/src 底下图片的生成 mask,相关结果会被存储到 images/dst 底下。

使用

MegEngine-SAM 的 api 和原始版本的 segment-anything 保持了一致。

所以你可以用下面的代码根据 prompt 为一张图片生成 mask:

from mge_segment_anything import SamPredictor, sam_model_registry
predictor = SamPredictor(
    sam_model_registry["model_name"](checkpoint="<path/to/checkpoint>")
)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)

或者为一张图片中所有物体生成 mask:

from mge_segment_anything import SamAutomaticMaskGenerator, sam_model_registry
mask_generator = SamAutomaticMaskGenerator(
    sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
)
masks = mask_generator.generate(<your_image>)