From 65e92502bfcb43031eca9537f00172a8625a4317 Mon Sep 17 00:00:00 2001 From: "Dr. Chat" Date: Sun, 23 Jun 2024 13:32:06 -0500 Subject: [PATCH 1/3] Eliminate `com` crate; bump windows-rs version --- Cargo.toml | 5 +- src/native/evntrace.rs | 20 ++-- src/native/pla.rs | 214 +++++------------------------------ src/native/sddl.rs | 24 +--- src/native/version_helper.rs | 2 +- 5 files changed, 47 insertions(+), 218 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b15a634..9fa833b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,17 +17,18 @@ time_rs = ["time"] serde = [ "dep:serde", "time?/serde", "time?/serde-human-readable" ] [dependencies] -windows = { version = "0.52", features = [ +windows = { version = "0.57.0", features = [ "Win32_Foundation", "Win32_Security_Authorization", + "Win32_System_Com", "Win32_System_Diagnostics_Etw", "Win32_System_LibraryLoader", "Win32_System_Memory", + "Win32_System_Performance", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Time", ]} -com = "0.6.0" memoffset = "0.9" rand = "~0.8.0" once_cell = "1.14" diff --git a/src/native/evntrace.rs b/src/native/evntrace.rs index 4d935e5..0e3f5fa 100644 --- a/src/native/evntrace.rs +++ b/src/native/evntrace.rs @@ -174,7 +174,8 @@ where PCWSTR::from_raw(properties.trace_name_array().as_ptr()), properties.as_mut_ptr(), ) - }; + } + .ok(); if let Err(status) = status { let code = status.code(); @@ -258,7 +259,8 @@ pub(crate) fn enable_provider( 0, Some(parameters.as_ptr()), ) - }; + } + .ok(); res.map_err(|err| { EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0)) @@ -280,7 +282,8 @@ pub(crate) fn process_trace(trace_handle: TraceHandle) -> EvntraceNativeResult<( // * for real-time traces, this means we might process a few events already waiting in the buffers when the processing is starting. This is fine, I suppose. let mut start = FILETIME::default(); Etw::ProcessTrace(&[trace_handle], Some(&mut start as *mut FILETIME), None) - }; + } + .ok(); result.map_err(|err| { EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0)) @@ -313,7 +316,8 @@ pub(crate) fn control_trace( properties.as_mut_ptr(), control_code, ) - }; + } + .ok(); result.map_err(|err| { EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0)) @@ -337,7 +341,8 @@ pub(crate) fn control_trace_by_name( properties.as_mut_ptr(), control_code, ) - }; + } + .ok(); result.map_err(|err| { EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0)) @@ -363,7 +368,7 @@ pub(crate) fn close_trace( UNIQUE_VALID_CONTEXTS .remove(callback_data.as_ref() as *const Arc as *const c_void); - let status = unsafe { Etw::CloseTrace(handle) }; + let status = unsafe { Etw::CloseTrace(handle) }.ok(); match status { Ok(()) => Ok(false), @@ -386,7 +391,8 @@ pub(crate) fn query_info(class: TraceInformation, buf: &mut [u8]) -> EvntraceNat buf.len() as u32, None, ) - }; + } + .ok(); result.map_err(|err| { EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0)) diff --git a/src/native/pla.rs b/src/native/pla.rs index fa2bb40..68fe58c 100644 --- a/src/native/pla.rs +++ b/src/native/pla.rs @@ -5,8 +5,13 @@ //! //! This module shouldn't be accessed directly. Modules from the the crate level provide a safe API to interact //! with the crate -use std::mem::MaybeUninit; -use windows::core::{BSTR, GUID}; +use windows::{ + core::{GUID, VARIANT}, + Win32::System::{ + Com::{CoCreateInstance, CoInitializeEx, CLSCTX_ALL, COINIT_MULTITHREADED}, + Performance::{ITraceDataProviderCollection, TraceDataProviderCollection}, + }, +}; /// Pla native module errors #[derive(Debug, PartialEq, Eq)] @@ -14,218 +19,51 @@ pub enum PlaError { /// Represents a Provider not found Error NotFound, /// Represents an HRESULT common error - ComHResultError(HResult), + ComError(windows::core::Error), } -/// Wrapper over common HRESULT native errors (Incomplete) -#[derive(Debug, PartialEq, Eq)] -pub enum HResult { - /// Represents S_OK - HrOk, - /// Represents E_ABORT - HrAbort, - /// Represents E_ACCESSDENIED - HrAccessDenied, - /// Represents E_FAIL - HrFail, - /// Represents E_INVALIDARG - HrInvalidArg, - /// Represents E_OUTOFMEMORY - HrOutOfMemory, - /// Represent an HRESULT not implemented in the Wrapper - NotImplemented(i32), -} - -impl From for HResult { - fn from(hr: i32) -> HResult { - match hr { - 0x0 => HResult::HrOk, - -2147467260 => HResult::HrAbort, - -2147024891 => HResult::HrAccessDenied, - -2147467259 => HResult::HrFail, - -2147024809 => HResult::HrInvalidArg, - -2147024882 => HResult::HrOutOfMemory, - _ => HResult::NotImplemented(hr), - } - } -} - -impl From for PlaError { - fn from(val: i32) -> PlaError { - PlaError::ComHResultError(HResult::from(val)) +impl From for PlaError { + fn from(val: windows::core::Error) -> PlaError { + PlaError::ComError(val) } } pub(crate) type ProvidersComResult = Result; -const VT_UI4: u16 = 0x13; -// We are just going to use VT_UI4 so we won't bother replicating the full VARIANT struct -// Not using Win32::Automation::VARIANT for commodity -#[repr(C)] -#[doc(hidden)] -#[derive(Debug, Default, Clone, Copy)] -pub struct Variant { - vt: u16, - w_reserved1: u16, - w_reserved2: u16, - w_reserved3: u16, - val: u32, -} - -impl Variant { - pub fn new(vt: u16, val: u32) -> Self { - Variant { - vt, - val, - ..Default::default() - } - } - - pub fn increment_val(&mut self) { - self.val += 1; - } - pub fn get_val(&self) -> u32 { - self.val - } -} - -fn check_hr(hr: i32) -> ProvidersComResult<()> { - let res = HResult::from(hr); - if res != HResult::HrOk { - return Err(PlaError::ComHResultError(res)); - } - - Ok(()) -} - // https://github.com/microsoft/krabsetw/blob/31679cf84bc85360158672699f2f68a821e8a6d0/krabs/krabs/provider.hpp#L487 pub(crate) unsafe fn get_provider_guid(name: &str) -> ProvidersComResult { - com::runtime::init_runtime()?; - - let all_providers = com::runtime::create_instance::< - pla_interfaces::ITraceDataProviderCollection, - >(&pla_interfaces::CLSID_TRACE_DATA_PROV_COLLECTION)?; + // FIXME: This is not paired with a call to CoUninitialize, so this will leak COM resources. + unsafe { CoInitializeEx(None, COINIT_MULTITHREADED) }.ok()?; - let mut guid: MaybeUninit = MaybeUninit::uninit(); - let mut hr = all_providers.get_trace_data_providers(BSTR::from("")); - check_hr(hr)?; + let all_providers: ITraceDataProviderCollection = + unsafe { CoCreateInstance(&TraceDataProviderCollection, None, CLSCTX_ALL) }?; - // could we assume count is unsigned... let's trust that count won't be negative - let mut count = 0; - hr = all_providers.get_count(&mut count); - check_hr(hr)?; + all_providers.GetTraceDataProviders(None)?; - let mut index = Variant::new(VT_UI4, 0); - while index.get_val() < count as u32 { - let mut provider = None; + let count = all_providers.Count()? as u32; - hr = all_providers.get_item(index, &mut provider); - check_hr(hr)?; + let mut index = 0u32; + let mut guid = None; - // We can safely unwrap after check_hr - let mut raw_name: MaybeUninit = MaybeUninit::uninit(); - let provider = provider.unwrap(); - provider.get_display_name(raw_name.as_mut_ptr()); - check_hr(hr)?; + while index < count as u32 { + let provider = all_providers.get_Item(&VARIANT::from(index))?; + let raw_name = provider.DisplayName()?; - let raw_name = raw_name.assume_init(); let prov_name = String::from_utf16_lossy(raw_name.as_wide()); - index.increment_val(); + index += 1; // check if matches, if it does get guid and break if prov_name.eq(name) { - hr = provider.get_guid(guid.as_mut_ptr()); - check_hr(hr)?; + guid = Some(provider.Guid()?); break; } } - if index.get_val() == count as u32 { + if index == count as u32 { return Err(PlaError::NotFound); } - // we can assume the guid is init if we reached this point eoc would return Error - Ok(guid.assume_init()) -} - -mod pla_interfaces { - use super::{Variant, BSTR, GUID}; - use com::sys::IID; - use com::{interfaces, interfaces::iunknown::IUnknown, sys::HRESULT}; - - interfaces! { - // functions parameters not defined unless necessary - #[uuid("00020400-0000-0000-C000-000000000046")] - pub unsafe interface IDispatch: IUnknown { - pub fn get_type_info_count(&self) -> HRESULT; - pub fn get_type_info(&self) -> HRESULT; - pub fn get_ids_of_names(&self) -> HRESULT; - pub fn invoke(&self) -> HRESULT; - } - - // pla.h - #[uuid("03837510-098b-11d8-9414-505054503030")] - pub unsafe interface ITraceDataProviderCollection: IDispatch { - pub fn get_count(&self, retval: *mut i32) -> HRESULT; - pub fn get_item( - &self, - #[pass_through] - index: Variant, - provider: *mut Option, - ) -> HRESULT; - pub fn get__new_enum(&self) -> HRESULT; - pub fn add(&self) -> HRESULT; - pub fn remove(&self) -> HRESULT; - pub fn clear(&self) -> HRESULT; - pub fn add_range(&self) -> HRESULT; - pub fn create_trace_data_provider(&self) -> HRESULT; - pub fn get_trace_data_providers( - &self, - #[pass_through] - server: BSTR - ) -> HRESULT; - pub fn get_trace_data_providers_by_process(&self) -> HRESULT; - } - - #[uuid("03837512-098b-11d8-9414-505054503030")] - pub unsafe interface ITraceDataProvider: IDispatch { - pub fn get_display_name( - &self, - #[pass_through] - name: *mut BSTR - ) -> HRESULT; - pub fn put_display_name(&self) -> HRESULT; - pub fn get_guid( - &self, - #[pass_through] - guid: *mut GUID - ) -> HRESULT; - pub fn put_guid(&self) -> HRESULT; - pub fn get_level(&self) -> HRESULT; - pub fn get_keywords_any(&self) -> HRESULT; - pub fn get_keywords_all(&self) -> HRESULT; - pub fn get_properties(&self) -> HRESULT; - pub fn get_filter_enabled(&self) -> HRESULT; - pub fn put_filter_enabled(&self) -> HRESULT; - pub fn get_filter_type(&self) -> HRESULT; - pub fn put_filter_type(&self) -> HRESULT; - pub fn get_filter_data(&self) -> HRESULT; - pub fn put_filter_data(&self) -> HRESULT; - pub fn query(&self) -> HRESULT; - pub fn resolve(&self) -> HRESULT; - pub fn set_security(&self) -> HRESULT; - pub fn get_security(&self) -> HRESULT; - pub fn get_registered_processes(&self) -> HRESULT; - } - } - - // 03837511-098b-11d8-9414-505054503030 - pub const CLSID_TRACE_DATA_PROV_COLLECTION: IID = IID { - data1: 0x03837511, - data2: 0x098b, - data3: 0x11d8, - data4: [0x94, 0x14, 0x50, 0x50, 0x54, 0x50, 0x30, 0x30], - }; + Ok(guid.unwrap()) } #[cfg(test)] diff --git a/src/native/sddl.rs b/src/native/sddl.rs index c889dd9..d963b54 100644 --- a/src/native/sddl.rs +++ b/src/native/sddl.rs @@ -1,27 +1,9 @@ use core::ffi::c_void; use std::str::Utf8Error; use windows::core::PSTR; -use windows::Win32::Foundation::{HLOCAL, PSID}; +use windows::Win32::Foundation::{LocalFree, HLOCAL, PSID}; use windows::Win32::Security::Authorization::ConvertSidToStringSidA; -// N.B windows-rs has an incorrect implementation for local free -// https://github.com/microsoft/windows-rs/issues/2488 -#[allow(non_snake_case)] -pub unsafe fn LocalFree(hmem: P0) -> ::windows::core::Result -where - P0: ::windows::core::IntoParam, -{ - #[link(name = "kernel32")] - extern "system" { - fn LocalFree(hmem: HLOCAL) -> HLOCAL; - } - let res = LocalFree(hmem.into_param().abi()); - match res.0 as usize { - 0 => Ok(res), - _ => Err(::windows::core::Error::from_win32()), - } -} - /// SDDL native error #[derive(Debug)] pub enum SddlNativeError { @@ -58,7 +40,9 @@ pub fn convert_sid_to_string(sid: *const c_void) -> SddlResult { let sid_string = std::ffi::CStr::from_ptr(tmp.0.cast()).to_str()?.to_owned(); - LocalFree(HLOCAL(tmp.0.cast())).map_err(|e| SddlNativeError::IoError(e.into()))?; + if LocalFree(HLOCAL(tmp.0.cast())) != HLOCAL(std::ptr::null_mut()) { + return Err(SddlNativeError::IoError(std::io::Error::last_os_error())); + } Ok(sid_string) } diff --git a/src/native/version_helper.rs b/src/native/version_helper.rs index a7f25c0..45ef70c 100644 --- a/src/native/version_helper.rs +++ b/src/native/version_helper.rs @@ -50,7 +50,7 @@ fn verify_system_version(major: u8, minor: u8, sp_major: u16) -> VersionHelperRe ) }; - let error = unsafe { GetLastError() }; + let error = unsafe { GetLastError() }.ok(); // See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-verifyversioninfoa#return-value match (res.is_ok(), error) { From 305c36f44dad203e4c456f43d6778ad873186aed Mon Sep 17 00:00:00 2001 From: "Dr. Chat" Date: Sun, 23 Jun 2024 13:54:05 -0500 Subject: [PATCH 2/3] Improve error handling for `verify_system_version` --- src/native/version_helper.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/native/version_helper.rs b/src/native/version_helper.rs index 45ef70c..11474db 100644 --- a/src/native/version_helper.rs +++ b/src/native/version_helper.rs @@ -5,6 +5,7 @@ //! //! At the moment the only option available is to check if the actual System Version is greater than //! Win8, is the only check we need for the crate to work as expected +use windows::core::HRESULT; use windows::Win32::Foundation::GetLastError; use windows::Win32::Foundation::ERROR_OLD_WIN_VERSION; use windows::Win32::System::SystemInformation::{VerSetConditionMask, VerifyVersionInfoA}; @@ -50,13 +51,15 @@ fn verify_system_version(major: u8, minor: u8, sp_major: u16) -> VersionHelperRe ) }; - let error = unsafe { GetLastError() }.ok(); - // See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-verifyversioninfoa#return-value - match (res.is_ok(), error) { - (true, _) => Ok(true), - (false, Err(err)) if err.code() == ERROR_OLD_WIN_VERSION.to_hresult() => Ok(false), - (false, _err) => Err(VersionHelperError::IoError(std::io::Error::last_os_error())), + match res { + Ok(_) => Ok(true), + Err(e) => match e.code() { + e if e == HRESULT::from_win32(ERROR_OLD_WIN_VERSION.0) => Ok(false), + _ => Err(VersionHelperError::IoError( + std::io::Error::from_raw_os_error(unsafe { GetLastError() }.0 as i32), + )), + }, } } From db1023976d71bd1c74fc974f0c88cc7406a58f31 Mon Sep 17 00:00:00 2001 From: "Dr. Chat" Date: Sun, 23 Jun 2024 14:00:25 -0500 Subject: [PATCH 3/3] rustfmt --- src/parser.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/parser.rs b/src/parser.rs index 75b2a17..4f2bf48 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -386,14 +386,13 @@ macro_rules! impl_try_parse_primitive_array { if prop_slice.buffer.as_ptr() as usize % align != 0 { return Err(ParserError::PropertyError( - "buffer alignment mismatch".into() + "buffer alignment mismatch".into(), )); } - if size.checked_mul(count).is_none() || (size * count) > isize::MAX as usize { - return Err(ParserError::PropertyError( - "size overflow".into() - )); + if size.checked_mul(count).is_none() || (size * count) > isize::MAX as usize + { + return Err(ParserError::PropertyError("size overflow".into())); } let slice = unsafe { @@ -468,9 +467,9 @@ impl private::TryParse for Parser<'_, '_> { )); } - // std::slice::from_raw_parts requires a pointer to be aligned, but we can't - // guarantee that the buffer is aligned. In testing, I found that the buffer - // is in fact never aligned appropriately, so a cheap workaround is to copy + // std::slice::from_raw_parts requires a pointer to be aligned, but we can't + // guarantee that the buffer is aligned. In testing, I found that the buffer + // is in fact never aligned appropriately, so a cheap workaround is to copy // the buffer into a new Vec and use that as the source for the slice // until we can find a better solution. let mut aligned_buffer = Vec::with_capacity(prop_slice.buffer.len() / 2);