Skip to content

Commit

Permalink
fix: reentrant tokio runtime when using Resolver (#142)
Browse files Browse the repository at this point in the history
`trust_dns_resolver` is implemented using async Rust and the tokio
runtime. Its blocking API is creating a new tokio runtime internally,
this is a problem because tokio does not allow starting a runtime from
within another runtime (panic).
  • Loading branch information
CBenoit authored Aug 14, 2023
1 parent c339695 commit ba3af51
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 35 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sspi"
version = "0.10.0"
version = "0.10.1"
edition = "2021"
readme = "README.md"
license = "MIT/Apache-2.0"
Expand All @@ -23,7 +23,7 @@ exclude = [
[features]
default = []
network_client = ["dep:reqwest", "dep:portpicker"]
dns_resolver = ["dep:trust-dns-resolver"]
dns_resolver = ["dep:trust-dns-resolver", "dep:tokio"]
# TSSSP should be used only on Windows as a native CREDSSP replacement
tsssp = ["dep:rustls"]

Expand Down Expand Up @@ -59,6 +59,7 @@ num-bigint-dig = "0.8.1"
tracing = "0.1.37"
rustls = { version = "0.20.7", features = ["dangerous_configuration"], optional = true }
zeroize = { version = "1.5.7", features = ["zeroize_derive"] }
tokio = { version = "1.1", features = ["time", "rt"], optional = true }

[target.'cfg(windows)'.dependencies]
winreg = "0.10"
Expand Down
74 changes: 41 additions & 33 deletions src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ cfg_if::cfg_if! {
cfg_if::cfg_if! {
if #[cfg(any(target_os="macos", target_os="ios"))] {
use std::time::Duration;
use std::thread;
use tokio::time::timeout;
use tokio::runtime;
use futures::stream::{StreamExt};
use async_dnssd::{query_record, QueryRecordResult, QueriedRecordFlags, Type};

Expand Down Expand Up @@ -179,10 +177,11 @@ cfg_if::cfg_if! {
}

pub fn dns_query_srv_records(name: &str) -> Vec<DnsSrvRecord> {
let query_timeout = 1000;
async fn query_with_timeout(name: String, query_timeout: u64) -> Vec<DnsSrvRecord> {
const QUERY_TIMEOUT: u64 = 1000;

async fn query_with_timeout(name: &str, query_timeout: u64) -> Vec<DnsSrvRecord> {
let mut dns_records: Vec<DnsSrvRecord> = Vec::new();
let mut query = query_record(&name, Type::SRV);
let mut query = query_record(name, Type::SRV);

loop {
match timeout(Duration::from_millis(query_timeout), query.next()).await {
Expand Down Expand Up @@ -210,26 +209,7 @@ cfg_if::cfg_if! {
dns_records
}

match runtime::Handle::try_current() {
Ok(handle) => {
// Tokio runtime already exists, cannot block again on the same thread.
// Spawn a new thread to run the blocking code.
let name = name.to_owned();
thread::spawn(move || {
// Send the dns_records back to the main thread.
handle.block_on(query_with_timeout(name, query_timeout))
}).join().unwrap() // returns the vec of dns records
},
Err(err) => {
if err.is_missing_context() {
// No existing tokio runtime context, block on a new one.
let rt = runtime::Builder::new_current_thread().enable_all().build().unwrap();
return rt.block_on(query_with_timeout(name.to_owned(), query_timeout));
}
// ThreadLocalDestroyed error should never happen.
panic!("Unexpected error when trying to get current runtime: {}", err);
}
}
execute_future(query_with_timeout(name, QUERY_TIMEOUT))
}

pub fn detect_kdc_hosts_from_dns_apple(domain: &str) -> Vec<String> {
Expand All @@ -254,7 +234,7 @@ cfg_if::cfg_if! {

cfg_if::cfg_if! {
if #[cfg(feature="dns_resolver")] {
use trust_dns_resolver::Resolver;
use trust_dns_resolver::TokioAsyncResolver;
use trust_dns_resolver::system_conf::read_system_conf;
use trust_dns_resolver::config::{ResolverConfig,NameServerConfig,Protocol,ResolverOpts};
use std::env;
Expand Down Expand Up @@ -293,7 +273,7 @@ cfg_if::cfg_if! {
None
}

fn get_trust_dns_resolver_from_name_servers(name_servers: Vec<String>) -> Option<Resolver> {
fn get_trust_dns_resolver_from_name_servers(name_servers: Vec<String>) -> Option<TokioAsyncResolver> {
let mut resolver_config = ResolverConfig::new();

for name_server_url in name_servers {
Expand All @@ -305,23 +285,23 @@ cfg_if::cfg_if! {
let mut resolver_options = ResolverOpts::default();
resolver_options.validate = false;

Resolver::new(resolver_config, resolver_options).ok()
TokioAsyncResolver::tokio(resolver_config, resolver_options).ok()
}

#[cfg(target_os="windows")]
fn get_trust_dns_resolver(domain: &str) -> Option<Resolver> {
fn get_trust_dns_resolver(domain: &str) -> Option<TokioAsyncResolver> {
let name_servers = get_name_servers_for_domain(domain);
get_trust_dns_resolver_from_name_servers(name_servers)
}

#[cfg(not(target_os="windows"))]
fn get_trust_dns_resolver(_domain: &str) -> Option<Resolver> {
fn get_trust_dns_resolver(_domain: &str) -> Option<TokioAsyncResolver> {
if let Ok(name_server_list) = env::var("SSPI_DNS_URL") {
let name_servers: Vec<String> = name_server_list
.split(',').map(|c|c.trim().to_string()).filter(|x: &String| !x.is_empty()).collect();
get_trust_dns_resolver_from_name_servers(name_servers)
} else if let Ok((resolver_config, resolver_options)) = read_system_conf() {
Resolver::new(resolver_config, resolver_options).ok()
TokioAsyncResolver::tokio(resolver_config, resolver_options).ok()
} else {
None
}
Expand All @@ -331,7 +311,7 @@ cfg_if::cfg_if! {
let mut kdc_hosts = Vec::new();

if let Some(resolver) = get_trust_dns_resolver(domain) {
if let Ok(records) = resolver.srv_lookup(format!("_kerberos._tcp.{}", domain)) {
if let Ok(records) = execute_future(resolver.srv_lookup(format!("_kerberos._tcp.{}", domain))) {
for record in records {
let port = record.port();
let target_name = record.target().to_string();
Expand All @@ -341,7 +321,7 @@ cfg_if::cfg_if! {
}
}

if let Ok(records) = resolver.srv_lookup(format!("_kerberos._udp.{}", domain)) {
if let Ok(records) = execute_future(resolver.srv_lookup(format!("_kerberos._udp.{}", domain))) {
for record in records {
let port = record.port();
let target_name = record.target().to_string();
Expand All @@ -357,6 +337,34 @@ cfg_if::cfg_if! {
}
}

#[cfg(any(feature = "dns_resolver", target_os = "macos", target_os = "ios"))]
fn execute_future<Fut>(fut: Fut) -> Fut::Output
where
Fut: std::future::IntoFuture + Send,
Fut::Output: Send,
{
use std::thread;
use tokio::runtime::{Builder, Handle};

match Handle::try_current() {
Ok(handle) => {
// Tokio runtime already exists, cannot block again on the same thread.
// Spawn a new thread to run the blocking code.
thread::scope(|s| s.spawn(move || handle.block_on(fut.into_future())).join().unwrap())
}
Err(err) => {
if err.is_missing_context() {
// No existing tokio runtime context, block on a new one.
let rt = Builder::new_current_thread().enable_all().build().unwrap();
return rt.block_on(fut.into_future());
}

// ThreadLocalDestroyed error should never happen.
panic!("Unexpected error when trying to get current runtime: {}", err);
}
}
}

#[allow(unused_variables)]
#[instrument(level = "debug", ret)]
pub fn detect_kdc_hosts_from_dns(domain: &str) -> Vec<String> {
Expand Down

0 comments on commit ba3af51

Please sign in to comment.