diff --git a/paddlenlp/trainer/argparser.py b/paddlenlp/trainer/argparser.py index c0289af11196..48a664682089 100644 --- a/paddlenlp/trainer/argparser.py +++ b/paddlenlp/trainer/argparser.py @@ -19,7 +19,6 @@ import dataclasses import json import sys -import warnings from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError from copy import copy from enum import Enum @@ -214,6 +213,10 @@ def parse_args_into_dataclasses( args = fargs + args if args is not None else fargs + sys.argv[1:] # in case of duplicate arguments the first one has precedence # so we append rather than prepend. + + return self.common_parse(args, return_remaining_strings) + + def common_parse(self, args, return_remaining_strings) -> Tuple[DataClass, ...]: namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: @@ -234,21 +237,30 @@ def parse_args_into_dataclasses( return (*outputs,) - def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: + def read_json(self, json_file: str) -> list: + json_file = Path(json_file) + if json_file.exists(): + with open(json_file, "r") as file: + data = json.load(file) + json_args = [] + for key, value in data.items(): + if isinstance(value, list): + json_args.extend([f"--{key}", *[str(v) for v in value]]) + else: + json_args.extend([f"--{key}", str(value)]) + return json_args + else: + raise FileNotFoundError(f"The argument file {json_file} does not exist.") + + def parse_json_file(self, json_file: str, return_remaining_strings=False) -> Tuple[DataClass, ...]: """ Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the dataclass types. """ - data = json.loads(Path(json_file).read_text()) - outputs = [] - for dtype in self.dataclass_types: - keys = {f.name for f in dataclasses.fields(dtype) if f.init} - inputs = {k: v for k, v in data.items() if k in keys} - obj = dtype(**inputs) - outputs.append(obj) - return (*outputs,) + json_args = self.read_json(json_file) + return self.common_parse(json_args, return_remaining_strings) - def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]: + def parse_json_file_and_cmd_lines(self, return_remaining_strings=False) -> Tuple[DataClass, ...]: """ Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON file. @@ -263,33 +275,10 @@ def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]: """ if not sys.argv[1].endswith(".json"): raise ValueError(f"The first argument should be a JSON file, but it is {sys.argv[1]}") - json_file = Path(sys.argv[1]) - if json_file.exists(): - with open(json_file, "r") as file: - data = json.load(file) - json_args = [] - for key, value in data.items(): - if isinstance(value, list): - json_args.extend([f"--{key}", *[str(v) for v in value]]) - else: - json_args.extend([f"--{key}", str(value)]) - else: - raise FileNotFoundError(f"The argument file {json_file} does not exist.") + json_args = self.read_json(sys.argv[1]) # In case of conflict, command line arguments take precedence args = json_args + sys.argv[2:] - namespace, remaining_args = self.parse_known_args(args=args) - outputs = [] - for dtype in self.dataclass_types: - keys = {f.name for f in dataclasses.fields(dtype) if f.init} - inputs = {k: v for k, v in vars(namespace).items() if k in keys} - for k in keys: - delattr(namespace, k) - obj = dtype(**inputs) - outputs.append(obj) - if remaining_args: - warnings.warn(f"Some specified arguments are not used by the PdArgumentParser: {remaining_args}") - - return (*outputs,) + return self.common_parse(args, return_remaining_strings) def parse_dict(self, args: dict) -> Tuple[DataClass, ...]: """