Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate com crate; bump windows-rs version #132

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ time_rs = ["time"]
serde = [ "dep:serde", "time?/serde", "time?/serde-human-readable" ]

[dependencies]
windows = { version = "0.52", features = [
windows = { version = "0.57.0", features = [
"Win32_Foundation",
"Win32_Security_Authorization",
"Win32_System_Com",
"Win32_System_Diagnostics_Etw",
"Win32_System_LibraryLoader",
"Win32_System_Memory",
"Win32_System_Performance",
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Time",
]}
com = "0.6.0"
memoffset = "0.9"
rand = "~0.8.0"
once_cell = "1.14"
Expand Down
20 changes: 13 additions & 7 deletions src/native/evntrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ where
PCWSTR::from_raw(properties.trace_name_array().as_ptr()),
properties.as_mut_ptr(),
)
};
}
.ok();

if let Err(status) = status {
let code = status.code();
Expand Down Expand Up @@ -258,7 +259,8 @@ pub(crate) fn enable_provider(
0,
Some(parameters.as_ptr()),
)
};
}
.ok();

res.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand All @@ -280,7 +282,8 @@ pub(crate) fn process_trace(trace_handle: TraceHandle) -> EvntraceNativeResult<(
// * for real-time traces, this means we might process a few events already waiting in the buffers when the processing is starting. This is fine, I suppose.
let mut start = FILETIME::default();
Etw::ProcessTrace(&[trace_handle], Some(&mut start as *mut FILETIME), None)
};
}
.ok();

result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand Down Expand Up @@ -313,7 +316,8 @@ pub(crate) fn control_trace(
properties.as_mut_ptr(),
control_code,
)
};
}
.ok();

result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand All @@ -337,7 +341,8 @@ pub(crate) fn control_trace_by_name(
properties.as_mut_ptr(),
control_code,
)
};
}
.ok();

result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand All @@ -363,7 +368,7 @@ pub(crate) fn close_trace(
UNIQUE_VALID_CONTEXTS
.remove(callback_data.as_ref() as *const Arc<CallbackData> as *const c_void);

let status = unsafe { Etw::CloseTrace(handle) };
let status = unsafe { Etw::CloseTrace(handle) }.ok();

match status {
Ok(()) => Ok(false),
Expand All @@ -386,7 +391,8 @@ pub(crate) fn query_info(class: TraceInformation, buf: &mut [u8]) -> EvntraceNat
buf.len() as u32,
None,
)
};
}
.ok();

result.map_err(|err| {
EvntraceNativeError::IoError(std::io::Error::from_raw_os_error(err.code().0))
Expand Down
214 changes: 26 additions & 188 deletions src/native/pla.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,227 +5,65 @@
//!
//! This module shouldn't be accessed directly. Modules from the the crate level provide a safe API to interact
//! with the crate
use std::mem::MaybeUninit;
use windows::core::{BSTR, GUID};
use windows::{
core::{GUID, VARIANT},
Win32::System::{
Com::{CoCreateInstance, CoInitializeEx, CLSCTX_ALL, COINIT_MULTITHREADED},
Performance::{ITraceDataProviderCollection, TraceDataProviderCollection},
},
};

/// Pla native module errors
#[derive(Debug, PartialEq, Eq)]
pub enum PlaError {
/// Represents a Provider not found Error
NotFound,
/// Represents an HRESULT common error
ComHResultError(HResult),
ComError(windows::core::Error),
}

/// Wrapper over common HRESULT native errors (Incomplete)
#[derive(Debug, PartialEq, Eq)]
pub enum HResult {
/// Represents S_OK
HrOk,
/// Represents E_ABORT
HrAbort,
/// Represents E_ACCESSDENIED
HrAccessDenied,
/// Represents E_FAIL
HrFail,
/// Represents E_INVALIDARG
HrInvalidArg,
/// Represents E_OUTOFMEMORY
HrOutOfMemory,
/// Represent an HRESULT not implemented in the Wrapper
NotImplemented(i32),
}

impl From<i32> for HResult {
fn from(hr: i32) -> HResult {
match hr {
0x0 => HResult::HrOk,
-2147467260 => HResult::HrAbort,
-2147024891 => HResult::HrAccessDenied,
-2147467259 => HResult::HrFail,
-2147024809 => HResult::HrInvalidArg,
-2147024882 => HResult::HrOutOfMemory,
_ => HResult::NotImplemented(hr),
}
}
}

impl From<i32> for PlaError {
fn from(val: i32) -> PlaError {
PlaError::ComHResultError(HResult::from(val))
impl From<windows::core::Error> for PlaError {
fn from(val: windows::core::Error) -> PlaError {
PlaError::ComError(val)
}
}

pub(crate) type ProvidersComResult<T> = Result<T, PlaError>;

const VT_UI4: u16 = 0x13;
// We are just going to use VT_UI4 so we won't bother replicating the full VARIANT struct
// Not using Win32::Automation::VARIANT for commodity
#[repr(C)]
#[doc(hidden)]
#[derive(Debug, Default, Clone, Copy)]
pub struct Variant {
vt: u16,
w_reserved1: u16,
w_reserved2: u16,
w_reserved3: u16,
val: u32,
}

impl Variant {
pub fn new(vt: u16, val: u32) -> Self {
Variant {
vt,
val,
..Default::default()
}
}

pub fn increment_val(&mut self) {
self.val += 1;
}
pub fn get_val(&self) -> u32 {
self.val
}
}

fn check_hr(hr: i32) -> ProvidersComResult<()> {
let res = HResult::from(hr);
if res != HResult::HrOk {
return Err(PlaError::ComHResultError(res));
}

Ok(())
}

// https://github.com/microsoft/krabsetw/blob/31679cf84bc85360158672699f2f68a821e8a6d0/krabs/krabs/provider.hpp#L487
pub(crate) unsafe fn get_provider_guid(name: &str) -> ProvidersComResult<GUID> {
com::runtime::init_runtime()?;

let all_providers = com::runtime::create_instance::<
pla_interfaces::ITraceDataProviderCollection,
>(&pla_interfaces::CLSID_TRACE_DATA_PROV_COLLECTION)?;
// FIXME: This is not paired with a call to CoUninitialize, so this will leak COM resources.
unsafe { CoInitializeEx(None, COINIT_MULTITHREADED) }.ok()?;

let mut guid: MaybeUninit<GUID> = MaybeUninit::uninit();
let mut hr = all_providers.get_trace_data_providers(BSTR::from(""));
check_hr(hr)?;
let all_providers: ITraceDataProviderCollection =
unsafe { CoCreateInstance(&TraceDataProviderCollection, None, CLSCTX_ALL) }?;

// could we assume count is unsigned... let's trust that count won't be negative
let mut count = 0;
hr = all_providers.get_count(&mut count);
check_hr(hr)?;
all_providers.GetTraceDataProviders(None)?;

let mut index = Variant::new(VT_UI4, 0);
while index.get_val() < count as u32 {
let mut provider = None;
let count = all_providers.Count()? as u32;

hr = all_providers.get_item(index, &mut provider);
check_hr(hr)?;
let mut index = 0u32;
let mut guid = None;

// We can safely unwrap after check_hr
let mut raw_name: MaybeUninit<BSTR> = MaybeUninit::uninit();
let provider = provider.unwrap();
provider.get_display_name(raw_name.as_mut_ptr());
check_hr(hr)?;
while index < count as u32 {
let provider = all_providers.get_Item(&VARIANT::from(index))?;
let raw_name = provider.DisplayName()?;

let raw_name = raw_name.assume_init();
let prov_name = String::from_utf16_lossy(raw_name.as_wide());

index.increment_val();
index += 1;
// check if matches, if it does get guid and break
if prov_name.eq(name) {
hr = provider.get_guid(guid.as_mut_ptr());
check_hr(hr)?;
guid = Some(provider.Guid()?);
break;
}
}

if index.get_val() == count as u32 {
if index == count as u32 {
return Err(PlaError::NotFound);
}

// we can assume the guid is init if we reached this point eoc would return Error
Ok(guid.assume_init())
}

mod pla_interfaces {
use super::{Variant, BSTR, GUID};
use com::sys::IID;
use com::{interfaces, interfaces::iunknown::IUnknown, sys::HRESULT};

interfaces! {
// functions parameters not defined unless necessary
#[uuid("00020400-0000-0000-C000-000000000046")]
pub unsafe interface IDispatch: IUnknown {
pub fn get_type_info_count(&self) -> HRESULT;
pub fn get_type_info(&self) -> HRESULT;
pub fn get_ids_of_names(&self) -> HRESULT;
pub fn invoke(&self) -> HRESULT;
}

// pla.h
#[uuid("03837510-098b-11d8-9414-505054503030")]
pub unsafe interface ITraceDataProviderCollection: IDispatch {
pub fn get_count(&self, retval: *mut i32) -> HRESULT;
pub fn get_item(
&self,
#[pass_through]
index: Variant,
provider: *mut Option<ITraceDataProvider>,
) -> HRESULT;
pub fn get__new_enum(&self) -> HRESULT;
pub fn add(&self) -> HRESULT;
pub fn remove(&self) -> HRESULT;
pub fn clear(&self) -> HRESULT;
pub fn add_range(&self) -> HRESULT;
pub fn create_trace_data_provider(&self) -> HRESULT;
pub fn get_trace_data_providers(
&self,
#[pass_through]
server: BSTR
) -> HRESULT;
pub fn get_trace_data_providers_by_process(&self) -> HRESULT;
}

#[uuid("03837512-098b-11d8-9414-505054503030")]
pub unsafe interface ITraceDataProvider: IDispatch {
pub fn get_display_name(
&self,
#[pass_through]
name: *mut BSTR
) -> HRESULT;
pub fn put_display_name(&self) -> HRESULT;
pub fn get_guid(
&self,
#[pass_through]
guid: *mut GUID
) -> HRESULT;
pub fn put_guid(&self) -> HRESULT;
pub fn get_level(&self) -> HRESULT;
pub fn get_keywords_any(&self) -> HRESULT;
pub fn get_keywords_all(&self) -> HRESULT;
pub fn get_properties(&self) -> HRESULT;
pub fn get_filter_enabled(&self) -> HRESULT;
pub fn put_filter_enabled(&self) -> HRESULT;
pub fn get_filter_type(&self) -> HRESULT;
pub fn put_filter_type(&self) -> HRESULT;
pub fn get_filter_data(&self) -> HRESULT;
pub fn put_filter_data(&self) -> HRESULT;
pub fn query(&self) -> HRESULT;
pub fn resolve(&self) -> HRESULT;
pub fn set_security(&self) -> HRESULT;
pub fn get_security(&self) -> HRESULT;
pub fn get_registered_processes(&self) -> HRESULT;
}
}

// 03837511-098b-11d8-9414-505054503030
pub const CLSID_TRACE_DATA_PROV_COLLECTION: IID = IID {
data1: 0x03837511,
data2: 0x098b,
data3: 0x11d8,
data4: [0x94, 0x14, 0x50, 0x50, 0x54, 0x50, 0x30, 0x30],
};
Ok(guid.unwrap())
}

#[cfg(test)]
Expand Down
24 changes: 4 additions & 20 deletions src/native/sddl.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,9 @@
use core::ffi::c_void;
use std::str::Utf8Error;
use windows::core::PSTR;
use windows::Win32::Foundation::{HLOCAL, PSID};
use windows::Win32::Foundation::{LocalFree, HLOCAL, PSID};
use windows::Win32::Security::Authorization::ConvertSidToStringSidA;

// N.B windows-rs has an incorrect implementation for local free
// https://github.com/microsoft/windows-rs/issues/2488
#[allow(non_snake_case)]
pub unsafe fn LocalFree<P0>(hmem: P0) -> ::windows::core::Result<HLOCAL>
where
P0: ::windows::core::IntoParam<HLOCAL>,
{
#[link(name = "kernel32")]
extern "system" {
fn LocalFree(hmem: HLOCAL) -> HLOCAL;
}
let res = LocalFree(hmem.into_param().abi());
match res.0 as usize {
0 => Ok(res),
_ => Err(::windows::core::Error::from_win32()),
}
}

/// SDDL native error
#[derive(Debug)]
pub enum SddlNativeError {
Expand Down Expand Up @@ -58,7 +40,9 @@ pub fn convert_sid_to_string(sid: *const c_void) -> SddlResult<String> {

let sid_string = std::ffi::CStr::from_ptr(tmp.0.cast()).to_str()?.to_owned();

LocalFree(HLOCAL(tmp.0.cast())).map_err(|e| SddlNativeError::IoError(e.into()))?;
if LocalFree(HLOCAL(tmp.0.cast())) != HLOCAL(std::ptr::null_mut()) {
return Err(SddlNativeError::IoError(std::io::Error::last_os_error()));
}

Ok(sid_string)
}
Expand Down
Loading
Loading