Skip to content

Commit

Permalink
Refactor: Change protocol + mirrord-layer to split messages into modu… (
Browse files Browse the repository at this point in the history
#150)

* Refactor: Change protocol + mirrord-layer to split messages into modules, so main module only handles general messages, passing down to the appropriate module for handling.
  • Loading branch information
aviramha authored Jun 16, 2022
1 parent 255fbe5 commit f677ce2
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 274 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how
- CI: Remove building agent before building & running tests (duplicate)
- CI: Add Docker cache to Docker build-push action to reduce build duration.
- CD release: Fix universal binary for macOS
- Refactor: Change protocol + mirrord-layer to split messages into modules, so main module only handles general messages, passing down to the appropriate module for handling.

## 2.2.1
### Changed
Expand Down
36 changes: 23 additions & 13 deletions mirrord-agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::{
use error::AgentError;
use futures::SinkExt;
use mirrord_protocol::{
tcp::{DaemonTcp, LayerTcp},
ClientMessage, ConnectionID, DaemonCodec, DaemonMessage, FileRequest, FileResponse, Port,
};
use tokio::{
Expand Down Expand Up @@ -113,14 +114,23 @@ async fn handle_peer_messages(
peer_message: PeerMessage,
) -> Result<(), AgentError> {
match peer_message.client_message {
ClientMessage::PortSubscribe(ports) => {
ClientMessage::Tcp(LayerTcp::PortUnsubscribe(port)) => {
debug!(
"ClientMessage::PortUnsubscribe -> peer id {:#?}, port {port:#?}",
peer_message.peer_id
);
state
.port_subscriptions
.unsubscribe(peer_message.peer_id, port);
}
ClientMessage::Tcp(LayerTcp::PortSubscribe(port)) => {
debug!(
"ClientMessage::PortSubscribe -> peer id {:#?} asked to subscribe to {:#?}",
peer_message.peer_id, ports
peer_message.peer_id, port
);
state
.port_subscriptions
.subscribe_many(peer_message.peer_id, ports);
.subscribe(peer_message.peer_id, port);

let ports = state.port_subscriptions.get_subscribed_topics();
sniffer_command_tx
Expand All @@ -139,7 +149,7 @@ async fn handle_peer_messages(
.send(SnifferCommand::SetPorts(ports))
.await?;
}
ClientMessage::ConnectionUnsubscribe(connection_id) => {
ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe(connection_id)) => {
debug!("ClientMessage::ConnectionUnsubscribe -> peer id {:#?} unsubscribe connection id {:#?}", &peer_message.peer_id, connection_id);
state
.connections_subscriptions
Expand Down Expand Up @@ -288,15 +298,15 @@ async fn start_agent() -> Result<(), AgentError> {
match sniffer_output {
Some(sniffer_output) => {
match sniffer_output {
SnifferOutput::NewTCPConnection(new_connection) => {
debug!("SnifferOutput::NewTCPConnection -> connection {:#?}", new_connection);
SnifferOutput::NewTcpConnection(new_connection) => {
debug!("SnifferOutput::NewTcpConnection -> connection {:#?}", new_connection);
let peer_ids = state.port_subscriptions.get_topic_subscribers(new_connection.destination_port);

for peer_id in peer_ids {
state.connections_subscriptions.subscribe(peer_id, new_connection.connection_id);

if let Some(peer) = state.peers.get(&peer_id) {
match peer.channel.send(DaemonMessage::NewTCPConnection(new_connection.clone())).await {
match peer.channel.send(DaemonMessage::Tcp(DaemonTcp::NewConnection(new_connection.clone()))).await {
Ok(_) => {},
Err(err) => {
error!("error sending message {:?}", err);
Expand All @@ -305,13 +315,13 @@ async fn start_agent() -> Result<(), AgentError> {
}
}
},
SnifferOutput::TCPClose(close) => {
debug!("SnifferOutput::TCPClose -> close {:#?}", close);
SnifferOutput::TcpClose(close) => {
debug!("SnifferOutput::TcpClose -> close {:#?}", close);
let peer_ids = state.connections_subscriptions.get_topic_subscribers(close.connection_id);

for peer_id in peer_ids {
if let Some(peer) = state.peers.get(&peer_id) {
match peer.channel.send(DaemonMessage::TCPClose(close.clone())).await {
match peer.channel.send(DaemonMessage::Tcp(DaemonTcp::Close(close.clone()))).await {
Ok(_) => {},
Err(err) => {
error!("error sending message {:?}", err);
Expand All @@ -321,13 +331,13 @@ async fn start_agent() -> Result<(), AgentError> {
}
state.connections_subscriptions.remove_topic(close.connection_id);
},
SnifferOutput::TCPData(data) => {
debug!("SnifferOutput::TCPData -> data");
SnifferOutput::TcpData(data) => {
debug!("SnifferOutput::TcpData -> data");
let peer_ids = state.connections_subscriptions.get_topic_subscribers(data.connection_id);

for peer_id in peer_ids {
if let Some(peer) = state.peers.get(&peer_id) {
match peer.channel.send(DaemonMessage::TCPData(data.clone())).await {
match peer.channel.send(DaemonMessage::Tcp(DaemonTcp::Data(data.clone()))).await {
Ok(_) => {},
Err(err) => {
error!("error sending message {:?}", err);
Expand Down
32 changes: 16 additions & 16 deletions mirrord-agent/src/sniffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
};

use futures::StreamExt;
use mirrord_protocol::{NewTCPConnection, TCPClose, TCPData};
use mirrord_protocol::tcp::{NewTcpConnection, TcpClose, TcpData};
use pcap::{stream::PacketCodec, Active, Capture, Device, Linktype};
use pnet::packet::{
ethernet::{EtherTypes, EthernetPacket},
Expand Down Expand Up @@ -42,22 +42,22 @@ pub enum SnifferCommand {

#[derive(Debug)]
pub enum SnifferOutput {
NewTCPConnection(NewTCPConnection),
TCPClose(TCPClose),
TCPData(TCPData),
NewTcpConnection(NewTcpConnection),
TcpClose(TcpClose),
TcpData(TcpData),
}

#[derive(Debug, Eq, Copy, Clone)]
pub struct TCPSessionIdentifier {
pub struct TcpSessionIdentifier {
source_addr: Ipv4Addr,
dest_addr: Ipv4Addr,
source_port: u16,
dest_port: u16,
}

impl PartialEq for TCPSessionIdentifier {
impl PartialEq for TcpSessionIdentifier {
/// It's the same session if 4 tuple is same/opposite.
fn eq(&self, other: &TCPSessionIdentifier) -> bool {
fn eq(&self, other: &TcpSessionIdentifier) -> bool {
self.source_addr == other.source_addr
&& self.dest_addr == other.dest_addr
&& self.source_port == other.source_port
Expand All @@ -69,7 +69,7 @@ impl PartialEq for TCPSessionIdentifier {
}
}

impl Hash for TCPSessionIdentifier {
impl Hash for TcpSessionIdentifier {
fn hash<H: Hasher>(&self, state: &mut H) {
if self.source_addr > self.dest_addr {
self.source_addr.hash(state);
Expand All @@ -89,7 +89,7 @@ impl Hash for TCPSessionIdentifier {
}

type Session = ConnectionID;
type SessionMap = HashMap<TCPSessionIdentifier, Session>;
type SessionMap = HashMap<TcpSessionIdentifier, Session>;

fn is_new_connection(flags: u16) -> bool {
flags == TcpFlags::SYN
Expand Down Expand Up @@ -156,7 +156,7 @@ impl ConnectionManager {
let tcp_flags = tcp_packet.get_flags();
let source_port = tcp_packet.get_source();

let identifier = TCPSessionIdentifier {
let identifier = TcpSessionIdentifier {
source_addr: ip_packet.get_source(),
dest_addr: ip_packet.get_destination(),
source_port,
Expand All @@ -179,7 +179,7 @@ impl ConnectionManager {
error!("connection index exhausted, dropping new connection");
None
})?;
messages.push(SnifferOutput::NewTCPConnection(NewTCPConnection {
messages.push(SnifferOutput::NewTcpConnection(NewTcpConnection {
destination_port: dest_port,
source_port,
connection_id: id,
Expand All @@ -193,7 +193,7 @@ impl ConnectionManager {
let data = tcp_packet.payload();

if !data.is_empty() {
messages.push(SnifferOutput::TCPData(TCPData {
messages.push(SnifferOutput::TcpData(TcpData {
bytes: data.to_vec(),
connection_id: session,
}));
Expand All @@ -203,7 +203,7 @@ impl ConnectionManager {
if is_closed_connection(tcp_flags) {
self.index_allocator.free_index(session);

messages.push(SnifferOutput::TCPClose(TCPClose {
messages.push(SnifferOutput::TcpClose(TcpClose {
connection_id: session,
}));
} else {
Expand All @@ -214,9 +214,9 @@ impl ConnectionManager {
}
}

pub struct TCPManagerCodec {}
pub struct TcpManagerCodec {}

impl PacketCodec for TCPManagerCodec {
impl PacketCodec for TcpManagerCodec {
type Type = Vec<u8>;

fn decode(&mut self, packet: pcap::Packet) -> Result<Self::Type, pcap::Error> {
Expand Down Expand Up @@ -282,7 +282,7 @@ pub async fn packet_worker(

debug!("preparing sniffer");
let sniffer = prepare_sniffer(interface)?;
let codec = TCPManagerCodec {};
let codec = TcpManagerCodec {};
let mut connection_manager = ConnectionManager::new();
let mut sniffer_stream = sniffer.stream(codec)?;

Expand Down
10 changes: 2 additions & 8 deletions mirrord-agent/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ where
.insert(client);
}

/// Subscribe many topics at once
pub fn subscribe_many(&mut self, client: C, topics: impl IntoIterator<Item = T>) {
for topic in topics {
self.subscribe(client, topic);
}
}

/// Remove a subscription of given client from the topic.
/// topic is removed if no subscribers left.
pub fn unsubscribe(&mut self, client: C, topic: T) {
Expand Down Expand Up @@ -136,7 +129,8 @@ mod subscription_tests {
#[test]
fn sanity() {
let mut subscriptions = Subscriptions::<Port, _>::new();
subscriptions.subscribe_many(3, vec![3, 2]);
subscriptions.subscribe(3, 2);
subscriptions.subscribe(3, 3);
subscriptions.subscribe(3, 1);
subscriptions.subscribe(1, 4);
subscriptions.subscribe(2, 1);
Expand Down
15 changes: 8 additions & 7 deletions mirrord-agent/tests/blackbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ mod tests {
use actix_codec::Framed;
use futures::SinkExt;
use mirrord_protocol::{
ClientCodec, ClientMessage, DaemonMessage, NewTCPConnection, TCPClose, TCPData,
tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData},
ClientCodec, ClientMessage, DaemonMessage,
};
use test_bin::get_test_bin;
use tokio::{
Expand Down Expand Up @@ -69,7 +70,7 @@ mod tests {
let mut codec = Framed::new(stream, ClientCodec::new());

codec
.send(ClientMessage::PortSubscribe(vec![1337, 1338]))
.send(ClientMessage::Tcp(LayerTcp::PortSubscribe(1337)))
.await
.expect("port subscribe failed");
// Let message be acknowledged and dummy socket to start listening
Expand Down Expand Up @@ -101,25 +102,25 @@ mod tests {
.expect("got invalid message");
assert_eq!(
new_conn_msg,
DaemonMessage::NewTCPConnection(NewTCPConnection {
DaemonMessage::Tcp(DaemonTcp::NewConnection(NewTcpConnection {
connection_id: 0,
address: IpAddr::V4("127.0.0.1".parse().unwrap()),
destination_port: 1337,
source_port: port
})
}))
);

assert_eq!(
data_msg,
DaemonMessage::TCPData(TCPData {
DaemonMessage::Tcp(DaemonTcp::Data(TcpData {
connection_id: 0,
bytes: test_data.to_vec()
})
}))
);

assert_eq!(
close_msg,
DaemonMessage::TCPClose(TCPClose { connection_id: 0 })
DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 0 }))
);

drop(codec);
Expand Down
55 changes: 5 additions & 50 deletions mirrord-layer/src/common.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,13 @@
use std::{
borrow::Borrow,
hash::{Hash, Hasher},
io::SeekFrom,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
os::unix::io::RawFd,
path::PathBuf,
};
use std::{io::SeekFrom, os::unix::io::RawFd, path::PathBuf};

use mirrord_protocol::{
CloseFileResponse, OpenFileResponse, OpenOptionsInternal, Port, ReadFileResponse,
SeekFileResponse, WriteFileResponse,
CloseFileResponse, OpenFileResponse, OpenOptionsInternal, ReadFileResponse, SeekFileResponse,
WriteFileResponse,
};
use tokio::sync::oneshot;

#[derive(Debug, Clone)]
pub struct Listen {
pub fake_port: Port,
pub real_port: Port,
pub ipv6: bool,
pub fd: RawFd,
}

impl PartialEq for Listen {
fn eq(&self, other: &Self) -> bool {
self.real_port == other.real_port
}
}

impl Eq for Listen {}
use crate::tcp::HookMessageTcp;

impl Hash for Listen {
fn hash<H: Hasher>(&self, state: &mut H) {
self.real_port.hash(state);
}
}

impl Borrow<Port> for Listen {
fn borrow(&self) -> &Port {
&self.real_port
}
}

impl From<&Listen> for SocketAddr {
fn from(listen: &Listen) -> Self {
let address = if listen.ipv6 {
SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), listen.fake_port)
} else {
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), listen.fake_port)
};

debug_assert_eq!(address.port(), listen.fake_port);
address
}
}
// TODO: Some ideas around abstracting file operations:
// Alright, all these are pretty much the same thing, they could be
// abstract over a generic dependent-ish type like so:
Expand Down Expand Up @@ -119,7 +74,7 @@ pub struct CloseFileHook {
/// These messages are handled internally by -layer, and become `ClientMessage`s sent to -agent.
#[derive(Debug)]
pub enum HookMessage {
Listen(Listen),
Tcp(HookMessageTcp),
OpenFileHook(OpenFileHook),
OpenRelativeFileHook(OpenRelativeFileHook),
ReadFileHook(ReadFileHook),
Expand Down
11 changes: 6 additions & 5 deletions mirrord-layer/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{env::VarError, os::unix::io::RawFd, str::ParseBoolError};

use mirrord_protocol::tcp::LayerTcp;
use thiserror::Error;
use tokio::sync::{mpsc::error::SendError, oneshot::error::RecvError};

use super::{tcp::TrafficHandlerInput, HookMessage};
use super::HookMessage;

#[derive(Error, Debug)]
pub enum LayerError {
Expand All @@ -16,12 +17,12 @@ pub enum LayerError {
#[error("mirrord-layer: Sender<HookMessage> failed with `{0}`!")]
SendErrorHookMessage(#[from] SendError<HookMessage>),

#[error("mirrord-layer: Sender<TrafficHandlerInput> failed with `{0}`!")]
SendErrorTrafficHandler(#[from] SendError<TrafficHandlerInput>),

#[error("mirrord-layer: Sender<TrafficHandlerInput> failed with `{0}`!")]
#[error("mirrord-layer: Sender<Vec<u8>> failed with `{0}`!")]
SendErrorConnection(#[from] SendError<Vec<u8>>),

#[error("mirrord-layer: Sender<LayerTcp> failed with `{0}`!")]
SendErrorLayerTcp(#[from] SendError<LayerTcp>),

#[error("mirrord-layer: Receiver failed with `{0}`!")]
RecvError(#[from] RecvError),

Expand Down
Loading

0 comments on commit f677ce2

Please sign in to comment.