This code is a megengine version of Segment Anything Model (SAM)
, which is transfered from torch code.
The Segment Anything Model (SAM) is a foundation model for image segmentation. It can generate masks according to the input prompts such as points or boxes. User can use it to generate masks for all objects in an image. For more information of SAM, you can reference this paper.
pip install megengine opencv-python pycocotools matplotlib
There are two ways to get the MegEngine-SAM weights:
You can download MegEngine-SAM weights from here and save as ./checkpoints/*.pkl
.
vit_b:
VIT-B Model
vit_l:
VIT-L Model
vit_h:
VIT-H Model
You can download torch weights and save as ./checkpoints/*.pth
.
Then run:
export PYTHONPATH=/path/to/megengine-sam:$PYTHONPATH
python convert_weights.py
The converted MegEngine-SAM weights is saved as ./checkpoints/*.pkl
.
export PYTHONPATH=/path/to/megengine-sam:$PYTHONPATH
python example.py
This example can generate masks for the images in images/src
, and the results are saved in images/dst
.
MegEngine-SAM have the same api as segment-anything.
So you can use MegEngine-SAM to generate mask with the prompt like the torch version:
from mge_segment_anything import SamPredictor, sam_model_registry
predictor = SamPredictor(
sam_model_registry["model_name"](checkpoint="<path/to/checkpoint>")
)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
Or generate masks for a whole image:
from mge_segment_anything import SamAutomaticMaskGenerator, sam_model_registry
mask_generator = SamAutomaticMaskGenerator(
sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
)
masks = mask_generator.generate(<your_image>)