Skip to content

Commit

Permalink
add code for multi gpus retrain and code examples
Browse files Browse the repository at this point in the history
  • Loading branch information
davebulaval committed Oct 18, 2024
1 parent 48dde6f commit fb097de
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
5 changes: 4 additions & 1 deletion deepparse/parser/address_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def retrain(
layers_to_freeze: Union[str, None] = None,
name_of_the_retrain_parser: Union[None, str] = None,
verbose: Union[None, bool] = None,
retrain_device: Union[None, str, int, List[int]] = None,
) -> List[Dict]:
# pylint: disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements

Expand Down Expand Up @@ -773,11 +774,13 @@ def retrain(

optimizer = SGD(self.model.parameters(), learning_rate)

device = retrain_device if retrain_device is not None else self.device

# Poutyne handle model.train()
exp = Experiment(
logging_path,
self.model,
device=self.device,
device=device,
optimizer=optimizer,
loss_function=nll_loss,
batch_metrics=[accuracy],
Expand Down
61 changes: 61 additions & 0 deletions examples/fine_tuning_multi_gpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import poutyne

from deepparse import download_from_public_repository
from deepparse.dataset_container import PickleDatasetContainer
from deepparse.parser import AddressParser

# First, let's download the train and test data from the public repository.
saving_dir = "./data"
file_extension = "p"
training_dataset_name = "sample_incomplete_data"
test_dataset_name = "test_sample_data"
download_from_public_repository(training_dataset_name, saving_dir, file_extension=file_extension)
download_from_public_repository(test_dataset_name, saving_dir, file_extension=file_extension)

# Now let's create a training and test container.
training_container = PickleDatasetContainer(os.path.join(saving_dir, training_dataset_name + "." + file_extension))
test_container = PickleDatasetContainer(os.path.join(saving_dir, test_dataset_name + "." + file_extension))

# We will retrain the FastText version of our pretrained model.
address_parser = AddressParser(model_type="fasttext", device=0)

# Now, let's retrain for 5 epochs using a batch size of 8 since the data is really small for the example.
# Let's start with the default learning rate of 0.01 and use a learning rate scheduler to lower the learning rate
# as we progress.
lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1) # reduce LR by a factor of 10 each epoch

# The checkpoints (ckpt) are saved in the default "./checkpoints" directory, so if you wish to retrain
# another model (let's say BPEmb), you need to change the `logging_path` directory; otherwise, you will get
# an error when retraining since Poutyne will try to use the last checkpoint.
address_parser.retrain(
training_container,
retrain_device=[0, 1],
train_ratio=0.8,
epochs=5,
batch_size=8,
num_workers=2,
callbacks=[lr_scheduler],
)

# Now, let's test our fine-tuned model using the best checkpoint (default parameter).
address_parser.test(test_container, batch_size=256)

# Now let's retrain the FastText version but with an attention mechanism.
address_parser = AddressParser(model_type="fasttext", device=0, attention_mechanism=True)

# Since the previous checkpoints were saved in the default "./checkpoints" directory, we need to use a new one.
# Otherwise, poutyne will try to reload the previous checkpoints, and our model has changed.
address_parser.retrain(
training_container,
retrain_device="all",
train_ratio=0.8,
epochs=5,
batch_size=8,
num_workers=2,
callbacks=[lr_scheduler],
logging_path="checkpoints_attention",
)

# Now, let's test our fine-tuned model using the best checkpoint (default parameter).
address_parser.test(test_container, batch_size=256)

0 comments on commit fb097de

Please sign in to comment.