Skip to content

Commit

Permalink
Fix SSH execution not redirecting output properly
Browse files Browse the repository at this point in the history
This patch aims to save commands stdout/stderr once they are executed
inside the SSH channel. The whole implementation has been changed, in
order to customize the current SSH session object and to store both
stdout and stderr messages inside it, as well as checking for Kernel
Panic triggered by the command. In this way, the whole SSH
implementation should be also more stable.

Also, SSH tests have been improved by adding more tests, in order to
check stderr acquisition and long stdout text messages.
  • Loading branch information
acerv committed Mar 14, 2024
1 parent cf0680d commit cda4a74
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 24 deletions.
82 changes: 58 additions & 24 deletions libkirk/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,42 @@
try:
import asyncssh
import asyncssh.misc

class MySSHClientSession(asyncssh.SSHClientSession):
"""
Custom SSHClientSession used to store stdout during execution of commands
and to check if Kernel Panic has occured in the system.
"""

def __init__(self, iobuffer: IOBuffer):
self._output = []
self._iobuffer = iobuffer
self._panic = False

def data_received(self, data, datatype):
"""
Override default data_received callback, storing stdout/stderr inside
a buffer and checking for kernel panic.
"""
self._output.append(data)

if self._iobuffer:
self._iobuffer.write(data)

if "Kernel panic" in data:
self._panic = True

def kernel_panic(self) -> bool:
"""
True if command triggered a kernel panic during its execution.
"""
return self._panic

def get_output(self) -> list:
"""
Return the list containing stored stdout/stderr messages.
"""
return self._output
except ModuleNotFoundError:
pass

Expand All @@ -42,7 +78,7 @@ def __init__(self) -> None:
self._stop = False
self._conn = None
self._downloader = None
self._procs = []
self._channels = []

@property
def name(self) -> str:
Expand Down Expand Up @@ -188,14 +224,14 @@ async def stop(self, iobuffer: IOBuffer = None) -> None:

self._stop = True
try:
if self._procs:
self._logger.info("Killing %d process(es)", len(self._procs))
if self._channels:
self._logger.info("Killing %d process(es)", len(self._channels))

for proc in self._procs:
for proc in self._channels:
proc.kill()
await proc.wait()
await proc.wait_closed()

self._procs.clear()
self._channels.clear()

if self._downloader:
await self._downloader.close()
Expand Down Expand Up @@ -244,38 +280,36 @@ async def run_command(
async with self._session_sem:
cmd = self._create_command(command, cwd, env)
ret = None
proc = None
start_t = 0
stdout = None
panic = False
channel = None
session = None

try:
self._logger.info("Running command: %s", repr(command))

proc = await self._conn.create_process(cmd)
self._procs.append(proc)
channel, session = await self._conn.create_session(
lambda: MySSHClientSession(iobuffer),
cmd
)

self._channels.append(channel)
start_t = time.time()
panic = False
stdout = ""

async for data in proc.stdout:
stdout += data

if iobuffer:
await iobuffer.write(data)
await channel.wait_closed()

if "Kernel panic" in data:
panic = True
panic = session.kernel_panic()
stdout = session.get_output()
finally:
if proc:
self._procs.remove(proc)

await proc.wait()
if channel:
self._channels.remove(channel)

ret = {
"command": command,
"returncode": proc.returncode,
"returncode": channel.get_returncode(),
"exec_time": time.time() - start_t,
"stdout": stdout
"stdout": "".join(stdout)
}

if panic:
Expand Down
26 changes: 26 additions & 0 deletions libkirk/tests/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Unittests for ssh module.
"""
import os
import subprocess
import asyncio
import pytest
from libkirk.sut import IOBuffer
Expand Down Expand Up @@ -108,6 +109,31 @@ async def test_kernel_panic(self, sut):
await sut.run_command(
"echo 'Kernel panic\nThis is a generic message'")

async def test_stderr(self, sut):
"""
Test if we are correctly reading stderr.
"""
await sut.communicate()

ret = await sut.run_command(">&2 echo ciao_stderr && echo ciao_stdout")
assert ret["stdout"] == "ciao_stdout\nciao_stderr\n"

async def test_long_stdout(self, sut):
"""
Test really long stdout.
"""
await sut.communicate()

result = subprocess.run(
"tr -dc 'a-zA-Z0-9' </dev/urandom | head -c 10000",
shell=True,
capture_output=True,
text=True,
check=True)

ret = await sut.run_command(f"echo -n {result.stdout}")
assert ret["stdout"] == result.stdout


@pytest.fixture
def config_password(tmpdir):
Expand Down

0 comments on commit cda4a74

Please sign in to comment.