diff --git a/examples/multicast.rs b/examples/multicast.rs index ea89a2e93..44162d6d3 100644 --- a/examples/multicast.rs +++ b/examples/multicast.rs @@ -82,11 +82,7 @@ fn main() { // Join a multicast group to receive mDNS traffic iface - .join_multicast_group( - &mut device, - Ipv4Address::from_bytes(&MDNS_GROUP), - Instant::now(), - ) + .join_multicast_group(Ipv4Address::from_bytes(&MDNS_GROUP)) .unwrap(); loop { diff --git a/examples/multicast6.rs b/examples/multicast6.rs index 814c4fe1e..46e4e7bd7 100644 --- a/examples/multicast6.rs +++ b/examples/multicast6.rs @@ -66,7 +66,7 @@ fn main() { // Join a multicast group iface - .join_multicast_group(&mut device, Ipv6Address::from_parts(&GROUP), Instant::now()) + .join_multicast_group(Ipv6Address::from_parts(&GROUP)) .unwrap(); loop { diff --git a/src/iface/interface/igmp.rs b/src/iface/interface/igmp.rs index d120a463d..e5506a98e 100644 --- a/src/iface/interface/igmp.rs +++ b/src/iface/interface/igmp.rs @@ -4,8 +4,6 @@ use super::*; #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum MulticastError { - /// The hardware device transmit buffer is full. Try again later. - Exhausted, /// The table of joined multicast groups is already full. GroupTableFull, /// Cannot join/leave the given multicast group. @@ -15,7 +13,6 @@ pub enum MulticastError { impl core::fmt::Display for MulticastError { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { - MulticastError::Exhausted => write!(f, "Exhausted"), MulticastError::GroupTableFull => write!(f, "GroupTableFull"), MulticastError::Unaddressable => write!(f, "Unaddressable"), } @@ -27,138 +24,52 @@ impl std::error::Error for MulticastError {} impl Interface { /// Add an address to a list of subscribed multicast IP addresses. - /// - /// Returns `Ok(announce_sent)` if the address was added successfully, where `announce_sent` - /// indicates whether an initial immediate announcement has been sent. - pub fn join_multicast_group>( + pub fn join_multicast_group>( &mut self, - device: &mut D, addr: T, - timestamp: Instant, - ) -> Result - where - D: Device + ?Sized, - { + ) -> Result<(), MulticastError> { let addr = addr.into(); - self.inner.now = timestamp; - - let is_not_new = self - .inner - .multicast_groups - .insert(addr, ()) - .map_err(|_| MulticastError::GroupTableFull)? - .is_some(); - if is_not_new { - return Ok(false); + if !addr.is_multicast() { + return Err(MulticastError::Unaddressable); } - match addr { - IpAddress::Ipv4(addr) => { - if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr) { - // Send initial membership report - let tx_token = device - .transmit(timestamp) - .ok_or(MulticastError::Exhausted)?; - - // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. - self.inner - .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) - .unwrap(); - - Ok(true) - } else { - Ok(false) - } - } - #[cfg(feature = "proto-ipv6")] - IpAddress::Ipv6(addr) => { - // Build report packet containing this new address - if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new( - MldRecordType::ChangeToInclude, - addr, - )]) { - // Send initial membership report - let tx_token = device - .transmit(timestamp) - .ok_or(MulticastError::Exhausted)?; - - // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. - self.inner - .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) - .unwrap(); - - Ok(true) - } else { - Ok(false) - } - } - #[allow(unreachable_patterns)] - _ => Err(MulticastError::Unaddressable), + if let Some(state) = self.inner.multicast_groups.get_mut(&addr) { + *state = match state { + MulticastGroupState::Joining => MulticastGroupState::Joining, + MulticastGroupState::Joined => MulticastGroupState::Joined, + MulticastGroupState::Leaving => MulticastGroupState::Joined, + }; + } else { + self.inner + .multicast_groups + .insert(addr, MulticastGroupState::Joining) + .map_err(|_| MulticastError::GroupTableFull)?; } + Ok(()) } /// Remove an address from the subscribed multicast IP addresses. - /// - /// Returns `Ok(leave_sent)` if the address was removed successfully, where `leave_sent` - /// indicates whether an immediate leave packet has been sent. - pub fn leave_multicast_group>( + pub fn leave_multicast_group>( &mut self, - device: &mut D, addr: T, - timestamp: Instant, - ) -> Result - where - D: Device + ?Sized, - { + ) -> Result<(), MulticastError> { let addr = addr.into(); - self.inner.now = timestamp; - let was_not_present = self.inner.multicast_groups.remove(&addr).is_none(); - if was_not_present { - return Ok(false); + if !addr.is_multicast() { + return Err(MulticastError::Unaddressable); } - match addr { - IpAddress::Ipv4(addr) => { - if let Some(pkt) = self.inner.igmp_leave_packet(addr) { - // Send group leave packet - let tx_token = device - .transmit(timestamp) - .ok_or(MulticastError::Exhausted)?; - - // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. - self.inner - .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) - .unwrap(); - - Ok(true) - } else { - Ok(false) - } + if let Some(state) = self.inner.multicast_groups.get_mut(&addr) { + let delete; + (*state, delete) = match state { + MulticastGroupState::Joining => (MulticastGroupState::Joined, true), + MulticastGroupState::Joined => (MulticastGroupState::Leaving, false), + MulticastGroupState::Leaving => (MulticastGroupState::Leaving, false), + }; + if delete { + self.inner.multicast_groups.remove(&addr); } - #[cfg(feature = "proto-ipv6")] - IpAddress::Ipv6(addr) => { - if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new( - MldRecordType::ChangeToExclude, - addr, - )]) { - // Send group leave packet - let tx_token = device - .transmit(timestamp) - .ok_or(MulticastError::Exhausted)?; - - // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. - self.inner - .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) - .unwrap(); - - Ok(true) - } else { - Ok(false) - } - } - #[allow(unreachable_patterns)] - _ => Err(MulticastError::Unaddressable), } + Ok(()) } /// Check whether the interface listens to given destination multicast IP address. @@ -166,12 +77,101 @@ impl Interface { self.inner.has_multicast_group(addr) } - /// Depending on `igmp_report_state` and the therein contained - /// timeouts, send IGMP membership reports. - pub(crate) fn igmp_egress(&mut self, device: &mut D) -> bool + /// Do multicast egress. + /// + /// - Send join/leave packets according to the multicast group state. + /// - Depending on `igmp_report_state` and the therein contained + /// timeouts, send IGMP membership reports. + pub(crate) fn multicast_egress(&mut self, device: &mut D) -> bool where D: Device + ?Sized, { + // Process multicast joins. + while let Some((&addr, _)) = self + .inner + .multicast_groups + .iter() + .find(|(_, &state)| state == MulticastGroupState::Joining) + { + match addr { + IpAddress::Ipv4(addr) => { + if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr) { + let Some(tx_token) = device.transmit(self.inner.now) else { + break; + }; + + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + } + } + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(addr) => { + if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new( + MldRecordType::ChangeToInclude, + addr, + )]) { + let Some(tx_token) = device.transmit(self.inner.now) else { + break; + }; + + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + } + } + } + + // NOTE(unwrap): this is always replacing an existing entry, so it can't fail due to the map being full. + self.inner + .multicast_groups + .insert(addr, MulticastGroupState::Joined) + .unwrap(); + } + + // Process multicast leaves. + while let Some((&addr, _)) = self + .inner + .multicast_groups + .iter() + .find(|(_, &state)| state == MulticastGroupState::Leaving) + { + match addr { + IpAddress::Ipv4(addr) => { + if let Some(pkt) = self.inner.igmp_leave_packet(addr) { + let Some(tx_token) = device.transmit(self.inner.now) else { + break; + }; + + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + } + } + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(addr) => { + if let Some(pkt) = self.inner.mldv2_report_packet(&[MldAddressRecordRepr::new( + MldRecordType::ChangeToExclude, + addr, + )]) { + let Some(tx_token) = device.transmit(self.inner.now) else { + break; + }; + + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + } + } + } + + self.inner.multicast_groups.remove(&addr); + } + match self.inner.igmp_report_state { IgmpReportState::ToSpecificQuery { version, diff --git a/src/iface/interface/mod.rs b/src/iface/interface/mod.rs index 870e88e61..b2cbdc399 100644 --- a/src/iface/interface/mod.rs +++ b/src/iface/interface/mod.rs @@ -82,6 +82,16 @@ pub struct Interface { fragmenter: Fragmenter, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MulticastGroupState { + /// Joining group, we have to send the join packet. + Joining, + /// We've already sent the join packet, we have nothing to do. + Joined, + /// We want to leave the group, we have to send a leave packet. + Leaving, +} + /// The device independent part of an Ethernet network interface. /// /// Separating the device from the data required for processing and dispatching makes @@ -112,7 +122,7 @@ pub struct InterfaceInner { any_ip: bool, routes: Routes, #[cfg(any(feature = "proto-igmp", feature = "proto-ipv6"))] - multicast_groups: LinearMap, + multicast_groups: LinearMap, /// When to report for (all or) the next multicast group membership via IGMP #[cfg(feature = "proto-igmp")] igmp_report_state: IgmpReportState, @@ -437,7 +447,7 @@ impl Interface { #[cfg(feature = "proto-igmp")] { - readiness_may_have_changed |= self.igmp_egress(device); + readiness_may_have_changed |= self.multicast_egress(device); } readiness_may_have_changed @@ -749,18 +759,29 @@ impl InterfaceInner { /// If built without feature `proto-igmp` this function will /// always return `false` when using IPv4. fn has_multicast_group>(&self, addr: T) -> bool { + /// Return false if we don't have the multicast group, + /// or we're leaving it. + fn wanted_state(x: Option<&MulticastGroupState>) -> bool { + match x { + None => false, + Some(MulticastGroupState::Joining) => true, + Some(MulticastGroupState::Joined) => true, + Some(MulticastGroupState::Leaving) => false, + } + } + let addr = addr.into(); match addr { #[cfg(feature = "proto-igmp")] IpAddress::Ipv4(key) => { key == Ipv4Address::MULTICAST_ALL_SYSTEMS - || self.multicast_groups.get(&addr).is_some() + || wanted_state(self.multicast_groups.get(&addr)) } #[cfg(feature = "proto-ipv6")] IpAddress::Ipv6(key) => { key == Ipv6Address::LINK_LOCAL_ALL_NODES || self.has_solicited_node(key) - || self.multicast_groups.get(&addr).is_some() + || wanted_state(self.multicast_groups.get(&addr)) } #[cfg(feature = "proto-rpl")] IpAddress::Ipv6(Ipv6Address::LINK_LOCAL_ALL_RPL_NODES) => true, diff --git a/src/iface/interface/tests/ipv4.rs b/src/iface/interface/tests/ipv4.rs index c2c8ce460..c9e100fc8 100644 --- a/src/iface/interface/tests/ipv4.rs +++ b/src/iface/interface/tests/ipv4.rs @@ -702,10 +702,9 @@ fn test_handle_igmp(#[case] medium: Medium) { // Join multicast groups let timestamp = Instant::ZERO; for group in &groups { - iface - .join_multicast_group(&mut device, *group, timestamp) - .unwrap(); + iface.join_multicast_group(*group).unwrap(); } + iface.poll(timestamp, &mut device, &mut sockets); let reports = recv_igmp(&mut device, timestamp); assert_eq!(reports.len(), 2); @@ -745,10 +744,9 @@ fn test_handle_igmp(#[case] medium: Medium) { // Leave multicast groups let timestamp = Instant::ZERO; for group in &groups { - iface - .leave_multicast_group(&mut device, *group, timestamp) - .unwrap(); + iface.leave_multicast_group(*group).unwrap(); } + iface.poll(timestamp, &mut device, &mut sockets); let leaves = recv_igmp(&mut device, timestamp); assert_eq!(leaves.len(), 2); diff --git a/src/iface/interface/tests/ipv6.rs b/src/iface/interface/tests/ipv6.rs index f6e214e76..620712c9c 100644 --- a/src/iface/interface/tests/ipv6.rs +++ b/src/iface/interface/tests/ipv6.rs @@ -1289,7 +1289,7 @@ fn test_join_ipv6_multicast_group(#[case] medium: Medium) { .collect::>() } - let (mut iface, _sockets, mut device) = setup(medium); + let (mut iface, mut sockets, mut device) = setup(medium); let groups = [ Ipv6Address::from_parts(&[0xff05, 0, 0, 0, 0, 0, 0, 0x00fb]), @@ -1299,12 +1299,12 @@ fn test_join_ipv6_multicast_group(#[case] medium: Medium) { let timestamp = Instant::from_millis(0); for &group in &groups { - iface - .join_multicast_group(&mut device, group, timestamp) - .unwrap(); + iface.join_multicast_group(group).unwrap(); assert!(iface.has_multicast_group(group)); } assert!(iface.has_multicast_group(Ipv6Address::LINK_LOCAL_ALL_NODES)); + iface.poll(timestamp, &mut device, &mut sockets); + assert!(iface.has_multicast_group(Ipv6Address::LINK_LOCAL_ALL_NODES)); let reports = recv_icmpv6(&mut device, timestamp); assert_eq!(reports.len(), 2); @@ -1374,9 +1374,9 @@ fn test_join_ipv6_multicast_group(#[case] medium: Medium) { } ); - iface - .leave_multicast_group(&mut device, group_addr, timestamp) - .unwrap(); + iface.leave_multicast_group(group_addr).unwrap(); + assert!(!iface.has_multicast_group(group_addr)); + iface.poll(timestamp, &mut device, &mut sockets); assert!(!iface.has_multicast_group(group_addr)); } }