Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multiple arguments to config_path #25

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions pyrallis/argparsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class ArgumentParser(Generic[T], argparse.ArgumentParser):
def __init__(
self,
config_class: Type[T],
config_path: Optional[str] = None,
config_path: Optional[Union[Path, str, List[Union[str, Path]]]] = None,
commandline_overwrites: bool = True,
formatter_class: Type[HelpFormatter] = SimpleHelpFormatter,
*args,
**kwargs,
Expand All @@ -43,11 +44,22 @@ def __init__(

self._wrappers: List[DataclassWrapper] = []

self.config_path = config_path
self.commandline_overwrites = commandline_overwrites
if config_path is None:
self.config_path = []
elif isinstance(config_path, list):
self.config_path = config_path
else:
self.config_path = [config_path]
self.config_class = config_class

self._assert_no_conflicts()
self.add_argument(f'--{utils.CONFIG_ARG}', type=str, help='Path for a config file to parse with pyrallis')
self.add_argument(
f'--{utils.CONFIG_ARG}',
type=str,
nargs='*',
help='Paths for config files to parse with pyrallis, read from left to right',
)
self.set_dataclass(config_class)

def set_dataclass(
Expand Down Expand Up @@ -116,22 +128,27 @@ def _postprocessing(self, parsed_args: Namespace) -> T:

parsed_arg_values = vars(parsed_args)

for key in parsed_arg_values:
parsed_arg_values[key] = cfgparsing.parse_string(parsed_arg_values[key])

config_path = self.config_path # Could be NONE
config_path = self.config_path[:] # copy, always list

if utils.CONFIG_ARG in parsed_arg_values:
new_config_path = parsed_arg_values[utils.CONFIG_ARG]
if config_path is not None:
if self.commandline_overwrites and len(config_path) > 0:
warnings.warn(
UserWarning(f'Overriding default {config_path} with {new_config_path}')
)
config_path = new_config_path
config_path = new_config_path
else:
config_path.extend(new_config_path)
del parsed_arg_values[utils.CONFIG_ARG]

if config_path is not None:
file_args = cfgparsing.load_config(open(config_path, 'r'))
for key in parsed_arg_values:
parsed_arg_values[key] = cfgparsing.parse_string(parsed_arg_values[key])

# values read first will then be used to overwrite values
# read during later parses, so reverse the order for standard
# left to right parsing
for cfg_path in config_path[::-1]:
file_args = cfgparsing.load_config(open(cfg_path, 'r'))
file_args = utils.flatten(file_args, sep='.')
file_args.update(parsed_arg_values)
parsed_arg_values = file_args
Expand All @@ -143,8 +160,9 @@ def _postprocessing(self, parsed_args: Namespace) -> T:


def parse(config_class: Type[T], config_path: Optional[Union[Path, str]] = None,
args: Optional[Sequence[str]] = None) -> T:
parser = ArgumentParser(config_class=config_class, config_path=config_path)
args: Optional[Sequence[str]] = None, commandline_overwrites: bool=True) -> T:
parser = ArgumentParser(config_class=config_class, config_path=config_path,
commandline_overwrites=commandline_overwrites)
return parser.parse_args(args)


Expand Down
118 changes: 118 additions & 0 deletions tests/test_multi_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json
import warnings
from dataclasses import dataclass
from enum import Enum, auto

import yaml

from pyrallis.utils import PyrallisException

from .testutils import *

# List of simple attributes to use in tests:
two_arguments: List[Tuple[Type, Any, Any]] = [
# type, first (parsed) value, second (parsed) different value
(int, 123, 124),
(int, -1, 1),
(float, 123.0, 124.0),
(float, 0.123, 0.124),
(bool, True, False),
(str, "bob", "alice"),
(str, "[123]", "[124]"),
(str, "123", "124"),
]

def switch_args(args):
for t, a, b in args:
yield (t, a, b)
yield (t, b, a)

@pytest.fixture(params=list(switch_args(two_arguments)))
def two_attribute(request):
"""Test fixture that produces an tuple of (type, value 1, value 2) where
both values are different"""
return request.param


def test_multi_load(two_attribute, tmp_path):
some_type, val_a, val_b = two_attribute

@dataclass
class SomeClass:
val_a: Optional[some_type] = None
val_b: Optional[some_type] = None
val_c: Optional[some_type] = None

a = SomeClass(val_a=val_a)
b = SomeClass(val_a=val_b, val_b=val_b)
d = SomeClass(val_b=val_b)

# c = b, a
c = SomeClass(val_a=val_a, val_b=val_b)

tmp_file_a = tmp_path / 'config_a'
pyrallis.dump(a, tmp_file_a.open('w'), omit_defaults=True)
tmp_file_b = tmp_path / 'config_b'
pyrallis.dump(b, tmp_file_b.open('w'), omit_defaults=True)
tmp_file_d = tmp_path / 'config_d'
pyrallis.dump(d, tmp_file_d.open('w'), omit_defaults=True)

# b at second place overrides a
# both as python arguments
new_b = pyrallis.parse(config_class=SomeClass,
config_path=[tmp_file_a, tmp_file_b],
args="",
)
assert new_b == b

# both as commandline arguments
arguments = shlex.split(f"--config_path {tmp_file_a} {tmp_file_b}")
new_b = pyrallis.parse(config_class=SomeClass, args=arguments)
assert new_b == b

# mixed command line and python arguments
arguments = shlex.split(f"--config_path {tmp_file_b}")
new_b = pyrallis.parse(config_class=SomeClass, config_path=tmp_file_a, args=arguments,
commandline_overwrites=False)
assert new_b == b

# a at second place overrides b for some value only
# both as python arguments
new_c = pyrallis.parse(config_class=SomeClass,
config_path=[tmp_file_b, tmp_file_a],
args="",
)
assert new_c == c

# both as commandline arguments
arguments = shlex.split(f"--config_path {tmp_file_b} {tmp_file_a}")
new_c = pyrallis.parse(config_class=SomeClass, args=arguments)
assert new_c == c


# mixed command line and python arguments
arguments = shlex.split(f"--config_path {tmp_file_a}")
new_c = pyrallis.parse(config_class=SomeClass, config_path=tmp_file_b, args=arguments,
commandline_overwrites=False)
assert new_c == c

# merge files with mutually exclusive parameters
# both as python arguments
new_c = pyrallis.parse(config_class=SomeClass,
config_path=[tmp_file_a, tmp_file_d],
args="",
)
assert new_c == c

# commandline_overwrites = True
# mixed command line and python arguments with override
arguments = shlex.split(f"--config_path {tmp_file_b}")
with pytest.warns(UserWarning):
new_b = pyrallis.parse(config_class=SomeClass, config_path=tmp_file_a, args=arguments, commandline_overwrites=True)
assert new_b == b

# mixed command line and python arguments with override
arguments = shlex.split(f"--config_path {tmp_file_a}")
with pytest.warns(UserWarning):
new_a = pyrallis.parse(config_class=SomeClass, config_path=tmp_file_b, args=arguments, commandline_overwrites=True)
assert new_a == a