Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer support simultaneously parse JSON files and cmd arguments. #7768

Merged
merged 10 commits into from
Jan 10, 2024
8 changes: 4 additions & 4 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def read_local_dataset(path):
def main():
# Arguments
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
# Support format as "args.json --arg1 value1 --arg2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines()
else:
gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses()
training_args.print_config(model_args, "Model")
Expand Down
6 changes: 4 additions & 2 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,10 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:

def main():
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
# Support format as "args.json --arg1 value1 --arg2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines()
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

Expand Down
41 changes: 41 additions & 0 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dataclasses
import json
import sys
import warnings
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
from copy import copy
from enum import Enum
Expand Down Expand Up @@ -247,6 +248,46 @@
outputs.append(obj)
return (*outputs,)

def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]:
"""
Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON
file.

When there is a conflict between the command line arguments and the JSON file configuration,
the command line arguments will take precedence.

Returns:
Tuple consisting of:

- the dataclass instances in the same order as they were passed to the initializer.abspath
"""
if not sys.argv[1].endswith(".json"):
raise ValueError(f"The first argument should be a JSON file, but it is {sys.argv[1]}")

Check warning on line 265 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L265

Added line #L265 was not covered by tests
json_file = Path(sys.argv[1])
if json_file.exists():
with open(json_file, "r") as file:
data = json.load(file)
json_args = []
for key, value in data.items():
json_args.extend([f"--{key}", str(value)])
else:
raise FileNotFoundError(f"The argument file {json_file} does not exist.")

Check warning on line 274 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L274

Added line #L274 was not covered by tests
# In case of conflict, command line arguments take precedence
args = json_args + sys.argv[2:]
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
for k in keys:
delattr(namespace, k)
obj = dtype(**inputs)
outputs.append(obj)
if remaining_args:
warnings.warn(f"Some specified arguments are not used by the PdArgumentParser: {remaining_args}")

Check warning on line 287 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L287

Added line #L287 was not covered by tests

return (*outputs,)

def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
Expand Down
125 changes: 125 additions & 0 deletions tests/trainer/test_argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import tempfile
import unittest
from unittest.mock import patch

from llm.run_pretrain import PreTrainingArguments
from paddlenlp.trainer.argparser import PdArgumentParser


def parse_args():
parser = PdArgumentParser((PreTrainingArguments,))
# Support format as "args.json --arg1 value1 --arg2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
model_args = parser.parse_json_file_and_cmd_lines()
else:
model_args = parser.parse_args_into_dataclasses()
return model_args


def create_json_from_dict(data_dict, file_path):
with open(file_path, "w") as f:
json.dump(data_dict, f)


class ArgparserTest(unittest.TestCase):
script_name = "test_argparser.py"
args_dict = {
"max_steps": 3000,
"amp_master_grad": False,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"adam_epsilon": 1e-08,
"bf16": False,
"enable_linear_fused_grad_add": False,
"eval_steps": 3216,
"flatten_param_grads": False,
"fp16": 1,
"log_on_each_node": True,
"logging_dir": "./checkpoints/llama2_pretrain_ckpts/runs/Dec27_04-28-35_instance-047hzlt0-4",
"logging_first_step": False,
"logging_steps": 1,
"lr_end": 1e-07,
"max_evaluate_steps": -1,
"max_grad_norm": 1.0,
"min_learning_rate": 3e-06,
"no_cuda": False,
"num_cycles": 0.5,
"num_train_epochs": 3.0,
"output_dir": "./checkpoints/llama2_pretrain_ckpts",
}

def test_parse_cmd_lines(self):
cmd_line_args = [ArgparserTest.script_name]
for key, value in ArgparserTest.args_dict.items():
cmd_line_args.extend([f"--{key}", str(value)])
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)

def test_parse_json_file(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
create_json_from_dict(ArgparserTest.args_dict, tmpfile.name)
tmpfile_path = tmpfile.name
with patch("sys.argv", [ArgparserTest.script_name, tmpfile_path]):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)

def test_parse_json_file_and_cmd_lines(self):
half_size = len(ArgparserTest.args_dict) // 2
json_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[:half_size]}
cmd_line_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[half_size:]}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
create_json_from_dict(json_part, tmpfile.name)
tmpfile_path = tmpfile.name
cmd_line_args = [ArgparserTest.script_name, tmpfile_path]
for key, value in cmd_line_part.items():
cmd_line_args.extend([f"--{key}", str(value)])
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)

def test_parse_json_file_and_cmd_lines_with_conflict(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
json.dump(ArgparserTest.args_dict, tmpfile)
tmpfile_path = tmpfile.name
cmd_line_args = [
ArgparserTest.script_name,
tmpfile_path,
"--min_learning_rate",
"2e-5",
"--max_steps",
"3000",
"--log_on_each_node",
"False",
]
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
self.assertEqual(model_args.get("min_learning_rate"), 2e-5)
self.assertEqual(model_args.get("max_steps"), 3000)
self.assertEqual(model_args.get("log_on_each_node"), False)
for key, value in ArgparserTest.args_dict.items():
if key not in ["min_learning_rate", "max_steps", "log_on_each_node"]:
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)