diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 3bf6b91263..9e8c7ed66a 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1560,8 +1560,10 @@ impl Connection { let mut ack_eliciting = false; let mut probing = true; let mut d = Decoder::from(&packet[..]); + let mut decoded_frames = false; while d.remaining() > 0 { let f = Frame::decode(&mut d)?; + decoded_frames = true; ack_eliciting |= f.ack_eliciting(); probing &= f.path_probing(); let t = f.get_type(); @@ -1569,6 +1571,10 @@ impl Connection { self.capture_error(Some(Rc::clone(path)), now, t, Err(e))?; } } + if !decoded_frames { + qerror!([self], "Received packet with no frames"); + return Err(Error::ProtocolViolation); + } let largest_received = if let Some(space) = self .acks diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 0843d050ab..d435ac0dd8 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -764,14 +764,10 @@ impl<'a> PublicPacket<'a> { assert_ne!(self.packet_type, PacketType::Retry); assert_ne!(self.packet_type, PacketType::VersionNegotiation); - qtrace!( - "unmask hdr={}", - hex(&self.data[..self.header_len + SAMPLE_OFFSET]) - ); - let sample_offset = self.header_len + SAMPLE_OFFSET; let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE)) { + qtrace!("unmask hdr={}", hex(&self.data[..sample_offset])); crypto.compute_mask(sample) } else { Err(Error::NoMoreData) diff --git a/neqo-transport/tests/common/mod.rs b/neqo-transport/tests/common/mod.rs index e36e66f753..84a2bc0d27 100644 --- a/neqo-transport/tests/common/mod.rs +++ b/neqo-transport/tests/common/mod.rs @@ -8,7 +8,7 @@ use std::{cell::RefCell, mem, ops::Range, rc::Rc}; -use neqo_common::{event::Provider, hex_with_len, qtrace, Datagram, Decoder, Role}; +use neqo_common::{event::Provider, hex_with_len, qdebug, qtrace, Datagram, Decoder, Role}; use neqo_crypto::{ constants::{TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3}, hkdf, diff --git a/neqo-transport/tests/connection.rs b/neqo-transport/tests/connection.rs index 0b91fcf306..cdcf1b9ef1 100644 --- a/neqo-transport/tests/connection.rs +++ b/neqo-transport/tests/connection.rs @@ -127,6 +127,56 @@ fn reorder_server_initial() { assert_eq!(*client.state(), State::Confirmed); } +/// Test that the stack treats a packet without any frames as a protocol violation. +#[test] +fn packet_without_frames() { + let mut client = new_client( + ConnectionParameters::default().versions(Version::Version1, vec![Version::Version1]), + ); + let mut server = default_server(); + + let client_initial = client.process_output(now()); + let (_, client_dcid, _, _) = + decode_initial_header(client_initial.as_dgram_ref().unwrap(), Role::Client); + let client_dcid = client_dcid.to_owned(); + + let server_packet = server.process(client_initial.as_dgram_ref(), now()).dgram(); + let (server_initial, _server_hs) = split_datagram(server_packet.as_ref().unwrap()); + let (protected_header, _, _, payload) = decode_initial_header(&server_initial, Role::Server); + + // Now decrypt the packet. + let (aead, hp) = initial_aead_and_hp(&client_dcid, Role::Server); + let (mut header, pn) = remove_header_protection(&hp, protected_header, payload); + assert_eq!(pn, 0); + // Re-encode the packet number as a four-byte varint, so we have enough material for the header + // protection sample. + let hl = header.len(); + header[hl - 2] = u8::try_from(4 + aead.expansion()).unwrap(); + header.resize(header.len() + 3, 0); + let hl = header.len(); + header[hl - 4..].copy_from_slice(&[0; 4]); + header[0] |= 0b0000_0011; // Set the packet number length to 4. + + // And build an empty packet. + let mut packet = header.clone(); + packet.resize(header.len() + aead.expansion(), 0); + aead.encrypt(pn, &header, &[], &mut packet[header.len()..]) + .unwrap(); + apply_header_protection(&hp, &mut packet, protected_header.len()..header.len()); + let empty = Datagram::new( + server_initial.source(), + server_initial.destination(), + server_initial.tos(), + server_initial.ttl(), + packet, + ); + client.process_input(&empty, now()); + assert!(matches!( + client.state(), + State::Closed(ConnectionError::Transport(Error::ProtocolViolation)) + )); +} + /// Overflow the crypto buffer. #[allow(clippy::similar_names)] // For ..._scid and ..._dcid, which are fine. #[test]