diff --git a/src/drivers/net/mod.rs b/src/drivers/net/mod.rs index d5626b3f34..d8afc926ca 100644 --- a/src/drivers/net/mod.rs +++ b/src/drivers/net/mod.rs @@ -43,6 +43,10 @@ pub trait NetworkInterface { fn set_polling_mode(&mut self, value: bool); /// Handle interrupt and check if a packet is available fn handle_interrupt(&mut self) -> bool; + /// Returns true, if the device has to create checksums + fn with_checksums(&self) -> bool { + true + } } #[inline] diff --git a/src/drivers/net/virtio_net.rs b/src/drivers/net/virtio_net.rs index bf0db81cd1..b76cca224d 100644 --- a/src/drivers/net/virtio_net.rs +++ b/src/drivers/net/virtio_net.rs @@ -15,7 +15,7 @@ use core::str::FromStr; use pci_types::InterruptLine; use zerocopy::AsBytes; -use self::constants::{FeatureSet, Features, NetHdrGSO, Status, MAX_NUM_VQ}; +use self::constants::{FeatureSet, Features, NetHdrFlag, NetHdrGSO, Status, MAX_NUM_VQ}; use self::error::VirtioNetError; use crate::arch::kernel::core_local::increment_irq_counter; use crate::config::VIRTIO_MAX_QUEUE_SIZE; @@ -48,8 +48,8 @@ pub(crate) struct NetDevCfg { #[derive(AsBytes, Debug)] #[repr(C)] pub struct VirtioNetHdr { - flags: u8, - gso_type: u8, + flags: NetHdrFlag, + gso_type: NetHdrGSO, /// Ethernet + IP + tcp/udp hdrs hdr_len: u16, /// Bytes to append to hdr_len per frame @@ -62,23 +62,11 @@ pub struct VirtioNetHdr { num_buffers: u16, } -impl VirtioNetHdr { - pub fn get_tx_hdr() -> VirtioNetHdr { - VirtioNetHdr { - flags: 0, - gso_type: NetHdrGSO::VIRTIO_NET_HDR_GSO_NONE.into(), - hdr_len: 0, - gso_size: 0, - csum_start: 0, - csum_offset: 0, - num_buffers: 0, - } - } - - pub fn get_rx_hdr() -> VirtioNetHdr { - VirtioNetHdr { - flags: 0, - gso_type: 0, +impl Default for VirtioNetHdr { + fn default() -> Self { + Self { + flags: NetHdrFlag::VIRTIO_NET_HDR_F_NONE, + gso_type: NetHdrGSO::VIRTIO_NET_HDR_GSO_NONE, hdr_len: 0, gso_size: 0, csum_start: 0, @@ -419,7 +407,7 @@ impl TxQueues { self.ready_queue.push( vq.prep_buffer(Rc::clone(vq), Some(spec.clone()), None) .unwrap() - .write_seq(Some(&VirtioNetHdr::get_tx_hdr()), None::<&VirtioNetHdr>) + .write_seq(Some(&VirtioNetHdr::default()), None::<&VirtioNetHdr>) .unwrap(), ) } @@ -568,7 +556,53 @@ impl NetworkInterface for VirtioNetDriver { fn send_tx_buffer(&mut self, tkn_handle: usize, _len: usize) -> Result<(), ()> { // This does not result in a new assignment, or in a drop of the BufferToken, which // would be dangerous, as the memory is freed then. - let tkn = *unsafe { Box::from_raw(tkn_handle as *mut BufferToken) }; + let mut tkn = *unsafe { Box::from_raw(tkn_handle as *mut BufferToken) }; + + // If a checksum isn't necessary, we have inform the host within the header + // see Virtio specification 5.1.6.2 + if !self.with_checksums() { + unsafe { + let (send_ptrs, _) = tkn.raw_ptrs(); + let (addr, _) = send_ptrs.unwrap()[0]; + let header = addr as *mut VirtioNetHdr; + let type_ = u16::from_be( + *(addr.offset( + (12 + core::mem::size_of::()) + .try_into() + .unwrap(), + ) as *const u16), + ); + + *header = Default::default(); + match type_ { + 0x0800 /* IPv4 */ => { + let protocol = *(addr.offset((14+9+core::mem::size_of::()).try_into().unwrap()) as *const u8); + if protocol == 6 /* TCP */ { + (*header).flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; + (*header).csum_start = 14+20; + (*header).csum_offset = 16; + } else if protocol == 17 /* UDP */ { + (*header).flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; + (*header).csum_start = 14+20; + (*header).csum_offset = 6; + } + }, + 0x86DD /* IPv6 */ => { + let protocol = *(addr.offset((14+9+core::mem::size_of::()).try_into().unwrap()) as *const u8); + if protocol == 6 /* TCP */ { + (*header).flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; + (*header).csum_start = 14+40; + (*header).csum_offset = 16; + } else if protocol == 17 /* UDP */ { + (*header).flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; + (*header).csum_start = 14+40; + (*header).csum_offset = 6; + } + }, + _ => {}, + } + } + } tkn.provide() .dispatch_await(Rc::clone(&self.send_vqs.poll_queue), false); @@ -598,6 +632,15 @@ impl NetworkInterface for VirtioNetDriver { // If the given length is zero, we currently fail. if recv_data.len() == 2 { let recv_payload = recv_data.pop().unwrap(); + /*let header = recv_data.pop().unwrap(); + let header = unsafe { + const HEADER_SIZE: usize = mem::size_of::(); + core::mem::transmute::<[u8; HEADER_SIZE], VirtioNetHdr>( + header[..HEADER_SIZE].try_into().unwrap(), + ) + }; + trace!("Receive data with header {:?}", header);*/ + // Create static reference for the user-space // As long as we keep the Transfer in a raw reference this reference is static, // so this is fine. @@ -613,6 +656,14 @@ impl NetworkInterface for VirtioNetDriver { Ok(vec_data) } else if recv_data.len() == 1 { let packet = recv_data.pop().unwrap(); + /*let header = unsafe { + const HEADER_SIZE: usize = mem::size_of::(); + core::mem::transmute::<[u8; HEADER_SIZE], VirtioNetHdr>( + packet[..HEADER_SIZE].try_into().unwrap(), + ) + }; + trace!("Receive data with header {:?}", header);*/ + let payload_ptr = (&packet[mem::size_of::()] as *const u8) as *mut u8; @@ -635,7 +686,7 @@ impl NetworkInterface for VirtioNetDriver { transfer .reuse() .unwrap() - .write_seq(None::<&VirtioNetHdr>, Some(&VirtioNetHdr::get_rx_hdr())) + .write_seq(None::<&VirtioNetHdr>, Some(&VirtioNetHdr::default())) .unwrap() .provide() .dispatch_await(Rc::clone(&self.recv_vqs.poll_queue), false); @@ -677,6 +728,15 @@ impl NetworkInterface for VirtioNetDriver { result } + + /// Returns `true` if the device supports the virtio feature + /// `VIRTIO_NET_F_GUEST_CSUM` and trust the incoming packages. + fn with_checksums(&self) -> bool { + !self + .dev_cfg + .features + .is_feature(Features::VIRTIO_NET_F_GUEST_CSUM) + } } // Backend-independent interface for Virtio network driver @@ -796,12 +856,13 @@ impl VirtioNetDriver { feats.push(Features::VIRTIO_NET_F_MTU); // Packed Vq can be used feats.push(Features::VIRTIO_F_RING_PACKED); + // Avoid the creation of checksums + feats.push(Features::VIRTIO_NET_F_GUEST_CSUM); // Currently the driver does NOT support the features below. // In order to provide functionality for these, the driver // needs to take care of calculating checksum in // RxQueues.post_processing() - // feats.push(Features::VIRTIO_NET_F_GUEST_CSUM); // feats.push(Features::VIRTIO_NET_F_GUEST_TSO4); // feats.push(Features::VIRTIO_NET_F_GUEST_TSO6); @@ -1068,6 +1129,8 @@ pub mod constants { use alloc::vec::Vec; use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign}; + use zerocopy::AsBytes; + pub use super::error::VirtioNetError; // Configuration constants @@ -1077,10 +1140,11 @@ pub mod constants { /// /// See Virtio specification v1.1. - 5.1.6 #[allow(dead_code, non_camel_case_types)] - #[derive(Copy, Clone, Debug)] + #[derive(AsBytes, Copy, Clone, Debug)] #[repr(u8)] - /// pub enum NetHdrFlag { + /// No further information + VIRTIO_NET_HDR_F_NONE = 0, /// use csum_start, csum_offset VIRTIO_NET_HDR_F_NEEDS_CSUM = 1, /// csum is valid @@ -1092,6 +1156,7 @@ pub mod constants { impl From for u8 { fn from(val: NetHdrFlag) -> Self { match val { + NetHdrFlag::VIRTIO_NET_HDR_F_NONE => 0, NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM => 1, NetHdrFlag::VIRTIO_NET_HDR_F_DATA_VALID => 2, NetHdrFlag::VIRTIO_NET_HDR_F_RSC_INFO => 4, @@ -1147,7 +1212,7 @@ pub mod constants { /// /// See Virtio specification v1.1. - 5.1.6 #[allow(dead_code, non_camel_case_types)] - #[derive(Copy, Clone, Debug)] + #[derive(AsBytes, Copy, Clone, Debug)] #[repr(u8)] pub enum NetHdrGSO { /// not a GSO frame diff --git a/src/executor/device.rs b/src/executor/device.rs index 476f019a09..a9e5367611 100644 --- a/src/executor/device.rs +++ b/src/executor/device.rs @@ -5,7 +5,7 @@ use core::slice; use core::str::FromStr; use smoltcp::iface::{Config, Interface, SocketSet}; -use smoltcp::phy::{self, Device, DeviceCapabilities, Medium}; +use smoltcp::phy::{self, Checksum, Device, DeviceCapabilities, Medium}; #[cfg(feature = "dhcpv4")] use smoltcp::socket::dhcpv4; use smoltcp::time::Instant; @@ -24,26 +24,34 @@ use crate::drivers::pci as hardware; #[derive(Debug, Copy, Clone)] #[repr(C)] pub(crate) struct HermitNet { - pub mtu: u16, + mtu: u16, + with_checksums: bool, } impl HermitNet { - pub(crate) const fn new(mtu: u16) -> Self { - Self { mtu } + pub(crate) const fn new(mtu: u16, with_checksums: bool) -> Self { + Self { + mtu, + with_checksums, + } } } impl<'a> NetworkInterface<'a> { #[cfg(feature = "dhcpv4")] pub(crate) fn create() -> NetworkState<'a> { - let (mtu, mac) = if let Some(driver) = hardware::get_network_driver() { + let (mtu, mac, with_checksums) = if let Some(driver) = hardware::get_network_driver() { let guard = driver.lock(); - (guard.get_mtu(), guard.get_mac_address()) + ( + guard.get_mtu(), + guard.get_mac_address(), + guard.with_checksums(), + ) } else { return NetworkState::InitializationFailed; }; - let mut device = HermitNet::new(mtu); + let mut device = HermitNet::new(mtu, with_checksums); let ethernet_addr = EthernetAddress([mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]]); let hardware_addr = HardwareAddress::Ethernet(ethernet_addr); @@ -74,14 +82,18 @@ impl<'a> NetworkInterface<'a> { #[cfg(not(feature = "dhcpv4"))] pub(crate) fn create() -> NetworkState<'a> { - let (mtu, mac) = if let Some(driver) = hardware::get_network_driver() { + let (mtu, mac, with_checksums) = if let Some(driver) = hardware::get_network_driver() { let guard = driver.lock(); - (guard.get_mtu(), guard.get_mac_address()) + ( + guard.get_mtu(), + guard.get_mac_address(), + guard.with_checksums(), + ) } else { return NetworkState::InitializationFailed; }; - let mut device = HermitNet::new(mtu); + let mut device = HermitNet::new(mtu, with_checksums); let myip = Ipv4Address::from_str(hermit_var_or!("HERMIT_IP", "10.0.5.3")).unwrap(); let mygw = Ipv4Address::from_str(hermit_var_or!("HERMIT_GATEWAY", "10.0.5.1")).unwrap(); @@ -156,6 +168,10 @@ impl Device for HermitNet { fn capabilities(&self) -> DeviceCapabilities { let mut cap = DeviceCapabilities::default(); cap.max_transmission_unit = self.mtu.into(); + if !self.with_checksums { + cap.checksum.tcp = Checksum::None; + cap.checksum.udp = Checksum::None; + } cap }