Skip to content

Commit

Permalink
Eliminate com crate; bump windows-rs version (#132)
Browse files Browse the repository at this point in the history
* Eliminate `com` crate; bump windows-rs version

* Improve error handling for `verify_system_version`

* rustfmt
  • Loading branch information
DrChat authored Jun 27, 2024
1 parent 73137e7 commit 06f81bd
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 231 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ time_rs = ["time"]
serde = [ "dep:serde", "time?/serde", "time?/serde-human-readable" ]

[dependencies]
windows = { version = "0.52", features = [
windows = { version = "0.57.0", features = [
"Win32_Foundation",
"Win32_Security_Authorization",
"Win32_System_Com",
"Win32_System_Diagnostics_Etw",
"Win32_System_LibraryLoader",
"Win32_System_Memory",
"Win32_System_Performance",
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Time",
]}
com = "0.6.0"
memoffset = "0.9"
rand = "~0.8.0"
once_cell = "1.14"
Expand Down
20 changes: 13 additions & 7 deletions src/native/evntrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ where
PCWSTR::from_raw(properties.trace_name_array().as_ptr()),
properties.as_mut_ptr(),
)
};
}
.ok();

if let Err(status) = status {
let code = status.code();
Expand Down Expand Up @@ -258,7 +259,8 @@ pub(crate) fn enable_provider(
0,
Some(parameters.as_ptr()),
)
};
}
.ok();

res.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand All @@ -280,7 +282,8 @@ pub(crate) fn process_trace(trace_handle: TraceHandle) -> EvntraceNativeResult<(
// * for real-time traces, this means we might process a few events already waiting in the buffers when the processing is starting. This is fine, I suppose.
let mut start = FILETIME::default();
Etw::ProcessTrace(&[trace_handle], Some(&mut start as *mut FILETIME), None)
};
}
.ok();

result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand Down Expand Up @@ -313,7 +316,8 @@ pub(crate) fn control_trace(
properties.as_mut_ptr(),
control_code,
)
};
}
.ok();

result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand All @@ -337,7 +341,8 @@ pub(crate) fn control_trace_by_name(
properties.as_mut_ptr(),
control_code,
)
};
}
.ok();

result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand All @@ -363,7 +368,7 @@ pub(crate) fn close_trace(
UNIQUE_VALID_CONTEXTS
.remove(callback_data.as_ref() as *const Arc<CallbackData> as *const c_void);

let status = unsafe { Etw::CloseTrace(handle) };
let status = unsafe { Etw::CloseTrace(handle) }.ok();

match status {
Ok(()) => Ok(false),
Expand All @@ -386,7 +391,8 @@ pub(crate) fn query_info(class: TraceInformation, buf: &mut [u8]) -> EvntraceNat
buf.len() as u32,
None,
)
};
}
.ok();

result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand Down
214 changes: 26 additions & 188 deletions src/native/pla.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,227 +5,65 @@
//!
//! This module shouldn't be accessed directly. Modules from the the crate level provide a safe API to interact
//! with the crate
use std::mem::MaybeUninit;
use windows::core::{BSTR, GUID};
use windows::{
core::{GUID, VARIANT},
Win32::System::{
Com::{CoCreateInstance, CoInitializeEx, CLSCTX_ALL, COINIT_MULTITHREADED},
Performance::{ITraceDataProviderCollection, TraceDataProviderCollection},
},
};

/// Pla native module errors
#[derive(Debug, PartialEq, Eq)]
pub enum PlaError {
/// Represents a Provider not found Error
NotFound,
/// Represents an HRESULT common error
ComHResultError(HResult),
ComError(windows::core::Error),
}

/// Wrapper over common HRESULT native errors (Incomplete)
#[derive(Debug, PartialEq, Eq)]
pub enum HResult {
/// Represents S_OK
HrOk,
/// Represents E_ABORT
HrAbort,
/// Represents E_ACCESSDENIED
HrAccessDenied,
/// Represents E_FAIL
HrFail,
/// Represents E_INVALIDARG
HrInvalidArg,
/// Represents E_OUTOFMEMORY
HrOutOfMemory,
/// Represent an HRESULT not implemented in the Wrapper
NotImplemented(i32),
}

impl From<i32> for HResult {
fn from(hr: i32) -> HResult {
match hr {
0x0 => HResult::HrOk,
-2147467260 => HResult::HrAbort,
-2147024891 => HResult::HrAccessDenied,
-2147467259 => HResult::HrFail,
-2147024809 => HResult::HrInvalidArg,
-2147024882 => HResult::HrOutOfMemory,
_ => HResult::NotImplemented(hr),
}
}
}

impl From<i32> for PlaError {
fn from(val: i32) -> PlaError {
PlaError::ComHResultError(HResult::from(val))
impl From<windows::core::Error> for PlaError {
fn from(val: windows::core::Error) -> PlaError {
PlaError::ComError(val)
}
}

pub(crate) type ProvidersComResult<T> = Result<T, PlaError>;

const VT_UI4: u16 = 0x13;
// We are just going to use VT_UI4 so we won't bother replicating the full VARIANT struct
// Not using Win32::Automation::VARIANT for commodity
#[repr(C)]
#[doc(hidden)]
#[derive(Debug, Default, Clone, Copy)]
pub struct Variant {
vt: u16,
w_reserved1: u16,
w_reserved2: u16,
w_reserved3: u16,
val: u32,
}

impl Variant {
pub fn new(vt: u16, val: u32) -> Self {
Variant {
vt,
val,
..Default::default()
}
}

pub fn increment_val(&mut self) {
self.val += 1;
}
pub fn get_val(&self) -> u32 {
self.val
}
}

fn check_hr(hr: i32) -> ProvidersComResult<()> {
let res = HResult::from(hr);
if res != HResult::HrOk {
return Err(PlaError::ComHResultError(res));
}

Ok(())
}

// https://github.com/microsoft/krabsetw/blob/31679cf84bc85360158672699f2f68a821e8a6d0/krabs/krabs/provider.hpp#L487
pub(crate) unsafe fn get_provider_guid(name: &str) -> ProvidersComResult<GUID> {
com::runtime::init_runtime()?;

let all_providers = com::runtime::create_instance::<
pla_interfaces::ITraceDataProviderCollection,
>(&pla_interfaces::CLSID_TRACE_DATA_PROV_COLLECTION)?;
// FIXME: This is not paired with a call to CoUninitialize, so this will leak COM resources.
unsafe { CoInitializeEx(None, COINIT_MULTITHREADED) }.ok()?;

let mut guid: MaybeUninit<GUID> = MaybeUninit::uninit();
let mut hr = all_providers.get_trace_data_providers(BSTR::from(""));
check_hr(hr)?;
let all_providers: ITraceDataProviderCollection =
unsafe { CoCreateInstance(&TraceDataProviderCollection, None, CLSCTX_ALL) }?;

// could we assume count is unsigned... let's trust that count won't be negative
let mut count = 0;
hr = all_providers.get_count(&mut count);
check_hr(hr)?;
all_providers.GetTraceDataProviders(None)?;

let mut index = Variant::new(VT_UI4, 0);
while index.get_val() < count as u32 {
let mut provider = None;
let count = all_providers.Count()? as u32;

hr = all_providers.get_item(index, &mut provider);
check_hr(hr)?;
let mut index = 0u32;
let mut guid = None;

// We can safely unwrap after check_hr
let mut raw_name: MaybeUninit<BSTR> = MaybeUninit::uninit();
let provider = provider.unwrap();
provider.get_display_name(raw_name.as_mut_ptr());
check_hr(hr)?;
while index < count as u32 {
let provider = all_providers.get_Item(&VARIANT::from(index))?;
let raw_name = provider.DisplayName()?;

let raw_name = raw_name.assume_init();
let prov_name = String::from_utf16_lossy(raw_name.as_wide());

index.increment_val();
index += 1;
// check if matches, if it does get guid and break
if prov_name.eq(name) {
hr = provider.get_guid(guid.as_mut_ptr());
check_hr(hr)?;
guid = Some(provider.Guid()?);
break;
}
}

if index.get_val() == count as u32 {
if index == count as u32 {
return Err(PlaError::NotFound);
}

// we can assume the guid is init if we reached this point eoc would return Error
Ok(guid.assume_init())
}

mod pla_interfaces {
use super::{Variant, BSTR, GUID};
use com::sys::IID;
use com::{interfaces, interfaces::iunknown::IUnknown, sys::HRESULT};

interfaces! {
// functions parameters not defined unless necessary
#[uuid("00020400-0000-0000-C000-000000000046")]
pub unsafe interface IDispatch: IUnknown {
pub fn get_type_info_count(&self) -> HRESULT;
pub fn get_type_info(&self) -> HRESULT;
pub fn get_ids_of_names(&self) -> HRESULT;
pub fn invoke(&self) -> HRESULT;
}

// pla.h
#[uuid("03837510-098b-11d8-9414-505054503030")]
pub unsafe interface ITraceDataProviderCollection: IDispatch {
pub fn get_count(&self, retval: *mut i32) -> HRESULT;
pub fn get_item(
&self,
#[pass_through]
index: Variant,
provider: *mut Option<ITraceDataProvider>,
) -> HRESULT;
pub fn get__new_enum(&self) -> HRESULT;
pub fn add(&self) -> HRESULT;
pub fn remove(&self) -> HRESULT;
pub fn clear(&self) -> HRESULT;
pub fn add_range(&self) -> HRESULT;
pub fn create_trace_data_provider(&self) -> HRESULT;
pub fn get_trace_data_providers(
&self,
#[pass_through]
server: BSTR
) -> HRESULT;
pub fn get_trace_data_providers_by_process(&self) -> HRESULT;
}

#[uuid("03837512-098b-11d8-9414-505054503030")]
pub unsafe interface ITraceDataProvider: IDispatch {
pub fn get_display_name(
&self,
#[pass_through]
name: *mut BSTR
) -> HRESULT;
pub fn put_display_name(&self) -> HRESULT;
pub fn get_guid(
&self,
#[pass_through]
guid: *mut GUID
) -> HRESULT;
pub fn put_guid(&self) -> HRESULT;
pub fn get_level(&self) -> HRESULT;
pub fn get_keywords_any(&self) -> HRESULT;
pub fn get_keywords_all(&self) -> HRESULT;
pub fn get_properties(&self) -> HRESULT;
pub fn get_filter_enabled(&self) -> HRESULT;
pub fn put_filter_enabled(&self) -> HRESULT;
pub fn get_filter_type(&self) -> HRESULT;
pub fn put_filter_type(&self) -> HRESULT;
pub fn get_filter_data(&self) -> HRESULT;
pub fn put_filter_data(&self) -> HRESULT;
pub fn query(&self) -> HRESULT;
pub fn resolve(&self) -> HRESULT;
pub fn set_security(&self) -> HRESULT;
pub fn get_security(&self) -> HRESULT;
pub fn get_registered_processes(&self) -> HRESULT;
}
}

// 03837511-098b-11d8-9414-505054503030
pub const CLSID_TRACE_DATA_PROV_COLLECTION: IID = IID {
data1: 0x03837511,
data2: 0x098b,
data3: 0x11d8,
data4: [0x94, 0x14, 0x50, 0x50, 0x54, 0x50, 0x30, 0x30],
};
Ok(guid.unwrap())
}

#[cfg(test)]
Expand Down
24 changes: 4 additions & 20 deletions src/native/sddl.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,9 @@
use core::ffi::c_void;
use std::str::Utf8Error;
use windows::core::PSTR;
use windows::Win32::Foundation::{HLOCAL, PSID};
use windows::Win32::Foundation::{LocalFree, HLOCAL, PSID};
use windows::Win32::Security::Authorization::ConvertSidToStringSidA;

// N.B windows-rs has an incorrect implementation for local free
// https://github.com/microsoft/windows-rs/issues/2488
#[allow(non_snake_case)]
pub unsafe fn LocalFree<P0>(hmem: P0) -> ::windows::core::Result<HLOCAL>
where
P0: ::windows::core::IntoParam<HLOCAL>,
{
#[link(name = "kernel32")]
extern "system" {
fn LocalFree(hmem: HLOCAL) -> HLOCAL;
}
let res = LocalFree(hmem.into_param().abi());
match res.0 as usize {
0 => Ok(res),
_ => Err(::windows::core::Error::from_win32()),
}
}

/// SDDL native error
#[derive(Debug)]
pub enum SddlNativeError {
Expand Down Expand Up @@ -58,7 +40,9 @@ pub fn convert_sid_to_string(sid: *const c_void) -> SddlResult<String> {

let sid_string = std::ffi::CStr::from_ptr(tmp.0.cast()).to_str()?.to_owned();

LocalFree(HLOCAL(tmp.0.cast())).map_err(|e| SddlNativeError::IoError(e.into()))?;
if LocalFree(HLOCAL(tmp.0.cast())) != HLOCAL(std::ptr::null_mut()) {
return Err(SddlNativeError::IoError(std::io::Error::last_os_error()));
}

Ok(sid_string)
}
Expand Down
Loading

0 comments on commit 06f81bd

Please sign in to comment.