Skip to content

Commit

Permalink
fix(misc): Fixed unreasonable parts in the pull request
Browse files Browse the repository at this point in the history
1. add comment to `ossl_init` referring to rats-tls repository
2. inline `init` method of `Client` and `Server` to each
Builders' method, remove unused Option
3. remove `Arc<Mutex<Cell<>>` wrapper for openssl CRYPTO index
since it never change after initialized
4. move `GetFdDumpImpl` to test module as a stream mock
  • Loading branch information
csyJoy authored and imlk0 committed Sep 2, 2024
1 parent 8680139 commit fe1e961
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 185 deletions.
2 changes: 2 additions & 0 deletions rats-rs/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ pub enum ErrorKind {

SpdmlibError,

OsslTlsBuilderStreamUnset,

OsslUnsupportedPkeyAlgo,

OsslCtxUninitialize,
Expand Down
148 changes: 64 additions & 84 deletions rats-rs/src/transport/tls/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
as_raw, as_raw_mut, ossl_init, verify_certificate_default, EpvPkey, GetFd, GetFdDumpImpl,
SslMode, TcpWrapper, OPENSSL_EX_DATA_IDX,
as_raw, as_raw_mut, ossl_init, verify_certificate_default, EpvPkey, GetFd, SslMode, TcpWrapper,
OPENSSL_EX_DATA_IDX,
};
use crate::{
cert::{
Expand All @@ -25,12 +25,12 @@ use std::{
ffi::c_int,
net::{TcpStream, ToSocketAddrs},
os::fd::AsRawFd,
ptr,
ptr::{self, null_mut},
sync::{Arc, Mutex},
};

pub struct Client {
ctx: Option<*mut SSL_CTX>,
ctx: *mut SSL_CTX,
ssl_session: Option<*mut SSL>,
verify_callback: SSL_verify_cb,
stream: Box<dyn GetFd>,
Expand All @@ -39,28 +39,42 @@ pub struct Client {

pub struct TlsClientBuilder {
verify: SSL_verify_cb,
stream: Box<dyn GetFd>,
stream: Option<Box<dyn GetFd>>,
attest_self: bool,
}

impl TlsClientBuilder {
pub fn build(self) -> Result<Client> {
ossl_init()?;
let ctx = unsafe { SSL_CTX_new(TLS_client_method()) };
if ctx.is_null() {
return Err(Error::kind(ErrorKind::OsslCtxInitializeFail));
}
let mut c = Client {
ctx: None,
ctx: ctx,
ssl_session: None,
verify_callback: Some(self.verify.unwrap_or(verify_certificate_default)),
stream: self.stream,
stream: self
.stream
.ok_or(Error::kind(ErrorKind::OsslTlsBuilderStreamUnset))?,
attest_self: self.attest_self,
};
c.init()?;
if c.attest_self {
let privkey = DefaultCrypto::gen_private_key(crate::crypto::AsymmetricAlgo::Rsa2048)?;
c.use_privkey(&privkey)?;
let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256)
.build_with_private_key(&privkey)?
.cert_to_der()?;
c.use_cert(&cert)?;
}
Ok(c)
}
pub fn with_verify(mut self, verify: SSL_verify_cb) -> Self {
self.verify = verify;
self
}
pub fn with_tcp_stream(mut self, stream: TcpStream) -> Self {
self.stream = Box::new(TcpWrapper(stream));
self.stream = Some(Box::new(TcpWrapper(stream)));
self
}
pub fn with_attest_self(mut self, attest_self: bool) -> Self {
Expand All @@ -70,7 +84,7 @@ impl TlsClientBuilder {
pub fn new() -> Self {
Self {
verify: None,
stream: Box::new(GetFdDumpImpl {}),
stream: None,
attest_self: false,
}
}
Expand All @@ -79,7 +93,7 @@ impl TlsClientBuilder {
#[maybe_async]
impl GenericSecureTransPortWrite for Client {
async fn send(&mut self, bytes: &[u8]) -> Result<()> {
if self.ctx.is_none() || self.ssl_session.is_none() {
if self.ssl_session.is_none() {
return Err(Error::kind(ErrorKind::OsslCtxOrSessionUninitialized));
}
let res = unsafe {
Expand All @@ -99,22 +113,29 @@ impl GenericSecureTransPortWrite for Client {
if let Some(ssl_session) = self.ssl_session {
unsafe {
SSL_shutdown(ssl_session);
SSL_free(ssl_session);
}
}
if let Some(ctx) = self.ctx {
Ok(())
}
}

impl Drop for Client {
fn drop(&mut self) {
if let Some(ssl_session) = self.ssl_session {
unsafe {
SSL_CTX_free(ctx);
SSL_free(ssl_session);
}
}
Ok(())
unsafe {
SSL_CTX_free(self.ctx);
}
}
}

#[maybe_async]
impl GenericSecureTransPortRead for Client {
async fn receive(&mut self, buf: &mut [u8]) -> Result<usize> {
if self.ctx.is_none() || self.ssl_session.is_none() {
if self.ssl_session.is_none() {
return Err(Error::kind(ErrorKind::OsslCtxOrSessionUninitialized));
}
let res = unsafe {
Expand All @@ -134,9 +155,7 @@ impl GenericSecureTransPortRead for Client {
#[maybe_async]
impl GenericSecureTransPort for Client {
async fn negotiate(&mut self) -> Result<()> {
let ctx = self
.ctx
.ok_or(Error::kind(ErrorKind::OsslCtxUninitialize))?;
let ctx = self.ctx;
if self.verify_callback.is_some() {
let mode = SslMode::SSL_VERIFY_PEER;
unsafe {
Expand All @@ -150,7 +169,7 @@ impl GenericSecureTransPort for Client {
unsafe {
X509_STORE_set_ex_data(
SSL_CTX_get_cert_store(ctx),
OPENSSL_EX_DATA_IDX.lock().unwrap().get(),
*OPENSSL_EX_DATA_IDX,
as_raw_mut(self),
);
}
Expand All @@ -174,23 +193,6 @@ impl GenericSecureTransPort for Client {
}

impl Client {
pub fn init(&mut self) -> Result<()> {
ossl_init()?;
let ctx = unsafe { SSL_CTX_new(TLS_client_method()) };
if ctx.is_null() {
return Err(Error::kind(ErrorKind::OsslCtxInitializeFail));
}
self.ctx = Some(ctx);
if self.attest_self {
let privkey = DefaultCrypto::gen_private_key(crate::crypto::AsymmetricAlgo::Rsa2048)?;
self.use_privkey(&privkey)?;
let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256)
.build_with_private_key(&privkey)?
.cert_to_der()?;
self.use_cert(&cert)?;
}
Ok(())
}
pub fn use_privkey(&mut self, privkey: &AsymmetricPrivateKey) -> Result<()> {
let pkey;
let epkey: ::libc::c_int;
Expand All @@ -206,9 +208,7 @@ impl Client {
epkey = EpvPkey::EC.bits();
}
}
let ctx = self
.ctx
.ok_or(Error::kind(ErrorKind::OsslCtxUninitialize))?;
let ctx = self.ctx;
let pkey_len = pkey.as_bytes().len() as ::libc::c_long;
let pkey_buffer = as_raw(&pkey.as_bytes()[0]);
unsafe {
Expand All @@ -220,9 +220,7 @@ impl Client {
Ok(())
}
pub fn use_cert(&mut self, cert: &Vec<u8>) -> Result<()> {
let ctx = self
.ctx
.ok_or(Error::kind(ErrorKind::OsslCtxUninitialize))?;
let ctx = self.ctx;
let res = unsafe {
SSL_CTX_use_certificate_ASN1(
ctx,
Expand All @@ -239,49 +237,41 @@ impl Client {

#[cfg(test)]
mod tests {
use super::Client;
use super::{Client, TlsClientBuilder};
use crate::{
cert::create::CertBuilder,
crypto::{AsymmetricAlgo, DefaultCrypto, HashAlgo},
errors::*,
tee::{AutoAttester, AutoVerifier},
transport::{
tls::{as_raw, as_raw_mut, GetFdDumpImpl},
tls::{as_raw, as_raw_mut, ossl_init, GetFd},
GenericSecureTransPortWrite,
},
};
use openssl_sys::*;
use std::{ptr, slice};
use std::{
net::TcpStream,
ptr::{self, null_mut},
slice,
};

#[test]
fn test_client_init() -> Result<()> {
let mut c = Client {
ctx: None,
ssl_session: None,
verify_callback: None,
stream: Box::new(GetFdDumpImpl),
attest_self: false,
};
c.init()?;
Ok(())
struct GetFdDumpImpl;
impl GetFd for GetFdDumpImpl {
fn get_fd(&self) -> i32 {
0
}
}

#[test]
fn test_client_shutdown() -> Result<()> {
let mut c = Client {
ctx: None,
ssl_session: None,
verify_callback: None,
stream: Box::new(GetFdDumpImpl),
attest_self: false,
};
c.init()?;
let mut builder = TlsClientBuilder::new();
builder.stream = Some(Box::new(GetFdDumpImpl));
let mut c = builder.build()?;
c.shutdown()?;
Ok(())
}

fn ossl_get_privkey(c: &mut Client) -> Vec<u8> {
let ssl_session = unsafe { SSL_new(c.ctx.unwrap()) };
let ssl_session = unsafe { SSL_new(c.ctx) };
let pkey = unsafe { SSL_get_privatekey(ssl_session) };
let bio = unsafe { BIO_new(BIO_s_mem()) };
let res = unsafe {
Expand Down Expand Up @@ -309,14 +299,9 @@ mod tests {

#[test]
fn test_client_use_key() -> Result<()> {
let mut c = Client {
ctx: None,
ssl_session: None,
verify_callback: None,
stream: Box::new(GetFdDumpImpl),
attest_self: false,
};
c.init()?;
let mut builder = TlsClientBuilder::new();
builder.stream = Some(Box::new(GetFdDumpImpl));
let mut c = builder.build()?;
let privkey = DefaultCrypto::gen_private_key(AsymmetricAlgo::Rsa2048)?;
let binding = privkey.to_pkcs8_pem()?;
let privpem = binding.as_bytes();
Expand All @@ -329,20 +314,15 @@ mod tests {

#[test]
fn test_client_use_cert() -> Result<()> {
let mut c = Client {
ctx: None,
ssl_session: None,
verify_callback: None,
stream: Box::new(GetFdDumpImpl),
attest_self: false,
};
c.init()?;
let mut builder = TlsClientBuilder::new();
builder.stream = Some(Box::new(GetFdDumpImpl));
let mut c = builder.build()?;
let privkey = DefaultCrypto::gen_private_key(AsymmetricAlgo::Rsa2048)?;
let cert = CertBuilder::new(AutoAttester::new(), HashAlgo::Sha256)
.build_with_private_key(&privkey)?
.cert_to_der()?;
c.use_cert(&cert)?;
let raw_cert = unsafe { SSL_CTX_get0_certificate(c.ctx.unwrap()) };
let raw_cert = unsafe { SSL_CTX_get0_certificate(c.ctx) };
let mut raw_ptr = ptr::null_mut::<u8>();
let len = unsafe { i2d_X509(raw_cert, &mut raw_ptr as *mut *mut u8) };
let now = unsafe { slice::from_raw_parts(raw_ptr as *const u8, len as usize).to_vec() };
Expand Down
33 changes: 11 additions & 22 deletions rats-rs/src/transport/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ mod server;
use super::{Error, ErrorKind, Result};
use crate::{
cert::{
create::CertBuilder,
dice::extensions::{OID_TCG_DICE_ENDORSEMENT_MANIFEST, OID_TCG_DICE_TAGGED_EVIDENCE},
verify::{
verify_cert_der, CertVerifier, VerifiyPolicy, VerifiyPolicy::Contains,
verify_cert_der, CertVerifier,
VerifiyPolicy::{self, Contains},
VerifyPolicyOutput,
},
},
tee::claims::Claims,
crypto::{AsymmetricAlgo, AsymmetricPrivateKey, DefaultCrypto, HashAlgo},
tee::{claims::Claims, AutoAttester},
};
use bitflags::bitflags;
pub use client::{Client, TlsClientBuilder};
use lazy_static::lazy_static;
use libc::c_int;
use libc::*;
use log::{debug, error};
use openssl_sys::*;
pub use openssl_sys::*;
use pkcs8::ObjectIdentifier;
pub use server::{Server, TlsServerBuilder};
use std::{
Expand All @@ -29,16 +32,8 @@ use std::{
};

lazy_static! {
static ref OPENSSL_EX_DATA_IDX: Arc<Mutex<Cell<i32>>> = unsafe {
Arc::new(Mutex::new(Cell::new(CRYPTO_get_ex_new_index(
4,
0,
ptr::null_mut(),
None,
None,
None,
))))
};
static ref OPENSSL_EX_DATA_IDX: i32 =
unsafe { CRYPTO_get_ex_new_index(4, 0, ptr::null_mut(), None, None, None,) };
}

static START: Once = Once::new();
Expand All @@ -47,14 +42,6 @@ trait GetFd {
fn get_fd(&self) -> i32;
}

struct GetFdDumpImpl;

impl GetFd for GetFdDumpImpl {
fn get_fd(&self) -> i32 {
0
}
}

struct TcpWrapper(TcpStream);

impl GetFd for TcpWrapper {
Expand Down Expand Up @@ -98,6 +85,8 @@ bitflags! {
}
}

// Initialize OpenSSL, referring to the rat-tls repository
// https://github.com/inclavare-containers/rats-tls/blob/cf5e911a480f7120da480f046417a209e222e101/src/tls_wrappers/openssl/init.c#L11
pub fn ossl_init() -> Result<()> {
START.call_once(|| unsafe {
OPENSSL_init_crypto(
Expand Down
Loading

0 comments on commit fe1e961

Please sign in to comment.