-
Notifications
You must be signed in to change notification settings - Fork 38
/
model.py
138 lines (115 loc) · 6.54 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Copyright (c) 2019-present NAVER Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch.nn as nn
from modules.transformation import TPS_SpatialTransformerNetwork
from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
from modules.sequence_modeling import BidirectionalLSTM
from modules.prediction import Attention
from modules.resnet_aster import ResNet_ASTER
from modules.bert import Bert_Ocr
from modules.bert import Config
from modules.SRN_modules import Transforme_Encoder, SRN_Decoder, Torch_transformer_encoder
from modules.resnet_fpn import ResNet_FPN
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
""" Transformation """
if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
else:
print('No Transformation module specified')
""" FeatureExtraction """
if opt.FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
elif opt.FeatureExtraction == 'AsterRes':
self.FeatureExtraction = ResNet_ASTER(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResnetFpn':
self.FeatureExtraction = ResNet_FPN()
else:
raise Exception('No FeatureExtraction module specified')
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
""" Sequence modeling"""
if opt.SequenceModeling == 'BiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
self.SequenceModeling_output = opt.hidden_size
elif opt.SequenceModeling == 'Bert':
cfg = Config()
cfg.dim = opt.output_channel; cfg.dim_c = opt.output_channel # 降维减少计算量
cfg.p_dim = opt.position_dim # 一张图片cnn编码之后的特征序列长度
cfg.max_vocab_size = opt.batch_max_length + 1 # 一张图片中最多的文字个数, +1 for EOS
cfg.len_alphabet = opt.alphabet_size # 文字的类别个数
self.SequenceModeling = Bert_Ocr(cfg)
elif opt.SequenceModeling == 'SRN':
self.SequenceModeling = Transforme_Encoder(n_layers=2, n_position=opt.position_dim)
# self.SequenceModeling = Torch_transformer_encoder(n_layers=2, n_position=opt.position_dim)
self.SequenceModeling_output = 512
else:
print('No SequenceModeling module specified')
self.SequenceModeling_output = self.FeatureExtraction_output
""" Prediction """
if opt.Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
elif opt.Prediction == 'Bert_pred':
pass
elif opt.Prediction == 'SRN':
self.Prediction = SRN_Decoder(n_position=opt.position_dim, N_max_character=opt.batch_max_character + 1, n_class=opt.alphabet_size)
else:
raise Exception('Prediction is neither CTC or Attn')
def forward(self, input, text, is_train=True):
""" Transformation stage """
if not self.stages['Trans'] == "None":
input = self.Transformation(input)
""" Feature extraction stage """
visual_feature = self.FeatureExtraction(input)
# if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn':
if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn':
b, c, h, w = visual_feature.shape
visual_feature = visual_feature.permute(0, 1, 3, 2)
visual_feature = visual_feature.contiguous().view(b, c, -1)
visual_feature = visual_feature.permute(0, 2, 1) # batch, seq, feature
else:
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
visual_feature = visual_feature.squeeze(3)
""" Sequence modeling stage """
if self.stages['Seq'] == 'BiLSTM':
contextual_feature = self.SequenceModeling(visual_feature)
elif self.stages['Seq'] == 'Bert':
pad_mask = text
contextual_feature = self.SequenceModeling(visual_feature, pad_mask)
elif self.stages['Seq'] == 'SRN':
contextual_feature = self.SequenceModeling(visual_feature, src_mask=None)[0]
else:
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
""" Prediction stage """
if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous())
elif self.stages['Pred'] == 'Bert_pred':
prediction = contextual_feature
elif self.stages['Pred'] == 'SRN':
prediction = self.Prediction(contextual_feature)
else:
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length)
return prediction