Skip to content

Commit

Permalink
Resize vocab and classifier automatically for taggers
Browse files Browse the repository at this point in the history
  • Loading branch information
hanlpbot committed Oct 19, 2023
1 parent 9579160 commit 19eb67a
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 14 deletions.
15 changes: 8 additions & 7 deletions hanlp/common/torch_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def fit(self,
self.load(finetune, devices=devices)
else:
self.load(save_dir, devices=devices)
self.config.finetune = finetune
self.vocabs.unlock() # For extending vocabs
logger.info(
f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
Expand All @@ -258,13 +260,12 @@ def fit(self,
dev = self.build_dataloader(**merge_dict(config, data=dev_data, batch_size=batch_size, shuffle=False,
training=None, device=first_device, logger=logger, vocabs=self.vocabs,
overwrite=True)) if dev_data else None
if not finetune:
flash('[yellow]Building model [blink]...[/blink][/yellow]')
self.model = self.build_model(**merge_dict(config, training=True))
flash('')
logger.info(f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
assert self.model, 'build_model is not properly implemented.'
flash('[yellow]Building model [blink]...[/blink][/yellow]')
self.model = self.build_model(**merge_dict(config, training=True), logger=logger)
flash('')
logger.info(f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
assert self.model, 'build_model is not properly implemented.'
_description = repr(self.model)
if len(_description.split('\n')) < 10:
logger.info(_description)
Expand Down
10 changes: 9 additions & 1 deletion hanlp/common/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,20 @@ def _load_vocabs(vd, vocabs: dict, vocab_cls=Vocab):

def lock(self):
"""
Lock each vocabs.
Lock each vocab.
"""
for key, value in self.items():
if isinstance(value, Vocab):
value.lock()

def unlock(self):
"""
Unlock each vocab.
"""
for key, value in self.items():
if isinstance(value, Vocab):
value.unlock()

@property
def mutable(self):
status = [v.mutable for v in self.values() if isinstance(v, Vocab)]
Expand Down
2 changes: 1 addition & 1 deletion hanlp/components/ner/transformer_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def fit(self, trn_data, dev_data, save_dir, transformer,
sampler_builder: SamplerBuilder = None,
epochs=3,
tagset=None,
token_key=None,
token_key='token',
max_seq_len=None,
sent_delimiter=None,
char_level=False,
Expand Down
21 changes: 18 additions & 3 deletions hanlp/components/taggers/transformers/transformer_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from hanlp.layers.transformers.encoder import TransformerEncoder
from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
from hanlp.utils.time_util import CountdownTimer
from hanlp.utils.torch_util import clip_grad_norm, lengths_to_mask
from hanlp.utils.torch_util import clip_grad_norm, lengths_to_mask, filter_state_dict_safely
from hanlp_common.util import merge_locals_kwargs


Expand Down Expand Up @@ -142,14 +142,28 @@ def compute_distill_loss(self, kd_criterion, out_S, out_T, mask, temperature_sch
temperature = temperature_scheduler(logits_S, logits_T)
return kd_criterion(logits_S, logits_T, temperature)

def build_model(self, training=True, extra_embeddings: Embedding = None, **kwargs) -> torch.nn.Module:
def build_model(self, training=True, extra_embeddings: Embedding = None, finetune=False, logger=None,
**kwargs) -> torch.nn.Module:
model = TransformerTaggingModel(
self.build_transformer(training=training),
len(self.vocabs.tag),
self.config.crf,
self.config.get('secondary_encoder', None),
extra_embeddings=extra_embeddings.module(self.vocabs) if extra_embeddings else None,
)
if finetune:
model_state = model.state_dict()
load_state = self.model.state_dict()
safe_state = filter_state_dict_safely(model_state, load_state)
missing_params = model_state.keys() - safe_state.keys()
if missing_params:
logger.info(f'The following parameters were missing from the checkpoint: '
f'{", ".join(sorted(missing_params))}.')
model.load_state_dict(safe_state, strict=False)
n = self.model.classifier.bias.size(0)
if model.classifier.bias.size(0) != n:
model.classifier.weight.data[:n, :] = self.model.classifier.weight.data[:n, :]
model.classifier.bias.data[:n] = self.model.classifier.bias.data[:n]
return model

# noinspection PyMethodOverriding
Expand Down Expand Up @@ -203,7 +217,8 @@ def tokenizer_transform(self) -> TransformerSequenceTokenizer:
return self._tokenizer_transform

def build_vocabs(self, trn, logger, **kwargs):
self.vocabs.tag = Vocab(pad_token=None, unk_token=None)
if 'tag' not in self.vocabs:
self.vocabs.tag = Vocab(pad_token=None, unk_token=None)
timer = CountdownTimer(len(trn))
max_seq_len = 0
token_key = self.config.token_key
Expand Down
11 changes: 10 additions & 1 deletion hanlp/utils/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,4 +283,13 @@ def lengths_to_mask(seq_len, max_len=None):


def activation_from_name(name: str):
return getattr(torch.nn, name)
return getattr(torch.nn, name)


def filter_state_dict_safely(model_state: dict, load_state: dict):
safe_state = dict()
for k, v in load_state.items():
model_v = model_state.get(k, None)
if model_v is not None and model_v.shape == v.shape:
safe_state[k] = v
return safe_state
2 changes: 1 addition & 1 deletion hanlp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Author: hankcs
# Date: 2019-12-28 19:26

__version__ = '2.1.0-beta.51'
__version__ = '2.1.0-beta.52'
"""HanLP version"""


Expand Down
43 changes: 43 additions & 0 deletions plugins/hanlp_demo/hanlp_demo/zh/train/finetune_ner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2023-10-18 18:49
import os

import hanlp
from hanlp.components.ner.transformer_ner import TransformerNamedEntityRecognizer
from tests import cdroot

cdroot()

your_training_corpus = 'data/ner/finetune/word_to_iobes.tsv'
your_development_corpus = your_training_corpus # Use a different one in reality
save_dir = 'data/ner/finetune/model'

if not os.path.exists(your_training_corpus):
os.makedirs(os.path.dirname(your_training_corpus), exist_ok=True)
with open(your_training_corpus, 'w') as out:
out.write(
'''训练\tB-NLP
语料\tE-NLP
\tO
IOBES\tO
格式\tO
'''
)

ner = TransformerNamedEntityRecognizer()
ner.fit(
trn_data=your_training_corpus,
dev_data=your_development_corpus,
save_dir=save_dir,
epochs=50, # Since the corpus is small, overfit it
finetune=hanlp.pretrained.ner.MSRA_NER_ELECTRA_SMALL_ZH,
# You MUST set the same parameters with the fine-tuning model:
average_subwords=True,
transformer='hfl/chinese-electra-180g-small-discriminator',
)

HanLP = hanlp.pipeline()\
.append(hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH), output_key='tok')\
.append(ner, output_key='ner')
HanLP(['训练语料为IOBES格式', '晓美焰来到北京立方庭参观自然语义科技公司。']).pretty_print()

0 comments on commit 19eb67a

Please sign in to comment.