Skip to content

Commit

Permalink
Feat/add new tags to retrain cli (#167)
Browse files Browse the repository at this point in the history
* add missing import in init

* add feature to allow new_prediction_tags in retrain CLI API
  • Loading branch information
davebulaval authored Nov 24, 2022
1 parent 1b536f7 commit 6de6511
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 5 deletions.
3 changes: 3 additions & 0 deletions deepparse/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# pylint: disable=wildcard-import
from .download_model import *
from .download_models import *
from .parse import *
from .parser_arguments_adder import *
from .retrain import *
from .test import *
from .tools import *
29 changes: 29 additions & 0 deletions deepparse/cli/retrain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import json
import sys
from typing import Dict

Expand Down Expand Up @@ -45,6 +46,21 @@ def parse_retrained_arguments(parsed_args) -> Dict:
return parsed_retain_arguments


def handle_prediction_tags(parsed_args):
dict_parsed_args = vars(parsed_args)
path = dict_parsed_args.get("prediction_tags")

tags_dict_arguments = {"prediction_tags": None} # Default case

if path is not None:
with open(path, "r", encoding="UTF-8") as file:
prediction_tags = json.load(file)
if "EOS" not in prediction_tags.keys():
raise ValueError("The prediction tags dictionary is missing the EOS tag.")
tags_dict_arguments.update({"prediction_tags": prediction_tags})
return tags_dict_arguments


def main(args=None) -> None:
# pylint: disable=too-many-locals, too-many-branches
"""
Expand Down Expand Up @@ -115,6 +131,9 @@ def main(args=None) -> None:

address_parser = AddressParser(**parser_args)

new_tags_parser_args_update_args = handle_prediction_tags(parsed_args)
parser_args.update(**new_tags_parser_args_update_args)

parsed_retain_arguments = parse_retrained_arguments(parsed_args)

address_parser.retrain(
Expand Down Expand Up @@ -217,6 +236,16 @@ def get_parser() -> argparse.ArgumentParser:
default=None,
type=str,
)
parser.add_argument(
"--prediction_tags",
help=wrap(
"Path to a JSON file of prediction tags to use to retrain. Tags are in a key-value style, where "
"the key is the tag name, and the value is the index one."
"The last element has to be an EOS tag. Read the doc for more detail about EOS tag."
),
default=None,
type=str,
)

add_seed_arg(parser)

Expand Down
3 changes: 2 additions & 1 deletion docs/source/cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ One can use the command ``parse --help`` to output the same description in your
- ``--csv_column_names``: The column names to extract address in the CSV. Need to be specified if the provided dataset_path leads to a CSV file. Column names have to be separated by whitespace. For example, ``--csv_column_names column1 column2``.
- ``--csv_column_separator``: The column separator for the dataset container will only be used if the dataset is a CSV one. By default, ``'\t'``.
- ``--cache_dir``: To change the default cache directory (default to ``None``, e.g. default path).
- ``prediction_tags``: To change the prediction tags. The ``prediction_tags`` is a path leading to a JSON file of the new tags in a key-value style. For example, the path can be ``"a_path/file.json"`` and the content can be ``{"new_tag": 0, "other_tag": 1, "EOS": 2}``

.. autofunction:: deepparse.cli.retrain.main

We do not handle the ``seq2seq_params`` and ``prediction_tags`` fine-tuning argument for now.
We do not handle the ``seq2seq_params`` fine-tuning argument for now.

Test
****
Expand Down
47 changes: 43 additions & 4 deletions tests/cli/test_retrain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# pylint: disable=too-many-arguments, too-many-locals

# Pylint error for TemporaryDirectory ask for with statement
# pylint: disable=consider-using-with

import json
import os
import unittest
from tempfile import TemporaryDirectory
Expand All @@ -14,6 +11,27 @@
from deepparse.cli.retrain import get_args, parse_retrained_arguments
from tests.parser.integration.base_retrain import RetrainTestCase

# Pylint error for TemporaryDirectory ask for with statement
# pylint: disable=consider-using-with

_new_tag_set = {
"primary": 0,
"pre": 1,
"street": 2,
"suffix": 3,
"post": 4,
"secdes": 5,
"secnum": 6,
"extsecdes": 7,
"extsecnum": 8,
"pmbdes": 9,
"pmbnum": 10,
"city": 11,
"state": 12,
"zipcode": 13,
"EOS": 14,
}


@skipIf(
not os.path.exists(os.path.join(os.path.expanduser("~"), ".cache", "deepparse", "cc.fr.300.bin")),
Expand Down Expand Up @@ -46,6 +64,11 @@ def setUp(self) -> None:
# temp directory
self.logging_path = os.path.join(self.temp_checkpoints_obj.name, "checkpoints")

self.path_to_new_tags_set = os.path.join(self.temp_checkpoints_obj.name, "a_new_tag_set.json")

with open(self.path_to_new_tags_set, "w", encoding="UTF-8") as file:
json.dump(_new_tag_set, file, ensure_ascii=False)

def tearDown(self) -> None:
self.temp_checkpoints_obj.cleanup()

Expand All @@ -68,6 +91,7 @@ def set_up_params(
device="cpu", # By default, we set it to cpu instead of gpu device 0 as the CLI function.
csv_column_names: List = None,
csv_column_separator="\t",
prediction_tags=None, # By default, we do not set a new tags set
) -> List:
if model_type is None:
# The default case for the test is a FastText model
Expand Down Expand Up @@ -123,6 +147,10 @@ def set_up_params(
parser_params.extend(["--csv_column_names"])
parser_params.extend(csv_column_names) # Since csv_column_names is a list

if prediction_tags is not None:
# To handle the None case (that is using the default None of the argparser).
parser_params.extend(["--prediction_tags", prediction_tags])

return parser_params

def test_integration_cpu(self):
Expand Down Expand Up @@ -280,6 +308,17 @@ def test_integrationWithValDataset(self):
)
)

def test_integrationWithNewTagsSet(self):
parser_params = self.set_up_params(prediction_tags=self.path_to_new_tags_set)

retrain.main(parser_params)

self.assertTrue(
os.path.isfile(
os.path.join(self.temp_checkpoints_obj.name, "checkpoints", "retrained_fasttext_address_parser.ckpt")
)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6de6511

Please sign in to comment.