From 5dc3ab554e89a87467b16a7da747a3ef808650d1 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Wed, 3 Jan 2024 09:57:04 +0000 Subject: [PATCH 1/6] add parse_json_file_and_cmd_lines --- llm/finetune_generation.py | 14 ++-- llm/run_pretrain.py | 10 ++- paddlenlp/trainer/argparser.py | 45 +++++++++++ tests/test_argparser.py | 132 +++++++++++++++++++++++++++++++++ 4 files changed, 193 insertions(+), 8 deletions(-) create mode 100644 tests/test_argparser.py diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index 826cdb4daf30..c0add5ee29d6 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -58,12 +58,16 @@ def read_local_dataset(path): def main(): # Arguments parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file( - json_file=os.path.abspath(sys.argv[1]) - ) - else: + json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")] + if len(json_indices) >= 2: + raise ValueError("Only support one file in json format at most, please check the command line parameters.") + elif len(json_indices) == 0: gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses() + else: + json_file_idx = json_indices[0] + gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines( + json_file_idx + ) training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") training_args.print_config(quant_args, "Quant") diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 00f8928d2ead..e75a158cc3c9 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -372,10 +372,14 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: def main(): parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: + json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")] + if len(json_indices) >= 2: + raise ValueError("Only support one file in json format at most, please check the command line parameters.") + elif len(json_indices) == 0: model_args, data_args, training_args = parser.parse_args_into_dataclasses() + else: + json_file_idx = json_indices[0] + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines(json_file_idx) if training_args.enable_linear_fused_grad_add: from fused_layers import mock_layers diff --git a/paddlenlp/trainer/argparser.py b/paddlenlp/trainer/argparser.py index 16f0ea8ca028..357297c8f6e6 100644 --- a/paddlenlp/trainer/argparser.py +++ b/paddlenlp/trainer/argparser.py @@ -18,6 +18,7 @@ import dataclasses import json +import os import sys from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError from copy import copy @@ -247,6 +248,50 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: outputs.append(obj) return (*outputs,) + def parse_json_file_and_cmd_lines(self, json_file_idx: int) -> Tuple[DataClass, ...]: + """ + Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON + file. + + This method combines data from a JSON file and command line arguments to populate instances of dataclasses. + The JSON file is identified using its index in the command line arguments array. + + Args: + json_file_idx : + The index of the JSON file argument within the command line arguments array. + This index is used to locate and extract the JSON file path from the command line arguments. + + Returns: + Tuple consisting of: + + - the dataclass instances in the same order as they were passed to the initializer.abspath + """ + json_file = os.path.abspath(sys.argv[json_file_idx]) + json_args = json.loads(Path(json_file).read_text()) + del sys.argv[json_file_idx] + output_dir_arg = next( + (arg for arg in sys.argv if arg == "--output_dir" or arg.startswith("--output_dir=")), None + ) + if output_dir_arg is None: + if "output_dir" in json_args.keys(): + sys.argv.extend(["--output_dir", json_args["output_dir"]]) + else: + raise ValueError("The following arguments are required: --output_dir") + cmd_args = vars(self.parse_args()) + merged_args = {} + for key in json_args.keys() | cmd_args.keys(): + if any(arg == f"--{key}" or arg.startswith(f"--{key}=") for arg in sys.argv): + merged_args[key] = cmd_args.get(key) + elif json_args.get(key): + merged_args[key] = json_args.get(key) + outputs = [] + for dtype in self.dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype) if f.init} + inputs = {k: v for k, v in merged_args.items() if k in keys} + obj = dtype(**inputs) + outputs.append(obj) + return (*outputs,) + def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: """ Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass diff --git a/tests/test_argparser.py b/tests/test_argparser.py new file mode 100644 index 000000000000..46d557ce46f5 --- /dev/null +++ b/tests/test_argparser.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import sys +import tempfile +import unittest +from unittest.mock import patch + +from llm.run_pretrain import PreTrainingArguments +from paddlenlp.trainer.argparser import PdArgumentParser + + +def parse_args(): + parser = PdArgumentParser((PreTrainingArguments,)) + json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")] + if len(json_indices) >= 2: + raise ValueError("Only support one file in json format at most, please check the command line parameters.") + elif len(json_indices) == 0: + model_args = parser.parse_args_into_dataclasses() + else: + json_file_idx = json_indices[0] + model_args = parser.parse_json_file_and_cmd_lines(json_file_idx) + return model_args + + +def create_json_from_dict(data_dict, file_path): + with open(file_path, "w") as f: + json.dump(data_dict, f) + + +class ArgparserTest(unittest.TestCase): + script_name = "test_argparser.py" + args_dict = { + "max_steps": 3000, + "amp_master_grad": False, + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "adam_epsilon": 1e-08, + "bf16": False, + "enable_linear_fused_grad_add": False, + "eval_steps": 3216, + "flatten_param_grads": False, + "fp16": 1, + "log_on_each_node": True, + "logging_dir": "./checkpoints/llama2_pretrain_ckpts/runs/Dec27_04-28-35_instance-047hzlt0-4", + "logging_first_step": False, + "logging_steps": 1, + "lr_end": 1e-07, + "max_evaluate_steps": -1, + "max_grad_norm": 1.0, + "min_learning_rate": 3e-06, + "no_cuda": False, + "num_cycles": 0.5, + "num_train_epochs": 3.0, + "output_dir": "./checkpoints/llama2_pretrain_ckpts", + } + + def test_parse_args_with_multiple_json_files(self): + with self.assertRaises(ValueError): + with patch("sys.argv", [ArgparserTest.script_name, "config1.json", "config2.json"]): + parse_args() + + def test_parse_cmd_lines(self): + cmd_line_args = [ArgparserTest.script_name] + for key, value in ArgparserTest.args_dict.items(): + cmd_line_args.extend([f"--{key}", str(value)]) + with patch("sys.argv", cmd_line_args): + model_args = vars(parse_args()[0]) + for key, value in ArgparserTest.args_dict.items(): + self.assertEqual(model_args.get(key), value) + + def test_parse_json_file(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile: + create_json_from_dict(ArgparserTest.args_dict, tmpfile.name) + tmpfile_path = tmpfile.name + with patch("sys.argv", [ArgparserTest.script_name, tmpfile_path]): + model_args = vars(parse_args()[0]) + for key, value in ArgparserTest.args_dict.items(): + self.assertEqual(model_args.get(key), value) + os.remove(tmpfile_path) + + def test_parse_json_file_and_cmd_lines(self): + half_size = len(ArgparserTest.args_dict) // 2 + json_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[:half_size]} + cmd_line_part = {k: ArgparserTest.args_dict[k] for k in list(ArgparserTest.args_dict)[half_size:]} + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile: + create_json_from_dict(json_part, tmpfile.name) + tmpfile_path = tmpfile.name + cmd_line_args = [ArgparserTest.script_name, tmpfile_path] + for key, value in cmd_line_part.items(): + cmd_line_args.extend([f"--{key}", str(value)]) + with patch("sys.argv", cmd_line_args): + model_args = vars(parse_args()[0]) + for key, value in ArgparserTest.args_dict.items(): + self.assertEqual(model_args.get(key), value) + os.remove(tmpfile_path) + + def test_parse_json_file_and_cmd_lines_with_conflict(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmpfile: + json.dump(ArgparserTest.args_dict, tmpfile) + tmpfile_path = tmpfile.name + cmd_line_args = [ + ArgparserTest.script_name, + tmpfile_path, + "--min_learning_rate", + "2e-5", + "--max_steps", + "3000", + "--log_on_each_node", + "False", + ] + with patch("sys.argv", cmd_line_args): + model_args = vars(parse_args()[0]) + self.assertEqual(model_args.get("min_learning_rate"), 2e-5) + self.assertEqual(model_args.get("max_steps"), 3000) + self.assertEqual(model_args.get("log_on_each_node"), False) + for key, value in ArgparserTest.args_dict.items(): + if key not in ["min_learning_rate", "max_steps", "log_on_each_node"]: + self.assertEqual(model_args.get(key), value) + os.remove(tmpfile_path) From e348363dae0d521d66b1b51954508e8aa8511c11 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Thu, 4 Jan 2024 03:16:27 +0000 Subject: [PATCH 2/6] change unit test file path --- tests/{ => trainer}/test_argparser.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{ => trainer}/test_argparser.py (100%) diff --git a/tests/test_argparser.py b/tests/trainer/test_argparser.py similarity index 100% rename from tests/test_argparser.py rename to tests/trainer/test_argparser.py From 6652ed69b5d4c069d46ffc53cdbb93ed66180b33 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Mon, 8 Jan 2024 08:22:49 +0000 Subject: [PATCH 3/6] Change the way the JSON file is determined --- llm/finetune_generation.py | 16 ++++++++-------- llm/run_pretrain.py | 14 +++++++------- paddlenlp/trainer/argparser.py | 17 +++++++++-------- tests/trainer/test_argparser.py | 17 +++++------------ 4 files changed, 29 insertions(+), 35 deletions(-) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index e8ff15f07821..050d8b800195 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -58,16 +58,16 @@ def read_local_dataset(path): def main(): # Arguments parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments)) - json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")] - if len(json_indices) >= 2: - raise ValueError("Only support one file in json format at most, please check the command line parameters.") - elif len(json_indices) == 0: - gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses() - else: - json_file_idx = json_indices[0] + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + elif len(sys.argv) > 2 and sys.argv[1].endswith(".json"): gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines( - json_file_idx + json_file=os.path.abspath(sys.argv[1]) ) + else: + gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses() training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") training_args.print_config(quant_args, "Quant") diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index e75a158cc3c9..b78f7ea31d34 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -372,14 +372,14 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: def main(): parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) - json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")] - if len(json_indices) >= 2: - raise ValueError("Only support one file in json format at most, please check the command line parameters.") - elif len(json_indices) == 0: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + elif len(sys.argv) > 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines( + json_file=os.path.abspath(sys.argv[1]) + ) else: - json_file_idx = json_indices[0] - model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines(json_file_idx) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() if training_args.enable_linear_fused_grad_add: from fused_layers import mock_layers diff --git a/paddlenlp/trainer/argparser.py b/paddlenlp/trainer/argparser.py index 357297c8f6e6..08ebbf586493 100644 --- a/paddlenlp/trainer/argparser.py +++ b/paddlenlp/trainer/argparser.py @@ -18,7 +18,6 @@ import dataclasses import json -import os import sys from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError from copy import copy @@ -248,27 +247,29 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: outputs.append(obj) return (*outputs,) - def parse_json_file_and_cmd_lines(self, json_file_idx: int) -> Tuple[DataClass, ...]: + def parse_json_file_and_cmd_lines(self, json_file: str) -> Tuple[DataClass, ...]: """ Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON file. + When there is a conflict between the command line arguments and the JSON file configuration, + the command line arguments will take precedence. + This method combines data from a JSON file and command line arguments to populate instances of dataclasses. - The JSON file is identified using its index in the command line arguments array. Args: - json_file_idx : - The index of the JSON file argument within the command line arguments array. - This index is used to locate and extract the JSON file path from the command line arguments. + json_file : + The path to the JSON formatted file should be at index position 1 in the command line + arguments array (sys.argv[1]). + Any JSON file path at other positions will be considered invalid. Returns: Tuple consisting of: - the dataclass instances in the same order as they were passed to the initializer.abspath """ - json_file = os.path.abspath(sys.argv[json_file_idx]) json_args = json.loads(Path(json_file).read_text()) - del sys.argv[json_file_idx] + del sys.argv[1] output_dir_arg = next( (arg for arg in sys.argv if arg == "--output_dir" or arg.startswith("--output_dir=")), None ) diff --git a/tests/trainer/test_argparser.py b/tests/trainer/test_argparser.py index 46d557ce46f5..200a265f4945 100644 --- a/tests/trainer/test_argparser.py +++ b/tests/trainer/test_argparser.py @@ -24,14 +24,12 @@ def parse_args(): parser = PdArgumentParser((PreTrainingArguments,)) - json_indices = [index for index, string in enumerate(sys.argv) if string.endswith(".json")] - if len(json_indices) >= 2: - raise ValueError("Only support one file in json format at most, please check the command line parameters.") - elif len(json_indices) == 0: - model_args = parser.parse_args_into_dataclasses() + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + elif len(sys.argv) > 2 and sys.argv[1].endswith(".json"): + model_args = parser.parse_json_file_and_cmd_lines(json_file=os.path.abspath(sys.argv[1])) else: - json_file_idx = json_indices[0] - model_args = parser.parse_json_file_and_cmd_lines(json_file_idx) + model_args = parser.parse_args_into_dataclasses() return model_args @@ -67,11 +65,6 @@ class ArgparserTest(unittest.TestCase): "output_dir": "./checkpoints/llama2_pretrain_ckpts", } - def test_parse_args_with_multiple_json_files(self): - with self.assertRaises(ValueError): - with patch("sys.argv", [ArgparserTest.script_name, "config1.json", "config2.json"]): - parse_args() - def test_parse_cmd_lines(self): cmd_line_args = [ArgparserTest.script_name] for key, value in ArgparserTest.args_dict.items(): From 36c50e0c97e6284608c4775abab0a9d5aea24834 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Mon, 8 Jan 2024 08:56:28 +0000 Subject: [PATCH 4/6] Merge parameter parsing judgment branches and add comments. --- llm/finetune_generation.py | 8 +++----- llm/run_pretrain.py | 6 +++--- tests/trainer/test_argparser.py | 6 +++--- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index 050d8b800195..eed4793ec993 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -58,11 +58,9 @@ def read_local_dataset(path): def main(): # Arguments parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file( - json_file=os.path.abspath(sys.argv[1]) - ) - elif len(sys.argv) > 2 and sys.argv[1].endswith(".json"): + # Support format as "args.json --args1 value1 --args2 value2.” + # In case of conflict, command line arguments take precedence. + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines( json_file=os.path.abspath(sys.argv[1]) ) diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index b78f7ea31d34..489aae9f7935 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -372,9 +372,9 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: def main(): parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - elif len(sys.argv) > 2 and sys.argv[1].endswith(".json"): + # Support format as "args.json --args1 value1 --args2 value2.” + # In case of conflict, command line arguments take precedence. + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines( json_file=os.path.abspath(sys.argv[1]) ) diff --git a/tests/trainer/test_argparser.py b/tests/trainer/test_argparser.py index 200a265f4945..ed4e9a3e4256 100644 --- a/tests/trainer/test_argparser.py +++ b/tests/trainer/test_argparser.py @@ -24,9 +24,9 @@ def parse_args(): parser = PdArgumentParser((PreTrainingArguments,)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - elif len(sys.argv) > 2 and sys.argv[1].endswith(".json"): + # Support format as "args.json --args1 value1 --args2 value2.” + # In case of conflict, command line arguments take precedence. + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): model_args = parser.parse_json_file_and_cmd_lines(json_file=os.path.abspath(sys.argv[1])) else: model_args = parser.parse_args_into_dataclasses() From 52cb70d7c8b533e8805af54022b2b05490cee90f Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Tue, 9 Jan 2024 13:59:23 +0000 Subject: [PATCH 5/6] remove the special handling of output_dir --- llm/finetune_generation.py | 6 ++--- llm/run_pretrain.py | 6 ++--- paddlenlp/trainer/argparser.py | 45 +++++++++++++-------------------- tests/trainer/test_argparser.py | 4 +-- 4 files changed, 24 insertions(+), 37 deletions(-) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index eed4793ec993..d9a54a0e6226 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -58,12 +58,10 @@ def read_local_dataset(path): def main(): # Arguments parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments)) - # Support format as "args.json --args1 value1 --args2 value2.” + # Support format as "args.json --arg1 value1 --arg2 value2.” # In case of conflict, command line arguments take precedence. if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): - gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines( - json_file=os.path.abspath(sys.argv[1]) - ) + gen_args, quant_args, model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() else: gen_args, quant_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses() training_args.print_config(model_args, "Model") diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 489aae9f7935..23c185a2fcf1 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -372,12 +372,10 @@ def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: def main(): parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) - # Support format as "args.json --args1 value1 --args2 value2.” + # Support format as "args.json --arg1 value1 --arg2 value2.” # In case of conflict, command line arguments take precedence. if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): - model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines( - json_file=os.path.abspath(sys.argv[1]) - ) + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() diff --git a/paddlenlp/trainer/argparser.py b/paddlenlp/trainer/argparser.py index 08ebbf586493..c08506a64a64 100644 --- a/paddlenlp/trainer/argparser.py +++ b/paddlenlp/trainer/argparser.py @@ -247,7 +247,7 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: outputs.append(obj) return (*outputs,) - def parse_json_file_and_cmd_lines(self, json_file: str) -> Tuple[DataClass, ...]: + def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]: """ Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON file. @@ -255,40 +255,31 @@ def parse_json_file_and_cmd_lines(self, json_file: str) -> Tuple[DataClass, ...] When there is a conflict between the command line arguments and the JSON file configuration, the command line arguments will take precedence. - This method combines data from a JSON file and command line arguments to populate instances of dataclasses. - - Args: - json_file : - The path to the JSON formatted file should be at index position 1 in the command line - arguments array (sys.argv[1]). - Any JSON file path at other positions will be considered invalid. - Returns: Tuple consisting of: - the dataclass instances in the same order as they were passed to the initializer.abspath """ - json_args = json.loads(Path(json_file).read_text()) - del sys.argv[1] - output_dir_arg = next( - (arg for arg in sys.argv if arg == "--output_dir" or arg.startswith("--output_dir=")), None - ) - if output_dir_arg is None: - if "output_dir" in json_args.keys(): - sys.argv.extend(["--output_dir", json_args["output_dir"]]) - else: - raise ValueError("The following arguments are required: --output_dir") - cmd_args = vars(self.parse_args()) - merged_args = {} - for key in json_args.keys() | cmd_args.keys(): - if any(arg == f"--{key}" or arg.startswith(f"--{key}=") for arg in sys.argv): - merged_args[key] = cmd_args.get(key) - elif json_args.get(key): - merged_args[key] = json_args.get(key) + if not sys.argv[1].endswith(".json"): + raise ValueError(f"The first argument should be a JSON file, but it is {sys.argv[1]}") + json_file = Path(sys.argv[1]) + if json_file.exists(): + with open(json_file, "r") as file: + data = json.load(file) + json_args = [] + for key, value in data.items(): + json_args.extend([f"--{key}", str(value)]) + else: + raise FileNotFoundError(f"The argument file {json_file} does not exist.") + # In case of conflict, command line arguments take precedence + args = json_args + sys.argv[2:] + namespace, _ = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: keys = {f.name for f in dataclasses.fields(dtype) if f.init} - inputs = {k: v for k, v in merged_args.items() if k in keys} + inputs = {k: v for k, v in vars(namespace).items() if k in keys} + for k in keys: + delattr(namespace, k) obj = dtype(**inputs) outputs.append(obj) return (*outputs,) diff --git a/tests/trainer/test_argparser.py b/tests/trainer/test_argparser.py index ed4e9a3e4256..946edd6f337d 100644 --- a/tests/trainer/test_argparser.py +++ b/tests/trainer/test_argparser.py @@ -24,10 +24,10 @@ def parse_args(): parser = PdArgumentParser((PreTrainingArguments,)) - # Support format as "args.json --args1 value1 --args2 value2.” + # Support format as "args.json --arg1 value1 --arg2 value2.” # In case of conflict, command line arguments take precedence. if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): - model_args = parser.parse_json_file_and_cmd_lines(json_file=os.path.abspath(sys.argv[1])) + model_args = parser.parse_json_file_and_cmd_lines() else: model_args = parser.parse_args_into_dataclasses() return model_args From 9f8e8d5adfc914216ce69f64d8d8991747982887 Mon Sep 17 00:00:00 2001 From: greycooker <526929599@qq.com> Date: Tue, 9 Jan 2024 15:04:47 +0000 Subject: [PATCH 6/6] Add remaining_args warning --- paddlenlp/trainer/argparser.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddlenlp/trainer/argparser.py b/paddlenlp/trainer/argparser.py index c08506a64a64..6a54a46fa389 100644 --- a/paddlenlp/trainer/argparser.py +++ b/paddlenlp/trainer/argparser.py @@ -19,6 +19,7 @@ import dataclasses import json import sys +import warnings from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError from copy import copy from enum import Enum @@ -273,7 +274,7 @@ def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]: raise FileNotFoundError(f"The argument file {json_file} does not exist.") # In case of conflict, command line arguments take precedence args = json_args + sys.argv[2:] - namespace, _ = self.parse_known_args(args=args) + namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: keys = {f.name for f in dataclasses.fields(dtype) if f.init} @@ -282,6 +283,9 @@ def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]: delattr(namespace, k) obj = dtype(**inputs) outputs.append(obj) + if remaining_args: + warnings.warn(f"Some specified arguments are not used by the PdArgumentParser: {remaining_args}") + return (*outputs,) def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: