Skip to content

Commit

Permalink
[AutoNLP]add predict (#4967)
Browse files Browse the repository at this point in the history
* add_predict

* fix

* fix
  • Loading branch information
lugimzzz authored Feb 24, 2023
1 parent edf6fef commit f38a255
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 198 deletions.
18 changes: 14 additions & 4 deletions paddlenlp/experimental/autonlp/auto_trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _preprocess_fn(
preprocess an example from raw features to input features that Transformers models expect (e.g. input_ids, attention_mask, labels, etc)
"""

@abstractmethod
def export(self, export_path, trial_id=None):
"""
Export the model from a certain `trial_id` to the given file path.
Expand Down Expand Up @@ -171,15 +172,24 @@ def evaluate(self, trial_id=None, eval_dataset=None) -> Dict[str, float]:
"""
raise NotImplementedError

@abstractmethod
def predict(self, test_dataset: Dataset, trial_id: Optional[str] = None):
"""
Run prediction and returns predictions and potential metrics from a certain `trial_id` on the given dataset
Args:
test_dataset (Dataset, required): Custom test dataset and must contains the 'text_column' and 'label_column' fields.
trial_id (str, optional): Specify the model to be evaluated through the `trial_id`. Defaults to the best model selected by `metric_for_best_model`.
"""
raise NotImplementedError

def _override_hp(self, config: Dict[str, Any], default_hp: Any) -> Any:
"""
Overrides the arguments with the provided hyperparameter config
"""
new_hp = copy.deepcopy(default_hp)
for key, value in config.items():
if key.startswith(default_hp.__class__.__name__):
_, hp_key = key.split(".")
setattr(new_hp, hp_key, value)
if key in new_hp.to_dict():
setattr(new_hp, key, value)
return new_hp

def _filter_model_candidates(
Expand Down Expand Up @@ -264,7 +274,7 @@ def train(
experiment_name: (str, optional): name of the experiment. Experiment log will be stored under <output_dir>/<experiment_name>.
Defaults to UNIX timestamp.
hp_overrides: (dict[str, Any], optional): Advanced users only.
override the hyperparameters of every model candidate. For example, {"TrainingArguments.max_steps": 5}.
override the hyperparameters of every model candidate. For example, {"max_steps": 5}.
custom_model_candiates: (dict[str, Any], optional): Advanced users only.
Run the user-provided model candidates instead of the default model candidated from PaddleNLP. See `._model_candidates` property as an example
Expand Down
Loading

0 comments on commit f38a255

Please sign in to comment.