Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to quantize the lightweight SAM model? #24

Open
ranpin opened this issue May 7, 2024 · 0 comments
Open

how to quantize the lightweight SAM model? #24

ranpin opened this issue May 7, 2024 · 0 comments

Comments

@ranpin
Copy link

ranpin commented May 7, 2024

Hi, nice work it is. I'm tring your method to do some application and have some questions about the quantization.

I have carefully looked at the code in demo_quan.py and layer.py, but currently the model in demo_quan.py is loaded directly from the quantized weights. I would like to ask how to quantize the instances created from an existing pre-trained SAM model using your quantization method?

Since I don't know how you quantize your lightweight SAM model using the quantization method in layer.py, can you provide a reference example of how did you do when quantizing the model? Thank you very much!

Here is the demo I wrote, it runs successfully, but the test result after quantization is close to 0. does it need retraining? Or maybe I'm not thinking correctly? hoping your reply!

from quantization_layer.layers import InferQuantConv2d, InferQuantConvTranspose2d

model_type = 'vit_b'
checkpoint = 'checkpoints/sam_vit_b_01ec64.pth'
model = sam_model_registry[model_type](checkpoint=checkpoint)
model.to(device)
model.eval()
predictor = SamPredictor(model)

w_bit = 8
a_bit = 8
input_size = (1, 3, 1024, 1024)  
n_V = input_size[2]
n_H = input_size[3]
a_interval = torch.tensor(0.1)
a_bias = torch.tensor(0.0)
w_interval = torch.tensor(0.01)

# 量化模型中的卷积层和卷积转置层
def replace_with_quantized_layers(model):
    layers_to_replace = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            layers_to_replace.append((name, module))
    for name, module in layers_to_replace:
        if isinstance(module, nn.Conv2d):
            quantized_module = InferQuantConv2d(
                in_channels=module.in_channels,
                out_channels=module.out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                dilation=module.dilation,
                groups=module.groups,
                bias=module.bias is not None,
                mode='quant_forward',
                w_bit=w_bit,
                a_bit=a_bit
            )
            quantized_module.get_parameter(n_V=n_V, 
                      n_H=n_H,
                      a_interval=a_interval,
                      a_bias=a_bias,
                      w_interval=w_interval)
        elif isinstance(module, nn.ConvTranspose2d):
            quantized_module = InferQuantConvTranspose2d(
                in_channels=module.in_channels,
                out_channels=module.out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                output_padding=module.output_padding,
                groups=module.groups,
                bias=module.bias is not None,
                mode='quant_forward',
                w_bit=w_bit,
                a_bit=a_bit
            )
            quantized_module.get_parameter(n_V=n_V,
                              n_H=n_H,
                              a_interval=a_interval,
                              a_bias=a_bias,
                              w_interval=w_interval)
        setattr(model, name, quantized_module)
    return model

quan_model = replace_with_quantized_layers(model)
print(quan_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant