Skip to content

Commit

Permalink
Unify libssh/russh errors in one type.
Browse files Browse the repository at this point in the history
Also add a `.with()` method to attach information to the error which may
or may not be present.
  • Loading branch information
Tehforsch committed Oct 29, 2024
1 parent 03c3e95 commit e0b66ca
Show file tree
Hide file tree
Showing 14 changed files with 238 additions and 199 deletions.
170 changes: 170 additions & 0 deletions rust/src/nasl/builtin/ssh/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use std::fmt;

use thiserror::Error;

use crate::nasl::FunctionErrorKind;

use super::SessionId;

/// A cloneable representation of the Error type of the underlying SSH lib
#[derive(Clone, Debug, PartialEq, Eq, Error)]
#[error("{0}")]
pub struct LibError(String);

#[cfg(feature = "nasl-builtin-libssh")]
impl From<libssh_rs::Error> for LibError {
fn from(e: libssh_rs::Error) -> Self {
Self(format!("{}", e))
}
}

#[cfg(not(feature = "nasl-builtin-libssh"))]
impl From<russh::Error> for LibError {
fn from(e: russh::Error) -> Self {
Self(format!("{}", e))
}
}

#[derive(Clone, Debug, PartialEq, Eq, Error)]
pub struct SshError {
pub kind: SshErrorKind,
id: Option<SessionId>,
#[source]
source: Option<LibError>,
}

pub type Result<T> = std::result::Result<T, SshError>;

impl fmt::Display for SshError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.kind)?;
if let Some(id) = self.id {
write!(f, " Session ID: {0}.", id)?;
}
if let Some(ref source) = self.source {
write!(f, " {}", source)?;
}
Ok(())
}
}

#[derive(Clone, Debug, PartialEq, Eq, Error)]
pub enum SshErrorKind {
#[error("Failed to open new SSH session.")]
NewSession,
#[error("Invalid SSH session ID.")]
InvalidSessionId,
#[error("Poisoned lock.")]
PoisonedLock,
#[error("Failed to connect.")]
Connect,
#[error("Failed to open a new channel.")]
OpenChannel,
#[error("No available channel.")]
NoAvailableChannel,
#[error("Channel unexpectedly closed.")]
ChannelClosed,
#[error("Failed to request subsystem {0}.")]
RequestSubsystem(String),
#[error("Failed to open session.")]
OpenSession,
#[error("Failed to close channel.")]
Close,
#[error("Failed to request PTY.")]
RequestPty,
#[error("Failed to request command execution.")]
RequestExec(String),
#[error("Failed to request shell.")]
RequestShell,
#[error("Failed to get server public key.")]
GetServerPublicKey,
#[error("Failed to get server banner.")]
GetServerBanner,
#[error("Failed to get issue banner.")]
GetIssueBanner,
#[error("Failed to set SSH option {0}.")]
SetOption(String),
#[error("Failed to set authentication to keyboard-interactive.")]
UserAuthKeyboardInteractiveInfo,
#[error("Failed to initiate keyboard-interactive authentication.")]
UserAuthKeyboardInteractive,
#[error("Failed to set answers for authentication via keyboard-interactive.")]
UserAuthKeyboardInteractiveSetAnswers,
#[error("Failed to authenticate via password.")]
UserAuthPassword,
#[error("Failed to perform 'none' authentication.")]
UserAuthNone,
#[error("Failed to request list of authentication methods.")]
UserAuthList,
#[error("Failed to check whether public key authentication is possible")]
UserAuthTryPublicKey,
#[error("Failed to authenticate with public key.")]
UserAuthPublicKey,
#[error("Failed to read.")]
ReadSsh,
#[error("Error initiating SFTP.")]
Sftp,
#[error("Failed to parse IP address '{0}' with error {1}.")]
InvalidIpAddr(String, std::net::AddrParseError),
#[error("Attempted to authenticate without authentication data.")]
NoAuthenticationGiven,
#[error("Error while converting private key")]
ConvertPrivateKey,
#[error("Not yet implemented.")]
Unimplemented,
}

pub trait ErrorInfo {
fn attach_error_info(self, e: SshError) -> SshError;
}

impl ErrorInfo for SessionId {
fn attach_error_info(self, mut e: SshError) -> SshError {
e.id = Some(self);
e
}
}

#[cfg(feature = "nasl-builtin-libssh")]
impl ErrorInfo for libssh_rs::Error {
fn attach_error_info(self, mut e: SshError) -> SshError {
e.source = Some(self.into());
e
}
}

#[cfg(not(feature = "nasl-builtin-libssh"))]
impl ErrorInfo for russh::Error {
fn attach_error_info(self, mut e: SshError) -> SshError {
e.source = Some(self.into());
e
}
}

impl From<SshErrorKind> for SshError {
fn from(kind: SshErrorKind) -> Self {
SshError {
kind,
source: None,
id: None,
}
}
}

impl SshErrorKind {
pub fn with(self, m: impl ErrorInfo) -> SshError {
m.attach_error_info(self.into())
}
}

impl SshError {
pub fn with(self, m: impl ErrorInfo) -> SshError {
m.attach_error_info(self)
}
}

impl From<SshError> for FunctionErrorKind {
fn from(e: SshError) -> Self {
FunctionErrorKind::Ssh(e)
}
}
11 changes: 6 additions & 5 deletions rust/src/nasl/builtin/ssh/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::nasl::{
utils::{IntoFunctionSet, StoredFunctionSet},
};

use super::{utils::CommaSeparated, AuthMethods, SessionId, Socket, Ssh, SshError};
use super::error::SshError;
use super::{error::SshErrorKind, utils::CommaSeparated, AuthMethods, SessionId, Socket, Ssh};

#[cfg(feature = "nasl-builtin-libssh")]
mod libssh_uses {
Expand Down Expand Up @@ -111,9 +112,9 @@ impl Ssh {
let port = port
.filter(|_| socket.is_none())
.unwrap_or(DEFAULT_SSH_PORT);
let ip = ctx
.target_ip()
.map_err(|e| SshError::InvalidIpAddr(ctx.target().to_string(), e))?;
let ip = ctx.target_ip().map_err(|e| {
SshError::from(SshErrorKind::InvalidIpAddr(ctx.target().to_string(), e))
})?;
let timeout = timeout.map(Duration::from_secs);
let keytype = keytype
.map(|keytype| keytype.0)
Expand Down Expand Up @@ -249,7 +250,7 @@ impl Ssh {
) -> Result<()> {
if password.is_none() && privatekey.is_none() && passphrase.is_none() {
//TODO: Get values from KB
return Err(SshError::NoAuthenticationGiven(session_id).into());
return Err(SshErrorKind::NoAuthenticationGiven.with(session_id).into());
}
let login = login.unwrap_or("");
let mut session = self.get_by_id(session_id).await?;
Expand Down
37 changes: 20 additions & 17 deletions rust/src/nasl/builtin/ssh/libssh/channel.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::time::Duration;

use super::{
error::{Result, SshError},
SessionId,
};
use crate::nasl::builtin::ssh::error::SshErrorKind;

use super::{super::error::Result, SessionId};

pub struct Channel {
channel: libssh_rs::Channel,
Expand All @@ -20,14 +19,16 @@ impl Channel {

pub fn request_subsystem(&self, subsystem: &str) -> Result<()> {
self.channel.request_subsystem(subsystem).map_err(|e| {
SshError::RequestSubsystem(self.session_id, e.into(), subsystem.to_string())
SshErrorKind::RequestSubsystem(subsystem.to_string())
.with(e)
.with(self.session_id)
})
}

pub fn open_session(&self) -> Result<()> {
self.channel
.open_session()
.map_err(|e| SshError::OpenSession(self.session_id, e.into()))
.map_err(|e| SshErrorKind::OpenSession.with(self.session_id).with(e))
}

pub fn is_closed(&self) -> bool {
Expand All @@ -37,7 +38,7 @@ impl Channel {
pub fn close(&self) -> Result<()> {
self.channel
.close()
.map_err(|e| SshError::Close(self.session_id, e.into()))
.map_err(|e| SshErrorKind::Close.with(self.session_id).with(e))
}

pub fn stdin(&self) -> impl std::io::Write + '_ {
Expand All @@ -47,31 +48,33 @@ impl Channel {
pub fn request_pty(&self, term: &str, columns: u32, rows: u32) -> Result<()> {
self.channel
.request_pty(term, columns, rows)
.map_err(|e| SshError::RequestPty(self.session_id, e.into()))
.map_err(|e| SshErrorKind::RequestPty.with(self.session_id).with(e))
}

pub fn request_exec(&self, command: &str) -> Result<()> {
self.channel
.request_exec(command)
.map_err(|e| SshError::RequestExec(self.session_id, e.into()))
self.channel.request_exec(command).map_err(|e| {
SshErrorKind::RequestExec(command.to_string())
.with(self.session_id)
.with(e)
})
}

pub fn request_shell(&self) -> Result<()> {
self.channel
.request_shell()
.map_err(|e| SshError::RequestShell(self.session_id, e.into()))
.map_err(|e| SshErrorKind::RequestShell.with(self.session_id).with(e))
}

pub fn ensure_open(&self) -> Result<()> {
if self.is_closed() {
Err(SshError::ChannelClosed(self.session_id))
Err(SshErrorKind::ChannelClosed.with(self.session_id))
} else {
Ok(())
}
}

fn buf_as_str<'a>(&self, buf: &'a [u8]) -> Result<&'a str> {
std::str::from_utf8(buf).map_err(|_| SshError::ReadSsh(self.session_id))
std::str::from_utf8(buf).map_err(|_| SshErrorKind::ReadSsh.with(self.session_id))
}

pub fn read_timeout(&self, timeout: Duration, stderr: bool) -> Result<String> {
Expand All @@ -85,7 +88,7 @@ impl Channel {
}
Err(libssh_rs::Error::TryAgain) => {}
Err(_) => {
return Err(SshError::ReadSsh(self.session_id));
return Err(SshErrorKind::ReadSsh.with(self.session_id));
}
}
}
Expand All @@ -105,13 +108,13 @@ impl Channel {
let response = self.buf_as_str(&buf[..n])?.to_string();
Ok(response)
}
Err(_) => Err(SshError::ReadSsh(self.session_id)),
Err(_) => Err(SshErrorKind::ReadSsh.with(self.session_id)),
}
}

pub fn read_ssh_nonblocking(&self) -> Result<String> {
if self.channel.is_closed() || self.channel.is_eof() {
return Err(SshError::ReadSsh(self.session_id));
return Err(SshErrorKind::ReadSsh.with(self.session_id));
}

let stderr = self.read_nonblocking(true)?;
Expand Down
Loading

0 comments on commit e0b66ca

Please sign in to comment.