diff --git a/src/lib.rs b/src/lib.rs index 25fd6bd..88540f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -312,6 +312,7 @@ impl Pica { fn connect(&mut self, stream: TcpStream) { let (packet_tx, mut packet_rx) = mpsc::unbounded_channel(); + let invalid_packet_tx = packet_tx.clone(); let device_handle = self.counter; let pica_tx = self.command_tx.clone(); let pcapng_dir = self.pcapng_dir.clone(); @@ -339,10 +340,10 @@ impl Pica { // The task notifies pica when exiting to let it clean // the state. tokio::task::spawn(async move { - let mut pcapng_file = if let Some(dir) = pcapng_dir { + let pcapng_file = if let Some(dir) = pcapng_dir { let full_path = dir.join(format!("device-{}.pcapng", device_handle)); log::debug!("Recording pcapng to file {}", full_path.as_path().display()); - Some(pcapng::File::create(full_path).await.unwrap()) + Some(pcapng::File::create(full_path).unwrap()) } else { None }; @@ -351,34 +352,45 @@ impl Pica { let mut uci_reader = packets::uci::Reader::new(uci_rx); let mut uci_writer = packets::uci::Writer::new(uci_tx); - 'outer: loop { - tokio::select! { - // Read command packet sent from connected UWB host. - // Run associated command. - result = uci_reader.read(&mut pcapng_file) => - match result { - Ok(packet) => - match parse_uci_packet(&packet) { - UciParseResult::UciCommand(cmd) => { - pica_tx.send(PicaCommand::UciCommand(device_handle, cmd)).await.unwrap() - }, - UciParseResult::UciData(data) => { - pica_tx.send(PicaCommand::UciData(device_handle, data)).await.unwrap() - }, - UciParseResult::Err(response) => - uci_writer.write(&response, &mut pcapng_file).await.unwrap(), - UciParseResult::Skip => (), - }, - Err(_) => break 'outer - }, - - // Send response packets to the connected UWB host. - Some(packet) = packet_rx.recv() => - if uci_writer.write(&packet, &mut pcapng_file).await.is_err() { - break 'outer + tokio::try_join!( + async { + loop { + // Read UCI packets sent from connected UWB host. + // Run associated command. + match uci_reader.read(&pcapng_file).await { + Ok(packet) => match parse_uci_packet(&packet) { + UciParseResult::UciCommand(cmd) => pica_tx + .send(PicaCommand::UciCommand(device_handle, cmd)) + .await + .unwrap(), + UciParseResult::UciData(data) => pica_tx + .send(PicaCommand::UciData(device_handle, data)) + .await + .unwrap(), + UciParseResult::Err(response) => { + invalid_packet_tx.send(response.into()).unwrap() + } + UciParseResult::Skip => (), + }, + err => break err, } + } + }, + async { + loop { + // Write UCI packets to connected UWB host. + let Some(packet) = packet_rx.recv().await else { + anyhow::bail!("uci packet channel closed"); + }; + match uci_writer.write(&packet, &pcapng_file).await { + Ok(_) => (), + err => break err, + } + } } - } + ) + .unwrap(); + pica_tx .send(PicaCommand::Disconnect(device_handle)) .await diff --git a/src/packets.rs b/src/packets.rs index f0b4b64..576e571 100644 --- a/src/packets.rs +++ b/src/packets.rs @@ -68,7 +68,7 @@ pub mod uci { /// re-assembled if segmented on the UCI transport. Data segments /// are _not_ re-assembled but returned immediatly for credit /// acknowledgment. - pub async fn read(&mut self, pcapng: &mut Option) -> anyhow::Result> { + pub async fn read(&mut self, pcapng: &Option) -> anyhow::Result> { use tokio::io::AsyncReadExt; let mut complete_packet = vec![0; HEADER_SIZE]; @@ -102,11 +102,11 @@ pub mod uci { self.socket.read_exact(&mut payload_bytes).await?; complete_packet.extend(&payload_bytes); - if let Some(ref mut pcapng) = pcapng { + if let Some(ref pcapng) = pcapng { let mut packet_bytes = vec![]; packet_bytes.extend(&complete_packet[0..HEADER_SIZE]); packet_bytes.extend(&payload_bytes); - pcapng.write(&packet_bytes, pcapng::Direction::Tx).await?; + pcapng.write(&packet_bytes, pcapng::Direction::Tx)?; } if common_packet_header.get_mt() == MessageType::Data { @@ -135,7 +135,7 @@ pub mod uci { pub async fn write( &mut self, mut packet: &[u8], - pcapng: &mut Option, + pcapng: &Option, ) -> anyhow::Result<()> { use tokio::io::AsyncWriteExt; @@ -169,11 +169,11 @@ pub mod uci { _ => header_bytes[3] = chunk_length as u8, } - if let Some(ref mut pcapng) = pcapng { + if let Some(ref pcapng) = pcapng { let mut packet_bytes = vec![]; packet_bytes.extend(&header_bytes); packet_bytes.extend(&packet[..chunk_length]); - pcapng.write(&packet_bytes, pcapng::Direction::Rx).await? + pcapng.write(&packet_bytes, pcapng::Direction::Rx)? } // Write the header and payload segment bytes. diff --git a/src/pcapng.rs b/src/pcapng.rs index 3119bb1..9c1da7e 100644 --- a/src/pcapng.rs +++ b/src/pcapng.rs @@ -14,12 +14,12 @@ #![allow(clippy::unused_io_amount)] +use std::io::Write; use std::path::Path; use std::time::Instant; -use tokio::io::AsyncWriteExt; pub struct File { - file: tokio::fs::File, + file: std::sync::Mutex, start_time: Instant, } @@ -29,51 +29,50 @@ pub enum Direction { } impl File { - pub async fn create>(path: P) -> std::io::Result { - let mut file = tokio::fs::File::create(path).await?; + pub fn create>(path: P) -> std::io::Result { + let mut file = std::fs::File::create(path)?; // PCAPng files must start with a Section Header Block. - file.write(&u32::to_le_bytes(0x0A0D0D0A)).await?; // Block Type - file.write(&u32::to_le_bytes(28)).await?; // Block Total Length - file.write(&u32::to_le_bytes(0x1A2B3C4D)).await?; // Byte-Order Magic - file.write(&u16::to_le_bytes(1)).await?; // Major Version - file.write(&u16::to_le_bytes(0)).await?; // Minor Version - file.write(&u64::to_le_bytes(0xFFFFFFFFFFFFFFFF)).await?; // Section Length (not specified) - file.write(&u32::to_le_bytes(28)).await?; // Block Total Length + file.write(&u32::to_le_bytes(0x0A0D0D0A))?; // Block Type + file.write(&u32::to_le_bytes(28))?; // Block Total Length + file.write(&u32::to_le_bytes(0x1A2B3C4D))?; // Byte-Order Magic + file.write(&u16::to_le_bytes(1))?; // Major Version + file.write(&u16::to_le_bytes(0))?; // Minor Version + file.write(&u64::to_le_bytes(0xFFFFFFFFFFFFFFFF))?; // Section Length (not specified) + file.write(&u32::to_le_bytes(28))?; // Block Total Length // Write the Interface Description Block used for all // UCI records. - file.write(&u32::to_le_bytes(0x00000001)).await?; // Block Type - file.write(&u32::to_le_bytes(20)).await?; // Block Total Length - file.write(&u16::to_le_bytes(293)).await?; // LinkType - file.write(&u16::to_le_bytes(0)).await?; // Reserved - file.write(&u32::to_le_bytes(0)).await?; // SnapLen (no limit) - file.write(&u32::to_le_bytes(20)).await?; // Block Total Length + file.write(&u32::to_le_bytes(0x00000001))?; // Block Type + file.write(&u32::to_le_bytes(20))?; // Block Total Length + file.write(&u16::to_le_bytes(293))?; // LinkType + file.write(&u16::to_le_bytes(0))?; // Reserved + file.write(&u32::to_le_bytes(0))?; // SnapLen (no limit) + file.write(&u32::to_le_bytes(20))?; // Block Total Length Ok(File { - file, + file: std::sync::Mutex::new(file), start_time: Instant::now(), }) } - pub async fn write(&mut self, packet: &[u8], _dir: Direction) -> std::io::Result<()> { + pub fn write(&self, packet: &[u8], _dir: Direction) -> std::io::Result<()> { let packet_data_padding: usize = 4 - packet.len() % 4; let block_total_length: u32 = packet.len() as u32 + packet_data_padding as u32 + 32; let timestamp = self.start_time.elapsed().as_micros(); - let file = &mut self.file; + let mut file = self.file.lock().unwrap(); // Wrap the packet inside an Enhanced Packet Block. - file.write(&u32::to_le_bytes(0x00000006)).await?; // Block Type - file.write(&u32::to_le_bytes(block_total_length)).await?; - file.write(&u32::to_le_bytes(0)).await?; // Interface ID - file.write(&u32::to_le_bytes((timestamp >> 32) as u32)) - .await?; // Timestamp (High) - file.write(&u32::to_le_bytes(timestamp as u32)).await?; // Timestamp (Low) - file.write(&u32::to_le_bytes(packet.len() as u32)).await?; // Captured Packet Length - file.write(&u32::to_le_bytes(packet.len() as u32)).await?; // Original Packet Length - file.write(packet).await?; - file.write(&vec![0; packet_data_padding]).await?; - file.write(&u32::to_le_bytes(block_total_length)).await?; // Block Total Length + file.write(&u32::to_le_bytes(0x00000006))?; // Block Type + file.write(&u32::to_le_bytes(block_total_length))?; + file.write(&u32::to_le_bytes(0))?; // Interface ID + file.write(&u32::to_le_bytes((timestamp >> 32) as u32))?; // Timestamp (High) + file.write(&u32::to_le_bytes(timestamp as u32))?; // Timestamp (Low) + file.write(&u32::to_le_bytes(packet.len() as u32))?; // Captured Packet Length + file.write(&u32::to_le_bytes(packet.len() as u32))?; // Original Packet Length + file.write(packet)?; + file.write(&vec![0; packet_data_padding])?; + file.write(&u32::to_le_bytes(block_total_length))?; // Block Total Length Ok(()) } }