diff --git a/newsfragments/102.bugfix.rst b/newsfragments/102.bugfix.rst new file mode 100644 index 0000000..247e29a --- /dev/null +++ b/newsfragments/102.bugfix.rst @@ -0,0 +1 @@ +Support reading annotated variables (e.g. `__requires__: list[str] = ["jaraco.functools"]`). \ No newline at end of file diff --git a/pip_run/scripts.py b/pip_run/scripts.py index 32baff6..2c8c43d 100644 --- a/pip_run/scripts.py +++ b/pip_run/scripts.py @@ -1,6 +1,5 @@ import abc import ast -import contextlib import itertools import json import pathlib @@ -151,24 +150,28 @@ def read_python(self): >>> DepsReader(r"__requires__='foo\nbar\n#baz'").read() ['foo', 'bar'] """ - raw_reqs = suppress(ValueError)(self._read)('__requires__') or [] + raw_reqs = self._read('__requires__', default=()) reqs_items = jaraco.text.yield_lines(raw_reqs) deps = Dependencies.load(reqs_items) - with contextlib.suppress(Exception): - deps.index_url = self._read('__index_url__') + deps.index_url = self._read('__index_url__') return deps - def _read(self, var_name): + def _read(self, var_name, default=None): mod = ast.parse(self.script) - (node,) = ( - node - for node in mod.body - if isinstance(node, ast.Assign) - and len(node.targets) == 1 - and isinstance(node.targets[0], ast.Name) - and node.targets[0].id == var_name - ) - return ast.literal_eval(node.value) + code = None + for node in mod.body: + if isinstance(node, ast.Assign) and len(node.targets) == 1: + (target,) = node.targets + elif isinstance(node, ast.AnnAssign): + target = node.target + else: + continue + if isinstance(target, ast.Name) and target.id == var_name: + if code: + code = None + break + code = node.value + return ast.literal_eval(code) if code else default class SourceDepsReader(DepsReader): diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 877c0a5..62ae31a 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -73,6 +73,31 @@ def test_single_dep(self): ) assert scripts.DepsReader(script).read() == ['foo'] + def test_single_annotated_dep(self): + script = textwrap.dedent( + """ + __requires__:str='foo' + """ + ) + assert scripts.DepsReader(script).read() == ['foo'] + + def test_multiple_annotated_deps(self): + script = textwrap.dedent( + """ + __requires__:list[str]=['foo', 'bar'] + """ + ) + assert scripts.DepsReader(script).read() == ['foo', 'bar'] + + def test_skips_on_ambiguity(self): + script = textwrap.dedent( + """ + __requires__:list[str]=['foo'] + __requires__='bar' + """ + ) + assert scripts.DepsReader(script).read() == [] + def test_index_url(self): script = textwrap.dedent( """