Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Exception: Mask conflict happenes! #2666

Closed
YoungSharp opened this issue Jul 9, 2020 · 21 comments
Closed

Exception: Mask conflict happenes! #2666

YoungSharp opened this issue Jul 9, 2020 · 21 comments
Assignees
Labels
close-as-overdue for the issue that is over 3 months with no responses model compression support user raised

Comments

@YoungSharp
Copy link

YoungSharp commented Jul 9, 2020

Environment: linux 16.04

  • NNI version: '999.0.0-developing' (20200608)
  • NNI mode (local|remote|pai): local
  • Client OS: linux16.04
  • Server OS (for remote mode only): no
  • Python version: 3.6
  • PyTorch/TensorFlow version: pytorch1.3.0
  • Is conda/virtualenv/venv used?: no
  • Is running in Docker?: no

Log message:

  • nnimanager.log:
  • dispatcher.log:
  • nnictl stdout and stderr:

What issue meet, what's expected?:
Using nni "SlimPruner" method to speedup mobilenetv2, and got error "Exception: Mask conflict happenes!";
To find the conflict layer, I changed the code in following:
`def find_successors(self, unique_name):
"""
Find successor nodes of the given node
Parameters
----------
unique_name : str
The unique name of the node

    Returns
    -------
    list
        a list of nodes who are the given node's successor
    """
    successors = []
    for output in self.name_to_node[unique_name].outputs:
        if output not in self.input_to_node:
            # may reach the output of the whole graph
            continue
        nodes_py = self.input_to_node[output]
        for node_py in nodes_py:
            successors.append(node_py.unique_name)
    print("successors = ", successors)
    return successors`

And got the following log:
'''

successors = ['backbone.conv1.2']
successors = ['backbone.stage1.0.conv.0.0']
successors = ['backbone.stage1.0.conv.0.2']
successors = ['backbone.stage1.0.conv.1']
successors = ['backbone.stage2.0.conv.0.0']
successors = ['backbone.stage2.0.conv.0.2']
successors = ['backbone.stage2.0.conv.1.0']
successors = ['backbone.stage2.0.conv.1.2']
successors = ['backbone.stage2.0.conv.2']
successors = ['backbone.stage2.1.conv.0.0', 'backbone.stage2.1.aten::add.138']
successors = ['backbone.stage3.0.conv.0.0']
successors = ['backbone.stage2.1.conv.0.2']
successors = ['backbone.stage2.1.conv.1.0']
successors = ['backbone.stage2.1.conv.1.2']
successors = ['backbone.stage2.1.conv.2']
successors = ['backbone.stage2.1.aten::add.138']
Traceback (most recent call last):
File "tools/compress_onnx_export.py", line 164, in
main()
File "tools/compress_onnx_export.py", line 137, in main
m_speedup.speedup_model()
File "speedup/compressor.py", line 187, in speedup_model
self.infer_modules_masks()
File "
/speedup/compressor.py", line 146, in infer_modules_masks
self.infer_module_mask(module_name, None, mask=mask)
File "/speedup/compressor.py", line 138, in infer_module_mask
self.infer_module_mask(_module_name, module_name, in_shape=output_cmask)
File "
/speedup/compressor.py", line 122, in infer_module_mask
output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
File "//speedup/infer_shape.py", line 240, in
'aten::add': lambda module_mask, mask: add_inshape(module_mask, mask),
File "/
/speedup/infer_shape.py", line 364, in add_inshape
raise Exception('Mask conflict happenes!')
Exception: Mask conflict happenes!
'''
How to reproduce it?:

Additional information:

@zheng-ningxin
Copy link
Contributor

Hi there, thanks for the feedback! The reason for this error is because NNI does not yet support fixing mask conflicts between BN layers(we may support this in the next release). I suggest you use the L1FilterPruner or L2FilterPruner to prune and speedup the Mobilenet.

@YoungSharp
Copy link
Author

YoungSharp commented Jul 9, 2020

Thanks for replay, do you have a schedule of your next release?

By the way, after using L1FilterPruner, i got the same error.

@zheng-ningxin
Copy link
Contributor

zheng-ningxin commented Jul 9, 2020

Could you please show the config list of the L1FilterPruner? Thanks~

@YoungSharp
Copy link
Author

config_list=[{ 'sparsity': 0.7, 'op_types': ['Conv2d'] }],

@zheng-ningxin
Copy link
Contributor

I run L1FilterPruner on this config_list, and it works fine. Please check if the mask file is generated by L1FilterPruner? I suggest you delete the mask file and re-generate the mask file by L1FilterPruner. If this problem still exists, could you please show me the code(just the snippet code of the pruner and the speedup part is also ok)? Thanks~

@YoungSharp
Copy link
Author

YoungSharp commented Jul 9, 2020

Actually i integrated nni compression model into mmdet.
`
compress_flag = cfg.compress_cfg['compress_flag']
if compress_flag:
prunner_epochs = cfg.compress_cfg['prune_epochs']
prunner_name = cfg.compress_cfg['compress_mode']
config_list = cfg.compress_cfg['config_list']
input_shape = cfg.compress_cfg['input_shape']
compress_model_name = cfg.compress_cfg['compress_model_name']
compress_model_mask_name = cfg.compress_cfg['compress_model_mask_name']
import os
compress_model_name = os.path.join(cfg.work_dir, compress_model_name)
compress_model_mask_name = os.path.join(cfg.work_dir, compress_model_mask_name)
print(compress_model_name)
print(compress_model_mask_name)
optimizer1 = build_optimizer(model, cfg.optimizer)
pruner = create_pruner(model, prunner_name, config_list, optimizer1)
model = pruner.compress()
runner = Runner(model, batch_processor, optimizer1, cfg.work_dir,
cfg.log_level, mix_up_cfg=cfg.mix_up_cfg, attack_cfg = cfg.attack_cfg)

        runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                       cfg.checkpoint_config, cfg.log_config)

        runner.run(data_loaders, cfg.workflow, prunner_epochs)
        pruner.export_model(model_path = compress_model_name, mask_path = compress_model_mask_name)

`
This code is followed by your sample "nni-master/examples/model_compress/model_prune_torch.py".

@zheng-ningxin
Copy link
Contributor

Could you please also show the speedup part? By the way, does this problem still exist after you delete and re-generate the mask file?

@YoungSharp
Copy link
Author

Sorry, Something else interrupted. I did not delete mask, I will try to re-generate the mask, and try to speed-up.

speed-up code:
`
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'ONNX exporting is currently not currently supported with {}'.
format(model.class.name))

config = config1[args.example_name]

use_mask_out = use_speedup_out = None
# must run use_mask before use_speedup because use_speedup modify the model
import numpy as np
if use_mask:
    apply_compression_results(model, args.masks_file, 'cpu' if config['device'] == 'cpu' else None)
    start = time.time()
    for _ in range(32):
        use_mask_out = model(dummy_input)
    print('elapsed time when use mask: ', time.time() - start)
    # print(use_mask_out.size())
    print(use_mask_out)

if use_speedup:
    m_speedup = ModelSpeedup(model, dummy_input, args.masks_file,
                             'cpu' if config['device'] == 'cpu' else None)
    m_speedup.speedup_model()
    start = time.time()
    for _ in range(32):
        use_speedup_out = model(dummy_input)
    print('elapsed time when use speedup: ', time.time() - start)
    
    
if compare_results:
    if torch.allclose(use_mask_out, use_speedup_out, atol=1e-02):
        print('the outputs from use_mask and use_speedup are the same')
    else:
        raise RuntimeError('the outputs from use_mask and use_speedup are different')

torch.onnx.export(model, dummy_input.to(device), 'after_speedup.onnx')

`
This code is followed by your sample "nni-master/examples/model_compress/model_speedup.py".

@YoungSharp
Copy link
Author

After re-generate mask. Got the same error.

@zheng-ningxin
Copy link
Contributor

Hi, since the context of mmdet seems very complicated, could you please just run the following code and see if it can finish without an error?

import os
import json
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from nni.compression.torch.speedup.compressor import ModelSpeedup
from nni.compression.torch import L1FilterPruner


net = torchvision.models.mobilenet_v2(pretrained=True)
net.cuda()
cfg = [{ 'sparsity': 0.7, 'op_types': ['Conv2d'] }]
pruner = L1FilterPruner(net, cfg)
pruner.compress()
pruner.export_model('./mobile_v2.pth', './mobile_v2_mask')


another = torchvision.models.mobilenet_v2()
another.cuda()
state_dict = torch.load('./mobile_v2.pth')
another.load_state_dict(state_dict)
data = torch.rand(1, 3, 224, 224).cuda()
ms = ModelSpeedup(another, data, './mobile_v2_mask')
ms.speedup_model()

@YoungSharp
Copy link
Author

Actually I can run your sample successfully(VGG-slim).
My mobilenetv2 is :
`

def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v

class ConvBNReLU(nn.Sequential):

def __init__(self,
             in_planes,
             out_planes,
             kernel_size=3,
             stride=1,
             groups=1):
    padding = (kernel_size - 1) // 2
    super(ConvBNReLU, self).__init__(
        nn.Conv2d(
            in_planes,
            out_planes,
            kernel_size,
            stride,
            padding,
            groups=groups,
            bias=False), nn.BatchNorm2d(out_planes),
        nn.ReLU6(inplace=True))

class InvertedResidual(nn.Module):

def __init__(self, inp, oup, stride, expand_ratio):
    super(InvertedResidual, self).__init__()
    self.stride = stride
    assert stride in [1, 2]

    hidden_dim = int(round(inp * expand_ratio))
    self.use_res_connect = self.stride == 1 and inp == oup

    layers = []
    if expand_ratio != 1:
        # pw
        layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
    layers.extend([
        # dw
        ConvBNReLU(
            hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
        # pw-linear
        nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
    ])
    self.conv = nn.Sequential(*layers)

def forward(self, x):
    if self.use_res_connect:
        return x + self.conv(x)
    else:
        return self.conv(x)

@BACKBONES.register_module
class MobileNetV2(nn.Module):
"""
MobileNetV2 is taken from pytorch hub.
https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py
"""

def __init__(self,
             out_indices=(1, 2, 4, 6),
             frozen_stages=-1,
             width_mult=1.0,
             inverted_residual_setting=None,
             round_nearest=8):
    """
    MobileNet V2 main class
    Args:
        width_mult (float): Width multiplier - adjusts number of channels
        in each layer by this amount
        inverted_residual_setting: Network structure
        round_nearest (int): Round the number of channels in each layer to
         be a multiple of this number. Set to 1 to turn off rounding
    """
    super(MobileNetV2, self).__init__()
    block = InvertedResidual
    input_channel = 32

    if inverted_residual_setting is None:
        inverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],  # 0
            [6, 24, 2, 2],  # 1
            [6, 32, 3, 2],  # 2
            [6, 64, 4, 2],  # 3
            [6, 96, 3, 1],  # 4
            [6, 160, 3, 2],  # 5
            [6, 320, 1, 1],  # 6
        ]

    # only check the first element,
    # assuming user knows t,c,n,s are required
    if len(inverted_residual_setting) == 0 or len(
            inverted_residual_setting[0]) != 4:
        raise ValueError("inverted_residual_setting should be non-empty "
                         "or a 4-element list, got {}".format(
                             inverted_residual_setting))

    self.frozen_stages = frozen_stages
    self.out_indices = out_indices
    assert max(out_indices) < len(inverted_residual_setting)
    # building first layer
    input_channel = _make_divisible(input_channel * width_mult,
                                    round_nearest)
    self.conv1 = ConvBNReLU(3, input_channel, stride=2)
    # building inverted residual blocks
    self.stages = []
    for si, (t, c, n, s) in enumerate(inverted_residual_setting):
        output_channel = _make_divisible(c * width_mult, round_nearest)
        stage = []
        for i in range(n):
            stride = s if i == 0 else 1
            stage.append(
                block(
                    input_channel, output_channel, stride, expand_ratio=t))
            input_channel = output_channel
        stage_name = 'stage{}'.format(si + 1)
        self.add_module(stage_name, nn.Sequential(*stage))
        self.stages.append(stage_name)

    self._freeze_stages()

def _freeze_stages(self):
    if self.frozen_stages >= 0:
        self.conv1.eval()
        for param in self.conv1.parameters():
            param.requires_grad = False

    for i in range(1, self.frozen_stages + 1):
        m = getattr(self, 'stage{}'.format(i))
        m.eval()
        for param in m.parameters():
            param.requires_grad = False

def init_weights(self, pretrained=None):
    if isinstance(pretrained, str):
        logger = logging.getLogger()
        load_checkpoint(self, pretrained, strict=False, logger=logger)
    elif pretrained is None:
        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
    else:
        raise TypeError('pretrained must be a str or None')

def forward(self, x):
    x = self.conv1(x)
    outs = []
    for i, stage_name in enumerate(self.stages):
        stage = getattr(self, stage_name)
        x = stage(x)
        if i in self.out_indices:
            outs.append(x)
    return tuple(outs)

`

@zheng-ningxin
Copy link
Contributor

Can you run the code I provided yesterday successfully?

Hi, since the context of mmdet seems very complicated, could you please just run the following code and see if it can finish without an error?

import os
import json
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from nni.compression.torch.speedup.compressor import ModelSpeedup
from nni.compression.torch import L1FilterPruner


net = torchvision.models.mobilenet_v2(pretrained=True)
net.cuda()
cfg = [{ 'sparsity': 0.7, 'op_types': ['Conv2d'] }]
pruner = L1FilterPruner(net, cfg)
pruner.compress()
pruner.export_model('./mobile_v2.pth', './mobile_v2_mask')


another = torchvision.models.mobilenet_v2()
another.cuda()
state_dict = torch.load('./mobile_v2.pth')
another.load_state_dict(state_dict)
data = torch.rand(1, 3, 224, 224).cuda()
ms = ModelSpeedup(another, data, './mobile_v2_mask')
ms.speedup_model()

@YoungSharp
Copy link
Author

I integrate you code into mmdetection, and could run it successfully, and got output without error.

@zheng-ningxin
Copy link
Contributor

Thanks for the quick reply~
In addition, I found that your mobilenetv2 use the tensor tuple as the output. The tensor tuple is not supported yet and will be provided in the next release( #2609).

def forward(self, x):
x = self.conv1(x)
outs = []
for i, stage_name in enumerate(self.stages):
stage = getattr(self, stage_name)
x = stage(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)


`

@YoungSharp
Copy link
Author

@zheng-ningxin
I had already changed that, and it dose not have an error with tensor tuple.

So do you know how to fix my code, or this is a bug?

@zheng-ningxin
Copy link
Contributor

I integrate you code into mmdetection, and could run it successfully, and got output without error.

I thought you have fixed this problem. So you still cannot speedup the mobilenetv_2? Could you please paste the newest version of your mobilenet_V2(without tuple)?

@YoungSharp
Copy link
Author

In mmdet, old version is :
def multi_apply_compress(func, *args, **kwargs): pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) return tuple(map(list, zip(*map_results)))
I could get tuple error.
So I changed the code:
def multi_apply_compress(func, *args, **kwargs): pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) result = tuple(map(list, zip(*map_results))) return result[0][0]
and will not get tuple error.

@zheng-ningxin
Copy link
Contributor

zheng-ningxin commented Jul 14, 2020

Hi, I run the speedup module on the MobilenetV2 you gave me, and it works fine. Please check if you can run the following code successfully?

import os
import json
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from nni.compression.torch.speedup.compressor import ModelSpeedup
from nni.compression.torch import L1FilterPruner


def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class ConvBNReLU(nn.Sequential):

    def __init__(self,
                in_planes,
                out_planes,
                kernel_size=3,
                stride=1,
                groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(
                in_planes,
                out_planes,
                kernel_size,
                stride,
                padding,
                groups=groups,
                bias=False), nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True))

class InvertedResidual(nn.Module):

    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            # dw
            ConvBNReLU(
                hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        ])
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    """
    MobileNetV2 is taken from pytorch hub.
    https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py
    """
    def __init__(self,
                out_indices=(1, 2, 4, 6),
                frozen_stages=-1,
                width_mult=1.0,
                inverted_residual_setting=None,
                round_nearest=8):
        """
        MobileNet V2 main class
        Args:
            width_mult (float): Width multiplier - adjusts number of channels
            in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to
            be a multiple of this number. Set to 1 to turn off rounding
        """
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],  # 0
                [6, 24, 2, 2],  # 1
                [6, 32, 3, 2],  # 2
                [6, 64, 4, 2],  # 3
                [6, 96, 3, 1],  # 4
                [6, 160, 3, 2],  # 5
                [6, 320, 1, 1],  # 6
            ]

        # only check the first element,
        # assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(
                inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                            "or a 4-element list, got {}".format(
                                inverted_residual_setting))

        self.frozen_stages = frozen_stages
        self.out_indices = out_indices
        assert max(out_indices) < len(inverted_residual_setting)
        # building first layer
        input_channel = _make_divisible(input_channel * width_mult,
                                        round_nearest)
        self.conv1 = ConvBNReLU(3, input_channel, stride=2)
        # building inverted residual blocks
        self.stages = []
        for si, (t, c, n, s) in enumerate(inverted_residual_setting):
            output_channel = _make_divisible(c * width_mult, round_nearest)
            stage = []
            for i in range(n):
                stride = s if i == 0 else 1
                stage.append(
                    block(
                        input_channel, output_channel, stride, expand_ratio=t))
                input_channel = output_channel
            stage_name = 'stage{}'.format(si + 1)
            self.add_module(stage_name, nn.Sequential(*stage))
            self.stages.append(stage_name)

        self._freeze_stages()

    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.conv1.eval()
            for param in self.conv1.parameters():
                param.requires_grad = False

        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'stage{}'.format(i))
            m.eval()
            for param in m.parameters():
                param.requires_grad = False

    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            # weight initialization
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        x = self.conv1(x)
        outs = []
        for i, stage_name in enumerate(self.stages):
            stage = getattr(self, stage_name)
            x = stage(x)
            if i in self.out_indices:
                outs.append(x)
        return tuple(outs)




net = MobileNetV2()
net.cuda()
cfg = [{ 'sparsity': 0.7, 'op_types': ['Conv2d'] }]
pruner = L1FilterPruner(net, cfg)
pruner.compress()
pruner.export_model('./mobile_v2.pth', './mobile_v2_mask')


another = MobileNetV2()
another.cuda()
state_dict = torch.load('./mobile_v2.pth')
another.load_state_dict(state_dict)
data = torch.rand(1, 3, 224, 224).cuda()
ms = ModelSpeedup(another, data, './mobile_v2_mask')
ms.speedup_model()

So, I guess the problem you met is caused by mmdet.

So do you know how to fix my code, or this is a bug?

Because we can speedup the mobilenet you give me and the one in torchvision successfully, so I think this is not caused by speedup's bug. I can only see your code snippets. Without the complete code and environment, it is difficult to see whether it is cause by your code writing or there is some conflict between mmdet and nni speedup, for example, both frameworks may wrap or modify the model.

@YoungSharp
Copy link
Author

Thanks a lot. I will look into mmdet try to find bugs of it.

@kvartet
Copy link
Contributor

kvartet commented Dec 7, 2020

@YoungSharp I’m closing this issue as it has no updates from user for 3 months, please feel free to reopen if you are still seeing it an active issue.

@kvartet kvartet closed this as completed Dec 7, 2020
@kvartet kvartet added the close-as-overdue for the issue that is over 3 months with no responses label Dec 7, 2020
@guojunyao419
Copy link

Thanks a lot. I will look into mmdet try to find bugs of it.

Hi @YoungSharp , I am trying to integerate nni into mmdetection. May I ask have you successfully applied nni compression/speedup to any object detection models in mmdetection? Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
close-as-overdue for the issue that is over 3 months with no responses model compression support user raised
Projects
None yet
Development

No branches or pull requests

5 participants