Skip to content

Commit

Permalink
Json parsers supports raise error
Browse files Browse the repository at this point in the history
  • Loading branch information
gongel committed Mar 21, 2024
1 parent 0c65a47 commit 766b993
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import dataclasses
import json
import sys
import warnings
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
from copy import copy
from enum import Enum
Expand Down Expand Up @@ -214,6 +213,10 @@ def parse_args_into_dataclasses(
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence
# so we append rather than prepend.

return self.common_parse(args, return_remaining_strings)

def common_parse(self, args, return_remaining_strings) -> Tuple[DataClass, ...]:
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
Expand All @@ -234,21 +237,30 @@ def parse_args_into_dataclasses(

return (*outputs,)

def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
def read_json(self, json_file: str) -> list:
json_file = Path(json_file)
if json_file.exists():
with open(json_file, "r") as file:
data = json.load(file)
json_args = []
for key, value in data.items():
if isinstance(value, list):
json_args.extend([f"--{key}", *[str(v) for v in value]])
else:
json_args.extend([f"--{key}", str(value)])
return json_args
else:
raise FileNotFoundError(f"The argument file {json_file} does not exist.")

def parse_json_file(self, json_file: str, return_remaining_strings=False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
dataclass types.
"""
data = json.loads(Path(json_file).read_text())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
json_args = self.read_json(json_file)
return self.common_parse(json_args, return_remaining_strings)

def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]:
def parse_json_file_and_cmd_lines(self, return_remaining_strings=False) -> Tuple[DataClass, ...]:
"""
Extend the functionality of `parse_json_file` to handle command line arguments in addition to loading a JSON
file.
Expand All @@ -263,33 +275,10 @@ def parse_json_file_and_cmd_lines(self) -> Tuple[DataClass, ...]:
"""
if not sys.argv[1].endswith(".json"):
raise ValueError(f"The first argument should be a JSON file, but it is {sys.argv[1]}")
json_file = Path(sys.argv[1])
if json_file.exists():
with open(json_file, "r") as file:
data = json.load(file)
json_args = []
for key, value in data.items():
if isinstance(value, list):
json_args.extend([f"--{key}", *[str(v) for v in value]])
else:
json_args.extend([f"--{key}", str(value)])
else:
raise FileNotFoundError(f"The argument file {json_file} does not exist.")
json_args = self.read_json(sys.argv[1])
# In case of conflict, command line arguments take precedence
args = json_args + sys.argv[2:]
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
for k in keys:
delattr(namespace, k)
obj = dtype(**inputs)
outputs.append(obj)
if remaining_args:
warnings.warn(f"Some specified arguments are not used by the PdArgumentParser: {remaining_args}")

return (*outputs,)
return self.common_parse(args, return_remaining_strings)

def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Expand Down

0 comments on commit 766b993

Please sign in to comment.