Skip to content

Commit

Permalink
Consistent implementation for reduce operation
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Sep 19, 2024
1 parent 365a1a3 commit 84bfe56
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ import com.spotify.scio.hash._
import com.spotify.scio.util._
import com.spotify.scio.util.random.{BernoulliValueSampler, PoissonValueSampler}
import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup}
import org.apache.beam.sdk.transforms.DoFn.{Element, OutputReceiver, ProcessElement, Timestamp}
import org.apache.beam.sdk.transforms._
import org.apache.beam.sdk.values.{KV, PCollection}
import org.joda.time.Duration
import org.joda.time.{Duration, Instant}
import org.slf4j.LoggerFactory

import scala.collection.compat._
Expand Down Expand Up @@ -718,6 +719,23 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
def distinctByKey: SCollection[(K, V)] =
self.distinctBy(_._1)

/**
* Convert values into pairs of (value, timestamp).
* @group transform
*/
def withTimestampedValues: SCollection[(K, (V, Instant))] =
self.parDo(new DoFn[(K, V), (K, (V, Instant))] {
@ProcessElement
private[scio] def processElement(
@Element element: (K, V),
@Timestamp timestamp: Instant,
out: OutputReceiver[(K, (V, Instant))]
): Unit = {
val (k, v) = element
out.output((k, (v, timestamp)))
}
})

/**
* Return a new SCollection of (key, value) pairs whose values satisfy the predicate.
* @group transform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
val e = ev // defeat closure
in.map(e.toDouble)
.asInstanceOf[SCollection[JDouble]]
.pApply(Mean.globally())
.pApply(Mean.globally().withoutDefaults())
.asInstanceOf[SCollection[Double]]
}

Expand All @@ -808,14 +808,13 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
this.reduce(ord.min)

/**
* Return the latest of this SCollection according to its event time, or null if there are no
* elements.
* Return the latest of this SCollection according to its event time.
* @return
* a new SCollection with the latest element
* @group transform
*/
def latest: SCollection[T] =
this.pApply(Latest.globally())
this.withTimestamp.max(Ordering.by(_._2)).keys

/**
* Compute the SCollection's data distribution using approximate `N`-tiles.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import com.spotify.scio.ScioContext
import com.spotify.scio.util.Functions
import com.spotify.scio.coders.Coder
import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup}
import org.apache.beam.sdk.transforms.{Combine, Latest, Mean, Reify, Top}
import org.apache.beam.sdk.transforms.{Combine, Mean, Top}
import org.apache.beam.sdk.values.PCollection

import java.lang.{Double => JDouble, Iterable => JIterable}
Expand Down Expand Up @@ -136,8 +136,7 @@ class SCollectionWithFanout[T] private[values] (coll: SCollection[T], fanout: In
/** [[SCollection.latest]] with fan out. */
def latest: SCollection[T] = {
coll.transform { in =>
in.pApply("Reify Timestamps", Reify.timestamps[T]())
.pApply("Latest Value", Combine.globally(Latest.combineFn[T]()).withFanout(fanout))
new SCollectionWithFanout(in.withTimestamp, this.fanout).max(Ordering.by(_._2)).keys
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,7 @@ import com.spotify.scio.util.TupleFunctions._
import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup}
import org.apache.beam.sdk.transforms.Combine.PerKeyWithHotKeyFanout
import org.apache.beam.sdk.transforms.Top.TopCombineFn
import org.apache.beam.sdk.transforms.{
Combine,
Latest,
Mean,
PTransform,
Reify,
SerializableFunction
}
import org.apache.beam.sdk.values.{KV, PCollection}
import org.apache.beam.sdk.transforms.{Combine, Mean, SerializableFunction}

import java.lang.{Double => JDouble}

Expand Down Expand Up @@ -170,14 +162,11 @@ class SCollectionWithHotKeyFanout[K, V] private[values] (

/** [[SCollection.latest]] with hot key fan out. */
def latestByKey: SCollection[(K, V)] = {
self.applyPerKey(new PTransform[PCollection[KV[K, V]], PCollection[KV[K, V]]]() {
override def expand(input: PCollection[KV[K, V]]): PCollection[KV[K, V]] = {
input
.apply("Reify Timestamps", Reify.timestampsInValue[K, V])
.apply("Latest Value", withFanout(Combine.perKey(Latest.combineFn[V]())))
.setCoder(input.getCoder)
}
})(kvToTuple)
self.self.transform { in =>
new SCollectionWithHotKeyFanout(in.withTimestampedValues, this.hotKeyFanout)
.maxByKey(Ordering.by(_._2))
.mapValues(_._1)
}
}

/** [[PairSCollectionFunctions.topByKey]] with hot key fanout. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,17 @@ class PairSCollectionFunctionsTest extends PipelineSpec {
}
}

it should "support withTimestampedValues" in {
runWithContext { sc =>
val p = sc.parallelizeTimestamped(
Seq(("a", 1), ("b", 2), ("c", 3)),
Seq(1L, 2L, 3L).map(new Instant(_))
)
val r = p.withTimestampedValues.map { case (k, (v, ts)) => (k, v, ts.getMillis) }
r should containInAnyOrder(Seq(("a", 1, 1L), ("b", 2, 2L), ("c", 3, 3L)))
}
}

it should "support filterValues()" in {
runWithContext { sc =>
val p = sc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ class SCollectionTest extends PipelineSpec {
runWithContext { sc =>
def max[T: Coder: Numeric](elems: T*): SCollection[T] =
sc.parallelize(elems).max
max[Int]() should beEmpty
max(1, 2, 3) should containSingleValue(3)
max(1L, 2L, 3L) should containSingleValue(3L)
max(1f, 2f, 3f) should containSingleValue(3f)
Expand All @@ -509,6 +510,7 @@ class SCollectionTest extends PipelineSpec {
runWithContext { sc =>
def mean[T: Coder: Numeric](elems: T*): SCollection[Double] =
sc.parallelize(elems).mean
mean[Int]() should beEmpty
mean(1, 2, 3) should containSingleValue(2.0)
mean(1L, 2L, 3L) should containSingleValue(2.0)
mean(1f, 2f, 3f) should containSingleValue(2.0)
Expand All @@ -520,6 +522,7 @@ class SCollectionTest extends PipelineSpec {
runWithContext { sc =>
def min[T: Coder: Numeric](elems: T*): SCollection[T] =
sc.parallelize(elems).min
min[Int]() should beEmpty
min(1, 2, 3) should containSingleValue(1)
min(1L, 2L, 3L) should containSingleValue(1L)
min(1f, 2f, 3f) should containSingleValue(1f)
Expand All @@ -531,6 +534,7 @@ class SCollectionTest extends PipelineSpec {
runWithContext { sc =>
def latest(elems: Long*): SCollection[Long] =
sc.parallelize(elems).timestampBy(Instant.ofEpochMilli).latest
latest() should beEmpty
latest(1L, 2L, 3L) should containSingleValue(3L)
}
}
Expand Down

0 comments on commit 84bfe56

Please sign in to comment.