diff --git a/Cargo.lock b/Cargo.lock index 2122da74988..663a82d0946 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6404,6 +6404,7 @@ dependencies = [ "predicates", "regex", "rstest", + "scopeguard", "shell-words", "smol", "socket2", diff --git a/wezterm-ssh/Cargo.toml b/wezterm-ssh/Cargo.toml index 1fa5c0066d2..eab0a2c7d70 100644 --- a/wezterm-ssh/Cargo.toml +++ b/wezterm-ssh/Cargo.toml @@ -39,6 +39,7 @@ wezterm-uds = { path = "../wezterm-uds" } # Not used directly, but is used to centralize the openssl vendor feature selection async_ossl = { path = "../async_ossl" } +scopeguard = "1.2.0" [dev-dependencies] assert_fs = "1.0.4" diff --git a/wezterm-ssh/src/sessioninner.rs b/wezterm-ssh/src/sessioninner.rs index bbc145a1572..3888e49412b 100644 --- a/wezterm-ssh/src/sessioninner.rs +++ b/wezterm-ssh/src/sessioninner.rs @@ -15,11 +15,13 @@ use filedescriptor::{ poll, pollfd, socketpair, AsRawSocketDescriptor, FileDescriptor, POLLIN, POLLOUT, }; use portable_pty::ExitStatus; +use scopeguard::defer; use smol::channel::{bounded, Receiver, Sender, TryRecvError}; use socket2::{Domain, Socket, Type}; use std::collections::{HashMap, VecDeque}; use std::io::{Read, Write}; use std::net::ToSocketAddrs; +use std::process::Child; use std::time::Duration; #[derive(Debug)] @@ -205,8 +207,9 @@ impl SessionInner { sess.set_option(libssh_rs::SshOption::HostKeys(host_key.to_string()))?; } - let sock = + let (sock, child) = self.connect_to_host(&hostname, port, verbose, self.config.get("proxycommand"))?; + defer! { clean_up_proxy_command_child(child); } let raw = { #[cfg(unix)] { @@ -288,8 +291,9 @@ impl SessionInner { )))) .context("notifying user of banner")?; - let sock = + let (sock, child) = self.connect_to_host(&hostname, port, verbose, self.config.get("proxycommand"))?; + defer! { clean_up_proxy_command_child(child); } let mut sess = ssh2::Session::new()?; if verbose { @@ -332,7 +336,7 @@ impl SessionInner { port: u16, verbose: bool, proxy_command: Option<&String>, - ) -> anyhow::Result { + ) -> anyhow::Result<(Socket, Option)> { match proxy_command.map(|s| s.as_str()) { Some("none") | None => {} Some(proxy_command) => { @@ -351,19 +355,19 @@ impl SessionInner { cmd.stdin(b.as_stdio()?); cmd.stdout(b.as_stdio()?); cmd.stderr(std::process::Stdio::inherit()); - let _child = cmd + let child = cmd .spawn() .with_context(|| format!("spawning ProxyCommand {}", proxy_command))?; #[cfg(unix)] unsafe { use std::os::unix::io::{FromRawFd, IntoRawFd}; - return Ok(Socket::from_raw_fd(a.into_raw_fd())); + return Ok((Socket::from_raw_fd(a.into_raw_fd()), Some(child))); } #[cfg(windows)] unsafe { use std::os::windows::io::{FromRawSocket, IntoRawSocket}; - return Ok(Socket::from_raw_socket(a.into_raw_socket())); + return Ok((Socket::from_raw_socket(a.into_raw_socket()), Some(child))); } } } @@ -392,7 +396,7 @@ impl SessionInner { sock.connect(&addr.into()) .with_context(|| format!("Connecting to {hostname}:{port} ({addr:?})"))?; - Ok(sock) + Ok((sock, None)) } /// Used to restrict to_socket_addrs results to the address @@ -1086,3 +1090,14 @@ where } Ok(true) } + +fn clean_up_proxy_command_child(child: Option) { + if let Some(mut child) = child { + if let Err(err) = child.kill() { + log::error!("Error killing ProxyCommand: {}", err); + } + if let Err(err) = child.wait() { + log::error!("Error waiting for ProxyCommand to finish: {}", err); + } + } +}