From c1c9607011e6ba52151d1d63359161e4e702f5b4 Mon Sep 17 00:00:00 2001 From: Bartek Sokorski Date: Wed, 25 Oct 2023 15:31:38 +0200 Subject: [PATCH] Fix script name detection --- src/cleo/commands/completions_command.py | 30 +++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/cleo/commands/completions_command.py b/src/cleo/commands/completions_command.py index b5c871d6..be2f8b3c 100644 --- a/src/cleo/commands/completions_command.py +++ b/src/cleo/commands/completions_command.py @@ -9,11 +9,13 @@ from typing import TYPE_CHECKING from typing import ClassVar +from typing import cast from cleo import helpers from cleo._compat import shell_quote from cleo.commands.command import Command from cleo.commands.completions.templates import TEMPLATES +from cleo.exceptions import CleoRuntimeError if TYPE_CHECKING: @@ -137,10 +139,32 @@ def render(self, shell: str) -> str: raise RuntimeError(f"Unrecognized shell: {shell}") + @staticmethod + def _get_prog_name_from_stack() -> str: + package_name = "" + frame = inspect.currentframe() + f_back = frame.f_back if frame is not None else None + f_globals = f_back.f_globals if f_back is not None else None + # break reference cycle + # https://docs.python.org/3/library/inspect.html#the-interpreter-stack + del frame + + if f_globals is not None: + package_name = cast(str, f_globals.get("__name__")) + + if package_name == "__main__": + package_name = cast(str, f_globals.get("__package__")) + + if package_name: + package_name = package_name.partition(".")[0] + + if not package_name: + raise CleoRuntimeError("Can not determine package name") + + return package_name + def _get_script_name_and_path(self) -> tuple[str, str]: - # FIXME: when generating completions via `python -m script completions`, - # we incorrectly infer `script_name` as `__main__.py` - script_name = self._io.input.script_name or inspect.stack()[-1][1] + script_name = self._io.input.script_name or self._get_prog_name_from_stack() script_path = posixpath.realpath(script_name) script_name = os.path.basename(script_path)