diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java index 21ded1c268..cd9d7ad223 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java @@ -406,10 +406,11 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), null, mainInput.getWindowingStrategy())); + final TupleTag partialMainOutputTag = new TupleTag<>(); final GBKTransform partialCombineStreamTransform = - new GBKTransform( - getOutputCoders(pTransform), - new TupleTag<>(), + new GBKTransform(inputCoder, + Collections.singletonMap(partialMainOutputTag, KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)), + partialMainOutputTag, mainInput.getWindowingStrategy(), ctx.getPipelineOptions(), partialSystemReduceFn, @@ -418,9 +419,9 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato true); final GBKTransform finalCombineStreamTransform = - new GBKTransform( + new GBKTransform(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), getOutputCoders(pTransform), - new TupleTag<>(), + Iterables.getOnlyElement(beamNode.getOutputs().keySet()), mainInput.getWindowingStrategy(), ctx.getPipelineOptions(), finalSystemReduceFn, @@ -556,7 +557,7 @@ private static Transform createGBKTransform( final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline()); final PCollection mainInput = (PCollection) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(pTransform)); - final TupleTag mainOutputTag = new TupleTag<>(); + final TupleTag mainOutputTag = Iterables.getOnlyElement(beamNode.getOutputs().keySet()); if (isGlobalWindow(beamNode, ctx.getPipeline())) { // GroupByKey Transform when using a global windowing strategy. @@ -564,6 +565,7 @@ private static Transform createGBKTransform( } else { // GroupByKey Transform when using a non-global windowing strategy. return new GBKTransform<>( + (KvCoder) mainInput.getCoder(), getOutputCoders(pTransform), mainOutputTag, mainInput.getWindowingStrategy(), diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java index 9dd2e5a6e9..1bf6cb8d88 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java @@ -58,7 +58,8 @@ public final class GBKTransform private transient OutputCollector originOc; private final boolean isPartialCombining; - public GBKTransform(final Map, Coder> outputCoders, + public GBKTransform(final Coder> inputCoder, + final Map, Coder> outputCoders, final TupleTag> mainOutputTag, final WindowingStrategy windowingStrategy, final PipelineOptions options, @@ -67,7 +68,7 @@ public GBKTransform(final Map, Coder> outputCoders, final DisplayData displayData, final boolean isPartialCombining) { super(null, - null, + inputCoder, outputCoders, mainOutputTag, Collections.emptyList(), /* no additional outputs */ @@ -278,7 +279,7 @@ public GBKOutputCollector(final OutputCollector oc) { /** Emit output. If {@param output} is emitted on-time, save its timestamp in the output watermark map. */ @Override - public void emit(final WindowedValue> output) { + public final void emit(final WindowedValue> output) { // The watermark advances only in ON_TIME if (output.getPane().getTiming().equals(PaneInfo.Timing.ON_TIME)) { KV value = output.getValue(); @@ -296,13 +297,13 @@ public void emit(final WindowedValue> output) { /** Emit watermark. */ @Override - public void emitWatermark(final Watermark watermark) { + public final void emitWatermark(final Watermark watermark) { oc.emitWatermark(watermark); } /** Emit output value to {@param dstVertexId}. */ @Override - public void emit(final String dstVertexId, final T output) { + public final void emit(final String dstVertexId, final T output) { oc.emit(dstVertexId, output); } } diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java index 45933b0d37..3c08c50fb2 100644 --- a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java +++ b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java @@ -18,6 +18,7 @@ */ package org.apache.nemo.compiler.frontend.beam.transform; +import com.google.common.collect.Iterables; import junit.framework.TestCase; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.sdk.coders.*; @@ -41,15 +42,12 @@ import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.*; import static org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES; -import static org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode.DISCARDING_FIRED_PANES; -import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; public class GBKTransformTest extends TestCase { private static final Logger LOG = LoggerFactory.getLogger(GBKTransformTest.class.getName()); private final static Coder STRING_CODER = StringUtf8Coder.of(); private final static Coder INTEGER_CODER = BigEndianIntegerCoder.of(); - private final static Map, Coder> NULL_OUTPUT_CODERS = null; private void checkOutput(final KV expected, final KV result) { // check key @@ -155,7 +153,8 @@ public void test_combine() { final GBKTransform combine_transform = new GBKTransform( - NULL_OUTPUT_CODERS, + KvCoder.of(STRING_CODER, INTEGER_CODER), + Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, INTEGER_CODER)), outputTag, WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES), PipelineOptionsFactory.as(NemoPipelineOptions.class), @@ -283,7 +282,8 @@ public void test_combine_lateData() { final GBKTransform combine_transform = new GBKTransform( - NULL_OUTPUT_CODERS, + KvCoder.of(STRING_CODER, INTEGER_CODER), + Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, INTEGER_CODER)), outputTag, WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES).withAllowedLateness(lateness), PipelineOptionsFactory.as(NemoPipelineOptions.class), @@ -377,7 +377,8 @@ public void test_gbk() { final GBKTransform> doFnTransform = new GBKTransform( - NULL_OUTPUT_CODERS, + KvCoder.of(STRING_CODER, STRING_CODER), + Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, IterableCoder.of(STRING_CODER))), outputTag, WindowingStrategy.of(slidingWindows), PipelineOptionsFactory.as(NemoPipelineOptions.class), @@ -562,7 +563,8 @@ public void test_gbk_eventTimeTrigger() { final GBKTransform> doFnTransform = new GBKTransform( - NULL_OUTPUT_CODERS, + KvCoder.of(STRING_CODER, STRING_CODER), + Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, IterableCoder.of(STRING_CODER))), outputTag, WindowingStrategy.of(window).withTrigger(trigger) .withMode(ACCUMULATING_FIRED_PANES)