diff --git a/deucalion/Cargo.toml b/deucalion/Cargo.toml index 67bd81f..5aaeaef 100644 --- a/deucalion/Cargo.toml +++ b/deucalion/Cargo.toml @@ -60,6 +60,7 @@ winapi = { version = "0.3", features = [ "libloaderapi", "consoleapi", "wincon", + "tlhelp32", ] } [dev-dependencies] diff --git a/deucalion/src/lib.rs b/deucalion/src/lib.rs index 57a9467..2431a56 100644 --- a/deucalion/src/lib.rs +++ b/deucalion/src/lib.rs @@ -5,6 +5,7 @@ use std::path::PathBuf; use std::time::SystemTime; use simplelog::{LevelFilter, WriteLogger}; +use w32module::drop_ref_count_to_one; #[cfg(windows)] use winapi::shared::minwindef::*; use winapi::um::libloaderapi; @@ -24,6 +25,7 @@ use tokio::select; use tokio::sync::oneshot; mod hook; +mod w32module; pub mod namedpipe; pub mod procloader; @@ -152,10 +154,12 @@ async fn main_with_result() -> Result<()> { Ok(()) } +const DLL_PROCESS_ATTACH: u32 = 1; + #[allow(non_snake_case)] #[no_mangle] unsafe extern "system" fn DllMain(hModule: HINSTANCE, reason: u32, _: u32) -> BOOL { - if reason == 1 { + if reason == DLL_PROCESS_ATTACH { processthreadsapi::CreateThread( 0 as LPSECURITY_ATTRIBUTES, 0, @@ -220,6 +224,9 @@ unsafe extern "system" fn main(dll_base_addr: LPVOID) -> u32 { error!("Panic happened: {cause:?}"); pause(); } + if let Err(e) = drop_ref_count_to_one(dll_base_addr as HMODULE) { + error!("Could not drop ref count to one: {e}") + } info!("Shut down!"); #[cfg(debug_assertions)] wincon::FreeConsole(); diff --git a/deucalion/src/w32module.rs b/deucalion/src/w32module.rs new file mode 100644 index 0000000..a52c608 --- /dev/null +++ b/deucalion/src/w32module.rs @@ -0,0 +1,86 @@ +use anyhow::{format_err, Result}; + +use winapi::{ + shared::{minwindef::HMODULE, ntdef::HANDLE}, + um::{ + errhandlingapi::GetLastError, + handleapi::CloseHandle, + libloaderapi::FreeLibrary, + processthreadsapi::GetCurrentProcessId, + tlhelp32::{ + CreateToolhelp32Snapshot, Module32First, Module32Next, MODULEENTRY32, TH32CS_SNAPMODULE, + }, + }, +}; + +use log::info; + +struct TH32Handle(HANDLE); + +impl TH32Handle { + unsafe fn new(handle: HANDLE) -> Result { + if handle.is_null() { + return Err(format_err!( + "Failed to call CreateToolhelp32Snapshot: {}", + GetLastError() + )); + } + Ok(TH32Handle(handle)) + } +} +impl Drop for TH32Handle { + fn drop(&mut self) { + unsafe { + let _ = CloseHandle(self.0); + } + } +} + +unsafe fn get_ref_count(hmodule: HMODULE) -> Result { + let pid = GetCurrentProcessId(); + let snapshot_handle = TH32Handle::new(CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, pid))?; + + let mut me32: MODULEENTRY32 = core::mem::zeroed(); + let me32_size = std::mem::size_of::() as u32; + me32.dwSize = me32_size; + + if Module32First(snapshot_handle.0, &mut me32) == 0 { + return Err(format_err!( + "Failed to call Module32First: {}", + GetLastError() + )); + } + + let mut more_modules: bool = true; + + while more_modules { + if hmodule == me32.hModule { + if me32.GlblcntUsage == 0xFFFF { + continue; + } + return Ok(me32.GlblcntUsage); + } + more_modules = Module32Next(snapshot_handle.0, &mut me32) > 0; + } + Err(format_err!("Could not find ref count for current module")) +} + +pub unsafe fn drop_ref_count_to_one(hmodule: HMODULE) -> Result<()> { + let count = get_ref_count(hmodule)?; + if count <= 1 { + return Ok(()); + } + info!( + "Ref count is {count}. Calling FreeLibrary {} extra time(s)...", + count - 1 + ); + for _ in 0..count - 1 { + if FreeLibrary(hmodule) == 0 { + return Err(format_err!( + "Failed to call FreeLibrary: {}", + GetLastError() + )); + }; + } + Ok(()) +}