Skip to content

Commit

Permalink
Merge pull request #36 from matissecallewaert/bug/32-add-customflow-t…
Browse files Browse the repository at this point in the history
…o-program

⚡ Improve performance of reader
  • Loading branch information
matissecallewaert authored Apr 21, 2024
2 parents 975fc1d + aa4accf commit a98c573
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 135 deletions.
209 changes: 76 additions & 133 deletions feature-extraction-tool/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,7 @@ where
let start = Instant::now();
let mut amount_of_packets = 0;

let flow_map_ipv4: Arc<DashMap<String, T>> = Arc::new(DashMap::new());
let flow_map_ipv6: Arc<DashMap<String, T>> = Arc::new(DashMap::new());
let flow_map: Arc<DashMap<String, T>> = Arc::new(DashMap::new());

let mut cap = match pcap::Capture::from_file(path) {
Ok(c) => c,
Expand All @@ -561,14 +560,14 @@ where
EtherTypes::Ipv4 => {
if let Some(ipv4_packet) = Ipv4Packet::new(ethernet.payload()) {
if let Some(features_ipv4) = extract_ipv4_features(&ipv4_packet) {
redirect_packet_ipv4(&features_ipv4, &flow_map_ipv4);
redirect_packet_ipv4(&features_ipv4, &flow_map);
}
}
}
EtherTypes::Ipv6 => {
if let Some(ipv6_packet) = Ipv6Packet::new(ethernet.payload()) {
if let Some(features_ipv6) = extract_ipv6_features(&ipv6_packet) {
redirect_packet_ipv6(&features_ipv6, &flow_map_ipv6);
redirect_packet_ipv6(&features_ipv6, &flow_map);
}
}
}
Expand All @@ -581,16 +580,7 @@ where
}
}

for entry in flow_map_ipv4.iter() {
let flow = entry.value();
if *NO_CONTAMINANT_FEATURES.lock().unwrap().deref() {
export(&flow.dump_without_contamination());
} else {
export(&flow.dump());
}
}

for entry in flow_map_ipv6.iter() {
for entry in flow_map.iter() {
let flow = entry.value();
if *NO_CONTAMINANT_FEATURES.lock().unwrap().deref() {
export(&flow.dump_without_contamination());
Expand All @@ -615,8 +605,7 @@ where
let mut amount_of_packets = 0;
let mut size: usize = 0;

let flow_map_ipv4: Arc<DashMap<String, T>> = Arc::new(DashMap::new());
let flow_map_ipv6: Arc<DashMap<String, T>> = Arc::new(DashMap::new());
let flow_map: Arc<DashMap<String, T>> = Arc::new(DashMap::new());

let mut cap = match pcap::Capture::from_file(path) {
Ok(c) => c,
Expand All @@ -640,14 +629,14 @@ where
SLL_IPV4 => {
if let Some(ipv4_packet) = Ipv4Packet::new(&packet.data[16..]) {
if let Some(features_ipv4) = extract_ipv4_features(&ipv4_packet) {
redirect_packet_ipv4(&features_ipv4, &flow_map_ipv4);
redirect_packet_ipv4(&features_ipv4, &flow_map);
}
}
}
SLL_IPV6 => {
if let Some(ipv6_packet) = Ipv6Packet::new(&packet.data[16..]) {
if let Some(features_ipv6) = extract_ipv6_features(&ipv6_packet) {
redirect_packet_ipv6(&features_ipv6, &flow_map_ipv6);
redirect_packet_ipv6(&features_ipv6, &flow_map);
}
}
}
Expand All @@ -660,16 +649,7 @@ where
}
}

for entry in flow_map_ipv4.iter() {
let flow = entry.value();
if *NO_CONTAMINANT_FEATURES.lock().unwrap().deref() {
export(&flow.dump_without_contamination());
} else {
export(&flow.dump());
}
}

for entry in flow_map_ipv6.iter() {
for entry in flow_map.iter() {
let flow = entry.value();
if *NO_CONTAMINANT_FEATURES.lock().unwrap().deref() {
export(&flow.dump_without_contamination());
Expand Down Expand Up @@ -750,17 +730,17 @@ where
};
let flow_id = if fwd {
create_flow_id(
source,
&source,
data.port_source,
destination,
&destination,
data.port_destination,
data.protocol,
)
} else {
create_flow_id(
destination,
&destination,
data.port_destination,
source,
&source,
data.port_source,
data.protocol,
)
Expand Down Expand Up @@ -830,17 +810,17 @@ where

let flow_id = if fwd {
create_flow_id(
source,
&source,
data.port_source,
destination,
&destination,
data.port_destination,
data.protocol,
)
} else {
create_flow_id(
destination,
&destination,
data.port_destination,
source,
&source,
data.port_source,
data.protocol,
)
Expand Down Expand Up @@ -889,16 +869,16 @@ where
T: Flow,
{
let fwd_flow_id = create_flow_id(
std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_source)),
&std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_source)),
features_ipv4.port_source,
std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_destination)),
&std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_destination)),
features_ipv4.port_destination,
features_ipv4.protocol,
);
let bwd_flow_id = create_flow_id(
std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_destination)),
&std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_destination)),
features_ipv4.port_destination,
std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_source)),
&std::net::IpAddr::V4(Ipv4Addr::from(features_ipv4.ipv4_source)),
features_ipv4.port_source,
features_ipv4.protocol,
);
Expand All @@ -923,16 +903,16 @@ where
T: Flow,
{
let fwd_flow_id = create_flow_id(
std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_source)),
&std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_source)),
features_ipv6.port_source,
std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_destination)),
&std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_destination)),
features_ipv6.port_destination,
features_ipv6.protocol,
);
let bwd_flow_id = create_flow_id(
std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_destination)),
&std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_destination)),
features_ipv6.port_destination,
std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_source)),
&std::net::IpAddr::V6(Ipv6Addr::from(features_ipv6.ipv6_source)),
features_ipv6.port_source,
features_ipv6.protocol,
);
Expand All @@ -956,63 +936,45 @@ where
///
/// * `Option<BasicFeaturesIpv4>` - Basic features of the packet.
fn extract_ipv4_features(ipv4_packet: &Ipv4Packet) -> Option<BasicFeaturesIpv4> {
let source_ip = ipv4_packet.get_source();
let destination_ip = ipv4_packet.get_destination();
let protocol = ipv4_packet.get_next_level_protocol();

let source_port: u16;
let destination_port: u16;

let mut combined_flags: u8 = 0;

let data_length: u16;
let header_length: u8;
let length: u16;

let mut window_size: u16 = 0;

if protocol.0 == IpNextHeaderProtocols::Tcp.0 {
if ipv4_packet.get_next_level_protocol().0 == IpNextHeaderProtocols::Tcp.0 {
if let Some(tcp_packet) = TcpPacket::new(ipv4_packet.payload()) {
source_port = tcp_packet.get_source();
destination_port = tcp_packet.get_destination();

data_length = tcp_packet.payload().len() as u16;
header_length = (tcp_packet.get_data_offset() * 4) as u8;
length = ipv4_packet.get_total_length();

window_size = tcp_packet.get_window();

combined_flags = tcp_packet.get_flags();
return Some(BasicFeaturesIpv4::new(
ipv4_packet.get_destination().into(),
ipv4_packet.get_source().into(),
tcp_packet.get_destination(),
tcp_packet.get_source(),
tcp_packet.payload().len() as u16,
ipv4_packet.get_total_length(),
tcp_packet.get_window(),
tcp_packet.get_flags(),
ipv4_packet.get_next_level_protocol().0,
(tcp_packet.get_data_offset() * 4) as u8,
));
} else {
return None;
}
} else if protocol.0 == IpNextHeaderProtocols::Udp.0 {
} else if ipv4_packet.get_next_level_protocol().0 == IpNextHeaderProtocols::Udp.0 {
if let Some(udp_packet) = pnet::packet::udp::UdpPacket::new(ipv4_packet.payload()) {
source_port = udp_packet.get_source();
destination_port = udp_packet.get_destination();

data_length = udp_packet.payload().len() as u16;
header_length = 8;
length = udp_packet.get_length();
return Some(BasicFeaturesIpv4::new(
ipv4_packet.get_destination().into(),
ipv4_packet.get_source().into(),
udp_packet.get_destination(),
udp_packet.get_source(),
udp_packet.payload().len() as u16,
udp_packet.get_length(),
0,
0,
ipv4_packet.get_next_level_protocol().0,
8,
));
} else {
return None;
}
} else {
return None;
}

Some(BasicFeaturesIpv4::new(
destination_ip.into(),
source_ip.into(),
destination_port,
source_port,
data_length,
length,
window_size,
combined_flags,
protocol.0,
header_length,
))
}

/// Extracts the basic features of an ipv6 packet pnet struct.
Expand All @@ -1025,63 +987,44 @@ fn extract_ipv4_features(ipv4_packet: &Ipv4Packet) -> Option<BasicFeaturesIpv4>
///
/// * `Option<BasicFeaturesIpv6>` - Basic features of the packet.
fn extract_ipv6_features(ipv6_packet: &Ipv6Packet) -> Option<BasicFeaturesIpv6> {
let source_ip = ipv6_packet.get_source();
let destination_ip = ipv6_packet.get_destination();
let protocol = ipv6_packet.get_next_header();

let source_port: u16;
let destination_port: u16;

let mut combined_flags: u8 = 0;

let data_length: u16;
let header_length: u8;
let length: u16;

let mut window_size: u16 = 0;

if protocol == IpNextHeaderProtocols::Tcp {
if ipv6_packet.get_next_header() == IpNextHeaderProtocols::Tcp {
if let Some(tcp_packet) = TcpPacket::new(ipv6_packet.payload()) {
source_port = tcp_packet.get_source();
destination_port = tcp_packet.get_destination();

data_length = tcp_packet.payload().len() as u16;
header_length = (tcp_packet.get_data_offset() * 4) as u8;
length = ipv6_packet.packet().bytes().count() as u16;

window_size = tcp_packet.get_window();

combined_flags = tcp_packet.get_flags();
return Some(BasicFeaturesIpv6::new(
ipv6_packet.get_destination().into(),
ipv6_packet.get_source().into(),
tcp_packet.get_destination(),
tcp_packet.get_source(),
tcp_packet.payload().len() as u16,
ipv6_packet.packet().bytes().count() as u16,
tcp_packet.get_window(),
tcp_packet.get_flags(),
ipv6_packet.get_next_header().0,
(tcp_packet.get_data_offset() * 4) as u8,
));
} else {
return None;
}
} else if protocol == IpNextHeaderProtocols::Udp {
} else if ipv6_packet.get_next_header() == IpNextHeaderProtocols::Udp {
if let Some(udp_packet) = pnet::packet::udp::UdpPacket::new(ipv6_packet.payload()) {
source_port = udp_packet.get_source();
destination_port = udp_packet.get_destination();

data_length = udp_packet.payload().len() as u16;
header_length = 8;
length = udp_packet.get_length();
return Some(BasicFeaturesIpv6::new(
ipv6_packet.get_destination().into(),
ipv6_packet.get_source().into(),
udp_packet.get_destination(),
udp_packet.get_source(),
udp_packet.payload().len() as u16,
ipv6_packet.packet().bytes().count() as u16,
0,
0,
ipv6_packet.get_next_header().0,
8,
));
} else {
return None;
}
} else {
return None;
}

Some(BasicFeaturesIpv6::new(
destination_ip.into(),
source_ip.into(),
destination_port,
source_port,
data_length,
length,
window_size,
combined_flags,
protocol.0,
header_length,
))
}

#[cfg(test)]
Expand Down
4 changes: 2 additions & 2 deletions feature-extraction-tool/src/utils/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ use chrono::{DateTime, Utc};
///
/// A string representing the unique identifier of the network flow.
pub fn create_flow_id(
ip_source: IpAddr,
ip_source: &IpAddr,
port_source: u16,
ip_destination: IpAddr,
ip_destination: &IpAddr,
port_destination: u16,
protocol: u8,
) -> String {
Expand Down

0 comments on commit a98c573

Please sign in to comment.