Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add peek APIs to std::net #38983

Merged
merged 1 commit into from
Feb 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libstd/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@
#![feature(oom)]
#![feature(optin_builtin_traits)]
#![feature(panic_unwind)]
#![feature(peek)]
#![feature(placement_in_syntax)]
#![feature(prelude_import)]
#![feature(pub_restricted)]
Expand Down
54 changes: 54 additions & 0 deletions src/libstd/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,29 @@ impl TcpStream {
self.0.write_timeout()
}

/// Receives data on the socket from the remote adress to which it is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be good to document the return value as well.

/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::TcpStream;
///
/// let stream = TcpStream::connect("127.0.0.1:8000")
/// .expect("couldn't bind to address");
/// let mut buf = [0; 10];
/// let len = stream.peek(&mut buf).expect("peek failed");
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.peek(buf)
}

/// Sets the value of the `TCP_NODELAY` option on this socket.
///
/// If set, this option disables the Nagle algorithm. This means that
Expand Down Expand Up @@ -1405,4 +1428,35 @@ mod tests {
Err(e) => panic!("unexpected error {}", e),
}
}

#[test]
fn peek() {
each_ip(&mut |addr| {
let (txdone, rxdone) = channel();

let srv = t!(TcpListener::bind(&addr));
let _t = thread::spawn(move|| {
let mut cl = t!(srv.accept()).0;
cl.write(&[1,3,3,7]).unwrap();
t!(rxdone.recv());
});

let mut c = t!(TcpStream::connect(&addr));
let mut b = [0; 10];
for _ in 1..3 {
let len = c.peek(&mut b).unwrap();
assert_eq!(len, 4);
}
let len = c.read(&mut b).unwrap();
assert_eq!(len, 4);

t!(c.set_nonblocking(true));
match c.peek(&mut b) {
Ok(_) => panic!("expected error"),
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {}
Err(e) => panic!("unexpected error {}", e),
}
t!(txdone.send(()));
})
}
}
97 changes: 97 additions & 0 deletions src/libstd/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ impl UdpSocket {
self.0.recv_from(buf)
}

/// Receives data from the socket, without removing it from the queue.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recvfrom` system call.
///
/// On success, returns the number of bytes peeked and the address from
/// whence the data came.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::UdpSocket;
///
/// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
/// let mut buf = [0; 10];
/// let (number_of_bytes, src_addr) = socket.peek_from(&mut buf)
/// .expect("Didn't receive data");
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.0.peek_from(buf)
}

/// Sends data on the socket to the given address. On success, returns the
/// number of bytes written.
///
Expand Down Expand Up @@ -579,6 +603,37 @@ impl UdpSocket {
self.0.recv(buf)
}

/// Receives data on the socket from the remote adress to which it is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs should include info about the return value.

/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
///
/// # Errors
///
/// This method will fail if the socket is not connected. The `connect` method
/// will connect this socket to a remote address.
///
/// # Examples
///
/// ```no_run
/// #![feature(peek)]
/// use std::net::UdpSocket;
///
/// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
/// socket.connect("127.0.0.1:8080").expect("connect function failed");
/// let mut buf = [0; 10];
/// match socket.peek(&mut buf) {
/// Ok(received) => println!("received {} bytes", received),
/// Err(e) => println!("peek function failed: {:?}", e),
/// }
/// ```
#[unstable(feature = "peek", issue = "38980")]
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.peek(buf)
}

/// Moves this UDP socket into or out of nonblocking mode.
///
/// On Unix this corresponds to calling fcntl, and on Windows this
Expand Down Expand Up @@ -869,6 +924,48 @@ mod tests {
assert_eq!(b"hello world", &buf[..]);
}

#[test]
fn connect_send_peek_recv() {
each_ip(&mut |addr, _| {
let socket = t!(UdpSocket::bind(&addr));
t!(socket.connect(addr));

t!(socket.send(b"hello world"));

for _ in 1..3 {
let mut buf = [0; 11];
let size = t!(socket.peek(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
}

let mut buf = [0; 11];
let size = t!(socket.recv(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
})
}

#[test]
fn peek_from() {
each_ip(&mut |addr, _| {
let socket = t!(UdpSocket::bind(&addr));
t!(socket.send_to(b"hello world", &addr));

for _ in 1..3 {
let mut buf = [0; 11];
let (size, _) = t!(socket.peek_from(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
}

let mut buf = [0; 11];
let (size, _) = t!(socket.recv_from(&mut buf));
assert_eq!(b"hello world", &buf[..]);
assert_eq!(size, 11);
})
}

#[test]
fn ttl() {
let ttl = 100;
Expand Down
45 changes: 42 additions & 3 deletions src/libstd/sys/unix/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

use ffi::CStr;
use io;
use libc::{self, c_int, size_t, sockaddr, socklen_t, EAI_SYSTEM};
use libc::{self, c_int, c_void, size_t, sockaddr, socklen_t, EAI_SYSTEM, MSG_PEEK};
use mem;
use net::{SocketAddr, Shutdown};
use str;
use sys::fd::FileDesc;
use sys_common::{AsInner, FromInner, IntoInner};
use sys_common::net::{getsockopt, setsockopt};
use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
use time::Duration;

pub use sys::{cvt, cvt_r};
Expand Down Expand Up @@ -155,8 +156,46 @@ impl Socket {
self.0.duplicate().map(Socket)
}

fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
let ret = cvt(unsafe {
libc::recv(self.0.raw(),
buf.as_mut_ptr() as *mut c_void,
buf.len(),
flags)
})?;
Ok(ret as usize)
}

pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
self.recv_with_flags(buf, 0)
}

pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, MSG_PEEK)
}

fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
-> io::Result<(usize, SocketAddr)> {
let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t;

let n = cvt(unsafe {
libc::recvfrom(self.0.raw(),
buf.as_mut_ptr() as *mut c_void,
buf.len(),
flags,
&mut storage as *mut _ as *mut _,
&mut addrlen)
})?;
Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
}

pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, 0)
}

pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, MSG_PEEK)
}

pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
Expand Down
1 change: 1 addition & 0 deletions src/libstd/sys/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ pub const IP_ADD_MEMBERSHIP: c_int = 12;
pub const IP_DROP_MEMBERSHIP: c_int = 13;
pub const IPV6_ADD_MEMBERSHIP: c_int = 12;
pub const IPV6_DROP_MEMBERSHIP: c_int = 13;
pub const MSG_PEEK: c_int = 0x2;

#[repr(C)]
pub struct ip_mreq {
Expand Down
44 changes: 42 additions & 2 deletions src/libstd/sys/windows/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,59 @@ impl Socket {
Ok(socket)
}

pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
// On unix when a socket is shut down all further reads return 0, so we
// do the same on windows to map a shut down socket to returning EOF.
let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
unsafe {
match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, 0) {
match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) {
-1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0),
-1 => Err(last_error()),
n => Ok(n as usize)
}
}
}

pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, 0)
}

pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.recv_with_flags(buf, c::MSG_PEEK)
}

fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
-> io::Result<(usize, SocketAddr)> {
let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;

// On unix when a socket is shut down all further reads return 0, so we
// do the same on windows to map a shut down socket to returning EOF.
unsafe {
match c::recvfrom(self.0,
buf.as_mut_ptr() as *mut c_void,
len,
flags,
&mut storage as *mut _ as *mut _,
&mut addrlen) {
-1 if c::WSAGetLastError() == c::WSAESHUTDOWN => {
Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?))
},
-1 => Err(last_error()),
n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)),
}
}
}

pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, 0)
}

pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, c::MSG_PEEK)
}

pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
let mut me = self;
(&mut me).read_to_end(buf)
Expand Down
24 changes: 13 additions & 11 deletions src/libstd/sys_common/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fn sockname<F>(f: F) -> io::Result<SocketAddr>
}
}

fn sockaddr_to_addr(storage: &c::sockaddr_storage,
pub fn sockaddr_to_addr(storage: &c::sockaddr_storage,
len: usize) -> io::Result<SocketAddr> {
match storage.ss_family as c_int {
c::AF_INET => {
Expand Down Expand Up @@ -222,6 +222,10 @@ impl TcpStream {
self.inner.timeout(c::SO_SNDTIMEO)
}

pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf)
}

pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
Expand Down Expand Up @@ -441,17 +445,11 @@ impl UdpSocket {
}

pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
let mut storage: c::sockaddr_storage = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
self.inner.recv_from(buf)
}

let n = cvt(unsafe {
c::recvfrom(*self.inner.as_inner(),
buf.as_mut_ptr() as *mut c_void,
len, 0,
&mut storage as *mut _ as *mut _, &mut addrlen)
})?;
Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.inner.peek_from(buf)
}

pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result<usize> {
Expand Down Expand Up @@ -578,6 +576,10 @@ impl UdpSocket {
self.inner.read(buf)
}

pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf)
}

pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
let ret = cvt(unsafe {
Expand Down