Skip to content

Commit

Permalink
Merge pull request #806 from stlankes/async
Browse files Browse the repository at this point in the history
add support of VIRTIO_NET_F_GUEST_CSUM
  • Loading branch information
mkroening authored Jul 26, 2023
2 parents 9cf5554 + 15d924a commit 4f7ce89
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 37 deletions.
4 changes: 4 additions & 0 deletions src/drivers/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ pub trait NetworkInterface {
fn set_polling_mode(&mut self, value: bool);
/// Handle interrupt and check if a packet is available
fn handle_interrupt(&mut self) -> bool;
/// Returns true, if the device has to create checksums
fn with_checksums(&self) -> bool {
true
}
}

#[inline]
Expand Down
119 changes: 92 additions & 27 deletions src/drivers/net/virtio_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use core::str::FromStr;
use pci_types::InterruptLine;
use zerocopy::AsBytes;

use self::constants::{FeatureSet, Features, NetHdrGSO, Status, MAX_NUM_VQ};
use self::constants::{FeatureSet, Features, NetHdrFlag, NetHdrGSO, Status, MAX_NUM_VQ};
use self::error::VirtioNetError;
use crate::arch::kernel::core_local::increment_irq_counter;
use crate::config::VIRTIO_MAX_QUEUE_SIZE;
Expand Down Expand Up @@ -48,8 +48,8 @@ pub(crate) struct NetDevCfg {
#[derive(AsBytes, Debug)]
#[repr(C)]
pub struct VirtioNetHdr {
flags: u8,
gso_type: u8,
flags: NetHdrFlag,
gso_type: NetHdrGSO,
/// Ethernet + IP + tcp/udp hdrs
hdr_len: u16,
/// Bytes to append to hdr_len per frame
Expand All @@ -62,23 +62,11 @@ pub struct VirtioNetHdr {
num_buffers: u16,
}

impl VirtioNetHdr {
pub fn get_tx_hdr() -> VirtioNetHdr {
VirtioNetHdr {
flags: 0,
gso_type: NetHdrGSO::VIRTIO_NET_HDR_GSO_NONE.into(),
hdr_len: 0,
gso_size: 0,
csum_start: 0,
csum_offset: 0,
num_buffers: 0,
}
}

pub fn get_rx_hdr() -> VirtioNetHdr {
VirtioNetHdr {
flags: 0,
gso_type: 0,
impl Default for VirtioNetHdr {
fn default() -> Self {
Self {
flags: NetHdrFlag::VIRTIO_NET_HDR_F_NONE,
gso_type: NetHdrGSO::VIRTIO_NET_HDR_GSO_NONE,
hdr_len: 0,
gso_size: 0,
csum_start: 0,
Expand Down Expand Up @@ -419,7 +407,7 @@ impl TxQueues {
self.ready_queue.push(
vq.prep_buffer(Rc::clone(vq), Some(spec.clone()), None)
.unwrap()
.write_seq(Some(&VirtioNetHdr::get_tx_hdr()), None::<&VirtioNetHdr>)
.write_seq(Some(&VirtioNetHdr::default()), None::<&VirtioNetHdr>)
.unwrap(),
)
}
Expand Down Expand Up @@ -568,7 +556,53 @@ impl NetworkInterface for VirtioNetDriver {
fn send_tx_buffer(&mut self, tkn_handle: usize, _len: usize) -> Result<(), ()> {
// This does not result in a new assignment, or in a drop of the BufferToken, which
// would be dangerous, as the memory is freed then.
let tkn = *unsafe { Box::from_raw(tkn_handle as *mut BufferToken) };
let mut tkn = *unsafe { Box::from_raw(tkn_handle as *mut BufferToken) };

// If a checksum isn't necessary, we have inform the host within the header
// see Virtio specification 5.1.6.2
if !self.with_checksums() {
unsafe {
let (send_ptrs, _) = tkn.raw_ptrs();
let (addr, _) = send_ptrs.unwrap()[0];
let header = addr as *mut VirtioNetHdr;
let type_ = u16::from_be(
*(addr.offset(
(12 + core::mem::size_of::<VirtioNetHdr>())
.try_into()
.unwrap(),
) as *const u16),
);

*header = Default::default();
match type_ {
0x0800 /* IPv4 */ => {
let protocol = *(addr.offset((14+9+core::mem::size_of::<VirtioNetHdr>()).try_into().unwrap()) as *const u8);
if protocol == 6 /* TCP */ {
(*header).flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM;
(*header).csum_start = 14+20;
(*header).csum_offset = 16;
} else if protocol == 17 /* UDP */ {
(*header).flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM;
(*header).csum_start = 14+20;
(*header).csum_offset = 6;
}
},
0x86DD /* IPv6 */ => {
let protocol = *(addr.offset((14+9+core::mem::size_of::<VirtioNetHdr>()).try_into().unwrap()) as *const u8);
if protocol == 6 /* TCP */ {
(*header).flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM;
(*header).csum_start = 14+40;
(*header).csum_offset = 16;
} else if protocol == 17 /* UDP */ {
(*header).flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM;
(*header).csum_start = 14+40;
(*header).csum_offset = 6;
}
},
_ => {},
}
}
}

tkn.provide()
.dispatch_await(Rc::clone(&self.send_vqs.poll_queue), false);
Expand Down Expand Up @@ -598,6 +632,15 @@ impl NetworkInterface for VirtioNetDriver {
// If the given length is zero, we currently fail.
if recv_data.len() == 2 {
let recv_payload = recv_data.pop().unwrap();
/*let header = recv_data.pop().unwrap();
let header = unsafe {
const HEADER_SIZE: usize = mem::size_of::<VirtioNetHdr>();
core::mem::transmute::<[u8; HEADER_SIZE], VirtioNetHdr>(
header[..HEADER_SIZE].try_into().unwrap(),
)
};
trace!("Receive data with header {:?}", header);*/

// Create static reference for the user-space
// As long as we keep the Transfer in a raw reference this reference is static,
// so this is fine.
Expand All @@ -613,6 +656,14 @@ impl NetworkInterface for VirtioNetDriver {
Ok(vec_data)
} else if recv_data.len() == 1 {
let packet = recv_data.pop().unwrap();
/*let header = unsafe {
const HEADER_SIZE: usize = mem::size_of::<VirtioNetHdr>();
core::mem::transmute::<[u8; HEADER_SIZE], VirtioNetHdr>(
packet[..HEADER_SIZE].try_into().unwrap(),
)
};
trace!("Receive data with header {:?}", header);*/

let payload_ptr =
(&packet[mem::size_of::<VirtioNetHdr>()] as *const u8) as *mut u8;

Expand All @@ -635,7 +686,7 @@ impl NetworkInterface for VirtioNetDriver {
transfer
.reuse()
.unwrap()
.write_seq(None::<&VirtioNetHdr>, Some(&VirtioNetHdr::get_rx_hdr()))
.write_seq(None::<&VirtioNetHdr>, Some(&VirtioNetHdr::default()))
.unwrap()
.provide()
.dispatch_await(Rc::clone(&self.recv_vqs.poll_queue), false);
Expand Down Expand Up @@ -677,6 +728,15 @@ impl NetworkInterface for VirtioNetDriver {

result
}

/// Returns `true` if the device supports the virtio feature
/// `VIRTIO_NET_F_GUEST_CSUM` and trust the incoming packages.
fn with_checksums(&self) -> bool {
!self
.dev_cfg
.features
.is_feature(Features::VIRTIO_NET_F_GUEST_CSUM)
}
}

// Backend-independent interface for Virtio network driver
Expand Down Expand Up @@ -796,12 +856,13 @@ impl VirtioNetDriver {
feats.push(Features::VIRTIO_NET_F_MTU);
// Packed Vq can be used
feats.push(Features::VIRTIO_F_RING_PACKED);
// Avoid the creation of checksums
feats.push(Features::VIRTIO_NET_F_GUEST_CSUM);

// Currently the driver does NOT support the features below.
// In order to provide functionality for these, the driver
// needs to take care of calculating checksum in
// RxQueues.post_processing()
// feats.push(Features::VIRTIO_NET_F_GUEST_CSUM);
// feats.push(Features::VIRTIO_NET_F_GUEST_TSO4);
// feats.push(Features::VIRTIO_NET_F_GUEST_TSO6);

Expand Down Expand Up @@ -1068,6 +1129,8 @@ pub mod constants {
use alloc::vec::Vec;
use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign};

use zerocopy::AsBytes;

pub use super::error::VirtioNetError;

// Configuration constants
Expand All @@ -1077,10 +1140,11 @@ pub mod constants {
///
/// See Virtio specification v1.1. - 5.1.6
#[allow(dead_code, non_camel_case_types)]
#[derive(Copy, Clone, Debug)]
#[derive(AsBytes, Copy, Clone, Debug)]
#[repr(u8)]
///
pub enum NetHdrFlag {
/// No further information
VIRTIO_NET_HDR_F_NONE = 0,
/// use csum_start, csum_offset
VIRTIO_NET_HDR_F_NEEDS_CSUM = 1,
/// csum is valid
Expand All @@ -1092,6 +1156,7 @@ pub mod constants {
impl From<NetHdrFlag> for u8 {
fn from(val: NetHdrFlag) -> Self {
match val {
NetHdrFlag::VIRTIO_NET_HDR_F_NONE => 0,
NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM => 1,
NetHdrFlag::VIRTIO_NET_HDR_F_DATA_VALID => 2,
NetHdrFlag::VIRTIO_NET_HDR_F_RSC_INFO => 4,
Expand Down Expand Up @@ -1147,7 +1212,7 @@ pub mod constants {
///
/// See Virtio specification v1.1. - 5.1.6
#[allow(dead_code, non_camel_case_types)]
#[derive(Copy, Clone, Debug)]
#[derive(AsBytes, Copy, Clone, Debug)]
#[repr(u8)]
pub enum NetHdrGSO {
/// not a GSO frame
Expand Down
36 changes: 26 additions & 10 deletions src/executor/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use core::slice;
use core::str::FromStr;

use smoltcp::iface::{Config, Interface, SocketSet};
use smoltcp::phy::{self, Device, DeviceCapabilities, Medium};
use smoltcp::phy::{self, Checksum, Device, DeviceCapabilities, Medium};
#[cfg(feature = "dhcpv4")]
use smoltcp::socket::dhcpv4;
use smoltcp::time::Instant;
Expand All @@ -24,26 +24,34 @@ use crate::drivers::pci as hardware;
#[derive(Debug, Copy, Clone)]
#[repr(C)]
pub(crate) struct HermitNet {
pub mtu: u16,
mtu: u16,
with_checksums: bool,
}

impl HermitNet {
pub(crate) const fn new(mtu: u16) -> Self {
Self { mtu }
pub(crate) const fn new(mtu: u16, with_checksums: bool) -> Self {
Self {
mtu,
with_checksums,
}
}
}

impl<'a> NetworkInterface<'a> {
#[cfg(feature = "dhcpv4")]
pub(crate) fn create() -> NetworkState<'a> {
let (mtu, mac) = if let Some(driver) = hardware::get_network_driver() {
let (mtu, mac, with_checksums) = if let Some(driver) = hardware::get_network_driver() {
let guard = driver.lock();
(guard.get_mtu(), guard.get_mac_address())
(
guard.get_mtu(),
guard.get_mac_address(),
guard.with_checksums(),
)
} else {
return NetworkState::InitializationFailed;
};

let mut device = HermitNet::new(mtu);
let mut device = HermitNet::new(mtu, with_checksums);

let ethernet_addr = EthernetAddress([mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]]);
let hardware_addr = HardwareAddress::Ethernet(ethernet_addr);
Expand Down Expand Up @@ -74,14 +82,18 @@ impl<'a> NetworkInterface<'a> {

#[cfg(not(feature = "dhcpv4"))]
pub(crate) fn create() -> NetworkState<'a> {
let (mtu, mac) = if let Some(driver) = hardware::get_network_driver() {
let (mtu, mac, with_checksums) = if let Some(driver) = hardware::get_network_driver() {
let guard = driver.lock();
(guard.get_mtu(), guard.get_mac_address())
(
guard.get_mtu(),
guard.get_mac_address(),
guard.with_checksums(),
)
} else {
return NetworkState::InitializationFailed;
};

let mut device = HermitNet::new(mtu);
let mut device = HermitNet::new(mtu, with_checksums);

let myip = Ipv4Address::from_str(hermit_var_or!("HERMIT_IP", "10.0.5.3")).unwrap();
let mygw = Ipv4Address::from_str(hermit_var_or!("HERMIT_GATEWAY", "10.0.5.1")).unwrap();
Expand Down Expand Up @@ -156,6 +168,10 @@ impl Device for HermitNet {
fn capabilities(&self) -> DeviceCapabilities {
let mut cap = DeviceCapabilities::default();
cap.max_transmission_unit = self.mtu.into();
if !self.with_checksums {
cap.checksum.tcp = Checksum::None;
cap.checksum.udp = Checksum::None;
}
cap
}

Expand Down

0 comments on commit 4f7ce89

Please sign in to comment.