diff --git a/Cargo.lock b/Cargo.lock index a03d7b8..a7f7fb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,7 @@ dependencies = [ "pretty_assertions", "tempfile", "tokio", + "winapi", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 28d01c4..20c642d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,9 @@ path-dedot = "3.0.14" tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "process", "rt-multi-thread", "sync", "time"] } os_pipe = "1.0.1" +[target.'cfg(windows)'.dependencies] +winapi = { version = "0.3.9", features = ["jobapi2"] } + [dev-dependencies] parking_lot = "0.12.0" pretty_assertions = "1" diff --git a/src/child_process_tracker.rs b/src/child_process_tracker.rs new file mode 100644 index 0000000..d8c3c43 --- /dev/null +++ b/src/child_process_tracker.rs @@ -0,0 +1,160 @@ +use anyhow::Result; +use std::sync::Arc; + +use tokio::process::Child; + +/// Windows does not have a concept of parent processes and so +/// killing the deno task process will not also kill any spawned +/// processes by default. To make this work, we can use winapi's +/// jobs api, which allows for associating a main process so that +/// when the main process terminates, it will also terminate the +/// associated processes. +/// +/// Read more: https://stackoverflow.com/questions/3342941/kill-child-process-when-parent-process-is-killed +#[derive(Clone)] +pub struct ChildProcessTracker(Arc); + +impl ChildProcessTracker { + #[cfg(windows)] + pub fn new() -> Self { + match windows::WinChildProcessTracker::new() { + Ok(tracker) => Self(Arc::new(tracker)), + Err(err) => { + if cfg!(debug_assertions) { + panic!("Could not start tracking processes. {:#}", err); + } else { + // fallback to not tracking processes if this fails + Self(Arc::new(NullChildProcessTracker)) + } + } + } + } + + #[cfg(not(windows))] + pub fn new() -> Self { + Self(Arc::new(NullChildProcessTracker)) + } + + pub fn track(&self, child: &Child) { + if let Err(err) = self.0.track(child) { + if cfg!(debug_assertions) { + panic!("Could not track process: {:#}", err); + } + } + } +} + +trait Tracker: Send + Sync { + fn track(&self, child: &Child) -> Result<()>; +} + +struct NullChildProcessTracker; + +impl Tracker for NullChildProcessTracker { + fn track(&self, _: &Child) -> Result<()> { + Ok(()) + } +} + +#[cfg(target_os = "windows")] +mod windows { + use anyhow::bail; + use anyhow::Result; + use std::ptr; + use tokio::process::Child; + use winapi::shared::minwindef::DWORD; + use winapi::shared::minwindef::LPVOID; + use winapi::shared::minwindef::TRUE; + use winapi::um::handleapi::INVALID_HANDLE_VALUE; + use winapi::um::jobapi2::AssignProcessToJobObject; + use winapi::um::jobapi2::CreateJobObjectW; + use winapi::um::jobapi2::SetInformationJobObject; + use winapi::um::winnt::JobObjectExtendedLimitInformation; + use winapi::um::winnt::HANDLE; + use winapi::um::winnt::JOBOBJECT_EXTENDED_LIMIT_INFORMATION; + use winapi::um::winnt::JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + + use super::Tracker; + + pub struct WinChildProcessTracker { + handle: WinHandle, + } + + impl WinChildProcessTracker { + pub fn new() -> Result { + unsafe { + let handle = CreateJobObjectW(ptr::null_mut(), ptr::null()); + let mut info: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = std::mem::zeroed(); + info.BasicLimitInformation.LimitFlags = + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + let result = SetInformationJobObject( + handle, + JobObjectExtendedLimitInformation, + &mut info as *mut _ as LPVOID, + std::mem::size_of_val(&info) as DWORD, + ); + if result != TRUE { + bail!( + "Could not set job information object. {:#}", + std::io::Error::last_os_error() + ); + } + + Ok(Self { + handle: WinHandle::new(handle), + }) + } + } + + unsafe fn add_process_handle(&self, process_handle: HANDLE) -> Result<()> { + let result = + AssignProcessToJobObject(self.handle.as_raw_handle(), process_handle); + if result != TRUE { + bail!( + "Could not assign process to job object. {:#}", + std::io::Error::last_os_error() + ); + } else { + Ok(()) + } + } + } + + impl Tracker for WinChildProcessTracker { + fn track(&self, child: &Child) -> Result<()> { + if let Some(handle) = child.raw_handle() { + unsafe { self.add_process_handle(handle) } + } else { + // process exited... ignore + Ok(()) + } + } + } + + struct WinHandle { + inner: HANDLE, + } + + impl WinHandle { + pub fn new(handle: HANDLE) -> Self { + WinHandle { inner: handle } + } + + pub fn as_raw_handle(&self) -> HANDLE { + self.inner + } + } + + unsafe impl Send for WinHandle {} + unsafe impl Sync for WinHandle {} + + impl Drop for WinHandle { + fn drop(&mut self) { + unsafe { + if !self.inner.is_null() && self.inner != INVALID_HANDLE_VALUE { + winapi::um::handleapi::CloseHandle(self.inner); + } + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 440bf18..079ba40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ // Copyright 2018-2022 the Deno authors. All rights reserved. MIT license. +mod child_process_tracker; mod combinators; mod commands; mod fs_util; diff --git a/src/shell.rs b/src/shell.rs index b30b99d..fc170b0 100644 --- a/src/shell.rs +++ b/src/shell.rs @@ -400,6 +400,8 @@ async fn execute_command( } }; + state.track_child_process(&child); + // avoid deadlock since this is holding onto the pipes drop(sub_command); diff --git a/src/shell_types.rs b/src/shell_types.rs index 5a29677..2602354 100644 --- a/src/shell_types.rs +++ b/src/shell_types.rs @@ -10,6 +10,7 @@ use anyhow::Result; use futures::future::BoxFuture; use tokio::task::JoinHandle; +use crate::child_process_tracker::ChildProcessTracker; use crate::fs_util; #[derive(Clone)] @@ -21,6 +22,7 @@ pub struct ShellState { /// not passed down to any sub commands. shell_vars: HashMap, cwd: PathBuf, + process_tracker: ChildProcessTracker, } impl ShellState { @@ -29,6 +31,7 @@ impl ShellState { env_vars: Default::default(), shell_vars: Default::default(), cwd: PathBuf::new(), + process_tracker: ChildProcessTracker::new(), }; // ensure the data is normalized for (name, value) in env_vars { @@ -107,6 +110,10 @@ impl ShellState { } } } + + pub fn track_child_process(&self, child: &tokio::process::Child) { + self.process_tracker.track(child); + } } #[derive(Debug, PartialEq)]