-
Notifications
You must be signed in to change notification settings - Fork 17
/
get_flops.py
73 lines (57 loc) · 2.56 KB
/
get_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import time
import argparse
import datetime
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy
from utils import AverageMeter
from datasets.blending import CutmixMixupBlending
from config import get_config
from classification import build_model
from datasets.build import build_dataloader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
from timm.models.layers import trunc_normal_
from thop import profile, clever_format
device = "cuda:0" if torch.cuda.is_available() else "cpu"
def parse_option():
parser = argparse.ArgumentParser('FocalNet training and evaluation script', add_help=False)
parser.add_argument('--cfg', type=str, required=False, metavar="FILE", help='path to config file',
default='./configs/kinetics400/video-focalnet_tiny.yaml')
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
# easy config modification
parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--output', default='output', type=str, metavar='PATH',
help='root of output folder')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel')
args, unparsed = parser.parse_known_args()
config = get_config(args)
return args, config
_, config = parse_option()
config.defrost()
config.DATA.NUM_FRAMES = 8
config.freeze()
model = build_model(config)
model = model.to(device)
data = torch.randn(1,8,3,224,224).to(device)
macs, params = profile(model, inputs=(data, ))
macs, _ = clever_format([macs, params], "%.3f")
print("gflops:", macs)