-
-
Notifications
You must be signed in to change notification settings - Fork 30
/
retrain_attention_model.py
45 lines (36 loc) · 1.78 KB
/
retrain_attention_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import poutyne
from deepparse import download_from_public_repository
from deepparse.dataset_container import PickleDatasetContainer
from deepparse.parser import AddressParser
# First, let's download the train and test data with the new tags, "new tags", from the public repository.
saving_dir = "./data"
file_extension = "p"
training_dataset_name = "sample_incomplete_data"
test_dataset_name = "test_sample_data"
download_from_public_repository(training_dataset_name, saving_dir, file_extension=file_extension)
download_from_public_repository(test_dataset_name, saving_dir, file_extension=file_extension)
# Now let's create a training and test container.
training_container = PickleDatasetContainer(os.path.join(saving_dir, training_dataset_name + "." + file_extension))
test_container = PickleDatasetContainer(os.path.join(saving_dir, test_dataset_name + "." + file_extension))
# We will retrain the FastText attention version of our pretrained model.
model = "bpemb"
address_parser = AddressParser(model_type=model, device=0, attention_mechanism=True)
# Now, let's retrain for 5 epochs using a batch size of 8 since the data is really small for the example.
# Let's start with the default learning rate of 0.01 and use a learning rate scheduler to lower the learning rate
# as we progress.
lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1) # reduce LR by a factor of 10 each epoch
# The path to save our checkpoints
logging_path = "./checkpoints"
address_parser.retrain(
training_container,
train_ratio=0.8,
epochs=5,
batch_size=8,
num_workers=2,
callbacks=[lr_scheduler],
logging_path=logging_path,
layers_to_freeze="seq2seq",
)
# Now, let's test our fine-tuned model using the best checkpoint (default parameter).
address_parser.test(test_container, batch_size=256)