From aa4accf3c484ef7ba9ad665eca690cc08bf208df Mon Sep 17 00:00:00 2001 From: Matisse Callewaert Date: Sun, 21 Apr 2024 15:26:03 +0200 Subject: [PATCH] :zap: Improve performance of reader --- feature-extraction-tool/src/main.rs | 209 ++++++++------------- feature-extraction-tool/src/utils/utils.rs | 4 +- 2 files changed, 78 insertions(+), 135 deletions(-) diff --git a/feature-extraction-tool/src/main.rs b/feature-extraction-tool/src/main.rs index c281fd1..d1e397f 100644 --- a/feature-extraction-tool/src/main.rs +++ b/feature-extraction-tool/src/main.rs @@ -543,8 +543,7 @@ where let start = Instant::now(); let mut amount_of_packets = 0; - let flow_map_ipv4: Arc> = Arc::new(DashMap::new()); - let flow_map_ipv6: Arc> = Arc::new(DashMap::new()); + let flow_map: Arc> = Arc::new(DashMap::new()); let mut cap = match pcap::Capture::from_file(path) { Ok(c) => c, @@ -561,14 +560,14 @@ where EtherTypes::Ipv4 => { if let Some(ipv4_packet) = Ipv4Packet::new(ethernet.payload()) { if let Some(features_ipv4) = extract_ipv4_features(&ipv4_packet) { - redirect_packet_ipv4(&features_ipv4, &flow_map_ipv4); + redirect_packet_ipv4(&features_ipv4, &flow_map); } } } EtherTypes::Ipv6 => { if let Some(ipv6_packet) = Ipv6Packet::new(ethernet.payload()) { if let Some(features_ipv6) = extract_ipv6_features(&ipv6_packet) { - redirect_packet_ipv6(&features_ipv6, &flow_map_ipv6); + redirect_packet_ipv6(&features_ipv6, &flow_map); } } } @@ -581,16 +580,7 @@ where } } - for entry in flow_map_ipv4.iter() { - let flow = entry.value(); - if *NO_CONTAMINANT_FEATURES.lock().unwrap().deref() { - export(&flow.dump_without_contamination()); - } else { - export(&flow.dump()); - } - } - - for entry in flow_map_ipv6.iter() { + for entry in flow_map.iter() { let flow = entry.value(); if *NO_CONTAMINANT_FEATURES.lock().unwrap().deref() { export(&flow.dump_without_contamination()); @@ -615,8 +605,7 @@ where let mut amount_of_packets = 0; let mut size: usize = 0; - let flow_map_ipv4: Arc> = Arc::new(DashMap::new()); - let flow_map_ipv6: Arc> = Arc::new(DashMap::new()); + let flow_map: Arc> = Arc::new(DashMap::new()); let mut cap = match pcap::Capture::from_file(path) { Ok(c) => c, @@ -640,14 +629,14 @@ where SLL_IPV4 => { if let Some(ipv4_packet) = Ipv4Packet::new(&packet.data[16..]) { if let Some(features_ipv4) = extract_ipv4_features(&ipv4_packet) { - redirect_packet_ipv4(&features_ipv4, &flow_map_ipv4); + redirect_packet_ipv4(&features_ipv4, &flow_map); } } } SLL_IPV6 => { if let Some(ipv6_packet) = Ipv6Packet::new(&packet.data[16..]) { if let Some(features_ipv6) = extract_ipv6_features(&ipv6_packet) { - redirect_packet_ipv6(&features_ipv6, &flow_map_ipv6); + redirect_packet_ipv6(&features_ipv6, &flow_map); } } } @@ -660,16 +649,7 @@ where } } - for entry in flow_map_ipv4.iter() { - let flow = entry.value(); - if *NO_CONTAMINANT_FEATURES.lock().unwrap().deref() { - export(&flow.dump_without_contamination()); - } else { - export(&flow.dump()); - } - } - - for entry in flow_map_ipv6.iter() { + for entry in flow_map.iter() { let flow = entry.value(); if *NO_CONTAMINANT_FEATURES.lock().unwrap().deref() { export(&flow.dump_without_contamination()); @@ -750,17 +730,17 @@ where }; let flow_id = if fwd { create_flow_id( - source, + &source, data.port_source, - destination, + &destination, data.port_destination, data.protocol, ) } else { create_flow_id( - destination, + &destination, data.port_destination, - source, + &source, data.port_source, data.protocol, ) @@ -830,17 +810,17 @@ where let flow_id = if fwd { create_flow_id( - source, + &source, data.port_source, - destination, + &destination, data.port_destination, data.protocol, ) } else { create_flow_id( - destination, + &destination, data.port_destination, - source, + &source, data.port_source, data.protocol, ) @@ -889,16 +869,16 @@ where T: Flow, { let fwd_flow_id = create_flow_id( - std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_source)), + &std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_source)), features_ipv4.port_source, - std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_destination)), + &std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_destination)), features_ipv4.port_destination, features_ipv4.protocol, ); let bwd_flow_id = create_flow_id( - std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_destination)), + &std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_destination)), features_ipv4.port_destination, - std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_source)), + &std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_source)), features_ipv4.port_source, features_ipv4.protocol, ); @@ -923,16 +903,16 @@ where T: Flow, { let fwd_flow_id = create_flow_id( - std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_source)), + &std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_source)), features_ipv6.port_source, - std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_destination)), + &std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_destination)), features_ipv6.port_destination, features_ipv6.protocol, ); let bwd_flow_id = create_flow_id( - std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_destination)), + &std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_destination)), features_ipv6.port_destination, - std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_source)), + &std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_source)), features_ipv6.port_source, features_ipv6.protocol, ); @@ -956,63 +936,45 @@ where /// /// * `Option` - Basic features of the packet. fn extract_ipv4_features(ipv4_packet: &Ipv4Packet) -> Option { - let source_ip = ipv4_packet.get_source(); - let destination_ip = ipv4_packet.get_destination(); - let protocol = ipv4_packet.get_next_level_protocol(); - - let source_port: u16; - let destination_port: u16; - - let mut combined_flags: u8 = 0; - let data_length: u16; - let header_length: u8; - let length: u16; - - let mut window_size: u16 = 0; - - if protocol.0 == IpNextHeaderProtocols::Tcp.0 { + if ipv4_packet.get_next_level_protocol().0 == IpNextHeaderProtocols::Tcp.0 { if let Some(tcp_packet) = TcpPacket::new(ipv4_packet.payload()) { - source_port = tcp_packet.get_source(); - destination_port = tcp_packet.get_destination(); - - data_length = tcp_packet.payload().len() as u16; - header_length = (tcp_packet.get_data_offset() * 4) as u8; - length = ipv4_packet.get_total_length(); - - window_size = tcp_packet.get_window(); - - combined_flags = tcp_packet.get_flags(); + return Some(BasicFeaturesIpv4::new( + ipv4_packet.get_destination().into(), + ipv4_packet.get_source().into(), + tcp_packet.get_destination(), + tcp_packet.get_source(), + tcp_packet.payload().len() as u16, + ipv4_packet.get_total_length(), + tcp_packet.get_window(), + tcp_packet.get_flags(), + ipv4_packet.get_next_level_protocol().0, + (tcp_packet.get_data_offset() * 4) as u8, + )); } else { return None; } - } else if protocol.0 == IpNextHeaderProtocols::Udp.0 { + } else if ipv4_packet.get_next_level_protocol().0 == IpNextHeaderProtocols::Udp.0 { if let Some(udp_packet) = pnet::packet::udp::UdpPacket::new(ipv4_packet.payload()) { - source_port = udp_packet.get_source(); - destination_port = udp_packet.get_destination(); - data_length = udp_packet.payload().len() as u16; - header_length = 8; - length = udp_packet.get_length(); + return Some(BasicFeaturesIpv4::new( + ipv4_packet.get_destination().into(), + ipv4_packet.get_source().into(), + udp_packet.get_destination(), + udp_packet.get_source(), + udp_packet.payload().len() as u16, + udp_packet.get_length(), + 0, + 0, + ipv4_packet.get_next_level_protocol().0, + 8, + )); } else { return None; } } else { return None; } - - Some(BasicFeaturesIpv4::new( - destination_ip.into(), - source_ip.into(), - destination_port, - source_port, - data_length, - length, - window_size, - combined_flags, - protocol.0, - header_length, - )) } /// Extracts the basic features of an ipv6 packet pnet struct. @@ -1025,63 +987,44 @@ fn extract_ipv4_features(ipv4_packet: &Ipv4Packet) -> Option /// /// * `Option` - Basic features of the packet. fn extract_ipv6_features(ipv6_packet: &Ipv6Packet) -> Option { - let source_ip = ipv6_packet.get_source(); - let destination_ip = ipv6_packet.get_destination(); - let protocol = ipv6_packet.get_next_header(); - - let source_port: u16; - let destination_port: u16; - - let mut combined_flags: u8 = 0; - - let data_length: u16; - let header_length: u8; - let length: u16; - - let mut window_size: u16 = 0; - if protocol == IpNextHeaderProtocols::Tcp { + if ipv6_packet.get_next_header() == IpNextHeaderProtocols::Tcp { if let Some(tcp_packet) = TcpPacket::new(ipv6_packet.payload()) { - source_port = tcp_packet.get_source(); - destination_port = tcp_packet.get_destination(); - - data_length = tcp_packet.payload().len() as u16; - header_length = (tcp_packet.get_data_offset() * 4) as u8; - length = ipv6_packet.packet().bytes().count() as u16; - - window_size = tcp_packet.get_window(); - - combined_flags = tcp_packet.get_flags(); + return Some(BasicFeaturesIpv6::new( + ipv6_packet.get_destination().into(), + ipv6_packet.get_source().into(), + tcp_packet.get_destination(), + tcp_packet.get_source(), + tcp_packet.payload().len() as u16, + ipv6_packet.packet().bytes().count() as u16, + tcp_packet.get_window(), + tcp_packet.get_flags(), + ipv6_packet.get_next_header().0, + (tcp_packet.get_data_offset() * 4) as u8, + )); } else { return None; } - } else if protocol == IpNextHeaderProtocols::Udp { + } else if ipv6_packet.get_next_header() == IpNextHeaderProtocols::Udp { if let Some(udp_packet) = pnet::packet::udp::UdpPacket::new(ipv6_packet.payload()) { - source_port = udp_packet.get_source(); - destination_port = udp_packet.get_destination(); - - data_length = udp_packet.payload().len() as u16; - header_length = 8; - length = udp_packet.get_length(); + return Some(BasicFeaturesIpv6::new( + ipv6_packet.get_destination().into(), + ipv6_packet.get_source().into(), + udp_packet.get_destination(), + udp_packet.get_source(), + udp_packet.payload().len() as u16, + ipv6_packet.packet().bytes().count() as u16, + 0, + 0, + ipv6_packet.get_next_header().0, + 8, + )); } else { return None; } } else { return None; } - - Some(BasicFeaturesIpv6::new( - destination_ip.into(), - source_ip.into(), - destination_port, - source_port, - data_length, - length, - window_size, - combined_flags, - protocol.0, - header_length, - )) } #[cfg(test)] diff --git a/feature-extraction-tool/src/utils/utils.rs b/feature-extraction-tool/src/utils/utils.rs index 8b18966..1155567 100644 --- a/feature-extraction-tool/src/utils/utils.rs +++ b/feature-extraction-tool/src/utils/utils.rs @@ -20,9 +20,9 @@ use chrono::{DateTime, Utc}; /// /// A string representing the unique identifier of the network flow. pub fn create_flow_id( - ip_source: IpAddr, + ip_source: &IpAddr, port_source: u16, - ip_destination: IpAddr, + ip_destination: &IpAddr, port_destination: u16, protocol: u8, ) -> String {