From d0d02d849725c6601271d1d78c01ee2420902401 Mon Sep 17 00:00:00 2001 From: irving ou Date: Tue, 26 Sep 2023 16:03:17 -0400 Subject: [PATCH] feat: add generator based io structure --- Cargo.lock | 12 ++ Cargo.toml | 1 + examples/client.rs | 8 +- ffi/src/macros.rs | 2 - ffi/src/sec_handle.rs | 35 ++-- src/credssp/mod.rs | 133 +++++++++------ src/credssp/sspi_cred_ssp/mod.rs | 49 ++++-- src/generator.rs | 225 +++++++++++++++++++++++++ src/kerberos/config.rs | 29 +--- src/kerberos/mod.rs | 265 ++++++++++++++++-------------- src/lib.rs | 12 +- src/negotiate.rs | 149 +++++++++-------- src/network_client.rs | 72 ++------ src/ntlm/mod.rs | 110 +++++++------ src/pku2u/mod.rs | 42 +++-- tests/common.rs | 4 +- tools/wasm-testcompile/src/lib.rs | 14 -- 17 files changed, 711 insertions(+), 451 deletions(-) create mode 100644 src/generator.rs diff --git a/Cargo.lock b/Cargo.lock index 7065294d..3ef6a97d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,6 +89,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "async-recursion" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.73" @@ -1862,6 +1873,7 @@ name = "sspi" version = "0.10.1" dependencies = [ "async-dnssd", + "async-recursion", "bitflags 2.4.0", "byteorder", "cfg-if", diff --git a/Cargo.toml b/Cargo.toml index fe3da140..464d7880 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,7 @@ rustls = { version = "0.21", features = ["dangerous_configuration"], optional = zeroize = { version = "1.6", features = ["zeroize_derive"] } tokio = { version = "1.32", features = ["time", "rt"], optional = true } pcsc = { version = "2.8", optional = true } +async-recursion = "1.0.5" [target.'cfg(windows)'.dependencies] winreg = "0.51" diff --git a/examples/client.rs b/examples/client.rs index d65baccb..e33532ff 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -85,7 +85,9 @@ fn do_authentication(ntlm: &mut Ntlm, identity: &AuthIdentity, mut stream: &mut .with_target_name(username.as_str()) .with_output(&mut output_buffer); - let _result = ntlm.initialize_security_context_impl(&mut builder)?; + let _result = ntlm + .initialize_security_context_impl(&mut builder) + .resolve_to_result()?; write_message(&mut stream, &output_buffer[0].buffer)?; @@ -104,7 +106,9 @@ fn do_authentication(ntlm: &mut Ntlm, identity: &AuthIdentity, mut stream: &mut .with_input(&mut input_buffer) .with_output(&mut output_buffer); - let result = ntlm.initialize_security_context_impl(&mut builder)?; + let result = ntlm + .initialize_security_context_impl(&mut builder) + .resolve_to_result()?; if [SecurityStatus::CompleteAndContinue, SecurityStatus::CompleteNeeded].contains(&result.status) { println!("Completing the token..."); diff --git a/ffi/src/macros.rs b/ffi/src/macros.rs index e6e743c9..c8bac5df 100644 --- a/ffi/src/macros.rs +++ b/ffi/src/macros.rs @@ -1,7 +1,5 @@ macro_rules! try_execute { ($x:expr) => {{ - use num_traits::ToPrimitive; - match $x { Ok(value) => value, Err(err) => { diff --git a/ffi/src/sec_handle.rs b/ffi/src/sec_handle.rs index 8743d170..b1a22c59 100644 --- a/ffi/src/sec_handle.rs +++ b/ffi/src/sec_handle.rs @@ -10,7 +10,6 @@ use sspi::credssp::sspi_cred_ssp; use sspi::credssp::sspi_cred_ssp::SspiCredSsp; use sspi::credssp::SspiContext; use sspi::kerberos::config::KerberosConfig; -use sspi::network_client::reqwest_network_client::{RequestClientFactory, ReqwestNetworkClient}; use sspi::ntlm::NtlmConfig; use sspi::{ kerberos, negotiate, ntlm, pku2u, ClientRequestFlags, CredentialsBuffers, DataRepresentation, Error, ErrorKind, @@ -91,13 +90,9 @@ fn create_negotiate_context(attributes: &CredentialsAttributes) -> Result::default(), hostname.clone()); - let negotiate_config = NegotiateConfig::new( - Box::new(kerberos_config), - attributes.package_list.clone(), - hostname, - Box::new(RequestClientFactory), - ); + let kerberos_config = KerberosConfig::new(&kdc_url, hostname.clone()); + let negotiate_config = + NegotiateConfig::new(Box::new(kerberos_config), attributes.package_list.clone(), hostname); Negotiate::new(negotiate_config) } else { @@ -105,7 +100,6 @@ fn create_negotiate_context(attributes: &CredentialsAttributes) -> Result::default(), - hostname, + &kdc_url, hostname, ))?) } else { - let mut krb_config = KerberosConfig::from_env(); - krb_config.hostname = Some(hostname); + let krb_config = KerberosConfig { + hostname: Some(hostname), + url: None, + }; SspiContext::Kerberos(Kerberos::new_client_from_config(krb_config)?) } } @@ -380,7 +374,7 @@ pub unsafe extern "system" fn InitializeSecurityContextA( .with_target_name(service_principal) .with_input(&mut input_tokens) .with_output(&mut output_tokens); - let result_status = sspi_context.initialize_security_context_impl(&mut builder); + let result_status = sspi_context.initialize_security_context_impl(&mut builder).resolve_with_default_network_client(); let context_requirements = ClientRequestFlags::from_bits_retain(f_context_req); let allocate = context_requirements.contains(ClientRequestFlags::ALLOCATE_MEMORY); @@ -479,7 +473,7 @@ pub unsafe extern "system" fn InitializeSecurityContextW( .with_target_name(&service_principal) .with_input(&mut input_tokens) .with_output(&mut output_tokens); - let result_status = sspi_context.initialize_security_context_impl(&mut builder); + let result_status = sspi_context.initialize_security_context_impl(&mut builder).resolve_with_default_network_client(); let context_requirements = ClientRequestFlags::from_bits_retain(f_context_req); let allocate = context_requirements.contains(ClientRequestFlags::ALLOCATE_MEMORY); @@ -939,13 +933,14 @@ pub unsafe extern "system" fn ChangeAccountPasswordA( protocol_config: Box::new(NtlmConfig::new(whoami::hostname())), package_list: None, hostname: whoami::hostname(), - network_client_factory: Box::new(RequestClientFactory), }; SspiContext::Negotiate(try_execute!(Negotiate::new(negotiate_config))) }, kerberos::PKG_NAME => { - let mut krb_config = KerberosConfig::from_env(); - krb_config.hostname = Some(whoami::hostname()); + let krb_config = KerberosConfig{ + hostname:Some(whoami::hostname()), + url:None + }; SspiContext::Kerberos(try_execute!(Kerberos::new_client_from_config( krb_config ))) @@ -956,7 +951,7 @@ pub unsafe extern "system" fn ChangeAccountPasswordA( } }; - let result_status = sspi_context.change_password(change_password); + let result_status = sspi_context.change_password(change_password).resolve_with_default_network_client(); copy_to_c_sec_buffer((*p_output).p_buffers, &output_tokens, false); diff --git a/src/credssp/mod.rs b/src/credssp/mod.rs index 7545f374..d7eb939f 100644 --- a/src/credssp/mod.rs +++ b/src/credssp/mod.rs @@ -22,6 +22,9 @@ use ts_request::{NONCE_SIZE, TS_REQUEST_VERSION}; use self::sspi_cred_ssp::SspiCredSsp; use crate::builders::{ChangePassword, EmptyInitializeSecurityContext}; use crate::crypto::compute_sha256; +use crate::generator::{ + Generator, GeneratorChangePassword, GeneratorInitSecurityContext, NetworkRequest, YieldPointLocal, +}; use crate::kerberos::config::KerberosConfig; use crate::kerberos::{self, Kerberos}; use crate::ntlm::{self, Ntlm, NtlmConfig, SIGNATURE_SIZE}; @@ -228,7 +231,20 @@ impl CredSspClient { } #[instrument(fields(state = ?self.state), skip_all)] - pub fn process(&mut self, mut ts_request: TsRequest) -> crate::Result { + pub fn process<'a>( + &'a mut self, + ts_request: TsRequest, + ) -> Generator<'a, NetworkRequest, crate::Result>, crate::Result> { + Generator::<'a, NetworkRequest, crate::Result>, crate::Result>::new( + move |mut yield_point| async move { self.process_impl(&mut yield_point, ts_request).await }, + ) + } + + async fn process_impl( + &mut self, + yield_point: &mut YieldPointLocal, + mut ts_request: TsRequest, + ) -> crate::Result { ts_request.check_error()?; if let Some(ref mut context) = self.context { context.check_peer_version(ts_request.version)?; @@ -282,7 +298,8 @@ impl CredSspClient { .with_output(&mut output_token); let result = cred_ssp_context .sspi_context - .initialize_security_context_impl(&mut builder)?; + .initialize_security_context_impl(yield_point, &mut builder) + .await?; self.credentials_handle = credentials_handle; ts_request.nego_tokens = Some(output_token.remove(0).buffer); @@ -643,41 +660,6 @@ impl SspiImpl for SspiContext { }) } - #[instrument(ret, fields(security_package = self.package_name()), skip_all)] - fn initialize_security_context_impl<'a>( - &mut self, - builder: &mut FilledInitializeSecurityContext<'a, Self::CredentialsHandle>, - ) -> crate::Result { - match self { - SspiContext::Ntlm(ntlm) => { - let mut auth_identity = if let Some(Some(CredentialsBuffers::AuthIdentity(ref identity))) = - builder.credentials_handle_mut() - { - Some(identity.clone()) - } else { - None - }; - let mut new_builder = builder.full_transform(Some(&mut auth_identity)); - ntlm.initialize_security_context_impl(&mut new_builder) - } - SspiContext::Kerberos(kerberos) => kerberos.initialize_security_context_impl(builder), - SspiContext::Negotiate(negotiate) => negotiate.initialize_security_context_impl(builder), - SspiContext::Pku2u(pku2u) => { - let mut auth_identity = if let Some(Some(CredentialsBuffers::AuthIdentity(ref identity))) = - builder.credentials_handle_mut() - { - Some(identity.clone()) - } else { - None - }; - let mut new_builder = builder.full_transform(Some(&mut auth_identity)); - pku2u.initialize_security_context_impl(&mut new_builder) - } - #[cfg(feature = "tsssp")] - SspiContext::CredSsp(credssp) => credssp.initialize_security_context_impl(builder), - } - } - #[instrument(ret, fields(security_package = self.package_name()), skip_all)] fn accept_security_context_impl<'a>( &'a mut self, @@ -714,6 +696,69 @@ impl SspiImpl for SspiContext { SspiContext::CredSsp(credssp) => builder.transform(credssp).execute(), } } + + fn initialize_security_context_impl<'a>( + &'a mut self, + builder: &'a mut FilledInitializeSecurityContext<'a, Self::CredentialsHandle>, + ) -> GeneratorInitSecurityContext { + Generator::new(move |mut yield_point| async move { + self.initialize_security_context_impl(&mut yield_point, builder).await + }) + } +} + +impl<'a> SspiContext { + #[instrument(ret, fields(security_package = self.package_name()), skip_all)] + async fn change_password_impl( + &'a mut self, + yield_point: &mut YieldPointLocal, + change_password: ChangePassword<'a>, + ) -> crate::Result<()> { + match self { + SspiContext::Kerberos(kerberos) => kerberos.change_password(yield_point, change_password).await, + SspiContext::Negotiate(negotiate) => negotiate.change_password(yield_point, change_password).await, + _ => Err(crate::Error::new( + ErrorKind::UnsupportedFunction, + "Change password not supported for this protocol", + )), + } + } + + #[instrument(ret, fields(security_package = self.package_name()), skip_all)] + async fn initialize_security_context_impl( + &'a mut self, + yield_point: &mut YieldPointLocal, + builder: &'a mut FilledInitializeSecurityContext<'_, ::CredentialsHandle>, + ) -> crate::Result { + match self { + SspiContext::Ntlm(ntlm) => { + let mut auth_identity = if let Some(Some(CredentialsBuffers::AuthIdentity(ref identity))) = + builder.credentials_handle_mut() + { + Some(identity.clone()) + } else { + None + }; + let mut new_builder = builder.full_transform(Some(&mut auth_identity)); + ntlm.initialize_security_context_impl(&mut new_builder) + } + SspiContext::Kerberos(kerberos) => kerberos.initialize_security_context_impl(yield_point, builder).await, + SspiContext::Negotiate(negotiate) => negotiate.initialize_security_context_impl(yield_point, builder).await, + SspiContext::Pku2u(pku2u) => { + let mut auth_identity = if let Some(Some(CredentialsBuffers::AuthIdentity(ref identity))) = + builder.credentials_handle_mut() + { + Some(identity.clone()) + } else { + None + }; + let mut new_builder = builder.full_transform(Some(&mut auth_identity)); + pku2u.initialize_security_context_impl(&mut new_builder) + } + #[cfg(feature = "tsssp")] + SspiContext::CredSsp(credssp) => credssp.initialize_security_context_impl(yield_point, builder).await, + } + } } impl Sspi for SspiContext { @@ -858,16 +903,10 @@ impl Sspi for SspiContext { } } - #[instrument(ret, fields(security_package = self.package_name()), skip_all)] - fn change_password(&mut self, change_password: ChangePassword) -> crate::Result<()> { - match self { - SspiContext::Ntlm(ntlm) => ntlm.change_password(change_password), - SspiContext::Kerberos(kerberos) => kerberos.change_password(change_password), - SspiContext::Negotiate(negotiate) => negotiate.change_password(change_password), - SspiContext::Pku2u(pku2u) => pku2u.change_password(change_password), - #[cfg(feature = "tsssp")] - SspiContext::CredSsp(credssp) => credssp.change_password(change_password), - } + fn change_password<'a>(&'a mut self, change_password: ChangePassword<'a>) -> GeneratorChangePassword { + GeneratorChangePassword::new(move |mut yield_point| async move { + self.change_password_impl(&mut yield_point, change_password).await + }) } } diff --git a/src/credssp/sspi_cred_ssp/mod.rs b/src/credssp/sspi_cred_ssp/mod.rs index b57bbda3..7cc64ce6 100644 --- a/src/credssp/sspi_cred_ssp/mod.rs +++ b/src/credssp/sspi_cred_ssp/mod.rs @@ -12,6 +12,7 @@ use self::tls_connection::{danger, TlsConnection, TLS_PACKET_HEADER_LEN}; use super::ts_request::NONCE_SIZE; use super::{CredSspContext, CredSspMode, EndpointType, SspiContext, TsRequest}; use crate::builders::EmptyInitializeSecurityContext; +use crate::generator::{GeneratorChangePassword, GeneratorInitSecurityContext, YieldPointLocal}; use crate::{ builders, negotiate, AcquireCredentialsHandleResult, CertContext, CertEncodingType, CertTrustErrorStatus, CertTrustInfoStatus, CertTrustStatus, ClientRequestFlags, ClientResponseFlags, ConnectionInfo, ContextNames, @@ -19,6 +20,7 @@ use crate::{ Error, ErrorKind, InitializeSecurityContextResult, PackageCapabilities, PackageInfo, Result, SecurityBuffer, SecurityBufferType, SecurityPackageType, SecurityStatus, Sspi, SspiEx, SspiImpl, StreamSizes, PACKAGE_ID_NONE, }; +use async_recursion::async_recursion; pub const PKG_NAME: &str = "CREDSSP"; @@ -259,11 +261,12 @@ impl Sspi for SspiCredSsp { } #[instrument(level = "debug", ret, fields(state = ?self.state), skip_all)] - fn change_password(&mut self, _change_password: builders::ChangePassword) -> Result<()> { + fn change_password(&mut self, _change_password: builders::ChangePassword) -> GeneratorChangePassword { Err(Error::new( ErrorKind::UnsupportedFunction, "ChangePassword is not supported in SspiCredSsp context", )) + .into() } } @@ -295,10 +298,34 @@ impl SspiImpl for SspiCredSsp { }) } - #[instrument(ret, fields(state = ?self.state), skip_all)] fn initialize_security_context_impl<'a>( + &'a mut self, + builder: &'a mut builders::FilledInitializeSecurityContext<'a, Self::CredentialsHandle>, + ) -> GeneratorInitSecurityContext<'a> { + GeneratorInitSecurityContext::new(move |mut yield_point| async move { + self.initialize_security_context_impl(&mut yield_point, builder).await + }) + } + + #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, _builder))] + fn accept_security_context_impl<'a>( + &'a mut self, + _builder: builders::FilledAcceptSecurityContext<'a, Self::AuthenticationData, Self::CredentialsHandle>, + ) -> Result { + Err(Error::new( + ErrorKind::UnsupportedFunction, + "AcceptSecurityContext is not supported in SspiCredSsp context", + )) + } +} + +impl SspiCredSsp { + #[instrument(ret, fields(state = ?self.state), skip_all)] + #[async_recursion] + pub(crate) async fn initialize_security_context_impl<'a>( &mut self, - builder: &mut builders::FilledInitializeSecurityContext<'a, Self::CredentialsHandle>, + yield_point: &mut YieldPointLocal, + builder: &mut builders::FilledInitializeSecurityContext<'a, ::CredentialsHandle>, ) -> Result { trace!(?builder); // In the CredSSP we always set DELEGATE flag @@ -327,7 +354,7 @@ impl SspiImpl for SspiCredSsp { // delete the previous TLS message builder.input = None; - return self.initialize_security_context_impl(builder); + return self.initialize_security_context_impl(yield_point, builder).await; } let output_token = SecurityBuffer::find_buffer_mut(builder.output, SecurityBufferType::Token)?; @@ -369,7 +396,8 @@ impl SspiImpl for SspiCredSsp { let result = self .cred_ssp_context .sspi_context - .initialize_security_context_impl(&mut inner_builder)?; + .initialize_security_context_impl(yield_point, &mut inner_builder) + .await?; ts_request.nego_tokens = Some(output_token.remove(0).buffer); @@ -470,17 +498,6 @@ impl SspiImpl for SspiCredSsp { expiry: None, }) } - - #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, _builder))] - fn accept_security_context_impl<'a>( - &'a mut self, - _builder: builders::FilledAcceptSecurityContext<'a, Self::AuthenticationData, Self::CredentialsHandle>, - ) -> Result { - Err(Error::new( - ErrorKind::UnsupportedFunction, - "AcceptSecurityContext is not supported in SspiCredSsp context", - )) - } } impl SspiEx for SspiCredSsp { diff --git a/src/generator.rs b/src/generator.rs new file mode 100644 index 00000000..f1e95c8f --- /dev/null +++ b/src/generator.rs @@ -0,0 +1,225 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Wake, Waker}; + +use url::Url; + +use crate::network_client::{NetworkClient, NetworkProtocol}; +use crate::{Error, InitializeSecurityContextResult}; + +pub struct Interrupt { + value_to_yield: Option, + yielded_value: YieldedValue, + resumed_value: ResumedValue, + ready_to_resume: bool, +} + +impl Future for Interrupt +where + YieldTy: Unpin, +{ + type Output = ResumeTy; + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if this.ready_to_resume { + let resumed_value = this.resumed_value.try_lock().unwrap().take().unwrap(); + Poll::Ready(resumed_value) + } else { + let value_to_yield = this.value_to_yield.take().unwrap(); + *this.yielded_value.try_lock().unwrap() = Some(value_to_yield); + this.ready_to_resume = true; + Poll::Pending + } + } +} + +#[derive(Debug)] +pub struct YieldPoint { + yielded_value: YieldedValue, + resumed_value: ResumedValue, +} + +impl Clone for YieldPoint { + fn clone(&self) -> Self { + Self { + yielded_value: self.yielded_value.clone(), + resumed_value: self.resumed_value.clone(), + } + } +} + +impl<'point, YieldTy, ResumeTy> YieldPoint { + pub fn suspend(&'point mut self, value: YieldTy) -> Interrupt { + Interrupt { + value_to_yield: Some(value), + yielded_value: Arc::clone(&self.yielded_value), + resumed_value: Arc::clone(&self.resumed_value), + ready_to_resume: false, + } + } +} + +type YieldedValue = Arc>>; +type ResumedValue = Arc>>; +type PinnedFuture<'a, T> = Pin + Send + 'a>>; + +pub enum GeneratorState { + Suspended(YieldTy), + Completed(OutTy), +} + +pub struct Generator<'a, YieldTy, ResumeTy, OutTy> { + yielded_value: YieldedValue, + resumed_value: ResumedValue, + generator: PinnedFuture<'a, OutTy>, +} + +impl<'a, YieldTy, ResumeTy, OutTy> Generator<'a, YieldTy, ResumeTy, OutTy> +where + OutTy: Send + 'a, +{ + pub fn new(producer: Producer) -> Self + where + Producer: FnOnce(YieldPoint) -> Task, + Task: Future + Send + 'a, + { + let yielded_value = Arc::new(Mutex::new(None)); + let resumed_value = Arc::new(Mutex::new(None)); + + let yield_point = YieldPoint { + yielded_value: Arc::clone(&yielded_value), + resumed_value: Arc::clone(&resumed_value), + }; + Self { + yielded_value, + resumed_value, + generator: Box::pin(producer(yield_point)), + } + } + + pub fn start(&mut self) -> GeneratorState { + self.step() + } + + pub fn resume(&mut self, value: ResumeTy) -> GeneratorState { + *self.resumed_value.try_lock().unwrap() = Some(value); + self.step() + } + + fn step(&mut self) -> GeneratorState { + match execute_one_step(&mut self.generator) { + None => { + let value = self.yielded_value.try_lock().unwrap().take().unwrap(); + GeneratorState::Suspended(value) + } + Some(value) => GeneratorState::Completed(value), + } + } +} + +fn execute_one_step(task: &mut PinnedFuture) -> Option { + struct NoopWake; + + impl Wake for NoopWake { + fn wake(self: std::sync::Arc) { + // do nothing + } + } + + let waker = Waker::from(Arc::new(NoopWake)); + let mut context = Context::from_waker(&waker); + + match task.as_mut().poll(&mut context) { + Poll::Pending => None, + Poll::Ready(item) => Some(item), + } +} + +/// Utility types and methods +impl<'a, YieldTy, ResumeTy, OutTy> Generator<'a, YieldTy, ResumeTy, Result> +where + OutTy: Send + 'a, +{ + pub fn resolve_to_result(&mut self) -> Result { + let state = self.start(); + match state { + GeneratorState::Suspended(_) => Err(Error::new( + crate::ErrorKind::UnsupportedFunction, + "cannot finish generator", + )), + GeneratorState::Completed(res) => res, + } + } + + pub fn unwrap(&mut self) -> OutTy { + self.resolve_to_result().unwrap() + } + + pub fn expect(&mut self, msg: &str) -> OutTy { + self.resolve_to_result().expect(msg) + } +} +#[derive(Debug)] +pub struct NetworkRequest { + pub protocol: NetworkProtocol, + pub url: Url, + pub data: Vec, // avoid life time problem, suspend requires 'static life time +} + +impl<'a, YieldTy, ResumeTy, OutTy> std::fmt::Debug for Generator<'a, YieldTy, ResumeTy, OutTy> +where + YieldTy: std::fmt::Debug, + ResumeTy: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Generator") + .field("yielded_value", &self.yielded_value) + .field("resumed_value", &self.resumed_value) + .finish() + } +} +pub type GeneratorInitSecurityContext<'a> = + Generator<'a, NetworkRequest, crate::Result>, crate::Result>; + +pub type GeneratorChangePassword<'a> = Generator<'a, NetworkRequest, crate::Result>, crate::Result<()>>; + +pub(crate) type YieldPointLocal = YieldPoint>>; + +impl<'a, YieldType, ResumeType, OutType, ErrorType> From> + for Generator<'a, YieldType, ResumeType, Result> +where + OutType: Send + 'a, + ErrorType: Send + 'a, +{ + fn from(value: Result) -> Self { + Generator::new(move |_| async move { value }) + } +} +/// Utilitiies for working with network client +impl<'a, OutTy> Generator<'a, NetworkRequest, crate::Result>, OutTy> +where + OutTy: 'a + Send, +{ + #[cfg(feature = "network_client")] + pub fn resolve_with_default_network_client(&mut self) -> OutTy { + let network_client = crate::network_client::reqwest_network_client::ReqwestNetworkClient; + self.resolve_with_client(&network_client) + } + + pub fn resolve_with_client(&mut self, network_client: &dyn NetworkClient) -> OutTy { + let mut state = self.start(); + loop { + match state { + GeneratorState::Suspended(ref request) => { + state = self.resume(crate::network_client::NetworkClient::send(network_client, request)); + } + GeneratorState::Completed(res) => { + return res; + } + } + } + } +} diff --git a/src/kerberos/config.rs b/src/kerberos/config.rs index bcc09712..594a91f9 100644 --- a/src/kerberos/config.rs +++ b/src/kerberos/config.rs @@ -5,12 +5,10 @@ use url::Url; use crate::kdc::detect_kdc_url; use crate::negotiate::{NegotiatedProtocol, ProtocolConfig}; -use crate::network_client::NetworkClient; use crate::{Kerberos, Result}; pub struct KerberosConfig { pub url: Option, - pub network_client: Box, pub hostname: Option, } @@ -42,12 +40,11 @@ pub fn parse_kdc_url(mut kdc: String) -> Option { } impl KerberosConfig { - pub fn new(url: &str, network_client: Box, hostname: String) -> Self { + pub fn new(url: &str, hostname: String) -> Self { let kdc_url = parse_kdc_url(url.to_owned()); Self { url: kdc_url, - network_client, hostname: Some(hostname), } } @@ -60,42 +57,20 @@ impl KerberosConfig { } } - pub fn new_with_network_client(network_client: Box) -> Self { - Self { - url: None, - network_client, - hostname: None, - } - } - - #[cfg(feature = "network_client")] - pub fn from_env() -> Self { - use crate::network_client::reqwest_network_client::ReqwestNetworkClient; - - Self::new_with_network_client(Box::::default()) - } - - pub fn from_kdc_url(url: &str, network_client: Box) -> Self { + pub fn from_kdc_url(url: &str) -> Self { let kdc_url = parse_kdc_url(url.to_owned()); Self { url: kdc_url, - network_client, hostname: None, } } - - #[cfg(not(feature = "network_client"))] - pub fn from_env(network_client: Box) -> Self { - Self::new_with_network_client(network_client) - } } impl Clone for KerberosConfig { fn clone(&self) -> Self { Self { url: self.url.clone(), - network_client: self.network_client.box_clone(), hostname: self.hostname.clone(), } } diff --git a/src/kerberos/mod.rs b/src/kerberos/mod.rs index bddfbd2e..f6a590f2 100644 --- a/src/kerberos/mod.rs +++ b/src/kerberos/mod.rs @@ -41,6 +41,7 @@ use self::server::extractors::extract_tgt_ticket; use self::utils::{serialize_message, unwrap_hostname}; use super::channel_bindings::ChannelBindings; use crate::builders::ChangePassword; +use crate::generator::{GeneratorChangePassword, GeneratorInitSecurityContext, NetworkRequest, YieldPointLocal}; use crate::kerberos::client::extractors::{extract_salt_from_krb_error, extract_status_code_from_krb_priv_response}; use crate::kerberos::client::generators::{ generate_authenticator, generate_final_neg_token_targ, get_mech_list, GenerateTgsReqOptions, @@ -173,7 +174,7 @@ impl Kerberos { } } - fn send(&self, data: &[u8]) -> Result> { + async fn send<'data>(&self, yield_point: &mut YieldPointLocal, data: &'data [u8]) -> Result> { if let Some((realm, kdc_url)) = self.get_kdc() { let protocol = NetworkProtocol::from_url_scheme(kdc_url.scheme()).ok_or_else(|| { Error::new( @@ -182,20 +183,15 @@ impl Kerberos { ) })?; - if !self.config.network_client.is_protocol_supported(protocol) { - return Err(Error::new( - ErrorKind::InvalidParameter, - format!( - "Network protocol `{}` is not supported by `{}` network client. Supported protocols are: {:?}", - kdc_url.scheme(), - self.config.network_client.name(), - self.config.network_client.supported_protocols(), - ), - )); - } - return match protocol { - NetworkProtocol::Tcp => self.config.network_client.send(protocol, kdc_url, data), + NetworkProtocol::Tcp => { + let request = NetworkRequest { + protocol, + url: kdc_url.clone(), + data: data.to_vec(), + }; + yield_point.suspend(request).await + } NetworkProtocol::Udp => { if data.len() < 4 { return Err(Error::new( @@ -208,7 +204,12 @@ impl Kerberos { } // First 4 bytes are message length and it’s not included when using UDP - self.config.network_client.send(protocol, kdc_url, &data[4..]) + let request = NetworkRequest { + protocol, + url: kdc_url.clone(), + data: data[4..].to_vec(), + }; + yield_point.suspend(request).await } NetworkProtocol::Http | NetworkProtocol::Https => { let data = OctetStringAsn1::from(data.to_vec()); @@ -221,7 +222,12 @@ impl Kerberos { }; let message_request = picky_asn1_der::to_vec(&kdc_proxy_message)?; - let result_bytes = self.config.network_client.send(protocol, kdc_url, &message_request)?; + let request = NetworkRequest { + protocol, + url: kdc_url, + data: message_request, + }; + let result_bytes = yield_point.suspend(request).await?; let message_response: KdcProxyMessage = picky_asn1_der::from_bytes(&result_bytes)?; Ok(message_response.kerb_message.0 .0) } @@ -230,30 +236,39 @@ impl Kerberos { Err(Error::new(ErrorKind::NoAuthenticatingAuthority, "No KDC server found")) } - pub fn as_exchange(&mut self, kdc_req_body: &KdcReqBody, mut pa_data_options: AsReqPaDataOptions) -> Result { + pub async fn as_exchange( + &mut self, + yield_point: &mut YieldPointLocal, + kdc_req_body: &KdcReqBody, + mut pa_data_options: AsReqPaDataOptions<'_>, + ) -> Result { pa_data_options.with_pre_auth(false); let pa_datas = pa_data_options.generate()?; let as_req = generate_as_req(pa_datas, kdc_req_body.clone()); - let response = self.send(&serialize_message(&as_req)?)?; + let response = self.send(yield_point, &serialize_message(&as_req)?).await?; // first 4 bytes are message len. skipping them - let mut d = picky_asn1_der::Deserializer::new_from_bytes(&response[4..]); - let as_rep: KrbResult = KrbResult::deserialize(&mut d)?; + { + let mut d = picky_asn1_der::Deserializer::new_from_bytes(&response[4..]); + let as_rep: KrbResult = KrbResult::deserialize(&mut d)?; - if as_rep.is_ok() { - error!("KDC replied with AS_REP to the AS_REQ without the encrypted timestamp. The KRB_ERROR expected."); + if as_rep.is_ok() { + error!( + "KDC replied with AS_REP to the AS_REQ without the encrypted timestamp. The KRB_ERROR expected." + ); - return Err(Error::new( - ErrorKind::InvalidToken, - "KDC server should not process AS_REQ without the pa-pac data", - )); - } + return Err(Error::new( + ErrorKind::InvalidToken, + "KDC server should not process AS_REQ without the pa-pac data", + )); + } - if let Some(correct_salt) = extract_salt_from_krb_error(&as_rep.unwrap_err())? { - debug!("salt extracted successfully from the KRB_ERROR"); + if let Some(correct_salt) = extract_salt_from_krb_error(&as_rep.unwrap_err())? { + debug!("salt extracted successfully from the KRB_ERROR"); - pa_data_options.with_salt(correct_salt.as_bytes().to_vec()); + pa_data_options.with_salt(correct_salt.as_bytes().to_vec()); + } } pa_data_options.with_pre_auth(true); @@ -261,7 +276,7 @@ impl Kerberos { let as_req = generate_as_req(pa_datas, kdc_req_body.clone()); - let response = self.send(&serialize_message(&as_req)?)?; + let response = self.send(yield_point, &serialize_message(&as_req)?).await?; // first 4 bytes are message len. skipping them let mut d = picky_asn1_der::Deserializer::new_from_bytes(&response[4..]); @@ -440,8 +455,94 @@ impl Sspi for Kerberos { )) } + fn change_password<'a>(&'a mut self, change_password: ChangePassword<'a>) -> GeneratorChangePassword { + GeneratorChangePassword::new(move |mut yield_point| async move { + self.change_password(&mut yield_point, change_password).await + }) + } +} + +impl SspiImpl for Kerberos { + type CredentialsHandle = Option; + + type AuthenticationData = Credentials; + + #[instrument(level = "trace", ret, fields(state = ?self.state), skip(self))] + fn acquire_credentials_handle_impl( + &mut self, + builder: crate::builders::FilledAcquireCredentialsHandle<'_, Self::CredentialsHandle, Self::AuthenticationData>, + ) -> Result> { + if builder.credential_use == CredentialUse::Outbound && builder.auth_data.is_none() { + return Err(Error::new( + ErrorKind::NoCredentials, + String::from("The client must specify the auth data"), + )); + } + + self.auth_identity = builder + .auth_data + .cloned() + .map(|auth_data| auth_data.try_into()) + .transpose()?; + + Ok(AcquireCredentialsHandleResult { + credentials_handle: self.auth_identity.clone(), + expiry: None, + }) + } + + #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, builder))] + fn accept_security_context_impl( + &mut self, + builder: crate::builders::FilledAcceptSecurityContext<'_, Self::AuthenticationData, Self::CredentialsHandle>, + ) -> Result { + let input = builder + .input + .ok_or_else(|| crate::Error::new(ErrorKind::InvalidToken, "Input buffers must be specified"))?; + + let status = match &self.state { + KerberosState::ApExchange => { + let input_token = SecurityBuffer::find_buffer(input, SecurityBufferType::Token)?; + + let _ap_req: ApReq = picky_asn1_der::from_bytes(&input_token.buffer) + .map_err(|e| Error::new(ErrorKind::DecryptFailure, format!("{:?}", e)))?; + + self.state = KerberosState::Final; + + SecurityStatus::Ok + } + state => { + return Err(Error::new( + ErrorKind::OutOfSequence, + format!("Got wrong Kerberos state: {:?}", state), + )) + } + }; + + Ok(AcceptSecurityContextResult { + status, + flags: ServerResponseFlags::empty(), + expiry: None, + }) + } + + fn initialize_security_context_impl<'a>( + &'a mut self, + builder: &'a mut crate::builders::FilledInitializeSecurityContext, + ) -> GeneratorInitSecurityContext { + GeneratorInitSecurityContext::new(move |mut yield_point| async move { + self.initialize_security_context_impl(&mut yield_point, builder).await + }) + } +} + +impl<'a> Kerberos { #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, change_password))] - fn change_password(&mut self, change_password: ChangePassword) -> Result<()> { + pub async fn change_password( + &'a mut self, + yield_point: &mut YieldPointLocal, + change_password: ChangePassword<'a>, + ) -> Result<()> { let username = &change_password.account_name; let domain = &change_password.domain_name; let password = &change_password.old_password; @@ -471,7 +572,7 @@ impl Sspi for Kerberos { with_pre_auth: false, }); - let as_rep = self.as_exchange(&kdc_req_body, pa_data_options)?; + let as_rep = self.as_exchange(yield_point, &kdc_req_body, pa_data_options).await?; info!("AS exchange finished successfully."); @@ -520,7 +621,7 @@ impl Sspi for Kerberos { .set_port(Some(KPASSWD_PORT)) .map_err(|_| Error::new(ErrorKind::InvalidParameter, "Cannot set port for KDC URL"))?; - let response = self.send(&serialize_message(&krb_priv)?)?; + let response = self.send(yield_point, &serialize_message(&krb_priv)?).await?; trace!(?response, "Change password raw response"); let krb_priv_response = KrbPrivMessage::deserialize(&response[4..]).map_err(|err| { @@ -551,41 +652,11 @@ impl Sspi for Kerberos { Ok(()) } -} -impl SspiImpl for Kerberos { - type CredentialsHandle = Option; - - type AuthenticationData = Credentials; - - #[instrument(level = "trace", ret, fields(state = ?self.state), skip(self))] - fn acquire_credentials_handle_impl( - &mut self, - builder: crate::builders::FilledAcquireCredentialsHandle<'_, Self::CredentialsHandle, Self::AuthenticationData>, - ) -> Result> { - if builder.credential_use == CredentialUse::Outbound && builder.auth_data.is_none() { - return Err(Error::new( - ErrorKind::NoCredentials, - String::from("The client must specify the auth data"), - )); - } - - self.auth_identity = builder - .auth_data - .cloned() - .map(|auth_data| auth_data.try_into()) - .transpose()?; - - Ok(AcquireCredentialsHandleResult { - credentials_handle: self.auth_identity.clone(), - expiry: None, - }) - } - - #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, builder))] - fn initialize_security_context_impl( - &mut self, - builder: &mut crate::builders::FilledInitializeSecurityContext<'_, Self::CredentialsHandle>, + pub async fn initialize_security_context_impl( + &'a mut self, + yield_point: &mut YieldPointLocal, + builder: &'a mut crate::builders::FilledInitializeSecurityContext<'_, ::CredentialsHandle>, ) -> Result { trace!(?builder); @@ -723,7 +794,7 @@ impl SspiImpl for Kerberos { } }; - let as_rep = self.as_exchange(&kdc_req_body, pa_data_options)?; + let as_rep = self.as_exchange(yield_point, &kdc_req_body, pa_data_options).await?; info!("AS exchange finished successfully."); @@ -776,7 +847,7 @@ impl SspiImpl for Kerberos { context_requirements: builder.context_requirements, })?; - let response = self.send(&serialize_message(&tgs_req)?)?; + let response = self.send(yield_point, &serialize_message(&tgs_req)?).await?; // first 4 bytes are message len. skipping them let mut d = picky_asn1_der::Deserializer::new_from_bytes(&response[4..]); @@ -899,41 +970,6 @@ impl SspiImpl for Kerberos { expiry: None, }) } - - #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, builder))] - fn accept_security_context_impl( - &mut self, - builder: crate::builders::FilledAcceptSecurityContext<'_, Self::AuthenticationData, Self::CredentialsHandle>, - ) -> Result { - let input = builder - .input - .ok_or_else(|| crate::Error::new(ErrorKind::InvalidToken, "Input buffers must be specified"))?; - - let status = match &self.state { - KerberosState::ApExchange => { - let input_token = SecurityBuffer::find_buffer(input, SecurityBufferType::Token)?; - - let _ap_req: ApReq = picky_asn1_der::from_bytes(&input_token.buffer) - .map_err(|e| Error::new(ErrorKind::DecryptFailure, format!("{:?}", e)))?; - - self.state = KerberosState::Final; - - SecurityStatus::Ok - } - state => { - return Err(Error::new( - ErrorKind::OutOfSequence, - format!("Got wrong Kerberos state: {:?}", state), - )) - } - }; - - Ok(AcceptSecurityContextResult { - status, - flags: ServerResponseFlags::empty(), - expiry: None, - }) - } } impl SspiEx for Kerberos { @@ -951,27 +987,16 @@ mod tests { use picky_krb::crypto::CipherSuite; use super::EncryptionParams; - use crate::network_client::{NetworkClient, NetworkProtocol}; + use crate::generator::NetworkRequest; + use crate::network_client::NetworkClient; use crate::{EncryptionFlags, Kerberos, KerberosConfig, KerberosState, SecurityBuffer, SecurityBufferType, Sspi}; struct NetworkClientMock; impl NetworkClient for NetworkClientMock { - fn send(&self, _protocol: NetworkProtocol, _url: url::Url, _data: &[u8]) -> crate::Result> { + fn send(&self, _request: &NetworkRequest) -> crate::Result> { unreachable!("unsupported protocol") } - - fn box_clone(&self) -> Box { - Box::new(Self) - } - - fn name(&self) -> &'static str { - "Mock" - } - - fn supported_protocols(&self) -> &[crate::network_client::NetworkProtocol] { - &[] - } } #[test] @@ -991,7 +1016,6 @@ mod tests { state: KerberosState::Final, config: KerberosConfig { url: None, - network_client: Box::new(NetworkClientMock), hostname: None, }, auth_identity: None, @@ -1014,7 +1038,6 @@ mod tests { state: KerberosState::Final, config: KerberosConfig { url: None, - network_client: Box::new(NetworkClientMock), hostname: None, }, auth_identity: None, diff --git a/src/lib.rs b/src/lib.rs index f471e770..7beda24c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,6 +55,7 @@ extern crate tracing; pub mod builders; pub mod channel_bindings; pub mod credssp; +pub mod generator; pub mod kerberos; pub mod negotiate; pub mod network_client; @@ -88,6 +89,7 @@ use std::{error, fmt, io, result, str, string}; use bitflags::bitflags; #[cfg(feature = "tsssp")] use credssp::sspi_cred_ssp; +use generator::{GeneratorChangePassword, GeneratorInitSecurityContext}; use num_derive::{FromPrimitive, ToPrimitive}; use picky_asn1::restricted_string::CharSetError; use picky_asn1_der::Asn1DerError; @@ -878,7 +880,7 @@ where /// # MSDN /// /// * [ChangeAccountPasswordW function](https://docs.microsoft.com/en-us/windows/win32/api/sspi/nf-sspi-changeaccountpasswordw) - fn change_password(&mut self, change_password: ChangePassword) -> Result<()>; + fn change_password<'a>(&'a mut self, change_password: ChangePassword<'a>) -> GeneratorChangePassword; } /// Protocol used to establish connection. @@ -993,10 +995,10 @@ pub trait SspiImpl { builder: FilledAcquireCredentialsHandle<'a, Self::CredentialsHandle, Self::AuthenticationData>, ) -> Result>; - fn initialize_security_context_impl( - &mut self, - builder: &mut FilledInitializeSecurityContext, - ) -> Result; + fn initialize_security_context_impl<'a>( + &'a mut self, + builder: &'a mut FilledInitializeSecurityContext<'a, Self::CredentialsHandle>, + ) -> GeneratorInitSecurityContext; fn accept_security_context_impl<'a>( &'a mut self, diff --git a/src/negotiate.rs b/src/negotiate.rs index 6f99529f..70825604 100644 --- a/src/negotiate.rs +++ b/src/negotiate.rs @@ -3,19 +3,18 @@ use std::net::IpAddr; use lazy_static::lazy_static; +use crate::generator::{GeneratorChangePassword, GeneratorInitSecurityContext, YieldPointLocal}; use crate::kdc::detect_kdc_url; use crate::kerberos::client::generators::get_client_principal_realm; -use crate::network_client::NetworkClientFactory; use crate::ntlm::NtlmConfig; #[allow(unused)] use crate::utils::is_azure_ad_domain; -#[cfg(feature = "network_client")] -use crate::KerberosConfig; use crate::{ builders, kerberos, ntlm, pku2u, AcceptSecurityContextResult, AcquireCredentialsHandleResult, AuthIdentity, CertTrustStatus, ContextNames, ContextSizes, CredentialUse, Credentials, CredentialsBuffers, DecryptionFlags, - Error, ErrorKind, InitializeSecurityContextResult, Kerberos, Ntlm, PackageCapabilities, PackageInfo, Pku2u, Result, - SecurityBuffer, SecurityPackageType, SecurityStatus, Sspi, SspiEx, SspiImpl, PACKAGE_ID_NONE, + Error, ErrorKind, InitializeSecurityContextResult, Kerberos, KerberosConfig, Ntlm, PackageCapabilities, + PackageInfo, Pku2u, Result, SecurityBuffer, SecurityPackageType, SecurityStatus, Sspi, SspiEx, SspiImpl, + PACKAGE_ID_NONE, }; pub const PKG_NAME: &str = "Negotiate"; @@ -40,34 +39,27 @@ pub struct NegotiateConfig { pub protocol_config: Box, pub package_list: Option, pub hostname: String, - pub network_client_factory: Box, } impl NegotiateConfig { + /// package_list format, "kerberos,ntlm,pku2u" pub fn new( protocol_config: Box, package_list: Option, hostname: String, - network_client_factory: Box, ) -> Self { Self { protocol_config, package_list, hostname, - network_client_factory, } } - pub fn from_protocol_config( - protocol_config: Box, - hostname: String, - network_client_factory: Box, - ) -> Self { + pub fn from_protocol_config(protocol_config: Box, hostname: String) -> Self { Self { protocol_config, package_list: None, hostname, - network_client_factory, } } } @@ -78,7 +70,6 @@ impl Clone for NegotiateConfig { protocol_config: self.protocol_config.clone(), package_list: None, hostname: self.hostname.clone(), - network_client_factory: self.network_client_factory.box_clone(), } } } @@ -107,7 +98,6 @@ pub struct Negotiate { package_list: Option, auth_identity: Option, hostname: String, - network_client_factory: Box, } impl Clone for Negotiate { @@ -117,7 +107,6 @@ impl Clone for Negotiate { package_list: self.package_list.clone(), auth_identity: self.auth_identity.clone(), hostname: self.hostname.clone(), - network_client_factory: self.network_client_factory.box_clone(), } } } @@ -140,7 +129,6 @@ impl Negotiate { package_list: config.package_list, auth_identity: None, hostname: config.hostname, - network_client_factory: config.network_client_factory, }) } @@ -171,7 +159,6 @@ impl Negotiate { self.protocol = NegotiatedProtocol::Kerberos(Kerberos::new_client_from_config(crate::KerberosConfig { url: Some(host), - network_client: self.network_client_factory.network_client(), hostname: Some(self.hostname.clone()), })?); } @@ -213,7 +200,7 @@ impl Negotiate { fn filter_protocol( negotiated_protocol: &NegotiatedProtocol, package_list: &Option, - #[allow(unused_variables)] hostname: &str, // Unused if `network_client` feature is disabled + hostname: &str, ) -> Result> { let mut filtered_protocol = None; let PackageListConfig { @@ -241,17 +228,11 @@ impl Negotiate { } } NegotiatedProtocol::Ntlm(_) => { - #[cfg(not(feature = "network_client"))] if !is_ntlm { - return Err(Error::new( - ErrorKind::InvalidParameter, - "Can not initialize Kerberos: network client is not provided", - )); - } - #[cfg(feature = "network_client")] - if !is_ntlm { - let mut config = KerberosConfig::from_env(); - config.hostname = Some(hostname.to_owned()); + let config = KerberosConfig { + hostname: Some(hostname.to_owned()), + url: None, + }; let kerberos_client = Kerberos::new_client_from_config(config)?; filtered_protocol = Some(NegotiatedProtocol::Kerberos(kerberos_client)); @@ -388,15 +369,10 @@ impl Sspi for Negotiate { } } - #[instrument(ret, fields(protocol = self.protocol.protocol_name()), skip_all)] - fn change_password(&mut self, change_password: builders::ChangePassword) -> Result<()> { - self.negotiate_protocol(&change_password.account_name, &change_password.domain_name)?; - - match &mut self.protocol { - NegotiatedProtocol::Pku2u(pku2u) => pku2u.change_password(change_password), - NegotiatedProtocol::Kerberos(kerberos) => kerberos.change_password(change_password), - NegotiatedProtocol::Ntlm(ntlm) => ntlm.change_password(change_password), - } + fn change_password<'a>(&'a mut self, change_password: builders::ChangePassword<'a>) -> GeneratorChangePassword { + GeneratorChangePassword::new(move |mut yield_point| async move { + self.change_password(&mut yield_point, change_password).await + }) } } @@ -463,9 +439,66 @@ impl SspiImpl for Negotiate { } #[instrument(ret, fields(protocol = self.protocol.protocol_name()), skip_all)] + fn accept_security_context_impl<'a>( + &'a mut self, + builder: builders::FilledAcceptSecurityContext<'a, Self::AuthenticationData, Self::CredentialsHandle>, + ) -> Result { + match &mut self.protocol { + NegotiatedProtocol::Pku2u(pku2u) => { + let mut creds_handle = if let Some(creds_handle) = &builder.credentials_handle { + creds_handle.as_ref().and_then(|c| c.clone().auth_identity()) + } else { + None + }; + let new_builder = builder.full_transform(pku2u, Some(&mut creds_handle)); + new_builder.execute() + } + NegotiatedProtocol::Kerberos(kerberos) => kerberos.accept_security_context_impl(builder), + NegotiatedProtocol::Ntlm(ntlm) => { + let mut creds_handle = if let Some(creds_handle) = &builder.credentials_handle { + creds_handle.as_ref().and_then(|c| c.clone().auth_identity()) + } else { + None + }; + let new_builder = builder.full_transform(ntlm, Some(&mut creds_handle)); + new_builder.execute() + } + } + } + fn initialize_security_context_impl<'a>( - &mut self, - builder: &mut builders::FilledInitializeSecurityContext<'a, Self::CredentialsHandle>, + &'a mut self, + builder: &'a mut builders::FilledInitializeSecurityContext, + ) -> GeneratorInitSecurityContext { + GeneratorInitSecurityContext::new(move |mut yield_point| async move { + self.initialize_security_context_impl(&mut yield_point, builder).await + }) + } +} + +impl<'a> Negotiate { + #[instrument(ret, fields(protocol = self.protocol.protocol_name()), skip_all)] + pub(crate) async fn change_password( + &'a mut self, + yield_point: &mut YieldPointLocal, + change_password: builders::ChangePassword<'a>, + ) -> Result<()> { + self.negotiate_protocol(&change_password.account_name, &change_password.domain_name)?; + + match &mut self.protocol { + NegotiatedProtocol::Kerberos(kerberos) => kerberos.change_password(yield_point, change_password).await, + _ => Err(crate::Error::new( + ErrorKind::UnsupportedFunction, + "cannot change password for this protocol", + )), + } + } + + #[instrument(ret, fields(protocol = self.protocol.protocol_name()), skip_all)] + pub(crate) async fn initialize_security_context_impl( + &'a mut self, + yield_point: &mut YieldPointLocal, + builder: &'a mut builders::FilledInitializeSecurityContext<'_, ::CredentialsHandle>, ) -> Result { if let Some(target_name) = &builder.target_name { self.check_target_name_for_ntlm_downgrade(target_name); @@ -499,7 +532,7 @@ impl SspiImpl for Negotiate { } if let NegotiatedProtocol::Kerberos(kerberos) = &mut self.protocol { - match kerberos.initialize_security_context_impl(builder) { + match kerberos.initialize_security_context_impl(yield_point, builder).await { Result::Err(Error { error_type: ErrorKind::NoCredentials, .. @@ -528,7 +561,9 @@ impl SspiImpl for Negotiate { pku2u.initialize_security_context_impl(&mut transformed_builder) } - NegotiatedProtocol::Kerberos(kerberos) => kerberos.initialize_security_context_impl(builder), + NegotiatedProtocol::Kerberos(kerberos) => { + kerberos.initialize_security_context_impl(yield_point, builder).await + } NegotiatedProtocol::Ntlm(ntlm) => { let mut credentials_handle = self.auth_identity.as_mut().and_then(|c| c.clone().auth_identity()); let mut transformed_builder = builder.full_transform(Some(&mut credentials_handle)); @@ -537,32 +572,4 @@ impl SspiImpl for Negotiate { } } } - - #[instrument(ret, fields(protocol = self.protocol.protocol_name()), skip_all)] - fn accept_security_context_impl<'a>( - &'a mut self, - builder: builders::FilledAcceptSecurityContext<'a, Self::AuthenticationData, Self::CredentialsHandle>, - ) -> Result { - match &mut self.protocol { - NegotiatedProtocol::Pku2u(pku2u) => { - let mut creds_handle = if let Some(creds_handle) = &builder.credentials_handle { - creds_handle.as_ref().and_then(|c| c.clone().auth_identity()) - } else { - None - }; - let new_builder = builder.full_transform(pku2u, Some(&mut creds_handle)); - new_builder.execute() - } - NegotiatedProtocol::Kerberos(kerberos) => kerberos.accept_security_context_impl(builder), - NegotiatedProtocol::Ntlm(ntlm) => { - let mut creds_handle = if let Some(creds_handle) = &builder.credentials_handle { - creds_handle.as_ref().and_then(|c| c.clone().auth_identity()) - } else { - None - }; - let new_builder = builder.full_transform(ntlm, Some(&mut creds_handle)); - new_builder.execute() - } - } - } } diff --git a/src/network_client.rs b/src/network_client.rs index be2f7ce2..e9977ba0 100644 --- a/src/network_client.rs +++ b/src/network_client.rs @@ -1,8 +1,6 @@ use std::fmt::Debug; -use url::Url; - -use crate::Result; +use crate::{generator::NetworkRequest, Result}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum NetworkProtocol { @@ -26,29 +24,12 @@ impl NetworkProtocol { } } -pub trait NetworkClientFactory: Debug + Send + Sync { - fn network_client(&self) -> Box; - fn box_clone(&self) -> Box; -} - pub trait NetworkClient: Send + Sync { - /// Return the name of the network client instance (for logging/error reporting purposes). - fn name(&self) -> &'static str; - /// Return list of supported protocols by the network client. - fn supported_protocols(&self) -> &[NetworkProtocol]; - /// Return true if the protocol is supported by the network client. - fn is_protocol_supported(&self, protocol: NetworkProtocol) -> bool { - self.supported_protocols().contains(&protocol) - } - - /// Clone network client instance via trait object. - fn box_clone(&self) -> Box; - /// Send request to the server and return the response. URL scheme is guaranteed to be /// the same as specified by `protocol` argument. `sspi-rs` will call this method only if /// `NetworkClient::is_protocol_supported` returned true prior to the call, so unsupported /// `protocol` values could be marked as `unreachable!`. - fn send(&self, protocol: NetworkProtocol, url: Url, data: &[u8]) -> Result>; + fn send(&self, request: &NetworkRequest) -> Result>; } #[cfg(feature = "network_client")] @@ -60,17 +41,15 @@ pub mod reqwest_network_client { use reqwest::blocking::Client; use url::Url; - use super::{NetworkClient, NetworkClientFactory, NetworkProtocol}; + use super::{NetworkClient, NetworkProtocol}; + use crate::generator::NetworkRequest; use crate::{Error, ErrorKind, Result}; #[derive(Debug, Clone, Default)] pub struct ReqwestNetworkClient; impl ReqwestNetworkClient { - const NAME: &str = "Reqwest"; - const SUPPORTED_PROTOCOLS: &[NetworkProtocol] = NetworkProtocol::ALL; - - fn send_tcp(&self, url: Url, data: &[u8]) -> Result> { + fn send_tcp(&self, url: &Url, data: &[u8]) -> Result> { let addr = format!("{}:{}", url.host_str().unwrap_or_default(), url.port().unwrap_or(88)); let mut stream = TcpStream::connect(addr) .map_err(|e| Error::new(ErrorKind::NoAuthenticatingAuthority, format!("{:?}", e)))?; @@ -93,7 +72,7 @@ pub mod reqwest_network_client { Ok(buf) } - fn send_udp(&self, url: Url, data: &[u8]) -> Result> { + fn send_udp(&self, url: &Url, data: &[u8]) -> Result> { let port = portpicker::pick_unused_port().ok_or_else(|| Error::new(ErrorKind::InternalError, "No free ports"))?; let udp_socket = UdpSocket::bind((IpAddr::V4(Ipv4Addr::LOCALHOST), port))?; @@ -113,11 +92,11 @@ pub mod reqwest_network_client { Ok(reply_buf) } - fn send_http(&self, url: Url, data: &[u8]) -> Result> { + fn send_http(&self, url: &Url, data: &[u8]) -> Result> { let client = Client::new(); let result_bytes = client - .post(url) + .post(url.clone()) .body(data.to_vec()) .send() .map_err(|err| match err { @@ -144,37 +123,12 @@ pub mod reqwest_network_client { } impl NetworkClient for ReqwestNetworkClient { - fn send(&self, protocol: NetworkProtocol, url: Url, data: &[u8]) -> Result> { - match protocol { - NetworkProtocol::Tcp => self.send_tcp(url, data), - NetworkProtocol::Udp => self.send_udp(url, data), - NetworkProtocol::Http | NetworkProtocol::Https => self.send_http(url, data), + fn send(&self, request: &NetworkRequest) -> Result> { + match request.protocol { + NetworkProtocol::Tcp => self.send_tcp(&request.url, &request.data), + NetworkProtocol::Udp => self.send_udp(&request.url, &request.data), + NetworkProtocol::Http | NetworkProtocol::Https => self.send_http(&request.url, &request.data), } } - - fn box_clone(&self) -> Box { - Box::new(Clone::clone(self)) - } - - fn name(&self) -> &'static str { - Self::NAME - } - - fn supported_protocols(&self) -> &[NetworkProtocol] { - Self::SUPPORTED_PROTOCOLS - } - } - - #[derive(Debug, Clone, Default)] - pub struct RequestClientFactory; - - impl NetworkClientFactory for RequestClientFactory { - fn network_client(&self) -> Box { - Box::::default() - } - - fn box_clone(&self) -> Box { - Box::new(Clone::clone(self)) - } } } diff --git a/src/ntlm/mod.rs b/src/ntlm/mod.rs index 15986288..31c51fef 100644 --- a/src/ntlm/mod.rs +++ b/src/ntlm/mod.rs @@ -14,9 +14,10 @@ use messages::{client, server}; pub use self::config::NtlmConfig; use super::channel_bindings::ChannelBindings; use crate::crypto::{compute_hmac_md5, Rc4, HASH_SIZE}; +use crate::generator::GeneratorInitSecurityContext; use crate::{ AcceptSecurityContextResult, AcquireCredentialsHandleResult, AuthIdentity, AuthIdentityBuffers, CertTrustStatus, - ClientResponseFlags, ContextNames, ContextSizes, CredentialUse, DecryptionFlags, EncryptionFlags, + ClientResponseFlags, ContextNames, ContextSizes, CredentialUse, DecryptionFlags, EncryptionFlags, Error, ErrorKind, FilledAcceptSecurityContext, FilledAcquireCredentialsHandle, FilledInitializeSecurityContext, InitializeSecurityContextResult, PackageCapabilities, PackageInfo, SecurityBuffer, SecurityBufferType, SecurityPackageType, SecurityStatus, ServerResponseFlags, Sspi, SspiEx, SspiImpl, PACKAGE_ID_NONE, @@ -237,10 +238,63 @@ impl SspiImpl for Ntlm { }) } + #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, builder))] + fn accept_security_context_impl( + &mut self, + builder: FilledAcceptSecurityContext<'_, Self::AuthenticationData, Self::CredentialsHandle>, + ) -> crate::Result { + let input = builder + .input + .ok_or_else(|| crate::Error::new(crate::ErrorKind::InvalidToken, "Input buffers must be specified"))?; + let status = match self.state { + NtlmState::Initial => { + let input_token = SecurityBuffer::find_buffer(input, SecurityBufferType::Token)?; + let output_token = SecurityBuffer::find_buffer_mut(builder.output, SecurityBufferType::Token)?; + + self.state = NtlmState::Negotiate; + server::read_negotiate(self, input_token.buffer.as_slice())?; + + server::write_challenge(self, &mut output_token.buffer)? + } + NtlmState::Authenticate => { + let input_token = SecurityBuffer::find_buffer(input, SecurityBufferType::Token)?; + + self.identity = builder.credentials_handle.cloned().flatten(); + + if let Ok(sec_buffer) = SecurityBuffer::find_buffer(input, SecurityBufferType::ChannelBindings) { + self.channel_bindings = Some(ChannelBindings::from_bytes(&sec_buffer.buffer)?); + } + + server::read_authenticate(self, input_token.buffer.as_slice())? + } + _ => { + return Err(crate::Error::new( + crate::ErrorKind::OutOfSequence, + format!("got wrong NTLM state: {:?}", self.state), + )) + } + }; + + Ok(AcceptSecurityContextResult { + status, + flags: ServerResponseFlags::empty(), + expiry: None, + }) + } + #[instrument(ret, fields(state = ?self.state), skip_all)] fn initialize_security_context_impl( &mut self, builder: &mut FilledInitializeSecurityContext<'_, Self::CredentialsHandle>, + ) -> GeneratorInitSecurityContext { + self.initialize_security_context_impl(builder).into() + } +} + +impl Ntlm { + pub(crate) fn initialize_security_context_impl( + &mut self, + builder: &mut FilledInitializeSecurityContext<'_, ::CredentialsHandle>, ) -> crate::Result { trace!(?builder); @@ -296,50 +350,6 @@ impl SspiImpl for Ntlm { expiry: None, }) } - - #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, builder))] - fn accept_security_context_impl( - &mut self, - builder: FilledAcceptSecurityContext<'_, Self::AuthenticationData, Self::CredentialsHandle>, - ) -> crate::Result { - let input = builder - .input - .ok_or_else(|| crate::Error::new(crate::ErrorKind::InvalidToken, "Input buffers must be specified"))?; - let status = match self.state { - NtlmState::Initial => { - let input_token = SecurityBuffer::find_buffer(input, SecurityBufferType::Token)?; - let output_token = SecurityBuffer::find_buffer_mut(builder.output, SecurityBufferType::Token)?; - - self.state = NtlmState::Negotiate; - server::read_negotiate(self, input_token.buffer.as_slice())?; - - server::write_challenge(self, &mut output_token.buffer)? - } - NtlmState::Authenticate => { - let input_token = SecurityBuffer::find_buffer(input, SecurityBufferType::Token)?; - - self.identity = builder.credentials_handle.cloned().flatten(); - - if let Ok(sec_buffer) = SecurityBuffer::find_buffer(input, SecurityBufferType::ChannelBindings) { - self.channel_bindings = Some(ChannelBindings::from_bytes(&sec_buffer.buffer)?); - } - - server::read_authenticate(self, input_token.buffer.as_slice())? - } - _ => { - return Err(crate::Error::new( - crate::ErrorKind::OutOfSequence, - format!("got wrong NTLM state: {:?}", self.state), - )) - } - }; - - Ok(AcceptSecurityContextResult { - status, - flags: ServerResponseFlags::empty(), - expiry: None, - }) - } } impl Sspi for Ntlm { @@ -451,12 +461,12 @@ impl Sspi for Ntlm { )) } - #[instrument(level = "debug", ret, fields(state = ?self.state), skip_all)] - fn change_password(&mut self, _change_password: crate::builders::ChangePassword) -> crate::Result<()> { - Err(crate::Error::new( - crate::ErrorKind::UnsupportedFunction, - "change_password is not supported in NTLM", + fn change_password(&mut self, _: crate::builders::ChangePassword) -> crate::generator::GeneratorChangePassword { + Err(Error::new( + ErrorKind::UnsupportedFunction, + "NTLM does not support change pasword", )) + .into() } } diff --git a/src/pku2u/mod.rs b/src/pku2u/mod.rs index 983471c4..714a982f 100644 --- a/src/pku2u/mod.rs +++ b/src/pku2u/mod.rs @@ -37,6 +37,7 @@ use self::generators::{ generate_server_dh_parameters, WELLKNOWN_REALM, }; use crate::builders::ChangePassword; +use crate::generator::GeneratorInitSecurityContext; use crate::kerberos::client::generators::{ generate_ap_req, generate_as_req, generate_as_req_kdc_body, ChecksumOptions, EncKey, GenerateAsReqOptions, GenerateAuthenticatorOptions, @@ -346,12 +347,12 @@ impl Sspi for Pku2u { )) } - #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, _change_password))] - fn change_password(&mut self, _change_password: ChangePassword) -> Result<()> { + fn change_password(&mut self, _: ChangePassword) -> crate::generator::GeneratorChangePassword { Err(Error::new( ErrorKind::UnsupportedFunction, - "change_password is not supported in PKU2U", + "Pku2u does not support change pasword", )) + .into() } } @@ -380,11 +381,31 @@ impl SspiImpl for Pku2u { }) } - #[instrument(ret, fields(state = ?self.state), skip_all)] + #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, _builder))] + fn accept_security_context_impl<'a>( + &'a mut self, + _builder: crate::builders::FilledAcceptSecurityContext<'a, Self::AuthenticationData, Self::CredentialsHandle>, + ) -> Result { + Err(Error::new( + ErrorKind::UnsupportedFunction, + "accept_security_context_impl is not implemented yet", + )) + } + fn initialize_security_context_impl( &mut self, builder: &mut crate::builders::FilledInitializeSecurityContext<'_, Self::CredentialsHandle>, - ) -> Result { + ) -> GeneratorInitSecurityContext { + self.initialize_security_context_impl(builder).into() + } +} + +impl Pku2u { + #[instrument(ret, fields(state = ?self.state), skip_all)] + pub(crate) fn initialize_security_context_impl( + &mut self, + builder: &mut crate::builders::FilledInitializeSecurityContext<'_, ::CredentialsHandle>, + ) -> crate::Result { trace!(?builder); let status = match self.state { @@ -784,17 +805,6 @@ impl SspiImpl for Pku2u { expiry: None, }) } - - #[instrument(level = "debug", ret, fields(state = ?self.state), skip(self, _builder))] - fn accept_security_context_impl<'a>( - &'a mut self, - _builder: crate::builders::FilledAcceptSecurityContext<'a, Self::AuthenticationData, Self::CredentialsHandle>, - ) -> Result { - Err(Error::new( - ErrorKind::UnsupportedFunction, - "accept_security_context_impl is not implemented yet", - )) - } } impl SspiEx for Pku2u { diff --git a/tests/common.rs b/tests/common.rs index e70ec611..29cfab79 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -112,7 +112,9 @@ where .with_input(&mut server_output) .with_output(&mut client_output); - let client_result = client.initialize_security_context_impl(&mut builder)?; + let client_result = client + .initialize_security_context_impl(&mut builder) + .resolve_to_result()?; client_status = client_result.status; if client_status != SecurityStatus::ContinueNeeded && server_status != SecurityStatus::ContinueNeeded { diff --git a/tools/wasm-testcompile/src/lib.rs b/tools/wasm-testcompile/src/lib.rs index 690a2c98..f3e40f55 100644 --- a/tools/wasm-testcompile/src/lib.rs +++ b/tools/wasm-testcompile/src/lib.rs @@ -1,20 +1,7 @@ -use sspi::network_client::{NetworkClient, NetworkClientFactory}; use sspi::ntlm::NtlmConfig; use sspi::{credssp, Credentials}; use wasm_bindgen::prelude::*; -#[derive(Debug, Clone)] -struct DummyNetworkClientFactory; - -impl NetworkClientFactory for DummyNetworkClientFactory { - fn network_client(&self) -> Box { - unimplemented!() - } - - fn box_clone(&self) -> Box { - Box::new(Clone::clone(self)) - } -} #[wasm_bindgen] pub fn credssp_client() { @@ -26,7 +13,6 @@ pub fn credssp_client() { protocol_config: Box::::default(), package_list: None, hostname: "testhostname".into(), - network_client_factory: Box::new(DummyNetworkClientFactory), }), String::new(), )