Skip to content

Commit

Permalink
Merge pull request #33 from elcaminoreal/add_error_and_subcommand
Browse files Browse the repository at this point in the history
add better error behavior and subcommand detection
  • Loading branch information
moshez authored Jan 5, 2024
2 parents 5f3cd8a + daa2cd6 commit bfb5972
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "gather"
version = "2024.1.5.1"
version = "2024.1.5.2"
description = "A gatherer"
readme = "README.rst"
authors = [{name = "Moshe Zadka", email = "[email protected]"}]
Expand Down
27 changes: 25 additions & 2 deletions src/gather/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,15 @@ def wrapped_dry_run(cmdargs, **kwargs):
args.orig_run = orig_run


def run_maybe_dry(*, parser, argv=sys.argv, env=os.environ, sp_run=subprocess.run):
def run_maybe_dry(
*,
parser,
argv=sys.argv,
env=os.environ,
sp_run=subprocess.run,
is_subcommand=False,
prefix=None,
):
"""
Run commands that only take ``args``.
Expand All @@ -128,11 +136,26 @@ def run_maybe_dry(*, parser, argv=sys.argv, env=os.environ, sp_run=subprocess.ru
* ``safe_run``: Run with logging
* ``orig_run``: Original function
"""

def error(args):
parser.print_help()
raise SystemExit(1)

argv = list(argv)
if is_subcommand:
argv[0:0] = [prefix or "base-command"]
argv[1] = argv[1].rsplit("/", 1)[-1]
if prefix is not None:
argv[1] = argv[1].removeprefix(prefix + "-")

args = parser.parse_args(argv[1:])
args.orig_run = sp_run
args.env = env
_make_safe_run(args)
command = args.__gather_command__
try:
command = args.__gather_command__
except AttributeError:
command = error
return command(
args=args,
)
Expand Down
79 changes: 79 additions & 0 deletions src/gather/tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,32 @@ def test_custom_parser(self):
contains_string("custom help message"),
)


class CommandMaybeDryTest(unittest.TestCase):

"""Test run_maybe_dry"""

def test_error(self):
"""Help message is printed out"""
parser = commands.set_parser(collected=MAYBE_DRY_COMMANDS_COLLECTOR.collect())
mock_output = mock.patch("sys.stdout", new=io.StringIO())
self.addCleanup(mock_output.stop)
fake_stdout = mock_output.start()
assert_that(
calling(commands.run_maybe_dry).with_args(
parser=parser,
argv=["command"],
env={},
sp_run=subprocess.run,
),
raises(SystemExit),
)
output = fake_stdout.getvalue()
assert_that(
output,
contains_string("usage"),
)

def test_with_dry(self):
"""Test running command in dry-run mode"""
parser = commands.set_parser(collected=MAYBE_DRY_COMMANDS_COLLECTOR.collect())
Expand Down Expand Up @@ -203,3 +229,56 @@ def test_with_dry_fail(self):
),
raises(subprocess.CalledProcessError),
)

def test_with_subcommand(self):
"""Test running command as subcommand"""
parser = commands.set_parser(collected=MAYBE_DRY_COMMANDS_COLLECTOR.collect())
with contextlib.ExitStack() as stack:
tmp_dir = pathlib.Path(stack.enter_context(tempfile.TemporaryDirectory()))
commands.run_maybe_dry(
parser=parser,
argv=[
"write-safely",
"--output-dir",
os.fspath(tmp_dir),
"--no-dry-run",
],
is_subcommand=True,
env={},
sp_run=subprocess.run,
)
contents = {child.name: child.read_text() for child in tmp_dir.iterdir()}
assert_that(
contents,
all_of(
has_entry("unsafe.txt", "2"),
has_entry("safe.txt", "2"),
),
)

def test_with_prefixed_subcommand(self):
"""Test running command as subcommand"""
parser = commands.set_parser(collected=MAYBE_DRY_COMMANDS_COLLECTOR.collect())
with contextlib.ExitStack() as stack:
tmp_dir = pathlib.Path(stack.enter_context(tempfile.TemporaryDirectory()))
commands.run_maybe_dry(
parser=parser,
argv=[
"command-write-safely",
"--output-dir",
os.fspath(tmp_dir),
"--no-dry-run",
],
is_subcommand=True,
prefix="command",
env={},
sp_run=subprocess.run,
)
contents = {child.name: child.read_text() for child in tmp_dir.iterdir()}
assert_that(
contents,
all_of(
has_entry("unsafe.txt", "2"),
has_entry("safe.txt", "2"),
),
)

0 comments on commit bfb5972

Please sign in to comment.