Skip to content

Commit

Permalink
Merge pull request #404 from Comcast/topic/lenient-ssm
Browse files Browse the repository at this point in the history
Fixes #402 - allows source specific multicast ranges outside RFC range
  • Loading branch information
mpilquist authored May 16, 2022
2 parents 58796ea + 07608ed commit 679c429
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 37 deletions.
11 changes: 8 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ ThisBuild / initialCommands := "import com.comcast.ip4s._"

ThisBuild / mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("com.comcast.ip4s.Ipv6Address.toInetAddress"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("com.comcast.ip4s.Dns.*") // sealed trait
ProblemFilters.exclude[ReversedMissingMethodProblem]("com.comcast.ip4s.Dns.*"), // sealed trait
// Scala 3 (erroneously?) considered Multicast/SourceSpecificMulticast as sum types
ProblemFilters.exclude[DirectMissingMethodProblem]("com.comcast.ip4s.Multicast.ordinal"),
ProblemFilters.exclude[MissingTypesProblem]("com.comcast.ip4s.Multicast$"),
ProblemFilters.exclude[DirectMissingMethodProblem]("com.comcast.ip4s.SourceSpecificMulticast.ordinal"),
ProblemFilters.exclude[MissingTypesProblem]("com.comcast.ip4s.SourceSpecificMulticast$")
)

lazy val root = tlCrossRootProject.aggregate(core, testKit)
Expand Down Expand Up @@ -93,8 +98,8 @@ lazy val docs = project
.settings(
mdocIn := baseDirectory.value / "src",
mdocOut := baseDirectory.value / "../docs",
crossScalaVersions := (ThisBuild / crossScalaVersions).value.filter(_.startsWith("2.")),
githubWorkflowArtifactUpload := false
githubWorkflowArtifactUpload := false,
libraryDependencies += "org.typelevel" %%% "cats-effect" % "3.3.11"
)

lazy val commonSettings = Seq(
Expand Down
15 changes: 12 additions & 3 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ val homeIA = ip"127.0.0.1".toInetAddress
val home4IA = ipv4"127.0.0.1".toInetAddress
// home4IA: java.net.Inet4Address = /127.0.0.1
val home6IA = ipv6"::1".toInetAddress
// home6IA: java.net.Inet6Address = /0:0:0:0:0:0:0:1
// home6IA: java.net.InetAddress = /0:0:0:0:0:0:0:1
```

# Multicast
Expand Down Expand Up @@ -135,9 +135,18 @@ To construct instances of `Multicast[A]` and `SourceSpecificMulticast[A]`, we ca
val multicastIps = ips.flatMap(_.asMulticast)
// multicastIps: List[com.comcast.ip4s.Multicast[IpAddress]] = List(224.10.10.10, 232.11.11.11, ff00::10, ff3b::11)
val ssmIps = ips.flatMap(_.asSourceSpecificMulticast)
// ssmIps: List[com.comcast.ip4s.SourceSpecificMulticast[IpAddress]] = List(232.11.11.11, ff3b::11)
// ssmIps: List[SourceSpecificMulticast.Strict[IpAddress]] = List(232.11.11.11, ff3b::11)
```

It's common for source specific multicast to be used with group addresses outside the designated source specific multicast address range. To support such cases, use `asSourceSpecificMulticastLenient`:

```scala
val lenient = ips.flatMap(_.asSourceSpecificMulticastLenient)
// lenient: List[com.comcast.ip4s.SourceSpecificMulticast[IpAddress]] = List(224.10.10.10, 232.11.11.11, ff00::10, ff3b::11)
```

Additionally, the `SourceSpecificMulticast.Strict[A]` type provides the guarantee that the wrapped address is in the RFC defined source specific range.

## Multicast Literals

There are string interpolators for constructing multicast and source specific multicast address from literal strings, similar to the `ip`, `ipv4`, and `ipv6` interpolators. The multicast interpolators are:
Expand Down Expand Up @@ -259,7 +268,7 @@ On the JVM, a `SocketAddress` can be converted to a `java.net.InetSocketAddress`

```scala
val u = t.toInetSocketAddress
// u: java.net.InetSocketAddress = /0:0:0:0:0:0:0:1:5555
// u: java.net.InetSocketAddress = /[0:0:0:0:0:0:0:1]:5555
```

## Multicast Socket Addresses
Expand Down
8 changes: 8 additions & 0 deletions docs/src/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ val multicastIps = ips.flatMap(_.asMulticast)
val ssmIps = ips.flatMap(_.asSourceSpecificMulticast)
```

It's common for source specific multicast to be used with group addresses outside the designated source specific multicast address range. To support such cases, use `asSourceSpecificMulticastLenient`:

```scala mdoc:nest:to-string
val lenient = ips.flatMap(_.asSourceSpecificMulticastLenient)
```

Additionally, the `SourceSpecificMulticast.Strict[A]` type provides the guarantee that the wrapped address is in the RFC defined source specific range.

## Multicast Literals

There are string interpolators for constructing multicast and source specific multicast address from literal strings, similar to the `ip`, `ipv4`, and `ipv6` interpolators. The multicast interpolators are:
Expand Down
12 changes: 6 additions & 6 deletions shared/src/main/scala-2/Literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object Literals {
def make(c: Context)(args: c.Expr[Any]*): c.Expr[Multicast[Ipv6Address]] = apply(c)(args: _*)
}

object ssmip extends Literally[SourceSpecificMulticast[IpAddress]] {
object ssmip extends Literally[SourceSpecificMulticast.Strict[IpAddress]] {
def validate(c: Context)(s: String) = {
import c.universe._
IpAddress.fromString(s).flatMap(_.asSourceSpecificMulticast) match {
Expand All @@ -96,10 +96,10 @@ object Literals {
case None => Left("invalid source specific IP multicast address")
}
}
def make(c: Context)(args: c.Expr[Any]*): c.Expr[SourceSpecificMulticast[IpAddress]] = apply(c)(args: _*)
def make(c: Context)(args: c.Expr[Any]*): c.Expr[SourceSpecificMulticast.Strict[IpAddress]] = apply(c)(args: _*)
}

object ssmipv4 extends Literally[SourceSpecificMulticast[Ipv4Address]] {
object ssmipv4 extends Literally[SourceSpecificMulticast.Strict[Ipv4Address]] {
def validate(c: Context)(s: String) = {
import c.universe._
Ipv4Address.fromString(s).flatMap(_.asSourceSpecificMulticast) match {
Expand All @@ -108,10 +108,10 @@ object Literals {
case None => Left("invalid source specific IPv4 multicast address")
}
}
def make(c: Context)(args: c.Expr[Any]*): c.Expr[SourceSpecificMulticast[Ipv4Address]] = apply(c)(args: _*)
def make(c: Context)(args: c.Expr[Any]*): c.Expr[SourceSpecificMulticast.Strict[Ipv4Address]] = apply(c)(args: _*)
}

object ssmipv6 extends Literally[SourceSpecificMulticast[Ipv6Address]] {
object ssmipv6 extends Literally[SourceSpecificMulticast.Strict[Ipv6Address]] {
def validate(c: Context)(s: String) = {
import c.universe._
Ipv6Address.fromString(s).flatMap(_.asSourceSpecificMulticast) match {
Expand All @@ -120,7 +120,7 @@ object Literals {
case None => Left("invalid source specific IPv6 multicast address")
}
}
def make(c: Context)(args: c.Expr[Any]*): c.Expr[SourceSpecificMulticast[Ipv6Address]] = apply(c)(args: _*)
def make(c: Context)(args: c.Expr[Any]*): c.Expr[SourceSpecificMulticast.Strict[Ipv6Address]] = apply(c)(args: _*)
}

object port extends Literally[Port] {
Expand Down
6 changes: 3 additions & 3 deletions shared/src/main/scala-2/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ package object ip4s extends ip4splatform {
def mipv6(args: Any*): Multicast[Ipv6Address] =
macro Literals.mipv6.make

def ssmip(args: Any*): SourceSpecificMulticast[IpAddress] =
def ssmip(args: Any*): SourceSpecificMulticast.Strict[IpAddress] =
macro Literals.ssmip.make
def ssmipv4(args: Any*): SourceSpecificMulticast[Ipv4Address] =
def ssmipv4(args: Any*): SourceSpecificMulticast.Strict[Ipv4Address] =
macro Literals.ssmipv4.make
def ssmipv6(args: Any*): SourceSpecificMulticast[Ipv6Address] =
def ssmipv6(args: Any*): SourceSpecificMulticast.Strict[Ipv6Address] =
macro Literals.ssmipv6.make

def port(args: Any*): Port =
Expand Down
24 changes: 15 additions & 9 deletions shared/src/main/scala-3/Literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ extension (inline ctx: StringContext)
inline def mipv6(inline args: Any*): Multicast[Ipv6Address] =
${ Literals.mipv6('ctx, 'args) }

inline def ssmip(inline args: Any*): SourceSpecificMulticast[IpAddress] =
inline def ssmip(inline args: Any*): SourceSpecificMulticast.Strict[IpAddress] =
${ Literals.ssmip('ctx, 'args) }

inline def ssmipv4(inline args: Any*): SourceSpecificMulticast[Ipv4Address] =
inline def ssmipv4(inline args: Any*): SourceSpecificMulticast.Strict[Ipv4Address] =
${ Literals.ssmipv4('ctx, 'args) }

inline def ssmipv6(inline args: Any*): SourceSpecificMulticast[Ipv6Address] =
inline def ssmipv6(inline args: Any*): SourceSpecificMulticast.Strict[Ipv6Address] =
${ Literals.ssmipv6('ctx, 'args) }

inline def port(inline args: Any*): Port =
Expand Down Expand Up @@ -93,25 +93,31 @@ object Literals:
case Some(_) => Right('{ _root_.com.comcast.ip4s.Ipv6Address.fromString(${ Expr(s) }).get.asMulticast.get })
case None => Left("Invalid IPv6 multicast address")

object ssmip extends Literally[SourceSpecificMulticast[IpAddress]]:
object ssmip extends Literally[SourceSpecificMulticast.Strict[IpAddress]]:
def validate(s: String)(using Quotes) =
IpAddress.fromString(s).flatMap(_.asSourceSpecificMulticast) match
case Some(_) =>
Right('{ _root_.com.comcast.ip4s.IpAddress.fromString(${ Expr(s) }).get.asSourceSpecificMulticast.get })
Right('{
_root_.com.comcast.ip4s.IpAddress.fromString(${ Expr(s) }).get.asSourceSpecificMulticast.get
})
case None => Left("Invalid source specific IP multicast address")

object ssmipv4 extends Literally[SourceSpecificMulticast[Ipv4Address]]:
object ssmipv4 extends Literally[SourceSpecificMulticast.Strict[Ipv4Address]]:
def validate(s: String)(using Quotes) =
Ipv4Address.fromString(s).flatMap(_.asSourceSpecificMulticast) match
case Some(_) =>
Right('{ _root_.com.comcast.ip4s.Ipv4Address.fromString(${ Expr(s) }).get.asSourceSpecificMulticast.get })
Right('{
_root_.com.comcast.ip4s.Ipv4Address.fromString(${ Expr(s) }).get.asSourceSpecificMulticast.get
})
case None => Left("Invalid source specific IPv4 multicast address")

object ssmipv6 extends Literally[SourceSpecificMulticast[Ipv6Address]]:
object ssmipv6 extends Literally[SourceSpecificMulticast.Strict[Ipv6Address]]:
def validate(s: String)(using Quotes) =
Ipv6Address.fromString(s).flatMap(_.asSourceSpecificMulticast) match
case Some(_) =>
Right('{ _root_.com.comcast.ip4s.Ipv6Address.fromString(${ Expr(s) }).get.asSourceSpecificMulticast.get })
Right('{
_root_.com.comcast.ip4s.Ipv6Address.fromString(${ Expr(s) }).get.asSourceSpecificMulticast.get
})
case None => Left("Invalid source specific IPv6 multicast address")

object port extends Literally[Port]:
Expand Down
6 changes: 5 additions & 1 deletion shared/src/main/scala/com/comcast/ip4s/Host.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,13 @@ sealed abstract class IpAddress extends IpAddressPlatform with Host with Seriali
/** Converts this address to a source specific multicast address, as long as it is in the source specific multicast
* address range.
*/
def asSourceSpecificMulticast: Option[SourceSpecificMulticast[this.type]] =
def asSourceSpecificMulticast: Option[SourceSpecificMulticast.Strict[this.type]] =
SourceSpecificMulticast.fromIpAddress(this)

/** Like `asSourceSpecificMulticast` but allows multicast addresses outside the source specific range. */
def asSourceSpecificMulticastLenient: Option[SourceSpecificMulticast[this.type]] =
SourceSpecificMulticast.fromIpAddressLenient(this)

/** Narrows this address to an Ipv4Address if that is the underlying type. */
def asIpv4: Option[Ipv4Address] = collapseMappedV4.fold(Some(_), _ => None)

Expand Down
38 changes: 33 additions & 5 deletions shared/src/main/scala/com/comcast/ip4s/Multicast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ sealed trait Multicast[+A <: IpAddress] extends Product with Serializable {
object Multicast {
private case class DefaultMulticast[+A <: IpAddress](address: A) extends Multicast[A] {
override def toString: String = address.toString
override def equals(that: Any): Boolean = that match {
case m: Multicast[_] => address == m.address
case _ => false
}
override def hashCode: Int = address.hashCode
}

/** Constructs a multicast IP address. Returns `None` is the supplied address is not in the valid multicast range. */
Expand All @@ -47,27 +52,50 @@ object Multicast {

/** Witnesses that the wrapped address of type `A` is a source specific multicast address.
*
* An instance of `SourceSpecificMulticast` is typically created by either calling `Multicast.apply` or by using the
* `asSourceSpecificMulticast` method on `IpAddress`.
* An instance of `SourceSpecificMulticast` is typically created by either calling `Multicast.apply` or by using
* `asSourceSpecificMulticast` and `asSourceSpecificMulticastLenient` methods on `IpAddress`.
*/
sealed trait SourceSpecificMulticast[+A <: IpAddress] extends Multicast[A] {

/** Ensures the referenced address is in the RFC defined source specific address range. */
def strict: Option[SourceSpecificMulticast.Strict[A]] =
if (address.isSourceSpecificMulticast) Some(SourceSpecificMulticast.unsafeCreateStrict(address)) else None

override def toString: String = address.toString
}

object SourceSpecificMulticast {

/** Indicates the address is within the RFC defined source specific multicast range. */
trait Strict[+A <: IpAddress] extends SourceSpecificMulticast[A]

private case class DefaultSourceSpecificMulticast[+A <: IpAddress](address: A) extends SourceSpecificMulticast[A] {
override def toString: String = address.toString
override def equals(that: Any): Boolean = that match {
case m: Multicast[_] => address == m.address
case _ => false
}
override def hashCode: Int = address.hashCode
}

/** Constructs a source specific multicast IP address. Returns `None` is the supplied address is not in the valid
/** Constructs a source specific multicast IP address. Returns `None` if the supplied address is not in the valid
* source specific multicast range.
*/
def fromIpAddress[A <: IpAddress](address: A): Option[SourceSpecificMulticast[A]] =
if (address.isSourceSpecificMulticast) Some(new DefaultSourceSpecificMulticast(address)) else None
def fromIpAddress[A <: IpAddress](address: A): Option[SourceSpecificMulticast.Strict[A]] =
if (address.isSourceSpecificMulticast) Some(unsafeCreateStrict(address)) else None

/** Constructs a source specific multicast IP address. Unlike `fromIpAddress`, multicast addresses outside the RFC
* defined source specific range are allowed.
*/
def fromIpAddressLenient[A <: IpAddress](address: A): Option[SourceSpecificMulticast[A]] =
if (address.isMulticast) Some(unsafeCreate(address)) else None

private[ip4s] def unsafeCreate[A <: IpAddress](address: A): SourceSpecificMulticast[A] =
DefaultSourceSpecificMulticast(address)

private[ip4s] def unsafeCreateStrict[A <: IpAddress](address: A): SourceSpecificMulticast.Strict[A] =
new DefaultSourceSpecificMulticast(address) with Strict[A]

implicit def ordering[A <: IpAddress]: Ordering[SourceSpecificMulticast[A]] =
Multicast.ordering[SourceSpecificMulticast, A]
}
4 changes: 2 additions & 2 deletions shared/src/main/scala/com/comcast/ip4s/MulticastJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ object MulticastJoin {
case Some(sourceStr) =>
for {
source <- parse(sourceStr)
group <- parse(groupStr).flatMap(_.asSourceSpecificMulticast)
group <- parse(groupStr).flatMap(_.asSourceSpecificMulticastLenient)
} yield ssm(source, group)
case None =>
for {
group <- parse(groupStr).flatMap(_.asSourceSpecificMulticast)
group <- parse(groupStr).flatMap(_.asSourceSpecificMulticastLenient)
} yield asm(group)
}
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ object MulticastSocketAddress {
case Some(sourceStr) =>
for {
source <- parse(sourceStr)
group <- parse(groupStr).flatMap(_.asSourceSpecificMulticast)
group <- parse(groupStr).flatMap(_.asSourceSpecificMulticastLenient)
port <- Port.fromString(portStr)
} yield MulticastSocketAddress(MulticastJoin.ssm(source, group), port)
case None =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ object Arbitraries {
def multicastJoinGenerator[A <: IpAddress](genSource: Gen[A], genGroup: Gen[Multicast[A]]): Gen[MulticastJoin[A]] =
genGroup.flatMap { group =>
group.address.asSourceSpecificMulticast match {
case Some(grp) => genSource.filter(_.getClass == grp.getClass).flatMap(src => MulticastJoin.ssm(src, grp))
case None => MulticastJoin.asm(group)
case Some(grp) =>
genSource.filter(_.getClass == grp.address.getClass).flatMap(src => MulticastJoin.ssm(src, grp))
case None => MulticastJoin.asm(group)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ class MulticastTest extends BaseTestSuite {
test("support equality") {
forAll { (mip: Multicast[IpAddress]) =>
assertEquals(mip.address.asMulticast, Some(mip))
mip.address.asSourceSpecificMulticast.foreach(x => assert(x == mip))
mip.address.asSourceSpecificMulticast.foreach(x => assertEquals(mip, x))
mip.address.asSourceSpecificMulticastLenient.foreach(x => assertEquals(mip, x))
mip.address.asSourceSpecificMulticastLenient.foreach(x => assert(x == mip))
}
}

test("support SSM outside source specific range") {
assertEquals(ip"239.10.10.10".asSourceSpecificMulticast, None)
assertEquals(
ip"239.10.10.10".asSourceSpecificMulticastLenient,
Some(SourceSpecificMulticast.unsafeCreate(ip"239.10.10.10"))
)
}
}

0 comments on commit 679c429

Please sign in to comment.