diff --git a/.github/workflows/beam_PreCommit_Java.yml b/.github/workflows/beam_PreCommit_Java.yml index 6f4669e0c784..21898e5f758b 100644 --- a/.github/workflows/beam_PreCommit_Java.yml +++ b/.github/workflows/beam_PreCommit_Java.yml @@ -161,6 +161,7 @@ jobs: matrix: job_name: [beam_PreCommit_Java] job_phrase: [Run Java PreCommit] + timeout-minutes: 120 if: | github.event_name == 'push' || github.event_name == 'pull_request_target' || diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceReaderBase.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceReaderBase.java index f0b93e0dde0f..ce4404f8ce9a 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceReaderBase.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceReaderBase.java @@ -17,8 +17,6 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.io.source; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; @@ -42,7 +40,6 @@ import org.apache.beam.runners.flink.metrics.FlinkMetricContainerWithoutAccumulator; import org.apache.beam.runners.flink.metrics.ReaderInvocationUtil; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.compat.FlinkSourceCompat; -import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.io.UnboundedSource; @@ -149,19 +146,11 @@ public List> snapshotState(long checkpointId) { // Add all the source splits being actively read. beamSourceReaders.forEach( (splitId, readerAndOutput) -> { - Source.Reader reader = readerAndOutput.reader; - if (reader instanceof BoundedSource.BoundedReader) { - // Sometimes users may decide to run a bounded source in streaming mode as "finite - // stream." - // For bounded source, the checkpoint granularity is the entire source split. - // So, in case of failure, all the data from this split will be consumed again. - splitsState.add(new FlinkSourceSplit<>(splitId, reader.getCurrentSource())); - } else if (reader instanceof UnboundedSource.UnboundedReader) { - // The checkpoint for unbounded sources is fine granular. - byte[] checkpointState = - getAndEncodeCheckpointMark((UnboundedSource.UnboundedReader) reader); - splitsState.add( - new FlinkSourceSplit<>(splitId, reader.getCurrentSource(), checkpointState)); + try { + splitsState.add(getReaderCheckpoint(splitId, readerAndOutput)); + } catch (IOException e) { + throw new IllegalStateException( + String.format("Failed to get checkpoint for split %d", splitId), e); } }); return splitsState; @@ -228,9 +217,17 @@ public void close() throws Exception { */ protected abstract CompletableFuture isAvailableForAliveReaders(); + /** Create {@link FlinkSourceSplit} for given {@code splitId}. */ + protected abstract FlinkSourceSplit getReaderCheckpoint( + int splitId, ReaderAndOutput readerAndOutput) throws IOException; + + /** Create {@link Source.Reader} for given {@link FlinkSourceSplit}. */ + protected abstract Source.Reader createReader(@Nonnull FlinkSourceSplit sourceSplit) + throws IOException; + // ----------------- protected helper methods for subclasses -------------------- - protected Optional createAndTrackNextReader() throws IOException { + protected final Optional createAndTrackNextReader() throws IOException { FlinkSourceSplit sourceSplit = sourceSplits.poll(); if (sourceSplit != null) { Source.Reader reader = createReader(sourceSplit); @@ -241,7 +238,7 @@ protected Optional createAndTrackNextReader() throws IOExceptio return Optional.empty(); } - protected void finishSplit(int splitIndex) throws IOException { + protected final void finishSplit(int splitIndex) throws IOException { ReaderAndOutput readerAndOutput = beamSourceReaders.remove(splitIndex); if (readerAndOutput != null) { LOG.info("Finished reading from split {}", readerAndOutput.splitId); @@ -252,7 +249,7 @@ protected void finishSplit(int splitIndex) throws IOException { } } - protected boolean checkIdleTimeoutAndMaybeStartCountdown() { + protected final boolean checkIdleTimeoutAndMaybeStartCountdown() { if (idleTimeoutMs <= 0) { idleTimeoutFuture.complete(null); } else if (!idleTimeoutCountingDown) { @@ -262,7 +259,7 @@ protected boolean checkIdleTimeoutAndMaybeStartCountdown() { return idleTimeoutFuture.isDone(); } - protected boolean noMoreSplits() { + protected final boolean noMoreSplits() { return noMoreSplits; } @@ -308,49 +305,6 @@ protected Map allReaders() { protected static void ignoreReturnValue(Object o) { // do nothing. } - // ------------------------------ private methods ------------------------------ - - @SuppressWarnings("unchecked") - private - byte[] getAndEncodeCheckpointMark(UnboundedSource.UnboundedReader reader) { - UnboundedSource source = - (UnboundedSource) reader.getCurrentSource(); - CheckpointMarkT checkpointMark = (CheckpointMarkT) reader.getCheckpointMark(); - Coder coder = source.getCheckpointMarkCoder(); - try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - coder.encode(checkpointMark, baos); - return baos.toByteArray(); - } catch (IOException ioe) { - throw new RuntimeException("Failed to encode checkpoint mark.", ioe); - } - } - - private Source.Reader createReader(@Nonnull FlinkSourceSplit sourceSplit) - throws IOException { - Source beamSource = sourceSplit.getBeamSplitSource(); - if (beamSource instanceof BoundedSource) { - return ((BoundedSource) beamSource).createReader(pipelineOptions); - } else if (beamSource instanceof UnboundedSource) { - return createUnboundedSourceReader(beamSource, sourceSplit.getSplitState()); - } else { - throw new IllegalStateException("Unknown source type " + beamSource.getClass()); - } - } - - private - Source.Reader createUnboundedSourceReader( - Source beamSource, @Nullable byte[] splitState) throws IOException { - UnboundedSource unboundedSource = - (UnboundedSource) beamSource; - Coder coder = unboundedSource.getCheckpointMarkCoder(); - if (splitState == null) { - return unboundedSource.createReader(pipelineOptions, null); - } else { - try (ByteArrayInputStream bais = new ByteArrayInputStream(splitState)) { - return unboundedSource.createReader(pipelineOptions, coder.decode(bais)); - } - } - } // -------------------- protected helper class --------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceSplitEnumerator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceSplitEnumerator.java index 292697479bcd..8ceab393533d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceSplitEnumerator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceSplitEnumerator.java @@ -121,7 +121,6 @@ public void addReader(int subtaskId) { List> splitsForSubtask = pendingSplits.remove(subtaskId); if (splitsForSubtask != null) { assignSplitsAndLog(splitsForSubtask, subtaskId); - pendingSplits.remove(subtaskId); } else { if (splitsInitialized) { LOG.info("There is no split for subtask {}. Signaling no more splits.", subtaskId); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java index b015b527aa45..a25964af809d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReader.java @@ -18,18 +18,29 @@ package org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.bounded; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Function; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSourceReaderBase; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSourceSplit; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; +import org.apache.flink.api.common.eventtime.Watermark; import org.apache.flink.api.connector.source.ReaderOutput; import org.apache.flink.api.connector.source.SourceReaderContext; import org.apache.flink.core.io.InputStatus; @@ -50,6 +61,8 @@ */ public class FlinkBoundedSourceReader extends FlinkSourceReaderBase> { private static final Logger LOG = LoggerFactory.getLogger(FlinkBoundedSourceReader.class); + private static final VarLongCoder LONG_CODER = VarLongCoder.of(); + private final Map consumedFromSplit = new HashMap<>(); private @Nullable Source.Reader currentReader; private int currentSplitId; @@ -62,6 +75,40 @@ public FlinkBoundedSourceReader( currentSplitId = -1; } + @Override + protected FlinkSourceSplit getReaderCheckpoint(int splitId, ReaderAndOutput readerAndOutput) + throws CoderException { + // Sometimes users may decide to run a bounded source in streaming mode as "finite + // stream." + // For bounded source, the checkpoint granularity is the entire source split. + // So, in case of failure, all the data from this split will be consumed again. + return new FlinkSourceSplit<>( + splitId, readerAndOutput.reader.getCurrentSource(), asBytes(consumedFromSplit(splitId))); + } + + @Override + protected Source.Reader createReader(@Nonnull FlinkSourceSplit sourceSplit) + throws IOException { + Source beamSource = sourceSplit.getBeamSplitSource(); + byte[] state = sourceSplit.getSplitState(); + if (state != null) { + consumedFromSplit.put(Integer.parseInt(sourceSplit.splitId()), fromBytes(state)); + } + return ((BoundedSource) beamSource).createReader(pipelineOptions); + } + + private byte[] asBytes(long l) throws CoderException { + return CoderUtils.encodeToByteArray(LONG_CODER, l); + } + + private long fromBytes(byte[] b) throws CoderException { + return CoderUtils.decodeFromByteArray(LONG_CODER, b); + } + + private long consumedFromSplit(int splitId) { + return consumedFromSplit.getOrDefault(splitId, 0L); + } + @VisibleForTesting protected FlinkBoundedSourceReader( String stepName, @@ -78,26 +125,28 @@ public InputStatus pollNext(ReaderOutput> output) throws Except checkExceptionAndMaybeThrow(); if (currentReader == null && !moveToNextNonEmptyReader()) { // Nothing to read for now. - if (noMoreSplits() && checkIdleTimeoutAndMaybeStartCountdown()) { - // All the source splits have been read and idle timeout has passed. - LOG.info( - "All splits have finished reading, and idle time {} ms has passed.", idleTimeoutMs); - return InputStatus.END_OF_INPUT; - } else { - // This reader either hasn't received NoMoreSplitsEvent yet or it is waiting for idle - // timeout. - return InputStatus.NOTHING_AVAILABLE; + if (noMoreSplits()) { + output.emitWatermark(Watermark.MAX_WATERMARK); + if (checkIdleTimeoutAndMaybeStartCountdown()) { + // All the source splits have been read and idle timeout has passed. + LOG.info( + "All splits have finished reading, and idle time {} ms has passed.", idleTimeoutMs); + return InputStatus.END_OF_INPUT; + } } + // This reader either hasn't received NoMoreSplitsEvent yet or it is waiting for idle + // timeout. + return InputStatus.NOTHING_AVAILABLE; } - Source.Reader tempCurrentReader = currentReader; - if (tempCurrentReader != null) { - T record = tempCurrentReader.getCurrent(); + if (currentReader != null) { + // make null checks happy + final @Nonnull Source.Reader splitReader = currentReader; + // store number of processed elements from this split + consumedFromSplit.compute(currentSplitId, (k, v) -> v == null ? 1 : v + 1); + T record = splitReader.getCurrent(); WindowedValue windowedValue = WindowedValue.of( - record, - tempCurrentReader.getCurrentTimestamp(), - GlobalWindow.INSTANCE, - PaneInfo.NO_FIRING); + record, splitReader.getCurrentTimestamp(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); if (timestampExtractor == null) { output.collect(windowedValue); } else { @@ -107,11 +156,12 @@ public InputStatus pollNext(ReaderOutput> output) throws Except // If the advance() invocation throws exception here, the job will just fail over and read // everything again from // the beginning. So the failover granularity is the entire Flink job. - if (!invocationUtil.invokeAdvance(tempCurrentReader)) { + if (!invocationUtil.invokeAdvance(splitReader)) { finishSplit(currentSplitId); + consumedFromSplit.remove(currentSplitId); + LOG.debug("Finished reading from {}", currentSplitId); currentReader = null; currentSplitId = -1; - LOG.debug("Finished reading from {}", currentSplitId); } // Always return MORE_AVAILABLE here regardless of the availability of next record. If there // is no more @@ -138,6 +188,12 @@ private boolean moveToNextNonEmptyReader() throws IOException { if (invocationUtil.invokeStart(rao.reader)) { currentSplitId = Integer.parseInt(rao.splitId); currentReader = rao.reader; + long toSkipAfterStart = + MoreObjects.firstNonNull(consumedFromSplit.remove(currentSplitId), 0L); + @Nonnull Source.Reader reader = Preconditions.checkArgumentNotNull(currentReader); + while (toSkipAfterStart > 0 && reader.advance()) { + toSkipAfterStart--; + } return true; } else { finishSplit(Integer.parseInt(rao.splitId)); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReader.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReader.java index 0a7acb669efd..7b02702e244c 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReader.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReader.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.unbounded; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -25,9 +27,12 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.beam.runners.flink.FlinkPipelineOptions; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSourceReaderBase; +import org.apache.beam.runners.flink.translation.wrappers.streaming.io.source.FlinkSourceSplit; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; @@ -179,6 +184,22 @@ protected CompletableFuture isAvailableForAliveReaders() { } } + @Override + protected FlinkSourceSplit getReaderCheckpoint(int splitId, ReaderAndOutput readerAndOutput) { + // The checkpoint for unbounded sources is fine granular. + byte[] checkpointState = + getAndEncodeCheckpointMark((UnboundedSource.UnboundedReader) readerAndOutput.reader); + return new FlinkSourceSplit<>( + splitId, readerAndOutput.reader.getCurrentSource(), checkpointState); + } + + @Override + protected Source.Reader createReader(@Nonnull FlinkSourceSplit sourceSplit) + throws IOException { + Source beamSource = sourceSplit.getBeamSplitSource(); + return createUnboundedSourceReader(beamSource, sourceSplit.getSplitState()); + } + // -------------- private helper methods ---------------- private void emitRecord( @@ -274,4 +295,34 @@ private void createPendingBytesGauge(SourceReaderContext context) { return pendingBytes; }); } + + @SuppressWarnings("unchecked") + private + byte[] getAndEncodeCheckpointMark(UnboundedSource.UnboundedReader reader) { + UnboundedSource source = + (UnboundedSource) reader.getCurrentSource(); + CheckpointMarkT checkpointMark = (CheckpointMarkT) reader.getCheckpointMark(); + Coder coder = source.getCheckpointMarkCoder(); + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + coder.encode(checkpointMark, baos); + return baos.toByteArray(); + } catch (IOException ioe) { + throw new RuntimeException("Failed to encode checkpoint mark.", ioe); + } + } + + private + Source.Reader createUnboundedSourceReader( + Source beamSource, @Nullable byte[] splitState) throws IOException { + UnboundedSource unboundedSource = + (UnboundedSource) beamSource; + Coder coder = unboundedSource.getCheckpointMarkCoder(); + if (splitState == null) { + return unboundedSource.createReader(pipelineOptions, null); + } else { + try (ByteArrayInputStream bais = new ByteArrayInputStream(splitState)) { + return unboundedSource.createReader(pipelineOptions, coder.decode(bais)); + } + } + } } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceReaderTestBase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceReaderTestBase.java index 462a1ba0153d..c635a5778b5c 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceReaderTestBase.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/FlinkSourceReaderTestBase.java @@ -363,11 +363,9 @@ public int numCollectedRecords() { } public boolean allRecordsConsumed() { - boolean allRecordsConsumed = true; - for (Source source : sources) { - allRecordsConsumed = allRecordsConsumed && ((TestSource) source).isConsumptionCompleted(); - } - return allRecordsConsumed; + return sources.stream() + .map(TestSource.class::cast) + .allMatch(TestSource::isConsumptionCompleted); } public boolean allTimestampReceived() { diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReaderTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReaderTest.java index 84cb2a72ddaf..022f1abde826 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReaderTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/bounded/FlinkBoundedSourceReaderTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.verify; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -34,6 +35,7 @@ import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.apache.flink.api.common.eventtime.Watermark; import org.apache.flink.api.connector.source.ReaderOutput; import org.apache.flink.api.connector.source.SourceReader; import org.apache.flink.api.connector.source.SourceReaderContext; @@ -62,6 +64,20 @@ public void testPollWithIdleTimeout() throws Exception { } } + @Test + public void testPollEmitsMaxWatermark() throws Exception { + ManuallyTriggeredScheduledExecutorService executor = + new ManuallyTriggeredScheduledExecutorService(); + ReaderOutput>> mockReaderOutput = + Mockito.mock(ReaderOutput.class); + try (FlinkBoundedSourceReader> reader = + (FlinkBoundedSourceReader>) createReader(executor, Long.MAX_VALUE)) { + reader.notifyNoMoreSplits(); + assertEquals(InputStatus.NOTHING_AVAILABLE, reader.pollNext(mockReaderOutput)); + verify(mockReaderOutput).emitWatermark(Watermark.MAX_WATERMARK); + } + } + @Test public void testPollWithoutIdleTimeout() throws Exception { ReaderOutput>> mockReaderOutput = @@ -107,8 +123,6 @@ public void testSnapshotStateAndRestore() throws Exception { snapshot = reader.snapshotState(0L); } - // Create a new validating output because the first split will be consumed from very beginning. - validatingOutput = new RecordsValidatingOutput(splits); // Create another reader, add the snapshot splits back. try (SourceReader>, FlinkSourceSplit>> reader = createReader()) { diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java index 6a16ca18fef9..ea23e28ddb66 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java @@ -915,10 +915,9 @@ public Coder valueCoder() { // Expect the following requests for the first bundle: // * one to read iterable side input // * one to read keys from multimap side input - // * one to attempt multimap side input bulk read // * one to read key1 iterable from multimap side input // * one to read key2 iterable from multimap side input - assertEquals(5, stateRequestHandler.receivedRequests.size()); + assertEquals(4, stateRequestHandler.receivedRequests.size()); assertEquals( stateRequestHandler.receivedRequests.get(0).getStateKey().getIterableSideInput(), BeamFnApi.StateKey.IterableSideInput.newBuilder() @@ -932,20 +931,14 @@ public Coder valueCoder() { .setTransformId(transformId) .build()); assertEquals( - stateRequestHandler.receivedRequests.get(2).getStateKey().getMultimapKeysValuesSideInput(), - BeamFnApi.StateKey.MultimapKeysValuesSideInput.newBuilder() - .setSideInputId(multimapView.getTagInternal().getId()) - .setTransformId(transformId) - .build()); - assertEquals( - stateRequestHandler.receivedRequests.get(3).getStateKey().getMultimapSideInput(), + stateRequestHandler.receivedRequests.get(2).getStateKey().getMultimapSideInput(), BeamFnApi.StateKey.MultimapSideInput.newBuilder() .setSideInputId(multimapView.getTagInternal().getId()) .setTransformId(transformId) .setKey(encode("key1")) .build()); assertEquals( - stateRequestHandler.receivedRequests.get(4).getStateKey().getMultimapSideInput(), + stateRequestHandler.receivedRequests.get(3).getStateKey().getMultimapSideInput(), BeamFnApi.StateKey.MultimapSideInput.newBuilder() .setSideInputId(multimapView.getTagInternal().getId()) .setTransformId(transformId) diff --git a/sdks/go.mod b/sdks/go.mod index c46c7e28a58c..11bd34cf079b 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -24,18 +24,18 @@ go 1.20 require ( cloud.google.com/go/bigquery v1.57.1 - cloud.google.com/go/bigtable v1.20.0 + cloud.google.com/go/bigtable v1.21.0 cloud.google.com/go/datastore v1.15.0 cloud.google.com/go/profiler v0.4.0 cloud.google.com/go/pubsub v1.33.0 cloud.google.com/go/spanner v1.53.1 cloud.google.com/go/storage v1.35.1 - github.com/aws/aws-sdk-go-v2 v1.23.5 - github.com/aws/aws-sdk-go-v2/config v1.25.8 - github.com/aws/aws-sdk-go-v2/credentials v1.16.9 + github.com/aws/aws-sdk-go-v2 v1.24.0 + github.com/aws/aws-sdk-go-v2/config v1.26.1 + github.com/aws/aws-sdk-go-v2/credentials v1.16.12 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.13.8 github.com/aws/aws-sdk-go-v2/service/s3 v1.42.2 - github.com/aws/smithy-go v1.18.1 + github.com/aws/smithy-go v1.19.0 github.com/docker/go-connections v0.4.0 github.com/dustin/go-humanize v1.0.1 github.com/go-sql-driver/mysql v1.7.1 @@ -45,7 +45,7 @@ require ( github.com/johannesboyne/gofakes3 v0.0.0-20221110173912-32fb85c5aed6 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.12.0 - github.com/nats-io/nats-server/v2 v2.10.6 + github.com/nats-io/nats-server/v2 v2.10.7 github.com/nats-io/nats.go v1.31.0 github.com/proullon/ramsql v0.1.3 github.com/spf13/cobra v1.8.0 @@ -106,18 +106,18 @@ require ( github.com/apache/thrift v0.16.0 // indirect github.com/aws/aws-sdk-go v1.34.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.1 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.9 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.8 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.8 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.7.1 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.2.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.8 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.18.2 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.26.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.26.5 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect @@ -146,7 +146,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect - github.com/klauspost/compress v1.17.3 // indirect + github.com/klauspost/compress v1.17.4 // indirect github.com/klauspost/cpuid/v2 v2.2.5 // indirect github.com/kr/text v0.2.0 // indirect github.com/magiconair/properties v1.8.7 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 5df50d7fd020..df82d0afbce1 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -15,8 +15,8 @@ cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNF cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= cloud.google.com/go/bigquery v1.57.1 h1:FiULdbbzUxWD0Y4ZGPSVCDLvqRSyCIO6zKV7E2nf5uA= cloud.google.com/go/bigquery v1.57.1/go.mod h1:iYzC0tGVWt1jqSzBHqCr3lrRn0u13E8e+AqowBsDgug= -cloud.google.com/go/bigtable v1.20.0 h1:NqZC/WcesSn4O8L0I2JmuNsUigSyBQifVLYgM9LMQeQ= -cloud.google.com/go/bigtable v1.20.0/go.mod h1:upJDn8frsjzpRMfybiWkD1PG6WCCL7CRl26MgVeoXY4= +cloud.google.com/go/bigtable v1.21.0 h1:BFN4jhkA9ULYYV2Ug7AeOtetVLnN2jKuIq5TcRc5C38= +cloud.google.com/go/bigtable v1.21.0/go.mod h1:V0sYNRtk0dgAKjyRr/MyBpHpSXqh+9P39euf820EZ74= cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiVlk= cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= @@ -81,39 +81,39 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= -github.com/aws/aws-sdk-go-v2 v1.23.5 h1:xK6C4udTyDMd82RFvNkDQxtAd00xlzFUtX4fF2nMZyg= -github.com/aws/aws-sdk-go-v2 v1.23.5/go.mod h1:t3szzKfP0NeRU27uBFczDivYJjsmSnqI8kIvKyWb9ds= +github.com/aws/aws-sdk-go-v2 v1.24.0 h1:890+mqQ+hTpNuw0gGP6/4akolQkSToDJgHfQE7AwGuk= +github.com/aws/aws-sdk-go-v2 v1.24.0/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.1 h1:ZY3108YtBNq96jNZTICHxN1gSBSbnvIdYwwqnvCV4Mc= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.1/go.mod h1:t8PYl/6LzdAqsU4/9tz28V/kU+asFePvpOMkdul0gEQ= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= -github.com/aws/aws-sdk-go-v2/config v1.25.8 h1:CHr7PIzyfevjNiqL9rU6xoqHZKCO2ldY6LmvRDfpRuI= -github.com/aws/aws-sdk-go-v2/config v1.25.8/go.mod h1:zefIy117FDPOVU0xSOFG8mx9kJunuVopzI639tjYXc0= +github.com/aws/aws-sdk-go-v2/config v1.26.1 h1:z6DqMxclFGL3Zfo+4Q0rLnAZ6yVkzCRxhRMsiRQnD1o= +github.com/aws/aws-sdk-go-v2/config v1.26.1/go.mod h1:ZB+CuKHRbb5v5F0oJtGdhFTelmrxd4iWO1lf0rQwSAg= github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc= -github.com/aws/aws-sdk-go-v2/credentials v1.16.9 h1:LQo3MUIOzod9JdUK+wxmSdgzLVYUbII3jXn3S/HJZU0= -github.com/aws/aws-sdk-go-v2/credentials v1.16.9/go.mod h1:R7mDuIJoCjH6TxGUc/cylE7Lp/o0bhKVoxdBThsjqCM= +github.com/aws/aws-sdk-go-v2/credentials v1.16.12 h1:v/WgB8NxprNvr5inKIiVVrXPuuTegM+K8nncFkr1usU= +github.com/aws/aws-sdk-go-v2/credentials v1.16.12/go.mod h1:X21k0FjEJe+/pauud82HYiQbEr9jRKY3kXEIQ4hXeTQ= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.9 h1:FZVFahMyZle6WcogZCOxo6D/lkDA2lqKIn4/ueUmVXw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.9/go.mod h1:kjq7REMIkxdtcEC9/4BVXjOsNY5isz6jQbEgk6osRTU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 h1:w98BT5w+ao1/r5sUuiH6JkVzjowOKeOJRHERyy1vh58= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10/go.mod h1:K2WGI7vUvkIv1HoNbfBA1bvIZ+9kL3YVmWxeKuLQsiw= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.13.8 h1:wuOjvalpd2CnXffks74Vq6n3yv9vunKCoy4R1sjStGk= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.13.8/go.mod h1:vywwjy6VnrR48Izg136JoSUXC4mH9QeUi3g0EH9DSrA= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.8 h1:8GVZIR0y6JRIUNSYI1xAMF4HDfV8H/bOsZ/8AD/uY5Q= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.8/go.mod h1:rwBfu0SoUkBUZndVgPZKAD9Y2JigaZtRP68unRiYToQ= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.8 h1:ZE2ds/qeBkhk3yqYvS3CDCFNvd9ir5hMjlVStLZWrvM= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.8/go.mod h1:/lAPPymDYL023+TS6DJmjuL42nxix2AvEvfjqOBRODk= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 h1:v+HbZaCGmOwnTTVS86Fleq0vPzOd7tnJGbFhP0stNLs= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9/go.mod h1:Xjqy+Nyj7VDLBtCMkQYOw1QYfAEZCVLrfI0ezve8wd4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 h1:N94sVhRACtXyVcjXxrwK1SKFIJrA9pOJ5yu2eSHnmls= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9/go.mod h1:hqamLz7g1/4EJP+GH5NBhcUMLjW+gKLQabgyz6/7WAU= github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg= -github.com/aws/aws-sdk-go-v2/internal/ini v1.7.1 h1:uR9lXYjdPX0xY+NhvaJ4dD8rpSRz5VY81ccIIoNG+lw= -github.com/aws/aws-sdk-go-v2/internal/ini v1.7.1/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= github.com/aws/aws-sdk-go-v2/internal/v4a v1.2.3 h1:lMwCXiWJlrtZot0NJTjbC8G9zl+V3i68gBTBBvDeEXA= github.com/aws/aws-sdk-go-v2/internal/v4a v1.2.3/go.mod h1:5yzAuE9i2RkVAttBl8yxZgQr5OCq4D5yDnG7j9x2L0U= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.3 h1:e3PCNeEaev/ZF01cQyNZgmYE9oYYePIMJs2mWSKG514= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.3/go.mod h1:gIeeNyaL8tIEqZrzAnTeyhHcE0yysCtcaP+N9kxLZ+E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.3 h1:xbwRyCy7kXrOj89iIKLB6NfE2WCpP9HoKyk8dMDvnIQ= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.3/go.mod h1:R+/S1O4TYpcktbVwddeOYg+uwUfLhADP2S/x4QwsCTM= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.8 h1:EamsKe+ZjkOQjDdHd86/JCEucjFKQ9T0atWKO4s2Lgs= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.8/go.mod h1:Q0vV3/csTpbkfKLI5Sb56cJQTCTtJ0ixdb7P+Wedqiw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 h1:Nf2sHxjMJR8CSImIVCONRi4g0Su3J+TSTbS7G0pUeMU= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9/go.mod h1:idky4TER38YIjr2cADF1/ugFMKvZV7p//pVeV5LZbF0= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.3 h1:KV0z2RDc7euMtg8aUT1czv5p29zcLlXALNFsd3jkkEc= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.3/go.mod h1:KZgs2ny8HsxRIRbDwgvJcHHBZPOzQr/+NtGwnP+w2ec= @@ -121,16 +121,16 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32 github.com/aws/aws-sdk-go-v2/service/s3 v1.42.2 h1:NnduxUd9+Fq9DcCDdJK8v6l9lR1xDX4usvog+JuQAno= github.com/aws/aws-sdk-go-v2/service/s3 v1.42.2/go.mod h1:NXRKkiRF+erX2hnybnVU660cYT5/KChRD4iUgJ97cI8= github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM= -github.com/aws/aws-sdk-go-v2/service/sso v1.18.2 h1:xJPydhNm0Hiqct5TVKEuHG7weC0+sOs4MUnd7A5n5F4= -github.com/aws/aws-sdk-go-v2/service/sso v1.18.2/go.mod h1:zxk6y1X2KXThESWMS5CrKRvISD8mbIMab6nZrCGxDG0= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.2 h1:8dU9zqA77C5egbU6yd4hFLaiIdPv3rU+6cp7sz5FjCU= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.2/go.mod h1:7Lt5mjQ8x5rVdKqg+sKKDeuwoszDJIIPmkd8BVsEdS0= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 h1:ldSFWz9tEHAwHNmjx2Cvy1MjP5/L9kNoR0skc6wyOOM= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5/go.mod h1:CaFfXLYL376jgbP7VKC96uFcU8Rlavak0UlAwk1Dlhc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 h1:2k9KmFawS63euAkY4/ixVNsYYwrwnd5fIvgEKkfZFNM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5/go.mod h1:W+nd4wWDVkSUIox9bacmkBP5NMFQeTJ/xqNabpzSR38= github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg= -github.com/aws/aws-sdk-go-v2/service/sts v1.26.2 h1:fFrLsy08wEbAisqW3KDl/cPHrF43GmV79zXB9EwJiZw= -github.com/aws/aws-sdk-go-v2/service/sts v1.26.2/go.mod h1:7Ld9eTqocTvJqqJ5K/orbSDwmGcpRdlDiLjz2DO+SL8= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.5 h1:5UYvv8JUvllZsRnfrcMQ+hJ9jNICmcgKPAO1CER25Wg= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.5/go.mod h1:XX5gh4CB7wAs4KhcF46G6C8a2i7eupU19dcAAE+EydU= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= -github.com/aws/smithy-go v1.18.1 h1:pOdBTUfXNazOlxLrgeYalVnuTpKreACHtc62xLwIB3c= -github.com/aws/smithy-go v1.18.1/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= +github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -321,8 +321,8 @@ github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= -github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -366,8 +366,8 @@ github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7P github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ= github.com/nats-io/jwt/v2 v2.5.3 h1:/9SWvzc6hTfamcgXJ3uYRpgj+QuY2aLNqRiqrKcrpEo= github.com/nats-io/jwt/v2 v2.5.3/go.mod h1:iysuPemFcc7p4IoYots3IuELSI4EDe9Y0bQMe+I3Bf4= -github.com/nats-io/nats-server/v2 v2.10.6 h1:40U3ngyAKyC1tNT4Kw7PjuvivY74NTYD3qyIHxZUHKQ= -github.com/nats-io/nats-server/v2 v2.10.6/go.mod h1:IrTXS8o4Roa3G2kW8L5mEtSdmSrFjKhYb/m2g0gQ/vc= +github.com/nats-io/nats-server/v2 v2.10.7 h1:f5VDy+GMu7JyuFA0Fef+6TfulfCs5nBTgq7MMkFJx5Y= +github.com/nats-io/nats-server/v2 v2.10.7/go.mod h1:V2JHOvPiPdtfDXTuEUsthUnCvSDeFrK4Xn9hRo6du7c= github.com/nats-io/nats.go v1.31.0 h1:/WFBHEc/dOKBF6qf1TZhrdEfTmOZ5JzdJ+Y3m6Y/p7E= github.com/nats-io/nats.go v1.31.0/go.mod h1:di3Bm5MLsoB4Bx61CBTsxuarI36WbhAwOm8QrW39+i8= github.com/nats-io/nkeys v0.4.6 h1:IzVe95ru2CT6ta874rt9saQRkWfe2nFj1NtvYSLqMzY= diff --git a/sdks/go/examples/timer_wordcap/wordcap.go b/sdks/go/examples/timer_wordcap/wordcap.go index db64c10eb4d7..01f87edc59bd 100644 --- a/sdks/go/examples/timer_wordcap/wordcap.go +++ b/sdks/go/examples/timer_wordcap/wordcap.go @@ -118,7 +118,6 @@ func (s *Stateful) OnTimer(ctx context.Context, ts beam.EventTime, sp state.Prov // Clean up the state that has been evicted. s.ElementBag.Clear(sp) s.MinTime.Clear(sp) - s.OutputState.ClearTag(tp, tag) // Clean up the fired timer tag. (Temporary workaround for a runner bug.) } } } diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasampler.go b/sdks/go/pkg/beam/core/runtime/exec/datasampler.go new file mode 100644 index 000000000000..426213a12afd --- /dev/null +++ b/sdks/go/pkg/beam/core/runtime/exec/datasampler.go @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package exec + +import ( + "context" + "sync" + "time" +) + +// DataSample contains property for sampled element +type dataSample struct { + PCollectionID string + Timestamp time.Time + Element []byte +} + +// DataSampler manages sampled elements based on PCollectionID +type DataSampler struct { + sampleChannel chan *dataSample + samplesMap sync.Map // Key: PCollectionID string, Value: *OutputSamples pointer + ctx context.Context +} + +// NewDataSampler inits a new Data Sampler object and returns pointer to it. +func NewDataSampler(ctx context.Context) *DataSampler { + return &DataSampler{ + sampleChannel: make(chan *dataSample, 1000), + ctx: ctx, + } +} + +// Process processes sampled element. +func (d *DataSampler) Process() { + for { + select { + case <-d.ctx.Done(): + return + case sample := <-d.sampleChannel: + d.addSample(sample) + } + } +} + +// GetSamples returns samples for given pCollectionID. +// If no pCollectionID is provided, return all available samples +func (d *DataSampler) GetSamples(pids []string) map[string][]*dataSample { + if len(pids) == 0 { + return d.getAllSamples() + } + return d.getSamplesForPCollections(pids) +} + +// SendSample is called by PCollection Node to send sampled element to Data Sampler async +func (d *DataSampler) SendSample(pCollectionID string, element []byte, timestamp time.Time) { + sample := dataSample{ + PCollectionID: pCollectionID, + Element: element, + Timestamp: timestamp, + } + d.sampleChannel <- &sample +} + +func (d *DataSampler) getAllSamples() map[string][]*dataSample { + var res = make(map[string][]*dataSample) + d.samplesMap.Range(func(key any, value any) bool { + pid := key.(string) + samples := d.getSamples(pid) + if len(samples) > 0 { + res[pid] = samples + } + return true + }) + return res +} + +func (d *DataSampler) getSamplesForPCollections(pids []string) map[string][]*dataSample { + var res = make(map[string][]*dataSample) + for _, pid := range pids { + samples := d.getSamples(pid) + if len(samples) > 0 { + res[pid] = samples + } + } + return res +} + +func (d *DataSampler) addSample(sample *dataSample) { + p, ok := d.samplesMap.Load(sample.PCollectionID) + if !ok { + p = &outputSamples{maxElements: 10, sampleIndex: 0} + d.samplesMap.Store(sample.PCollectionID, p) + } + outputSamples := p.(*outputSamples) + outputSamples.addSample(sample) +} + +func (d *DataSampler) getSamples(pCollectionID string) []*dataSample { + p, ok := d.samplesMap.Load(pCollectionID) + if !ok { + return nil + } + outputSamples := p.(*outputSamples) + return outputSamples.getSamples() +} + +type outputSamples struct { + elements []*dataSample + mu sync.Mutex + maxElements int + sampleIndex int +} + +func (o *outputSamples) addSample(element *dataSample) { + o.mu.Lock() + defer o.mu.Unlock() + + if len(o.elements) < o.maxElements { + o.elements = append(o.elements, element) + } else { + o.elements[o.sampleIndex] = element + o.sampleIndex = (o.sampleIndex + 1) % o.maxElements + } +} + +func (o *outputSamples) getSamples() []*dataSample { + o.mu.Lock() + defer o.mu.Unlock() + if len(o.elements) == 0 { + return nil + } + samples := o.elements + + // Reset index and samples + o.sampleIndex = 0 + // Release memory since samples are only returned once based on best efforts + o.elements = nil + + return samples +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasampler_test.go b/sdks/go/pkg/beam/core/runtime/exec/datasampler_test.go new file mode 100644 index 000000000000..d648fd89efaa --- /dev/null +++ b/sdks/go/pkg/beam/core/runtime/exec/datasampler_test.go @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package exec + +import ( + "context" + "reflect" + "sort" + "testing" + "time" +) + +// TestDataSampler verifies that the DataSampler works correctly. +func TestDataSampler(t *testing.T) { + timestamp := time.Now() + tests := []struct { + name string + samples []dataSample + pids []string + want map[string][]*dataSample + }{ + { + name: "GetAllSamples", + samples: []dataSample{ + {PCollectionID: "pid1", Element: []byte("element1"), Timestamp: timestamp}, + {PCollectionID: "pid2", Element: []byte("element2"), Timestamp: timestamp}, + }, + pids: []string{}, + want: map[string][]*dataSample{ + "pid1": {{PCollectionID: "pid1", Element: []byte("element1"), Timestamp: timestamp}}, + "pid2": {{PCollectionID: "pid2", Element: []byte("element2"), Timestamp: timestamp}}, + }, + }, + { + name: "GetSamplesForPCollections", + samples: []dataSample{ + {PCollectionID: "pid1", Element: []byte("element1"), Timestamp: timestamp}, + {PCollectionID: "pid2", Element: []byte("element2"), Timestamp: timestamp}, + }, + pids: []string{"pid1"}, + want: map[string][]*dataSample{ + "pid1": {{PCollectionID: "pid1", Element: []byte("element1"), Timestamp: timestamp}}, + }, + }, + { + name: "GetSamplesForPCollectionsWithNoResult", + samples: []dataSample{ + {PCollectionID: "pid1", Element: []byte("element1"), Timestamp: timestamp}, + {PCollectionID: "pid2", Element: []byte("element2"), Timestamp: timestamp}, + }, + pids: []string{"pid3"}, + want: map[string][]*dataSample{}, + }, + { + name: "GetSamplesForPCollectionsTooManySamples", + samples: []dataSample{ + {PCollectionID: "pid1", Element: []byte("element1"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element2"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element3"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element4"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element5"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element6"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element7"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element8"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element9"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element10"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element11"), Timestamp: timestamp}, + }, + pids: []string{"pid1"}, + want: map[string][]*dataSample{ + "pid1": { + {PCollectionID: "pid1", Element: []byte("element2"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element3"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element4"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element5"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element6"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element7"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element8"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element9"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element10"), Timestamp: timestamp}, + {PCollectionID: "pid1", Element: []byte("element11"), Timestamp: timestamp}, + }}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + dataSampler := NewDataSampler(ctx) + go dataSampler.Process() + for _, sample := range test.samples { + dataSampler.SendSample(sample.PCollectionID, sample.Element, sample.Timestamp) + } + var samplesCount = -1 + var samples map[string][]*dataSample + for i := 0; i < 5; i++ { + samples = dataSampler.GetSamples(test.pids) + if len(samples) == len(test.want) { + samplesCount = len(samples) + break + } + time.Sleep(time.Second) + } + cancel() + if samplesCount != len(test.want) { + t.Errorf("got an unexpected number of sampled elements: %v, want: %v", samplesCount, len(test.want)) + } + if !verifySampledElements(samples, test.want) { + t.Errorf("got an unexpected sampled elements: %v, want: %v", samples, test.want) + } + }) + } +} + +func verifySampledElements(samples, want map[string][]*dataSample) bool { + if len(samples) != len(want) { + return false + } + for pid, samples := range samples { + expected, ok := want[pid] + if !ok { + return false + } + sort.SliceStable(samples, func(i, j int) bool { + return string(samples[i].Element) < string(samples[j].Element) + }) + sort.SliceStable(expected, func(i, j int) bool { + return string(expected[i].Element) < string(expected[j].Element) + }) + if !reflect.DeepEqual(samples, expected) { + return false + } + } + return true +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go index 401cdbef7a37..674de44cf35b 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go @@ -241,6 +241,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { return err } // Collect the actual size of the element, and reset the bytecounter reader. + // TODO(zechenj18) 2023-12-07: currently we never sample anything from the DataSource, we need to validate CoGBKs and similar types with the sampling implementation n.PCol.addSize(int64(bcr.reset())) // Check if there's a continuation and return residuals diff --git a/sdks/go/pkg/beam/core/runtime/exec/pcollection.go b/sdks/go/pkg/beam/core/runtime/exec/pcollection.go index 3b2e3ab3bf2c..ed13c65a6f5b 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/pcollection.go +++ b/sdks/go/pkg/beam/core/runtime/exec/pcollection.go @@ -16,12 +16,14 @@ package exec import ( + "bytes" "context" "fmt" "math" "math/rand" "sync" "sync/atomic" + "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" ) @@ -32,19 +34,22 @@ import ( // In particular, must not be placed after a Multiplex, and must be placed // after a Flatten. type PCollection struct { - UID UnitID - PColID string - Out Node // Out is the consumer of this PCollection. - Coder *coder.Coder - Seed int64 + UID UnitID + PColID string + Out Node // Out is the consumer of this PCollection. + Coder *coder.Coder + WindowCoder *coder.WindowCoder + Seed int64 r *rand.Rand nextSampleIdx int64 // The index of the next value to sample. elementCoder ElementEncoder + windowCoder WindowEncoder elementCount int64 // must use atomic operations. sizeMu sync.Mutex sizeCount, sizeSum, sizeMin, sizeMax int64 + dataSampler *DataSampler } // ID returns the debug id for this unit. @@ -57,6 +62,7 @@ func (p *PCollection) Up(ctx context.Context) error { // dedicated rand source p.r = rand.New(rand.NewSource(p.Seed)) p.elementCoder = MakeElementEncoder(p.Coder) + p.windowCoder = MakeWindowEncoder(p.WindowCoder) return nil } @@ -93,9 +99,19 @@ func (p *PCollection) ProcessElement(ctx context.Context, elm *FullValue, values } else { p.nextSampleIdx = cur + p.r.Int63n(cur/10+2) + 1 } - var w byteCounter - p.elementCoder.Encode(elm, &w) - p.addSize(int64(w.count)) + + if p.dataSampler == nil { + var w byteCounter + p.elementCoder.Encode(elm, &w) + p.addSize(int64(w.count)) + } else { + var buf bytes.Buffer + EncodeWindowedValueHeader(p.windowCoder, elm.Windows, elm.Timestamp, elm.Pane, &buf) + winSize := buf.Len() + p.elementCoder.Encode(elm, &buf) + p.addSize(int64(buf.Len() - winSize)) + p.dataSampler.SendSample(p.PColID, buf.Bytes(), time.Now()) + } } return p.Out.ProcessElement(ctx, elm, values...) } diff --git a/sdks/go/pkg/beam/core/runtime/exec/pcollection_test.go b/sdks/go/pkg/beam/core/runtime/exec/pcollection_test.go index 1cb6adee97d6..1b702588c051 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/pcollection_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/pcollection_test.go @@ -30,7 +30,7 @@ import ( // randomness for the samples. func TestPCollection(t *testing.T) { a := &CaptureNode{UID: 1} - pcol := &PCollection{UID: 2, Out: a, Coder: coder.NewVarInt()} + pcol := &PCollection{UID: 2, Out: a, Coder: coder.NewVarInt(), WindowCoder: coder.NewGlobalWindow()} // The "large" 2nd value is to ensure the values are encoded properly, // and that Min & Max are behaving. inputs := []any{int64(1), int64(2000000000), int64(3)} @@ -99,7 +99,7 @@ func BenchmarkPCollection(b *testing.B) { Elm: int64(1), }}) } - pcol := &PCollection{UID: 2, Out: out, Coder: coder.NewVarInt()} + pcol := &PCollection{UID: 2, Out: out, Coder: coder.NewVarInt(), WindowCoder: coder.NewGlobalWindow()} n := &FixedRoot{UID: 3, Elements: process, Out: pcol} p, err := NewPlan("a", []Unit{n, pcol, out}) if err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/exec/to_string.go b/sdks/go/pkg/beam/core/runtime/exec/to_string.go index 2196fd951806..df7050483b16 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/to_string.go +++ b/sdks/go/pkg/beam/core/runtime/exec/to_string.go @@ -45,6 +45,7 @@ func (m *ToString) ProcessElement(ctx context.Context, elm *FullValue, values .. Elm: elm.Elm, Elm2: fmt.Sprintf("%v", elm.Elm2), Timestamp: elm.Timestamp, + Pane: elm.Pane, } return m.Out.ProcessElement(ctx, &ret, values...) diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go index 6b3e3e457229..115fe187daa4 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go @@ -51,8 +51,8 @@ const ( ) // UnmarshalPlan converts a model bundle descriptor into an execution Plan. -func UnmarshalPlan(desc *fnpb.ProcessBundleDescriptor) (*Plan, error) { - b, err := newBuilder(desc) +func UnmarshalPlan(desc *fnpb.ProcessBundleDescriptor, dataSampler *DataSampler) (*Plan, error) { + b, err := newBuilder(desc, dataSampler) if err != nil { return nil, err } @@ -169,8 +169,9 @@ type builder struct { nodes map[string]*PCollection // PCollectionID -> Node (cache) links map[linkID]Node // linkID -> Node (cache) - units []Unit // result - idgen *GenID + units []Unit // result + idgen *GenID + dataSampler *DataSampler } // linkID represents an incoming data link to an Node. @@ -179,7 +180,7 @@ type linkID struct { input int // input index. If > 0, it's a side input. } -func newBuilder(desc *fnpb.ProcessBundleDescriptor) (*builder, error) { +func newBuilder(desc *fnpb.ProcessBundleDescriptor, dataSampler *DataSampler) (*builder, error) { // Preprocess graph structure to allow insertion of Multiplex, // Flatten and Discard. @@ -216,7 +217,8 @@ func newBuilder(desc *fnpb.ProcessBundleDescriptor) (*builder, error) { nodes: make(map[string]*PCollection), links: make(map[linkID]Node), - idgen: &GenID{}, + idgen: &GenID{}, + dataSampler: dataSampler, } return b, nil } @@ -411,11 +413,11 @@ func (b *builder) makePCollection(id string) (*PCollection, error) { } func (b *builder) newPCollectionNode(id string, out Node) (*PCollection, error) { - ec, _, err := b.makeCoderForPCollection(id) + ec, wc, err := b.makeCoderForPCollection(id) if err != nil { return nil, err } - u := &PCollection{UID: b.idgen.New(), Out: out, PColID: id, Coder: ec, Seed: rand.Int63()} + u := &PCollection{UID: b.idgen.New(), Out: out, PColID: id, Coder: ec, WindowCoder: wc, Seed: rand.Int63(), dataSampler: b.dataSampler} b.nodes[id] = u b.units = append(b.units, u) return u, nil diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate_test.go b/sdks/go/pkg/beam/core/runtime/exec/translate_test.go index c6a70fe07a02..a9917ec456fe 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate_test.go @@ -460,7 +460,7 @@ func TestUnmarshalPlan(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - plan, err := UnmarshalPlan(test.inputDesc) + plan, err := UnmarshalPlan(test.inputDesc, nil) if err != nil && test.outputError == nil { t.Errorf("there is an error where should not be. UnmarshalPlan(%v) = (%v, %v), want (%v, %v)", test.inputDesc, plan, err, test.outputPlan, test.outputError) } else if err != nil && !reflect.DeepEqual(err, test.outputError) { @@ -503,7 +503,7 @@ func TestNewBuilder(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - b, err := newBuilder(test.inputDesc) + b, err := newBuilder(test.inputDesc, nil) if err != nil && test.outputError == nil { t.Errorf("There is an error where should not be. newBuilder(%v) = (%v, %v), want (%v, %v)", test.inputDesc, b, err, test.outputBuilder, test.outputError) } else if err != nil && err != test.outputError { diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go index 9ef28eb7809b..3b7cfc5639cd 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go @@ -73,6 +73,7 @@ const ( URNMultiCore = "beam:protocol:multi_core_bundle_processing:v1" URNWorkerStatus = "beam:protocol:worker_status:v1" URNMonitoringInfoShortID = "beam:protocol:monitoring_info_short_ids:v1" + URNDataSampling = "beam:protocol:data_sampling:v1" URNRequiresSplittableDoFn = "beam:requirement:pardo:splittable_dofn:v1" URNRequiresBundleFinalization = "beam:requirement:pardo:finalization:v1" @@ -109,6 +110,7 @@ func goCapabilities() []string { URNMonitoringInfoShortID, URNBaseVersionGo, URNToString, + URNDataSampling, } return append(capabilities, knownStandardCoders()...) } diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go index ed57e3eca59b..0f2de99dd2ad 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go @@ -386,14 +386,31 @@ func (c *DataChannel) read(ctx context.Context) { return } + // Consolidating required timer writer creation to a optional single lock section. + type seenTimers struct { + InstID instructionID + PTransformID, FamilyID string + } + neededTimers := map[seenTimers]struct{}{} + // Each message may contain segments for multiple streams, so we // must treat each segment in isolation. We maintain a local cache // to reduce lock contention. iterateElements(c, cache, &seenLast, msg.GetTimers(), func(elm *fnpb.Elements_Timers) exec.Elements { + neededTimers[seenTimers{InstID: instructionID(elm.GetInstructionId()), PTransformID: elm.GetTransformId(), FamilyID: elm.GetTimerFamilyId()}] = struct{}{} return exec.Elements{Timers: elm.GetTimers(), PtransformID: elm.GetTransformId(), TimerFamilyID: elm.GetTimerFamilyId()} }) + // Creating a writer is necessary to ensure a "is_last" signal is returned for timers that aren't set. + if len(neededTimers) > 0 { + c.mu.Lock() + for key := range neededTimers { + c.makeTimerWriterLocked(ctx, clientID{ptransformID: key.PTransformID, instID: key.InstID}, key.FamilyID) + } + c.mu.Unlock() + } + iterateElements(c, cache, &seenLast, msg.GetData(), func(elm *fnpb.Elements_Data) exec.Elements { return exec.Elements{Data: elm.GetData(), PtransformID: elm.GetTransformId()} @@ -629,7 +646,13 @@ func (w *dataWriter) Write(p []byte) (n int, err error) { func (c *DataChannel) makeTimerWriter(ctx context.Context, id clientID, family string) *timerWriter { c.mu.Lock() defer c.mu.Unlock() + return c.makeTimerWriterLocked(ctx, id, family) +} +// makeTimerWriterLocked does the work of makeTimerWriter, but doesn't call the lock methods. +// +// c.mu must be locked when this is called. +func (c *DataChannel) makeTimerWriterLocked(ctx context.Context, id clientID, family string) *timerWriter { var m map[timerKey]*timerWriter var ok bool if m, ok = c.timerWriters[id.instID]; !ok { diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go index c7f8ac5858c1..92c4d0a8f8cd 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go @@ -389,6 +389,18 @@ func TestElementChan(t *testing.T) { return elms }, wantSum: 0, wantCount: 0, + }, { + name: "SomeTimersAndADataThenReaderThenCleanup", + sequenceFn: func(ctx context.Context, t *testing.T, client *fakeChanClient, c *DataChannel) <-chan exec.Elements { + client.Send(&fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{timerElm(1, false), timerElm(2, true)}, + Data: []*fnpb.Elements_Data{dataElm(3, true)}, + }) + elms := openChan(ctx, t, c, timerID) + c.removeInstruction(instID) + return elms + }, + wantSum: 6, wantCount: 3, }, } for _, test := range tests { diff --git a/sdks/go/pkg/beam/core/runtime/harness/harness.go b/sdks/go/pkg/beam/core/runtime/harness/harness.go index c5db9a85f367..6a66c81a0a60 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/harness.go +++ b/sdks/go/pkg/beam/core/runtime/harness/harness.go @@ -30,6 +30,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/harness/statecache" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/hooks" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" @@ -40,6 +41,7 @@ import ( "golang.org/x/sync/singleflight" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" ) // URNMonitoringInfoShortID is a URN indicating support for short monitoring info IDs. @@ -157,6 +159,11 @@ func MainWithOptions(ctx context.Context, loggingEndpoint, controlEndpoint strin runnerCapabilities: rcMap, } + if enabled, ok := rcMap[graphx.URNDataSampling]; ok && enabled { + ctrl.dataSampler = exec.NewDataSampler(ctx) + go ctrl.dataSampler.Process() + } + // if the runner supports worker status api then expose SDK harness status if opts.StatusEndpoint != "" { statusHandler, err := newWorkerStatusHandler(ctx, opts.StatusEndpoint, ctrl.cache, func(statusInfo *strings.Builder) { ctrl.metStoreToString(statusInfo) }) @@ -304,6 +311,7 @@ type control struct { // TODO(BEAM-11097): Cache is currently unused. cache *statecache.SideInputCache runnerCapabilities map[string]bool + dataSampler *exec.DataSampler } func (c *control) metStoreToString(statusInfo *strings.Builder) { @@ -345,7 +353,7 @@ func (c *control) getOrCreatePlan(bdID bundleDescriptorID) (*exec.Plan, error) { } desc = newDesc.(*fnpb.ProcessBundleDescriptor) } - newPlan, err := exec.UnmarshalPlan(desc) + newPlan, err := exec.UnmarshalPlan(desc, c.dataSampler) if err != nil { return nil, errors.WithContextf(err, "invalid bundle desc: %v\n%v\n", bdID, desc.String()) } @@ -654,7 +662,28 @@ func (c *control) handleInstruction(ctx context.Context, req *fnpb.InstructionRe }, }, } - + case req.GetSampleData() != nil: + msg := req.GetSampleData() + var samples = make(map[string]*fnpb.SampleDataResponse_ElementList) + var elementsMap = c.dataSampler.GetSamples(msg.GetPcollectionIds()) + + for pid, elements := range elementsMap { + var elementList fnpb.SampleDataResponse_ElementList + for i := range elements { + var sampledElement = &fnpb.SampledElement{ + Element: elements[i].Element, + SampleTimestamp: timestamppb.New(elements[i].Timestamp), + } + elementList.Elements = append(elementList.Elements, sampledElement) + } + samples[pid] = &elementList + } + return &fnpb.InstructionResponse{ + InstructionId: string(instID), + Response: &fnpb.InstructionResponse_SampleData{ + SampleData: &fnpb.SampleDataResponse{ElementSamples: samples}, + }, + } default: return fail(ctx, instID, "Unexpected request: %v", req) } diff --git a/sdks/go/pkg/beam/core/runtime/harness/harness_test.go b/sdks/go/pkg/beam/core/runtime/harness/harness_test.go index 84c5770c71a1..91dd3c591d5b 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/harness_test.go +++ b/sdks/go/pkg/beam/core/runtime/harness/harness_test.go @@ -94,7 +94,7 @@ func invalidDescriptor(t *testing.T) *fnpb.ProcessBundleDescriptor { func TestControl_getOrCreatePlan(t *testing.T) { testBDID := bundleDescriptorID("test") - testPlan, err := exec.UnmarshalPlan(validDescriptor(t)) + testPlan, err := exec.UnmarshalPlan(validDescriptor(t), nil) if err != nil { t.Fatal("bad testPlan") } diff --git a/sdks/go/pkg/beam/core/runtime/harness/init/init.go b/sdks/go/pkg/beam/core/runtime/harness/init/init.go index 55f3d0beabdc..27d44f7029e5 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/init/init.go +++ b/sdks/go/pkg/beam/core/runtime/harness/init/init.go @@ -22,6 +22,7 @@ import ( "context" "encoding/json" "flag" + "slices" "strings" "time" @@ -31,6 +32,7 @@ import ( "runtime/debug" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/harness" // Import gcs filesystem so that it can be used to upload heap dumps. @@ -78,6 +80,13 @@ func hook() { return } + // Extract environment variables. These are optional runner supported capabilities. + // Expected env variables: + // RUNNER_CAPABILITIES : list of runner supported capability urn. + // STATUS_ENDPOINT : Endpoint to connect to status server used for worker status reporting. + statusEndpoint := os.Getenv("STATUS_ENDPOINT") + runnerCapabilities := strings.Split(os.Getenv("RUNNER_CAPABILITIES"), " ") + // Initialization logging // // We use direct output to stderr here, because it is expected that logging @@ -91,6 +100,14 @@ func hook() { os.Exit(1) } runtime.GlobalOptions.Import(opt.Options) + var experiments []string + if e, ok := opt.Options.Options["experiments"]; ok { + experiments = strings.Split(e, ",") + } + // TODO(zechenj18) 2023-12-07: Remove once the data sampling URN is properly sent in via the capabilities + if slices.Contains(experiments, "enable_data_sampling") { + runnerCapabilities = append(runnerCapabilities, graphx.URNDataSampling) + } } defer func() { @@ -120,12 +137,6 @@ func hook() { fmt.Println("Error Setting Rlimit ", err) } - // Extract environment variables. These are optional runner supported capabilities. - // Expected env variables: - // RUNNER_CAPABILITIES : list of runner supported capability urn. - // STATUS_ENDPOINT : Endpoint to connect to status server used for worker status reporting. - statusEndpoint := os.Getenv("STATUS_ENDPOINT") - runnerCapabilities := strings.Split(os.Getenv("RUNNER_CAPABILITIES"), " ") options := harness.Options{ StatusEndpoint: statusEndpoint, RunnerCapabilities: runnerCapabilities, diff --git a/sdks/go/pkg/beam/core/typex/fulltype.go b/sdks/go/pkg/beam/core/typex/fulltype.go index 41ef0ab09d22..ff5520c28617 100644 --- a/sdks/go/pkg/beam/core/typex/fulltype.go +++ b/sdks/go/pkg/beam/core/typex/fulltype.go @@ -124,7 +124,7 @@ func New(t reflect.Type, components ...FullType) FullType { if len(components) != 2 { panic(fmt.Sprintf("Invalid number of components for KV: %v, %v", t, components)) } - if isAnyNonKVComposite(components) { + if isAnyNonKVAndNonWindowedComposite(components) { panic(fmt.Sprintf("Invalid to nest composite composites inside KV: %v, %v", t, components)) } return &tree{class, t, components} @@ -169,6 +169,15 @@ func isAnyNonKVComposite(list []FullType) bool { return false } +func isAnyNonKVAndNonWindowedComposite(list []FullType) bool { + for _, t := range list { + if t.Class() == Composite && t.Type() != KVType && t.Type() != WindowedValueType { + return true + } + } + return false +} + // Convenience functions. // IsW returns true iff the type is a WindowedValue. diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go index eb26071d10ec..ed706ec1a482 100644 --- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go +++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go @@ -154,6 +154,7 @@ func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, worker return nil, err } + opts.Options.Options["experiments"] = strings.Join(opts.Experiments, ",") job := &df.Job{ ProjectId: opts.Project, Name: opts.Name, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaTransformProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaTransformProvider.java index e542007c9a55..c76d7a25e69b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaTransformProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaTransformProvider.java @@ -36,6 +36,15 @@ public interface SchemaTransformProvider { /** Returns an id that uniquely represents this transform. */ String identifier(); + /** + * Returns a description regarding the {@link SchemaTransform} represented by the {@link + * SchemaTransformProvider}. Please keep the language generic (i.e. not specific to any + * programming language). The description may be markdown formatted. + */ + default String description() { + return ""; + } + /** * Returns the expected schema of the configuration object. Note this is distinct from the schema * of the transform itself. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/TypedSchemaTransformProviderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/TypedSchemaTransformProviderTest.java index db7b1436a128..2b698f4f67bb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/TypedSchemaTransformProviderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/TypedSchemaTransformProviderTest.java @@ -61,6 +61,11 @@ public String identifier() { return "fake:v1"; } + @Override + public String description() { + return "Description of fake provider"; + } + @Override protected Class configurationClass() { return Configuration.class; @@ -115,6 +120,7 @@ public void testFrom() { Configuration outputConfig = ((FakeSchemaTransform) provider.from(inputConfig)).config; assertEquals("field1", outputConfig.getField1()); assertEquals(13, outputConfig.getField2().intValue()); + assertEquals("Description of fake provider", provider.description()); } @Test diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index 43690c603701..7760cab64acc 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -767,6 +767,7 @@ DiscoverSchemaTransformResponse discover(DiscoverSchemaTransformRequest request) transformProvider.getAllProviders()) { SchemaTransformConfig.Builder schemaTransformConfigBuilder = SchemaTransformConfig.newBuilder(); + schemaTransformConfigBuilder.setDescription(provider.description()); schemaTransformConfigBuilder.setConfigSchema( SchemaTranslation.schemaToProto(provider.configurationSchema(), true)); schemaTransformConfigBuilder.addAllInputPcollectionNames(provider.inputCollectionNames()); diff --git a/sdks/java/extensions/sql/expansion-service/src/main/java/org/apache/beam/sdk/extensions/sql/expansion/SqlTransformSchemaTransformProvider.java b/sdks/java/extensions/sql/expansion-service/src/main/java/org/apache/beam/sdk/extensions/sql/expansion/SqlTransformSchemaTransformProvider.java index 54415644152f..f032da0799d8 100644 --- a/sdks/java/extensions/sql/expansion-service/src/main/java/org/apache/beam/sdk/extensions/sql/expansion/SqlTransformSchemaTransformProvider.java +++ b/sdks/java/extensions/sql/expansion-service/src/main/java/org/apache/beam/sdk/extensions/sql/expansion/SqlTransformSchemaTransformProvider.java @@ -71,6 +71,21 @@ public String identifier() { return "schematransform:org.apache.beam:sql_transform:v1"; } + @Override + public String description() { + return "A transform that executes a SQL query on its input PCollections.\n\n" + + "If a single input is given, it may be referred to as `PCOLLECTION`, e.g. the query could be of the form" + + "\n\n" + + " SELECT a, sum(b) FROM PCOLLECTION" + + "\n\n" + + "If multiple inputs are given, the should be named as they are in the query, e.g." + + "\n\n" + + " SELECT a, b, c FROM pcoll_1 join pcoll_2 using (b)" + + "\n\n" + + "For more details about Beam SQL in general see " + + "[the Beam SQL documentation](https://beam.apache.org/documentation/dsls/sql/overview/)."; + } + @Override public Schema configurationSchema() { List providers = new ArrayList<>(); @@ -82,7 +97,7 @@ public Schema configurationSchema() { EnumerationType providerEnum = EnumerationType.create(providers); return Schema.of( - Schema.Field.of("query", Schema.FieldType.STRING), + Schema.Field.of("query", Schema.FieldType.STRING).withDescription("SQL query to execute"), Schema.Field.nullable( "ddl", Schema.FieldType.STRING), // TODO: Underlying builder seems more capable? Schema.Field.nullable("dialect", Schema.FieldType.logicalType(QUERY_ENUMERATION)), diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java index 619eea6cc70f..ec7429fcdc0e 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java @@ -54,6 +54,7 @@ public class MultimapSideInput implements MultimapView { private final Coder keyCoder; private final Coder valueCoder; private volatile Function> bulkReadResult; + private final boolean useBulkRead; public MultimapSideInput( Cache cache, @@ -62,6 +63,18 @@ public MultimapSideInput( StateKey stateKey, Coder keyCoder, Coder valueCoder) { + // TODO(robertwb): Plumb the value of useBulkRead from runner capabilities. + this(cache, beamFnStateClient, instructionId, stateKey, keyCoder, valueCoder, false); + } + + public MultimapSideInput( + Cache cache, + BeamFnStateClient beamFnStateClient, + String instructionId, + StateKey stateKey, + Coder keyCoder, + Coder valueCoder, + boolean useBulkRead) { checkArgument( stateKey.hasMultimapKeysSideInput(), "Expected MultimapKeysSideInput StateKey but received %s.", @@ -72,6 +85,7 @@ public MultimapSideInput( StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(); this.keyCoder = keyCoder; this.valueCoder = valueCoder; + this.useBulkRead = useBulkRead; } @Override @@ -84,62 +98,70 @@ public Iterable get() { public Iterable get(K k) { ByteString encodedKey = encodeKey(k); - if (bulkReadResult == null) { - synchronized (this) { - if (bulkReadResult == null) { - Map> bulkRead = new HashMap<>(); - StateKey bulkReadStateKey = - StateKey.newBuilder() - .setMultimapKeysValuesSideInput( - StateKey.MultimapKeysValuesSideInput.newBuilder() - .setTransformId( - keysRequest.getStateKey().getMultimapKeysSideInput().getTransformId()) - .setSideInputId( - keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId()) - .setWindow( - keysRequest.getStateKey().getMultimapKeysSideInput().getWindow())) - .build(); + if (useBulkRead) { + if (bulkReadResult == null) { + synchronized (this) { + if (bulkReadResult == null) { + Map> bulkRead = new HashMap<>(); + StateKey bulkReadStateKey = + StateKey.newBuilder() + .setMultimapKeysValuesSideInput( + StateKey.MultimapKeysValuesSideInput.newBuilder() + .setTransformId( + keysRequest + .getStateKey() + .getMultimapKeysSideInput() + .getTransformId()) + .setSideInputId( + keysRequest + .getStateKey() + .getMultimapKeysSideInput() + .getSideInputId()) + .setWindow( + keysRequest.getStateKey().getMultimapKeysSideInput().getWindow())) + .build(); - StateRequest bulkReadRequest = - keysRequest.toBuilder().setStateKey(bulkReadStateKey).build(); - try { - Iterator>> entries = - StateFetchingIterators.readAllAndDecodeStartingFrom( - Caches.subCache(cache, "ValuesForKey", encodedKey), - beamFnStateClient, - bulkReadRequest, - KvCoder.of(keyCoder, IterableCoder.of(valueCoder))) - .iterator(); - while (bulkRead.size() < BULK_READ_SIZE && entries.hasNext()) { - KV> entry = entries.next(); - bulkRead.put(encodeKey(entry.getKey()), entry.getValue()); - } - if (entries.hasNext()) { + StateRequest bulkReadRequest = + keysRequest.toBuilder().setStateKey(bulkReadStateKey).build(); + try { + Iterator>> entries = + StateFetchingIterators.readAllAndDecodeStartingFrom( + Caches.subCache(cache, "ValuesForKey", encodedKey), + beamFnStateClient, + bulkReadRequest, + KvCoder.of(keyCoder, IterableCoder.of(valueCoder))) + .iterator(); + while (bulkRead.size() < BULK_READ_SIZE && entries.hasNext()) { + KV> entry = entries.next(); + bulkRead.put(encodeKey(entry.getKey()), entry.getValue()); + } + if (entries.hasNext()) { + bulkReadResult = bulkRead::get; + } else { + bulkReadResult = + key -> { + Iterable result = bulkRead.get(key); + if (result == null) { + // As we read the entire set of values, we don't have to do a lookup to know + // this key doesn't exist. + // Missing keys are treated as empty iterables in this multimap. + return Collections.emptyList(); + } else { + return result; + } + }; + } + } catch (Exception exn) { bulkReadResult = bulkRead::get; - } else { - bulkReadResult = - key -> { - Iterable result = bulkRead.get(key); - if (result == null) { - // As we read the entire set of values, we don't have to do a lookup to know - // this key doesn't exist. - // Missing keys are treated as empty iterables in this multimap. - return Collections.emptyList(); - } else { - return result; - } - }; } - } catch (Exception exn) { - bulkReadResult = bulkRead::get; } } } - } - Iterable bulkReadValues = bulkReadResult.apply(encodedKey); - if (bulkReadValues != null) { - return bulkReadValues; + Iterable bulkReadValues = bulkReadResult.apply(encodedKey); + if (bulkReadValues != null) { + return bulkReadValues; + } } StateKey stateKey = diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java index 17ebf4234396..23a572894b40 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java @@ -69,7 +69,8 @@ public void testGetWithBulkRead() throws Exception { "instructionId", keysStateKey(), ByteArrayCoder.of(), - StringUtf8Coder.of()); + StringUtf8Coder.of(), + true); assertArrayEquals( new String[] {"A1", "A2", "A3"}, Iterables.toArray(multimapSideInput.get(A), String.class)); assertArrayEquals( @@ -94,7 +95,8 @@ public void testGet() throws Exception { "instructionId", keysStateKey(), ByteArrayCoder.of(), - StringUtf8Coder.of()); + StringUtf8Coder.of(), + true); assertArrayEquals( new String[] {"A1", "A2", "A3"}, Iterables.toArray(multimapSideInput.get(A), String.class)); assertArrayEquals( @@ -124,7 +126,8 @@ public void testGetCached() throws Exception { "instructionId", keysStateKey(), ByteArrayCoder.of(), - StringUtf8Coder.of()); + StringUtf8Coder.of(), + true); assertArrayEquals( new String[] {"A1", "A2", "A3"}, Iterables.toArray(multimapSideInput.get(A), String.class)); @@ -147,7 +150,8 @@ public void testGetCached() throws Exception { "instructionId", keysStateKey(), ByteArrayCoder.of(), - StringUtf8Coder.of()); + StringUtf8Coder.of(), + true); assertArrayEquals( new String[] {"A1", "A2", "A3"}, Iterables.toArray(multimapSideInput.get(A), String.class)); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index 98cc246ce0dd..8c4edd2244b4 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -98,6 +98,16 @@ public String identifier() { return String.format("beam:schematransform:org.apache.beam:bigquery_storage_write:v2"); } + @Override + public String description() { + return String.format( + "Writes data to BigQuery using the Storage Write API (https://cloud.google.com/bigquery/docs/write-api)." + + "\n\nThis expects a single PCollection of Beam Rows and outputs two dead-letter queues (DLQ) that " + + "contain failed rows. The first DLQ has tag [%s] and contains the failed rows. The second DLQ has " + + "tag [%s] and contains failed rows and along with their respective errors.", + FAILED_ROWS_TAG, FAILED_ROWS_WITH_ERRORS_TAG); + } + @Override public List inputCollectionNames() { return Collections.singletonList(INPUT_ROWS_TAG); diff --git a/sdks/python/apache_beam/io/requestresponseio.py b/sdks/python/apache_beam/io/requestresponseio.py index fc742fa00cad..0ec586e64018 100644 --- a/sdks/python/apache_beam/io/requestresponseio.py +++ b/sdks/python/apache_beam/io/requestresponseio.py @@ -17,11 +17,24 @@ """``PTransform`` for reading from and writing to Web APIs.""" import abc +import concurrent.futures +import contextlib +import logging +import sys +from typing import Generic +from typing import Optional from typing import TypeVar +import apache_beam as beam +from apache_beam.pvalue import PCollection + RequestT = TypeVar('RequestT') ResponseT = TypeVar('ResponseT') +DEFAULT_TIMEOUT_SECS = 30 # seconds + +_LOGGER = logging.getLogger(__name__) + class UserCodeExecutionException(Exception): """Base class for errors related to calling Web APIs.""" @@ -37,8 +50,10 @@ class UserCodeTimeoutException(UserCodeExecutionException): """Extends ``UserCodeExecutionException`` to signal a user code timeout.""" -class Caller(metaclass=abc.ABCMeta): - """Interfaces user custom code intended for API calls.""" +class Caller(contextlib.AbstractContextManager, abc.ABC): + """Interface for user custom code intended for API calls. + For setup and teardown of clients when applicable, implement the + ``__enter__`` and ``__exit__`` methods respectively.""" @abc.abstractmethod def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT: """Calls a Web API with the ``RequestT`` and returns a @@ -48,18 +63,156 @@ def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT: """ pass + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return None + + +class ShouldBackOff(abc.ABC): + """ + ShouldBackOff provides mechanism to apply adaptive throttling. + """ + pass + + +class Repeater(abc.ABC): + """Repeater provides mechanism to repeat requests for a + configurable condition.""" + pass -class SetupTeardown(metaclass=abc.ABCMeta): - """Interfaces user custom code to set up and teardown the API clients. - Called by ``RequestResponseIO`` within its DoFn's setup and teardown - methods. + +class CacheReader(abc.ABC): + """CacheReader provides mechanism to read from the cache.""" + pass + + +class CacheWriter(abc.ABC): + """CacheWriter provides mechanism to write to the cache.""" + pass + + +class PreCallThrottler(abc.ABC): + """PreCallThrottler provides a throttle mechanism before sending request.""" + pass + + +class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], + beam.PCollection[ResponseT]]): + """A :class:`RequestResponseIO` transform to read and write to APIs. + + Processes an input :class:`~apache_beam.pvalue.PCollection` of requests + by making a call to the API as defined in :class:`Caller`'s `__call__` + and returns a :class:`~apache_beam.pvalue.PCollection` of responses. + """ + def __init__( + self, + caller: [Caller], + timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, + should_backoff: Optional[ShouldBackOff] = None, + repeater: Optional[Repeater] = None, + cache_reader: Optional[CacheReader] = None, + cache_writer: Optional[CacheWriter] = None, + throttler: Optional[PreCallThrottler] = None, + ): """ - @abc.abstractmethod - def setup(self) -> None: - """Called during the DoFn's setup lifecycle method.""" - pass + Instantiates a RequestResponseIO transform. - @abc.abstractmethod - def teardown(self) -> None: - """Called during the DoFn's teardown lifecycle method.""" - pass + Args: + caller (~apache_beam.io.requestresponseio.Caller): an implementation of + `Caller` object that makes call to the API. + timeout (float): timeout value in seconds to wait for response from API. + should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff): + (Optional) provides methods for backoff. + repeater (~apache_beam.io.requestresponseio.Repeater): (Optional) + provides methods to repeat requests to API. + cache_reader (~apache_beam.io.requestresponseio.CacheReader): (Optional) + provides methods to read external cache. + cache_writer (~apache_beam.io.requestresponseio.CacheWriter): (Optional) + provides methods to write to external cache. + throttler (~apache_beam.io.requestresponseio.PreCallThrottler): + (Optional) provides methods to pre-throttle a request. + """ + self._caller = caller + self._timeout = timeout + self._should_backoff = should_backoff + self._repeater = repeater + self._cache_reader = cache_reader + self._cache_writer = cache_writer + self._throttler = throttler + + def expand(self, requests: PCollection[RequestT]) -> PCollection[ResponseT]: + # TODO(riteshghorse): add Cache and Throttle PTransforms. + return requests | _Call( + caller=self._caller, + timeout=self._timeout, + should_backoff=self._should_backoff, + repeater=self._repeater) + + +class _Call(beam.PTransform[beam.PCollection[RequestT], + beam.PCollection[ResponseT]]): + """(Internal-only) PTransform that invokes a remote function on each element + of the input PCollection. + + This PTransform uses a `Caller` object to invoke the actual API calls, + and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of + clients when applicable. Additionally, a timeout value is specified to + regulate the duration of each call, defaults to 30 seconds. + + Args: + caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable + object that invokes API call. + timeout (float): timeout value in seconds to wait for response from API. + """ + def __init__( + self, + caller: Caller, + timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, + should_backoff: Optional[ShouldBackOff] = None, + repeater: Optional[Repeater] = None, + ): + """Initialize the _Call transform. + Args: + caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable + object that invokes API call. + timeout (float): timeout value in seconds to wait for response from API. + should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff): + (Optional) provides methods for backoff. + repeater (~apache_beam.io.requestresponseio.Repeater): (Optional) provides + methods to repeat requests to API. + """ + self._caller = caller + self._timeout = timeout + self._should_backoff = should_backoff + self._repeater = repeater + + def expand( + self, + requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: + return requests | beam.ParDo(_CallDoFn(self._caller, self._timeout)) + + +class _CallDoFn(beam.DoFn, Generic[RequestT, ResponseT]): + def setup(self): + self._caller.__enter__() + + def __init__(self, caller: Caller, timeout: float): + self._caller = caller + self._timeout = timeout + + def process(self, request, *args, **kwargs): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(self._caller, request) + try: + yield future.result(timeout=self._timeout) + except concurrent.futures.TimeoutError: + raise UserCodeTimeoutException( + f'Timeout {self._timeout} exceeded ' + f'while completing request: {request}') + except RuntimeError: + raise UserCodeExecutionException('could not complete request') + + def teardown(self): + self._caller.__exit__(*sys.exc_info()) diff --git a/sdks/python/apache_beam/io/requestresponseio_it_test.py b/sdks/python/apache_beam/io/requestresponseio_it_test.py index f291ff96a4d7..aae6b4e6ef2c 100644 --- a/sdks/python/apache_beam/io/requestresponseio_it_test.py +++ b/sdks/python/apache_beam/io/requestresponseio_it_test.py @@ -23,10 +23,13 @@ import urllib3 +import apache_beam as beam from apache_beam.io.requestresponseio import Caller +from apache_beam.io.requestresponseio import RequestResponseIO from apache_beam.io.requestresponseio import UserCodeExecutionException from apache_beam.io.requestresponseio import UserCodeQuotaException from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.testing.test_pipeline import TestPipeline _HTTP_PATH = '/v1/echo' _PAYLOAD = base64.b64encode(bytes('payload', 'utf-8')) @@ -86,7 +89,6 @@ def __call__(self, request: EchoRequest, *args, **kwargs) -> EchoResponse: ``UserCodeExecutionException``, ``UserCodeTimeoutException``, or a ``UserCodeQuotaException``. """ - try: resp = urllib3.request( "POST", @@ -104,8 +106,8 @@ def __call__(self, request: EchoRequest, *args, **kwargs) -> EchoResponse: if resp.status == 429: # Too Many Requests raise UserCodeQuotaException(resp.reason) - - raise UserCodeExecutionException(resp.reason) + else: + raise UserCodeExecutionException(resp.status, resp.reason, request) except urllib3.exceptions.HTTPError as e: raise UserCodeExecutionException(e) @@ -167,6 +169,16 @@ def test_not_found_should_raise(self): self.assertRaisesRegex( UserCodeExecutionException, "Not Found", lambda: client(req)) + def test_request_response_io(self): + client, options = EchoHTTPCallerTestIT._get_client_and_options() + req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD) + with TestPipeline(is_integration_test=True) as test_pipeline: + output = ( + test_pipeline + | 'Create PCollection' >> beam.Create([req]) + | 'RRIO Transform' >> RequestResponseIO(client)) + self.assertIsNotNone(output) + if __name__ == '__main__': unittest.main(argv=sys.argv[:1]) diff --git a/sdks/python/apache_beam/io/requestresponseio_test.py b/sdks/python/apache_beam/io/requestresponseio_test.py new file mode 100644 index 000000000000..2828a3578871 --- /dev/null +++ b/sdks/python/apache_beam/io/requestresponseio_test.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import time +import unittest + +import apache_beam as beam +from apache_beam.io.requestresponseio import Caller +from apache_beam.io.requestresponseio import RequestResponseIO +from apache_beam.io.requestresponseio import UserCodeExecutionException +from apache_beam.io.requestresponseio import UserCodeTimeoutException +from apache_beam.testing.test_pipeline import TestPipeline + + +class AckCaller(Caller): + """AckCaller acknowledges the incoming request by returning a + request with ACK.""" + def __enter__(self): + pass + + def __call__(self, request: str): + return f"ACK: {request}" + + def __exit__(self, exc_type, exc_val, exc_tb): + return None + + +class CallerWithTimeout(AckCaller): + """CallerWithTimeout sleeps for 2 seconds before responding. + Used to test timeout in RequestResponseIO.""" + def __call__(self, request: str, *args, **kwargs): + time.sleep(2) + return f"ACK: {request}" + + +class CallerWithRuntimeError(AckCaller): + """CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO + to raise a UserCodeExecutionException.""" + def __call__(self, request: str, *args, **kwargs): + if not request: + raise RuntimeError("Exception expected, not an error.") + + +class TestCaller(unittest.TestCase): + def test_valid_call(self): + caller = AckCaller() + with TestPipeline() as test_pipeline: + output = ( + test_pipeline + | beam.Create(["sample_request"]) + | RequestResponseIO(caller=caller)) + + self.assertIsNotNone(output) + + def test_call_timeout(self): + caller = CallerWithTimeout() + with self.assertRaises(UserCodeTimeoutException): + with TestPipeline() as test_pipeline: + _ = ( + test_pipeline + | beam.Create(["timeout_request"]) + | RequestResponseIO(caller=caller, timeout=1)) + + def test_call_runtime_error(self): + caller = CallerWithRuntimeError() + with self.assertRaises(UserCodeExecutionException): + with TestPipeline() as test_pipeline: + _ = ( + test_pipeline + | beam.Create([""]) + | RequestResponseIO(caller=caller)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index b35997c4250f..02a2f6016f71 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -388,7 +388,8 @@ def __init__(self, transform_id, # type: str tag, # type: Optional[str] side_input_data, # type: pvalue.SideInputData - coder # type: WindowedValueCoder + coder, # type: WindowedValueCoder + use_bulk_read = False, # type: bool ): # type: (...) -> None self._state_handler = state_handler @@ -399,6 +400,7 @@ def __init__(self, self._target_window_coder = coder.window_coder # TODO(robertwb): Limit the cache size. self._cache = {} # type: Dict[BoundedWindow, Any] + self._use_bulk_read = use_bulk_read def __getitem__(self, window): target_window = self._side_input_data.window_mapping_fn(window) @@ -432,42 +434,44 @@ def __getitem__(self, window): key_coder = self._element_coder.key_coder() key_coder_impl = key_coder.get_impl() value_coder = self._element_coder.value_coder() + use_bulk_read = self._use_bulk_read class MultiMap(object): _bulk_read = None _lock = threading.Lock() def __getitem__(self, key): - if self._bulk_read is None: - with self._lock: - if self._bulk_read is None: - try: - # Attempt to bulk read the key-values over the iterable - # protocol which, if supported, can be much more efficient - # than point lookups if it fits into memory. - for ix, (k, vs) in enumerate(_StateBackedIterable( - state_handler, - kv_iter_state_key, - coders.TupleCoder( - (key_coder, coders.IterableCoder(value_coder))))): - cache[k] = vs - if ix > StateBackedSideInputMap._BULK_READ_LIMIT: + if use_bulk_read: + if self._bulk_read is None: + with self._lock: + if self._bulk_read is None: + try: + # Attempt to bulk read the key-values over the iterable + # protocol which, if supported, can be much more efficient + # than point lookups if it fits into memory. + for ix, (k, vs) in enumerate(_StateBackedIterable( + state_handler, + kv_iter_state_key, + coders.TupleCoder( + (key_coder, coders.IterableCoder(value_coder))))): + cache[k] = vs + if ix > StateBackedSideInputMap._BULK_READ_LIMIT: + self._bulk_read = ( + StateBackedSideInputMap._BULK_READ_PARTIALLY) + break + else: + # We reached the end of the iteration without breaking. self._bulk_read = ( - StateBackedSideInputMap._BULK_READ_PARTIALLY) - break - else: - # We reached the end of the iteration without breaking. + StateBackedSideInputMap._BULK_READ_FULLY) + except Exception: + _LOGGER.error( + "Iterable access of map side inputs unsupported.", + exc_info=True) self._bulk_read = ( - StateBackedSideInputMap._BULK_READ_FULLY) - except Exception: - _LOGGER.error( - "Iterable access of map side inputs unsupported.", - exc_info=True) - self._bulk_read = ( - StateBackedSideInputMap._BULK_READ_PARTIALLY) - - if (self._bulk_read == StateBackedSideInputMap._BULK_READ_FULLY): - return cache.get(key, []) + StateBackedSideInputMap._BULK_READ_PARTIALLY) + + if (self._bulk_read == StateBackedSideInputMap._BULK_READ_FULLY): + return cache.get(key, []) if key not in cache: keyed_state_key = beam_fn_api_pb2.StateKey() diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index a69ecbaee220..997cea347d33 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -41,6 +41,7 @@ from apache_beam.portability.api import beam_expansion_api_pb2_grpc from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.portability.api import external_transforms_pb2 +from apache_beam.portability.api import schema_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.portability import artifact_service from apache_beam.transforms import environments @@ -51,6 +52,7 @@ from apache_beam.typehints.schemas import named_fields_to_schema from apache_beam.typehints.schemas import named_tuple_from_schema from apache_beam.typehints.schemas import named_tuple_to_schema +from apache_beam.typehints.schemas import typing_from_runner_api from apache_beam.typehints.trivial_inference import instance_to_type from apache_beam.typehints.typehints import Union from apache_beam.typehints.typehints import UnionConstraint @@ -450,8 +452,24 @@ def discover_iter(expansion_service, ignore_errors=True): schema = named_tuple_from_schema(proto_config.config_schema) except Exception as exn: if ignore_errors: - logging.info("Bad schema for %s: %s", identifier, str(exn)[:250]) - continue + truncated_schema = schema_pb2.Schema() + truncated_schema.CopyFrom(proto_config.config_schema) + for field in truncated_schema.fields: + try: + typing_from_runner_api(field.type) + except Exception: + if field.type.nullable: + # Set it to an empty placeholder type. + field.type.CopyFrom( + schema_pb2.FieldType( + nullable=True, + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema()))) + try: + schema = named_tuple_from_schema(truncated_schema) + except Exception as exn: + logging.info("Bad schema for %s: %s", identifier, str(exn)[:250]) + continue else: raise diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index c612d4412081..0f1bc14c47c4 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -18,23 +18,18 @@ # TODO(robertwb): Add more providers. # TODO(robertwb): Perhaps auto-generate this file? -- type: 'beamJar' - config: - gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar' - version: BEAM_VERSION - transforms: - Sql: 'beam:external:java:sql:v1' - MapToFields-java: "beam:schematransform:org.apache.beam:yaml:map_to_fields-java:v1" - MapToFields-generic: "beam:schematransform:org.apache.beam:yaml:map_to_fields-java:v1" - - type: renaming transforms: + 'Sql': 'Sql' 'MapToFields-java': 'MapToFields-java' 'MapToFields-generic': 'MapToFields-java' 'Filter-java': 'Filter-java' 'Explode': 'Explode' config: mappings: + 'Sql': + query: 'query' + # Unfortunately dialect is a java logical type. 'MapToFields-generic': language: 'language' append: 'append' @@ -57,6 +52,7 @@ underlying_provider: type: beamJar transforms: + Sql: "schematransform:org.apache.beam:sql_transform:v1" MapToFields-java: "beam:schematransform:org.apache.beam:yaml:map_to_fields-java:v1" Filter-java: "beam:schematransform:org.apache.beam:yaml:filter-java:v1" Explode: "beam:schematransform:org.apache.beam:yaml:explode:v1" diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 08c7a59819a7..0ce706bbea58 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -326,6 +326,28 @@ def expand(pcoll, error_handling=None, **kwargs): class _Explode(beam.PTransform): + """Explodes (aka unnest/flatten) one or more fields producing multiple rows. + + Given one or more fields of iterable type, produces multiple rows, one for + each value of that field. For example, a row of the form `('a', [1, 2, 3])` + would expand to `('a', 1)`, `('a', 2')`, and `('a', 3)` when exploded on + the second field. + + This is akin to a `FlatMap` when paired with the MapToFields transform. + + Args: + fields: The list of fields to expand. + cross_product: If multiple fields are specified, indicates whether the + full cross-product of combinations should be produced, or if the + first element of the first field corresponds to the first element + of the second field, etc. For example, the row + `(['a', 'b'], [1, 2])` would expand to the four rows + `('a', 1)`, `('a', 2)`, `('b', 1)`, and `('b', 2)` when + `cross_product` is set to `true` but only the two rows + `('a', 1)` and `('b', 2)` when it is set to `false`. + Only meaningful (and required) if multiple rows are specified. + error_handling: Whether and how to handle errors during iteration. + """ def __init__( self, fields: Union[str, Collection[str]], diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 6a2d313183e5..f6078769c654 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -546,7 +546,7 @@ def dicts_to_rows(o): def create_builtin_provider(): - def create(elements: Iterable[Any], reshuffle: bool = True): + def create(elements: Iterable[Any], reshuffle: Optional[bool] = True): """Creates a collection containing a specified set of elements. YAML/JSON-style mappings will be interpreted as Beam rows. For example:: @@ -560,17 +560,48 @@ def create(elements: Iterable[Any], reshuffle: bool = True): Args: elements: The set of elements that should belong to the PCollection. YAML/JSON-style mappings will be interpreted as Beam rows. - reshuffle (optional): Whether to introduce a reshuffle if there is more - than one element in the collection. Defaults to True. + reshuffle (optional): Whether to introduce a reshuffle (to possibly + redistribute the work) if there is more than one element in the + collection. Defaults to True. """ - return beam.Create([element_to_rows(e) for e in elements], reshuffle) + return beam.Create([element_to_rows(e) for e in elements], + reshuffle=reshuffle is not False) # Or should this be posargs, args? # pylint: disable=dangerous-default-value def fully_qualified_named_transform( constructor: str, - args: Iterable[Any] = (), - kwargs: Mapping[str, Any] = {}): + args: Optional[Iterable[Any]] = (), + kwargs: Optional[Mapping[str, Any]] = {}): + """A Python PTransform identified by fully qualified name. + + This allows one to import, construct, and apply any Beam Python transform. + This can be useful for using transforms that have not yet been exposed + via a YAML interface. Note, however, that conversion may be required if this + transform does not accept or produce Beam Rows. + + For example, + + type: PyTransform + config: + constructor: apache_beam.pkg.mod.SomeClass + args: [1, 'foo'] + kwargs: + baz: 3 + + can be used to access the transform + `apache_beam.pkg.mod.SomeClass(1, 'foo', baz=3)`. + + Args: + constructor: Fully qualified name of a callable used to construct the + transform. Often this is a class such as + `apache_beam.pkg.mod.SomeClass` but it can also be a function or + any other callable that returns a PTransform. + args: A list of parameters to pass to the callable as positional + arguments. + kwargs: A list of parameters to pass to the callable as keyword + arguments. + """ with FullyQualifiedNamedTransform.with_filter('*'): return constructor >> FullyQualifiedNamedTransform( constructor, args, kwargs) @@ -579,6 +610,19 @@ def fully_qualified_named_transform( # exactly zero or one PCollection in yaml (as they would be interpreted as # PBegin and the PCollection itself respectively). class Flatten(beam.PTransform): + """Flattens multiple PCollections into a single PCollection. + + The elements of the resulting PCollection will be the (disjoint) union of + all the elements of all the inputs. + + Note that in YAML transforms can always take a list of inputs which will + be implicitly flattened. + """ + def __init__(self): + # Suppress the "label" argument from the superclass for better docs. + # pylint: disable=useless-parent-delegation + super().__init__() + def expand(self, pcolls): if isinstance(pcolls, beam.PCollection): pipeline_arg = {} @@ -592,6 +636,24 @@ def expand(self, pcolls): return pcolls | beam.Flatten(**pipeline_arg) class WindowInto(beam.PTransform): + # pylint: disable=line-too-long + + """A window transform assigning windows to each element of a PCollection. + + The assigned windows will affect all downstream aggregating operations, + which will aggregate by window as well as by key. + + See [the Beam documentation on windowing](https://beam.apache.org/documentation/programming-guide/#windowing) + for more details. + + Note that any Yaml transform can have a + [windowing parameter](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/yaml/README.md#windowing), + which is applied to its inputs (if any) or outputs (if there are no inputs) + which means that explicit WindowInto operations are not typically needed. + + Args: + windowing: the type and parameters of the windowing to perform + """ def __init__(self, windowing): self._window_transform = self._parse_window_spec(windowing) @@ -617,13 +679,21 @@ def _parse_window_spec(spec): # TODO: Triggering, etc. return beam.WindowInto(window_fn) - def log_and_return(x): - logging.info(x) - return x + def LogForTesting(): + """Logs each element of its input PCollection. + + The output of this transform is a copy of its input for ease of use in + chain-style pipelines. + """ + def log_and_return(x): + logging.info(x) + return x + + return beam.Map(log_and_return) return InlineProvider({ 'Create': create, - 'LogForTesting': lambda: beam.Map(log_and_return), + 'LogForTesting': LogForTesting, 'PyTransform': fully_qualified_named_transform, 'Flatten': Flatten, 'WindowInto': WindowInto, diff --git a/sdks/python/build.gradle b/sdks/python/build.gradle index 7795e77e3963..ab6f75fc653b 100644 --- a/sdks/python/build.gradle +++ b/sdks/python/build.gradle @@ -99,11 +99,12 @@ platform_identifiers_map.each { platform, idsuffix -> environment CIBW_ENVIRONMENT: "SETUPTOOLS_USE_DISTUTILS=stdlib" // note: sync cibuildwheel version with GitHub Action // .github/workflow/build_wheel.yml:build_wheels "Install cibuildwheel" step + // note(https://github.com/pypa/cibuildwheel/issues/1692): cibuildwheel appears to timeout occasionally. executable 'sh' args '-c', ". ${envdir}/bin/activate && " + "pip install cibuildwheel==2.9.0 && " + "cibuildwheel --print-build-identifiers --platform ${platform} --archs ${archs} && " + - "cibuildwheel --output-dir ${buildDir} --platform ${platform} --archs ${archs}" + "for i in {1..3}; do cibuildwheel --output-dir ${buildDir} --platform ${platform} --archs ${archs} && break; done" } } } @@ -162,4 +163,4 @@ tasks.register("wordCount") { args '-c', ". ${envdir}/bin/activate && python -m apache_beam.examples.wordcount --runner DirectRunner --output /tmp/output.txt" } } -} \ No newline at end of file +} diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 8d5b43167dd1..82740ae67c9f 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -133,7 +133,7 @@ autodoc_inherit_docstrings = False autodoc_member_order = 'bysource' autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub", - "tensorflow_transform", "tensorflow_metadata", "transformers", "tensorflow_text", + "tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", "datatable", "transformers", "sentence_transformers", ] diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index dc9e6a28cb9e..dbe90c084af2 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -148,10 +148,6 @@ deps = sphinx_rtd_theme==0.4.3 docutils<0.18 Jinja2==3.0.3 # TODO(https://github.com/apache/beam/issues/21587): Sphinx version is too old. - torch - xgboost<=1.7.6 - datatable==1.0.0 - transformers commands = time {toxinidir}/scripts/generate_pydoc.sh