Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer support simultaneously parse JSON files and cmd arguments. #7768

Merged
merged 10 commits into from
Jan 10, 2024
6 changes: 4 additions & 2 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def read_local_dataset(path):
def main():
# Arguments
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file(
# Support format as "args.json --args1 value1 --args2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines(
json_file=os.path.abspath(sys.argv[1])
greycooker marked this conversation as resolved.
Show resolved Hide resolved
)
else:
Expand Down
8 changes: 6 additions & 2 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,12 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:

def main():
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
# Support format as "args.json --args1 value1 --args2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

Expand Down
46 changes: 46 additions & 0 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,52 @@
outputs.append(obj)
return (*outputs,)

def parse_json_file_and_cmd_lines(self, json_file: str) -> Tuple[DataClass, ...]:
"""
Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON
file.

When there is a conflict between the command line arguments and the JSON file configuration,
the command line arguments will take precedence.

This method combines data from a JSON file and command line arguments to populate instances of dataclasses.

Args:
json_file :
The path to the JSON formatted file should be at index position 1 in the command line
arguments array (sys.argv[1]).
Any JSON file path at other positions will be considered invalid.

Returns:
Tuple consisting of:

- the dataclass instances in the same order as they were passed to the initializer.abspath
"""
json_args = json.loads(Path(json_file).read_text())
del sys.argv[1]
output_dir_arg = next(
(arg for arg in sys.argv if arg == "--output_dir" or arg.startswith("--output_dir=")), None
)
if output_dir_arg is None:
if "output_dir" in json_args.keys():
sys.argv.extend(["--output_dir", json_args["output_dir"]])
else:
raise ValueError("The following arguments are required: --output_dir")

Check warning on line 280 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L280

Added line #L280 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么这里要 特判 output_dir ?

Copy link
Contributor Author

@greycooker greycooker Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不特判output_dir,出现json文件里有output_dir参数,但是命令行里没有的情况,执行281行vars(self.parse_args())的时候就会报错,但是我们现在不希望让它报错,所以进行了output_dir的特判

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/PaddleNLP/blob/d89c01130a7f27c39d762cefb15926c4c69aa711/paddlenlp/trainer/argparser.py#L177C9-L177C36

你参考一下这个函数,这个函数也是一样的支持本地文件。
看看这个是怎么处理的。

这个作为通用的parser,在这里做output_dir之类的特判是不太合理的。

cmd_args = vars(self.parse_args())
merged_args = {}
for key in json_args.keys() | cmd_args.keys():
if any(arg == f"--{key}" or arg.startswith(f"--{key}=") for arg in sys.argv):
merged_args[key] = cmd_args.get(key)
elif json_args.get(key):
merged_args[key] = json_args.get(key)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in merged_args.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)

def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
Expand Down
125 changes: 125 additions & 0 deletions tests/trainer/test_argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import tempfile
import unittest
from unittest.mock import patch

from llm.run_pretrain import PreTrainingArguments
from paddlenlp.trainer.argparser import PdArgumentParser


def parse_args():
parser = PdArgumentParser((PreTrainingArguments,))
# Support format as "args.json --args1 value1 --args2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
model_args = parser.parse_json_file_and_cmd_lines(json_file=os.path.abspath(sys.argv[1]))
else:
model_args = parser.parse_args_into_dataclasses()
return model_args


def create_json_from_dict(data_dict, file_path):
with open(file_path, "w") as f:
json.dump(data_dict, f)


class ArgparserTest(unittest.TestCase):
script_name = "test_argparser.py"
args_dict = {
"max_steps": 3000,
"amp_master_grad": False,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"adam_epsilon": 1e-08,
"bf16": False,
"enable_linear_fused_grad_add": False,
"eval_steps": 3216,
"flatten_param_grads": False,
"fp16": 1,
"log_on_each_node": True,
"logging_dir": "./checkpoints/llama2_pretrain_ckpts/runs/Dec27_04-28-35_instance-047hzlt0-4",
"logging_first_step": False,
"logging_steps": 1,
"lr_end": 1e-07,
"max_evaluate_steps": -1,
"max_grad_norm": 1.0,
"min_learning_rate": 3e-06,
"no_cuda": False,
"num_cycles": 0.5,
"num_train_epochs": 3.0,
"output_dir": "./checkpoints/llama2_pretrain_ckpts",
}

def test_parse_cmd_lines(self):
cmd_line_args = [ArgparserTest.script_name]
for key, value in ArgparserTest.args_dict.items():
cmd_line_args.extend([f"--{key}", str(value)])
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)

def test_parse_json_file(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
create_json_from_dict(ArgparserTest.args_dict, tmpfile.name)
tmpfile_path = tmpfile.name
with patch("sys.argv", [ArgparserTest.script_name, tmpfile_path]):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)

def test_parse_json_file_and_cmd_lines(self):
half_size = len(ArgparserTest.args_dict) // 2
json_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[:half_size]}
cmd_line_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[half_size:]}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
create_json_from_dict(json_part, tmpfile.name)
tmpfile_path = tmpfile.name
cmd_line_args = [ArgparserTest.script_name, tmpfile_path]
for key, value in cmd_line_part.items():
cmd_line_args.extend([f"--{key}", str(value)])
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
for key, value in ArgparserTest.args_dict.items():
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)

def test_parse_json_file_and_cmd_lines_with_conflict(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile:
json.dump(ArgparserTest.args_dict, tmpfile)
tmpfile_path = tmpfile.name
cmd_line_args = [
ArgparserTest.script_name,
tmpfile_path,
"--min_learning_rate",
"2e-5",
"--max_steps",
"3000",
"--log_on_each_node",
"False",
]
with patch("sys.argv", cmd_line_args):
model_args = vars(parse_args()[0])
self.assertEqual(model_args.get("min_learning_rate"), 2e-5)
self.assertEqual(model_args.get("max_steps"), 3000)
self.assertEqual(model_args.get("log_on_each_node"), False)
for key, value in ArgparserTest.args_dict.items():
if key not in ["min_learning_rate", "max_steps", "log_on_each_node"]:
self.assertEqual(model_args.get(key), value)
os.remove(tmpfile_path)