Skip to content

Commit

Permalink
Make SpanContext sealed
Browse files Browse the repository at this point in the history
  • Loading branch information
iRevive committed Nov 10, 2023
1 parent 6c873d1 commit 762a460
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 69 deletions.
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ lazy val `java-trace` = project
"org.typelevel" %%% "cats-effect" % CatsEffectVersion,
"io.opentelemetry" % "opentelemetry-sdk-testing" % OpenTelemetryVersion % Test,
"org.typelevel" %%% "cats-effect-testkit" % CatsEffectVersion % Test,
"co.fs2" %% "fs2-core" % FS2Version % Test
"co.fs2" %% "fs2-core" % FS2Version % Test,
"org.scalameta" %%% "munit-scalacheck" % MUnitVersion % Test
),
mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[MissingClassProblem](
Expand Down
177 changes: 167 additions & 10 deletions core/trace/src/main/scala/org/typelevel/otel4s/trace/SpanContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package org.typelevel.otel4s.trace

import cats.Hash
import cats.Show
import cats.syntax.show._
import scodec.bits.ByteVector

/** A span context contains the state that must propagate to child spans and
Expand All @@ -24,8 +27,11 @@ import scodec.bits.ByteVector
* It contains the identifiers (a `trace_id` and `span_id`) associated with the
* span and a set of flags (currently only whether the context is sampled or
* not), as well as the remote flag.
*
* @see
* [[https://opentelemetry.io/docs/specs/otel/trace/api/#spancontext]]
*/
trait SpanContext {
sealed trait SpanContext {

/** Returns the trace identifier associated with this [[SpanContext]] as
* 16-byte vector.
Expand Down Expand Up @@ -69,6 +75,18 @@ trait SpanContext {
* parent.
*/
def isRemote: Boolean

override final def hashCode(): Int =
Hash[SpanContext].hash(this)

override final def equals(obj: Any): Boolean =
obj match {
case other: SpanContext => Hash[SpanContext].eqv(this, other)
case _ => false
}

override final def toString: String =
Show[SpanContext].show(this)
}

object SpanContext {
Expand Down Expand Up @@ -104,14 +122,153 @@ object SpanContext {
}

val invalid: SpanContext =
new SpanContext {
val traceIdHex: String = TraceId.InvalidHex
val traceId: ByteVector = ByteVector.fromValidHex(traceIdHex)
val spanIdHex: String = SpanId.InvalidHex
val spanId: ByteVector = ByteVector.fromValidHex(spanIdHex)
val traceFlags: TraceFlags = TraceFlags.Default
val traceState: TraceState = TraceState.empty
val isValid: Boolean = false
val isRemote: Boolean = false
SpanContextImpl(
traceId = TraceId.Invalid,
traceIdHex = TraceId.Invalid.toHex,
spanId = SpanId.Invalid,
spanIdHex = SpanId.Invalid.toHex,
traceFlags = TraceFlags.Default,
traceState = TraceState.empty,
isRemote = false,
isValid = false
)

/** Creates a new [[SpanContext]] with the given identifiers and options.
*
* If the `traceId` or the `spanId` are invalid (ie. do not conform to the
* requirements for hexadecimal ids of the appropriate lengths), both will be
* replaced with the standard "invalid" versions (i.e. all '0's).
*
* @param traceId
* the trace identifier of the span context
*
* @param spanId
* the span identifier of the span context
*
* @param traceFlags
* the trace flags of the span context
*
* @param traceState
* the trace state of the span context
*
* @param remote
* whether the span is propagated from the remote parent or not
*/
def apply(
traceId: ByteVector,
spanId: ByteVector,
traceFlags: TraceFlags,
traceState: TraceState,
remote: Boolean
): SpanContext =
createInternal(
traceId = traceId,
spanId = spanId,
traceFlags = traceFlags,
traceState = traceState,
remote = remote,
skipIdValidation = false
)

implicit val spanContextHash: Hash[SpanContext] =
Hash.by { ctx =>
(
ctx.traceIdHex,
ctx.spanIdHex,
ctx.traceFlags,
ctx.traceState,
ctx.isValid,
ctx.isRemote
)
}

implicit val spanContextShow: Show[SpanContext] =
Show.show { ctx =>
show"SpanContext{" +
show"traceId=${ctx.traceIdHex}, " +
show"spanId=${ctx.spanIdHex}, " +
show"traceFlags=${ctx.traceFlags}, " +
show"traceState=${ctx.traceState}, " +
show"remote=${ctx.isRemote}, " +
show"valid=${ctx.isValid}}"
}

/** Creates a new [[SpanContext]] with the given identifiers and options.
*
* If the id validation isn't skipped and the `traceId` or the `spanId` are
* invalid (ie. do not conform to the requirements for hexadecimal ids of the
* appropriate lengths), both will be replaced with the standard "invalid"
* versions (i.e. all '0's).
*
* '''Note''': the method is for the internal use only. It is not supposed to
* be publicly available.
*
* @see
* [[apply]]
*
* @param traceId
* the trace identifier of the span context
*
* @param spanId
* the span identifier of the span context
*
* @param traceFlags
* the trace flags of the span context
*
* @param traceState
* the trace state of the span context
*
* @param remote
* whether the span is propagated from the remote parent or not
*
* @param skipIdValidation
* pass true to skip validation of trace ID and span ID as an optimization
* in cases where they are known to have been already validated
*/
private[otel4s] def createInternal(
traceId: ByteVector,
spanId: ByteVector,
traceFlags: TraceFlags,
traceState: TraceState,
remote: Boolean,
skipIdValidation: Boolean
): SpanContext = {
if (
skipIdValidation || (TraceId.isValid(traceId) && SpanId.isValid(spanId))
) {
SpanContextImpl(
traceId = traceId,
traceIdHex = traceId.toHex,
spanId = spanId,
spanIdHex = spanId.toHex,
traceFlags = traceFlags,
traceState = traceState,
isRemote = remote,
isValid = true
)
} else {
SpanContextImpl(
traceId = TraceId.Invalid,
traceIdHex = TraceId.Invalid.toHex,
spanId = SpanId.Invalid,
spanIdHex = SpanId.Invalid.toHex,
traceFlags = traceFlags,
traceState = traceState,
isRemote = remote,
isValid = false
)
}
}

private final case class SpanContextImpl(
traceId: ByteVector,
traceIdHex: String,
spanId: ByteVector,
spanIdHex: String,
traceFlags: TraceFlags,
traceState: TraceState,
isRemote: Boolean,
isValid: Boolean
) extends SpanContext

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,147 @@

package org.typelevel.otel4s.trace

import cats.Show
import cats.kernel.laws.discipline.HashTests
import cats.syntax.show._
import munit._
import org.scalacheck.Arbitrary
import org.scalacheck.Cogen
import org.scalacheck.Gen
import org.scalacheck.Prop
import org.typelevel.otel4s.trace.SpanContext.SpanId
import org.typelevel.otel4s.trace.SpanContext.TraceId
import scodec.bits.ByteVector

class SpanContextSuite extends FunSuite {
class SpanContextSuite extends DisciplineSuite {

private val traceIdGen: Gen[ByteVector] =
for {
hi <- Gen.long
lo <- Gen.long.suchThat(_ != 0)
} yield SpanContext.TraceId.fromLongs(hi, lo)

private val spanIdGen: Gen[ByteVector] =
for {
value <- Gen.long.suchThat(_ != 0)
} yield SpanContext.SpanId.fromLong(value)

private val spanContextGen: Gen[SpanContext] =
for {
traceId <- traceIdGen
spanId <- spanIdGen
traceFlags <- Gen.oneOf(TraceFlags.Sampled, TraceFlags.Default)
remote <- Gen.oneOf(true, false)
} yield SpanContext(traceId, spanId, traceFlags, TraceState.empty, remote)

private implicit val spanContextArbitrary: Arbitrary[SpanContext] =
Arbitrary(spanContextGen)

private implicit val spanContextCogen: Cogen[SpanContext] =
Cogen[(String, String, Byte, Map[String, String], Boolean, Boolean)]
.contramap(c =>
(
c.traceIdHex,
c.spanIdHex,
c.traceFlags.toByte,
c.traceState.asMap,
c.isValid,
c.isRemote
)
)

private val ValidTraceIdHex =
"00000000000000000000000000000061"

private val ValidTraceId =
ByteVector.fromValidHex(ValidTraceIdHex)

private val ValidSpanIdHex =
"0000000000000061"

private val ValidSpanId =
ByteVector.fromValidHex(ValidSpanIdHex)

checkAll("SpanContext.HashLaws", HashTests[SpanContext].hash)

test("invalid span context") {
assertEquals(SpanContext.invalid.traceId, TraceId.Invalid)
assertEquals(SpanContext.invalid.traceIdHex, TraceId.InvalidHex)
assertEquals(SpanContext.invalid.spanId, SpanId.Invalid)
assertEquals(SpanContext.invalid.spanIdHex, SpanId.InvalidHex)
assertEquals(SpanContext.invalid.traceFlags, TraceFlags.Default)
assertEquals(SpanContext.invalid.traceState, TraceState.empty)
assertEquals(SpanContext.invalid.isValid, false)
assertEquals(SpanContext.invalid.isRemote, false)
}

test("fallback to an invalid span context when trace id is invalid") {
val input = List(
ByteVector.empty,
ByteVector.fromValidHex("0000"), // too short
ByteVector.fromValidHex("0000000000000000000000000000006100"), // too long
)

def context(traceId: ByteVector) = SpanContext(
traceId = traceId,
spanId = ByteVector.fromValidHex("0000000000000061"),
traceFlags = TraceFlags.Default,
traceState = TraceState.empty,
remote = false
)

input.foreach { traceId =>
assertEquals(context(traceId), SpanContext.invalid)
}
}

test("fallback to an invalid span context when span id is invalid") {
val input = List(
ByteVector.empty,
ByteVector.fromValidHex("0000"), // too short
ByteVector.fromValidHex("000000000000006100"), // too long
)

def context(spanId: ByteVector) = SpanContext(
traceId = ValidTraceId,
spanId = spanId,
traceFlags = TraceFlags.Default,
traceState = TraceState.empty,
remote = false
)

input.foreach { traceId =>
assertEquals(context(traceId), SpanContext.invalid)
}
}

test("convert ByteVector traceId and spanId to hex") {
val context = SpanContext(
traceId = ValidTraceId,
spanId = ValidSpanId,
traceFlags = TraceFlags.Sampled,
traceState = TraceState.empty,
remote = false
)

assertEquals(context.traceId, ValidTraceId)
assertEquals(context.traceIdHex, ValidTraceIdHex)
assertEquals(context.spanId, ValidSpanId)
assertEquals(context.spanIdHex, ValidSpanIdHex)
}

test("Show[SpanContext]") {
Prop.forAll(spanContextGen) { ctx =>
val expected = show"SpanContext{" +
show"traceId=${ctx.traceIdHex}, " +
show"spanId=${ctx.spanIdHex}, " +
show"traceFlags=${ctx.traceFlags}, " +
show"traceState=${ctx.traceState}, " +
show"remote=${ctx.isRemote}, " +
show"valid=${ctx.isValid}}"

assertEquals(Show[SpanContext].show(ctx), expected)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private[java] class SpanBackendImpl[F[_]: Sync](

private[java] object SpanBackendImpl {
def fromJSpan[F[_]: Sync](jSpan: JSpan): SpanBackendImpl[F] =
new SpanBackendImpl(jSpan, WrappedSpanContext(jSpan.getSpanContext))
new SpanBackendImpl(jSpan, WrappedSpanContext.wrap(jSpan.getSpanContext))

private def toJStatus(status: Status): JStatusCode =
status match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[java] class TracerImpl[F[_]: Sync](
case Context.Noop => None
case Context.Wrapped(underlying) =>
Option(JSpan.fromContextOrNull(underlying))
.map(jSpan => new WrappedSpanContext(jSpan.getSpanContext))
.map(jSpan => WrappedSpanContext.wrap(jSpan.getSpanContext))
}

def currentSpanOrNoop: F[Span[F]] =
Expand Down
Loading

0 comments on commit 762a460

Please sign in to comment.