diff --git a/pyproject.toml b/pyproject.toml index 03b05b9c2..df25eeed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.25" +version = "0.9.26" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/cli/cli.py b/truss/cli/cli.py index 5ef353e50..bd2efa4fd 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -728,8 +728,6 @@ def predict( @click.argument("script", required=True) @click.argument("target_directory", required=False, default=os.getcwd()) def run_python(script, target_directory): - from python_on_whales.exceptions import DockerException - if not Path(script).exists(): raise click.BadParameter( f"File {script} does not exist. Please provide a valid file." @@ -744,20 +742,19 @@ def run_python(script, target_directory): ) tr = _get_truss_from_directory(target_directory=target_directory) - output_stream = tr.run_python_script(Path(script)) - try: - for output in output_stream: - output_type = output[0] - output_content = output[1] + container = tr.run_python_script(Path(script)) + for output in container.logs(): + output_type = output[0] + output_content = output[1] - options = {} + options = {} - if output_type == "stderr": - options["fg"] = "red" + if output_type == "stderr": + options["fg"] = "red" - click.secho(output_content.decode("utf-8", "replace"), nl=False, **options) - except DockerException: - pass + click.secho(output_content.decode("utf-8", "replace"), nl=False, **options) + exit_code = container.wait() + sys.exit(exit_code) @truss_cli.command() diff --git a/truss/truss_handle.py b/truss/truss_handle.py index 453275d93..0322bb33e 100644 --- a/truss/truss_handle.py +++ b/truss/truss_handle.py @@ -78,6 +78,21 @@ logger.addHandler(logging.StreamHandler(sys.stdout)) +class RunningContainer: + def __init__(self, container): + self.container = container + + def logs(self): + from python_on_whales import docker + + return docker.logs(self.container, follow=True, stream=True) + + def wait(self): + from python_on_whales import docker + + return docker.wait(self.container) + + class TrussHandle: def __init__(self, truss_dir: Path, validate: bool = True) -> None: self._truss_dir = truss_dir @@ -198,7 +213,7 @@ def _docker_run(gpus: Optional[str] = None): add_hosts=[("host.docker.internal", "host-gateway")], ) - return Docker.client().logs(container, follow=True, stream=True) + return RunningContainer(container) try: return _docker_run("all" if self._spec.config.resources.use_gpu else None)