Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding indicative error message to union decoding #29

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions pyrallis/parsers/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
is_enum,
ParsingError,
format_error,
has_generic_arg
has_generic_arg,
add_tab_to_new_lines
)

logger = getLogger(__name__)
Expand Down Expand Up @@ -195,19 +196,31 @@ def _decode_optional(val: Optional[Any]) -> Optional[T]:
return _decode_optional


def try_functions(*funcs: Callable[[Any], T]) -> Callable[[Any], Union[T, Any]]:
def try_decoding_functions(funcs: List[Callable[[Any], T]],
types: List[T]) -> Callable[[Any], Union[T, Any]]:
"""Tries to use the functions in succession, else returns the same value unchanged."""

def _try_functions(val: Any) -> Union[T, Any]:
for func in funcs:
assert len(funcs) == len(types), 'Each decoding function must have a unique type.'

def _try_decoding_functions(val: Any) -> Union[T, Any]:
error_messages: List[str] = []
for func, t in zip(funcs, types):
try:
return func(val)
except Exception:
except Exception as e:
error_message = f"-> Failed to parse as class {t.__name__}: {add_tab_to_new_lines(format_error(e))}"
error_messages.append(error_message)
continue
# If no function worked, raise an exception
raise TypeError(f"No valid parsing for value {val}")
class_names = [t.__name__ for t in types]
exception_message = f"Failed to parse value for multiple classes: " \
f"value={val}, classes={class_names}.\n" \
f"Got parsing errors:"
for error_message in error_messages:
exception_message += f'\n{error_message}'
raise TypeError(add_tab_to_new_lines(exception_message))

return _try_functions
return _try_decoding_functions


def decode_union(*types: Type[T]) -> Callable[[Any], Union[T, Any]]:
Expand All @@ -217,11 +230,14 @@ def decode_union(*types: Type[T]) -> Callable[[Any], Union[T, Any]]:
while type(None) in types:
types.remove(type(None))

if len(types) == 1: # There is a single not None class
return decode_optional(types[0])

decoding_fns: List[Callable[[Any], T]] = [
decode_optional(t) if optional else get_decoding_fn(t) for t in types
]
# Try using each of the non-None types, in succession. Worst case, return the value.
return try_functions(*decoding_fns)
return try_decoding_functions(funcs=decoding_fns, types=types)


def decode_list(t: Type[T]) -> Callable[[List[Any]], List[T]]:
Expand Down Expand Up @@ -314,11 +330,10 @@ def no_op(v: T) -> T:

def try_constructor(t: Type[T]) -> Callable[[Any], Union[T, Any]]:
""" Tries to use the type as a constructor. If that fails, returns the value as-is. """
return try_functions(lambda val: t(**val), lambda val: t(val))
funcs = [lambda val: t(**val), lambda val: t(val)]
return try_decoding_functions(funcs=funcs, types=[t, t])


from pathlib import Path

decode.register(Path, Path)


4 changes: 4 additions & 0 deletions pyrallis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ def format_error(e: Exception):
return f'Exception: {e}'


def add_tab_to_new_lines(text: str):
return text.replace('\n', '\n\t')


def is_generic_arg(arg):
try:
return arg.__name__ in ['KT', 'VT', 'T']
Expand Down
5 changes: 3 additions & 2 deletions tests/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import yaml
import json
import toml

from pyrallis.utils import PyrallisException
from .testutils import *
Expand Down Expand Up @@ -49,7 +50,7 @@ class SomeClass:
new_b = pyrallis.parse(config_class=SomeClass, config_path=tmp_file, args="")
assert new_b == b

arguments = shlex.split(f"--config_path {tmp_file}")
arguments = ['--config_path', str(tmp_file)]
new_b = pyrallis.parse(config_class=SomeClass, args=arguments)
assert new_b == b

Expand Down Expand Up @@ -98,7 +99,7 @@ class SomeClass:

new_b = pyrallis.parse(config_class=SomeClass, config_path=tmp_file, args="")
assert new_b == b
arguments = shlex.split(f"--config_path {tmp_file}")
arguments = ['--config_path', str(tmp_file)]
new_b = pyrallis.parse(config_class=SomeClass, args=arguments)
assert new_b == b

Expand Down