Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change backbone #86

Open
paolopepe00 opened this issue Aug 27, 2024 · 1 comment
Open

Change backbone #86

paolopepe00 opened this issue Aug 27, 2024 · 1 comment

Comments

@paolopepe00
Copy link

How can I modify the backbone using MobileNet?

@yxl502
Copy link

yxl502 commented Nov 12, 2024

from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.model_config import MODEL_SPECS

def make_divisible(
value: float,
divisor: int,
min_value: Optional[float] = None,
round_down_protect: bool = True,
) -> int:
"""
This function is copied from here
"https://github.com/tensorflow/models/blob/master/official/vision/modeling/layers/nn_layers.py"

This is to ensure that all layers have channels that are divisible by 8.

Args:
    value: A `float` of original value.
    divisor: An `int` of the divisor that need to be checked upon.
    min_value: A `float` of  minimum value threshold.
    round_down_protect: A `bool` indicating whether round down more than 10%
    will be allowed.

Returns:
    The adjusted value in `int` that is divisible against divisor.
"""
if min_value is None:
    min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if round_down_protect and new_value < 0.9 * value:
    new_value += divisor
return int(new_value)

def conv_2d(inp, oup, kernel_size=3, stride=1, groups=1, bias=False, norm=True, act=True):
conv = nn.Sequential()
padding = (kernel_size - 1) // 2
conv.add_module('conv', nn.Conv2d(inp, oup, kernel_size, stride, padding, bias=bias, groups=groups))
if norm:
conv.add_module('BatchNorm2d', nn.BatchNorm2d(oup))
if act:
conv.add_module('Activation', nn.ReLU6())
return conv

class InvertedResidual(nn.Module):
def init(self, inp, oup, stride, expand_ratio, act=False, squeeze_excitation=False):
super(InvertedResidual, self).init()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.block = nn.Sequential()
if expand_ratio != 1:
self.block.add_module('exp_1x1', conv_2d(inp, hidden_dim, kernel_size=3, stride=stride))
if squeeze_excitation:
self.block.add_module('conv_3x3',
conv_2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, groups=hidden_dim))
self.block.add_module('red_1x1', conv_2d(hidden_dim, oup, kernel_size=1, stride=1, act=act))
self.use_res_connect = self.stride == 1 and inp == oup

def forward(self, x):
    if self.use_res_connect:
        return x + self.block(x)
    else:
        return self.block(x)

class UniversalInvertedBottleneckBlock(nn.Module):
def init(self,
inp,
oup,
start_dw_kernel_size,
middle_dw_kernel_size,
middle_dw_downsample,
stride,
expand_ratio
):
"""An inverted bottleneck block with optional depthwises.
Referenced from here https://github.com/tensorflow/models/blob/master/official/vision/modeling/layers/nn_blocks.py
"""
super().init()
# Starting depthwise conv.
self.start_dw_kernel_size = start_dw_kernel_size
if self.start_dw_kernel_size:
stride_ = stride if not middle_dw_downsample else 1
self.start_dw = conv_2d(inp, inp, kernel_size=start_dw_kernel_size, stride=stride_, groups=inp, act=False)
# Expansion with 1x1 convs.
expand_filters = make_divisible(inp * expand_ratio, 8)
self.expand_conv = conv_2d(inp, expand_filters, kernel_size=1)
# Middle depthwise conv.
self.middle_dw_kernel_size = middle_dw_kernel_size
if self.middle_dw_kernel_size:
stride
= stride if middle_dw_downsample else 1
self.middle_dw = conv_2d(expand_filters, expand_filters, kernel_size=middle_dw_kernel_size, stride=stride,
groups=expand_filters)
# Projection with 1x1 convs.
self._proj_conv = conv_2d(expand_filters, oup, kernel_size=1, stride=1, act=False)

    # Ending depthwise conv.
    # this not used
    # _end_dw_kernel_size = 0
    # self._end_dw = conv_2d(oup, oup, kernel_size=_end_dw_kernel_size, stride=stride, groups=inp, act=False)

def forward(self, x):
    if self.start_dw_kernel_size:
        x = self._start_dw_(x)
        # print("_start_dw_", x.shape)
    x = self._expand_conv(x)
    # print("_expand_conv", x.shape)
    if self.middle_dw_kernel_size:
        x = self._middle_dw(x)
        # print("_middle_dw", x.shape)
    x = self._proj_conv(x)
    # print("_proj_conv", x.shape)
    return x

class MultiQueryAttentionLayerWithDownSampling(nn.Module):
def init(self, inp, num_heads, key_dim, value_dim, query_h_strides, query_w_strides, kv_strides,
dw_kernel_size=3, dropout=0.0):
"""Multi Query Attention with spatial downsampling.
Referenced from here https://github.com/tensorflow/models/blob/master/official/vision/modeling/layers/nn_blocks.py

    3 parameters are introduced for the spatial downsampling:
    1. kv_strides: downsampling factor on Key and Values only.
    2. query_h_strides: vertical strides on Query only.
    3. query_w_strides: horizontal strides on Query only.

    This is an optimized version.
    1. Projections in Attention is explict written out as 1x1 Conv2D.
    2. Additional reshapes are introduced to bring a up to 3x speed up.
    """
    super().__init__()
    self.num_heads = num_heads
    self.key_dim = key_dim
    self.value_dim = value_dim
    self.query_h_strides = query_h_strides
    self.query_w_strides = query_w_strides
    self.kv_strides = kv_strides
    self.dw_kernel_size = dw_kernel_size
    self.dropout = dropout

    self.head_dim = key_dim // num_heads

    if self.query_h_strides > 1 or self.query_w_strides > 1:
        self._query_downsampling_norm = nn.BatchNorm2d(inp)
    self._query_proj = conv_2d(inp, num_heads * key_dim, 1, 1, norm=False, act=False)

    if self.kv_strides > 1:
        self._key_dw_conv = conv_2d(inp, inp, dw_kernel_size, kv_strides, groups=inp, norm=True, act=False)
        self._value_dw_conv = conv_2d(inp, inp, dw_kernel_size, kv_strides, groups=inp, norm=True, act=False)
    self._key_proj = conv_2d(inp, key_dim, 1, 1, norm=False, act=False)
    self._value_proj = conv_2d(inp, key_dim, 1, 1, norm=False, act=False)

    self._output_proj = conv_2d(num_heads * key_dim, inp, 1, 1, norm=False, act=False)
    self.dropout = nn.Dropout(p=dropout)

def forward(self, x):
    batch_size, seq_length, _, _ = x.size()
    if self.query_h_strides > 1 or self.query_w_strides > 1:
        q = F.avg_pool2d(self.query_h_stride, self.query_w_stride)
        q = self._query_downsampling_norm(q)
        q = self._query_proj(q)
    else:
        q = self._query_proj(x)
    px = q.size(2)
    q = q.view(batch_size, self.num_heads, -1, self.key_dim)  # [batch_size, num_heads, seq_length, key_dim]

    if self.kv_strides > 1:
        k = self._key_dw_conv(x)
        k = self._key_proj(k)
        v = self._value_dw_conv(x)
        v = self._value_proj(v)
    else:
        k = self._key_proj(x)
        v = self._value_proj(x)
    k = k.view(batch_size, 1, self.key_dim, -1)  # [batch_size, 1, key_dim, seq_length]
    v = v.view(batch_size, 1, -1, self.key_dim)  # [batch_size, 1, seq_length, key_dim]

    # calculate attn score
    attn_score = torch.matmul(q, k) / (self.head_dim ** 0.5)
    attn_score = self.dropout(attn_score)
    attn_score = F.softmax(attn_score, dim=-1)

    context = torch.matmul(attn_score, v)
    context = context.view(batch_size, self.num_heads * self.key_dim, px, px)
    output = self._output_proj(context)
    return output

class MNV4LayerScale(nn.Module):
def init(self, inp, init_value):
"""LayerScale as introduced in CaiT: https://arxiv.org/abs/2103.17239
Referenced from here https://github.com/tensorflow/models/blob/master/official/vision/modeling/layers/nn_blocks.py

    As used in MobileNetV4.

    Attributes:
        init_value (float): value to initialize the diagonal matrix of LayerScale.
    """
    super().__init__()
    self.init_value = init_value
    self._gamma = nn.Parameter(self.init_value * torch.ones(inp, 1, 1))

def forward(self, x):
    return x * self._gamma

class MultiHeadSelfAttentionBlock(nn.Module):
def init(
self,
inp,
num_heads,
key_dim,
value_dim,
query_h_strides,
query_w_strides,
kv_strides,
use_layer_scale,
use_multi_query,
use_residual=True
):
super().init()
self.query_h_strides = query_h_strides
self.query_w_strides = query_w_strides
self.kv_strides = kv_strides
self.use_layer_scale = use_layer_scale
self.use_multi_query = use_multi_query
self.use_residual = use_residual

    self._input_norm = nn.BatchNorm2d(inp)
    if self.use_multi_query:
        self.multi_query_attention = MultiQueryAttentionLayerWithDownSampling(
            inp, num_heads, key_dim, value_dim, query_h_strides, query_w_strides, kv_strides
        )
    else:
        self.multi_head_attention = nn.MultiheadAttention(inp, num_heads, kdim=key_dim)

    if self.use_layer_scale:
        self.layer_scale_init_value = 1e-5
        self.layer_scale = MNV4LayerScale(inp, self.layer_scale_init_value)

def forward(self, x):
    # Not using CPE, skipped
    # input norm
    shortcut = x
    x = self._input_norm(x)
    # multi query
    if self.use_multi_query:
        x = self.multi_query_attention(x)
    else:
        x = self.multi_head_attention(x, x)
    # layer scale
    if self.use_layer_scale:
        x = self.layer_scale(x)
    # use residual
    if self.use_residual:
        x = x + shortcut
    return x

def build_blocks(layer_spec):
if not layer_spec.get('block_name'):
return nn.Sequential()
block_names = layer_spec['block_name']
layers = nn.Sequential()
if block_names == "convbn":
schema_ = ['inp', 'oup', 'kernel_size', 'stride']
for i in range(layer_spec['num_blocks']):
args = dict(zip(schema_, layer_spec['block_specs'][i]))
layers.add_module(f"convbn_{i}", conv_2d(**args))
elif block_names == "uib":
schema_ = ['inp', 'oup', 'start_dw_kernel_size', 'middle_dw_kernel_size', 'middle_dw_downsample', 'stride',
'expand_ratio', 'mhsa']
for i in range(layer_spec['num_blocks']):
args = dict(zip(schema_, layer_spec['block_specs'][i]))
mhsa = args.pop("mhsa") if "mhsa" in args else 0
layers.add_module(f"uib_{i}", UniversalInvertedBottleneckBlock(**args))
if mhsa:
mhsa_schema_ = [
"inp", "num_heads", "key_dim", "value_dim", "query_h_strides", "query_w_strides", "kv_strides",
"use_layer_scale", "use_multi_query", "use_residual"
]
args = dict(zip(mhsa_schema_, [args['oup']] + (mhsa)))
layers.add_module(f"mhsa_{i}", MultiHeadSelfAttentionBlock(**args))
elif block_names == "fused_ib":
schema_ = ['inp', 'oup', 'stride', 'expand_ratio', 'act']
for i in range(layer_spec['num_blocks']):
args = dict(zip(schema_, layer_spec['block_specs'][i]))
layers.add_module(f"fused_ib_{i}", InvertedResidual(**args))
else:
raise NotImplementedError
return layers

class MobileNetV4(nn.Module):
def init(self, model):
# MobileNetV4ConvSmall MobileNetV4ConvMedium MobileNetV4ConvLarge
# MobileNetV4HybridMedium MobileNetV4HybridLarge
"""Params to initiate MobilenNetV4
Args:
model : support 5 types of models as indicated in
"https://github.com/tensorflow/models/blob/master/official/vision/modeling/backbones/mobilenet.py"
"""
super().init()
assert model in MODEL_SPECS.keys()
self.model = model
self.spec = MODEL_SPECS[self.model]

    # conv0
    self.conv0 = build_blocks(self.spec['conv0'])
    # layer1
    self.layer1 = build_blocks(self.spec['layer1'])
    # layer2
    self.layer2 = build_blocks(self.spec['layer2'])
    # layer3
    self.layer3 = build_blocks(self.spec['layer3'])
    # layer4
    self.layer4 = build_blocks(self.spec['layer4'])
    # layer5
    self.layer5 = build_blocks(self.spec['layer5'])

def forward(self, x):
    x0 = self.conv0(x)
    x1 = self.layer1(x0)
    x2 = self.layer2(x1)
    x3 = self.layer3(x2)
    x4 = self.layer4(x3)
    x5 = self.layer5(x4)
    x5 = nn.functional.adaptive_avg_pool2d(x5, 1)
    return [x1, x2, x3, x4, x5]

if name == 'main':

samples = torch.randn(32, 3, 128, 128)
model = MobileNetV4('MobileNetV4ConvSmall')(samples)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants