diff --git a/protocols/noise/Cargo.toml b/protocols/noise/Cargo.toml index 000fb508a8d..724f1bafac3 100644 --- a/protocols/noise/Cargo.toml +++ b/protocols/noise/Cargo.toml @@ -10,7 +10,7 @@ edition = "2018" [dependencies] bytes = "0.4" curve25519-dalek = "1" -futures-preview = "0.3.0-alpha.17" +futures-preview = "0.3.0-alpha.18" lazy_static = "1.2" libp2p-core = { version = "0.12.0", path = "../../core" } log = "0.4" diff --git a/protocols/noise/src/io.rs b/protocols/noise/src/io.rs index 67c1aeb4237..a6fb4143e5c 100644 --- a/protocols/noise/src/io.rs +++ b/protocols/noise/src/io.rs @@ -22,11 +22,11 @@ pub mod handshake; -use futures::Poll; +use futures::{ready, Poll}; +use futures::prelude::*; use log::{debug, trace}; use snow; -use std::{fmt, io}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::{fmt, io, pin::Pin, ops::DerefMut, task::Context}; const MAX_NOISE_PKG_LEN: usize = 65535; const MAX_WRITE_BUF_LEN: usize = 16384; @@ -121,57 +121,75 @@ enum WriteState { EncErr } -impl io::Read for NoiseOutput { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let buffer = self.buffer.borrow_mut(); +impl AsyncRead for NoiseOutput { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut this = self.deref_mut(); + + let buffer = this.buffer.borrow_mut(); + loop { - trace!("read state: {:?}", self.read_state); - match self.read_state { + trace!("read state: {:?}", this.read_state); + match this.read_state { ReadState::Init => { - self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; + this.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; } ReadState::ReadLen { mut buf, mut off } => { - let n = match read_frame_len(&mut self.io, &mut buf, &mut off) { - Ok(Some(n)) => n, - Ok(None) => { + let n = match read_frame_len(&mut this.io, cx, &mut buf, &mut off) { + Poll::Ready(Ok(Some(n))) => n, + Poll::Ready(Ok(None)) => { trace!("read: eof"); - self.read_state = ReadState::Eof(Ok(())); - return Ok(0) + this.read_state = ReadState::Eof(Ok(())); + return Poll::Ready(Ok(0)) } - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - // Preserve read state - self.read_state = ReadState::ReadLen { buf, off }; - } - return Err(e) + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(e)) + } + Poll::Pending => { + this.read_state = ReadState::ReadLen { buf, off }; + + return Poll::Pending; } }; trace!("read: next frame len = {}", n); if n == 0 { trace!("read: empty frame"); - self.read_state = ReadState::Init; + this.read_state = ReadState::Init; continue } - self.read_state = ReadState::ReadData { len: usize::from(n), off: 0 } + this.read_state = ReadState::ReadData { len: usize::from(n), off: 0 } } ReadState::ReadData { len, ref mut off } => { - let n = self.io.read(&mut buffer.read[*off .. len])?; + let n = match ready!( + Pin::new(&mut this.io).poll_read(cx, &mut buffer.read[*off ..len]) + ) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(e)), + }; + trace!("read: read {}/{} bytes", *off + n, len); if n == 0 { trace!("read: eof"); - self.read_state = ReadState::Eof(Err(())); - return Err(io::ErrorKind::UnexpectedEof.into()) + this.read_state = ReadState::Eof(Err(())); + return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())) } + *off += n; if len == *off { trace!("read: decrypting {} bytes", len); - if let Ok(n) = self.session.read_message(&buffer.read[.. len], buffer.read_crypto) { + if let Ok(n) = this.session.read_message( + &buffer.read[.. len], + buffer.read_crypto + ){ trace!("read: payload len = {} bytes", n); - self.read_state = ReadState::CopyData { len: n, off: 0 } + this.read_state = ReadState::CopyData { len: n, off: 0 } } else { debug!("decryption error"); - self.read_state = ReadState::DecErr; - return Err(io::ErrorKind::InvalidData.into()) + this.read_state = ReadState::DecErr; + return Poll::Ready(Err(io::ErrorKind::InvalidData.into())) } } } @@ -181,32 +199,43 @@ impl io::Read for NoiseOutput { trace!("read: copied {}/{} bytes", *off + n, len); *off += n; if len == *off { - self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; + this.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; } - return Ok(n) + return Poll::Ready(Ok(n)) } ReadState::Eof(Ok(())) => { trace!("read: eof"); - return Ok(0) + return Poll::Ready(Ok(0)) } ReadState::Eof(Err(())) => { trace!("read: eof (unexpected)"); - return Err(io::ErrorKind::UnexpectedEof.into()) + return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())) } - ReadState::DecErr => return Err(io::ErrorKind::InvalidData.into()) + ReadState::DecErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())) } } } + + unsafe fn initializer(&self) -> futures::io::Initializer { + futures::io::Initializer::nop() + } } -impl io::Write for NoiseOutput { - fn write(&mut self, buf: &[u8]) -> io::Result { - let buffer = self.buffer.borrow_mut(); +impl AsyncWrite for NoiseOutput { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>{ + let mut this = self.deref_mut(); + + let buffer = this.buffer.borrow_mut(); + loop { - trace!("write state: {:?}", self.write_state); - match self.write_state { + trace!("write state: {:?}", this.write_state); + match this.write_state { WriteState::Init => { - self.write_state = WriteState::BufferData { off: 0 } + this.write_state = WriteState::BufferData { off: 0 } } WriteState::BufferData { ref mut off } => { let n = std::cmp::min(MAX_WRITE_BUF_LEN - *off, buf.len()); @@ -215,136 +244,157 @@ impl io::Write for NoiseOutput { *off += n; if *off == MAX_WRITE_BUF_LEN { trace!("write: encrypting {} bytes", *off); - if let Ok(n) = self.session.write_message(buffer.write, buffer.write_crypto) { - trace!("write: cipher text len = {} bytes", n); - self.write_state = WriteState::WriteLen { - len: n, - buf: u16::to_be_bytes(n as u16), - off: 0 + match this.session.write_message(buffer.write, buffer.write_crypto) { + Ok(n) => { + trace!("write: cipher text len = {} bytes", n); + this.write_state = WriteState::WriteLen { + len: n, + buf: u16::to_be_bytes(n as u16), + off: 0 + } + } + Err(e) => { + debug!("encryption error: {:?}", e); + this.write_state = WriteState::EncErr; + return Poll::Ready(Err(io::ErrorKind::InvalidData.into())) } - } else { - debug!("encryption error"); - self.write_state = WriteState::EncErr; - return Err(io::ErrorKind::InvalidData.into()) } } - return Ok(n) + return Poll::Ready(Ok(n)) } WriteState::WriteLen { len, mut buf, mut off } => { trace!("write: writing len ({}, {:?}, {}/2)", len, buf, off); - match write_frame_len(&mut self.io, &mut buf, &mut off) { - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - self.write_state = WriteState::WriteLen{ len, buf, off }; - } - return Err(e) - } - Ok(false) => { + match write_frame_len(&mut this.io, cx, &mut buf, &mut off) { + Poll::Ready(Ok(true)) => (), + Poll::Ready(Ok(false)) => { trace!("write: eof"); - self.write_state = WriteState::Eof; - return Err(io::ErrorKind::WriteZero.into()) + this.write_state = WriteState::Eof; + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) + } + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(e)) + } + Poll::Pending => { + this.write_state = WriteState::WriteLen{ len, buf, off }; + + return Poll::Pending } - Ok(true) => () } - self.write_state = WriteState::WriteData { len, off: 0 } + this.write_state = WriteState::WriteData { len, off: 0 } } WriteState::WriteData { len, ref mut off } => { - let n = self.io.write(&buffer.write_crypto[*off .. len])?; + let n = match ready!( + Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len]) + ) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(e)), + }; trace!("write: wrote {}/{} bytes", *off + n, len); if n == 0 { trace!("write: eof"); - self.write_state = WriteState::Eof; - return Err(io::ErrorKind::WriteZero.into()) + this.write_state = WriteState::Eof; + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) } *off += n; if len == *off { trace!("write: finished writing {} bytes", len); - self.write_state = WriteState::Init + this.write_state = WriteState::Init } } WriteState::Eof => { trace!("write: eof"); - return Err(io::ErrorKind::WriteZero.into()) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) } - WriteState::EncErr => return Err(io::ErrorKind::InvalidData.into()) + WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())) } } } - fn flush(&mut self) -> io::Result<()> { - let buffer = self.buffer.borrow_mut(); + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_> + ) -> Poll> { + let mut this = self.deref_mut(); + + let buffer = this.buffer.borrow_mut(); + loop { - match self.write_state { - WriteState::Init => return Ok(()), + match this.write_state { + WriteState::Init => return Poll::Ready(Ok(())), WriteState::BufferData { off } => { trace!("flush: encrypting {} bytes", off); - if let Ok(n) = self.session.write_message(&buffer.write[.. off], buffer.write_crypto) { - trace!("flush: cipher text len = {} bytes", n); - self.write_state = WriteState::WriteLen { - len: n, - buf: u16::to_be_bytes(n as u16), - off: 0 + match this.session.write_message(&buffer.write[.. off], buffer.write_crypto) { + Ok(n) => { + trace!("flush: cipher text len = {} bytes", n); + this.write_state = WriteState::WriteLen { + len: n, + buf: u16::to_be_bytes(n as u16), + off: 0 + } + } + Err(e) => { + debug!("encryption error: {:?}", e); + this.write_state = WriteState::EncErr; + return Poll::Ready(Err(io::ErrorKind::InvalidData.into())) } - } else { - debug!("encryption error"); - self.write_state = WriteState::EncErr; - return Err(io::ErrorKind::InvalidData.into()) } } WriteState::WriteLen { len, mut buf, mut off } => { trace!("flush: writing len ({}, {:?}, {}/2)", len, buf, off); - match write_frame_len(&mut self.io, &mut buf, &mut off) { - Ok(true) => (), - Ok(false) => { + match write_frame_len(&mut this.io, cx, &mut buf, &mut off) { + Poll::Ready(Ok(true)) => (), + Poll::Ready(Ok(false)) => { trace!("write: eof"); - self.write_state = WriteState::Eof; - return Err(io::ErrorKind::WriteZero.into()) + this.write_state = WriteState::Eof; + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) } - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - // Preserve write state - self.write_state = WriteState::WriteLen { len, buf, off }; - } - return Err(e) + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(e)) + } + Poll::Pending => { + this.write_state = WriteState::WriteLen { len, buf, off }; + + return Poll::Pending } } - self.write_state = WriteState::WriteData { len, off: 0 } + this.write_state = WriteState::WriteData { len, off: 0 } } WriteState::WriteData { len, ref mut off } => { - let n = self.io.write(&buffer.write_crypto[*off .. len])?; + let n = match ready!( + Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len]) + ) { + Ok(n) => n, + Err(e) => return Poll::Ready(Err(e)), + }; trace!("flush: wrote {}/{} bytes", *off + n, len); if n == 0 { trace!("flush: eof"); - self.write_state = WriteState::Eof; - return Err(io::ErrorKind::WriteZero.into()) + this.write_state = WriteState::Eof; + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) } *off += n; if len == *off { trace!("flush: finished writing {} bytes", len); - self.write_state = WriteState::Init; - return Ok(()) + this.write_state = WriteState::Init; + return Poll::Ready(Ok(())) } } WriteState::Eof => { trace!("flush: eof"); - return Err(io::ErrorKind::WriteZero.into()) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) } - WriteState::EncErr => return Err(io::ErrorKind::InvalidData.into()) + WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())) } } } -} -impl AsyncRead for NoiseOutput { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool { - false + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>{ + Pin::new(&mut self.io).poll_close(cx) } -} -impl AsyncWrite for NoiseOutput { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.io.shutdown() - } } /// Read 2 bytes as frame length from the given source into the given buffer. @@ -356,17 +406,26 @@ impl AsyncWrite for NoiseOutput { /// for the next invocation. /// /// Returns `None` if EOF has been encountered. -fn read_frame_len(io: &mut R, buf: &mut [u8; 2], off: &mut usize) - -> io::Result> -{ +fn read_frame_len( + mut io: &mut R, + cx: &mut Context<'_>, + buf: &mut [u8; 2], + off: &mut usize, +) -> Poll, std::io::Error>> { loop { - let n = io.read(&mut buf[*off ..])?; - if n == 0 { - return Ok(None) - } - *off += n; - if *off == 2 { - return Ok(Some(u16::from_be_bytes(*buf))) + match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off ..])) { + Ok(n) => { + if n == 0 { + return Poll::Ready(Ok(None)); + } + *off += n; + if *off == 2 { + return Poll::Ready(Ok(Some(u16::from_be_bytes(*buf)))); + } + }, + Err(e) => { + return Poll::Ready(Err(e)); + }, } } } @@ -380,18 +439,26 @@ fn read_frame_len(io: &mut R, buf: &mut [u8; 2], off: &mut usize) /// be preserved for the next invocation. /// /// Returns `false` if EOF has been encountered. -fn write_frame_len(io: &mut W, buf: &[u8; 2], off: &mut usize) - -> io::Result -{ +fn write_frame_len( + mut io: &mut W, + cx: &mut Context<'_>, + buf: &[u8; 2], + off: &mut usize, +) -> Poll> { loop { - let n = io.write(&buf[*off ..])?; - if n == 0 { - return Ok(false) - } - *off += n; - if *off == 2 { - return Ok(true) + match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off ..])) { + Ok(n) => { + if n == 0 { + return Poll::Ready(Ok(false)) + } + *off += n; + if *off == 2 { + return Poll::Ready(Ok(true)) + } + } + Err(e) => { + return Poll::Ready(Err(e)); + } } } } - diff --git a/protocols/noise/src/io/handshake.rs b/protocols/noise/src/io/handshake.rs index f0dac45c2e0..f11d6c999fe 100644 --- a/protocols/noise/src/io/handshake.rs +++ b/protocols/noise/src/io/handshake.rs @@ -26,9 +26,10 @@ use crate::error::NoiseError; use crate::protocol::{Protocol, PublicKey, KeypairIdentity}; use libp2p_core::identity; use futures::prelude::*; -use std::{mem, io, task::Poll}; +use futures::task; +use futures::io::AsyncReadExt; use protobuf::Message; - +use std::{pin::Pin, task::Context}; use super::NoiseOutput; /// The identity of the remote established during a handshake. @@ -86,129 +87,162 @@ pub enum IdentityExchange { None { remote: identity::PublicKey } } -impl Handshake +/// A future performing a Noise handshake pattern. +pub struct Handshake( + Pin, NoiseOutput), NoiseError>, + > + Send>> +); + +impl Future for Handshake { + type Output = Result<(RemoteIdentity, NoiseOutput), NoiseError>; + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> task::Poll { + Pin::new(&mut self.0).poll(ctx) + } +} + +/// Creates an authenticated Noise handshake for the initiator of a +/// single roundtrip (2 message) handshake pattern. +/// +/// Subject to the chosen [`IdentityExchange`], this message sequence +/// identifies the local node to the remote with the first message payload +/// (i.e. unencrypted) and expects the remote to identify itself in the +/// second message payload. +/// +/// This message sequence is suitable for authenticated 2-message Noise handshake +/// patterns where the static keys of the initiator and responder are either +/// known (i.e. appear in the pre-message pattern) or are sent with +/// the first and second message, respectively (e.g. `IK` or `IX`). +/// +/// ```raw +/// initiator -{id}-> responder +/// initiator <-{id}- responder +/// ``` +pub fn rt1_initiator( + io: T, + session: Result, + identity: KeypairIdentity, + identity_x: IdentityExchange +) -> Handshake where - T: AsyncRead + AsyncWrite + Send + 'static, - C: Protocol + AsRef<[u8]> + Send + 'static, + T: AsyncWrite + AsyncRead + Send + Unpin + 'static, + C: Protocol + AsRef<[u8]> { - /// Creates an authenticated Noise handshake for the initiator of a - /// single roundtrip (2 message) handshake pattern. - /// - /// Subject to the chosen [`IdentityExchange`], this message sequence - /// identifies the local node to the remote with the first message payload - /// (i.e. unencrypted) and expects the remote to identify itself in the - /// second message payload. - /// - /// This message sequence is suitable for authenticated 2-message Noise handshake - /// patterns where the static keys of the initiator and responder are either - /// known (i.e. appear in the pre-message pattern) or are sent with - /// the first and second message, respectively (e.g. `IK` or `IX`). - /// - /// ```raw - /// initiator -{id}-> responder - /// initiator <-{id}- responder - /// ``` - pub fn rt1_initiator( - io: T, - session: Result, - identity: KeypairIdentity, - identity_x: IdentityExchange - ) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> { - let mut state = State::new(io, session, identity, identity_x); + Handshake(Box::pin(async move { + let mut state = State::new(io, session, identity, identity_x)?; send_identity(&mut state).await?; recv_identity(&mut state).await?; - state.finish.await - } + state.finish() + })) +} - /// Creates an authenticated Noise handshake for the responder of a - /// single roundtrip (2 message) handshake pattern. - /// - /// Subject to the chosen [`IdentityExchange`], this message sequence expects the - /// remote to identify itself in the first message payload (i.e. unencrypted) - /// and identifies the local node to the remote in the second message payload. - /// - /// This message sequence is suitable for authenticated 2-message Noise handshake - /// patterns where the static keys of the initiator and responder are either - /// known (i.e. appear in the pre-message pattern) or are sent with the first - /// and second message, respectively (e.g. `IK` or `IX`). - /// - /// ```raw - /// initiator -{id}-> responder - /// initiator <-{id}- responder - /// ``` - pub fn rt1_responder( - io: T, - session: Result, - identity: KeypairIdentity, - identity_x: IdentityExchange, - ) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> { - let mut state = State::new(io, session, identity, identity_x); +/// Creates an authenticated Noise handshake for the responder of a +/// single roundtrip (2 message) handshake pattern. +/// +/// Subject to the chosen [`IdentityExchange`], this message sequence expects the +/// remote to identify itself in the first message payload (i.e. unencrypted) +/// and identifies the local node to the remote in the second message payload. +/// +/// This message sequence is suitable for authenticated 2-message Noise handshake +/// patterns where the static keys of the initiator and responder are either +/// known (i.e. appear in the pre-message pattern) or are sent with the first +/// and second message, respectively (e.g. `IK` or `IX`). +/// +/// ```raw +/// initiator -{id}-> responder +/// initiator <-{id}- responder +/// ``` +pub fn rt1_responder( + io: T, + session: Result, + identity: KeypairIdentity, + identity_x: IdentityExchange, +) -> Handshake +where + T: AsyncWrite + AsyncRead + Send + Unpin + 'static, + C: Protocol + AsRef<[u8]> +{ + Handshake(Box::pin(async move { + let mut state = State::new(io, session, identity, identity_x)?; recv_identity(&mut state).await?; send_identity(&mut state).await?; - state.finish.await - } + state.finish() + })) +} - /// Creates an authenticated Noise handshake for the initiator of a - /// 1.5-roundtrip (3 message) handshake pattern. - /// - /// Subject to the chosen [`IdentityExchange`], this message sequence expects - /// the remote to identify itself in the second message payload and - /// identifies the local node to the remote in the third message payload. - /// The first (unencrypted) message payload is always empty. - /// - /// This message sequence is suitable for authenticated 3-message Noise handshake - /// patterns where the static keys of the responder and initiator are either known - /// (i.e. appear in the pre-message pattern) or are sent with the second and third - /// message, respectively (e.g. `XX`). - /// - /// ```raw - /// initiator --{}--> responder - /// initiator <-{id}- responder - /// initiator -{id}-> responder - /// ``` - pub fn rt15_initiator( - io: T, - session: Result, - identity: KeypairIdentity, - identity_x: IdentityExchange - ) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> { - let mut state = State::new(io, session, identity, identity_x); +/// Creates an authenticated Noise handshake for the initiator of a +/// 1.5-roundtrip (3 message) handshake pattern. +/// +/// Subject to the chosen [`IdentityExchange`], this message sequence expects +/// the remote to identify itself in the second message payload and +/// identifies the local node to the remote in the third message payload. +/// The first (unencrypted) message payload is always empty. +/// +/// This message sequence is suitable for authenticated 3-message Noise handshake +/// patterns where the static keys of the responder and initiator are either known +/// (i.e. appear in the pre-message pattern) or are sent with the second and third +/// message, respectively (e.g. `XX`). +/// +/// ```raw +/// initiator --{}--> responder +/// initiator <-{id}- responder +/// initiator -{id}-> responder +/// ``` +pub fn rt15_initiator( + io: T, + session: Result, + identity: KeypairIdentity, + identity_x: IdentityExchange +) -> Handshake +where + T: AsyncWrite + AsyncRead + Unpin + Send + 'static, + C: Protocol + AsRef<[u8]> +{ + Handshake(Box::pin(async move { + let mut state = State::new(io, session, identity, identity_x)?; send_empty(&mut state).await?; - send_identity(&mut state).await?; recv_identity(&mut state).await?; - state.finish.await - } + send_identity(&mut state).await?; + state.finish() + })) +} - /// Creates an authenticated Noise handshake for the responder of a - /// 1.5-roundtrip (3 message) handshake pattern. - /// - /// Subject to the chosen [`IdentityExchange`], this message sequence - /// identifies the local node in the second message payload and expects - /// the remote to identify itself in the third message payload. The first - /// (unencrypted) message payload is always empty. - /// - /// This message sequence is suitable for authenticated 3-message Noise handshake - /// patterns where the static keys of the responder and initiator are either known - /// (i.e. appear in the pre-message pattern) or are sent with the second and third - /// message, respectively (e.g. `XX`). - /// - /// ```raw - /// initiator --{}--> responder - /// initiator <-{id}- responder - /// initiator -{id}-> responder - /// ``` - pub async fn rt15_responder( - io: T, - session: Result, - identity: KeypairIdentity, - identity_x: IdentityExchange - ) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> { - let mut state = State::new(io, session, identity, identity_x); +/// Creates an authenticated Noise handshake for the responder of a +/// 1.5-roundtrip (3 message) handshake pattern. +/// +/// Subject to the chosen [`IdentityExchange`], this message sequence +/// identifies the local node in the second message payload and expects +/// the remote to identify itself in the third message payload. The first +/// (unencrypted) message payload is always empty. +/// +/// This message sequence is suitable for authenticated 3-message Noise handshake +/// patterns where the static keys of the responder and initiator are either known +/// (i.e. appear in the pre-message pattern) or are sent with the second and third +/// message, respectively (e.g. `XX`). +/// +/// ```raw +/// initiator --{}--> responder +/// initiator <-{id}- responder +/// initiator -{id}-> responder +/// ``` +pub fn rt15_responder( + io: T, + session: Result, + identity: KeypairIdentity, + identity_x: IdentityExchange +) -> Handshake +where + T: AsyncWrite + AsyncRead + Unpin + Send + 'static, + C: Protocol + AsRef<[u8]> +{ + Handshake(Box::pin(async move { + let mut state = State::new(io, session, identity, identity_x)?; recv_empty(&mut state).await?; send_identity(&mut state).await?; recv_identity(&mut state).await?; - state.finish().await - } + state.finish() + })) } ////////////////////////////////////////////////////////////////////////////// @@ -240,14 +274,14 @@ impl State { session: Result, identity: KeypairIdentity, identity_x: IdentityExchange - ) -> FutureResult { + ) -> Result { let (id_remote_pubkey, send_identity) = match identity_x { IdentityExchange::Mutual => (None, true), IdentityExchange::Send { remote } => (Some(remote), true), IdentityExchange::Receive => (None, false), IdentityExchange::None { remote } => (Some(remote), false) }; - future::result(session.map(|s| + session.map(|s| State { identity, io: NoiseOutput::new(io, s), @@ -255,7 +289,7 @@ impl State { id_remote_pubkey, send_identity } - )) + ) } } @@ -263,19 +297,19 @@ impl State { /// Finish a handshake, yielding the established remote identity and the /// [`NoiseOutput`] for communicating on the encrypted channel. - fn finish(self) -> FutureResult<(RemoteIdentity, NoiseOutput), NoiseError> + fn finish(self) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> where C: Protocol + AsRef<[u8]> { let dh_remote_pubkey = match self.io.session.get_remote_static() { None => None, Some(k) => match C::public_from_bytes(k) { - Err(e) => return future::err(e), + Err(e) => return Err(e), Ok(dh_pk) => Some(dh_pk) } }; match self.io.session.into_transport_mode() { - Err(e) => future::err(e.into()), + Err(e) => Err(e.into()), Ok(s) => { let remote = match (self.id_remote_pubkey, dh_remote_pubkey) { (_, None) => RemoteIdentity::Unknown, @@ -284,11 +318,11 @@ impl State if C::verify(&id_pk, &dh_pk, &self.dh_remote_pubkey_sig) { RemoteIdentity::IdentityKey(id_pk) } else { - return future::err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey) } } }; - future::ok((remote, NoiseOutput { session: s, .. self.io })) + Ok((remote, NoiseOutput { session: s, .. self.io })) } } } @@ -297,121 +331,72 @@ impl State ////////////////////////////////////////////////////////////////////////////// // Handshake Message Futures -// RecvEmpty ----------------------------------------------------------------- - /// A future for receiving a Noise handshake message with an empty payload. -/// -/// Obtained from [`Handshake::recv_empty`]. async fn recv_empty(state: &mut State) -> Result<(), NoiseError> where - T: AsyncRead + T: AsyncRead + Unpin { state.io.read(&mut []).await?; Ok(()) } -// SendEmpty ----------------------------------------------------------------- - /// A future for sending a Noise handshake message with an empty payload. -/// -/// Obtained from [`Handshake::send_empty`]. async fn send_empty(state: &mut State) -> Result<(), NoiseError> where - T: AsyncWrite + T: AsyncWrite + Unpin { - state.write(&[]).await?; - state.flush().await?; + state.io.write(&[]).await?; + state.io.flush().await?; Ok(()) } -// RecvIdentity -------------------------------------------------------------- - /// A future for receiving a Noise handshake message with a payload /// identifying the remote. -/// -/// Obtained from [`Handshake::recv_identity`]. -struct RecvIdentity { - state: RecvIdentityState -} - -enum RecvIdentityState { - Init(State), - ReadPayloadLen(nio::ReadExact, [u8; 2]>), - ReadPayload(nio::ReadExact, Vec>), - Done -} - -impl Future for RecvIdentity +async fn recv_identity(state: &mut State) -> Result<(), NoiseError> where - T: AsyncRead, + T: AsyncRead + Unpin, { - type Error = NoiseError; - type Item = State; + let mut len_buf = [0,0]; + state.io.read_exact(&mut len_buf).await?; + let len = u16::from_be_bytes(len_buf) as usize; - fn poll(&mut self) -> Poll { - loop { - match mem::replace(&mut self.state, RecvIdentityState::Done) { - RecvIdentityState::Init(st) => { - self.state = RecvIdentityState::ReadPayloadLen(nio::read_exact(st, [0, 0])); - }, - RecvIdentityState::ReadPayloadLen(mut read_len) => { - if let Async::Ready((st, bytes)) = read_len.poll()? { - let len = u16::from_be_bytes(bytes) as usize; - let buf = vec![0; len]; - self.state = RecvIdentityState::ReadPayload(nio::read_exact(st, buf)); - } else { - self.state = RecvIdentityState::ReadPayloadLen(read_len); - return Ok(Async::NotReady); - } - }, - RecvIdentityState::ReadPayload(mut read_payload) => { - if let Async::Ready((mut st, bytes)) = read_payload.poll()? { - let pb: payload::Identity = protobuf::parse_from_bytes(&bytes)?; - if !pb.pubkey.is_empty() { - let pk = identity::PublicKey::from_protobuf_encoding(pb.get_pubkey()) - .map_err(|_| NoiseError::InvalidKey)?; - if let Some(ref k) = st.id_remote_pubkey { - if k != &pk { - return Err(NoiseError::InvalidKey) - } - } - st.id_remote_pubkey = Some(pk); - } - if !pb.signature.is_empty() { - st.dh_remote_pubkey_sig = Some(pb.signature) - } - return Ok(Async::Ready(st)) - } else { - self.state = RecvIdentityState::ReadPayload(read_payload); - return Ok(Async::NotReady) - } - }, - RecvIdentityState::Done => panic!("RecvIdentity polled after completion") + let mut payload_buf = vec![0; len]; + state.io.read_exact(&mut payload_buf).await?; + let pb: payload::Identity = protobuf::parse_from_bytes(&payload_buf)?; + + if !pb.pubkey.is_empty() { + let pk = identity::PublicKey::from_protobuf_encoding(pb.get_pubkey()) + .map_err(|_| NoiseError::InvalidKey)?; + if let Some(ref k) = state.id_remote_pubkey { + if k != &pk { + return Err(NoiseError::InvalidKey) } } + state.id_remote_pubkey = Some(pk); + } + if !pb.signature.is_empty() { + state.dh_remote_pubkey_sig = Some(pb.signature); } -} -// SendIdentity -------------------------------------------------------------- + Ok(()) +} /// Send a Noise handshake message with a payload identifying the local node to the remote. -/// -/// Obtained from [`Handshake::send_identity`]. async fn send_identity(state: &mut State) -> Result<(), NoiseError> where - T: AsyncWrite + T: AsyncWrite + Unpin, { let mut pb = payload::Identity::new(); - if st.send_identity { - pb.set_pubkey(st.identity.public.clone().into_protobuf_encoding()); + if state.send_identity { + pb.set_pubkey(state.identity.public.clone().into_protobuf_encoding()); } - if let Some(ref sig) = st.identity.signature { + if let Some(ref sig) = state.identity.signature { pb.set_signature(sig.clone()); } let pb_bytes = pb.write_to_bytes()?; let len = (pb_bytes.len() as u16).to_be_bytes(); - st.write_all(&len).await?; - st.write_all(&pb_bytes).await?; - st.flush().await?; + state.io.write_all(&len).await?; + state.io.write_all(&pb_bytes).await?; + state.io.flush().await?; Ok(()) } diff --git a/protocols/noise/src/lib.rs b/protocols/noise/src/lib.rs index 97346a52d89..e82d7ff554f 100644 --- a/protocols/noise/src/lib.rs +++ b/protocols/noise/src/lib.rs @@ -25,11 +25,11 @@ //! //! This crate provides `libp2p_core::InboundUpgrade` and `libp2p_core::OutboundUpgrade` //! implementations for various noise handshake patterns (currently `IK`, `IX`, and `XX`) -//! over a particular choice of DH key agreement (currently only X25519). +//! over a particular choice of Diffie–Hellman key agreement (currently only X25519). //! //! All upgrades produce as output a pair, consisting of the remote's static public key //! and a `NoiseOutput` which represents the established cryptographic session with the -//! remote, implementing `tokio_io::AsyncRead` and `tokio_io::AsyncWrite`. +//! remote, implementing `futures::io::AsyncRead` and `futures::io::AsyncWrite`. //! //! # Usage //! @@ -57,12 +57,14 @@ mod protocol; pub use error::NoiseError; pub use io::NoiseOutput; -pub use io::handshake::{RemoteIdentity, IdentityExchange}; +pub use io::handshake; +pub use io::handshake::{Handshake, RemoteIdentity, IdentityExchange}; pub use protocol::{Keypair, AuthenticKeypair, KeypairIdentity, PublicKey, SecretKey}; pub use protocol::{Protocol, ProtocolParams, x25519::X25519, IX, IK, XX}; +use futures::prelude::*; use libp2p_core::{identity, PeerId, UpgradeInfo, InboundUpgrade, OutboundUpgrade, Negotiated}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::pin::Pin; use zeroize::Zeroize; /// The protocol upgrade configuration. @@ -157,7 +159,7 @@ where impl InboundUpgrade for NoiseConfig where NoiseConfig: UpgradeInfo, - T: AsyncRead + AsyncWrite + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, { type Output = (RemoteIdentity, NoiseOutput>); @@ -169,7 +171,7 @@ where .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - Handshake::rt1_responder(socket, session, + handshake::rt1_responder(socket, session, self.dh_keys.into_identity(), IdentityExchange::Mutual) } @@ -178,7 +180,7 @@ where impl OutboundUpgrade for NoiseConfig where NoiseConfig: UpgradeInfo, - T: AsyncRead + AsyncWrite + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, { type Output = (RemoteIdentity, NoiseOutput>); @@ -190,9 +192,9 @@ where .local_private_key(self.dh_keys.secret().as_ref()) .build_initiator() .map_err(NoiseError::from); - Handshake::rt1_initiator(socket, session, - self.dh_keys.into_identity(), - IdentityExchange::Mutual) + handshake::rt1_initiator(socket, session, + self.dh_keys.into_identity(), + IdentityExchange::Mutual) } } @@ -201,7 +203,7 @@ where impl InboundUpgrade for NoiseConfig where NoiseConfig: UpgradeInfo, - T: AsyncRead + AsyncWrite + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, { type Output = (RemoteIdentity, NoiseOutput>); @@ -213,7 +215,7 @@ where .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - Handshake::rt15_responder(socket, session, + handshake::rt15_responder(socket, session, self.dh_keys.into_identity(), IdentityExchange::Mutual) } @@ -222,7 +224,7 @@ where impl OutboundUpgrade for NoiseConfig where NoiseConfig: UpgradeInfo, - T: AsyncRead + AsyncWrite + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, { type Output = (RemoteIdentity, NoiseOutput>); @@ -234,7 +236,7 @@ where .local_private_key(self.dh_keys.secret().as_ref()) .build_initiator() .map_err(NoiseError::from); - Handshake::rt15_initiator(socket, session, + handshake::rt15_initiator(socket, session, self.dh_keys.into_identity(), IdentityExchange::Mutual) } @@ -245,7 +247,7 @@ where impl InboundUpgrade for NoiseConfig where NoiseConfig: UpgradeInfo, - T: AsyncRead + AsyncWrite + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, { type Output = (RemoteIdentity, NoiseOutput>); @@ -257,7 +259,7 @@ where .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - Handshake::rt1_responder(socket, session, + handshake::rt1_responder(socket, session, self.dh_keys.into_identity(), IdentityExchange::Receive) } @@ -266,7 +268,7 @@ where impl OutboundUpgrade for NoiseConfig, identity::PublicKey)> where NoiseConfig, identity::PublicKey)>: UpgradeInfo, - T: AsyncRead + AsyncWrite + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, { type Output = (RemoteIdentity, NoiseOutput>); @@ -279,7 +281,7 @@ where .remote_public_key(self.remote.0.as_ref()) .build_initiator() .map_err(NoiseError::from); - Handshake::rt1_initiator(socket, session, + handshake::rt1_initiator(socket, session, self.dh_keys.into_identity(), IdentityExchange::Send { remote: self.remote.1 }) } @@ -319,23 +321,20 @@ where NoiseConfig: UpgradeInfo + InboundUpgrade, NoiseOutput>), Error = NoiseError - >, + > + 'static, + as InboundUpgrade>::Future: Send, T: AsyncRead + AsyncWrite + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, { type Output = (PeerId, NoiseOutput>); type Error = NoiseError; - type Future = future::AndThen< - as InboundUpgrade>::Future, - FutureResult, - fn((RemoteIdentity, NoiseOutput>)) -> FutureResult - >; + type Future = Pin> + Send>>; fn upgrade_inbound(self, socket: Negotiated, info: Self::Info) -> Self::Future { - self.config.upgrade_inbound(socket, info) - .and_then(|(remote, io)| future::result(match remote { - RemoteIdentity::IdentityKey(pk) => Ok((pk.into_peer_id(), io)), - _ => Err(NoiseError::AuthenticationFailed) + Box::pin(self.config.upgrade_inbound(socket, info) + .and_then(|(remote, io)| match remote { + RemoteIdentity::IdentityKey(pk) => future::ok((pk.into_peer_id(), io)), + _ => future::err(NoiseError::AuthenticationFailed) })) } } @@ -345,24 +344,20 @@ where NoiseConfig: UpgradeInfo + OutboundUpgrade, NoiseOutput>), Error = NoiseError - >, + > + 'static, + as OutboundUpgrade>::Future: Send, T: AsyncRead + AsyncWrite + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, { type Output = (PeerId, NoiseOutput>); type Error = NoiseError; - type Future = future::AndThen< - as OutboundUpgrade>::Future, - FutureResult, - fn((RemoteIdentity, NoiseOutput>)) -> FutureResult - >; + type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: Negotiated, info: Self::Info) -> Self::Future { - self.config.upgrade_outbound(socket, info) - .and_then(|(remote, io)| future::result(match remote { - RemoteIdentity::IdentityKey(pk) => Ok((pk.into_peer_id(), io)), - _ => Err(NoiseError::AuthenticationFailed) + Box::pin(self.config.upgrade_outbound(socket, info) + .and_then(|(remote, io)| match remote { + RemoteIdentity::IdentityKey(pk) => future::ok((pk.into_peer_id(), io)), + _ => future::err(NoiseError::AuthenticationFailed) })) } } - diff --git a/protocols/noise/tests/smoke.rs b/protocols/noise/tests/smoke.rs index ff7a9d5a163..6fd8de94931 100644 --- a/protocols/noise/tests/smoke.rs +++ b/protocols/noise/tests/smoke.rs @@ -26,7 +26,6 @@ use libp2p_noise::{Keypair, X25519, NoiseConfig, RemoteIdentity, NoiseError, Noi use libp2p_tcp::{TcpConfig, TcpTransStream}; use log::info; use quickcheck::QuickCheck; -use tokio::{self, io}; #[allow(dead_code)] fn core_upgrade_compat() { @@ -113,9 +112,9 @@ fn ik_xx() { let server_transport = TcpConfig::new() .and_then(move |output, endpoint| { if endpoint.is_listener() { - Either::A(apply_inbound(output, NoiseConfig::ik_listener(server_dh))) + Either::Left(apply_inbound(output, NoiseConfig::ik_listener(server_dh))) } else { - Either::B(apply_outbound(output, NoiseConfig::xx(server_dh))) + Either::Right(apply_outbound(output, NoiseConfig::xx(server_dh))) } }) .and_then(move |out, _| expect_identity(out, &client_id_public)); @@ -125,10 +124,10 @@ fn ik_xx() { let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { if endpoint.is_dialer() { - Either::A(apply_outbound(output, + Either::Left(apply_outbound(output, NoiseConfig::ik_dialer(client_dh, server_id_public, server_dh_public))) } else { - Either::B(apply_inbound(output, NoiseConfig::xx(client_dh))) + Either::Right(apply_inbound(output, NoiseConfig::xx(client_dh))) } }) .and_then(move |out, _| expect_identity(out, &server_id_public2)); @@ -145,55 +144,63 @@ fn run(server_transport: T, client_transport: U, message1: Vec) where T: Transport, T::Dial: Send + 'static, - T::Listener: Send + 'static, + T::Listener: Send + Unpin + futures::stream::TryStream + 'static, T::ListenerUpgrade: Send + 'static, U: Transport, U::Dial: Send + 'static, U::Listener: Send + 'static, U::ListenerUpgrade: Send + 'static, { - let message2 = message1.clone(); - - let mut server = server_transport - .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) - .unwrap(); - - let server_address = server.by_ref().wait() - .next() - .expect("some event") - .expect("no error") - .into_new_address() - .expect("listen address"); - - let server = server.take(1) - .filter_map(ListenerEvent::into_upgrade) - .and_then(|client| client.0) - .map_err(|e| panic!("server error: {}", e)) - .and_then(|(_, client)| { + futures::executor::block_on(async { + let mut message2 = message1.clone(); + + let mut server: T::Listener = server_transport + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); + + let server_address = server.try_next() + .await + .expect("some event") + .expect("no error") + .into_new_address() + .expect("listen address"); + + let client_fut = async { + let mut client_session = client_transport.dial(server_address.clone()) + .unwrap() + .await + .map(|(_, session)| session) + .expect("no error"); + + client_session.write_all(&mut message2).await.expect("no error"); + client_session.flush().await.expect("no error"); + }; + + let server_fut = async { + let mut server_session = server.try_next() + .await + .expect("some event") + .map(ListenerEvent::into_upgrade) + .expect("no error") + .map(|client| client.0) + .expect("listener upgrade") + .await + .map(|(_, session)| session) + .expect("no error"); + + let mut server_buffer = vec![]; info!("server: reading message"); - io::read_to_end(client, Vec::new()) - }) - .for_each(move |msg| { - assert_eq!(msg.1, message1); - Ok(()) - }); - - let client = client_transport.dial(server_address.clone()).unwrap() - .map_err(|e| panic!("client error: {}", e)) - .and_then(move |(_, server)| { - io::write_all(server, message2).and_then(|(client, _)| io::flush(client)) - }) - .map(|_| ()); - - let future = client.join(server) - .map_err(|e| panic!("{:?}", e)) - .map(|_| ()); - - tokio::run(future) + server_session.read_to_end(&mut server_buffer).await.expect("no error"); + + assert_eq!(server_buffer, message1); + }; + + futures::future::join(server_fut, client_fut).await; + }) } fn expect_identity(output: Output, pk: &identity::PublicKey) - -> impl Future + -> impl Future> { match output.0 { RemoteIdentity::IdentityKey(ref k) if k == pk => future::ok(output),