Skip to content

Commit

Permalink
[examples/summarization] deal with max_length and num_beams (hugg…
Browse files Browse the repository at this point in the history
…ingface#21740)

* Override the decoding parameters of Seq2SeqTrainer

* Fix quality

* Fix max_length parameter

* Fix quality

* Remove redundant parameter max_length

* Separate the preprocess of train and validation to use different max_target_length
  • Loading branch information
bofenghuang authored and ArthurZucker committed Mar 2, 2023
1 parent f9cf164 commit 625cf9c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
22 changes: 12 additions & 10 deletions examples/pytorch/summarization/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,16 @@ def compute_metrics(eval_preds):
result["gen_len"] = np.mean(prediction_lens)
return result

# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)

# Initialize our Trainer
trainer = Seq2SeqTrainer(
model=model,
Expand Down Expand Up @@ -672,15 +682,9 @@ def compute_metrics(eval_preds):

# Evaluation
results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
metrics = trainer.evaluate(metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

Expand All @@ -690,9 +694,7 @@ def compute_metrics(eval_preds):
if training_args.do_predict:
logger.info("*** Predict ***")

predict_results = trainer.predict(
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict")
metrics = predict_results.metrics
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
Expand Down
30 changes: 15 additions & 15 deletions examples/pytorch/summarization/run_summarization_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,6 @@ def parse_args():
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
),
)
parser.add_argument(
"--max_length",
type=int,
default=128,
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
),
)
parser.add_argument(
"--num_beams",
type=int,
Expand Down Expand Up @@ -473,6 +464,9 @@ def main():
f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}"
)

if args.val_max_target_length is None:
args.val_max_target_length = args.max_target_length

# Temporarily set max_target_length for training.
max_target_length = args.max_target_length
padding = "max_length" if args.pad_to_max_length else False
Expand All @@ -497,7 +491,7 @@ def preprocess_function(examples):
return model_inputs

with accelerator.main_process_first():
processed_datasets = raw_datasets.map(
train_dataset = raw_datasets["train"].map(
preprocess_function,
batched=True,
num_proc=args.preprocessing_num_workers,
Expand All @@ -506,8 +500,16 @@ def preprocess_function(examples):
desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]
# Temporarily set max_target_length for validation.
max_target_length = args.val_max_target_length
eval_dataset = raw_datasets["validation"].map(
preprocess_function,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 1):
Expand Down Expand Up @@ -667,11 +669,9 @@ def postprocess_text(preds, labels):
break

model.eval()
if args.val_max_target_length is None:
args.val_max_target_length = args.max_target_length

gen_kwargs = {
"max_length": args.val_max_target_length if args is not None else config.max_length,
"max_length": args.val_max_target_length,
"num_beams": args.num_beams,
}
for step, batch in enumerate(eval_dataloader):
Expand Down

0 comments on commit 625cf9c

Please sign in to comment.