diff --git a/benchmark.py b/benchmark.py index 0a2aa27..e7b5453 100644 --- a/benchmark.py +++ b/benchmark.py @@ -8,14 +8,16 @@ import cpuinfo import icontract + """Run benchmarks and, if specified, overwrite README.""" def benchmark_against_others(repo_root: pathlib.Path, overwrite: bool) -> None: """Run benchmars against other libraries and include them in the Readme.""" script_rel_paths = [ - 'benchmarks/against_others/compare_invariant.py', 'benchmarks/against_others/compare_precondition.py', - 'benchmarks/against_others/compare_postcondition.py' + "benchmarks/against_others/compare_invariant.py", + "benchmarks/against_others/compare_precondition.py", + "benchmarks/against_others/compare_postcondition.py", ] if not overwrite: @@ -24,58 +26,88 @@ def benchmark_against_others(repo_root: pathlib.Path, overwrite: bool) -> None: print() subprocess.check_call([sys.executable, str(repo_root / script_rel_path)]) else: - out = ['The following scripts were run:\n\n'] + out = ["The following scripts were run:\n\n"] for script_rel_path in script_rel_paths: - out.append('* `{0} `_\n'.format(script_rel_path)) - out.append('\n') - - out.append(('The benchmarks were executed on {}.\nWe used Python {}, ' - 'icontract {}, deal 4.23.3 and dpcontracts 0.6.0.\n\n').format(cpuinfo.get_cpu_info()['brand'], - platform.python_version(), - icontract.__version__)) - - out.append('The following tables summarize the results.\n\n') + out.append( + "* `{0} `_\n".format( + script_rel_path + ) + ) + out.append("\n") + + out.append( + ( + "The benchmarks were executed on {}.\nWe used Python {}, " + "icontract {}, deal 4.23.3 and dpcontracts 0.6.0.\n\n" + ).format( + cpuinfo.get_cpu_info()["brand"], + platform.python_version(), + icontract.__version__, + ) + ) + + out.append("The following tables summarize the results.\n\n") stdouts = [] # type: List[str] for script_rel_path in script_rel_paths: - stdout = subprocess.check_output([sys.executable, str(repo_root / script_rel_path)]).decode() + stdout = subprocess.check_output( + [sys.executable, str(repo_root / script_rel_path)] + ).decode() stdouts.append(stdout) out.append(stdout) - out.append('\n') + out.append("\n") readme_path = repo_root / "docs" / "source" / "benchmarks.rst" - readme = readme_path.read_text(encoding='utf-8') - marker_start = '.. Becnhmark report from benchmark.py starts.' - marker_end = '.. Benchmark report from benchmark.py ends.' + readme = readme_path.read_text(encoding="utf-8") + marker_start = ".. Becnhmark report from benchmark.py starts." + marker_end = ".. Benchmark report from benchmark.py ends." lines = readme.splitlines() try: index_start = lines.index(marker_start) except ValueError as exc: - raise ValueError('Could not find the marker for the benchmarks in the {}: {}'.format( - readme_path, marker_start)) from exc + raise ValueError( + "Could not find the marker for the benchmarks in the {}: {}".format( + readme_path, marker_start + ) + ) from exc try: index_end = lines.index(marker_end) except ValueError as exc: - raise ValueError('Could not find the start marker for the benchmarks in the {}: {}'.format( - readme_path, marker_end)) from exc - - assert index_start < index_end, 'Unexpected end marker before start marker for the benchmarks.' - - lines = lines[:index_start + 1] + ['\n'] + (''.join(out)).splitlines() + ['\n'] + lines[index_end:] - readme_path.write_text('\n'.join(lines) + '\n', encoding='utf-8') + raise ValueError( + "Could not find the start marker for the benchmarks in the {}: {}".format( + readme_path, marker_end + ) + ) from exc + + assert ( + index_start < index_end + ), "Unexpected end marker before start marker for the benchmarks." + + lines = ( + lines[: index_start + 1] + + ["\n"] + + ("".join(out)).splitlines() + + ["\n"] + + lines[index_end:] + ) + readme_path.write_text("\n".join(lines) + "\n", encoding="utf-8") # This is necessary so that the benchmarks do not complain on a Windows machine if the console encoding has not # been properly set. - sys.stdout.buffer.write(('\n\n'.join(stdouts) + '\n').encode('utf-8')) + sys.stdout.buffer.write(("\n\n".join(stdouts) + "\n").encode("utf-8")) def main() -> int: - """"Execute main routine.""" + """ "Execute main routine.""" parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--overwrite", help="Overwrites the corresponding section in the docs.", action='store_true') + parser.add_argument( + "--overwrite", + help="Overwrites the corresponding section in the docs.", + action="store_true", + ) args = parser.parse_args() diff --git a/benchmarks/against_others/compare_invariant.py b/benchmarks/against_others/compare_invariant.py index 4fe75b6..e609dd3 100644 --- a/benchmarks/against_others/compare_invariant.py +++ b/benchmarks/against_others/compare_invariant.py @@ -22,7 +22,7 @@ def __init__(self, identifier: str) -> None: self.parts = identifier.split(".") def some_func(self) -> str: - return '.'.join(self.parts) + return ".".join(self.parts) @dpcontracts.invariant("some dummy invariant", lambda self: len(self.parts) > 0) @@ -31,7 +31,7 @@ def __init__(self, identifier: str) -> None: self.parts = identifier.split(".") def some_func(self) -> str: - return '.'.join(self.parts) + return ".".join(self.parts) @deal.inv(validator=lambda self: len(self.parts) > 0, message="some dummy invariant") @@ -40,7 +40,7 @@ def __init__(self, identifier: str) -> None: self.parts = identifier.split(".") def some_func(self) -> str: - return '.'.join(self.parts) + return ".".join(self.parts) class ClassWithInlineContract: @@ -50,7 +50,7 @@ def __init__(self, identifier: str) -> None: def some_func(self) -> str: assert len(self.parts) > 0 - result = '.'.join(self.parts) + result = ".".join(self.parts) assert len(self.parts) > 0 return result @@ -58,10 +58,10 @@ def some_func(self) -> str: # dpcontracts change __name__ attribute of the class, so we can not use # ClassWithDpcontractsInvariant.__name__ for a more maintainable list. clses = [ - 'ClassWithIcontract', - 'ClassWithDpcontracts', - 'ClassWithDeal', - 'ClassWithInlineContract', + "ClassWithIcontract", + "ClassWithDpcontracts", + "ClassWithDeal", + "ClassWithInlineContract", ] @@ -72,8 +72,8 @@ def writeln_utf8(text: str) -> None: We can not use ``print()`` as we can not rely on the correct encoding in Windows. See: https://stackoverflow.com/questions/31469707/changing-the-locale-preferred-encoding-in-python-3-in-windows """ - sys.stdout.buffer.write(text.encode('utf-8')) - sys.stdout.buffer.write(os.linesep.encode('utf-8')) + sys.stdout.buffer.write(text.encode("utf-8")) + sys.stdout.buffer.write(os.linesep.encode("utf-8")) def measure_invariant_at_init() -> None: @@ -82,7 +82,11 @@ def measure_invariant_at_init() -> None: number = 1 * 1000 * 1000 for i, cls in enumerate(clses): - duration = timeit.timeit("{}('X.Y')".format(cls), setup="from __main__ import {}".format(cls), number=number) + duration = timeit.timeit( + "{}('X.Y')".format(cls), + setup="from __main__ import {}".format(cls), + number=number, + ) durations[i] = duration writeln_utf8("Benchmarking invariant at __init__:\n") @@ -116,7 +120,10 @@ def measure_invariant_at_function() -> None: for i, cls in enumerate(clses): duration = timeit.timeit( - "a.some_func()", setup="from __main__ import {0}; a = {0}('X.Y')".format(cls), number=number) + "a.some_func()", + setup="from __main__ import {0}; a = {0}('X.Y')".format(cls), + number=number, + ) durations[i] = duration writeln_utf8("Benchmarking invariant at a function:\n") @@ -145,5 +152,5 @@ def measure_invariant_at_function() -> None: if __name__ == "__main__": measure_invariant_at_init() - writeln_utf8('') + writeln_utf8("") measure_invariant_at_function() diff --git a/benchmarks/against_others/compare_postcondition.py b/benchmarks/against_others/compare_postcondition.py index aaa638e..0856c85 100644 --- a/benchmarks/against_others/compare_postcondition.py +++ b/benchmarks/against_others/compare_postcondition.py @@ -49,14 +49,17 @@ def writeln_utf8(text: str) -> None: We can not use ``print()`` as we can not rely on the correct encoding in Windows. See: https://stackoverflow.com/questions/31469707/changing-the-locale-preferred-encoding-in-python-3-in-windows """ - sys.stdout.buffer.write(text.encode('utf-8')) - sys.stdout.buffer.write(os.linesep.encode('utf-8')) + sys.stdout.buffer.write(text.encode("utf-8")) + sys.stdout.buffer.write(os.linesep.encode("utf-8")) def measure_functions() -> None: funcs = [ - 'function_with_icontract', 'function_with_dpcontracts', 'function_with_deal_post', 'function_with_deal_ensure', - 'function_with_inline_contract' + "function_with_icontract", + "function_with_dpcontracts", + "function_with_deal_post", + "function_with_deal_ensure", + "function_with_inline_contract", ] durations = [0.0] * len(funcs) @@ -64,7 +67,11 @@ def measure_functions() -> None: number = 10 * 1000 for i, func in enumerate(funcs): - duration = timeit.timeit("{}(198.4)".format(func), setup="from __main__ import {}".format(func), number=number) + duration = timeit.timeit( + "{}(198.4)".format(func), + setup="from __main__ import {}".format(func), + number=number, + ) durations[i] = duration table = [] # type: List[List[str]] @@ -92,5 +99,5 @@ def measure_functions() -> None: if __name__ == "__main__": writeln_utf8("Benchmarking postcondition:") - writeln_utf8('') + writeln_utf8("") measure_functions() diff --git a/benchmarks/against_others/compare_precondition.py b/benchmarks/against_others/compare_precondition.py index 46aa7af..ac423dc 100644 --- a/benchmarks/against_others/compare_precondition.py +++ b/benchmarks/against_others/compare_precondition.py @@ -33,7 +33,7 @@ def function_with_deal(some_arg: int) -> float: def function_with_inline_contract(some_arg: int) -> float: - assert (some_arg > 0) + assert some_arg > 0 return math.sqrt(some_arg) @@ -48,13 +48,16 @@ def writeln_utf8(text: str) -> None: We can not use ``print()`` as we can not rely on the correct encoding in Windows. See: https://stackoverflow.com/questions/31469707/changing-the-locale-preferred-encoding-in-python-3-in-windows """ - sys.stdout.buffer.write(text.encode('utf-8')) - sys.stdout.buffer.write(os.linesep.encode('utf-8')) + sys.stdout.buffer.write(text.encode("utf-8")) + sys.stdout.buffer.write(os.linesep.encode("utf-8")) def measure_functions() -> None: funcs = [ - 'function_with_icontract', 'function_with_dpcontracts', 'function_with_deal', 'function_with_inline_contract' + "function_with_icontract", + "function_with_dpcontracts", + "function_with_deal", + "function_with_inline_contract", ] durations = [0.0] * len(funcs) @@ -62,7 +65,11 @@ def measure_functions() -> None: number = 10 * 1000 for i, func in enumerate(funcs): - duration = timeit.timeit("{}(198.4)".format(func), setup="from __main__ import {}".format(func), number=number) + duration = timeit.timeit( + "{}(198.4)".format(func), + setup="from __main__ import {}".format(func), + number=number, + ) durations[i] = duration table = [] # type: List[List[str]] @@ -90,5 +97,5 @@ def measure_functions() -> None: if __name__ == "__main__": writeln_utf8("Benchmarking precondition:") - writeln_utf8('') + writeln_utf8("") measure_functions() diff --git a/benchmarks/import_cost/generate.py b/benchmarks/import_cost/generate.py index 2c3d650..6b6a8ae 100755 --- a/benchmarks/import_cost/generate.py +++ b/benchmarks/import_cost/generate.py @@ -22,7 +22,9 @@ def generate_functions(functions: int, contracts: int, disabled: bool) -> str: if not disabled: out.write("@icontract.require(lambda x: x > {})\n".format(j)) else: - out.write("@icontract.require(lambda x: x > {}, enabled=False)\n".format(j)) + out.write( + "@icontract.require(lambda x: x > {}, enabled=False)\n".format(j) + ) out.write("def some_func{}(x: int) -> None:\n pass\n".format(i)) @@ -42,25 +44,36 @@ def generate_classes(classes: int, invariants: int, disabled: bool) -> str: if not disabled: out.write("@icontract.invariant(lambda self: self.x > {})\n".format(j)) else: - out.write("@icontract.invariant(lambda self: self.x > {}, enabled=False)\n".format(j)) + out.write( + "@icontract.invariant(lambda self: self.x > {}, enabled=False)\n".format( + j + ) + ) out.write( - textwrap.dedent("""\ + textwrap.dedent( + """\ class SomeClass{}: def __init__(self) -> None: self.x = 100 def some_func(self) -> None: pass - """.format(i))) + """.format( + i + ) + ) + ) return out.getvalue() def main() -> None: - """"Execute the main routine.""" + """ "Execute the main routine.""" parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--outdir", help="output directory", default=os.path.dirname(__file__)) + parser.add_argument( + "--outdir", help="output directory", default=os.path.dirname(__file__) + ) args = parser.parse_args() outdir = pathlib.Path(args.outdir) @@ -82,7 +95,9 @@ def main() -> None: if contracts == 1: pth = outdir / "functions_100_with_1_disabled_contract.py" else: - pth = outdir / "functions_100_with_{}_disabled_contracts.py".format(contracts) + pth = outdir / "functions_100_with_{}_disabled_contracts.py".format( + contracts + ) text = generate_functions(functions=100, contracts=contracts, disabled=True) pth.write_text(text) @@ -102,7 +117,9 @@ def main() -> None: if invariants == 1: pth = outdir / "classes_100_with_1_disabled_invariant.py" else: - pth = outdir / "classes_100_with_{}_disabled_invariants.py".format(invariants) + pth = outdir / "classes_100_with_{}_disabled_invariants.py".format( + invariants + ) text = generate_classes(classes=100, invariants=invariants, disabled=True) pth.write_text(text) diff --git a/benchmarks/import_cost/measure.py b/benchmarks/import_cost/measure.py index c5b3761..2d7df4a 100755 --- a/benchmarks/import_cost/measure.py +++ b/benchmarks/import_cost/measure.py @@ -10,7 +10,7 @@ def main() -> None: - """"Execute the main routine.""" + """ "Execute the main routine.""" parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--module", @@ -31,78 +31,93 @@ def main() -> None: "classes_100_with_5_disabled_invariants", "classes_100_with_10_disabled_invariants", ], - required=True) + required=True, + ) args = parser.parse_args() a_module = str(args.module) if a_module == "functions_100_with_no_contract": start = time.time() import functions_100_with_no_contract + print(time.time() - start) elif a_module == "functions_100_with_1_contract": start = time.time() import functions_100_with_1_contract + print(time.time() - start) elif a_module == "functions_100_with_5_contracts": start = time.time() import functions_100_with_5_contracts + print(time.time() - start) elif a_module == "functions_100_with_10_contracts": start = time.time() import functions_100_with_10_contracts + print(time.time() - start) elif a_module == "functions_100_with_1_disabled_contract": start = time.time() import functions_100_with_1_disabled_contract + print(time.time() - start) elif a_module == "functions_100_with_5_disabled_contracts": start = time.time() import functions_100_with_5_disabled_contracts + print(time.time() - start) elif a_module == "functions_100_with_10_disabled_contracts": start = time.time() import functions_100_with_10_disabled_contracts + print(time.time() - start) elif a_module == "classes_100_with_no_invariant": start = time.time() import classes_100_with_no_invariant + print(time.time() - start) elif a_module == "classes_100_with_1_invariant": start = time.time() import classes_100_with_1_invariant + print(time.time() - start) elif a_module == "classes_100_with_5_invariants": start = time.time() import classes_100_with_5_invariants + print(time.time() - start) elif a_module == "classes_100_with_10_invariants": start = time.time() import classes_100_with_10_invariants + print(time.time() - start) elif a_module == "classes_100_with_1_disabled_invariant": start = time.time() import classes_100_with_1_disabled_invariant + print(time.time() - start) elif a_module == "classes_100_with_5_disabled_invariants": start = time.time() import classes_100_with_5_disabled_invariants + print(time.time() - start) elif a_module == "classes_100_with_10_disabled_invariants": start = time.time() import classes_100_with_10_disabled_invariants + print(time.time() - start) else: diff --git a/benchmarks/import_cost/runme.py b/benchmarks/import_cost/runme.py index 7a82834..11f0af4 100755 --- a/benchmarks/import_cost/runme.py +++ b/benchmarks/import_cost/runme.py @@ -7,7 +7,7 @@ def main() -> None: - """"Execute the main routine.""" + """ "Execute the main routine.""" modules = [ "functions_100_with_no_contract", "functions_100_with_1_contract", @@ -29,13 +29,20 @@ def main() -> None: durations = [] # type: List[float] for i in range(0, 10): duration = float( - subprocess.check_output(["./measure.py", "--module", a_module], cwd=os.path.dirname(__file__)).strip()) + subprocess.check_output( + ["./measure.py", "--module", a_module], + cwd=os.path.dirname(__file__), + ).strip() + ) durations.append(duration) - print("Duration to import the module {} (in milliseconds): {:.2f} ± {:.2f}".format( - a_module, - statistics.mean(durations) * 10e3, - statistics.stdev(durations) * 10e3)) + print( + "Duration to import the module {} (in milliseconds): {:.2f} ± {:.2f}".format( + a_module, + statistics.mean(durations) * 10e3, + statistics.stdev(durations) * 10e3, + ) + ) if __name__ == "__main__": diff --git a/icontract/__init__.py b/icontract/__init__.py index bf8bec5..51089fd 100644 --- a/icontract/__init__.py +++ b/icontract/__init__.py @@ -8,11 +8,11 @@ # imports in setup.py. # Don't forget to update the version in __init__.py and CHANGELOG.rst! -__version__ = '2.6.2' -__author__ = 'Marko Ristin' -__copyright__ = 'Copyright 2019 Parquery AG' -__license__ = 'MIT' -__status__ = 'Production' +__version__ = "2.6.2" +__author__ = "Marko Ristin" +__copyright__ = "Copyright 2019 Parquery AG" +__license__ = "MIT" +__status__ = "Production" # pylint: disable=invalid-name # pylint: disable=wrong-import-position @@ -24,22 +24,27 @@ # https://stackoverflow.com/questions/44344327/cant-make-mypy-work-with-init-py-aliases import icontract._decorators + require = icontract._decorators.require snapshot = icontract._decorators.snapshot ensure = icontract._decorators.ensure invariant = icontract._decorators.invariant import icontract._globals + aRepr = icontract._globals.aRepr SLOW = icontract._globals.SLOW import icontract._metaclass + DBCMeta = icontract._metaclass.DBCMeta DBC = icontract._metaclass.DBC import icontract._types + _Contract = icontract._types.Contract _Snapshot = icontract._types.Snapshot import icontract.errors + ViolationError = icontract.errors.ViolationError diff --git a/icontract/_checkers.py b/icontract/_checkers.py index 8869cc1..513d52c 100644 --- a/icontract/_checkers.py +++ b/icontract/_checkers.py @@ -2,8 +2,19 @@ import contextvars import functools import inspect -from typing import Callable, Any, Iterable, Optional, Tuple, List, Mapping, \ - MutableMapping, Dict, cast, Set +from typing import ( + Callable, + Any, + Iterable, + Optional, + Tuple, + List, + Mapping, + MutableMapping, + Dict, + cast, + Set, +) import icontract._represent from icontract._globals import CallableT, ClassT @@ -15,7 +26,7 @@ # pylint: disable=raising-bad-type -def _walk_decorator_stack(func: CallableT) -> Iterable['CallableT']: +def _walk_decorator_stack(func: CallableT) -> Iterable["CallableT"]: """ Iterate through the stack of decorated functions until the original function. @@ -33,14 +44,20 @@ def find_checker(func: CallableT) -> Optional[CallableT]: """Iterate through the decorator stack till we find the contract checker.""" contract_checker = None # type: Optional[CallableT] for a_wrapper in _walk_decorator_stack(func): - if hasattr(a_wrapper, "__preconditions__") or hasattr(a_wrapper, "__postconditions__"): + if hasattr(a_wrapper, "__preconditions__") or hasattr( + a_wrapper, "__postconditions__" + ): contract_checker = a_wrapper return contract_checker -def kwargs_from_call(param_names: List[str], kwdefaults: Dict[str, Any], args: Tuple[Any, ...], - kwargs: Dict[str, Any]) -> MutableMapping[str, Any]: +def kwargs_from_call( + param_names: List[str], + kwdefaults: Dict[str, Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> MutableMapping[str, Any]: """ Inspect the input values received at the wrapper for the actual function call. @@ -57,7 +74,7 @@ def kwargs_from_call(param_names: List[str], kwdefaults: Dict[str, Any], args: T # (*e.g.*, when the contracts do not need them or don't use any argument at all). # We need to have a concrete issue where profiling helps us determine if this is a real # bottleneck or not and not optimize for no real benefit. - resolved_kwargs = {'_ARGS': args, '_KWARGS': kwargs} + resolved_kwargs = {"_ARGS": args, "_KWARGS": kwargs} # Set the default argument values as condition parameters. for param_name, param_value in kwdefaults.items(): @@ -99,12 +116,14 @@ def not_check(check: Any, contract: Contract) -> bool: if contract.location is not None: msg_parts.append("{}:\n".format(contract.location)) - msg_parts.append('Failed to negate the evaluation of the condition.') + msg_parts.append("Failed to negate the evaluation of the condition.") - raise ValueError(''.join(msg_parts)) from err + raise ValueError("".join(msg_parts)) from err -def select_condition_kwargs(contract: Contract, resolved_kwargs: Mapping[str, Any]) -> Mapping[str, Any]: +def select_condition_kwargs( + contract: Contract, resolved_kwargs: Mapping[str, Any] +) -> Mapping[str, Any]: """ Select the keyword arguments that are used by the contract. @@ -114,24 +133,34 @@ def select_condition_kwargs(contract: Contract, resolved_kwargs: Mapping[str, An :return: a subset of resolved_kwargs """ # Check that all arguments to the condition function have been set. - missing_args = [arg_name for arg_name in contract.mandatory_args if arg_name not in resolved_kwargs] + missing_args = [ + arg_name + for arg_name in contract.mandatory_args + if arg_name not in resolved_kwargs + ] if missing_args: msg_parts = [] # type: List[str] if contract.location is not None: msg_parts.append("{}:\n".format(contract.location)) msg_parts.append( - ("The argument(s) of the contract condition have not been set: {}. " - "Does the original function define them? Did you supply them in the call?").format(missing_args)) + ( + "The argument(s) of the contract condition have not been set: {}. " + "Does the original function define them? Did you supply them in the call?" + ).format(missing_args) + ) - if 'OLD' in missing_args: - msg_parts.append(' Did you decorate the function with a snapshot to capture OLD values?') + if "OLD" in missing_args: + msg_parts.append( + " Did you decorate the function with a snapshot to capture OLD values?" + ) - raise TypeError(''.join(msg_parts)) + raise TypeError("".join(msg_parts)) condition_kwargs = { arg_name: value - for arg_name, value in resolved_kwargs.items() if arg_name in contract.condition_arg_set + for arg_name, value in resolved_kwargs.items() + if arg_name in contract.condition_arg_set } return condition_kwargs @@ -139,18 +168,24 @@ def select_condition_kwargs(contract: Contract, resolved_kwargs: Mapping[str, An def _assert_no_invalid_kwargs(kwargs: Any) -> Optional[TypeError]: """Check that kwargs of a function contain no unexpected arguments.""" - if '_ARGS' in kwargs: - return TypeError('The arguments of the function call include "_ARGS" which is ' - 'a placeholder for positional arguments in a condition.') - - if '_KWARGS' in kwargs: - return TypeError('The arguments of the function call include "_KWARGS" which is ' - 'a placeholder for keyword arguments in a condition.') + if "_ARGS" in kwargs: + return TypeError( + 'The arguments of the function call include "_ARGS" which is ' + "a placeholder for positional arguments in a condition." + ) + + if "_KWARGS" in kwargs: + return TypeError( + 'The arguments of the function call include "_KWARGS" which is ' + "a placeholder for keyword arguments in a condition." + ) return None -def _unpack_pre_snap_posts(wrapper: CallableT) -> Tuple[List[List[Contract]], List[Snapshot], List[Contract]]: +def _unpack_pre_snap_posts( + wrapper: CallableT, +) -> Tuple[List[List[Contract]], List[Snapshot], List[Contract]]: """Retrieve the preconditions, snapshots and postconditions defined for the given wrapper checker.""" preconditions = getattr(wrapper, "__preconditions__") # type: List[List[Contract]] snapshots = getattr(wrapper, "__postcondition_snapshots__") # type: List[Snapshot] @@ -159,26 +194,35 @@ def _unpack_pre_snap_posts(wrapper: CallableT) -> Tuple[List[List[Contract]], Li return preconditions, snapshots, postconditions -def _assert_resolved_kwargs_valid(postconditions: List[Contract], - resolved_kwargs: Mapping[str, Any]) -> Optional[TypeError]: +def _assert_resolved_kwargs_valid( + postconditions: List[Contract], resolved_kwargs: Mapping[str, Any] +) -> Optional[TypeError]: """Check that the resolved kwargs of a decorated function are valid.""" if postconditions: - if 'result' in resolved_kwargs: - return TypeError("Unexpected argument 'result' in a function decorated with postconditions.") + if "result" in resolved_kwargs: + return TypeError( + "Unexpected argument 'result' in a function decorated with postconditions." + ) - if 'OLD' in resolved_kwargs: - return TypeError("Unexpected argument 'OLD' in a function decorated with postconditions.") + if "OLD" in resolved_kwargs: + return TypeError( + "Unexpected argument 'OLD' in a function decorated with postconditions." + ) return None -def _create_violation_error(contract: Contract, resolved_kwargs: Mapping[str, Any]) -> BaseException: +def _create_violation_error( + contract: Contract, resolved_kwargs: Mapping[str, Any] +) -> BaseException: """Create the violation error based on the violated contract.""" exception = None # type: Optional[BaseException] if contract.error is None: try: - msg = icontract._represent.generate_message(contract=contract, resolved_kwargs=resolved_kwargs) + msg = icontract._represent.generate_message( + contract=contract, resolved_kwargs=resolved_kwargs + ) except Exception as err: parts = ["Failed to recompute the values of the contract condition:\n"] if contract.location is not None: @@ -187,45 +231,62 @@ def _create_violation_error(contract: Contract, resolved_kwargs: Mapping[str, An if contract.description is not None: parts.append("{}: ".format(contract.description)) - parts.append(icontract._represent.represent_condition(condition=contract.condition)) + parts.append( + icontract._represent.represent_condition(condition=contract.condition) + ) - raise RuntimeError(''.join(parts)) from err + raise RuntimeError("".join(parts)) from err exception = ViolationError(msg) elif inspect.ismethod(contract.error) or inspect.isfunction(contract.error): - assert contract.error_arg_set is not None, ("Expected error_arg_set non-None if contract.error a function.") - assert contract.error_args is not None, ("Expected error_args non-None if contract.error a function.") + assert ( + contract.error_arg_set is not None + ), "Expected error_arg_set non-None if contract.error a function." + assert ( + contract.error_args is not None + ), "Expected error_args non-None if contract.error a function." - error_kwargs = select_error_kwargs(contract=contract, resolved_kwargs=resolved_kwargs) + error_kwargs = select_error_kwargs( + contract=contract, resolved_kwargs=resolved_kwargs + ) exception = cast(BaseException, contract.error(**error_kwargs)) if not isinstance(exception, BaseException): raise TypeError( "The exception returned by the contract's error {} does not inherit from BaseException.".format( - contract.error)) + contract.error + ) + ) elif isinstance(contract.error, type): if not issubclass(contract.error, BaseException): raise TypeError( "The exception class supplied in the contract's error {} is not a subclass of BaseException.".format( - contract.error)) + contract.error + ) + ) - msg = icontract._represent.generate_message(contract=contract, resolved_kwargs=resolved_kwargs) + msg = icontract._represent.generate_message( + contract=contract, resolved_kwargs=resolved_kwargs + ) exception = contract.error(msg) elif isinstance(contract.error, BaseException): exception = contract.error else: raise NotImplementedError( - ("icontract does not know how to handle the error of type {} " - "(expected a function, a subclass of BaseException or an instance of BaseException)").format( - type(contract.error))) + ( + "icontract does not know how to handle the error of type {} " + "(expected a function, a subclass of BaseException or an instance of BaseException)" + ).format(type(contract.error)) + ) assert exception is not None return exception -async def _assert_preconditions_async(preconditions: List[List[Contract]], - resolved_kwargs: Mapping[str, Any]) -> Optional[BaseException]: +async def _assert_preconditions_async( + preconditions: List[List[Contract]], resolved_kwargs: Mapping[str, Any] +) -> Optional[BaseException]: """Assert that the preconditions of an async function hold.""" exception = None # type: Optional[BaseException] @@ -236,9 +297,13 @@ async def _assert_preconditions_async(preconditions: List[List[Contract]], exception = None for contract in group: - assert exception is None, "No exception as long as pre-condition group is satisfiable." + assert ( + exception is None + ), "No exception as long as pre-condition group is satisfiable." - condition_kwargs = select_condition_kwargs(contract=contract, resolved_kwargs=resolved_kwargs) + condition_kwargs = select_condition_kwargs( + contract=contract, resolved_kwargs=resolved_kwargs + ) if inspect.iscoroutinefunction(contract.condition): check = await contract.condition(**condition_kwargs) @@ -250,7 +315,9 @@ async def _assert_preconditions_async(preconditions: List[List[Contract]], check = check_or_coroutine if not_check(check=check, contract=contract): - exception = _create_violation_error(contract=contract, resolved_kwargs=resolved_kwargs) + exception = _create_violation_error( + contract=contract, resolved_kwargs=resolved_kwargs + ) break # The group of preconditions was satisfied, no need to check the other groups. @@ -260,8 +327,11 @@ async def _assert_preconditions_async(preconditions: List[List[Contract]], return exception -def _assert_preconditions(preconditions: List[List[Contract]], resolved_kwargs: Mapping[str, Any], - func: CallableT) -> Optional[BaseException]: +def _assert_preconditions( + preconditions: List[List[Contract]], + resolved_kwargs: Mapping[str, Any], + func: CallableT, +) -> Optional[BaseException]: """Assert that the preconditions of a sync function hold.""" exception = None # type: Optional[BaseException] @@ -272,22 +342,34 @@ def _assert_preconditions(preconditions: List[List[Contract]], resolved_kwargs: exception = None for contract in group: - assert exception is None, "No exception as long as pre-condition group is satisfiable." + assert ( + exception is None + ), "No exception as long as pre-condition group is satisfiable." - condition_kwargs = select_condition_kwargs(contract=contract, resolved_kwargs=resolved_kwargs) + condition_kwargs = select_condition_kwargs( + contract=contract, resolved_kwargs=resolved_kwargs + ) if inspect.iscoroutinefunction(contract.condition): - raise ValueError("Unexpected coroutine (async) condition {} for a sync function {}.".format( - contract.condition, func)) + raise ValueError( + "Unexpected coroutine (async) condition {} for a sync function {}.".format( + contract.condition, func + ) + ) check = contract.condition(**condition_kwargs) if inspect.iscoroutine(check): - raise ValueError("Unexpected coroutine resulting from the condition {} for a sync function {}.".format( - contract.condition, func)) + raise ValueError( + "Unexpected coroutine resulting from the condition {} for a sync function {}.".format( + contract.condition, func + ) + ) if not_check(check=check, contract=contract): - exception = _create_violation_error(contract=contract, resolved_kwargs=resolved_kwargs) + exception = _create_violation_error( + contract=contract, resolved_kwargs=resolved_kwargs + ) break # The group of preconditions was satisfied, no need to check the other groups. @@ -297,16 +379,22 @@ def _assert_preconditions(preconditions: List[List[Contract]], resolved_kwargs: return exception -async def _capture_old_async(snapshots: List[Snapshot], resolved_kwargs: Mapping[str, Any]) -> 'Old': +async def _capture_old_async( + snapshots: List[Snapshot], resolved_kwargs: Mapping[str, Any] +) -> "Old": """Capture all snapshots of an async function and return the captured values bundled in an ``Old``.""" old_as_mapping = dict() # type: MutableMapping[str, Any] for snap in snapshots: # This assert is just a last defense. # Conflicting snapshot names should have been caught before, either during the decoration or # in the meta-class. - assert snap.name not in old_as_mapping, "Snapshots with the conflicting name: {}" + assert ( + snap.name not in old_as_mapping + ), "Snapshots with the conflicting name: {}" - capture_kwargs = select_capture_kwargs(a_snapshot=snap, resolved_kwargs=resolved_kwargs) + capture_kwargs = select_capture_kwargs( + a_snapshot=snap, resolved_kwargs=resolved_kwargs + ) if inspect.iscoroutinefunction(snap.capture): old_as_mapping[snap.name] = await snap.capture(**capture_kwargs) @@ -322,39 +410,56 @@ async def _capture_old_async(snapshots: List[Snapshot], resolved_kwargs: Mapping return Old(mapping=old_as_mapping) -def _capture_old(snapshots: List[Snapshot], resolved_kwargs: Mapping[str, Any], func: CallableT) -> 'Old': +def _capture_old( + snapshots: List[Snapshot], resolved_kwargs: Mapping[str, Any], func: CallableT +) -> "Old": """Capture all snapshots of a sync function and return the captured values bundled in an ``Old``.""" old_as_mapping = dict() # type: MutableMapping[str, Any] for snap in snapshots: # This assert is just a last defense. # Conflicting snapshot names should have been caught before, either during the decoration or # in the meta-class. - assert snap.name not in old_as_mapping, "Snapshots with the conflicting name: {}" + assert ( + snap.name not in old_as_mapping + ), "Snapshots with the conflicting name: {}" if inspect.iscoroutinefunction(snap.capture): - raise ValueError("Unexpected coroutine (async) snapshot capture {} for a sync function {}.".format( - snap.capture, func)) + raise ValueError( + "Unexpected coroutine (async) snapshot capture {} for a sync function {}.".format( + snap.capture, func + ) + ) - capture_kwargs = select_capture_kwargs(a_snapshot=snap, resolved_kwargs=resolved_kwargs) + capture_kwargs = select_capture_kwargs( + a_snapshot=snap, resolved_kwargs=resolved_kwargs + ) captured = snap.capture(**capture_kwargs) if inspect.iscoroutine(captured): - raise ValueError(("Unexpected coroutine resulting from the snapshot capture {} " - "of a sync function {}.").format(snap.capture, func)) + raise ValueError( + ( + "Unexpected coroutine resulting from the snapshot capture {} " + "of a sync function {}." + ).format(snap.capture, func) + ) old_as_mapping[snap.name] = captured return Old(mapping=old_as_mapping) -async def _assert_postconditions_async(postconditions: List[Contract], - resolved_kwargs: Mapping[str, Any]) -> Optional[BaseException]: +async def _assert_postconditions_async( + postconditions: List[Contract], resolved_kwargs: Mapping[str, Any] +) -> Optional[BaseException]: """Assert that the postconditions of an async function hold.""" - assert 'result' in resolved_kwargs, \ - "Expected 'result' to be already set in resolved kwargs before calling this function." + assert ( + "result" in resolved_kwargs + ), "Expected 'result' to be already set in resolved kwargs before calling this function." for contract in postconditions: - condition_kwargs = select_condition_kwargs(contract=contract, resolved_kwargs=resolved_kwargs) + condition_kwargs = select_condition_kwargs( + contract=contract, resolved_kwargs=resolved_kwargs + ) if inspect.iscoroutinefunction(contract.condition): check = await contract.condition(**condition_kwargs) @@ -366,34 +471,48 @@ async def _assert_postconditions_async(postconditions: List[Contract], check = check_or_coroutine if not_check(check=check, contract=contract): - exception = _create_violation_error(contract=contract, resolved_kwargs=resolved_kwargs) + exception = _create_violation_error( + contract=contract, resolved_kwargs=resolved_kwargs + ) return exception return None -def _assert_postconditions(postconditions: List[Contract], resolved_kwargs: Mapping[str, Any], - func: CallableT) -> Optional[BaseException]: +def _assert_postconditions( + postconditions: List[Contract], resolved_kwargs: Mapping[str, Any], func: CallableT +) -> Optional[BaseException]: """Assert that the postconditions of a sync function hold.""" - assert 'result' in resolved_kwargs, \ - "Expected 'result' to be already set in resolved kwargs before calling this function." + assert ( + "result" in resolved_kwargs + ), "Expected 'result' to be already set in resolved kwargs before calling this function." for contract in postconditions: if inspect.iscoroutinefunction(contract.condition): - raise ValueError("Unexpected coroutine (async) condition {} for a sync function {}.".format( - contract.condition, func)) + raise ValueError( + "Unexpected coroutine (async) condition {} for a sync function {}.".format( + contract.condition, func + ) + ) - condition_kwargs = select_condition_kwargs(contract=contract, resolved_kwargs=resolved_kwargs) + condition_kwargs = select_condition_kwargs( + contract=contract, resolved_kwargs=resolved_kwargs + ) check = contract.condition(**condition_kwargs) if inspect.iscoroutine(check): - raise ValueError("Unexpected coroutine resulting from the condition {} for a sync function {}.".format( - contract.condition, func)) + raise ValueError( + "Unexpected coroutine resulting from the condition {} for a sync function {}.".format( + contract.condition, func + ) + ) if not_check(check=check, contract=contract): - exception = _create_violation_error(contract=contract, resolved_kwargs=resolved_kwargs) + exception = _create_violation_error( + contract=contract, resolved_kwargs=resolved_kwargs + ) return exception @@ -402,16 +521,20 @@ def _assert_postconditions(postconditions: List[Contract], resolved_kwargs: Mapp def _assert_invariant(contract: Contract, instance: Any) -> None: """Assert that the contract holds as a class invariant given the instance of the class.""" - if 'self' in contract.condition_arg_set: + if "self" in contract.condition_arg_set: check = contract.condition(self=instance) else: check = contract.condition() if not_check(check=check, contract=contract): - raise _create_violation_error(contract=contract, resolved_kwargs={'self': instance}) + raise _create_violation_error( + contract=contract, resolved_kwargs={"self": instance} + ) -def select_capture_kwargs(a_snapshot: Snapshot, resolved_kwargs: Mapping[str, Any]) -> Mapping[str, Any]: +def select_capture_kwargs( + a_snapshot: Snapshot, resolved_kwargs: Mapping[str, Any] +) -> Mapping[str, Any]: """ Select the keyword arguments that are used by the snapshot capture. @@ -419,22 +542,33 @@ def select_capture_kwargs(a_snapshot: Snapshot, resolved_kwargs: Mapping[str, An :param resolved_kwargs: resolved keyword arguments (including the default values) :return: a subset of resolved_kwargs """ - missing_args = [arg_name for arg_name in a_snapshot.args if arg_name not in resolved_kwargs] + missing_args = [ + arg_name for arg_name in a_snapshot.args if arg_name not in resolved_kwargs + ] if missing_args: msg_parts = [] if a_snapshot.location is not None: msg_parts.append("{}:\n".format(a_snapshot.location)) msg_parts.append( - ("The argument(s) of the snapshot have not been set: {}. " - "Does the original function define them? Did you supply them in the call?").format(missing_args)) - - raise TypeError(''.join(msg_parts)) - - return {arg_name: arg_value for arg_name, arg_value in resolved_kwargs.items() if arg_name in a_snapshot.arg_set} + ( + "The argument(s) of the snapshot have not been set: {}. " + "Does the original function define them? Did you supply them in the call?" + ).format(missing_args) + ) + + raise TypeError("".join(msg_parts)) + + return { + arg_name: arg_value + for arg_name, arg_value in resolved_kwargs.items() + if arg_name in a_snapshot.arg_set + } -def select_error_kwargs(contract: Contract, resolved_kwargs: Mapping[str, Any]) -> Mapping[str, Any]: +def select_error_kwargs( + contract: Contract, resolved_kwargs: Mapping[str, Any] +) -> Mapping[str, Any]: """ Select the keyword arguments that are used by the error creator of the contract. @@ -447,20 +581,26 @@ def select_error_kwargs(contract: Contract, resolved_kwargs: Mapping[str, Any]) error_kwargs = { arg_name: value - for arg_name, value in resolved_kwargs.items() if arg_name in contract.error_arg_set + for arg_name, value in resolved_kwargs.items() + if arg_name in contract.error_arg_set } - missing_args = [arg_name for arg_name in contract.error_args if arg_name not in resolved_kwargs] + missing_args = [ + arg_name for arg_name in contract.error_args if arg_name not in resolved_kwargs + ] if missing_args: msg_parts = [] # type: List[str] if contract.location is not None: msg_parts.append("{}:\n".format(contract.location)) msg_parts.append( - ("The argument(s) of the contract error have not been set: {}. " - "Does the original function define them? Did you supply them in the call?").format(missing_args)) + ( + "The argument(s) of the contract error have not been set: {}. " + "Does the original function define them? Did you supply them in the call?" + ).format(missing_args) + ) - raise TypeError(''.join(msg_parts)) + raise TypeError("".join(msg_parts)) return error_kwargs @@ -478,8 +618,12 @@ def __init__(self, mapping: Mapping[str, Any]) -> None: def __getattr__(self, item: str) -> Any: """Raise an error as this ``item`` should not be in the ``__dict__``.""" - raise AttributeError("The snapshot with the name {!r} is not available in the OLD of a postcondition. " - "Have you decorated the function with a corresponding snapshot decorator?".format(item)) + raise AttributeError( + "The snapshot with the name {!r} is not available in the OLD of a postcondition. " + "Have you decorated the function with a corresponding snapshot decorator?".format( + item + ) + ) def __repr__(self) -> str: """Represent the old values with a string literal as user is unaware of the class.""" @@ -502,31 +646,38 @@ def resolve_kwdefaults(sign: inspect.Signature) -> Dict[str, Any]: # contract checking is already in progress. # # The key refers to the id() of the function (preconditions and postconditions) or instance (invariants). -_IN_PROGRESS = contextvars.ContextVar("_IN_PROGRESS", default=None) # type: contextvars.ContextVar[Optional[Set[int]]] +_IN_PROGRESS = contextvars.ContextVar( + "_IN_PROGRESS", default=None +) # type: contextvars.ContextVar[Optional[Set[int]]] def decorate_with_checker(func: CallableT) -> CallableT: """Decorate the function with a checker that verifies the preconditions and postconditions.""" - assert not hasattr(func, "__preconditions__"), \ - "Expected func to have no list of preconditions (there should be only a single contract checker per function)." + assert not hasattr( + func, "__preconditions__" + ), "Expected func to have no list of preconditions (there should be only a single contract checker per function)." - assert not hasattr(func, "__postconditions__"), \ - "Expected func to have no list of postconditions (there should be only a single contract checker per function)." + assert not hasattr( + func, "__postconditions__" + ), "Expected func to have no list of postconditions (there should be only a single contract checker per function)." - assert not hasattr(func, "__postcondition_snapshots__"), \ - "Expected func to have no list of postcondition snapshots (there should be only a single contract checker " \ + assert not hasattr(func, "__postcondition_snapshots__"), ( + "Expected func to have no list of postcondition snapshots (there should be only a single contract checker " "per function)." + ) sign = inspect.signature(func) - if '_ARGS' in sign.parameters: + if "_ARGS" in sign.parameters: raise TypeError( 'The arguments of the function to be decorated with a contract checker include "_ARGS" which is ' - 'a reserved placeholder for positional arguments in the condition.') + "a reserved placeholder for positional arguments in the condition." + ) - if '_KWARGS' in sign.parameters: + if "_KWARGS" in sign.parameters: raise TypeError( 'The arguments of the function to be decorated with a contract checker include "_KWARGS" which is ' - 'a reserved placeholder for keyword arguments in the condition.') + "a reserved placeholder for keyword arguments in the condition." + ) param_names = list(sign.parameters.keys()) @@ -571,24 +722,34 @@ async def wrapper(*args, **kwargs): # type: ignore in_progress.add(id_func) - (preconditions, snapshots, postconditions) = _unpack_pre_snap_posts(wrapper) + (preconditions, snapshots, postconditions) = _unpack_pre_snap_posts( + wrapper + ) resolved_kwargs = kwargs_from_call( - param_names=param_names, kwdefaults=kwdefaults, args=args, kwargs=kwargs) + param_names=param_names, + kwdefaults=kwdefaults, + args=args, + kwargs=kwargs, + ) - type_error = _assert_resolved_kwargs_valid(postconditions, resolved_kwargs) + type_error = _assert_resolved_kwargs_valid( + postconditions, resolved_kwargs + ) if type_error: raise type_error violation_error = await _assert_preconditions_async( - preconditions=preconditions, resolved_kwargs=resolved_kwargs) + preconditions=preconditions, resolved_kwargs=resolved_kwargs + ) if violation_error: raise violation_error # Capture the snapshots if postconditions and snapshots: - resolved_kwargs['OLD'] = await _capture_old_async( - snapshots=snapshots, resolved_kwargs=resolved_kwargs) + resolved_kwargs["OLD"] = await _capture_old_async( + snapshots=snapshots, resolved_kwargs=resolved_kwargs + ) # Ideally, we would catch any exception here and strip the checkers from the traceback. # Unfortunately, this can not be done in Python 3, see @@ -596,16 +757,18 @@ async def wrapper(*args, **kwargs): # type: ignore result = await func(*args, **kwargs) if postconditions: - resolved_kwargs['result'] = result + resolved_kwargs["result"] = result violation_error = await _assert_postconditions_async( - postconditions=postconditions, resolved_kwargs=resolved_kwargs) + postconditions=postconditions, resolved_kwargs=resolved_kwargs + ) if violation_error: raise violation_error return result finally: in_progress.discard(id_func) + else: def wrapper(*args, **kwargs): # type: ignore @@ -632,25 +795,36 @@ def wrapper(*args, **kwargs): # type: ignore in_progress.add(id_func) - (preconditions, snapshots, postconditions) = _unpack_pre_snap_posts(wrapper) + (preconditions, snapshots, postconditions) = _unpack_pre_snap_posts( + wrapper + ) resolved_kwargs = kwargs_from_call( - param_names=param_names, kwdefaults=kwdefaults, args=args, kwargs=kwargs) + param_names=param_names, + kwdefaults=kwdefaults, + args=args, + kwargs=kwargs, + ) type_error = _assert_resolved_kwargs_valid( - postconditions=postconditions, resolved_kwargs=resolved_kwargs) + postconditions=postconditions, resolved_kwargs=resolved_kwargs + ) if type_error: raise type_error violation_error = _assert_preconditions( - preconditions=preconditions, resolved_kwargs=resolved_kwargs, func=func) + preconditions=preconditions, + resolved_kwargs=resolved_kwargs, + func=func, + ) if violation_error: raise violation_error # Capture the snapshots if postconditions and snapshots: - resolved_kwargs['OLD'] = _capture_old( - snapshots=snapshots, resolved_kwargs=resolved_kwargs, func=func) + resolved_kwargs["OLD"] = _capture_old( + snapshots=snapshots, resolved_kwargs=resolved_kwargs, func=func + ) # Ideally, we would catch any exception here and strip the checkers from the traceback. # Unfortunately, this can not be done in Python 3, see @@ -658,10 +832,13 @@ def wrapper(*args, **kwargs): # type: ignore result = func(*args, **kwargs) if postconditions: - resolved_kwargs['result'] = result + resolved_kwargs["result"] = result violation_error = _assert_postconditions( - postconditions=postconditions, resolved_kwargs=resolved_kwargs, func=func) + postconditions=postconditions, + resolved_kwargs=resolved_kwargs, + func=func, + ) if violation_error: raise violation_error @@ -672,10 +849,15 @@ def wrapper(*args, **kwargs): # type: ignore # Copy __doc__ and other properties so that doctests can run functools.update_wrapper(wrapper=wrapper, wrapped=func) - assert not hasattr(wrapper, "__preconditions__"), "Expected no preconditions set on a pristine contract checker." - assert not hasattr(wrapper, "__postcondition_snapshots__"), \ - "Expected no postcondition snapshots set on a pristine contract checker." - assert not hasattr(wrapper, "__postconditions__"), "Expected no postconditions set on a pristine contract checker." + assert not hasattr( + wrapper, "__preconditions__" + ), "Expected no preconditions set on a pristine contract checker." + assert not hasattr( + wrapper, "__postcondition_snapshots__" + ), "Expected no postcondition snapshots set on a pristine contract checker." + assert not hasattr( + wrapper, "__postconditions__" + ), "Expected no postconditions set on a pristine contract checker." # Precondition is a list of condition groups (i.e. disjunctive normal form): # each group consists of AND'ed preconditions, while the groups are OR'ed. @@ -700,10 +882,11 @@ def add_precondition_to_checker(checker: CallableT, contract: Contract) -> None: assert hasattr(checker, "__preconditions__") preconditions = getattr(checker, "__preconditions__") assert isinstance(preconditions, list) - assert len(preconditions) <= 1, \ - ("At most a single group of preconditions expected when wrapping with a contract checker. " - "The preconditions are merged only in the DBC metaclass. " - "The current number of precondition groups: {}").format(len(preconditions)) + assert len(preconditions) <= 1, ( + "At most a single group of preconditions expected when wrapping with a contract checker. " + "The preconditions are merged only in the DBC metaclass. " + "The current number of precondition groups: {}" + ).format(len(preconditions)) if len(preconditions) == 0: # Create the first group if there is no group so far, i.e. this is the first decorator. @@ -728,7 +911,9 @@ def add_snapshot_to_checker(checker: CallableT, snapshot: Snapshot) -> None: for snap in snapshots: assert isinstance(snap, Snapshot) if snap.name == snapshot.name: - raise ValueError("There are conflicting snapshots with the name: {!r}".format(snap.name)) + raise ValueError( + "There are conflicting snapshots with the name: {!r}".format(snap.name) + ) snapshots.append(snapshot) @@ -746,7 +931,9 @@ def add_postcondition_to_checker(checker: CallableT, contract: Contract) -> None getattr(checker, "__postconditions__").append(contract) -def _find_self(param_names: List[str], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: +def _find_self( + param_names: List[str], args: Tuple[Any, ...], kwargs: Dict[str, Any] +) -> Any: """Find the instance of ``self`` in the arguments.""" instance_i = None try: @@ -813,9 +1000,12 @@ def wrapper(*args, **kwargs): # type: ignore try: instance = _find_self(param_names=param_names, args=args, kwargs=kwargs) except KeyError as err: - raise KeyError(("The parameter 'self' could not be found in the call to function {!r}: " - "the param names were {!r}, the args were {!r} and kwargs were {!r}").format( - func, param_names, args, kwargs)) from err + raise KeyError( + ( + "The parameter 'self' could not be found in the call to function {!r}: " + "the param names were {!r}, the args were {!r} and kwargs were {!r}" + ).format(func, param_names, args, kwargs) + ) from err # We need to disable the invariants check during the constructor. @@ -859,11 +1049,16 @@ def wrapper(*args, **kwargs): # type: ignore async def wrapper(*args, **kwargs): # type: ignore """Wrap a function of a class by checking the invariants *before* and *after* the invocation.""" try: - instance = _find_self(param_names=param_names, args=args, kwargs=kwargs) + instance = _find_self( + param_names=param_names, args=args, kwargs=kwargs + ) except KeyError as err: - raise KeyError(("The parameter 'self' could not be found in the call to function {!r}: " - "the param names were {!r}, the args were {!r} and kwargs were {!r}").format( - func, param_names, args, kwargs)) from err + raise KeyError( + ( + "The parameter 'self' could not be found in the call to function {!r}: " + "the param names were {!r}, the args were {!r} and kwargs were {!r}" + ).format(func, param_names, args, kwargs) + ) from err # We need to create a new in-progress set if it is None as the ``ContextVar`` does not accept # a factory function for the default argument. If we didn't do this, and simply set an empty @@ -902,11 +1097,16 @@ async def wrapper(*args, **kwargs): # type: ignore def wrapper(*args, **kwargs): # type: ignore """Wrap a function of a class by checking the invariants *before* and *after* the invocation.""" try: - instance = _find_self(param_names=param_names, args=args, kwargs=kwargs) + instance = _find_self( + param_names=param_names, args=args, kwargs=kwargs + ) except KeyError as err: - raise KeyError(("The parameter 'self' could not be found in the call to function {!r}: " - "the param names were {!r}, the args were {!r} and kwargs were {!r}").format( - func, param_names, args, kwargs)) from err + raise KeyError( + ( + "The parameter 'self' could not be found in the call to function {!r}: " + "the param names were {!r}, the args were {!r} and kwargs were {!r}" + ).format(func, param_names, args, kwargs) + ) from err # The following dunder indicates whether another invariant is currently being checked. If so, # we need to suspend any further invariant check to avoid endless recursion. @@ -981,19 +1181,30 @@ def add_invariant_checks(cls: ClassT) -> None: # We need to ignore __repr__ to prevent endless loops when generating error messages. # __getattribute__, __setattr__ and __delattr__ are too invasive and alter the state of the instance. # Hence we don't consider them "public". - if name in ["__new__", "__repr__", "__getattribute__", "__setattr__", "__delattr__"]: + if name in [ + "__new__", + "__repr__", + "__getattribute__", + "__setattr__", + "__delattr__", + ]: continue if name == "__init__": - assert inspect.isfunction(value) or isinstance(value, _SLOT_WRAPPER_TYPE), \ - "Expected __init__ to be either a function or a slot wrapper, but got: {}".format( - type(value)) + assert inspect.isfunction(value) or isinstance( + value, _SLOT_WRAPPER_TYPE + ), "Expected __init__ to be either a function or a slot wrapper, but got: {}".format( + type(value) + ) init_name_func = (name, value) continue - if not inspect.isfunction(value) and not isinstance(value, _SLOT_WRAPPER_TYPE) and \ - not isinstance(value, property): + if ( + not inspect.isfunction(value) + and not isinstance(value, _SLOT_WRAPPER_TYPE) + and not isinstance(value, property) + ): continue # Ignore "protected"/"private" methods @@ -1017,7 +1228,11 @@ def add_invariant_checks(cls: ClassT) -> None: names_properties.append((name, value)) else: - raise NotImplementedError("Unhandled directory entry of class {} for {}: {}".format(cls, name, value)) + raise NotImplementedError( + "Unhandled directory entry of class {} for {}: {}".format( + cls, name, value + ) + ) if init_name_func: name, func = init_name_func @@ -1038,8 +1253,15 @@ def add_invariant_checks(cls: ClassT) -> None: for name, prop in names_properties: new_prop = property( - fget=_decorate_with_invariants(func=prop.fget, is_init=False) if prop.fget else None, - fset=_decorate_with_invariants(func=prop.fset, is_init=False) if prop.fset else None, - fdel=_decorate_with_invariants(func=prop.fdel, is_init=False) if prop.fdel else None, - doc=prop.__doc__) + fget=_decorate_with_invariants(func=prop.fget, is_init=False) + if prop.fget + else None, + fset=_decorate_with_invariants(func=prop.fset, is_init=False) + if prop.fset + else None, + fdel=_decorate_with_invariants(func=prop.fdel, is_init=False) + if prop.fdel + else None, + doc=prop.__doc__, + ) setattr(cls, name, new_prop) diff --git a/icontract/_decorators.py b/icontract/_decorators.py index 5950a42..d5014c6 100644 --- a/icontract/_decorators.py +++ b/icontract/_decorators.py @@ -2,7 +2,14 @@ import inspect import reprlib import traceback -from typing import Callable, Optional, Union, Any, List, Type, TypeVar # pylint: disable=unused-import +from typing import ( + Callable, + Optional, + Union, + Any, + List, + Type, +) # pylint: disable=unused-import import icontract._checkers from icontract._globals import CallableT, ExceptionT, ClassT @@ -16,12 +23,16 @@ class require: # pylint: disable=invalid-name The arguments of the precondition are expected to be a subset of the arguments of the wrapped function. """ - def __init__(self, - condition: Callable[..., Any], - description: Optional[str] = None, - a_repr: reprlib.Repr = icontract._globals.aRepr, - enabled: bool = __debug__, - error: Optional[Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException]] = None) -> None: + def __init__( + self, + condition: Callable[..., Any], + description: Optional[str] = None, + a_repr: reprlib.Repr = icontract._globals.aRepr, + enabled: bool = __debug__, + error: Optional[ + Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException] + ] = None, + ) -> None: """ Initialize. @@ -58,22 +69,40 @@ def __init__(self, pass elif isinstance(error, type): if not issubclass(error, BaseException): - raise ValueError(("The error of the contract is given as a type, " - "but the type does not inherit from BaseException: {}").format(error)) + raise ValueError( + ( + "The error of the contract is given as a type, " + "but the type does not inherit from BaseException: {}" + ).format(error) + ) else: - if not inspect.isfunction(error) and not inspect.ismethod(error) and not isinstance(error, BaseException): + if ( + not inspect.isfunction(error) + and not inspect.ismethod(error) + and not isinstance(error, BaseException) + ): raise ValueError( - ("The error of the contract must be either a callable (a function or a method), " - "a class (subclass of BaseException) or an instance of BaseException, but got: {}").format(error)) + ( + "The error of the contract must be either a callable (a function or a method), " + "a class (subclass of BaseException) or an instance of BaseException, but got: {}" + ).format(error) + ) location = None # type: Optional[str] tb_stack = traceback.extract_stack(limit=2)[:1] if len(tb_stack) > 0: frame = tb_stack[0] - location = 'File {}, line {} in {}'.format(frame.filename, frame.lineno, frame.name) + location = "File {}, line {} in {}".format( + frame.filename, frame.lineno, frame.name + ) self._contract = Contract( - condition=condition, description=description, a_repr=a_repr, error=error, location=location) + condition=condition, + description=description, + a_repr=a_repr, + error=error, + location=location, + ) def __call__(self, func: CallableT) -> CallableT: """ @@ -98,7 +127,9 @@ def __call__(self, func: CallableT) -> CallableT: result = contract_checker assert self._contract is not None - icontract._checkers.add_precondition_to_checker(checker=contract_checker, contract=self._contract) + icontract._checkers.add_precondition_to_checker( + checker=contract_checker, contract=self._contract + ) return result @@ -115,7 +146,12 @@ class snapshot: # pylint: disable=invalid-name Snapshots are inherited from the base classes and must not have conflicting names in the class hierarchy. """ - def __init__(self, capture: Callable[..., Any], name: Optional[str] = None, enabled: bool = __debug__) -> None: + def __init__( + self, + capture: Callable[..., Any], + name: Optional[str] = None, + enabled: bool = __debug__, + ) -> None: """ Initialize. @@ -141,7 +177,9 @@ def __init__(self, capture: Callable[..., Any], name: Optional[str] = None, enab tb_stack = traceback.extract_stack(limit=2)[:1] if len(tb_stack) > 0: frame = tb_stack[0] - location = 'File {}, line {} in {}'.format(frame.filename, frame.lineno, frame.name) + location = "File {}, line {} in {}".format( + frame.filename, frame.lineno, frame.name + ) self._snapshot = Snapshot(capture=capture, name=name, location=location) @@ -161,12 +199,18 @@ def __call__(self, func: CallableT) -> CallableT: contract_checker = icontract._checkers.find_checker(func=func) if contract_checker is None: - raise ValueError("You are decorating a function with a snapshot, but no postcondition was defined " - "on the function before.") + raise ValueError( + "You are decorating a function with a snapshot, but no postcondition was defined " + "on the function before." + ) - assert self._snapshot is not None, "Expected the enabled snapshot to have the property ``snapshot`` set." + assert ( + self._snapshot is not None + ), "Expected the enabled snapshot to have the property ``snapshot`` set." - icontract._checkers.add_snapshot_to_checker(checker=contract_checker, snapshot=self._snapshot) + icontract._checkers.add_snapshot_to_checker( + checker=contract_checker, snapshot=self._snapshot + ) return func @@ -180,12 +224,16 @@ class ensure: # pylint: disable=invalid-name not have "result" among its arguments. """ - def __init__(self, - condition: Callable[..., Any], - description: Optional[str] = None, - a_repr: reprlib.Repr = icontract._globals.aRepr, - enabled: bool = __debug__, - error: Optional[Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException]] = None) -> None: + def __init__( + self, + condition: Callable[..., Any], + description: Optional[str] = None, + a_repr: reprlib.Repr = icontract._globals.aRepr, + enabled: bool = __debug__, + error: Optional[ + Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException] + ] = None, + ) -> None: """ Initialize. @@ -222,22 +270,40 @@ def __init__(self, pass elif isinstance(error, type): if not issubclass(error, BaseException): - raise ValueError(("The error of the contract is given as a type, " - "but the type does not inherit from BaseException: {}").format(error)) + raise ValueError( + ( + "The error of the contract is given as a type, " + "but the type does not inherit from BaseException: {}" + ).format(error) + ) else: - if not inspect.isfunction(error) and not inspect.ismethod(error) and not isinstance(error, BaseException): + if ( + not inspect.isfunction(error) + and not inspect.ismethod(error) + and not isinstance(error, BaseException) + ): raise ValueError( - ("The error of the contract must be either a callable (a function or a method), " - "a class (subclass of BaseException) or an instance of BaseException, but got: {}").format(error)) + ( + "The error of the contract must be either a callable (a function or a method), " + "a class (subclass of BaseException) or an instance of BaseException, but got: {}" + ).format(error) + ) location = None # type: Optional[str] tb_stack = traceback.extract_stack(limit=2)[:1] if len(tb_stack) > 0: frame = tb_stack[0] - location = 'File {}, line {} in {}'.format(frame.filename, frame.lineno, frame.name) + location = "File {}, line {} in {}".format( + frame.filename, frame.lineno, frame.name + ) self._contract = Contract( - condition=condition, description=description, a_repr=a_repr, error=error, location=location) + condition=condition, + description=description, + a_repr=a_repr, + error=error, + location=location, + ) def __call__(self, func: CallableT) -> CallableT: """ @@ -262,7 +328,9 @@ def __call__(self, func: CallableT) -> CallableT: result = contract_checker assert self._contract is not None - icontract._checkers.add_postcondition_to_checker(checker=contract_checker, contract=self._contract) + icontract._checkers.add_postcondition_to_checker( + checker=contract_checker, contract=self._contract + ) return result @@ -284,12 +352,16 @@ class invariant: # pylint: disable=invalid-name """ - def __init__(self, - condition: Callable[..., Any], - description: Optional[str] = None, - a_repr: reprlib.Repr = icontract._globals.aRepr, - enabled: bool = __debug__, - error: Optional[Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException]] = None) -> None: + def __init__( + self, + condition: Callable[..., Any], + description: Optional[str] = None, + a_repr: reprlib.Repr = icontract._globals.aRepr, + enabled: bool = __debug__, + error: Optional[ + Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException] + ] = None, + ) -> None: """ Initialize a class decorator to establish the invariant on all the public methods. @@ -328,30 +400,52 @@ def __init__(self, pass elif isinstance(error, type): if not issubclass(error, BaseException): - raise ValueError(("The error of the contract is given as a type, " - "but the type does not inherit from BaseException: {}").format(error)) + raise ValueError( + ( + "The error of the contract is given as a type, " + "but the type does not inherit from BaseException: {}" + ).format(error) + ) else: - if not inspect.isfunction(error) and not inspect.ismethod(error) and not isinstance(error, BaseException): + if ( + not inspect.isfunction(error) + and not inspect.ismethod(error) + and not isinstance(error, BaseException) + ): raise ValueError( - ("The error of the contract must be either a callable (a function or a method), " - "a class (subclass of BaseException) or an instance of BaseException, but got: {}").format(error)) + ( + "The error of the contract must be either a callable (a function or a method), " + "a class (subclass of BaseException) or an instance of BaseException, but got: {}" + ).format(error) + ) location = None # type: Optional[str] tb_stack = traceback.extract_stack(limit=2)[:1] if len(tb_stack) > 0: frame = tb_stack[0] - location = 'File {}, line {} in {}'.format(frame.filename, frame.lineno, frame.name) + location = "File {}, line {} in {}".format( + frame.filename, frame.lineno, frame.name + ) if inspect.iscoroutinefunction(condition): raise ValueError( - "Async conditions are not possible in invariants as sync methods such as __init__ have to be wrapped.") + "Async conditions are not possible in invariants as sync methods such as __init__ have to be wrapped." + ) self._contract = Contract( - condition=condition, description=description, a_repr=a_repr, error=error, location=location) - - if self._contract.mandatory_args and self._contract.mandatory_args != ['self']: - raise ValueError("Expected an invariant condition with at most an argument 'self', but got: {}".format( - self._contract.condition_args)) + condition=condition, + description=description, + a_repr=a_repr, + error=error, + location=location, + ) + + if self._contract.mandatory_args and self._contract.mandatory_args != ["self"]: + raise ValueError( + "Expected an invariant condition with at most an argument 'self', but got: {}".format( + self._contract.condition_args + ) + ) def __call__(self, cls: ClassT) -> ClassT: """ @@ -364,15 +458,20 @@ def __call__(self, cls: ClassT) -> ClassT: if not self.enabled: return cls - assert self._contract is not None, "self._contract must be set if the contract was enabled." + assert ( + self._contract is not None + ), "self._contract must be set if the contract was enabled." if not hasattr(cls, "__invariants__"): invariants = [] # type: List[Contract] setattr(cls, "__invariants__", invariants) else: invariants = getattr(cls, "__invariants__") - assert isinstance(invariants, list), \ - "Expected invariants of class {} to be a list, but got: {}".format(cls, type(invariants)) + assert isinstance( + invariants, list + ), "Expected invariants of class {} to be a list, but got: {}".format( + cls, type(invariants) + ) invariants.append(self._contract) diff --git a/icontract/_globals.py b/icontract/_globals.py index f9075a9..08d77e9 100644 --- a/icontract/_globals.py +++ b/icontract/_globals.py @@ -27,6 +27,6 @@ # # Contracts marked with SLOW are also disabled if the interpreter is run in optimized mode (``-O`` or ``-OO``). SLOW = __debug__ and os.environ.get("ICONTRACT_SLOW", "") != "" -CallableT = TypeVar('CallableT', bound=Callable[..., Any]) -ClassT = TypeVar('ClassT', bound=type) -ExceptionT = TypeVar('ExceptionT', bound=BaseException) +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) +ClassT = TypeVar("ClassT", bound=type) +ExceptionT = TypeVar("ExceptionT", bound=BaseException) diff --git a/icontract/_metaclass.py b/icontract/_metaclass.py index 8e03116..623d933 100644 --- a/icontract/_metaclass.py +++ b/icontract/_metaclass.py @@ -3,8 +3,17 @@ import inspect import sys import weakref -from typing import List, MutableMapping, Any, Callable, Optional, cast, Set, Type, \ - TypeVar # pylint: disable=unused-import +from typing import ( + List, + MutableMapping, + Any, + Callable, + Optional, + cast, + Set, + Type, + TypeVar, +) # pylint: disable=unused-import from icontract._types import Contract, Snapshot import icontract._checkers @@ -13,7 +22,9 @@ # pylint: skip-file -def _collapse_invariants(bases: List[type], namespace: MutableMapping[str, Any]) -> None: +def _collapse_invariants( + bases: List[type], namespace: MutableMapping[str, Any] +) -> None: """Collect invariants from the bases and merge them with the invariants in the namespace.""" invariants = [] # type: List[Contract] @@ -23,16 +34,20 @@ def _collapse_invariants(bases: List[type], namespace: MutableMapping[str, Any]) invariants.extend(getattr(base, "__invariants__")) # Add invariants in the current namespace - if '__invariants__' in namespace: - invariants.extend(namespace['__invariants__']) + if "__invariants__" in namespace: + invariants.extend(namespace["__invariants__"]) # Change the final invariants in the namespace if invariants: namespace["__invariants__"] = invariants -def _collapse_preconditions(base_preconditions: List[List[Contract]], bases_have_func: bool, - preconditions: List[List[Contract]], func: Callable[..., Any]) -> List[List[Contract]]: +def _collapse_preconditions( + base_preconditions: List[List[Contract]], + bases_have_func: bool, + preconditions: List[List[Contract]], + func: Callable[..., Any], +) -> List[List[Contract]]: """ Collapse function preconditions with the preconditions collected from the base classes. @@ -43,15 +58,21 @@ def _collapse_preconditions(base_preconditions: List[List[Contract]], bases_have :return: collapsed sequence of precondition groups """ if not base_preconditions and bases_have_func and preconditions: - raise TypeError(("The function {} can not weaken the preconditions because the bases specify " - "no preconditions at all. Hence this function must accept all possible input since " - "the preconditions are OR'ed and no precondition implies a dummy precondition which is always " - "fulfilled.").format(func.__qualname__)) + raise TypeError( + ( + "The function {} can not weaken the preconditions because the bases specify " + "no preconditions at all. Hence this function must accept all possible input since " + "the preconditions are OR'ed and no precondition implies a dummy precondition which is always " + "fulfilled." + ).format(func.__qualname__) + ) return base_preconditions + preconditions -def _collapse_snapshots(base_snapshots: List[Snapshot], snapshots: List[Snapshot]) -> List[Snapshot]: +def _collapse_snapshots( + base_snapshots: List[Snapshot], snapshots: List[Snapshot] +) -> List[Snapshot]: """ Collapse snapshots of pre-invocation values with the snapshots collected from the base classes. @@ -64,17 +85,22 @@ def _collapse_snapshots(base_snapshots: List[Snapshot], snapshots: List[Snapshot for snap in collapsed: if snap.name in seen_names: - raise ValueError("There are conflicting snapshots with the name: {!r}.\n\n" - "Please mind that the snapshots are inherited from the base classes. " - "Does one of the base classes defines a snapshot with the same name?".format(snap.name)) + raise ValueError( + "There are conflicting snapshots with the name: {!r}.\n\n" + "Please mind that the snapshots are inherited from the base classes. " + "Does one of the base classes defines a snapshot with the same name?".format( + snap.name + ) + ) seen_names.add(snap.name) return collapsed -def _collapse_postconditions(base_postconditions: List[Contract], postconditions: List[Contract]) -> \ - List[Contract]: +def _collapse_postconditions( + base_postconditions: List[Contract], postconditions: List[Contract] +) -> List[Contract]: """ Collapse function postconditions with the postconditions collected from the base classes. @@ -85,7 +111,9 @@ def _collapse_postconditions(base_postconditions: List[Contract], postconditions return base_postconditions + postconditions -def _decorate_namespace_function(bases: List[type], namespace: MutableMapping[str, Any], key: str) -> None: +def _decorate_namespace_function( + bases: List[type], namespace: MutableMapping[str, Any], key: str +) -> None: """Collect preconditions and postconditions from the bases and decorate the function at the ``key``.""" value = namespace[key] assert inspect.isfunction(value) or isinstance(value, (staticmethod, classmethod)) @@ -114,7 +142,7 @@ def _decorate_namespace_function(bases: List[type], namespace: MutableMapping[st # Preconditions and postconditions of __init__ and __new__ of base classes are deliberately ignored # (and not collapsed) since initialization is an operation specific to the concrete class and # does not relate to the class hierarchy. - if key not in ['__init__', '__new__']: + if key not in ["__init__", "__new__"]: base_preconditions = [] # type: List[List[Contract]] base_snapshots = [] # type: List[Snapshot] base_postconditions = [] # type: List[Contract] @@ -131,7 +159,9 @@ def _decorate_namespace_function(bases: List[type], namespace: MutableMapping[st # Ignore functions which don't have preconditions or postconditions if base_contract_checker is not None: base_preconditions.extend(base_contract_checker.__preconditions__) - base_snapshots.extend(base_contract_checker.__postcondition_snapshots__) + base_snapshots.extend( + base_contract_checker.__postcondition_snapshots__ + ) base_postconditions.extend(base_contract_checker.__postconditions__) # Collapse preconditions and postconditions from the bases with the function's own ones @@ -139,12 +169,16 @@ def _decorate_namespace_function(bases: List[type], namespace: MutableMapping[st base_preconditions=base_preconditions, bases_have_func=bases_have_func, preconditions=preconditions, - func=func) + func=func, + ) - snapshots = _collapse_snapshots(base_snapshots=base_snapshots, snapshots=snapshots) + snapshots = _collapse_snapshots( + base_snapshots=base_snapshots, snapshots=snapshots + ) postconditions = _collapse_postconditions( - base_postconditions=base_postconditions, postconditions=postconditions) + base_postconditions=base_postconditions, postconditions=postconditions + ) if preconditions or postconditions: if contract_checker is None: @@ -160,7 +194,9 @@ def _decorate_namespace_function(bases: List[type], namespace: MutableMapping[st namespace[key] = classmethod(contract_checker) else: - raise NotImplementedError("Unexpected value for a function: {}".format(value)) + raise NotImplementedError( + "Unexpected value for a function: {}".format(value) + ) # Override the preconditions and postconditions contract_checker.__preconditions__ = preconditions # type: ignore @@ -168,7 +204,9 @@ def _decorate_namespace_function(bases: List[type], namespace: MutableMapping[st contract_checker.__postconditions__ = postconditions # type: ignore -def _decorate_namespace_property(bases: List[type], namespace: MutableMapping[str, Any], key: str) -> None: +def _decorate_namespace_property( + bases: List[type], namespace: MutableMapping[str, Any], key: str +) -> None: """Collect contracts for all getters/setters/deleters corresponding to ``key`` and decorate them.""" value = namespace[key] assert isinstance(value, property) @@ -192,9 +230,11 @@ def _decorate_namespace_property(bases: List[type], namespace: MutableMapping[st for base in bases: if hasattr(base, key): base_property = getattr(base, key) - assert isinstance(base_property, property), \ - "Expected base {} to have {} as property, but got: {}".format(base, key, - base_property) + assert isinstance( + base_property, property + ), "Expected base {} to have {} as property, but got: {}".format( + base, key, base_property + ) if func == value.fget: base_func = getattr(base, key).fget @@ -203,7 +243,9 @@ def _decorate_namespace_property(bases: List[type], namespace: MutableMapping[st elif func == value.fdel: base_func = getattr(base, key).fdel else: - raise NotImplementedError("Unhandled case: func neither value.fget, value.fset nor value.fdel") + raise NotImplementedError( + "Unhandled case: func neither value.fget, value.fset nor value.fdel" + ) if base_func is None: continue @@ -216,7 +258,9 @@ def _decorate_namespace_property(bases: List[type], namespace: MutableMapping[st # Ignore functions which don't have preconditions or postconditions if base_contract_checker is not None: base_preconditions.extend(base_contract_checker.__preconditions__) - base_snapshots.extend(base_contract_checker.__postcondition_snapshots__) + base_snapshots.extend( + base_contract_checker.__postcondition_snapshots__ + ) base_postconditions.extend(base_contract_checker.__postconditions__) # Add preconditions and postconditions of the function @@ -234,12 +278,16 @@ def _decorate_namespace_property(bases: List[type], namespace: MutableMapping[st base_preconditions=base_preconditions, bases_have_func=bases_have_func, preconditions=preconditions, - func=func) + func=func, + ) - snapshots = _collapse_snapshots(base_snapshots=base_snapshots, snapshots=snapshots) + snapshots = _collapse_snapshots( + base_snapshots=base_snapshots, snapshots=snapshots + ) postconditions = _collapse_postconditions( - base_postconditions=base_postconditions, postconditions=postconditions) + base_postconditions=base_postconditions, postconditions=postconditions + ) if preconditions or postconditions: if contract_checker is None: @@ -253,7 +301,9 @@ def _decorate_namespace_property(bases: List[type], namespace: MutableMapping[st elif func == value.fdel: fdel = contract_checker else: - raise NotImplementedError("Unhandled case: func neither fget, fset nor fdel") + raise NotImplementedError( + "Unhandled case: func neither fget, fset nor fdel" + ) # Override the preconditions and postconditions contract_checker.__preconditions__ = preconditions # type: ignore @@ -264,7 +314,9 @@ def _decorate_namespace_property(bases: List[type], namespace: MutableMapping[st namespace[key] = property(fget=fget, fset=fset, fdel=fdel) -def _dbc_decorate_namespace(bases: List[type], namespace: MutableMapping[str, Any]) -> None: +def _dbc_decorate_namespace( + bases: List[type], namespace: MutableMapping[str, Any] +) -> None: """ Collect invariants, preconditions and postconditions from the bases and decorate all the methods. @@ -287,7 +339,7 @@ def _dbc_decorate_namespace(bases: List[type], namespace: MutableMapping[str, An _CONTRACT_CLASSES = weakref.WeakSet() # type: ignore -T = TypeVar('T') # pylint: disable=invalid-name +T = TypeVar("T") # pylint: disable=invalid-name def _register_for_hypothesis(cls: Type[T]) -> None: @@ -316,8 +368,10 @@ class DBCMeta(abc.ABCMeta): # instead of ``mcs``. # pylint: disable=bad-mcs-classmethod-argument - if sys.version_info < (3, ): - raise NotImplementedError("Python versions below not supported, got: {}".format(sys.version_info)) + if sys.version_info < (3,): + raise NotImplementedError( + "Python versions below not supported, got: {}".format(sys.version_info) + ) if sys.version_info < (3, 6): # pylint: disable=arguments-differ @@ -338,6 +392,7 @@ def __new__(mlcs, name, bases, namespace): _register_for_hypothesis(cls) return cls + else: def __new__(mlcs, name, bases, namespace, **kwargs): # type: ignore diff --git a/icontract/_recompute.py b/icontract/_recompute.py index 5276ede..f5a5f9d 100644 --- a/icontract/_recompute.py +++ b/icontract/_recompute.py @@ -7,7 +7,20 @@ import inspect import sys import uuid -from typing import (Any, Mapping, Dict, List, Optional, Union, Tuple, Set, Callable, cast, Iterable, TypeVar) # pylint: disable=unused-import +from typing import ( + Any, + Mapping, + Dict, + List, + Optional, + Union, + Tuple, + Set, + Callable, + cast, + Iterable, + TypeVar, +) # pylint: disable=unused-import from _ast import If @@ -41,7 +54,7 @@ def __bool__(self) -> Any: return self.result -ContextT = TypeVar('ContextT', bound=ast.expr_context) +ContextT = TypeVar("ContextT", bound=ast.expr_context) class _CollectStoredNamesVisitor(ast.NodeVisitor): @@ -85,8 +98,11 @@ def _collect_name_loads(nodes: Iterable[ast.expr]) -> List[ast.expr]: # noinspection PyTypeChecker -def _translate_all_expression_to_a_module(generator_exp: ast.GeneratorExp, generated_function_name: str, - name_to_value: Mapping[str, Any]) -> ast.Module: +def _translate_all_expression_to_a_module( + generator_exp: ast.GeneratorExp, + generated_function_name: str, + name_to_value: Mapping[str, Any], +) -> ast.Module: """ Generate the AST of the module to trace an all quantifier on an generator expression. @@ -101,14 +117,18 @@ def _translate_all_expression_to_a_module(generator_exp: ast.GeneratorExp, gener assert not hasattr(builtins, generated_function_name) # Collect all the names involved in the generation - relevant_names = _collect_stored_names(generator.target for generator in generator_exp.generators) + relevant_names = _collect_stored_names( + generator.target for generator in generator_exp.generators + ) assert generated_function_name not in relevant_names # Work backwards, from the most-inner block outwards - result_id = 'icontract_tracing_all_result_{}'.format(uuid.uuid4().hex) - result_assignment = ast.Assign(targets=[ast.Name(id=result_id, ctx=ast.Store())], value=generator_exp.elt) + result_id = "icontract_tracing_all_result_{}".format(uuid.uuid4().hex) + result_assignment = ast.Assign( + targets=[ast.Name(id=result_id, ctx=ast.Store())], value=generator_exp.elt + ) exceptional_return = ast.Return( ast.Tuple( @@ -119,22 +139,36 @@ def _translate_all_expression_to_a_module(generator_exp: ast.GeneratorExp, gener ast.Tuple( elts=[ ast.Constant(value=relevant_name, kind=None), - ast.Name(id=relevant_name, ctx=ast.Load()) + ast.Name(id=relevant_name, ctx=ast.Load()), ], - ctx=ast.Load()) for relevant_name in relevant_names + ctx=ast.Load(), + ) + for relevant_name in relevant_names ], - ctx=ast.Load()) + ctx=ast.Load(), + ), ], - ctx=ast.Load())) + ctx=ast.Load(), + ) + ) # While happy return shall not be executed, we add it here for robustness in case # future refactorings forget to check for that edge case. happy_return = ast.Return( - ast.Tuple(elts=[ast.Name(id=result_id, ctx=ast.Load()), - ast.Constant(value=None, kind=None)], ctx=ast.Load())) + ast.Tuple( + elts=[ + ast.Name(id=result_id, ctx=ast.Load()), + ast.Constant(value=None, kind=None), + ], + ctx=ast.Load(), + ) + ) critical_if: If = ast.If( - test=ast.Name(id=result_id, ctx=ast.Load()), body=[ast.Pass()], orelse=[exceptional_return]) + test=ast.Name(id=result_id, ctx=ast.Load()), + body=[ast.Pass()], + orelse=[exceptional_return], + ) # Previous inner block to be added as body to the next outer block block = None # type: Optional[List[ast.stmt]] @@ -148,9 +182,23 @@ def _translate_all_expression_to_a_module(generator_exp: ast.GeneratorExp, gener block = [ast.If(test=condition, body=block, orelse=[])] if not comprehension.is_async: - block = [ast.For(target=comprehension.target, iter=comprehension.iter, body=block, orelse=[])] + block = [ + ast.For( + target=comprehension.target, + iter=comprehension.iter, + body=block, + orelse=[], + ) + ] else: - block = [ast.AsyncFor(target=comprehension.target, iter=comprehension.iter, body=block, orelse=[])] + block = [ + ast.AsyncFor( + target=comprehension.target, + iter=comprehension.iter, + body=block, + orelse=[], + ) + ] assert block is not None @@ -163,42 +211,76 @@ def _translate_all_expression_to_a_module(generator_exp: ast.GeneratorExp, gener args = [ast.arg(arg=name, annotation=None) for name in sorted(name_to_value.keys())] if sys.version_info < (3, 5): - raise NotImplementedError("Python versions below 3.5 not supported, got: {}".format(sys.version_info)) + raise NotImplementedError( + "Python versions below 3.5 not supported, got: {}".format(sys.version_info) + ) if not is_async: if sys.version_info < (3, 8): func_def_node = ast.FunctionDef( name=generated_function_name, - args=ast.arguments(args=args, kwonlyargs=[], kw_defaults=[], defaults=[], vararg=None, kwarg=None), + args=ast.arguments( + args=args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + vararg=None, + kwarg=None, + ), decorator_list=[], - body=block) # type: Union[ast.FunctionDef, ast.AsyncFunctionDef] + body=block, + ) # type: Union[ast.FunctionDef, ast.AsyncFunctionDef] module_node = ast.Module(body=[func_def_node]) else: func_def_node = ast.FunctionDef( name=generated_function_name, args=ast.arguments( - args=args, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[], vararg=None, kwarg=None), + args=args, + posonlyargs=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + vararg=None, + kwarg=None, + ), decorator_list=[], - body=block) + body=block, + ) module_node = ast.Module(body=[func_def_node], type_ignores=[]) else: if sys.version_info < (3, 8): async_func_def_node = ast.AsyncFunctionDef( name=generated_function_name, - args=ast.arguments(args=args, kwonlyargs=[], kw_defaults=[], defaults=[], vararg=None, kwarg=None), + args=ast.arguments( + args=args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + vararg=None, + kwarg=None, + ), decorator_list=[], - body=block) + body=block, + ) module_node = ast.Module(body=[async_func_def_node]) else: async_func_def_node = ast.AsyncFunctionDef( name=generated_function_name, args=ast.arguments( - args=args, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[], vararg=None, kwarg=None), + args=args, + posonlyargs=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + vararg=None, + kwarg=None, + ), decorator_list=[], - body=block) + body=block, + ) module_node = ast.Module(body=[async_func_def_node], type_ignores=[]) @@ -267,6 +349,7 @@ def visit_NameConstant(self, node: ast.NameConstant) -> Any: """Forward the node value as a result.""" self.recomputed_values[node] = node.value return node.value + else: def visit_Constant(self, node: ast.Constant) -> Any: @@ -276,7 +359,9 @@ def visit_Constant(self, node: ast.Constant) -> Any: if sys.version_info >= (3, 6): - def visit_FormattedValue(self, node: ast.FormattedValue) -> Union[str, Placeholder]: + def visit_FormattedValue( + self, node: ast.FormattedValue + ) -> Union[str, Placeholder]: """Format the node value.""" recomputed_format_spec = None # type: Optional[Union[str, Placeholder]] if node.format_spec is not None: @@ -291,27 +376,30 @@ def visit_FormattedValue(self, node: ast.FormattedValue) -> Union[str, Placehold if recomputed_format_spec is PLACEHOLDER or recomputed_value is PLACEHOLDER: return PLACEHOLDER - fmt = ['{'] + fmt = ["{"] # See https://docs.python.org/3/library/ast.html#ast.FormattedValue for these # constants if node.conversion == -1: pass elif node.conversion == 115: - fmt.append('!s') + fmt.append("!s") elif node.conversion == 114: - fmt.append('!r') + fmt.append("!r") elif node.conversion == 97: - fmt.append('!a') + fmt.append("!a") else: - raise NotImplementedError("Unhandled conversion of a formatted value node {!r}: {}".format( - node, node.conversion)) + raise NotImplementedError( + "Unhandled conversion of a formatted value node {!r}: {}".format( + node, node.conversion + ) + ) if recomputed_format_spec is not None: fmt.append(f":{recomputed_format_spec}") - fmt.append('}') + fmt.append("}") - return ''.join(fmt).format(recomputed_value) + return "".join(fmt).format(recomputed_value) def visit_JoinedStr(self, node: ast.JoinedStr) -> Union[str, Placeholder]: """Visit the values and concatenate them.""" @@ -321,7 +409,7 @@ def visit_JoinedStr(self, node: ast.JoinedStr) -> Union[str, Placeholder]: if PLACEHOLDER in recomputed_values: return PLACEHOLDER - joined_str = ''.join(recomputed_values) + joined_str = "".join(recomputed_values) self.recomputed_values[node] = joined_str return joined_str @@ -373,7 +461,10 @@ def visit_Dict(self, node: ast.Dict) -> Union[Dict[Any, Any], Placeholder]: recomputed_dict[self.visit(node=key)] = self.visit(node=val) # Please see "NOTE ABOUT PLACEHOLDERS AND RE-COMPUTATION" - if any(key is PLACEHOLDER or value is PLACEHOLDER for key, value in recomputed_dict.items()): + if any( + key is PLACEHOLDER or value is PLACEHOLDER + for key, value in recomputed_dict.items() + ): return PLACEHOLDER self.recomputed_values[node] = recomputed_dict @@ -382,8 +473,11 @@ def visit_Dict(self, node: ast.Dict) -> Union[Dict[Any, Any], Placeholder]: def visit_Name(self, node: ast.Name) -> Any: """Load the variable by looking it up in the variable look-up and in the built-ins.""" if not isinstance(node.ctx, ast.Load): - raise NotImplementedError("Can only compute a value of Load on a name {}, but got context: {}".format( - node.id, node.ctx)) + raise NotImplementedError( + "Can only compute a value of Load on a name {}, but got context: {}".format( + node.id, node.ctx + ) + ) result = None # type: Optional[Any] @@ -500,7 +594,9 @@ def visit_Compare(self, node: ast.Compare) -> Any: comparators = [self.visit(node=comparator) for comparator in node.comparators] # Please see "NOTE ABOUT PLACEHOLDERS AND RE-COMPUTATION" - if left is PLACEHOLDER or any(comparator is PLACEHOLDER for comparator in comparators): + if left is PLACEHOLDER or any( + comparator is PLACEHOLDER for comparator in comparators + ): return PLACEHOLDER result = None # type: Optional[Any] @@ -547,13 +643,19 @@ def visit_Call(self, node: ast.Call) -> Any: return PLACEHOLDER if not callable(func): - raise ValueError("Unexpected call to a non-calllable during the re-computation: {}".format(func)) + raise ValueError( + "Unexpected call to a non-calllable during the re-computation: {}".format( + func + ) + ) if inspect.iscoroutinefunction(func): raise ValueError( - ("Unexpected coroutine function {} as a condition of a contract. " - "You must specify your own error if the condition of your contract is a coroutine function." - ).format(func)) + ( + "Unexpected coroutine function {} as a condition of a contract. " + "You must specify your own error if the condition of your contract is a coroutine function." + ).format(func) + ) # Short-circuit tracing the all quantifier over a generator expression # yapf: disable @@ -633,7 +735,10 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: if not isinstance(node.ctx, ast.Load): raise NotImplementedError( - "Can only compute a value of Load on the attribute {}, but got context: {}".format(node.attr, node.ctx)) + "Can only compute a value of Load on the attribute {}, but got context: {}".format( + node.attr, node.ctx + ) + ) result = getattr(value, node.attr) @@ -656,7 +761,10 @@ def visit_NamedExpr(self, node: ast.NamedExpr) -> Any: if not isinstance(target.ctx, ast.Store): raise NotImplementedError( - "Expected Store context in the target of a named expression, but got: {}".format(target.ctx)) + "Expected Store context in the target of a named expression, but got: {}".format( + target.ctx + ) + ) self._name_to_value[target.id] = value @@ -700,7 +808,9 @@ def visit_Slice(self, node: ast.Slice) -> Union[slice, Placeholder]: if sys.version_info < (3, 9): - def visit_ExtSlice(self, node: ast.ExtSlice) -> Union[Tuple[Any, ...], Placeholder]: + def visit_ExtSlice( + self, node: ast.ExtSlice + ) -> Union[Tuple[Any, ...], Placeholder]: """Visit each dimension of the advanced slicing and assemble the dimensions in a tuple.""" result = tuple(self.visit(node=dim) for dim in node.dims) @@ -723,7 +833,9 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: self.recomputed_values[node] = result return result - def _trace_all_with_generator(self, func: Callable[..., Any], node: ast.Call) -> Any: + def _trace_all_with_generator( + self, func: Callable[..., Any], node: ast.Call + ) -> Any: """Re-write the all call with for loops to trace the first offending item, if any.""" assert func == builtins.all # pylint: disable=comparison-with-callable assert len(node.args) == 1 @@ -736,7 +848,7 @@ def _trace_all_with_generator(self, func: Callable[..., Any], node: ast.Call) -> if recomputed_arg is PLACEHOLDER: return PLACEHOLDER - result = func(*(self.visit(node=node.args[0]), )) + result = func(*(self.visit(node=node.args[0]),)) if result: return result @@ -747,18 +859,21 @@ def _trace_all_with_generator(self, func: Callable[..., Any], node: ast.Call) -> generator_exp = node.args[0] assert isinstance(generator_exp, ast.GeneratorExp) - generated_function_name = "icontract_tracing_all_with_generator_expr_{}".format(uuid.uuid4().hex) + generated_function_name = "icontract_tracing_all_with_generator_expr_{}".format( + uuid.uuid4().hex + ) module_node = _translate_all_expression_to_a_module( generator_exp=generator_exp, generated_function_name=generated_function_name, - name_to_value=self._name_to_value) + name_to_value=self._name_to_value, + ) # In case you want to debug the generated function at this point, # you probably want to use ``astor`` module to generate the source code # based on the ``module_node``. - code = compile(source=module_node, filename='', mode='exec') + code = compile(source=module_node, filename="", mode="exec") module_locals = {} # type: Dict[str, Any] module_globals = {} # type: Dict[str, Any] @@ -770,41 +885,65 @@ def _trace_all_with_generator(self, func: Callable[..., Any], node: ast.Call) -> assert not bool(result), "Expected the unhappy path here" assert isinstance(inputs, tuple) - assert all(isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], str) for item in inputs) - - return FirstExceptionInAll(result=result, inputs=cast(Tuple[Tuple[str, Any]], inputs)) - - def _execute_comprehension(self, node: Union[ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp]) -> Any: + assert all( + isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], str) + for item in inputs + ) + + return FirstExceptionInAll( + result=result, inputs=cast(Tuple[Tuple[str, Any]], inputs) + ) + + def _execute_comprehension( + self, node: Union[ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp] + ) -> Any: """Compile the generator or comprehension from the node and execute the compiled code.""" # Please see "NOTE ABOUT NAME 🠒 VALUE STACKING". if any(value is PLACEHOLDER for value in self._name_to_value.values()): return PLACEHOLDER - args = [ast.arg(arg=name, annotation=None) for name in sorted(self._name_to_value.keys())] + args = [ + ast.arg(arg=name, annotation=None) + for name in sorted(self._name_to_value.keys()) + ] - if sys.version_info < (3, ): - raise NotImplementedError("Python versions below 3 not supported, got: {}".format(sys.version_info)) + if sys.version_info < (3,): + raise NotImplementedError( + "Python versions below 3 not supported, got: {}".format( + sys.version_info + ) + ) if sys.version_info < (3, 8): func_def_node = ast.FunctionDef( name="generator_expr", - args=ast.arguments(args=args, kwonlyargs=[], kw_defaults=[], defaults=[]), + args=ast.arguments( + args=args, kwonlyargs=[], kw_defaults=[], defaults=[] + ), decorator_list=[], - body=[ast.Return(node)]) + body=[ast.Return(node)], + ) module_node = ast.Module(body=[func_def_node]) else: func_def_node = ast.FunctionDef( name="generator_expr", - args=ast.arguments(args=args, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[]), + args=ast.arguments( + args=args, + posonlyargs=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), decorator_list=[], - body=[ast.Return(node)]) + body=[ast.Return(node)], + ) module_node = ast.Module(body=[func_def_node], type_ignores=[]) ast.fix_missing_locations(module_node) - code = compile(source=module_node, filename='', mode='exec') + code = compile(source=module_node, filename="", mode="exec") module_locals = {} # type: Dict[str, Any] module_globals = {} # type: Dict[str, Any] @@ -841,7 +980,9 @@ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any: # PLACEHOLDER's. old_name_to_value = copy.copy(self._name_to_value) - for target_name in _collect_stored_names([generator.target for generator in node.generators]): + for target_name in _collect_stored_names( + [generator.target for generator in node.generators] + ): self._name_to_value[target_name] = PLACEHOLDER self.visit(node.elt) @@ -866,7 +1007,9 @@ def visit_ListComp(self, node: ast.ListComp) -> Any: # Please see "NOTE ABOUT NAME 🠒 VALUE STACKING". old_name_to_value = copy.copy(self._name_to_value) - for target_name in _collect_stored_names([generator.target for generator in node.generators]): + for target_name in _collect_stored_names( + [generator.target for generator in node.generators] + ): self._name_to_value[target_name] = PLACEHOLDER self.visit(node.elt) @@ -893,7 +1036,9 @@ def visit_SetComp(self, node: ast.SetComp) -> Any: # Please see "NOTE ABOUT NAME 🠒 VALUE STACKING". old_name_to_value = copy.copy(self._name_to_value) - for target_name in _collect_stored_names([generator.target for generator in node.generators]): + for target_name in _collect_stored_names( + [generator.target for generator in node.generators] + ): self._name_to_value[target_name] = PLACEHOLDER self.visit(node.elt) @@ -920,7 +1065,9 @@ def visit_DictComp(self, node: ast.DictComp) -> Any: # Please see "NOTE ABOUT NAME 🠒 VALUE STACKING". old_name_to_value = copy.copy(self._name_to_value) - for target_name in _collect_stored_names([generator.target for generator in node.generators]): + for target_name in _collect_stored_names( + [generator.target for generator in node.generators] + ): self._name_to_value[target_name] = PLACEHOLDER self.visit(node.key) @@ -946,12 +1093,19 @@ def visit_Lambda(self, node: ast.Lambda) -> Callable[..., Any]: raise NotImplementedError( "Re-computation of in-line lambda functions is not supported since it is quite tricky to implement and " "we decided to implement it only once there is a real need for it. " - "Please make a feature request on https://github.com/Parquery/icontract") + "Please make a feature request on https://github.com/Parquery/icontract" + ) def visit_Return(self, node: ast.Return) -> Any: # pylint: disable=no-self-use """Raise an exception that this node is unexpected.""" - raise AssertionError("Unexpected return node during the re-computation: {}".format(ast.dump(node))) + raise AssertionError( + "Unexpected return node during the re-computation: {}".format( + ast.dump(node) + ) + ) def generic_visit(self, node: ast.AST) -> None: """Raise an exception that this node has not been handled.""" - raise NotImplementedError("Unhandled re-computation of the node: {} {}".format(type(node), node)) + raise NotImplementedError( + "Unhandled re-computation of the node: {} {}".format(type(node), node) + ) diff --git a/icontract/_represent.py b/icontract/_represent.py index 7dcc720..d2711d1 100644 --- a/icontract/_represent.py +++ b/icontract/_represent.py @@ -6,7 +6,16 @@ import sys import textwrap import uuid -from typing import (Any, Mapping, MutableMapping, Callable, List, Dict, cast, Optional) # pylint: disable=unused-import +from typing import ( + Any, + Mapping, + MutableMapping, + Callable, + List, + Dict, + cast, + Optional, +) # pylint: disable=unused-import import asttokens.asttokens @@ -27,8 +36,13 @@ def _representable(value: Any) -> bool: :param value: value related to an AST node :return: True if we want to represent it in the violation error """ - return not inspect.isclass(value) and not inspect.isfunction(value) and not inspect.ismethod(value) and not \ - inspect.ismodule(value) and not inspect.isbuiltin(value) + return ( + not inspect.isclass(value) + and not inspect.isfunction(value) + and not inspect.ismethod(value) + and not inspect.ismodule(value) + and not inspect.isbuiltin(value) + ) class Visitor(ast.NodeVisitor): @@ -37,8 +51,12 @@ class Visitor(ast.NodeVisitor): # pylint: disable=invalid-name # pylint: disable=missing-docstring - def __init__(self, recomputed_values: Mapping[ast.AST, Any], variable_lookup: List[Mapping[str, Any]], - atok: asttokens.asttokens.ASTTokens) -> None: + def __init__( + self, + recomputed_values: Mapping[ast.AST, Any], + variable_lookup: List[Mapping[str, Any]], + atok: asttokens.asttokens.ASTTokens, + ) -> None: """ Initialize. @@ -199,8 +217,8 @@ def __init__(self, atok: asttokens.asttokens.ASTTokens, node: ast.Lambda) -> Non self.text = text -_DECORATOR_RE = re.compile(r'^\s*@[a-zA-Z_]') -_DEF_CLASS_RE = re.compile(r'^\s*(async\s+def|def |class )') +_DECORATOR_RE = re.compile(r"^\s*@[a-zA-Z_]") +_DEF_CLASS_RE = re.compile(r"^\s*(async\s+def|def |class )") class DecoratorInspection: @@ -217,7 +235,9 @@ def __init__(self, atok: asttokens.asttokens.ASTTokens, node: ast.Call) -> None: self.node = node -def inspect_decorator(lines: List[str], lineno: int, filename: str) -> DecoratorInspection: +def inspect_decorator( + lines: List[str], lineno: int, filename: str +) -> DecoratorInspection: """ Parse the file in which the decorator is called and figure out the corresponding call AST node. @@ -227,9 +247,13 @@ def inspect_decorator(lines: List[str], lineno: int, filename: str) -> Decorator :return: inspected decorator call """ if lineno < 0 or lineno >= len(lines): - raise ValueError(("Given line number {} of one of the decorator lines " - "is not within the range [{}, {}) of lines in {}.\n\n" - "The decorator lines were:\n{}").format(lineno, 0, len(lines), filename, "\n".join(lines))) + raise ValueError( + ( + "Given line number {} of one of the decorator lines " + "is not within the range [{}, {}) of lines in {}.\n\n" + "The decorator lines were:\n{}" + ).format(lineno, 0, len(lines), filename, "\n".join(lines)) + ) # Go up till a line starts with a decorator decorator_lineno = None # type: Optional[int] @@ -239,8 +263,11 @@ def inspect_decorator(lines: List[str], lineno: int, filename: str) -> Decorator break if decorator_lineno is None: - raise SyntaxError("Decorator corresponding to the line {} could not be found in file {}: {!r}".format( - lineno + 1, filename, lines[lineno])) + raise SyntaxError( + "Decorator corresponding to the line {} could not be found in file {}: {!r}".format( + lineno + 1, filename, lines[lineno] + ) + ) # Find the decorator end -- it's either a function definition, a class definition or another decorator decorator_end_lineno = None # type: Optional[int] @@ -252,69 +279,91 @@ def inspect_decorator(lines: List[str], lineno: int, filename: str) -> Decorator break if decorator_end_lineno is None: - raise SyntaxError(("The next statement following the decorator corresponding to the line {} " - "could not be found in file {}: {!r}").format(lineno + 1, filename, lines[lineno])) + raise SyntaxError( + ( + "The next statement following the decorator corresponding to the line {} " + "could not be found in file {}: {!r}" + ).format(lineno + 1, filename, lines[lineno]) + ) decorator_lines = lines[decorator_lineno:decorator_end_lineno] # We need to dedent the decorator and add a dummy decorate so that we can parse its text as valid source code. - decorator_text = textwrap.dedent("".join(decorator_lines)) + "def dummy_{}(): pass".format(uuid.uuid4().hex) + decorator_text = textwrap.dedent( + "".join(decorator_lines) + ) + "def dummy_{}(): pass".format(uuid.uuid4().hex) atok = asttokens.asttokens.ASTTokens(decorator_text, parse=True) if not isinstance(atok.tree, ast.Module): - raise ValueError(("Expected the parsed decorator text to live in an AST module. " - "Are you trying to inspect a condition lambda which was not stated in a decorator? " - "(This feature is currently unsupported in icontract.) " - "The decorator was expected at line {} in {}. " - "The decorator lines under inspection were {}-{}.").format( - lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno)) + raise ValueError( + ( + "Expected the parsed decorator text to live in an AST module. " + "Are you trying to inspect a condition lambda which was not stated in a decorator? " + "(This feature is currently unsupported in icontract.) " + "The decorator was expected at line {} in {}. " + "The decorator lines under inspection were {}-{}." + ).format(lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno) + ) module_node = atok.tree if len(module_node.body) != 1: - raise ValueError(("Expected the module AST of the decorator text to have a single statement. " - "Are you trying to inspect a condition lambda which was not stated in a decorator? " - "(This feature is currently unsupported in icontract.) " - "The decorator was expected at line {} in {}. " - "The decorator lines under inspection were {}-{}.").format( - lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno)) + raise ValueError( + ( + "Expected the module AST of the decorator text to have a single statement. " + "Are you trying to inspect a condition lambda which was not stated in a decorator? " + "(This feature is currently unsupported in icontract.) " + "The decorator was expected at line {} in {}. " + "The decorator lines under inspection were {}-{}." + ).format(lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno) + ) if not isinstance(module_node.body[0], ast.FunctionDef): - raise ValueError(("Expected the only statement in the AST module corresponding to the decorator text " - "to be a function definition. " - "Are you trying to inspect a condition lambda which was not stated in a decorator? " - "(This feature is currently unsupported in icontract.) " - "The decorator was expected at line {} in {}. " - "The decorator lines under inspection were {}-{}.").format( - lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno)) + raise ValueError( + ( + "Expected the only statement in the AST module corresponding to the decorator text " + "to be a function definition. " + "Are you trying to inspect a condition lambda which was not stated in a decorator? " + "(This feature is currently unsupported in icontract.) " + "The decorator was expected at line {} in {}. " + "The decorator lines under inspection were {}-{}." + ).format(lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno) + ) func_def_node = module_node.body[0] if len(func_def_node.decorator_list) != 1: raise ValueError( - ("Expected the function AST node corresponding to the decorator text to have a single decorator. " - "Are you trying to inspect a condition lambda which was not stated in a decorator? " - "(This feature is currently unsupported in icontract.) " - "The decorator was expected at line {} in {}. " - "The decorator lines under inspection were {}-{}.").format(lineno + 1, filename, decorator_lineno + 1, - decorator_end_lineno)) + ( + "Expected the function AST node corresponding to the decorator text to have a single decorator. " + "Are you trying to inspect a condition lambda which was not stated in a decorator? " + "(This feature is currently unsupported in icontract.) " + "The decorator was expected at line {} in {}. " + "The decorator lines under inspection were {}-{}." + ).format(lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno) + ) if not isinstance(func_def_node.decorator_list[0], ast.Call): - raise ValueError(("Expected the only decorator in the function definition AST node corresponding " - "to the decorator text to be a call node. " - "Are you trying to inspect a condition lambda which was not stated in a decorator? " - "(This feature is currently unsupported in icontract.) " - "The decorator was expected at line {} in {}. " - "The decorator lines under inspection were {}-{}.").format( - lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno)) + raise ValueError( + ( + "Expected the only decorator in the function definition AST node corresponding " + "to the decorator text to be a call node. " + "Are you trying to inspect a condition lambda which was not stated in a decorator? " + "(This feature is currently unsupported in icontract.) " + "The decorator was expected at line {} in {}. " + "The decorator lines under inspection were {}-{}." + ).format(lineno + 1, filename, decorator_lineno + 1, decorator_end_lineno) + ) call_node = func_def_node.decorator_list[0] return DecoratorInspection(atok=atok, node=call_node) -def find_lambda_condition(decorator_inspection: DecoratorInspection) -> Optional[ConditionLambdaInspection]: +def find_lambda_condition( + decorator_inspection: DecoratorInspection, +) -> Optional[ConditionLambdaInspection]: """ Inspect the decorator and extract the condition as lambda. @@ -325,31 +374,39 @@ def find_lambda_condition(decorator_inspection: DecoratorInspection) -> Optional lambda_node = None # type: Optional[ast.Lambda] if len(call_node.args) > 0: - assert isinstance(call_node.args[0], ast.Lambda), \ - ("Expected the first argument to the decorator to be a condition as lambda AST node, " - "but got: {}").format(type(call_node.args[0])) + assert isinstance(call_node.args[0], ast.Lambda), ( + "Expected the first argument to the decorator to be a condition as lambda AST node, " + "but got: {}" + ).format(type(call_node.args[0])) lambda_node = call_node.args[0] elif len(call_node.keywords) > 0: for keyword in call_node.keywords: if keyword.arg == "condition": - assert isinstance(keyword.value, ast.Lambda), \ - "Expected lambda node as value of the 'condition' argument to the decorator." + assert isinstance( + keyword.value, ast.Lambda + ), "Expected lambda node as value of the 'condition' argument to the decorator." lambda_node = keyword.value break - assert lambda_node is not None, "Expected to find a keyword AST node with 'condition' arg, but found none" + assert ( + lambda_node is not None + ), "Expected to find a keyword AST node with 'condition' arg, but found none" else: raise AssertionError( "Expected a call AST node of a decorator to have either args or keywords, but got: {}".format( - ast.dump(call_node))) + ast.dump(call_node) + ) + ) return ConditionLambdaInspection(atok=decorator_inspection.atok, node=lambda_node) -def inspect_lambda_condition(condition: Callable[..., Any]) -> Optional[ConditionLambdaInspection]: +def inspect_lambda_condition( + condition: Callable[..., Any] +) -> Optional[ConditionLambdaInspection]: """ Try to extract the source code of the condition as lambda. @@ -362,7 +419,9 @@ def inspect_lambda_condition(condition: Callable[..., Any]) -> Optional[Conditio filename = inspect.getsourcefile(condition) assert filename is not None - decorator_inspection = inspect_decorator(lines=lines, lineno=condition_lineno, filename=filename) + decorator_inspection = inspect_decorator( + lines=lines, lineno=condition_lineno, filename=filename + ) lambda_inspection = find_lambda_condition(decorator_inspection=decorator_inspection) diff --git a/icontract/_types.py b/icontract/_types.py index e3a76bc..3e96bfb 100644 --- a/icontract/_types.py +++ b/icontract/_types.py @@ -1,7 +1,16 @@ """Define data structures shared among the modules.""" import inspect import reprlib -from typing import Callable, Optional, Union, Set, List, Any, Type, cast # pylint: disable=unused-import +from typing import ( + Callable, + Optional, + Union, + Set, + List, + Any, + Type, + cast, +) # pylint: disable=unused-import import icontract._globals @@ -11,12 +20,16 @@ class Contract: """Represent a contract to be enforced as a precondition, postcondition or as an invariant.""" - def __init__(self, - condition: Callable[..., Any], - description: Optional[str] = None, - a_repr: reprlib.Repr = icontract._globals.aRepr, - error: Optional[Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException]] = None, - location: Optional[str] = None) -> None: + def __init__( + self, + condition: Callable[..., Any], + description: Optional[str] = None, + a_repr: reprlib.Repr = icontract._globals.aRepr, + error: Optional[ + Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException] + ] = None, + location: Optional[str] = None, + ) -> None: """ Initialize. @@ -42,7 +55,9 @@ def __init__(self, # Names of the mandatory arguments of the condition self.mandatory_args = [ - name for name, param in signature.parameters.items() if param.default == inspect.Parameter.empty + name + for name, param in signature.parameters.items() + if param.default == inspect.Parameter.empty ] self.description = description @@ -53,7 +68,9 @@ def __init__(self, self.error_arg_set = None # type: Optional[Set[str]] if error is not None and (inspect.isfunction(error) or inspect.ismethod(error)): error_as_callable = cast(Callable[..., ExceptionT], error) - self.error_args = list(inspect.signature(error_as_callable).parameters.keys()) + self.error_args = list( + inspect.signature(error_as_callable).parameters.keys() + ) self.error_arg_set = set(self.error_args) self.location = location @@ -62,7 +79,12 @@ def __init__(self, class Snapshot: """Define a snapshot of an argument *prior* to the function invocation that is later supplied to a postcondition.""" - def __init__(self, capture: Callable[..., Any], name: Optional[str] = None, location: Optional[str] = None) -> None: + def __init__( + self, + capture: Callable[..., Any], + name: Optional[str] = None, + location: Optional[str] = None, + ) -> None: """ Initialize. @@ -80,14 +102,20 @@ def __init__(self, capture: Callable[..., Any], name: Optional[str] = None, loca if name is None: if len(args) == 0: - raise ValueError("You must name a snapshot if no argument was given in the capture function.") + raise ValueError( + "You must name a snapshot if no argument was given in the capture function." + ) elif len(args) > 1: - raise ValueError("You must name a snapshot if multiple arguments were given in the capture function.") + raise ValueError( + "You must name a snapshot if multiple arguments were given in the capture function." + ) else: assert len(args) == 1 name = args[0] - assert name is not None, "Expected ``name`` to be set in the preceding execution paths." + assert ( + name is not None + ), "Expected ``name`` to be set in the preceding execution paths." self.name = name self.args = args diff --git a/precommit.py b/precommit.py index 1cbb4a5..bb33a36 100755 --- a/precommit.py +++ b/precommit.py @@ -8,13 +8,14 @@ def main() -> int: - """"Execute main routine.""" + """ "Execute main routine.""" parser = argparse.ArgumentParser() parser.add_argument( "--overwrite", help="Overwrites the unformatted source files with the well-formatted code in place. " "If not set, an exception is raised if any of the files do not conform to the style guide.", - action='store_true') + action="store_true", + ) args = parser.parse_args() @@ -22,64 +23,78 @@ def main() -> int: repo_root = pathlib.Path(__file__).parent - print("YAPF'ing...") - yapf_targets = ["tests", "icontract", "setup.py", "precommit.py", "benchmark.py", "benchmarks", "tests_with_others"] + print("Reformatting...") + reformat_targets = [ + "tests", + "icontract", + "setup.py", + "precommit.py", + "benchmark.py", + "benchmarks", + "tests_with_others", + ] if sys.version_info >= (3, 6): - yapf_targets.append('tests_3_6') + reformat_targets.append("tests_3_6") if sys.version_info >= (3, 7): - yapf_targets.append('tests_3_7') + reformat_targets.append("tests_3_7") if sys.version_info >= (3, 8, 5): - yapf_targets.append('tests_3_8') + reformat_targets.append("tests_3_8") if overwrite: subprocess.check_call( - ["yapf", "--in-place", "--style=style.yapf", "--recursive"] + yapf_targets, cwd=str(repo_root)) + [sys.executable, "-m", "black"] + reformat_targets, cwd=str(repo_root) + ) else: subprocess.check_call( - ["yapf", "--diff", "--style=style.yapf", "--recursive"] + yapf_targets, cwd=str(repo_root)) + [sys.executable, "-m", "black"] + reformat_targets, cwd=str(repo_root) + ) if sys.version_info < (3, 8): - print("Mypy since 1.5 dropped support for Python 3.7 and " - "you are running Python {}, so skipping.".format(sys.version_info)) + print( + "Mypy since 1.5 dropped support for Python 3.7 and " + "you are running Python {}, so skipping.".format(sys.version_info) + ) else: print("Mypy'ing...") mypy_targets = ["icontract", "tests"] if sys.version_info >= (3, 6): - mypy_targets.append('tests_3_6') + mypy_targets.append("tests_3_6") if sys.version_info >= (3, 7): - mypy_targets.append('tests_3_7') + mypy_targets.append("tests_3_7") if sys.version_info >= (3, 8): - mypy_targets.append('tests_3_8') - mypy_targets.append('tests_with_others') + mypy_targets.append("tests_3_8") + mypy_targets.append("tests_with_others") subprocess.check_call(["mypy", "--strict"] + mypy_targets, cwd=str(repo_root)) print("Pylint'ing...") - pylint_targets = ['icontract', 'tests'] + pylint_targets = ["icontract", "tests"] if sys.version_info >= (3, 6): - pylint_targets.append('tests_3_6') + pylint_targets.append("tests_3_6") if sys.version_info >= (3, 7): - pylint_targets.append('tests_3_7') + pylint_targets.append("tests_3_7") if sys.version_info >= (3, 8): - pylint_targets.append('tests_3_8') - pylint_targets.append('tests_with_others') + pylint_targets.append("tests_3_8") + pylint_targets.append("tests_with_others") - subprocess.check_call(["pylint", "--rcfile=pylint.rc"] + pylint_targets, cwd=str(repo_root)) + subprocess.check_call( + ["pylint", "--rcfile=pylint.rc"] + pylint_targets, cwd=str(repo_root) + ) print("Pydocstyle'ing...") subprocess.check_call(["pydocstyle", "icontract"], cwd=str(repo_root)) print("Testing...") env = os.environ.copy() - env['ICONTRACT_SLOW'] = 'true' + env["ICONTRACT_SLOW"] = "true" # yapf: disable subprocess.check_call( @@ -111,7 +126,9 @@ def main() -> int: subprocess.check_call([sys.executable, "-m", "doctest", str(pth)]) print("Checking the restructured text of the readme...") - subprocess.check_call([sys.executable, 'setup.py', 'check', '--restructuredtext', '--strict']) + subprocess.check_call( + [sys.executable, "setup.py", "check", "--restructuredtext", "--strict"] + ) return 0 diff --git a/setup.py b/setup.py index fbb1f36..ef9fd49 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ here = os.path.abspath(os.path.dirname(__file__)) # pylint: disable=invalid-name -with open(os.path.join(here, 'README.rst'), encoding='utf-8') as fid: +with open(os.path.join(here, "README.rst"), encoding="utf-8") as fid: long_description = fid.read() # pylint: disable=invalid-name with open(os.path.join(here, "requirements.txt"), encoding="utf-8") as fid: @@ -26,14 +26,14 @@ # problems with installing icontract through pip on their servers with # imports in setup.py. setup( - name='icontract', + name="icontract", # Don't forget to update the version in __init__.py and CHANGELOG.rst! - version='2.6.2', - description='Provide design-by-contract with informative violation messages.', + version="2.6.2", + description="Provide design-by-contract with informative violation messages.", long_description=long_description, - url='https://github.com/Parquery/icontract', - author='Marko Ristin', - author_email='marko@ristin.ch', + url="https://github.com/Parquery/icontract", + author="Marko Ristin", + author_email="marko@ristin.ch", classifiers=[ # yapf: disable 'Development Status :: 5 - Production/Stable', @@ -45,19 +45,31 @@ 'Programming Language :: Python :: 3.10' # yapf: enable ], - license='License :: OSI Approved :: MIT License', - keywords='design-by-contract precondition postcondition validation', - packages=find_packages(exclude=['tests*']), + license="License :: OSI Approved :: MIT License", + keywords="design-by-contract precondition postcondition validation", + packages=find_packages(exclude=["tests*"]), install_requires=install_requires, extras_require={ - 'dev': [ - 'pylint==2.13.9', 'yapf==0.20.2', 'tox>=3.0.0', 'pydocstyle>=6.1.1,<7', 'coverage>=4.5.1,<5', - 'docutils>=0.14,<1', 'pygments>=2.2.0,<3', 'dpcontracts==0.6.0', 'tabulate>=0.8.7,<1', - 'py-cpuinfo>=5.0.0,<6', 'typeguard>=2,<3', 'astor==0.8.1', 'numpy>=1,<2' - ] + (['mypy==1.5.1'] if sys.version_info >= (3, 8) else []) + - (['deal==4.23.3'] if sys.version_info >= (3, 8) else []) + (['asyncstdlib==3.9.1'] - if sys.version_info >= (3, 8) else []), + "dev": [ + "pylint==2.13.9", + "black==23.9.1", + "tox>=3.0.0", + "pydocstyle>=6.1.1,<7", + "coverage>=4.5.1,<5", + "docutils>=0.14,<1", + "pygments>=2.2.0,<3", + "dpcontracts==0.6.0", + "tabulate>=0.8.7,<1", + "py-cpuinfo>=5.0.0,<6", + "typeguard>=2,<3", + "astor==0.8.1", + "numpy>=1,<2", + ] + + (["mypy==1.5.1"] if sys.version_info >= (3, 8) else []) + + (["deal==4.23.3"] if sys.version_info >= (3, 8) else []) + + (["asyncstdlib==3.9.1"] if sys.version_info >= (3, 8) else []), }, - py_modules=['icontract'], + py_modules=["icontract"], package_data={"icontract": ["py.typed"]}, - data_files=[(".", ["LICENSE.txt", "README.rst", "requirements.txt"])]) + data_files=[(".", ["LICENSE.txt", "README.rst", "requirements.txt"])], +) diff --git a/tests/error.py b/tests/error.py index eabf626..159c22a 100644 --- a/tests/error.py +++ b/tests/error.py @@ -2,7 +2,10 @@ """Manipulate the error text.""" import re -_LOCATION_RE = re.compile(r'\AFile [^\n]+, line [0-9]+ in [a-zA-Z_0-9]+:\n(.*)\Z', flags=re.MULTILINE | re.DOTALL) +_LOCATION_RE = re.compile( + r"\AFile [^\n]+, line [0-9]+ in [a-zA-Z_0-9]+:\n(.*)\Z", + flags=re.MULTILINE | re.DOTALL, +) def wo_mandatory_location(text: str) -> str: @@ -23,6 +26,10 @@ def wo_mandatory_location(text: str) -> str: """ mtch = _LOCATION_RE.match(text) if not mtch: - raise AssertionError("Expected the text to match {}, but got: {!r}".format(_LOCATION_RE.pattern, text)) + raise AssertionError( + "Expected the text to match {}, but got: {!r}".format( + _LOCATION_RE.pattern, text + ) + ) return mtch.group(1) diff --git a/tests/mock.py b/tests/mock.py index 1a37f2b..ae1cda9 100644 --- a/tests/mock.py +++ b/tests/mock.py @@ -9,17 +9,19 @@ def __init__(self, values: List[Union[int, bool]]) -> None: """Initialize with the given values.""" self.values = values - def __lt__(self, other: int) -> 'NumpyArray': + def __lt__(self, other: int) -> "NumpyArray": """Map the value to each comparison with ``other``.""" return NumpyArray(values=[value < other for value in self.values]) - def __gt__(self, other: int) -> 'NumpyArray': + def __gt__(self, other: int) -> "NumpyArray": """Map the value to each comparison with ``other``.""" return NumpyArray(values=[value > other for value in self.values]) def __bool__(self) -> bool: """Raise a ValueError.""" - raise ValueError("The truth value of an array with more than one element is ambiguous.") + raise ValueError( + "The truth value of an array with more than one element is ambiguous." + ) def all(self) -> bool: """Return True if all values are True.""" @@ -27,4 +29,4 @@ def all(self) -> bool: def __repr__(self) -> str: """Represent with the constructor.""" - return 'NumpyArray({!r})'.format(self.values) + return "NumpyArray({!r})".format(self.values) diff --git a/tests/test_args_and_kwargs_in_contract.py b/tests/test_args_and_kwargs_in_contract.py index f1f175d..39860bb 100644 --- a/tests/test_args_and_kwargs_in_contract.py +++ b/tests/test_args_and_kwargs_in_contract.py @@ -30,7 +30,7 @@ def some_func(x: int) -> None: some_func(3) assert recorded_args is not None - self.assertTupleEqual((3, ), recorded_args) + self.assertTupleEqual((3,), recorded_args) def test_args_with_named_and_variable_positional_arguments(self) -> None: recorded_args = None # type: Optional[Tuple[Any, ...]] @@ -96,11 +96,15 @@ def some_func(*args: Any) -> None: assert violation_error is not None self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ len(_ARGS) > 2: _ARGS was (3,) args was 3 - len(_ARGS) was 1'''), tests.error.wo_mandatory_location(str(violation_error))) + len(_ARGS) was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestKwargs(unittest.TestCase): @@ -155,7 +159,9 @@ def some_func(**kwargs: Any) -> None: assert recorded_kwargs is not None self.assertDictEqual({"x": 3, "y": 2, "z": 1}, recorded_kwargs) - def test_kwargs_with_uncommon_argument_name_for_variable_keyword_arguments(self) -> None: + def test_kwargs_with_uncommon_argument_name_for_variable_keyword_arguments( + self, + ) -> None: recorded_kwargs = None # type: Optional[Dict[str, Any]] def set_kwargs(kwargs: Dict[str, Any]) -> bool: @@ -173,7 +179,7 @@ def some_func(**parameters: Any) -> None: self.assertDictEqual({"x": 3, "y": 2, "z": 1, "a": 0}, recorded_kwargs) def test_fail(self) -> None: - @icontract.require(lambda _KWARGS: 'x' in _KWARGS) + @icontract.require(lambda _KWARGS: "x" in _KWARGS) def some_func(**kwargs: Any) -> None: pass @@ -185,10 +191,14 @@ def some_func(**kwargs: Any) -> None: assert violation_error is not None self.assertEqual( - textwrap.dedent("""\ - 'x' in _KWARGS: + textwrap.dedent( + """\ + "x" in _KWARGS: _KWARGS was {'y': 3} - y was 3"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestArgsAndKwargs(unittest.TestCase): @@ -210,7 +220,7 @@ def some_func(*args: Any, **kwargs: Any) -> None: some_func(5, x=10, y=20, z=30) assert recorded_args is not None - self.assertTupleEqual((5, ), recorded_args) + self.assertTupleEqual((5,), recorded_args) assert recorded_kwargs is not None self.assertDictEqual({"x": 10, "y": 20, "z": 30}, recorded_kwargs) @@ -247,12 +257,16 @@ def test_in_function_definition(self) -> None: @icontract.require(lambda _ARGS: True) def some_func(_ARGS: Any) -> None: pass + except TypeError as error: type_error = error assert type_error is not None - self.assertEqual('The arguments of the function to be decorated with a contract checker include "_ARGS" ' - 'which is a reserved placeholder for positional arguments in the condition.', str(type_error)) + self.assertEqual( + 'The arguments of the function to be decorated with a contract checker include "_ARGS" ' + "which is a reserved placeholder for positional arguments in the condition.", + str(type_error), + ) def test_in_kwargs_in_call(self) -> None: @icontract.require(lambda _ARGS: True) @@ -266,8 +280,11 @@ def some_func(*args, **kwargs) -> None: # type: ignore type_error = error assert type_error is not None - self.assertEqual('The arguments of the function call include "_ARGS" ' - 'which is a placeholder for positional arguments in a condition.', str(type_error)) + self.assertEqual( + 'The arguments of the function call include "_ARGS" ' + "which is a placeholder for positional arguments in a condition.", + str(type_error), + ) class TestConflictOnKWARGSReported(unittest.TestCase): @@ -278,12 +295,16 @@ def test_in_function_definition(self) -> None: @icontract.require(lambda _ARGS: True) def some_func(_KWARGS: Any) -> None: pass + except TypeError as error: type_error = error assert type_error is not None - self.assertEqual('The arguments of the function to be decorated with a contract checker include "_KWARGS" ' - 'which is a reserved placeholder for keyword arguments in the condition.', str(type_error)) + self.assertEqual( + 'The arguments of the function to be decorated with a contract checker include "_KWARGS" ' + "which is a reserved placeholder for keyword arguments in the condition.", + str(type_error), + ) def test_in_kwargs_in_call(self) -> None: @icontract.require(lambda _ARGS: True) @@ -297,9 +318,12 @@ def some_func(*args, **kwargs) -> None: # type: ignore type_error = error assert type_error is not None - self.assertEqual('The arguments of the function call include "_KWARGS" ' - 'which is a placeholder for keyword arguments in a condition.', str(type_error)) + self.assertEqual( + 'The arguments of the function call include "_KWARGS" ' + "which is a placeholder for keyword arguments in a condition.", + str(type_error), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_checkers.py b/tests/test_checkers.py index 8a771c6..a880dac 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -35,14 +35,20 @@ def test_wo_decorators(self) -> None: def func() -> int: return 0 - self.assertListEqual([0], [a_func() for a_func in icontract._checkers._walk_decorator_stack(func)]) + self.assertListEqual( + [0], + [a_func() for a_func in icontract._checkers._walk_decorator_stack(func)], + ) def test_with_single_decorator(self) -> None: @decorator_plus_1 def func() -> int: return 0 - self.assertListEqual([1, 0], [a_func() for a_func in icontract._checkers._walk_decorator_stack(func)]) + self.assertListEqual( + [1, 0], + [a_func() for a_func in icontract._checkers._walk_decorator_stack(func)], + ) def test_with_double_decorator(self) -> None: @decorator_plus_2 @@ -50,7 +56,10 @@ def test_with_double_decorator(self) -> None: def func() -> int: return 0 - self.assertListEqual([3, 1, 0], [a_func() for a_func in icontract._checkers._walk_decorator_stack(func)]) + self.assertListEqual( + [3, 1, 0], + [a_func() for a_func in icontract._checkers._walk_decorator_stack(func)], + ) class TestResolveKwargs(unittest.TestCase): @@ -67,7 +76,9 @@ def some_func(x: int, y: int) -> None: assert type_error is not None self.assertRegex( - str(type_error), r"^([a-zA-Z_0-9<>.]+\.)?some_func\(\) takes 2 positional arguments but 3 were given$") + str(type_error), + r"^([a-zA-Z_0-9<>.]+\.)?some_func\(\) takes 2 positional arguments but 3 were given$", + ) def test_that_result_in_kwargs_raises_an_error(self) -> None: @icontract.ensure(lambda result: result > 0) @@ -83,7 +94,10 @@ def some_func(*args, **kwargs) -> int: # type: ignore assert type_error is not None - self.assertEqual("Unexpected argument 'result' in a function decorated with postconditions.", str(type_error)) + self.assertEqual( + "Unexpected argument 'result' in a function decorated with postconditions.", + str(type_error), + ) def test_that_OLD_in_kwargs_raises_an_error(self) -> None: @icontract.ensure(lambda result: result > 0) @@ -99,8 +113,11 @@ def some_func(*args, **kwargs) -> int: # type: ignore assert type_error is not None - self.assertEqual("Unexpected argument 'OLD' in a function decorated with postconditions.", str(type_error)) + self.assertEqual( + "Unexpected argument 'OLD' in a function decorated with postconditions.", + str(type_error), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_error.py b/tests/test_error.py index 2c1d069..16a7f98 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -25,12 +25,17 @@ def some_func(x: int) -> None: violation_error = err assert violation_error is not None - self.assertEqual('x > 0: x was -1', tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "x > 0: x was -1", tests.error.wo_mandatory_location(str(violation_error)) + ) class TestSpecifiedAsFunction(unittest.TestCase): def test_lambda(self) -> None: - @icontract.require(lambda x: x > 0, error=lambda x: ValueError("x must be positive: {}".format(x))) + @icontract.require( + lambda x: x > 0, + error=lambda x: ValueError("x must be positive: {}".format(x)), + ) def some_func(x: int) -> None: pass @@ -41,7 +46,7 @@ def some_func(x: int) -> None: value_error = err assert value_error is not None - self.assertEqual('x must be positive: -1', str(value_error)) + self.assertEqual("x must be positive: -1", str(value_error)) def test_separate_function(self) -> None: def error_func(x: int) -> ValueError: @@ -58,7 +63,7 @@ def some_func(x: int) -> None: value_error = err assert value_error is not None - self.assertEqual('x must be positive: -1', str(value_error)) + self.assertEqual("x must be positive: -1", str(value_error)) def test_separate_method(self) -> None: class Errorer: @@ -78,7 +83,7 @@ def some_func(x: int) -> None: value_error = err assert value_error is not None - self.assertEqual('x must be positive: -1', str(value_error)) + self.assertEqual("x must be positive: -1", str(value_error)) def test_report_if_result_is_not_base_exception(self) -> None: @icontract.require(lambda x: x > 0, error=lambda x: "x must be positive") # type: ignore @@ -94,7 +99,8 @@ def some_func(x: int) -> None: assert type_error is not None self.assertRegex( str(type_error), - r"^The exception returned by the contract's error does not inherit from BaseException\.$") + r"^The exception returned by the contract's error does not inherit from BaseException\.$", + ) class TestSpecifiedAsType(unittest.TestCase): @@ -110,7 +116,9 @@ def some_func(x: int) -> None: value_error = err assert value_error is not None - self.assertEqual('x > 0: x was -1', tests.error.wo_mandatory_location(str(value_error))) + self.assertEqual( + "x > 0: x was -1", tests.error.wo_mandatory_location(str(value_error)) + ) class TestSpecifiedAsInstance(unittest.TestCase): @@ -126,7 +134,7 @@ def some_func(x: int) -> None: value_error = err assert value_error is not None - self.assertEqual('negative x', str(value_error)) + self.assertEqual("negative x", str(value_error)) def test_repeated_raising(self) -> None: @icontract.require(lambda x: x > 0, error=ValueError("negative x")) @@ -140,7 +148,7 @@ def some_func(x: int) -> None: value_error = err assert value_error is not None - self.assertEqual('negative x', str(value_error)) + self.assertEqual("negative x", str(value_error)) # Repeat value_error = None @@ -150,7 +158,7 @@ def some_func(x: int) -> None: value_error = err assert value_error is not None - self.assertEqual('negative x', str(value_error)) + self.assertEqual("negative x", str(value_error)) class TestSpecifiedAsInvalidType(unittest.TestCase): @@ -164,13 +172,16 @@ class A: @icontract.require(lambda x: x > 0, error=A) # type: ignore def some_func(x: int) -> None: pass + except ValueError as err: value_error = err assert value_error is not None self.assertRegex( - str(value_error), r"The error of the contract is given as a type, " - r"but the type does not inherit from BaseException: ") + str(value_error), + r"The error of the contract is given as a type, " + r"but the type does not inherit from BaseException: ", + ) def test_in_postcondition(self) -> None: class A: @@ -182,13 +193,16 @@ class A: @icontract.ensure(lambda result: result > 0, error=A) # type: ignore def some_func() -> int: return -1 + except ValueError as err: value_error = err assert value_error is not None self.assertRegex( - str(value_error), r"The error of the contract is given as a type, " - r"but the type does not inherit from BaseException: ") + str(value_error), + r"The error of the contract is given as a type, " + r"but the type does not inherit from BaseException: ", + ) def test_in_invariant(self) -> None: value_error = None # type: Optional[ValueError] @@ -201,13 +215,16 @@ class A: class B: def __init__(self) -> None: self.x = -1 + except ValueError as err: value_error = err assert value_error is not None self.assertRegex( - str(value_error), r"The error of the contract is given as a type, " - r"but the type does not inherit from BaseException: ") + str(value_error), + r"The error of the contract is given as a type, " + r"but the type does not inherit from BaseException: ", + ) class TestSpecifiedAsInstanceOfInvalidType(unittest.TestCase): @@ -222,14 +239,17 @@ def __init__(self, msg: str) -> None: @icontract.require(lambda x: x > 0, error=A("something went wrong")) # type: ignore def some_func(x: int) -> None: pass + except ValueError as err: value_error = err assert value_error is not None self.assertRegex( - str(value_error), r"^The error of the contract must be either a callable \(a function or a method\), " + str(value_error), + r"^The error of the contract must be either a callable \(a function or a method\), " r"a class \(subclass of BaseException\) or an instance of BaseException, " - r"but got: <.*\.A object at 0x.*>$") + r"but got: <.*\.A object at 0x.*>$", + ) def test_in_postcondition(self) -> None: class A: @@ -242,14 +262,17 @@ def __init__(self, msg: str) -> None: @icontract.ensure(lambda result: result > 0, error=A("something went wrong")) # type: ignore def some_func() -> int: return -1 + except ValueError as err: value_error = err assert value_error is not None self.assertRegex( - str(value_error), r"^The error of the contract must be either a callable \(a function or a method\), " + str(value_error), + r"^The error of the contract must be either a callable \(a function or a method\), " r"a class \(subclass of BaseException\) or an instance of BaseException, " - r"but got: <.*\.A object at 0x.*>$") + r"but got: <.*\.A object at 0x.*>$", + ) def test_in_invariant(self) -> None: class A: @@ -263,15 +286,18 @@ def __init__(self, msg: str) -> None: class B: def __init__(self) -> None: self.x = -1 + except ValueError as err: value_error = err assert value_error is not None self.assertRegex( - str(value_error), r"^The error of the contract must be either a callable \(a function or a method\), " + str(value_error), + r"^The error of the contract must be either a callable \(a function or a method\), " r"a class \(subclass of BaseException\) or an instance of BaseException, " - r"but got: <.*\.A object at 0x.*>$") + r"but got: <.*\.A object at 0x.*>$", + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_for_integrators.py b/tests/test_for_integrators.py index 76af981..1a13932 100644 --- a/tests/test_for_integrators.py +++ b/tests/test_for_integrators.py @@ -12,11 +12,17 @@ @icontract.require(lambda x: x > 0) -@icontract.snapshot(lambda cumulative: None if len(cumulative) == 0 else cumulative[-1], "last") +@icontract.snapshot( + lambda cumulative: None if len(cumulative) == 0 else cumulative[-1], "last" +) @icontract.snapshot(lambda cumulative: len(cumulative), "len_cumulative") @icontract.ensure(lambda cumulative, OLD: len(cumulative) == OLD.len_cumulative + 1) -@icontract.ensure(lambda x, cumulative, OLD: OLD.last is None or OLD.last + x == cumulative[-1]) -@icontract.ensure(lambda x, cumulative, OLD: OLD.last is not None or x == cumulative[-1]) +@icontract.ensure( + lambda x, cumulative, OLD: OLD.last is None or OLD.last + x == cumulative[-1] +) +@icontract.ensure( + lambda x, cumulative, OLD: OLD.last is not None or x == cumulative[-1] +) def func_with_contracts(x: int, cumulative: List[int]) -> None: if len(cumulative) == 0: cumulative.append(x) @@ -48,13 +54,17 @@ def test_evaluating(self) -> None: preconditions = checker.__preconditions__ # type: ignore assert isinstance(preconditions, list) assert all(isinstance(group, list) for group in preconditions) - assert all(isinstance(contract, icontract._types.Contract) for group in preconditions for contract in group) + assert all( + isinstance(contract, icontract._types.Contract) + for group in preconditions + for contract in group + ) ## # Evaluate manually preconditions ## - kwargs = {'x': 4, 'cumulative': [2]} + kwargs = {"x": 4, "cumulative": [2]} success = True # We have to check preconditions in groups in case they are weakened @@ -62,7 +72,8 @@ def test_evaluating(self) -> None: success = True for contract in group: condition_kwargs = icontract._checkers.select_condition_kwargs( - contract=contract, resolved_kwargs=kwargs) + contract=contract, resolved_kwargs=kwargs + ) success = contract.condition(**condition_kwargs) if not success: @@ -88,7 +99,11 @@ def some_func(x: int) -> None: # pylint: disable=unused-argument checker=checker, contract=icontract._types.Contract( condition=lambda x: x > 0, - error=lambda x: icontract.ViolationError("x must be positive, but got: {}".format(x)))) + error=lambda x: icontract.ViolationError( + "x must be positive, but got: {}".format(x) + ), + ), + ) violation_error = None # type: Optional[icontract.ViolationError] try: @@ -98,7 +113,7 @@ def some_func(x: int) -> None: # pylint: disable=unused-argument assert violation_error is not None - self.assertEqual('x must be positive, but got: -1', str(violation_error)) + self.assertEqual("x must be positive, but got: -1", str(violation_error)) class TestPostconditions(unittest.TestCase): @@ -109,24 +124,31 @@ def test_evaluating(self) -> None: # Retrieve postconditions postconditions = checker.__postconditions__ # type: ignore assert isinstance(postconditions, list) - assert all(isinstance(contract, icontract._types.Contract) for contract in postconditions) + assert all( + isinstance(contract, icontract._types.Contract) + for contract in postconditions + ) # Retrieve snapshots snapshots = checker.__postcondition_snapshots__ # type: ignore assert isinstance(snapshots, list) - assert all(isinstance(snapshot, icontract._types.Snapshot) for snapshot in snapshots) + assert all( + isinstance(snapshot, icontract._types.Snapshot) for snapshot in snapshots + ) ## # Evaluate manually postconditions ## cumulative = [2] - kwargs = {'x': 4, 'cumulative': cumulative} # kwargs **before** the call + kwargs = {"x": 4, "cumulative": cumulative} # kwargs **before** the call # Capture OLD old_as_mapping = dict() # type: MutableMapping[str, Any] for snap in snapshots: - snap_kwargs = icontract._checkers.select_capture_kwargs(a_snapshot=snap, resolved_kwargs=kwargs) + snap_kwargs = icontract._checkers.select_capture_kwargs( + a_snapshot=snap, resolved_kwargs=kwargs + ) old_as_mapping[snap.name] = snap.capture(**snap_kwargs) @@ -136,11 +158,13 @@ def test_evaluating(self) -> None: cumulative.append(6) # Evaluate the postconditions - kwargs['OLD'] = old + kwargs["OLD"] = old success = True for contract in postconditions: - condition_kwargs = icontract._checkers.select_condition_kwargs(contract=contract, resolved_kwargs=kwargs) + condition_kwargs = icontract._checkers.select_condition_kwargs( + contract=contract, resolved_kwargs=kwargs + ) success = contract.condition(**condition_kwargs) @@ -165,10 +189,16 @@ def some_func(lst: List[int]) -> None: checker=checker, contract=icontract._types.Contract( condition=lambda OLD, lst: OLD.len_lst == len(lst), - error=icontract.ViolationError("The size of lst must not change."))) + error=icontract.ViolationError("The size of lst must not change."), + ), + ) icontract._checkers.add_snapshot_to_checker( - checker=checker, snapshot=icontract._types.Snapshot(capture=lambda lst: len(lst), name="len_lst")) + checker=checker, + snapshot=icontract._types.Snapshot( + capture=lambda lst: len(lst), name="len_lst" + ), + ) violation_error = None # type: Optional[icontract.ViolationError] try: @@ -179,7 +209,7 @@ def some_func(lst: List[int]) -> None: assert violation_error is not None - self.assertEqual('The size of lst must not change.', str(violation_error)) + self.assertEqual("The size of lst must not change.", str(violation_error)) class TestInvariants(unittest.TestCase): @@ -189,11 +219,15 @@ def test_reading(self) -> None: invariants = ClassWithInvariants.__invariants__ # type: ignore assert isinstance(invariants, list) - assert all(isinstance(invariant, icontract._types.Contract) for invariant in invariants) + assert all( + isinstance(invariant, icontract._types.Contract) for invariant in invariants + ) invariants = instance.__invariants__ # type: ignore assert isinstance(invariants, list) - assert all(isinstance(invariant, icontract._types.Contract) for invariant in invariants) + assert all( + isinstance(invariant, icontract._types.Contract) for invariant in invariants + ) success = True for contract in invariants: @@ -216,11 +250,15 @@ def test_condition_text(self) -> None: assert icontract._represent.is_lambda(a_function=contract.condition) - lambda_inspection = icontract._represent.inspect_lambda_condition(condition=contract.condition) + lambda_inspection = icontract._represent.inspect_lambda_condition( + condition=contract.condition + ) assert lambda_inspection is not None - self.assertEqual('OLD.last is not None or x == cumulative[-1]', lambda_inspection.text) + self.assertEqual( + "OLD.last is not None or x == cumulative[-1]", lambda_inspection.text + ) assert isinstance(lambda_inspection.node, ast.Lambda) @@ -233,7 +271,10 @@ def test_condition_representation(self) -> None: assert isinstance(contract, icontract._types.Contract) text = icontract._represent.represent_condition(contract.condition) - self.assertEqual('lambda x, cumulative, OLD: OLD.last is not None or x == cumulative[-1]', text) + self.assertEqual( + "lambda x, cumulative, OLD: OLD.last is not None or x == cumulative[-1]", + text, + ) if __name__ == "__main__": diff --git a/tests/test_globals.py b/tests/test_globals.py index dab8a0d..c2a95ff 100644 --- a/tests/test_globals.py +++ b/tests/test_globals.py @@ -7,10 +7,12 @@ class TestSlow(unittest.TestCase): def test_slow_set(self) -> None: - self.assertTrue(icontract.SLOW, - "icontract.SLOW was not set. Please check if you set the environment variable ICONTRACT_SLOW " - "before running this test.") + self.assertTrue( + icontract.SLOW, + "icontract.SLOW was not set. Please check if you set the environment variable ICONTRACT_SLOW " + "before running this test.", + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_inheritance_invariant.py b/tests/test_inheritance_invariant.py index 90896f3..4730a6a 100644 --- a/tests/test_inheritance_invariant.py +++ b/tests/test_inheritance_invariant.py @@ -38,10 +38,18 @@ def some_func(self) -> int: return 2 inst = B() - self.assertEqual(1, Increment.count, "Invariant is expected to run only once at the initializer.") + self.assertEqual( + 1, + Increment.count, + "Invariant is expected to run only once at the initializer.", + ) inst.some_func() - self.assertEqual(3, Increment.count, "Invariant is expected to run before and after the method call.") + self.assertEqual( + 3, + Increment.count, + "Invariant is expected to run before and after the method call.", + ) class TestViolation(unittest.TestCase): @@ -70,10 +78,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was instance of B - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inherited_violated_in_child(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -103,10 +115,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was instance of B - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_additional_invariant_violated_in_childs_init(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -130,10 +146,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 100: self was an instance of B - self.x was 10"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was 10""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_method_violates_in_child(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -161,10 +181,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 100: self was an instance of B - self.x was 10"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was 10""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_triple_inheritance(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -195,10 +219,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was instance of C - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_abstract_method(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -229,10 +257,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of B - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestProperty(unittest.TestCase): @@ -261,10 +293,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inherited_setter(self) -> None: @icontract.invariant(lambda self: not self.toggled) @@ -294,10 +330,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inherited_deleter(self) -> None: @icontract.invariant(lambda self: not self.toggled) @@ -327,10 +367,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inherited_invariant_on_getter(self) -> None: @icontract.invariant(lambda self: not self.toggled) @@ -357,10 +401,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inherited_invariant_on_setter(self) -> None: @icontract.invariant(lambda self: not self.toggled) @@ -390,10 +438,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inherited_invariant_on_deleter(self) -> None: @icontract.invariant(lambda self: not self.toggled) @@ -423,11 +475,15 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_inheritance_postcondition.py b/tests/test_inheritance_postcondition.py index 6ae2316..14541f8 100644 --- a/tests/test_inheritance_postcondition.py +++ b/tests/test_inheritance_postcondition.py @@ -46,10 +46,14 @@ def some_func(self) -> None: return inst = SomeClass() - self.assertEqual(1, Increment.count) # Invariant needs to be checked once after the initialization. + self.assertEqual( + 1, Increment.count + ) # Invariant needs to be checked once after the initialization. inst.some_func() - self.assertEqual(3, Increment.count) # Invariant needs to be checked before and after some_func. + self.assertEqual( + 3, Increment.count + ) # Invariant needs to be checked before and after some_func. def test_count_checks_in_slot_wrappers(self) -> None: class Increment: @@ -66,10 +70,14 @@ class SomeClass: pass inst = SomeClass() - self.assertEqual(1, Increment.count) # Invariant needs to be checked once after the initialization. + self.assertEqual( + 1, Increment.count + ) # Invariant needs to be checked once after the initialization. _ = str(inst) - self.assertEqual(3, Increment.count) # Invariant needs to be checked before and after __str__. + self.assertEqual( + 3, Increment.count + ) # Invariant needs to be checked before and after __str__. class TestViolation(unittest.TestCase): @@ -92,10 +100,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result < 100: result was 1000 - self was an instance of B"""), tests.error.wo_mandatory_location(str(violation_error))) + self was an instance of B""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inherited_with_modified_implementation(self) -> None: class A(icontract.DBC): @@ -119,10 +131,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result < 100: result was 10000 - self was an instance of B"""), tests.error.wo_mandatory_location(str(violation_error))) + self was an instance of B""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_ensure_then_violated_in_base(self) -> None: class A(icontract.DBC): @@ -148,10 +164,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result % 2 == 0: result was 3 - self was an instance of B"""), tests.error.wo_mandatory_location(str(violation_error))) + self was an instance of B""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_ensure_then_violated_in_child(self) -> None: class A(icontract.DBC): @@ -177,10 +197,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result % 3 == 0: result was 2 - self was an instance of B"""), tests.error.wo_mandatory_location(str(violation_error))) + self was an instance of B""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_abstract_method(self) -> None: class A(icontract.DBC): @@ -208,10 +232,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result < 100: result was 1000 - self was an instance of B"""), tests.error.wo_mandatory_location(str(violation_error))) + self was an instance of B""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_that_base_postconditions_apply_to_init_if_not_defined(self) -> None: class A(icontract.DBC): @@ -234,12 +262,16 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x >= 0: result was None self was an instance of B self.x was -1 - x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_that_base_postconditions_dont_apply_to_init_if_overridden(self) -> None: class A(icontract.DBC): @@ -267,12 +299,16 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x < 0: result was None self was an instance of B self.x was 0 - x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestPropertyOK(unittest.TestCase): @@ -337,11 +373,15 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: result was 0 self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_setter(self) -> None: class SomeBase(icontract.DBC): @@ -375,12 +415,16 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: result was None self was an instance of SomeClass self.toggled was True - value was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + value was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_deleter(self) -> None: class SomeBase(icontract.DBC): @@ -414,11 +458,15 @@ def some_prop(self) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: result was None self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_setter_strengthened(self) -> None: class SomeBase(icontract.DBC): @@ -452,12 +500,16 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: result was None self was an instance of SomeClass self.toggled was True - value was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + value was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestInvalid(unittest.TestCase): @@ -480,10 +532,16 @@ class B(A): self.assertIsNotNone(type_err) if sys.version_info < (3, 9): - self.assertEqual("Can't instantiate abstract class B with abstract methods func", str(type_err)) + self.assertEqual( + "Can't instantiate abstract class B with abstract methods func", + str(type_err), + ) else: - self.assertEqual("Can't instantiate abstract class B with abstract method func", str(type_err)) + self.assertEqual( + "Can't instantiate abstract class B with abstract method func", + str(type_err), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_inheritance_precondition.py b/tests/test_inheritance_precondition.py index 4f8c459..55c81ab 100644 --- a/tests/test_inheritance_precondition.py +++ b/tests/test_inheritance_precondition.py @@ -68,10 +68,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < 100: self was an instance of B - x was 1000"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1000""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inherited_with_implementation(self) -> None: class A(icontract.DBC): @@ -95,10 +99,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < 100: self was an instance of B - x was 1000"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1000""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_require_else(self) -> None: class A(icontract.DBC): @@ -124,10 +132,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x % 3 == 0: self was an instance of B - x was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_triple_inheritance_wo_implementation(self) -> None: class A(icontract.DBC): @@ -152,10 +164,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < 100: self was an instance of C - x was 1000"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1000""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_triple_inheritance_with_implementation(self) -> None: class A(icontract.DBC): @@ -182,10 +198,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < 100: self was an instance of C - x was 1000"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1000""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_triple_inheritance_with_require_else(self) -> None: class A(icontract.DBC): @@ -216,10 +236,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x % 5 == 0: self was an instance of C - x was 7"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 7""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_abstract_method(self) -> None: class A(icontract.DBC): @@ -244,10 +268,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > 0: self was an instance of B - x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_that_base_preconditions_apply_to_init_if_not_defined(self) -> None: class A(icontract.DBC): @@ -267,10 +295,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x >= 0: self was an instance of B - x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_that_base_preconditions_dont_apply_to_init_if_overridden(self) -> None: class A(icontract.DBC): @@ -298,10 +330,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < 0: self was an instance of B - x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestPropertyOK(unittest.TestCase): @@ -366,10 +402,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_setter(self) -> None: class SomeBase(icontract.DBC): @@ -400,10 +440,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ value > 0: self was an instance of SomeClass - value was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + value was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_deleter(self) -> None: class SomeBase(icontract.DBC): @@ -437,10 +481,14 @@ def some_prop(self) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestConstructor(unittest.TestCase): @@ -470,14 +518,18 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > 0: self was an instance of B - x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_new_tightens_preconditions(self) -> None: class A(icontract.DBC): - def __new__(cls, xs: Sequence[int]) -> 'A': + def __new__(cls, xs: Sequence[int]) -> "A": return cast(A, xs) class B(A): @@ -485,7 +537,7 @@ class B(A): # __new__ is a special case: while other functions need to satisfy Liskov substitution principle, # __new__ is an exception. @icontract.require(lambda xs: all(x > 0 for x in xs)) - def __new__(cls, xs: Sequence[int]) -> 'B': + def __new__(cls, xs: Sequence[int]) -> "B": return cast(B, xs) def __repr__(self) -> str: @@ -501,11 +553,15 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all(x > 0 for x in xs): all(x > 0 for x in xs) was False, e.g., with x = -1 - xs was [-1, -2, -3]'''), tests.error.wo_mandatory_location(str(violation_error))) + xs was [-1, -2, -3]""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestInvalid(unittest.TestCase): @@ -528,9 +584,15 @@ class B(A): self.assertIsNotNone(type_err) if sys.version_info < (3, 9): - self.assertEqual("Can't instantiate abstract class B with abstract methods func", str(type_err)) + self.assertEqual( + "Can't instantiate abstract class B with abstract methods func", + str(type_err), + ) else: - self.assertEqual("Can't instantiate abstract class B with abstract method func", str(type_err)) + self.assertEqual( + "Can't instantiate abstract class B with abstract method func", + str(type_err), + ) def test_cant_weaken_base_function_without_preconditions(self) -> None: class A(icontract.DBC): @@ -545,6 +607,7 @@ class B(A): # pylint: disable=unused-variable @icontract.require(lambda x: x < 0) def func(self, x: int) -> int: return 1000 + except TypeError as err: type_error = err @@ -554,8 +617,10 @@ def func(self, x: int) -> int: "TestInvalid.test_cant_weaken_base_function_without_preconditions..B.func can not " "weaken the preconditions because the bases specify no preconditions at all. Hence this function must " "accept all possible input since the preconditions are OR'ed and no precondition implies a dummy " - "precondition which is always fulfilled.", str(type_error)) + "precondition which is always fulfilled.", + str(type_error), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_inheritance_snapshot.py b/tests/test_inheritance_snapshot.py index a656afb..666e5a8 100644 --- a/tests/test_inheritance_snapshot.py +++ b/tests/test_inheritance_snapshot.py @@ -53,14 +53,18 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ OLD.lst + [val] == self.lst: OLD was a bunch of OLD values OLD.lst was [] result was None self was an instance of B self.lst was [2, 1984] - val was 2"""), tests.error.wo_mandatory_location(str(violation_error))) + val was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_inherited_snapshot(self) -> None: class A(icontract.DBC): @@ -91,7 +95,8 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ OLD.len_lst + 1 == len(self.lst): OLD was a bunch of OLD values OLD.len_lst was 0 @@ -99,7 +104,10 @@ def __repr__(self) -> str: result was None self was an instance of B self.lst was [2, 1984] - val was 2"""), tests.error.wo_mandatory_location(str(violation_error))) + val was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestPropertyOK(unittest.TestCase): @@ -188,13 +196,17 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.gets == OLD.gets + 1: OLD was a bunch of OLD values OLD.gets was 0 result was 0 self was an instance of SomeClass - self.gets was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + self.gets was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) # setter fails violation_error = None @@ -205,14 +217,18 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.sets == OLD.sets + 1: OLD was a bunch of OLD values OLD.sets was 0 result was None self was an instance of SomeClass self.sets was 0 - value was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + value was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) # deleter fails violation_error = None @@ -223,13 +239,17 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.dels == OLD.dels + 1: OLD was a bunch of OLD values OLD.dels was 0 result was None self was an instance of SomeClass - self.dels was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + self.dels was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestInvalid(unittest.TestCase): @@ -259,9 +279,12 @@ def some_func(self, val: int) -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("There are conflicting snapshots with the name: 'len_lst'.\n\n" - "Please mind that the snapshots are inherited from the base classes. " - "Does one of the base classes defines a snapshot with the same name?", str(value_error)) + self.assertEqual( + "There are conflicting snapshots with the name: 'len_lst'.\n\n" + "Please mind that the snapshots are inherited from the base classes. " + "Does one of the base classes defines a snapshot with the same name?", + str(value_error), + ) class TestPropertyInvalid(unittest.TestCase): @@ -292,9 +315,12 @@ def some_prop(self) -> int: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("There are conflicting snapshots with the name: 'gets'.\n\n" - "Please mind that the snapshots are inherited from the base classes. " - "Does one of the base classes defines a snapshot with the same name?", str(value_error)) + self.assertEqual( + "There are conflicting snapshots with the name: 'gets'.\n\n" + "Please mind that the snapshots are inherited from the base classes. " + "Does one of the base classes defines a snapshot with the same name?", + str(value_error), + ) def test_setter_with_conflicting_snapshot_names(self) -> None: class SomeBase(icontract.DBC): @@ -332,9 +358,12 @@ def some_prop(self, value: int) -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("There are conflicting snapshots with the name: 'sets'.\n\n" - "Please mind that the snapshots are inherited from the base classes. " - "Does one of the base classes defines a snapshot with the same name?", str(value_error)) + self.assertEqual( + "There are conflicting snapshots with the name: 'sets'.\n\n" + "Please mind that the snapshots are inherited from the base classes. " + "Does one of the base classes defines a snapshot with the same name?", + str(value_error), + ) def test_deleter_with_conflicting_snapshot_names(self) -> None: class SomeBase(icontract.DBC): @@ -370,10 +399,13 @@ def some_prop(self) -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("There are conflicting snapshots with the name: 'dels'.\n\n" - "Please mind that the snapshots are inherited from the base classes. " - "Does one of the base classes defines a snapshot with the same name?", str(value_error)) + self.assertEqual( + "There are conflicting snapshots with the name: 'dels'.\n\n" + "Please mind that the snapshots are inherited from the base classes. " + "Does one of the base classes defines a snapshot with the same name?", + str(value_error), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_invariant.py b/tests/test_invariant.py index 14729b9..03e4f61 100644 --- a/tests/test_invariant.py +++ b/tests/test_invariant.py @@ -4,7 +4,13 @@ import textwrap import time import unittest -from typing import Dict, Iterator, Mapping, Optional, Any, NamedTuple # pylint: disable=unused-import +from typing import ( + Dict, + Iterator, + Mapping, + Optional, + Any, +) # pylint: disable=unused-import import icontract import tests.error @@ -218,7 +224,7 @@ def test_new_exempted(self) -> None: @icontract.invariant(lambda self: True) class Foo: - def __new__(cls, *args, **kwargs) -> 'Foo': # type: ignore + def __new__(cls, *args, **kwargs) -> "Foo": # type: ignore nonlocal new_call_counter new_call_counter += 1 return super(Foo, cls).__new__(cls) @@ -255,10 +261,10 @@ def __len__(self) -> int: return len(self._table) def __str__(self) -> str: - return '{}({})'.format(self.__class__.__name__, self._table) + return "{}({})".format(self.__class__.__name__, self._table) - f = Foo({'a': 1}) # test the constructor - _ = f['a'] # test __getitem__ + f = Foo({"a": 1}) # test the constructor + _ = f["a"] # test __getitem__ _ = iter(f) # test __iter__ _ = len(f) # test __len__ _ = str(f) # test __str__ @@ -289,10 +295,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of SomeClass - self.x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inv_as_precondition(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -316,10 +326,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of SomeClass - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_method(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -342,10 +356,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of SomeClass - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_magic_method(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -368,10 +386,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of SomeClass - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_multiple_invs_first_violated(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -391,10 +413,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of SomeClass - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_multiple_invs_last_violated(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -414,10 +440,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x < 10: self was an instance of SomeClass - self.x was 100"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was 100""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inv_violated_after_pre(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -441,10 +471,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ y > 0: self was an instance of SomeClass - y was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + y was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) violation_error = None try: @@ -455,10 +489,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of SomeClass - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inv_ok_but_post_violated(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -483,10 +521,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result > 0: result was -1 - self was an instance of SomeClass"""), tests.error.wo_mandatory_location(str(violation_error))) + self was an instance of SomeClass""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inv_violated_but_post_ok(self) -> None: @icontract.invariant(lambda self: self.x > 0) @@ -511,10 +553,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of SomeClass - self.x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_inv_with_empty_arguments(self) -> None: z = 42 @@ -532,13 +578,17 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ z != 42: self was an instance of A - z was 42"""), tests.error.wo_mandatory_location(str(violation_error))) + z was 42""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_condition_as_function(self) -> None: - def some_condition(self: 'A') -> bool: + def some_condition(self: "A") -> bool: return self.x > 0 @icontract.invariant(some_condition) @@ -563,10 +613,13 @@ def __repr__(self) -> str: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('some_condition: self was A(x=-1)', tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "some_condition: self was A(x=-1)", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_condition_as_function_with_default_argument_value(self) -> None: - def some_condition(self: 'A', y: int = 0) -> bool: + def some_condition(self: "A", y: int = 0) -> bool: return self.x > y @icontract.invariant(some_condition) @@ -591,7 +644,10 @@ def __repr__(self) -> str: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('some_condition: self was A(x=-1)', tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "some_condition: self was A(x=-1)", + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestProperty(unittest.TestCase): @@ -619,10 +675,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_property_setter(self) -> None: @icontract.invariant(lambda self: not self.toggled) @@ -651,10 +711,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_property_deleter(self) -> None: @icontract.invariant(lambda self: not self.toggled) @@ -683,10 +747,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ not self.toggled: self was an instance of SomeClass - self.toggled was True"""), tests.error.wo_mandatory_location(str(violation_error))) + self.toggled was True""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestError(unittest.TestCase): @@ -708,14 +776,22 @@ def __repr__(self) -> str: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.x > 0: self was an instance of A - self.x was 0"""), tests.error.wo_mandatory_location(str(value_error))) + self.x was 0""" + ), + tests.error.wo_mandatory_location(str(value_error)), + ) def test_as_function(self) -> None: @icontract.invariant( - lambda self: self.x > 0, error=lambda self: ValueError("x must be positive, but got: {}".format(self.x))) + lambda self: self.x > 0, + error=lambda self: ValueError( + "x must be positive, but got: {}".format(self.x) + ), + ) class A: def __init__(self) -> None: self.x = 0 @@ -731,10 +807,12 @@ def __repr__(self) -> str: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('x must be positive, but got: 0', str(value_error)) + self.assertEqual("x must be positive, but got: 0", str(value_error)) def test_as_function_with_empty_args(self) -> None: - @icontract.invariant(lambda self: self.x > 0, error=lambda: ValueError("x must be positive")) + @icontract.invariant( + lambda self: self.x > 0, error=lambda: ValueError("x must be positive") + ) class A: def __init__(self) -> None: self.x = 0 @@ -750,7 +828,7 @@ def __repr__(self) -> str: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('x must be positive', str(value_error)) + self.assertEqual("x must be positive", str(value_error)) class TestToggling(unittest.TestCase): @@ -765,7 +843,9 @@ def __init__(self) -> None: class TestBenchmark(unittest.TestCase): - @unittest.skip("Skipped the benchmark, execute manually on a prepared benchmark machine.") + @unittest.skip( + "Skipped the benchmark, execute manually on a prepared benchmark machine." + ) def test_benchmark_when_disabled(self) -> None: def some_long_condition() -> bool: time.sleep(5) @@ -805,8 +885,10 @@ def __init__(self) -> None: val_err = err self.assertIsNotNone(val_err) - self.assertEqual("Expected an invariant condition with at most an argument 'self', but got: ['self', 'z']", - str(val_err)) + self.assertEqual( + "Expected an invariant condition with at most an argument 'self', but got: ['self', 'z']", + str(val_err), + ) def test_no_boolyness(self) -> None: @icontract.invariant(lambda self: tests.mock.NumpyArray([True, False])) @@ -821,9 +903,11 @@ def __init__(self) -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual('Failed to negate the evaluation of the condition.', - tests.error.wo_mandatory_location(str(value_error))) + self.assertEqual( + "Failed to negate the evaluation of the condition.", + tests.error.wo_mandatory_location(str(value_error)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_mypy_decorators.py b/tests/test_mypy_decorators.py index f97c419..d135d61 100644 --- a/tests/test_mypy_decorators.py +++ b/tests/test_mypy_decorators.py @@ -6,8 +6,12 @@ import unittest if sys.version_info < (3, 8): - raise unittest.SkipTest(("Mypy since 1.5 dropped support for Python 3.7 and " - "you are running Python {}, so skipping.").format(sys.version_info)) + raise unittest.SkipTest( + ( + "Mypy since 1.5 dropped support for Python 3.7 and " + "you are running Python {}, so skipping." + ).format(sys.version_info) + ) class TestMypyDecorators(unittest.TestCase): @@ -35,21 +39,27 @@ def f3(x: int) -> int: ''' pth = pathlib.Path(tmpdir) / "source.py" - pth.write_text(content, encoding='utf-8') + pth.write_text(content, encoding="utf-8") with subprocess.Popen( - ['mypy', '--strict', str(pth)], universal_newlines=True, stdout=subprocess.PIPE) as proc: + ["mypy", "--strict", str(pth)], + universal_newlines=True, + stdout=subprocess.PIPE, + ) as proc: out, err = proc.communicate() self.assertIsNone(err) self.assertEqual( - '''\ + """\ {path}:8: error: Argument 1 to "f1" has incompatible type "str"; expected "int" [arg-type] {path}:13: error: Argument 1 to "f2" has incompatible type "str"; expected "int" [arg-type] {path}:18: error: Argument 1 to "f3" has incompatible type "str"; expected "int" [arg-type] Found 3 errors in 1 file (checked 1 source file) -'''.format(path=pth), - out) +""".format( + path=pth + ), + out, + ) def test_class_type_when_decorated_with_invariant(self) -> None: with tempfile.TemporaryDirectory(prefix="mypy_fail_case_") as tmpdir: @@ -72,20 +82,26 @@ def __init__(self) -> None: ''' pth = pathlib.Path(tmpdir) / "source.py" - pth.write_text(content, encoding='utf-8') + pth.write_text(content, encoding="utf-8") with subprocess.Popen( - ['mypy', '--strict', str(pth)], universal_newlines=True, stdout=subprocess.PIPE) as proc: + ["mypy", "--strict", str(pth)], + universal_newlines=True, + stdout=subprocess.PIPE, + ) as proc: out, err = proc.communicate() self.assertIsNone(err) self.assertEqual( - '''\ + """\ {path}:8: note: Revealed type is "def () -> source.SomeClass" {path}:15: note: Revealed type is "def () -> source.Decorated" Success: no issues found in 1 source file -'''.format(path=pth), - out) +""".format( + path=pth + ), + out, + ) def test_that_mypy_complains_when_decorating_non_type_with_invariant(self) -> None: with tempfile.TemporaryDirectory(prefix="mypy_fail_case_") as tmpdir: @@ -100,20 +116,26 @@ def some_func() -> None: ''' pth = pathlib.Path(tmpdir) / "source.py" - pth.write_text(content, encoding='utf-8') + pth.write_text(content, encoding="utf-8") with subprocess.Popen( - ['mypy', '--strict', str(pth)], universal_newlines=True, stdout=subprocess.PIPE) as proc: + ["mypy", "--strict", str(pth)], + universal_newlines=True, + stdout=subprocess.PIPE, + ) as proc: out, err = proc.communicate() self.assertIsNone(err) self.assertEqual( - '''\ + """\ {path}:5: error: Value of type variable "ClassT" of "__call__" of "invariant" cannot be "Callable[[], None]" [type-var] Found 1 error in 1 file (checked 1 source file) -'''.format(path=pth), - out) +""".format( + path=pth + ), + out, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_postcondition.py b/tests/test_postcondition.py index b87e0ed..e5166d6 100644 --- a/tests/test_postcondition.py +++ b/tests/test_postcondition.py @@ -40,11 +40,15 @@ def some_func(x: int, y: int = 5) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result > x: result was -4 x was 1 - y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_condition_as_function(self) -> None: def some_condition(result: int) -> bool: @@ -66,10 +70,14 @@ def some_func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ some_condition: result was 1 - x was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_condition_as_function_with_default_argument_value(self) -> None: def some_condition(result: int, y: int = 0) -> bool: @@ -91,10 +99,14 @@ def some_func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ some_condition: result was -1 - x was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_condition_as_function_with_default_argument_value_set(self) -> None: def some_condition(result: int, y: int = 0) -> bool: @@ -116,11 +128,15 @@ def some_func(x: int, y: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ some_condition: result was 1 x was 1 - y was 3"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_description(self) -> None: @icontract.ensure(lambda result, x: result > x, "expected summation") @@ -135,11 +151,15 @@ def some_func(x: int, y: int = 5) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ expected summation: result > x: result was -4 x was 1 - y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_stacked_decorators(self) -> None: def mydecorator(f: CallableT) -> CallableT: @@ -168,12 +188,16 @@ def some_func(x: int, y: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ y > result + another_var: another_var was 2 result was 100 x was 0 - y was 10"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 10""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_default_values_outer(self) -> None: @icontract.ensure(lambda result, c: result % c == 0) @@ -190,12 +214,16 @@ def some_func(a: int, b: int = 21, c: int = 2) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result % c == 0: a was 13 b was 21 c was 2 - result was 13"""), tests.error.wo_mandatory_location(str(violation_error))) + result was 13""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) # Check the inner post condition violation_error = None @@ -206,12 +234,16 @@ def some_func(a: int, b: int = 21, c: int = 2) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result < b: a was 36 b was 21 c was 2 - result was 36"""), tests.error.wo_mandatory_location(str(violation_error))) + result was 36""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_only_result(self) -> None: @icontract.ensure(lambda result: result > 3) @@ -226,10 +258,14 @@ def some_func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result > 3: result was 0 - x was 10000"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 10000""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestError(unittest.TestCase): @@ -247,13 +283,20 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result > 0: result was 0 - x was 0"""), tests.error.wo_mandatory_location(str(value_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(value_error)), + ) def test_as_function(self) -> None: - @icontract.ensure(lambda result: result > 0, error=lambda result: ValueError("result must be positive.")) + @icontract.ensure( + lambda result: result > 0, + error=lambda result: ValueError("result must be positive."), + ) def some_func(x: int) -> int: return x @@ -265,10 +308,13 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('result must be positive.', str(value_error)) + self.assertEqual("result must be positive.", str(value_error)) def test_with_empty_args(self) -> None: - @icontract.ensure(lambda result: result > 0, error=lambda: ValueError("result must be positive")) + @icontract.ensure( + lambda result: result > 0, + error=lambda: ValueError("result must be positive"), + ) def some_func(x: int) -> int: return x @@ -280,11 +326,15 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('result must be positive', str(value_error)) + self.assertEqual("result must be positive", str(value_error)) def test_with_different_args_from_condition(self) -> None: @icontract.ensure( - lambda result: result > 0, error=lambda x, result: ValueError("x is {}, result is {}".format(x, result))) + lambda result: result > 0, + error=lambda x, result: ValueError( + "x is {}, result is {}".format(x, result) + ), + ) def some_func(x: int) -> int: return x @@ -296,7 +346,7 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('x is 0, result is 0', str(value_error)) + self.assertEqual("x is 0, result is 0", str(value_error)) class TestToggling(unittest.TestCase): @@ -334,16 +384,20 @@ def some_func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result != 0: result was 0 - x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_postcondition_in_class_method(self) -> None: class SomeClass: @classmethod @icontract.ensure(lambda result: result != 0) - def some_func(cls: Type['SomeClass'], x: int) -> int: + def some_func(cls: Type["SomeClass"], x: int) -> int: return x result = SomeClass.some_func(x=1) @@ -357,10 +411,14 @@ def some_func(cls: Type['SomeClass'], x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result != 0: result was 0 - x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_postcondition_in_abstract_static_method(self) -> None: class SomeAbstract(icontract.DBC): @@ -386,22 +444,26 @@ def some_func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result != 0: result was 0 - x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_postcondition_in_abstract_class_method(self) -> None: class Abstract(icontract.DBC): @classmethod @abc.abstractmethod @icontract.ensure(lambda result: result != 0) - def some_func(cls: Type['Abstract'], x: int) -> int: + def some_func(cls: Type["Abstract"], x: int) -> int: pass class SomeClass(Abstract): @classmethod - def some_func(cls: Type['SomeClass'], x: int) -> int: + def some_func(cls: Type["SomeClass"], x: int) -> int: return x result = SomeClass.some_func(x=1) @@ -415,10 +477,14 @@ def some_func(cls: Type['SomeClass'], x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result != 0: result was 0 - x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_getter(self) -> None: class SomeClass: @@ -443,10 +509,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ result > 0: result was -1 - self was an instance of SomeClass"""), tests.error.wo_mandatory_location(str(violation_error))) + self was an instance of SomeClass""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_setter(self) -> None: class SomeClass: @@ -472,12 +542,16 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.some_prop > 0: result was None self was an instance of SomeClass self.some_prop was 0 - value was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + value was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_deleter(self) -> None: class SomeClass: @@ -506,12 +580,16 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.some_prop > 0: result was None self was an instance of SomeClass - self.some_prop was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.some_prop was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_precondition.py b/tests/test_precondition.py index d5e0d2b..a0f04f3 100644 --- a/tests/test_precondition.py +++ b/tests/test_precondition.py @@ -41,10 +41,14 @@ def some_func(x: int, y: int = 5) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > 3: x was 1 - y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_description(self) -> None: @icontract.require(lambda x: x > 3, "x must not be small") @@ -59,10 +63,14 @@ def some_func(x: int, y: int = 5) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x must not be small: x > 3: x was 1 - y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_condition_as_function(self) -> None: def some_condition(x: int) -> bool: @@ -80,10 +88,14 @@ def some_func(x: int, y: int = 5) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ some_condition: x was 1 - y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_condition_as_function_with_default_argument_value(self) -> None: def some_condition(x: int, y: int = 0) -> bool: @@ -104,7 +116,10 @@ def some_func(x: int) -> None: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("some_condition: x was -1", tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "some_condition: x was -1", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_condition_as_function_with_default_argument_value_set(self) -> None: def some_condition(x: int, y: int = 0) -> bool: @@ -126,10 +141,14 @@ def some_func(x: int, y: int) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ some_condition: x was -1 - y was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_pathlib(self) -> None: @icontract.require(lambda path: path.exists()) @@ -143,15 +162,18 @@ def some_func(path: pathlib.Path) -> None: violation_error = err # This dummy path is necessary to obtain the class name. - dummy_path = pathlib.Path('/also/doesnt/exist') + dummy_path = pathlib.Path("/also/doesnt/exist") self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ path.exists(): path was {}('/doesnt/exist/test_contract') - path.exists() was False''').format(dummy_path.__class__.__name__), - tests.error.wo_mandatory_location(str(violation_error))) + path.exists() was False""" + ).format(dummy_path.__class__.__name__), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_multiple_comparators(self) -> None: @icontract.require(lambda x: 0 < x < 3) @@ -165,7 +187,10 @@ def some_func(x: int) -> str: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("0 < x < 3: x was 10", tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "0 < x < 3: x was 10", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_stacked_decorators(self) -> None: def mydecorator(f: CallableT) -> CallableT: @@ -194,10 +219,14 @@ def some_func(x: int) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > another_var: another_var was 0 - x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_default_values(self) -> None: @icontract.require(lambda a: a < 10) @@ -214,11 +243,15 @@ def some_func(a: int, b: int = 21, c: int = 22) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ c < 10: a was 2 b was 21 - c was 22"""), tests.error.wo_mandatory_location(str(violation_error))) + c was 22""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) violation_error = None try: @@ -228,15 +261,21 @@ def some_func(a: int, b: int = 21, c: int = 22) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ b < 10: a was 2 b was 21 - c was 8"""), tests.error.wo_mandatory_location(str(violation_error))) + c was 8""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestBenchmark(unittest.TestCase): - @unittest.skip("Skipped the benchmark, execute manually on a prepared benchmark machine.") + @unittest.skip( + "Skipped the benchmark, execute manually on a prepared benchmark machine." + ) def test_enabled(self) -> None: @icontract.require(lambda x: x > 3) def pow_with_pre(x: int, y: int) -> int: @@ -264,7 +303,9 @@ def pow_wo_pre(x: int, y: int) -> int: self.assertLess(duration_with_pre / duration_wo_pre, 6) - @unittest.skip("Skipped the benchmark, execute manually on a prepared benchmark machine.") + @unittest.skip( + "Skipped the benchmark, execute manually on a prepared benchmark machine." + ) def test_disabled(self) -> None: @icontract.require(lambda x: x > 3, enabled=False) def pow_with_pre(x: int, y: int) -> int: @@ -304,10 +345,14 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('x > 0: x was 0', tests.error.wo_mandatory_location(str(value_error))) + self.assertEqual( + "x > 0: x was 0", tests.error.wo_mandatory_location(str(value_error)) + ) def test_as_function(self) -> None: - @icontract.require(lambda x: x > 0, error=lambda x: ValueError("x non-negative")) + @icontract.require( + lambda x: x > 0, error=lambda x: ValueError("x non-negative") + ) def some_func(x: int) -> int: return 0 @@ -319,12 +364,15 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('x non-negative', str(value_error)) + self.assertEqual("x non-negative", str(value_error)) def test_as_function_with_outer_scope(self) -> None: z = 42 - @icontract.require(lambda x: x > 0, error=lambda x: ValueError("x non-negative, z: {}".format(z))) + @icontract.require( + lambda x: x > 0, + error=lambda x: ValueError("x non-negative, z: {}".format(z)), + ) def some_func(x: int) -> int: return 0 @@ -336,10 +384,12 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('x non-negative, z: 42', str(value_error)) + self.assertEqual("x non-negative, z: 42", str(value_error)) def test_with_empty_args(self) -> None: - @icontract.require(lambda x: x > 0, error=lambda: ValueError("x must be positive")) + @icontract.require( + lambda x: x > 0, error=lambda: ValueError("x must be positive") + ) def some_func(x: int) -> int: return 0 @@ -351,10 +401,13 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('x must be positive', str(value_error)) + self.assertEqual("x must be positive", str(value_error)) def test_with_different_args_from_condition(self) -> None: - @icontract.require(lambda x: x > 0, error=lambda x, y: ValueError("x is {}, y is {}".format(x, y))) + @icontract.require( + lambda x: x > 0, + error=lambda x, y: ValueError("x is {}, y is {}".format(x, y)), + ) def some_func(x: int, y: int) -> int: return 0 @@ -366,7 +419,7 @@ def some_func(x: int, y: int) -> int: self.assertIsNotNone(value_error) self.assertIsInstance(value_error, ValueError) - self.assertEqual('x is 0, y is 10', str(value_error)) + self.assertEqual("x is 0, y is 10", str(value_error)) class TestToggling(unittest.TestCase): @@ -413,10 +466,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > 3: self was an instance of A - x was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) # Test method with self violation_error = None @@ -427,10 +484,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.y > 10: self was an instance of A - self.y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + self.y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_unbound_instance_method_with_self_as_kwarg(self) -> None: class A: @@ -456,10 +517,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.y > 10: self was an instance of A - self.y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + self.y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_getter(self) -> None: class SomeClass: @@ -484,10 +549,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self._some_prop > 0: self was an instance of SomeClass - self._some_prop was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self._some_prop was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_setter(self) -> None: class SomeClass: @@ -513,10 +582,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ value > 0: self was an instance of SomeClass - value was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + value was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_deleter(self) -> None: class SomeClass: @@ -545,10 +618,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.some_prop > 0: self was an instance of SomeClass - self.some_prop was -1"""), tests.error.wo_mandatory_location(str(violation_error))) + self.some_prop was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestInvalid(unittest.TestCase): @@ -564,12 +641,17 @@ def some_function(a: int) -> None: # pylint: disable=unused-variable type_err = err self.assertIsNotNone(type_err) - self.assertEqual("The argument(s) of the contract condition have not been set: ['b']. " - "Does the original function define them? Did you supply them in the call?", - tests.error.wo_mandatory_location(str(type_err))) + self.assertEqual( + "The argument(s) of the contract condition have not been set: ['b']. " + "Does the original function define them? Did you supply them in the call?", + tests.error.wo_mandatory_location(str(type_err)), + ) def test_error_with_invalid_arguments(self) -> None: - @icontract.require(lambda x: x > 0, error=lambda x, z: ValueError("x is {}, y is {}".format(x, z))) + @icontract.require( + lambda x: x > 0, + error=lambda x, z: ValueError("x is {}, y is {}".format(x, z)), + ) def some_func(x: int, y: int) -> int: return 0 @@ -580,9 +662,11 @@ def some_func(x: int, y: int) -> int: type_error = err self.assertIsNotNone(type_error) - self.assertEqual("The argument(s) of the contract error have not been set: ['z']. " - "Does the original function define them? Did you supply them in the call?", - tests.error.wo_mandatory_location(str(type_error))) + self.assertEqual( + "The argument(s) of the contract error have not been set: ['z']. " + "Does the original function define them? Did you supply them in the call?", + tests.error.wo_mandatory_location(str(type_error)), + ) def test_no_boolyness(self) -> None: @icontract.require(lambda: tests.mock.NumpyArray([True, False])) @@ -596,8 +680,10 @@ def some_func() -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual('Failed to negate the evaluation of the condition.', - tests.error.wo_mandatory_location(str(value_error))) + self.assertEqual( + "Failed to negate the evaluation of the condition.", + tests.error.wo_mandatory_location(str(value_error)), + ) def test_unexpected_positional_argument(self) -> None: @icontract.require(lambda: True) @@ -612,7 +698,9 @@ def some_func() -> None: self.assertIsNotNone(type_error) self.assertRegex( - str(type_error), r'^([a-zA-Z_0-9<>.]+\.)?some_func\(\) takes 0 positional arguments but 1 was given$') + str(type_error), + r"^([a-zA-Z_0-9<>.]+\.)?some_func\(\) takes 0 positional arguments but 1 was given$", + ) def test_unexpected_keyword_argument(self) -> None: @icontract.require(lambda: True) @@ -628,8 +716,10 @@ def some_func() -> None: self.assertIsNotNone(type_error) self.assertRegex( - str(type_error), r"^([a-zA-Z_0-9<>.]+\.)?some_func\(\) got an unexpected keyword argument 'x'$") + str(type_error), + r"^([a-zA-Z_0-9<>.]+\.)?some_func\(\) got an unexpected keyword argument 'x'$", + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_recompute.py b/tests/test_recompute.py index f02e057..7e6fc23 100644 --- a/tests/test_recompute.py +++ b/tests/test_recompute.py @@ -38,31 +38,41 @@ def translate_all_expression(input_source_code: str) -> str: module_node = icontract._recompute._translate_all_expression_to_a_module( generator_exp=generator_exp, - generated_function_name='some_func', - name_to_value={'SOME_GLOBAL_CONSTANT': 10}) + generated_function_name="some_func", + name_to_value={"SOME_GLOBAL_CONSTANT": 10}, + ) got = astor.to_source(module_node) # We need to replace the UUID of the result variable for reproducibility. - got = re.sub(r'icontract_tracing_all_result_[a-zA-Z0-9]+', 'icontract_tracing_all_result', got) + got = re.sub( + r"icontract_tracing_all_result_[a-zA-Z0-9]+", + "icontract_tracing_all_result", + got, + ) assert isinstance(got, str) return got def test_global_variable(self) -> None: - input_source_code = textwrap.dedent('''\ + input_source_code = textwrap.dedent( + """\ all( x > SOME_GLOBAL_CONSTANT for x in lst ) - ''') + """ + ) - got_source_code = TestTranslationForTracingAll.translate_all_expression(input_source_code=input_source_code) + got_source_code = TestTranslationForTracingAll.translate_all_expression( + input_source_code=input_source_code + ) # Please see ``TestTranslationForTracingAll.translate_all_expression`` and the note about ``name_to_value`` # if you wonder why ``lst`` is not in the arguments. self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ def some_func(SOME_GLOBAL_CONSTANT): for x in lst: icontract_tracing_all_result = (x > @@ -74,10 +84,14 @@ def some_func(SOME_GLOBAL_CONSTANT): icontract_tracing_all_result, (('x', x),)) return icontract_tracing_all_result, None - '''), got_source_code) + """ + ), + got_source_code, + ) def test_translation_two_fors_and_two_ifs(self) -> None: - input_source_code = textwrap.dedent('''\ + input_source_code = textwrap.dedent( + """\ all( cell > SOME_GLOBAL_CONSTANT for i, row in enumerate(matrix) @@ -85,14 +99,18 @@ def test_translation_two_fors_and_two_ifs(self) -> None: for j, cell in enumerate(row) if i == j ) - ''') + """ + ) - got_source_code = TestTranslationForTracingAll.translate_all_expression(input_source_code=input_source_code) + got_source_code = TestTranslationForTracingAll.translate_all_expression( + input_source_code=input_source_code + ) # Please see ``TestTranslationForTracingAll.translate_all_expression`` and the note about ``name_to_value`` # if you wonder why ``matrix`` is not in the arguments. self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ def some_func(SOME_GLOBAL_CONSTANT): for i, row in enumerate(matrix): if i > 0: @@ -109,24 +127,32 @@ def some_func(SOME_GLOBAL_CONSTANT): , (('i', i), ('row', row), ('j', j), ('cell', cell))) return icontract_tracing_all_result, None - '''), got_source_code) + """ + ), + got_source_code, + ) def test_nested_all(self) -> None: # Nesting is not recursively followed by design. Only the outer-most all expression should be traced. - input_source_code = textwrap.dedent('''\ + input_source_code = textwrap.dedent( + """\ all( all(cell > SOME_GLOBAL_CONSTANT for cell in row) for row in matrix ) - ''') + """ + ) - got_source_code = TestTranslationForTracingAll.translate_all_expression(input_source_code=input_source_code) + got_source_code = TestTranslationForTracingAll.translate_all_expression( + input_source_code=input_source_code + ) # Please see ``TestTranslationForTracingAll.translate_all_expression`` and the note about ``name_to_value`` # if you wonder why ``matrix`` is not in the arguments. self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ def some_func(SOME_GLOBAL_CONSTANT): for row in matrix: icontract_tracing_all_result = all( @@ -138,7 +164,10 @@ def some_func(SOME_GLOBAL_CONSTANT): icontract_tracing_all_result, (('row', row),)) return icontract_tracing_all_result, None - '''), got_source_code) + """ + ), + got_source_code, + ) if __name__ == "__main__": diff --git a/tests/test_recursion.py b/tests/test_recursion.py index 2a78e73..8df2b9b 100644 --- a/tests/test_recursion.py +++ b/tests/test_recursion.py @@ -1,5 +1,6 @@ # pylint: disable=missing-docstring # pylint: disable=no-self-use +# pylint: disable=unnecessary-lambda import unittest from typing import List @@ -10,14 +11,16 @@ class TestPrecondition(unittest.TestCase): def test_ok(self) -> None: order = [] # type: List[str] - @icontract.require(lambda: another_func()) # pylint: disable=unnecessary-lambda - @icontract.require(lambda: yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.require(lambda: another_func()) + @icontract.require(lambda: yet_another_func()) def some_func() -> bool: order.append(some_func.__name__) return True - @icontract.require(lambda: some_func()) # pylint: disable=unnecessary-lambda - @icontract.require(lambda: yet_yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.require(lambda: some_func()) + @icontract.require( + lambda: yet_yet_another_func() + ) # pylint: disable=unnecessary-lambda def another_func() -> bool: order.append(another_func.__name__) return True @@ -32,8 +35,16 @@ def yet_yet_another_func() -> bool: some_func() - self.assertListEqual(['yet_another_func', 'yet_yet_another_func', 'some_func', 'another_func', 'some_func'], - order) + self.assertListEqual( + [ + "yet_another_func", + "yet_yet_another_func", + "some_func", + "another_func", + "some_func", + ], + order, + ) def test_recover_after_exception(self) -> None: order = [] # type: List[str] @@ -42,16 +53,16 @@ def test_recover_after_exception(self) -> None: class CustomError(Exception): pass - @icontract.require(lambda: another_func()) # pylint: disable=unnecessary-lambda - @icontract.require(lambda: yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.require(lambda: another_func()) + @icontract.require(lambda: yet_another_func()) def some_func() -> bool: order.append(some_func.__name__) if some_func_should_raise: - raise CustomError('some_func_should_raise') + raise CustomError("some_func_should_raise") return True - @icontract.require(lambda: some_func()) # pylint: disable=unnecessary-lambda - @icontract.require(lambda: yet_yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.require(lambda: some_func()) + @icontract.require(lambda: yet_yet_another_func()) def another_func() -> bool: order.append(another_func.__name__) return True @@ -69,7 +80,9 @@ def yet_yet_another_func() -> bool: except CustomError: pass - self.assertListEqual(['yet_another_func', 'yet_yet_another_func', 'some_func'], order) + self.assertListEqual( + ["yet_another_func", "yet_yet_another_func", "some_func"], order + ) # Reset for the next experiment order = [] @@ -77,8 +90,16 @@ def yet_yet_another_func() -> bool: some_func() - self.assertListEqual(['yet_another_func', 'yet_yet_another_func', 'some_func', 'another_func', 'some_func'], - order) + self.assertListEqual( + [ + "yet_another_func", + "yet_yet_another_func", + "some_func", + "another_func", + "some_func", + ], + order, + ) class TestPostcondition(unittest.TestCase): @@ -89,18 +110,18 @@ def test_ok(self) -> None: class CustomError(Exception): pass - @icontract.ensure(lambda: another_func()) # pylint: disable=unnecessary-lambda - @icontract.ensure(lambda: yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.ensure(lambda: another_func()) + @icontract.ensure(lambda: yet_another_func()) def some_func() -> bool: order.append(some_func.__name__) return True - @icontract.ensure(lambda: some_func()) # pylint: disable=unnecessary-lambda - @icontract.ensure(lambda: yet_yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.ensure(lambda: some_func()) + @icontract.ensure(lambda: yet_yet_another_func()) def another_func() -> bool: order.append(another_func.__name__) if another_func_should_raise: - raise CustomError('some_func_should_raise') + raise CustomError("some_func_should_raise") return True @@ -117,7 +138,7 @@ def yet_yet_another_func() -> bool: except CustomError: pass - self.assertListEqual(['some_func', 'yet_another_func', 'another_func'], order) + self.assertListEqual(["some_func", "yet_another_func", "another_func"], order) # Reset for the next experiments order = [] @@ -125,20 +146,28 @@ def yet_yet_another_func() -> bool: some_func() - self.assertListEqual(['some_func', 'yet_another_func', 'another_func', 'yet_yet_another_func', 'some_func'], - order) + self.assertListEqual( + [ + "some_func", + "yet_another_func", + "another_func", + "yet_yet_another_func", + "some_func", + ], + order, + ) def test_recover_after_exception(self) -> None: order = [] # type: List[str] - @icontract.ensure(lambda: another_func()) # pylint: disable=unnecessary-lambda - @icontract.ensure(lambda: yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.ensure(lambda: another_func()) + @icontract.ensure(lambda: yet_another_func()) def some_func() -> bool: order.append(some_func.__name__) return True - @icontract.ensure(lambda: some_func()) # pylint: disable=unnecessary-lambda - @icontract.ensure(lambda: yet_yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.ensure(lambda: some_func()) + @icontract.ensure(lambda: yet_yet_another_func()) def another_func() -> bool: order.append(another_func.__name__) return True @@ -153,8 +182,16 @@ def yet_yet_another_func() -> bool: some_func() - self.assertListEqual(['some_func', 'yet_another_func', 'another_func', 'yet_yet_another_func', 'some_func'], - order) + self.assertListEqual( + [ + "some_func", + "yet_another_func", + "another_func", + "yet_yet_another_func", + "some_func", + ], + order, + ) class TestInvariant(unittest.TestCase): @@ -164,24 +201,24 @@ def test_ok(self) -> None: @icontract.invariant(lambda self: self.some_func()) class SomeClass(icontract.DBC): def __init__(self) -> None: - order.append('__init__') + order.append("__init__") def some_func(self) -> bool: - order.append('some_func') + order.append("some_func") return True def another_func(self) -> bool: - order.append('another_func') + order.append("another_func") return True some_instance = SomeClass() - self.assertListEqual(['__init__', 'some_func'], order) + self.assertListEqual(["__init__", "some_func"], order) # Reset for the next experiment order = [] some_instance.another_func() - self.assertListEqual(['some_func', 'another_func', 'some_func'], order) + self.assertListEqual(["some_func", "another_func", "some_func"], order) def test_recover_after_exception(self) -> None: order = [] # type: List[str] @@ -193,21 +230,21 @@ class CustomError(Exception): @icontract.invariant(lambda self: self.some_func()) class SomeClass(icontract.DBC): def __init__(self) -> None: - order.append('__init__') + order.append("__init__") def some_func(self) -> bool: - order.append('some_func') + order.append("some_func") if some_func_should_raise: - raise CustomError('some_func_should_raise') + raise CustomError("some_func_should_raise") return True def another_func(self) -> bool: - order.append('another_func') + order.append("another_func") return True some_instance = SomeClass() - self.assertListEqual(['__init__', 'some_func'], order) + self.assertListEqual(["__init__", "some_func"], order) # Reset for the next experiment order = [] @@ -218,14 +255,14 @@ def another_func(self) -> bool: except CustomError: pass - self.assertListEqual(['some_func'], order) + self.assertListEqual(["some_func"], order) # Reset for the next experiment order = [] some_func_should_raise = False some_instance.another_func() - self.assertListEqual(['some_func', 'another_func', 'some_func'], order) + self.assertListEqual(["some_func", "another_func", "some_func"], order) def test_member_function_call_in_constructor(self) -> None: order = [] # type: List[str] @@ -233,17 +270,17 @@ def test_member_function_call_in_constructor(self) -> None: @icontract.invariant(lambda self: self.some_attribute > 0) class SomeClass(icontract.DBC): def __init__(self) -> None: - order.append('__init__ enters') + order.append("__init__ enters") self.some_attribute = self.some_func() - order.append('__init__ exits') + order.append("__init__ exits") def some_func(self) -> int: - order.append('some_func') + order.append("some_func") return 3 _ = SomeClass() - self.assertListEqual(['__init__ enters', 'some_func', '__init__ exits'], order) + self.assertListEqual(["__init__ enters", "some_func", "__init__ exits"], order) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_represent.py b/tests/test_represent.py index 7f37a8b..e2f3395 100644 --- a/tests/test_represent.py +++ b/tests/test_represent.py @@ -30,7 +30,9 @@ def func(x: int) -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("x < 5: x was 100", tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "x < 5: x was 100", tests.error.wo_mandatory_location(str(violation_error)) + ) def test_str(self) -> None: @icontract.require(lambda x: x != "oi") @@ -44,7 +46,10 @@ def func(x: str) -> str: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("""x != "oi": x was 'oi'""", tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + """x != "oi": x was 'oi'""", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_bytes(self) -> None: @icontract.require(lambda x: x != b"oi") @@ -58,7 +63,10 @@ def func(x: bytes) -> bytes: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("""x != b"oi": x was b'oi'""", tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + """x != b"oi": x was b'oi'""", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_bool(self) -> None: @icontract.require(lambda x: x is not False) @@ -72,7 +80,10 @@ def func(x: bool) -> bool: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('x is not False: x was False', tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "x is not False: x was False", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_list(self) -> None: y = 1 @@ -89,11 +100,15 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ sum([1, y, x]) == 1: sum([1, y, x]) was 5 x was 3 - y was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_tuple(self) -> None: y = 1 @@ -110,11 +125,15 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ sum((1, y, x)) == 1: sum((1, y, x)) was 5 x was 3 - y was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_set(self) -> None: y = 2 @@ -131,11 +150,15 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ sum({1, y, x}) == 1: sum({1, y, x}) was 6 x was 3 - y was 2"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_dict(self) -> None: y = "someKey" @@ -152,11 +175,15 @@ def func(x: str) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ len({y: 3, x: 8}) == 6: len({y: 3, x: 8}) was 2 x was 'oi' - y was 'someKey'"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 'someKey'""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_unary_op(self) -> None: @icontract.require(lambda x: not -x + 10 > 3) @@ -170,7 +197,10 @@ def func(x: int) -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('not -x + 10 > 3: x was 1', tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "not -x + 10 > 3: x was 1", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_binary_op(self) -> None: @icontract.require(lambda x: -x + x - x * x / x // x**x % x > 3) @@ -184,8 +214,10 @@ def func(x: int) -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('-x + x - x * x / x // x**x % x > 3: x was 1', - tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "-x + x - x * x / x // x**x % x > 3: x was 1", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_binary_op_bit(self) -> None: @icontract.require(lambda x: ~(x << x | x & x ^ x) >> x > x) @@ -199,8 +231,10 @@ def func(x: int) -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('~(x << x | x & x ^ x) >> x > x: x was 1', - tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "~(x << x | x & x ^ x) >> x > x: x was 1", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_bool_op_single(self) -> None: # pylint: disable=chained-comparison @@ -215,7 +249,10 @@ def func(x: int) -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('x > 3 and x < 10: x was 1', tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "x > 3 and x < 10: x was 1", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_bool_op_multiple(self) -> None: # pylint: disable=chained-comparison @@ -230,16 +267,25 @@ def func(x: int) -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('x > 3 and x < 10 and x % 2 == 0: x was 1', - tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "x > 3 and x < 10 and x % 2 == 0: x was 1", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_compare(self) -> None: # pylint: disable=chained-comparison # Chain the compare operators in a meaningless order and semantics @icontract.require( - lambda x: 0 < x < 3 and x > 10 and x != 7 and x >= 10 and x <= 11 and x is not None and - x in [1, 2, 3] and x not in [1, 2, 3]) + lambda x: 0 < x < 3 + and x > 10 + and x != 7 + and x >= 10 + and x <= 11 + and x is not None + and x in [1, 2, 3] + and x not in [1, 2, 3] + ) def func(x: int) -> int: return x @@ -251,10 +297,17 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ - 0 < x < 3 and x > 10 and x != 7 and x >= 10 and x <= 11 and x is not None and - x in [1, 2, 3] and x not in [1, 2, 3]: x was 1"""), - tests.error.wo_mandatory_location(str(violation_error))) + """\ +0 < x < 3 + and x > 10 + and x != 7 + and x >= 10 + and x <= 11 + and x is not None + and x in [1, 2, 3] + and x not in [1, 2, 3]: x was 1""", + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_call(self) -> None: def y() -> int: @@ -272,10 +325,14 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < y(): x was 1 - y() was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + y() was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_if_exp_body(self) -> None: y = 5 @@ -292,10 +349,14 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < (x**2 if y == 5 else x**3): x was 1 - y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_if_exp_orelse(self) -> None: y = 5 @@ -312,10 +373,14 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < (x**2 if y != 5 else x**3): x was 1 - y was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_attr(self) -> None: class A: @@ -339,11 +404,15 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > a.y: a was an instance of A a.y was 3 - x was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_index(self) -> None: lst = [1, 2, 3] @@ -360,11 +429,15 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > lst[1]: lst was [1, 2, 3] lst[1] was 2 - x was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_slice(self) -> None: lst = [1, 2, 3] @@ -381,12 +454,16 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > sum(lst[1:2:1]): lst was [1, 2, 3] lst[1:2:1] was [2] sum(lst[1:2:1]) was 2 - x was 1"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_ext_slice(self) -> None: class SomeClass: @@ -409,10 +486,14 @@ def func(something: SomeClass) -> None: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ something[1, 2:3] is None: something was - something[1, 2:3] was (1, slice(2, 3, None))'''), tests.error.wo_mandatory_location(str(violation_err))) + something[1, 2:3] was (1, slice(2, 3, None))""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) def test_lambda(self) -> None: @icontract.require(lambda x: x > (lambda y: y + 4).__call__(y=7)) # type: ignore @@ -432,14 +513,18 @@ def func(x: int) -> int: not_implemented_error = runtime_error.__cause__ self.assertEqual( - 'Re-computation of in-line lambda functions is not supported since it is quite tricky to implement and ' - 'we decided to implement it only once there is a real need for it. ' - 'Please make a feature request on https://github.com/Parquery/icontract', str(not_implemented_error)) + "Re-computation of in-line lambda functions is not supported since it is quite tricky to implement and " + "we decided to implement it only once there is a real need for it. " + "Please make a feature request on https://github.com/Parquery/icontract", + str(not_implemented_error), + ) class TestGeneratorExpr(unittest.TestCase): def test_attr_on_element(self) -> None: - @icontract.ensure(lambda result: all(single_res[1].is_absolute() for single_res in result)) + @icontract.ensure( + lambda result: all(single_res[1].is_absolute() for single_res in result) + ) def some_func() -> List[Tuple[pathlib.Path, pathlib.Path]]: return [(pathlib.Path("/home/file1"), pathlib.Path("home/file2"))] @@ -450,16 +535,19 @@ def some_func() -> List[Tuple[pathlib.Path, pathlib.Path]]: violation_error = err # This dummy path is necessary to obtain the class name. - dummy_path = pathlib.Path('/also/doesnt/exist') + dummy_path = pathlib.Path("/also/doesnt/exist") self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all(single_res[1].is_absolute() for single_res in result): all(single_res[1].is_absolute() for single_res in result) was False, e.g., with single_res = ({0}('/home/file1'), {0}('home/file2')) - result was [({0}('/home/file1'), {0}('home/file2'))]''').format(dummy_path.__class__.__name__), - tests.error.wo_mandatory_location(str(violation_error))) + result was [({0}('/home/file1'), {0}('home/file2'))]""" + ).format(dummy_path.__class__.__name__), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_multiple_for(self) -> None: lst = [[1, 2], [3]] @@ -481,32 +569,37 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all(item == x for sublst in lst for item in sublst): all(item == x for sublst in lst for item in sublst) was False, e.g., with sublst = [1, 2] item = 1 lst was [[1, 2], [3]] - x was 0'''), tests.error.wo_mandatory_location(str(violation_error))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_zip_and_multiple_for(self) -> None: # Taken from a solution for Advent of Code 2020 day 11. @icontract.ensure( - lambda layout, result: - all(cell == result_cell + lambda layout, result: all( + cell == result_cell for row, result_row in zip(layout, result[0]) for cell, result_cell in zip(row, result_row) - if cell == '.'), - "Floor remains floor" + if cell == "." + ), + "Floor remains floor", ) def apply(layout: List[List[str]]) -> Tuple[List[List[str]], int]: height = len(layout) width = len(layout[0]) - result = [[''] * width] * height + result = [[""] * width] * height return result, 0 - layout = [['L', '.', '#'], ['.', '#', '#']] + layout = [["L", ".", "#"], [".", "#", "#"]] violation_error = None # type: Optional[icontract.ViolationError] try: @@ -516,19 +609,27 @@ def apply(layout: List[List[str]]) -> Tuple[List[List[str]], int]: self.assertIsNotNone(violation_error) - text = re.sub(r'', '', - tests.error.wo_mandatory_location(str(violation_error))) + text = re.sub( + r"", + "", + tests.error.wo_mandatory_location(str(violation_error)), + ) self.assertEqual( - textwrap.dedent('''\ - Floor remains floor: all(cell == result_cell + textwrap.dedent( + """\ + Floor remains floor: all( + cell == result_cell for row, result_row in zip(layout, result[0]) for cell, result_cell in zip(row, result_row) - if cell == '.'): - all(cell == result_cell + if cell == "." + ): + all( + cell == result_cell for row, result_row in zip(layout, result[0]) for cell, result_cell in zip(row, result_row) - if cell == '.') was False, e.g., with + if cell == "." + ) was False, e.g., with row = ['L', '.', '#'] result_row = ['', '', ''] cell = '.' @@ -536,7 +637,10 @@ def apply(layout: List[List[str]]) -> Tuple[List[List[str]], int]: layout was [['L', '.', '#'], ['.', '#', '#']] result was ([['', '', ''], ['', '', '']], 0) result[0] was [['', '', ''], ['', '', '']] - zip(layout, result[0]) was '''), text) + zip(layout, result[0]) was """ + ), + text, + ) class TestListComprehension(unittest.TestCase): @@ -555,11 +659,15 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ [item < x for item in lst if item % x == 0] == []: [item < x for item in lst if item % x == 0] was [False] lst was [1, 2, 3] - x was 2'''), tests.error.wo_mandatory_location(str(violation_error))) + x was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_nested(self) -> None: lst_of_lsts = [[1, 2, 3]] @@ -584,7 +692,8 @@ def func() -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ [ [item for item in sublst if item > 0] for sublst in lst_of_lsts @@ -593,14 +702,19 @@ def func() -> None: [item for item in sublst if item > 0] for sublst in lst_of_lsts ] was [[1, 2, 3]] - lst_of_lsts was [[1, 2, 3]]'''), tests.error.wo_mandatory_location(str(violation_error))) + lst_of_lsts was [[1, 2, 3]]""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestSetComprehension(unittest.TestCase): def test_single(self) -> None: lst = [1, 2, 3] - @icontract.require(lambda x: len({item < x for item in lst if item % x == 0}) == 0) + @icontract.require( + lambda x: len({item < x for item in lst if item % x == 0}) == 0 + ) def func(x: int) -> int: return x @@ -612,13 +726,16 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ len({item < x for item in lst if item % x == 0}) == 0: len({item < x for item in lst if item % x == 0}) was 1 lst was [1, 2, 3] x was 2 - {item < x for item in lst if item % x == 0} was {False}'''), - tests.error.wo_mandatory_location(str(violation_error))) + {item < x for item in lst if item % x == 0} was {False}""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_nested(self) -> None: lst_of_lsts = [[1, 2, 3]] @@ -643,7 +760,8 @@ def func() -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ { len({item for item in lst if item > 0}) for lst in lst_of_lsts @@ -653,7 +771,10 @@ def func() -> None: { len({item for item in lst if item > 0}) for lst in lst_of_lsts - } was {3}'''), tests.error.wo_mandatory_location(str(violation_error))) + } was {3}""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestDictComprehension(unittest.TestCase): @@ -670,13 +791,16 @@ def func(x: int) -> int: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ len({i: i**2 for i in range(x)}) == 0: len({i: i**2 for i in range(x)}) was 2 range(x) was range(0, 2) x was 2 - {i: i**2 for i in range(x)} was {0: 0, 1: 1}'''), tests.error.wo_mandatory_location( - str(violation_error))) + {i: i**2 for i in range(x)} was {0: 0, 1: 1}""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_nested(self) -> None: lst_of_lsts = [[1, 2, 3]] @@ -701,7 +825,8 @@ def func() -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ len({ len(lst): {item: item for item in lst} for lst in lst_of_lsts @@ -714,7 +839,10 @@ def func() -> None: { len(lst): {item: item for item in lst} for lst in lst_of_lsts - } was {3: {1: 1, 2: 2, 3: 3}}'''), tests.error.wo_mandatory_location(str(violation_error))) + } was {3: {1: 1, 2: 2, 3: 3}}""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestConditionAsText(unittest.TestCase): @@ -735,7 +863,9 @@ def func(x: int) -> int: violation_err = err self.assertIsNotNone(violation_err) - self.assertEqual("x > 3: x was 0", tests.error.wo_mandatory_location(str(violation_err))) + self.assertEqual( + "x > 3: x was 0", tests.error.wo_mandatory_location(str(violation_err)) + ) def test_condition_on_next_line(self) -> None: # yapf: disable @@ -753,7 +883,9 @@ def func(x: int) -> int: violation_err = err self.assertIsNotNone(violation_err) - self.assertEqual("x > 3: x was 0", tests.error.wo_mandatory_location(str(violation_err))) + self.assertEqual( + "x > 3: x was 0", tests.error.wo_mandatory_location(str(violation_err)) + ) def test_condition_on_multiple_lines(self) -> None: # yapf: disable @@ -775,10 +907,14 @@ def func(x: int) -> int: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x > - 3: x was 0"""), tests.error.wo_mandatory_location(str(violation_err))) + 3: x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) def test_with_multiple_lambdas_on_a_line(self) -> None: # pylint: disable=unnecessary-lambda @@ -831,10 +967,14 @@ def some_func(x: List[int]) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ len(x) < 10: len(x) was 10000 - x was [0, 1, 2, ...]"""), tests.error.wo_mandatory_location(str(violation_error))) + x was [0, 1, 2, ...]""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestClass(unittest.TestCase): @@ -870,11 +1010,15 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ self.b.x > 0: self was A() self.b was B(x=0) - self.b.x was 0"""), tests.error.wo_mandatory_location(str(violation_error))) + self.b.x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_nested_method(self) -> None: z = 10 @@ -903,7 +1047,10 @@ class A: def __init__(self) -> None: self.b = B() - @icontract.require(lambda self: pathlib.Path(str(gt_zero(self.b.c(x=0).x() + 12.2 * z))) is None) + @icontract.require( + lambda self: pathlib.Path(str(gt_zero(self.b.c(x=0).x() + 12.2 * z))) + is None + ) def some_func(self) -> None: pass @@ -919,12 +1066,14 @@ def __repr__(self) -> str: violation_error = err # This dummy path is necessary to obtain the class name. - dummy_path = pathlib.Path('/just/a/dummy/path') + dummy_path = pathlib.Path("/just/a/dummy/path") self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ - pathlib.Path(str(gt_zero(self.b.c(x=0).x() + 12.2 * z))) is None: + textwrap.dedent( + """\ + pathlib.Path(str(gt_zero(self.b.c(x=0).x() + 12.2 * z))) + is None: gt_zero(self.b.c(x=0).x() + 12.2 * z) was True pathlib.Path(str(gt_zero(self.b.c(x=0).x() + 12.2 * z))) was {}('True') self was A() @@ -932,8 +1081,10 @@ def __repr__(self) -> str: self.b.c(x=0) was C(x=0) self.b.c(x=0).x() was 0 str(gt_zero(self.b.c(x=0).x() + 12.2 * z)) was 'True' - z was 10''').format(dummy_path.__class__.__name__), - tests.error.wo_mandatory_location(str(violation_error))) + z was 10""" + ).format(dummy_path.__class__.__name__), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestClosures(unittest.TestCase): @@ -953,11 +1104,15 @@ def some_func(x: int) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < y + z: x was 100 y was 4 - z was 5"""), tests.error.wo_mandatory_location(str(violation_error))) + z was 5""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_global(self) -> None: @icontract.require(lambda x: x < SOME_GLOBAL_CONSTANT) @@ -972,10 +1127,14 @@ def some_func(x: int) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < SOME_GLOBAL_CONSTANT: SOME_GLOBAL_CONSTANT was 10 - x was 100"""), tests.error.wo_mandatory_location(str(violation_error))) + x was 100""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_closure_and_global(self) -> None: y = 4 @@ -992,11 +1151,15 @@ def some_func(x: int) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ x < y + SOME_GLOBAL_CONSTANT: SOME_GLOBAL_CONSTANT was 10 x was 100 - y was 4"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 4""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestWithNumpyMock(unittest.TestCase): @@ -1010,7 +1173,10 @@ def test_that_mock_works(self) -> None: value_err = err self.assertIsNotNone(value_err) - self.assertEqual('The truth value of an array with more than one element is ambiguous.', str(value_err)) + self.assertEqual( + "The truth value of an array with more than one element is ambiguous.", + str(value_err), + ) def test_that_single_comparator_works(self) -> None: @icontract.require(lambda arr: (arr > 0).all()) @@ -1025,10 +1191,14 @@ def some_func(arr: tests.mock.NumpyArray) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ (arr > 0).all(): (arr > 0).all() was False - arr was NumpyArray([-3, 3])"""), tests.error.wo_mandatory_location(str(violation_error))) + arr was NumpyArray([-3, 3])""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_that_multiple_comparators_fail(self) -> None: """ @@ -1061,7 +1231,10 @@ def some_func(arr: tests.mock.NumpyArray) -> None: value_err = err self.assertIsNotNone(value_err) - self.assertEqual('The truth value of an array with more than one element is ambiguous.', str(value_err)) + self.assertEqual( + "The truth value of an array with more than one element is ambiguous.", + str(value_err), + ) class TestNumpyArrays(unittest.TestCase): @@ -1081,10 +1254,14 @@ def some_func(arr: Any) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ len(arr) > 2: arr was array([0, 1]) - len(arr) was 2"""), tests.error.wo_mandatory_location(str(violation_error))) + len(arr) was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_arange_in_kwargs_values(self) -> None: # This test case addresses ``visit_Call` in the Visitor. @@ -1105,10 +1282,14 @@ def some_func(arr: Any) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ custom_len(arr=arr) > 2: arr was array([0, 1]) - custom_len(arr=arr) was 2"""), tests.error.wo_mandatory_location(str(violation_error))) + custom_len(arr=arr) was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestRecomputationFailure(unittest.TestCase): @@ -1138,17 +1319,24 @@ def some_func() -> None: assert runtime_error is not None lines = [ - re.sub(r'^File (.*), line ([0-9]+) in (.*):$', - 'File , line in :', line) + re.sub( + r"^File (.*), line ([0-9]+) in (.*):$", + "File , line in :", + line, + ) for line in str(runtime_error).splitlines() ] - text = '\n'.join(lines) + text = "\n".join(lines) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ Failed to recompute the values of the contract condition: File , line in : - lambda: some_condition()'''), text) + lambda: some_condition()""" + ), + text, + ) class TestTracingAll(unittest.TestCase): @@ -1176,7 +1364,8 @@ def func(lst: List[int]) -> None: got = tests.error.wo_mandatory_location(str(violation_error)) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all( value > SOME_GLOBAL_CONSTANT for value in lst @@ -1187,7 +1376,10 @@ def func(lst: List[int]) -> None: for value in lst ) was False, e.g., with value = -1 - lst was [-1, -2]'''), got) + lst was [-1, -2]""" + ), + got, + ) def test_formatted_string(self) -> None: # yapf: disable @@ -1204,7 +1396,7 @@ def func(lst: List[str]) -> None: violation_error = None # type: Optional[icontract.ViolationError] try: - func(lst=['y']) + func(lst=["y"]) except icontract.ViolationError as err: violation_error = err @@ -1213,7 +1405,8 @@ def func(lst: List[str]) -> None: got = tests.error.wo_mandatory_location(str(violation_error)) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all( f'{value}' == 'x' for value in lst @@ -1223,7 +1416,10 @@ def func(lst: List[str]) -> None: for value in lst ) was False, e.g., with value = 'y' - lst was ['y']'''), got) + lst was ['y']""" + ), + got, + ) def test_two_fors_and_two_ifs(self) -> None: # yapf: disable @@ -1249,11 +1445,15 @@ def func(matrix: List[List[int]]) -> None: self.assertIsNotNone(violation_error) - got = re.sub(r'', '', - tests.error.wo_mandatory_location(str(violation_error))) + got = re.sub( + r"", + "", + tests.error.wo_mandatory_location(str(violation_error)), + ) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all( cell > SOME_GLOBAL_CONSTANT for i, row in enumerate(matrix) @@ -1274,7 +1474,10 @@ def func(matrix: List[List[int]]) -> None: j = 1 cell = -1 enumerate(matrix) was - matrix was [[-1, -1], [-1, -1]]'''), got) + matrix was [[-1, -1], [-1, -1]]""" + ), + got, + ) def test_nested_all(self) -> None: # Nesting is not recursively followed by design. Only the outer-most all expression should be traced. @@ -1299,7 +1502,8 @@ def func(lst_of_lsts: List[List[int]]) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all( all(item > 0 for item in sublst) for sublst in lst_of_lsts @@ -1309,7 +1513,10 @@ def func(lst_of_lsts: List[List[int]]) -> None: for sublst in lst_of_lsts ) was False, e.g., with sublst = [-1, -1] - lst_of_lsts was [[-1, -1]]'''), tests.error.wo_mandatory_location(str(violation_error))) + lst_of_lsts was [[-1, -1]]""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_property_of_an_object_represented(self) -> None: class Something: @@ -1340,7 +1547,8 @@ def func(something: Something, lst: List[int]) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all( item > something.some_property for item in lst @@ -1352,7 +1560,10 @@ def func(something: Something, lst: List[int]) -> None: item = -1 lst was [-1] something was Something() - something.some_property was 0'''), tests.error.wo_mandatory_location(str(violation_error))) + something.some_property was 0""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_shadows_in_targets(self) -> None: # yapf: disable @@ -1375,7 +1586,8 @@ def func(lst_of_lsts: List[List[int]]) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ all( all(item > 0 for item in item) for item in lst_of_lsts @@ -1385,8 +1597,11 @@ def func(lst_of_lsts: List[List[int]]) -> None: for item in lst_of_lsts ) was False, e.g., with item = [-1, -1] - lst_of_lsts was [[-1, -1]]'''), tests.error.wo_mandatory_location(str(violation_error))) + lst_of_lsts was [[-1, -1]]""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py index e0b57aa..38ffb8c 100644 --- a/tests/test_snapshot.py +++ b/tests/test_snapshot.py @@ -42,7 +42,9 @@ def some_func(lst: List[int], val: int) -> None: def test_with_multiple_arguments(self) -> None: @icontract.snapshot(lambda lst_a, lst_b: set(lst_a).union(lst_b), name="union") - @icontract.ensure(lambda OLD, lst_a, lst_b: set(lst_a).union(lst_b) == OLD.union) + @icontract.ensure( + lambda OLD, lst_a, lst_b: set(lst_a).union(lst_b) == OLD.union + ) def some_func(lst_a: List[int], lst_b: List[int]) -> None: pass @@ -66,13 +68,17 @@ def some_func(lst: List[int], val: int) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ OLD.lst + [val] == lst: OLD was a bunch of OLD values OLD.lst was [1] lst was [1, 2, 1984] result was None - val was 2"""), tests.error.wo_mandatory_location(str(violation_error))) + val was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_custom_name(self) -> None: @icontract.snapshot(lambda lst: len(lst), name="len_lst") @@ -89,18 +95,24 @@ def some_func(lst: List[int], val: int) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ OLD.len_lst + 1 == len(lst): OLD was a bunch of OLD values OLD.len_lst was 1 len(lst) was 3 lst was [1, 2, 1984] result was None - val was 2"""), tests.error.wo_mandatory_location(str(violation_error))) + val was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) def test_with_multiple_arguments(self) -> None: @icontract.snapshot(lambda lst_a, lst_b: set(lst_a).union(lst_b), name="union") - @icontract.ensure(lambda OLD, lst_a, lst_b: set(lst_a).union(lst_b) == OLD.union) + @icontract.ensure( + lambda OLD, lst_a, lst_b: set(lst_a).union(lst_b) == OLD.union + ) def some_func(lst_a: List[int], lst_b: List[int]) -> None: lst_a.append(1984) # bug @@ -112,7 +124,8 @@ def some_func(lst_a: List[int], lst_b: List[int]) -> None: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ set(lst_a).union(lst_b) == OLD.union: OLD was a bunch of OLD values OLD.union was {1, 2, 3, 4} @@ -120,8 +133,10 @@ def some_func(lst_a: List[int], lst_b: List[int]) -> None: lst_b was [3, 4] result was None set(lst_a) was {1, 2, 1984} - set(lst_a).union(lst_b) was {1, 2, 3, 4, 1984}'''), - tests.error.wo_mandatory_location(str(violation_error))) + set(lst_a).union(lst_b) was {1, 2, 3, 4, 1984}""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestInvalid(unittest.TestCase): @@ -137,10 +152,12 @@ def some_func(lst: List[int], val: int) -> None: type_error = err self.assertIsNotNone(type_error) - self.assertEqual("The argument(s) of the contract condition have not been set: ['OLD']. " - "Does the original function define them? Did you supply them in the call? " - "Did you decorate the function with a snapshot to capture OLD values?", - tests.error.wo_mandatory_location(str(type_error))) + self.assertEqual( + "The argument(s) of the contract condition have not been set: ['OLD']. " + "Does the original function define them? Did you supply them in the call? " + "Did you decorate the function with a snapshot to capture OLD values?", + tests.error.wo_mandatory_location(str(type_error)), + ) def test_conflicting_snapshots_with_argument_name(self) -> None: value_error = None # type: Optional[ValueError] @@ -157,15 +174,17 @@ def some_func(lst: List[int], val: int) -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("There are conflicting snapshots with the name: 'lst'", str(value_error)) + self.assertEqual( + "There are conflicting snapshots with the name: 'lst'", str(value_error) + ) def test_conflicting_snapshots_with_custom_name(self) -> None: value_error = None # type: Optional[ValueError] try: # pylint: disable=unused-variable - @icontract.snapshot(lambda lst: len(lst), name='len_lst') - @icontract.snapshot(lambda lst: len(lst), name='len_lst') + @icontract.snapshot(lambda lst: len(lst), name="len_lst") + @icontract.snapshot(lambda lst: len(lst), name="len_lst") @icontract.ensure(lambda OLD, val, lst: OLD.len_lst + 1 == len(lst)) def some_func(lst: List[int], val: int) -> None: lst.append(val) @@ -174,14 +193,16 @@ def some_func(lst: List[int], val: int) -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("There are conflicting snapshots with the name: 'len_lst'", str(value_error)) + self.assertEqual( + "There are conflicting snapshots with the name: 'len_lst'", str(value_error) + ) def test_with_invalid_argument(self) -> None: # lst versus a_list type_error = None # type: Optional[TypeError] try: - @icontract.snapshot(lambda lst: len(lst), name='len_lst') + @icontract.snapshot(lambda lst: len(lst), name="len_lst") @icontract.ensure(lambda OLD, val, a_list: OLD.len_lst + 1 == len(a_list)) def some_func(a_list: List[int], val: int) -> None: a_list.append(val) @@ -191,9 +212,11 @@ def some_func(a_list: List[int], val: int) -> None: type_error = err self.assertIsNotNone(type_error) - self.assertEqual("The argument(s) of the snapshot have not been set: ['lst']. " - "Does the original function define them? Did you supply them in the call?", - tests.error.wo_mandatory_location(str(type_error))) + self.assertEqual( + "The argument(s) of the snapshot have not been set: ['lst']. " + "Does the original function define them? Did you supply them in the call?", + tests.error.wo_mandatory_location(str(type_error)), + ) def test_with_no_arguments_and_no_name(self) -> None: z = [1] @@ -211,7 +234,10 @@ def some_func(val: int) -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("You must name a snapshot if no argument was given in the capture function.", str(value_error)) + self.assertEqual( + "You must name a snapshot if no argument was given in the capture function.", + str(value_error), + ) def test_with_multiple_arguments_and_no_name(self) -> None: value_error = None # type: Optional[ValueError] @@ -219,7 +245,9 @@ def test_with_multiple_arguments_and_no_name(self) -> None: # pylint: disable=unused-variable @icontract.snapshot(lambda lst_a, lst_b: set(lst_a).union(lst_b)) - @icontract.ensure(lambda OLD, lst_a, lst_b: set(lst_a).union(lst_b) == OLD.union) + @icontract.ensure( + lambda OLD, lst_a, lst_b: set(lst_a).union(lst_b) == OLD.union + ) def some_func(lst_a: List[int], lst_b: List[int]) -> None: pass @@ -227,8 +255,10 @@ def some_func(lst_a: List[int], lst_b: List[int]) -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("You must name a snapshot if multiple arguments were given in the capture function.", - str(value_error)) + self.assertEqual( + "You must name a snapshot if multiple arguments were given in the capture function.", + str(value_error), + ) def test_with_no_postcondition(self) -> None: value_error = None # type: Optional[ValueError] @@ -238,16 +268,22 @@ def test_with_no_postcondition(self) -> None: @icontract.snapshot(lambda lst: lst[:]) def some_func(lst: List[int]) -> None: return + except ValueError as err: value_error = err self.assertIsNotNone(value_error) - self.assertEqual("You are decorating a function with a snapshot, " - "but no postcondition was defined on the function before.", str(value_error)) + self.assertEqual( + "You are decorating a function with a snapshot, " + "but no postcondition was defined on the function before.", + str(value_error), + ) def test_missing_old_attribute(self) -> None: @icontract.snapshot(lambda lst: lst[:]) - @icontract.ensure(lambda OLD, lst: OLD.len_list == lst) # We miss len_lst in OLD here! + @icontract.ensure( + lambda OLD, lst: OLD.len_list == lst + ) # We miss len_lst in OLD here! def some_func(lst: List[int]) -> None: return @@ -260,10 +296,12 @@ def some_func(lst: List[int]) -> None: assert attribute_error is not None - self.assertEqual("The snapshot with the name 'len_list' is not available in the OLD of a postcondition. " - "Have you decorated the function with a corresponding snapshot decorator?", - str(attribute_error)) + self.assertEqual( + "The snapshot with the name 'len_list' is not available in the OLD of a postcondition. " + "Have you decorated the function with a corresponding snapshot decorator?", + str(attribute_error), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_threading.py b/tests/test_threading.py index cad57b5..bdbfcc2 100644 --- a/tests/test_threading.py +++ b/tests/test_threading.py @@ -39,5 +39,5 @@ def run(self) -> None: another_worker.join() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_typeguard.py b/tests/test_typeguard.py index 8abcfa8..14fc947 100644 --- a/tests/test_typeguard.py +++ b/tests/test_typeguard.py @@ -44,7 +44,7 @@ def some_func(x: A) -> None: type_error = err expected = 'type of argument "x" must be ' - self.assertEqual(expected, str(type_error)[:len(expected)]) + self.assertEqual(expected, str(type_error)[: len(expected)]) def test_precondition_fails_and_typeguard_ok(self) -> None: @typeguard.typechecked @@ -59,7 +59,9 @@ def some_func(x: int) -> None: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual('x > 0: x was -10', tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "x > 0: x was -10", tests.error.wo_mandatory_location(str(violation_error)) + ) class TestInvariant(unittest.TestCase): @@ -100,7 +102,7 @@ def __init__(self, a: A) -> None: type_error = err expected = 'type of argument "a" must be ' - self.assertEqual(expected, str(type_error)[:len(expected)]) + self.assertEqual(expected, str(type_error)[: len(expected)]) def test_invariant_fails_and_typeguard_ok(self) -> None: @typeguard.typechecked @@ -120,10 +122,14 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ self.x > 0: self was an instance of A - self.x was -1'''), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestInheritance(unittest.TestCase): @@ -172,7 +178,7 @@ class D(C): self.assertIsNotNone(type_error) expected = 'type of argument "a" must be ' - self.assertEqual(expected, str(type_error)[:len(expected)]) + self.assertEqual(expected, str(type_error)[: len(expected)]) def test_invariant_fails_and_typeguard_ok(self) -> None: @typeguard.typechecked @@ -196,11 +202,15 @@ def __repr__(self) -> str: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ self.x > 0: self was an instance of B - self.x was -1'''), tests.error.wo_mandatory_location(str(violation_error))) + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_3_6/test_represent.py b/tests_3_6/test_represent.py index 5989713..dce4dd9 100644 --- a/tests_3_6/test_represent.py +++ b/tests_3_6/test_represent.py @@ -14,7 +14,7 @@ class TestLiteralStringInterpolation(unittest.TestCase): def test_plain_string(self) -> None: # pylint: disable=f-string-without-interpolation - @icontract.require(lambda x: f"something" == '') # type: ignore + @icontract.require(lambda x: f"something" == "") # type: ignore def func(x: float) -> float: return x @@ -26,13 +26,17 @@ def func(x: float) -> float: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent("""\ - f"something" == '': + textwrap.dedent( + """\ + f"something" == "": f"something" was 'something' - x was 0"""), tests.error.wo_mandatory_location(str(violation_err))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) def test_simple_interpolation(self) -> None: - @icontract.require(lambda x: f"{x}" == '') + @icontract.require(lambda x: f"{x}" == "") def func(x: float) -> float: return x @@ -44,13 +48,17 @@ def func(x: float) -> float: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent("""\ - f"{x}" == '': + textwrap.dedent( + """\ + f"{x}" == "": f"{x}" was '0' - x was 0"""), tests.error.wo_mandatory_location(str(violation_err))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) def test_string_formatting(self) -> None: - @icontract.require(lambda x: f"{x!s}" == '') + @icontract.require(lambda x: f"{x!s}" == "") def func(x: float) -> float: return x @@ -62,13 +70,17 @@ def func(x: float) -> float: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent("""\ - f"{x!s}" == '': + textwrap.dedent( + """\ + f"{x!s}" == "": f"{x!s}" was '1.984' - x was 1.984"""), tests.error.wo_mandatory_location(str(violation_err))) + x was 1.984""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) def test_repr_formatting(self) -> None: - @icontract.require(lambda x: f"{x!r}" == '') + @icontract.require(lambda x: f"{x!r}" == "") def func(x: float) -> float: return x @@ -80,13 +92,17 @@ def func(x: float) -> float: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent("""\ - f"{x!r}" == '': + textwrap.dedent( + """\ + f"{x!r}" == "": f"{x!r}" was '1.984' - x was 1.984"""), tests.error.wo_mandatory_location(str(violation_err))) + x was 1.984""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) def test_ascii_formatting(self) -> None: - @icontract.require(lambda x: f"{x!a}" == '') + @icontract.require(lambda x: f"{x!a}" == "") def func(x: float) -> float: return x @@ -98,13 +114,17 @@ def func(x: float) -> float: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent("""\ - f"{x!a}" == '': + textwrap.dedent( + """\ + f"{x!a}" == "": f"{x!a}" was '1.984' - x was 1.984"""), tests.error.wo_mandatory_location(str(violation_err))) + x was 1.984""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) def test_format_spec(self) -> None: - @icontract.require(lambda x: f"{x:.3}" == '') + @icontract.require(lambda x: f"{x:.3}" == "") def func(x: float) -> float: return x @@ -116,13 +136,17 @@ def func(x: float) -> float: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent("""\ - f"{x:.3}" == '': + textwrap.dedent( + """\ + f"{x:.3}" == "": f"{x:.3}" was '1.98' - x was 1.984"""), tests.error.wo_mandatory_location(str(violation_err))) + x was 1.984""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) def test_conversion_and_format_spec(self) -> None: - @icontract.require(lambda x: f"{x!r:.3}" == '') + @icontract.require(lambda x: f"{x!r:.3}" == "") def func(x: float) -> float: return x @@ -134,11 +158,15 @@ def func(x: float) -> float: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent("""\ - f"{x!r:.3}" == '': + textwrap.dedent( + """\ + f"{x!r:.3}" == "": f"{x!r:.3}" was '1.9' - x was 1.984"""), tests.error.wo_mandatory_location(str(violation_err))) + x was 1.984""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_3_7/test_invariant.py b/tests_3_7/test_invariant.py index 67af8b1..7b281ed 100644 --- a/tests_3_7/test_invariant.py +++ b/tests_3_7/test_invariant.py @@ -21,8 +21,10 @@ class RightHalfPlanePoint: _ = RightHalfPlanePoint(1, 0) - self.assertEqual('Create and return a new object. See help(type) for accurate signature.', - RightHalfPlanePoint.__new__.__doc__) + self.assertEqual( + "Create and return a new object. See help(type) for accurate signature.", + RightHalfPlanePoint.__new__.__doc__, + ) class TestViolation(unittest.TestCase): @@ -42,11 +44,15 @@ class RightHalfPlanePoint: self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ self.second > 0: self was TestViolation.test_on_dataclass..RightHalfPlanePoint(first=1, second=-1) - self.second was -1'''), tests.error.wo_mandatory_location(str(violation_error))) + self.second was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_3_8/async/separately_test_concurrent.py b/tests_3_8/async/separately_test_concurrent.py index 66082bc..cffa9b3 100644 --- a/tests_3_8/async/separately_test_concurrent.py +++ b/tests_3_8/async/separately_test_concurrent.py @@ -22,7 +22,10 @@ async def is_between_0_and_100(value: int) -> bool: @icontract.require( lambda bar: is_between_0_and_100(bar), - error=lambda bar: icontract.ViolationError(f"bar between 0 and 100, but got {bar}")) + error=lambda bar: icontract.ViolationError( + f"bar between 0 and 100, but got {bar}" + ), + ) async def is_less_than_42(bar: int) -> bool: sleep_time = random.randint(1, 5) await asyncio.sleep(sleep_time) @@ -31,9 +34,9 @@ async def is_less_than_42(bar: int) -> bool: results_or_errors = await asyncio.gather( is_less_than_42(0), # Should return True is_less_than_42(101), # Should violate the pre-condition - is_less_than_42(-1) # Should violate the pre-condition - , - return_exceptions=True) + is_less_than_42(-1), # Should violate the pre-condition + return_exceptions=True, + ) assert len(results_or_errors) == 3 assert results_or_errors[0] diff --git a/tests_3_8/async/test_args_and_kwargs_in_contract.py b/tests_3_8/async/test_args_and_kwargs_in_contract.py index b22de36..866846a 100644 --- a/tests_3_8/async/test_args_and_kwargs_in_contract.py +++ b/tests_3_8/async/test_args_and_kwargs_in_contract.py @@ -28,7 +28,7 @@ async def some_func(x: int) -> None: await some_func(3) assert recorded_args is not None - self.assertTupleEqual((3, ), recorded_args) + self.assertTupleEqual((3,), recorded_args) async def test_args_with_named_and_variable_positional_arguments(self) -> None: recorded_args = None # type: Optional[Tuple[Any, ...]] @@ -94,11 +94,15 @@ async def some_func(*args: Any) -> None: assert violation_error is not None self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ len(_ARGS) > 2: _ARGS was (3,) args was 3 - len(_ARGS) was 1'''), tests.error.wo_mandatory_location(str(violation_error))) + len(_ARGS) was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestKwargs(unittest.IsolatedAsyncioTestCase): @@ -153,7 +157,9 @@ async def some_func(**kwargs: Any) -> None: assert recorded_kwargs is not None self.assertDictEqual({"x": 3, "y": 2, "z": 1}, recorded_kwargs) - async def test_kwargs_with_uncommon_argument_name_for_variable_keyword_arguments(self) -> None: + async def test_kwargs_with_uncommon_argument_name_for_variable_keyword_arguments( + self, + ) -> None: recorded_kwargs = None # type: Optional[Dict[str, Any]] def set_kwargs(kwargs: Dict[str, Any]) -> bool: @@ -171,7 +177,7 @@ async def some_func(**parameters: Any) -> None: self.assertDictEqual({"x": 3, "y": 2, "z": 1, "a": 0}, recorded_kwargs) async def test_fail(self) -> None: - @icontract.require(lambda _KWARGS: 'x' in _KWARGS) + @icontract.require(lambda _KWARGS: "x" in _KWARGS) async def some_func(**kwargs: Any) -> None: pass @@ -183,10 +189,14 @@ async def some_func(**kwargs: Any) -> None: assert violation_error is not None self.assertEqual( - textwrap.dedent("""\ - 'x' in _KWARGS: + textwrap.dedent( + """\ + "x" in _KWARGS: _KWARGS was {'y': 3} - y was 3"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestArgsAndKwargs(unittest.IsolatedAsyncioTestCase): @@ -208,7 +218,7 @@ async def some_func(*args: Any, **kwargs: Any) -> None: await some_func(5, x=10, y=20, z=30) assert recorded_args is not None - self.assertTupleEqual((5, ), recorded_args) + self.assertTupleEqual((5,), recorded_args) assert recorded_kwargs is not None self.assertDictEqual({"x": 10, "y": 20, "z": 30}, recorded_kwargs) @@ -250,8 +260,11 @@ async def some_func(*args, **kwargs) -> None: # type: ignore type_error = error assert type_error is not None - self.assertEqual('The arguments of the function call include "_ARGS" ' - 'which is a placeholder for positional arguments in a condition.', str(type_error)) + self.assertEqual( + 'The arguments of the function call include "_ARGS" ' + "which is a placeholder for positional arguments in a condition.", + str(type_error), + ) class TestConflictOnKWARGSReported(unittest.IsolatedAsyncioTestCase): @@ -267,9 +280,12 @@ async def some_func(*args, **kwargs) -> None: # type: ignore type_error = error assert type_error is not None - self.assertEqual('The arguments of the function call include "_KWARGS" ' - 'which is a placeholder for keyword arguments in a condition.', str(type_error)) + self.assertEqual( + 'The arguments of the function call include "_KWARGS" ' + "which is a placeholder for keyword arguments in a condition.", + str(type_error), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_3_8/async/test_coroutine_example.py b/tests_3_8/async/test_coroutine_example.py index 029c553..c064c19 100644 --- a/tests_3_8/async/test_coroutine_example.py +++ b/tests_3_8/async/test_coroutine_example.py @@ -29,18 +29,33 @@ class Book: author: str @icontract.require(lambda categories: a.map(has_category, categories)) - @icontract.ensure(lambda result: a.all(a.await_each(has_author(book.author) for book in result))) + @icontract.ensure( + lambda result: a.all( + a.await_each(has_author(book.author) for book in result) + ) + ) async def list_books(categories: List[str]) -> List[Book]: result = [] # type: List[Book] for category in categories: if category == "sci-fi": - result.extend([Book(identifier="The Blazing World", author="Margaret Cavendish")]) + result.extend( + [ + Book( + identifier="The Blazing World", + author="Margaret Cavendish", + ) + ] + ) elif category == "romance": - result.extend([Book(identifier="Pride and Prejudice", author="Jane Austen")]) + result.extend( + [Book(identifier="Pride and Prejudice", author="Jane Austen")] + ) else: raise AssertionError(category) return result - sci_fi_books = await list_books(categories=['sci-fi']) - self.assertListEqual(['The Blazing World'], [book.identifier for book in sci_fi_books]) + sci_fi_books = await list_books(categories=["sci-fi"]) + self.assertListEqual( + ["The Blazing World"], [book.identifier for book in sci_fi_books] + ) diff --git a/tests_3_8/async/test_exceptions.py b/tests_3_8/async/test_exceptions.py index 94af4d4..e3b46ed 100644 --- a/tests_3_8/async/test_exceptions.py +++ b/tests_3_8/async/test_exceptions.py @@ -26,7 +26,8 @@ def some_func(x: int) -> int: self.assertIsNotNone(value_error) self.assertRegex( str(value_error), - r'^Unexpected coroutine \(async\) condition <.*> for a sync function <.*\.some_func at .*>.') + r"^Unexpected coroutine \(async\) condition <.*> for a sync function <.*\.some_func at .*>.", + ) def test_postcondition(self) -> None: async def result_greater_zero(result: int) -> bool: @@ -45,7 +46,8 @@ def some_func() -> int: self.assertIsNotNone(value_error) self.assertRegex( str(value_error), - r'^Unexpected coroutine \(async\) condition <.*> for a sync function <.*\.some_func at .*>.') + r"^Unexpected coroutine \(async\) condition <.*> for a sync function <.*\.some_func at .*>.", + ) def test_snapshot(self) -> None: async def capture_len_lst(lst: List[int]) -> int: @@ -64,8 +66,10 @@ def some_func(lst: List[int]) -> None: self.assertIsNotNone(value_error) self.assertRegex( - str(value_error), r'^Unexpected coroutine \(async\) snapshot capture ' - r'for a sync function \.') + str(value_error), + r"^Unexpected coroutine \(async\) snapshot capture " + r"for a sync function \.", + ) class TestSyncFunctionConditionCoroutineFail(unittest.IsolatedAsyncioTestCase): @@ -87,7 +91,8 @@ def some_func(x: int) -> int: self.assertRegex( str(value_error), - r"^Unexpected coroutine resulting from the condition for a sync function \.$") + r"^Unexpected coroutine resulting from the condition for a sync function \.$", + ) def test_postcondition(self) -> None: async def result_greater_zero(result: int) -> bool: @@ -107,7 +112,8 @@ def some_func() -> int: self.assertRegex( str(value_error), - r"^Unexpected coroutine resulting from the condition for a sync function \.$") + r"^Unexpected coroutine resulting from the condition for a sync function \.$", + ) def test_snapshot(self) -> None: async def capture_len_lst(lst: List[int]) -> int: @@ -126,13 +132,15 @@ def some_func(lst: List[int]) -> None: assert value_error is not None self.assertRegex( - str(value_error), r'^Unexpected coroutine resulting ' - r'from the snapshot capture of a sync function .$') + str(value_error), + r"^Unexpected coroutine resulting " + r"from the snapshot capture of a sync function .$", + ) class TestAsyncInvariantsFail(unittest.IsolatedAsyncioTestCase): def test_that_async_invariants_reported(self) -> None: - async def some_async_invariant(self: 'A') -> bool: + async def some_async_invariant(self: "A") -> bool: return self.x > 0 value_error = None # type: Optional[ValueError] @@ -142,6 +150,7 @@ async def some_async_invariant(self: 'A') -> bool: class A: def __init__(self) -> None: self.x = 100 + except ValueError as error: value_error = error @@ -149,8 +158,9 @@ def __init__(self) -> None: self.assertEqual( "Async conditions are not possible in invariants as sync methods such as __init__ have to be wrapped.", - str(value_error)) + str(value_error), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_3_8/async/test_invariant.py b/tests_3_8/async/test_invariant.py index 73eeb4a..0ae436a 100644 --- a/tests_3_8/async/test_invariant.py +++ b/tests_3_8/async/test_invariant.py @@ -48,4 +48,8 @@ async def some_func(self) -> None: violation_error = err self.assertIsNotNone(violation_error) - self.assertTrue(tests.error.wo_mandatory_location(str(violation_error)).startswith('self.x > 0')) + self.assertTrue( + tests.error.wo_mandatory_location(str(violation_error)).startswith( + "self.x > 0" + ) + ) diff --git a/tests_3_8/async/test_postcondition.py b/tests_3_8/async/test_postcondition.py index ab6f9d5..73663bf 100644 --- a/tests_3_8/async/test_postcondition.py +++ b/tests_3_8/async/test_postcondition.py @@ -43,7 +43,10 @@ async def some_func() -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("result > 0: result was -100", tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "result > 0: result was -100", + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestAsyncFunctionAsyncCondition(unittest.IsolatedAsyncioTestCase): @@ -86,8 +89,10 @@ async def some_func() -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("result_greater_zero: result was -100", tests.error.wo_mandatory_location( - str(violation_error))) + self.assertEqual( + "result_greater_zero: result was -100", + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestCoroutine(unittest.IsolatedAsyncioTestCase): @@ -105,7 +110,9 @@ async def test_fail(self) -> None: async def some_condition() -> bool: return False - @icontract.ensure(lambda: some_condition(), error=lambda: icontract.ViolationError("hihi")) + @icontract.ensure( + lambda: some_condition(), error=lambda: icontract.ViolationError("hihi") + ) async def some_func() -> None: pass @@ -118,7 +125,9 @@ async def some_func() -> None: self.assertIsNotNone(violation_error) self.assertEqual("hihi", str(violation_error)) - async def test_reported_if_no_error_is_specified_as_we_can_not_recompute_coroutine_functions(self) -> None: + async def test_reported_if_no_error_is_specified_as_we_can_not_recompute_coroutine_functions( + self, + ) -> None: async def some_condition() -> bool: return False @@ -139,8 +148,10 @@ async def some_func() -> None: value_error = runtime_error.__cause__ self.assertRegex( - str(value_error), r"^Unexpected coroutine function as a condition of a contract\. " - r"You must specify your own error if the condition of your contract is a coroutine function\.") + str(value_error), + r"^Unexpected coroutine function as a condition of a contract\. " + r"You must specify your own error if the condition of your contract is a coroutine function\.", + ) async def test_snapshot(self) -> None: async def some_capture() -> int: @@ -167,13 +178,17 @@ async def some_function(a: int) -> None: # pylint: disable=unused-variable type_err = err self.assertIsNotNone(type_err) - self.assertEqual("The argument(s) of the contract condition have not been set: ['b']. " - "Does the original function define them? Did you supply them in the call?", - tests.error.wo_mandatory_location(str(type_err))) + self.assertEqual( + "The argument(s) of the contract condition have not been set: ['b']. " + "Does the original function define them? Did you supply them in the call?", + tests.error.wo_mandatory_location(str(type_err)), + ) async def test_conflicting_result_argument(self) -> None: @icontract.ensure(lambda a, result: a > result) - async def some_function(a: int, result: int) -> None: # pylint: disable=unused-variable + async def some_function( + a: int, result: int + ) -> None: # pylint: disable=unused-variable pass type_err = None # type: Optional[TypeError] @@ -183,12 +198,17 @@ async def some_function(a: int, result: int) -> None: # pylint: disable=unused- type_err = err self.assertIsNotNone(type_err) - self.assertEqual("Unexpected argument 'result' in a function decorated with postconditions.", str(type_err)) + self.assertEqual( + "Unexpected argument 'result' in a function decorated with postconditions.", + str(type_err), + ) async def test_conflicting_OLD_argument(self) -> None: @icontract.snapshot(lambda a: a[:]) @icontract.ensure(lambda OLD, a: a == OLD.a) - async def some_function(a: List[int], OLD: int) -> None: # pylint: disable=unused-variable + async def some_function( + a: List[int], OLD: int + ) -> None: # pylint: disable=unused-variable pass type_err = None # type: Optional[TypeError] @@ -198,11 +218,18 @@ async def some_function(a: List[int], OLD: int) -> None: # pylint: disable=unus type_err = err self.assertIsNotNone(type_err) - self.assertEqual("Unexpected argument 'OLD' in a function decorated with postconditions.", str(type_err)) + self.assertEqual( + "Unexpected argument 'OLD' in a function decorated with postconditions.", + str(type_err), + ) async def test_error_with_invalid_arguments(self) -> None: @icontract.ensure( - lambda result: result > 0, error=lambda z, result: ValueError("x is {}, result is {}".format(z, result))) + lambda result: result > 0, + error=lambda z, result: ValueError( + "x is {}, result is {}".format(z, result) + ), + ) async def some_func(x: int) -> int: return x @@ -213,9 +240,11 @@ async def some_func(x: int) -> int: type_error = err self.assertIsNotNone(type_error) - self.assertEqual("The argument(s) of the contract error have not been set: ['z']. " - "Does the original function define them? Did you supply them in the call?", - tests.error.wo_mandatory_location(str(type_error))) + self.assertEqual( + "The argument(s) of the contract error have not been set: ['z']. " + "Does the original function define them? Did you supply them in the call?", + tests.error.wo_mandatory_location(str(type_error)), + ) async def test_no_boolyness(self) -> None: @icontract.ensure(lambda: tests.mock.NumpyArray([True, False])) @@ -229,9 +258,11 @@ async def some_func() -> None: value_error = err self.assertIsNotNone(value_error) - self.assertEqual('Failed to negate the evaluation of the condition.', - tests.error.wo_mandatory_location(str(value_error))) + self.assertEqual( + "Failed to negate the evaluation of the condition.", + tests.error.wo_mandatory_location(str(value_error)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_3_8/async/test_precondition.py b/tests_3_8/async/test_precondition.py index ca3bd80..b2da5c5 100644 --- a/tests_3_8/async/test_precondition.py +++ b/tests_3_8/async/test_precondition.py @@ -31,7 +31,9 @@ async def some_func(x: int) -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("x > 0: x was -1", tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "x > 0: x was -1", tests.error.wo_mandatory_location(str(violation_error)) + ) class TestAsyncFunctionAsyncCondition(unittest.IsolatedAsyncioTestCase): @@ -61,7 +63,10 @@ async def some_func(x: int) -> int: violation_error = err self.assertIsNotNone(violation_error) - self.assertEqual("x_greater_zero: x was -1", tests.error.wo_mandatory_location(str(violation_error))) + self.assertEqual( + "x_greater_zero: x was -1", + tests.error.wo_mandatory_location(str(violation_error)), + ) class TestCoroutine(unittest.IsolatedAsyncioTestCase): @@ -79,7 +84,9 @@ async def test_fail(self) -> None: async def some_condition() -> bool: return False - @icontract.require(lambda: some_condition(), error=lambda: icontract.ViolationError("hihi")) + @icontract.require( + lambda: some_condition(), error=lambda: icontract.ViolationError("hihi") + ) async def some_func() -> None: pass @@ -92,7 +99,9 @@ async def some_func() -> None: self.assertIsNotNone(violation_error) self.assertEqual("hihi", str(violation_error)) - async def test_reported_if_no_error_is_specified_as_we_can_not_recompute_coroutine_functions(self) -> None: + async def test_reported_if_no_error_is_specified_as_we_can_not_recompute_coroutine_functions( + self, + ) -> None: async def some_condition() -> bool: return False @@ -113,8 +122,10 @@ async def some_func() -> None: value_error = runtime_error.__cause__ self.assertRegex( - str(value_error), r"^Unexpected coroutine function as a condition of a contract\. " - r"You must specify your own error if the condition of your contract is a coroutine function\.") + str(value_error), + r"^Unexpected coroutine function as a condition of a contract\. " + r"You must specify your own error if the condition of your contract is a coroutine function\.", + ) if __name__ == "__main__": diff --git a/tests_3_8/async/test_recursion.py b/tests_3_8/async/test_recursion.py index b961ae0..9a08d6f 100644 --- a/tests_3_8/async/test_recursion.py +++ b/tests_3_8/async/test_recursion.py @@ -33,8 +33,16 @@ async def yet_yet_another_func() -> bool: await some_func() - self.assertListEqual(['yet_another_func', 'yet_yet_another_func', 'some_func', 'another_func', 'some_func'], - order) + self.assertListEqual( + [ + "yet_another_func", + "yet_yet_another_func", + "some_func", + "another_func", + "some_func", + ], + order, + ) async def test_recover_after_exception(self) -> None: order = [] # type: List[str] @@ -44,15 +52,19 @@ class CustomError(Exception): pass @icontract.require(lambda: another_func()) # pylint: disable=unnecessary-lambda - @icontract.require(lambda: yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.require( + lambda: yet_another_func() + ) # pylint: disable=unnecessary-lambda async def some_func() -> bool: order.append(some_func.__name__) if some_func_should_raise: - raise CustomError('some_func_should_raise') + raise CustomError("some_func_should_raise") return True @icontract.require(lambda: some_func()) # pylint: disable=unnecessary-lambda - @icontract.require(lambda: yet_yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.require( + lambda: yet_yet_another_func() + ) # pylint: disable=unnecessary-lambda async def another_func() -> bool: order.append(another_func.__name__) return True @@ -70,7 +82,9 @@ async def yet_yet_another_func() -> bool: except CustomError: pass - self.assertListEqual(['yet_another_func', 'yet_yet_another_func', 'some_func'], order) + self.assertListEqual( + ["yet_another_func", "yet_yet_another_func", "some_func"], order + ) # Reset for the next experiment order = [] @@ -78,8 +92,16 @@ async def yet_yet_another_func() -> bool: await some_func() - self.assertListEqual(['yet_another_func', 'yet_yet_another_func', 'some_func', 'another_func', 'some_func'], - order) + self.assertListEqual( + [ + "yet_another_func", + "yet_yet_another_func", + "some_func", + "another_func", + "some_func", + ], + order, + ) class TestPostcondition(unittest.IsolatedAsyncioTestCase): @@ -91,17 +113,21 @@ class CustomError(Exception): pass @icontract.ensure(lambda: another_func()) # pylint: disable=unnecessary-lambda - @icontract.ensure(lambda: yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.ensure( + lambda: yet_another_func() + ) # pylint: disable=unnecessary-lambda async def some_func() -> bool: order.append(some_func.__name__) return True @icontract.ensure(lambda: some_func()) # pylint: disable=unnecessary-lambda - @icontract.ensure(lambda: yet_yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.ensure( + lambda: yet_yet_another_func() + ) # pylint: disable=unnecessary-lambda async def another_func() -> bool: order.append(another_func.__name__) if another_func_should_raise: - raise CustomError('some_func_should_raise') + raise CustomError("some_func_should_raise") return True @@ -118,7 +144,7 @@ async def yet_yet_another_func() -> bool: except CustomError: pass - self.assertListEqual(['some_func', 'yet_another_func', 'another_func'], order) + self.assertListEqual(["some_func", "yet_another_func", "another_func"], order) # Reset for the next experiments order = [] @@ -126,20 +152,32 @@ async def yet_yet_another_func() -> bool: await some_func() - self.assertListEqual(['some_func', 'yet_another_func', 'another_func', 'yet_yet_another_func', 'some_func'], - order) + self.assertListEqual( + [ + "some_func", + "yet_another_func", + "another_func", + "yet_yet_another_func", + "some_func", + ], + order, + ) async def test_recover_after_exception(self) -> None: order = [] # type: List[str] @icontract.ensure(lambda: another_func()) # pylint: disable=unnecessary-lambda - @icontract.ensure(lambda: yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.ensure( + lambda: yet_another_func() + ) # pylint: disable=unnecessary-lambda async def some_func() -> bool: order.append(some_func.__name__) return True @icontract.ensure(lambda: some_func()) # pylint: disable=unnecessary-lambda - @icontract.ensure(lambda: yet_yet_another_func()) # pylint: disable=unnecessary-lambda + @icontract.ensure( + lambda: yet_yet_another_func() + ) # pylint: disable=unnecessary-lambda async def another_func() -> bool: order.append(another_func.__name__) return True @@ -154,9 +192,17 @@ async def yet_yet_another_func() -> bool: await some_func() - self.assertListEqual(['some_func', 'yet_another_func', 'another_func', 'yet_yet_another_func', 'some_func'], - order) + self.assertListEqual( + [ + "some_func", + "yet_another_func", + "another_func", + "yet_yet_another_func", + "some_func", + ], + order, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_3_8/test_error.py b/tests_3_8/test_error.py index b7e3677..224fa63 100644 --- a/tests_3_8/test_error.py +++ b/tests_3_8/test_error.py @@ -14,7 +14,9 @@ class TestNoneSpecified(unittest.TestCase): - def test_that_original_call_arguments_do_not_shadow_condition_variables_in_the_generated_message(self) -> None: + def test_that_original_call_arguments_do_not_shadow_condition_variables_in_the_generated_message( + self, + ) -> None: # ``y`` in the condition shadows the ``y`` in the arguments, but the condition lambda does not refer to # the original ``y``. @icontract.require(lambda x: (y := x + 3, x > 0)[1]) @@ -29,11 +31,15 @@ def some_func(x: int, y: int) -> None: assert violation_error is not None self.assertEqual( - textwrap.dedent("""\ + textwrap.dedent( + """\ (y := x + 3, x > 0)[1]: (y := x + 3, x > 0)[1] was False x was -1 - y was 2"""), tests.error.wo_mandatory_location(str(violation_error))) + y was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) if __name__ == "__main__": diff --git a/tests_3_8/test_invariant.py b/tests_3_8/test_invariant.py index 6dabcbe..499c500 100644 --- a/tests_3_8/test_invariant.py +++ b/tests_3_8/test_invariant.py @@ -23,8 +23,10 @@ class RightHalfPlanePoint(NamedTuple): _ = RightHalfPlanePoint(1, 0) - self.assertEqual('Create new instance of RightHalfPlanePoint(first, second)', - RightHalfPlanePoint.__new__.__doc__) + self.assertEqual( + "Create new instance of RightHalfPlanePoint(first, second)", + RightHalfPlanePoint.__new__.__doc__, + ) class TestViolation(unittest.TestCase): @@ -47,11 +49,15 @@ class RightHalfPlanePoint(NamedTuple): self.assertIsNotNone(violation_error) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ self.second > 0: self was RightHalfPlanePoint(first=1, second=-1) - self.second was -1'''), tests.error.wo_mandatory_location(str(violation_error))) + self.second was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_3_8/test_represent.py b/tests_3_8/test_represent.py index b6a783f..f0d7f1e 100644 --- a/tests_3_8/test_represent.py +++ b/tests_3_8/test_represent.py @@ -13,7 +13,9 @@ class TestReprValues(unittest.TestCase): def test_named_expression(self) -> None: - @icontract.require(lambda x: (t := x + 1) and t > 1) # pylint: disable=undefined-variable + @icontract.require( + lambda x: (t := x + 1) and t > 1 + ) # pylint: disable=undefined-variable def func(x: int) -> int: return x @@ -25,11 +27,15 @@ def func(x: int) -> int: self.assertIsNotNone(violation_err) self.assertEqual( - textwrap.dedent('''\ + textwrap.dedent( + """\ (t := x + 1) and t > 1: t was 1 - x was 0'''), tests.error.wo_mandatory_location(str(violation_err))) + x was 0""" + ), + tests.error.wo_mandatory_location(str(violation_err)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_with_others/test_deal.py b/tests_with_others/test_deal.py index 279325c..5a151ca 100644 --- a/tests_with_others/test_deal.py +++ b/tests_with_others/test_deal.py @@ -51,5 +51,5 @@ def some_func(self) -> int: b.some_func() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests_with_others/test_dpcontracts.py b/tests_with_others/test_dpcontracts.py index 2cd7e5f..67a8204 100644 --- a/tests_with_others/test_dpcontracts.py +++ b/tests_with_others/test_dpcontracts.py @@ -36,12 +36,12 @@ def yet_yet_another_func() -> bool: def test_inheritance_of_postconditions_incorrect(self) -> None: class A: - @dpcontracts.ensure('dummy contract', lambda args, result: result % 2 == 0) # type: ignore + @dpcontracts.ensure("dummy contract", lambda args, result: result % 2 == 0) # type: ignore def some_func(self) -> int: return 2 class B(A): - @dpcontracts.ensure('dummy contract', lambda args, result: result % 3 == 0) # type: ignore + @dpcontracts.ensure("dummy contract", lambda args, result: result % 3 == 0) # type: ignore def some_func(self) -> int: # The result 9 satisfies the postcondition of B.some_func, but not A.some_func. return 9 @@ -51,5 +51,5 @@ def some_func(self) -> int: b.some_func() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()