diff --git a/sdks/go.mod b/sdks/go.mod index 96446993b3524..be7626b22948b 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -20,13 +20,15 @@ // directory. module github.com/apache/beam/sdks/v2 -go 1.18 +go 1.19 require ( cloud.google.com/go/bigquery v1.45.0 + cloud.google.com/go/bigtable v1.18.1 cloud.google.com/go/datastore v1.10.0 cloud.google.com/go/profiler v0.3.1 cloud.google.com/go/pubsub v1.28.0 + cloud.google.com/go/spanner v1.43.0 cloud.google.com/go/storage v1.29.0 github.com/aws/aws-sdk-go-v2 v1.17.3 github.com/aws/aws-sdk-go-v2/config v1.18.11 @@ -46,27 +48,23 @@ require ( github.com/proullon/ramsql v0.0.0-20211120092837-c8d0a408b939 github.com/spf13/cobra v1.6.1 github.com/testcontainers/testcontainers-go v0.15.0 + github.com/tetratelabs/wazero v1.0.0-pre.9 github.com/xitongsys/parquet-go v1.6.2 github.com/xitongsys/parquet-go-source v0.0.0-20220315005136-aec0fe3e777c go.mongodb.org/mongo-driver v1.11.1 + golang.org/x/exp v0.0.0-20230206171751-46f607a40771 golang.org/x/net v0.5.0 golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 golang.org/x/sync v0.1.0 golang.org/x/sys v0.4.0 golang.org/x/text v0.6.0 - google.golang.org/api v0.108.0 + google.golang.org/api v0.109.0 google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f google.golang.org/grpc v1.52.3 google.golang.org/protobuf v1.28.1 gopkg.in/retry.v1 v1.0.3 gopkg.in/yaml.v2 v2.4.0 -) - -require cloud.google.com/go/spanner v1.43.0 - -require ( - cloud.google.com/go/bigtable v1.18.1 - github.com/tetratelabs/wazero v1.0.0-pre.7 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -136,9 +134,8 @@ require ( github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - golang.org/x/tools v0.1.12 // indirect + golang.org/x/tools v0.2.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/appengine v1.6.7 // indirect gopkg.in/linkedin/goavro.v1 v1.0.5 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) +) \ No newline at end of file diff --git a/sdks/go.sum b/sdks/go.sum index 55bf439e72b58..e40c086d144e5 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -911,6 +911,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230206171751-46f607a40771 h1:xP7rWLUr1e1n2xkK5YB4LI0hPEy3LJC6Wk+D4pGlOJg= +golang.org/x/exp v0.0.0-20230206171751-46f607a40771/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -1137,8 +1139,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200916195026-c9a70fc28ce3/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= +golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/sdks/go/pkg/beam/runners/prism/README.md b/sdks/go/pkg/beam/runners/prism/README.md new file mode 100644 index 0000000000000..0db3ec59bdf7e --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/README.md @@ -0,0 +1,190 @@ + + +# Apache Beam Go Prism Runner + +Prism is a local portable Apache Beam runner authored in Go. + +* Local, for fast startup and ease of testing on a single machine. +* Portable, in that it uses the Beam FnAPI to communicate with Beam SDKs of any language. +* Go simple concurrency enables clear structures for testing batch through streaming jobs. + +It's intended to replace the current Go Direct runner, but also be for general +single machine use. + +For Go SDK users: + - Short term: set runner to "prism" to use it, or invoke directly. + - Medium term: switch the default from "direct" to "prism". + - Long term: alias "direct" to "prism", and delete legacy Go direct runner. + +Prisms allow breaking apart and separating a beam of light into +it's component wavelengths, as well as recombining them together. + +The Prism Runner leans on this metaphor with the goal of making it +easier for users and Beam SDK developers alike to test and validate +aspects of Beam that are presently under represented. + +## Configurability + +Prism is configurable using YAML, which is eagerly validated on startup. +The configuration contains a set of variants to specify execution behavior, +either to support specific testing goals, or to emulate different runners. + +Beam's implementation contains a number of details that are hidden from +users, and to date, no runner implements the same set of features. This +can make SDK or pipeline development difficult, since exactly what is +being tested will vary on the runner being used. + +At the top level the configuration contains "variants", and the variants +configure the behaviors of different "handlers" in Prism. + +Jobs will be able to provide a pipeline option to select which variant to +use. Multiple jobs on the same prism instance can use different variants. +Jobs which don't provide a variant will default to testing behavior. + +All variants should execute the Beam Model faithfully and correctly, +and with few exceptions it should not be possible for there to be an +invalid execution. The machine's the limit. + +It's not expected that all handler options are useful for pipeline authors, +These options should remain useful for SDK developers, +or more precise issue reproduction. + +For more detail on the motivation, see Robert Burke's (@lostluck) Beam Summit 2022 talk: +https://2022.beamsummit.org/sessions/portable-go-beam-runner/. + +Here's a non-exhaustive set of variants. + +### Variant Highlight: "default" + +The "default" variant is testing focused, intending to route out issues at development +time, rather than discovering them on production runners. Notably, this mode should +never use fusion, executing each Transform individually and independantly, one at a time. + +This variant should be able to execute arbitrary pipelines, correctly, with clarity and +precision when an error occurs. Other features supported by the SDK should be enabled by default to +ensure good coverage, such as caches, or RPC reductions like sending elements in +ProcessBundleRequest and Response, as they should not affect correctness. Composite +transforms like Splitable DoFns and Combines should be expanded to ensure coverage. + +Additional validations may be added as time goes on. + +Does not retry or provide other resilience features, which may mask errors. + +To ensure coverage, there may be sibling variants that use mutually exclusive alternative +executions. + +### Variant Highlight: "fast" + +Not Yet Implemented - Illustrative goal. + +The "fast" variant is performance focused, intended for local scale execution. +A psuedo production execution. Fusion optimizations should be performed. +Large PCollection should be offloaded to persistent disk. Bundles should be +dynamically split. Multiple Bundles should be executed simultaneously. And so on. + +Pipelines should execute as swiftly as possible within the bounds of correct +execution. + +### Variant Hightlight: "flink" "dataflow" "spark" AKA Emulations + +Not Yet Implemented - Illustrative goal. + +Emulation variants have the goal of replicating on the local scale, +the behaviors of other runners. Flink execution never "lifts" Combines, and +doesn't dynamically split. Dataflow has different characteristics for batch +and streaming execution with certain execution charateristics enabled or +disabled. + +As Prism is intended to implement all facets of Beam Model execution, the handlers +can have features selectively disabled to ensure + +## Current Limitations + +* Experimental and testing use only. +* Executing docker containers isn't yet implemented. + * This precludes running the Java and Python SDKs, or their transforms for Cross Language. + * Loopback execution only. + * No stand alone execution. +* In Memory Only + * Not yet suitable for larger jobs, which may have intermediate data that exceeds memory bounds. + * Doesn't yet support sufficient intermediate data garbage collection for indefinite stream processing. +* Doesn't yet execute all beam pipeline features. +* No UI for job status inspection. + +## Implemented so far. + +* DoFns + * Side Inputs + * Multiple Outputs +* Flattens +* GBKs + * Includes handling session windows. + * Global Window + * Interval Windowing + * Session Windows. +* Combines lifted and unlifted. +* Expands Splittable DoFns +* Limited support for Process Continuations + * Residuals are rescheduled for execution immeadiately. + * The transform must be finite (and eventually return a stop process continuation) +* Basic Metrics support + +## Next feature short list (unordered) + +See https://github.com/apache/beam/issues/24789 for current status. + +* Resolve watermark advancement for Process Continuations +* Test Stream +* Triggers & Complex Windowing Strategy execution. +* State +* Timers +* "PubSub" Transform +* Support SDK Containers via Testcontainers + * Cross Language Transforms +* FnAPI Optimizations + * Fusion + * Data with ProcessBundleRequest & Response +* Progess tracking + * Channel Splitting + * Dynamic Splitting +* Stand alone execution support +* UI reporting of in progress jobs + +This is not a comprehensive feature set, but a set of goals to best +support users of the Go SDK in testing their pipelines. + +## How to contribute + +Until additional structure is necessary, check the main issue +https://github.com/apache/beam/issues/24789 for the current +status, file an issue for the feature or bug to fix with `[prism]` +in the title, and refer to the main issue, before begining work +to avoid duplication of effort. + +If a feature will take a long time, please send a PR to +link to your issue from this README to help others discover it. + +Otherwise, ordinary [Beam contribution guidelines apply](https://beam.apache.org/contribute/). + +# Long Term Goals + +Once support for containers is implemented, Prism should become a target +for the Java Runner Validation tests, which are the current specification +for correct runner behavior. This will inform further feature developement. \ No newline at end of file diff --git a/sdks/go/pkg/beam/runners/prism/internal/README.md b/sdks/go/pkg/beam/runners/prism/internal/README.md new file mode 100644 index 0000000000000..a8771f913f230 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/README.md @@ -0,0 +1,62 @@ + + +# Prism internal packages + +Go has a mechanism for ["internal" packages](https://go.dev/doc/go1.4#internalpackages) +to prevent use of implementation details outside of their intended use. + +This mechanism is used thoroughly for Prism to ensure we can make changes to the +runner's internals without worrying about the exposed surface changes breaking +non-compliant users. + +# Structure + +Here's a loose description of the current structure of the runner. Leaf packages should +not depend on other parts of the runner. Runner packages can and do depend on other +parts of the SDK, such as for Coder handling. + +`config` contains configuration parsing and handling. Leaf package. +Handler configurations are registered by dependant packages. + +`urns` contains beam URN strings pulled from the protos. Leaf package. + +`engine` contains the core manager for handling elements, watermarks, and windowing strategies. +Determines bundle readiness, and stages to execute. Leaf package. + +`jobservices` contains GRPC service handlers for job management and submission. +Should only depend on the `config` and `urns` packages. + +`worker` contains interactions with FnAPI services to communicate with worker SDKs. Leaf package +except for dependency on `engine.TentativeData` which will likely be removed at some point. + +`internal` AKA the package in this directory root. Contains fhe job execution +flow. Jobs are sent to it from `jobservices`, and those jobs are then executed by coordinating +with the `engine` and `worker` packages, and handlers urn. +Most configurable behavior is determined here. + +# Testing + +The sub packages should have reasonable Unit Test coverage in their own directories, but +most features will be exercised via executing pipelines in this package. + +For the time being test DoFns should be added to standard build in order to validate execution +coverage, in particular for Combine and Splittable DoFns. + +Eventually these behaviors should be covered by using Prism in the main SDK tests. \ No newline at end of file diff --git a/sdks/go/pkg/beam/runners/prism/internal/coders.go b/sdks/go/pkg/beam/runners/prism/internal/coders.go new file mode 100644 index 0000000000000..d88ed763f8048 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/coders.go @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "fmt" + "io" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/ioutilx" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "golang.org/x/exp/slog" + "google.golang.org/protobuf/encoding/prototext" +) + +// leafCoders lists coder urns the runner knows how to manipulate. +// In particular, ones that won't be a problem to parse, in general +// because they have a known total size. +var leafCoders = map[string]struct{}{ + urns.CoderBytes: {}, + urns.CoderStringUTF8: {}, + urns.CoderLengthPrefix: {}, + urns.CoderVarInt: {}, + urns.CoderDouble: {}, + urns.CoderBool: {}, + urns.CoderGlobalWindow: {}, + urns.CoderIntervalWindow: {}, +} + +func isLeafCoder(c *pipepb.Coder) bool { + _, ok := leafCoders[c.GetSpec().GetUrn()] + return ok +} + +// makeWindowedValueCoder gets the coder for the PCollection, renders it safe, and adds it to the coders map. +// +// PCollection coders are not inherently WindowValueCoder wrapped, and they are added by the runner +// for crossing the FnAPI boundary at data sources and data sinks. +func makeWindowedValueCoder(pID string, comps *pipepb.Components, coders map[string]*pipepb.Coder) string { + col := comps.GetPcollections()[pID] + cID := lpUnknownCoders(col.GetCoderId(), coders, comps.GetCoders()) + wcID := comps.GetWindowingStrategies()[col.GetWindowingStrategyId()].GetWindowCoderId() + + // The runner needs to be defensive, and tell the SDK to Length Prefix + // any coders that it doesn't understand. + // So here, we look at the coder and its components, and produce + // new coders that we know how to deal with. + + // Produce ID for the Windowed Value Coder + wvcID := "cwv_" + pID + wInC := &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderWindowedValue, + }, + ComponentCoderIds: []string{cID, wcID}, + } + // Populate the coders to send with the new windowed value coder. + coders[wvcID] = wInC + return wvcID +} + +// makeWindowCoders makes the coder pair but behavior is ultimately determined by the strategy's windowFn. +func makeWindowCoders(wc *pipepb.Coder) (exec.WindowDecoder, exec.WindowEncoder) { + var cwc *coder.WindowCoder + switch wc.GetSpec().GetUrn() { + case urns.CoderGlobalWindow: + cwc = coder.NewGlobalWindow() + case urns.CoderIntervalWindow: + cwc = coder.NewIntervalWindow() + default: + slog.Log(slog.LevelError, "makeWindowCoders: unknown urn", slog.String("urn", wc.GetSpec().GetUrn())) + panic(fmt.Sprintf("makeWindowCoders, unknown urn: %v", prototext.Format(wc))) + } + return exec.MakeWindowDecoder(cwc), exec.MakeWindowEncoder(cwc) +} + +// lpUnknownCoders takes a coder, and populates coders with any new coders +// coders that the runner needs to be safe, and speedy. +// It returns either the passed in coder id, or the new safe coder id. +func lpUnknownCoders(cID string, bundle, base map[string]*pipepb.Coder) string { + // First check if we've already added the LP version of this coder to coders already. + lpcID := cID + "_lp" + // Check if we've done this one before. + if _, ok := bundle[lpcID]; ok { + return lpcID + } + // All coders in the coders map have been processed. + if _, ok := bundle[cID]; ok { + return cID + } + // Look up the canonical location. + c, ok := base[cID] + if !ok { + // We messed up somewhere. + panic(fmt.Sprint("unknown coder id:", cID)) + } + // Add the original coder to the coders map. + bundle[cID] = c + // If we don't know this coder, and it has no sub components, + // we must LP it, and we return the LP'd version. + leaf := isLeafCoder(c) + if len(c.GetComponentCoderIds()) == 0 && !leaf { + lpc := &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderLengthPrefix, + }, + ComponentCoderIds: []string{cID}, + } + bundle[lpcID] = lpc + return lpcID + } + // We know we have a composite, so if we count this as a leaf, move everything to + // the coders map. + if leaf { + // Copy the components from the base. + for _, cc := range c.GetComponentCoderIds() { + bundle[cc] = base[cc] + } + return cID + } + var needNewComposite bool + var comps []string + for _, cc := range c.GetComponentCoderIds() { + rcc := lpUnknownCoders(cc, bundle, base) + if cc != rcc { + needNewComposite = true + } + comps = append(comps, rcc) + } + if needNewComposite { + lpc := &pipepb.Coder{ + Spec: c.GetSpec(), + ComponentCoderIds: comps, + } + bundle[lpcID] = lpc + return lpcID + } + return cID +} + +// reconcileCoders ensures that the bundle coders are primed with initial coders from +// the base pipeline components. +func reconcileCoders(bundle, base map[string]*pipepb.Coder) { + for { + var comps []string + for _, c := range bundle { + for _, ccid := range c.GetComponentCoderIds() { + if _, ok := bundle[ccid]; !ok { + // We don't have the coder yet, so in we go. + comps = append(comps, ccid) + } + } + } + if len(comps) == 0 { + return + } + for _, ccid := range comps { + c, ok := base[ccid] + if !ok { + panic(fmt.Sprintf("unknown coder id during reconciliation: %v", ccid)) + } + bundle[ccid] = c + } + } +} + +// pullDecoder return a function that will extract the bytes +// for the associated coder. Uses a buffer and a TeeReader to extract the original +// bytes from when decoding elements. +func pullDecoder(c *pipepb.Coder, coders map[string]*pipepb.Coder) func(io.Reader) []byte { + dec := pullDecoderNoAlloc(c, coders) + return func(r io.Reader) []byte { + var buf bytes.Buffer + tr := io.TeeReader(r, &buf) + dec(tr) + return buf.Bytes() + } +} + +// pullDecoderNoAlloc returns a function that decodes a single eleemnt of the given coder. +// Intended to only be used as an internal function for pullDecoder, which will use a io.TeeReader +// to extract the bytes. +func pullDecoderNoAlloc(c *pipepb.Coder, coders map[string]*pipepb.Coder) func(io.Reader) { + urn := c.GetSpec().GetUrn() + switch urn { + // Anything length prefixed can be treated as opaque. + case urns.CoderBytes, urns.CoderStringUTF8, urns.CoderLengthPrefix: + return func(r io.Reader) { + l, _ := coder.DecodeVarInt(r) + ioutilx.ReadN(r, int(l)) + } + case urns.CoderVarInt: + return func(r io.Reader) { + coder.DecodeVarInt(r) + } + case urns.CoderBool: + return func(r io.Reader) { + coder.DecodeBool(r) + } + case urns.CoderDouble: + return func(r io.Reader) { + coder.DecodeDouble(r) + } + case urns.CoderIterable: + ccids := c.GetComponentCoderIds() + ed := pullDecoderNoAlloc(coders[ccids[0]], coders) + return func(r io.Reader) { + l, _ := coder.DecodeInt32(r) + for i := int32(0); i < l; i++ { + ed(r) + } + } + + case urns.CoderKV: + ccids := c.GetComponentCoderIds() + kd := pullDecoderNoAlloc(coders[ccids[0]], coders) + vd := pullDecoderNoAlloc(coders[ccids[1]], coders) + return func(r io.Reader) { + kd(r) + vd(r) + } + case urns.CoderRow: + panic(fmt.Sprintf("Runner forgot to LP this Row Coder. %v", prototext.Format(c))) + default: + panic(fmt.Sprintf("unknown coder urn key: %v", urn)) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/coders_test.go b/sdks/go/pkg/beam/runners/prism/internal/coders_test.go new file mode 100644 index 0000000000000..ad6e36496286d --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/coders_test.go @@ -0,0 +1,377 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "encoding/binary" + "math" + "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/testing/protocmp" +) + +func Test_isLeafCoder(t *testing.T) { + tests := []struct { + urn string + isLeaf bool + }{ + {urns.CoderBytes, true}, + {urns.CoderStringUTF8, true}, + {urns.CoderLengthPrefix, true}, + {urns.CoderVarInt, true}, + {urns.CoderDouble, true}, + {urns.CoderBool, true}, + {urns.CoderGlobalWindow, true}, + {urns.CoderIntervalWindow, true}, + {urns.CoderIterable, false}, + {urns.CoderRow, false}, + {urns.CoderKV, false}, + } + for _, test := range tests { + undertest := &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: test.urn, + }, + } + if got, want := isLeafCoder(undertest), test.isLeaf; got != want { + t.Errorf("isLeafCoder(%v) = %v, want %v", test.urn, got, want) + } + } +} + +func Test_makeWindowedValueCoder(t *testing.T) { + coders := map[string]*pipepb.Coder{} + + gotID := makeWindowedValueCoder("testPID", &pipepb.Components{ + Pcollections: map[string]*pipepb.PCollection{ + "testPID": {CoderId: "testCoderID"}, + }, + Coders: map[string]*pipepb.Coder{ + "testCoderID": { + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderBool, + }, + }, + }, + }, coders) + + if gotID == "" { + t.Errorf("makeWindowedValueCoder(...) = %v, want non-empty", gotID) + } + got := coders[gotID] + if got == nil { + t.Errorf("makeWindowedValueCoder(...) = ID %v, had nil entry", gotID) + } + if got.GetSpec().GetUrn() != urns.CoderWindowedValue { + t.Errorf("makeWindowedValueCoder(...) = ID %v, had nil entry", gotID) + } +} + +func Test_makeWindowCoders(t *testing.T) { + tests := []struct { + urn string + window typex.Window + }{ + {urns.CoderGlobalWindow, window.GlobalWindow{}}, + {urns.CoderIntervalWindow, window.IntervalWindow{ + Start: mtime.MinTimestamp, + End: mtime.MaxTimestamp, + }}, + } + for _, test := range tests { + undertest := &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: test.urn, + }, + } + dec, enc := makeWindowCoders(undertest) + + // Validate we're getting a round trip coder. + var buf bytes.Buffer + if err := enc.EncodeSingle(test.window, &buf); err != nil { + t.Errorf("encoder[%v].EncodeSingle(%v) = %v, want nil", test.urn, test.window, err) + } + got, err := dec.DecodeSingle(&buf) + if err != nil { + t.Errorf("decoder[%v].DecodeSingle(%v) = %v, want nil", test.urn, test.window, err) + } + + if want := test.window; got != want { + t.Errorf("makeWindowCoders(%v) didn't round trip: got %v, want %v", test.urn, got, want) + } + } +} + +func Test_lpUnknownCoders(t *testing.T) { + tests := []struct { + name string + urn string + components []string + bundle, base map[string]*pipepb.Coder + want map[string]*pipepb.Coder + }{ + {"alreadyProcessed", + urns.CoderBool, nil, + map[string]*pipepb.Coder{ + "test": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + map[string]*pipepb.Coder{}, + map[string]*pipepb.Coder{ + "test": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + }, + {"alreadyProcessedLP", + urns.CoderBool, nil, + map[string]*pipepb.Coder{ + "test_lp": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderLengthPrefix}, ComponentCoderIds: []string{"test"}}, + "test": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + map[string]*pipepb.Coder{}, + map[string]*pipepb.Coder{ + "test_lp": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderLengthPrefix}, ComponentCoderIds: []string{"test"}}, + "test": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + }, + {"noNeedForLP", + urns.CoderBool, nil, + map[string]*pipepb.Coder{}, + map[string]*pipepb.Coder{}, + map[string]*pipepb.Coder{ + "test": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + }, + {"needLP", + urns.CoderRow, nil, + map[string]*pipepb.Coder{}, + map[string]*pipepb.Coder{}, + map[string]*pipepb.Coder{ + "test_lp": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderLengthPrefix}, ComponentCoderIds: []string{"test"}}, + "test": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderRow}}, + }, + }, + {"needLP_recurse", + urns.CoderKV, []string{"k", "v"}, + map[string]*pipepb.Coder{}, + map[string]*pipepb.Coder{ + "k": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderRow}}, + "v": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + map[string]*pipepb.Coder{ + "test_lp": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"k_lp", "v"}}, + "test": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"k", "v"}}, + "k_lp": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderLengthPrefix}, ComponentCoderIds: []string{"k"}}, + "k": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderRow}}, + "v": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + }, + {"alreadyLP", urns.CoderLengthPrefix, []string{"k"}, + map[string]*pipepb.Coder{}, + map[string]*pipepb.Coder{ + "k": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderRow}}, + }, + map[string]*pipepb.Coder{ + "test": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderLengthPrefix}, ComponentCoderIds: []string{"k"}}, + "k": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderRow}}, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Add the initial coder to base. + test.base["test"] = &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{Urn: test.urn}, + ComponentCoderIds: test.components, + } + + lpUnknownCoders("test", test.bundle, test.base) + + if d := cmp.Diff(test.want, test.bundle, protocmp.Transform()); d != "" { + t.Fatalf("lpUnknownCoders(%v); (-want, +got):\n%v", test.urn, d) + } + }) + } +} + +func Test_reconcileCoders(t *testing.T) { + tests := []struct { + name string + bundle, base map[string]*pipepb.Coder + want map[string]*pipepb.Coder + }{ + {name: "noChanges", + bundle: map[string]*pipepb.Coder{ + "a": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + base: map[string]*pipepb.Coder{ + "a": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + "b": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBytes}}, + "c": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderStringUTF8}}, + }, + want: map[string]*pipepb.Coder{ + "a": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + }, + {name: "KV", + bundle: map[string]*pipepb.Coder{ + "kv": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"k", "v"}}, + }, + base: map[string]*pipepb.Coder{ + "kv": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"k", "v"}}, + "k": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + "v": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + want: map[string]*pipepb.Coder{ + "kv": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"k", "v"}}, + "k": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + "v": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + }, + }, + {name: "KV-nested", + bundle: map[string]*pipepb.Coder{ + "kv": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"k", "v"}}, + }, + base: map[string]*pipepb.Coder{ + "kv": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"k", "v"}}, + "k": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"a", "b"}}, + "v": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + "a": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBytes}}, + "b": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderRow}}, + "c": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderStringUTF8}}, + }, + want: map[string]*pipepb.Coder{ + "kv": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"k", "v"}}, + "k": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderKV}, ComponentCoderIds: []string{"a", "b"}}, + "v": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBool}}, + "a": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderBytes}}, + "b": {Spec: &pipepb.FunctionSpec{Urn: urns.CoderRow}}, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + reconcileCoders(test.bundle, test.base) + + if d := cmp.Diff(test.want, test.bundle, protocmp.Transform()); d != "" { + t.Fatalf("reconcileCoders(...); (-want, +got):\n%v", d) + } + }) + } +} + +func Test_pullDecoder(t *testing.T) { + + doubleBytes := make([]byte, 8) + binary.BigEndian.PutUint64(doubleBytes, math.Float64bits(math.SqrtPi)) + + tests := []struct { + name string + coder *pipepb.Coder + coders map[string]*pipepb.Coder + input []byte + }{ + { + "bytes", + &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderBytes, + }, + }, + map[string]*pipepb.Coder{}, + []byte{3, 1, 2, 3}, + }, { + "varint", + &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderVarInt, + }, + }, + map[string]*pipepb.Coder{}, + []byte{255, 3}, + }, { + "bool", + &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderBool, + }, + }, + map[string]*pipepb.Coder{}, + []byte{1}, + }, { + "double", + &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderDouble, + }, + }, + map[string]*pipepb.Coder{}, + doubleBytes, + }, { + "iterable", + &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderIterable, + }, + ComponentCoderIds: []string{"elm"}, + }, + map[string]*pipepb.Coder{ + "elm": &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderVarInt, + }, + }, + }, + []byte{4, 0, 1, 2, 3}, + }, { + "kv", + &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderKV, + }, + ComponentCoderIds: []string{"key", "value"}, + }, + map[string]*pipepb.Coder{ + "key": &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderVarInt, + }, + }, + "value": &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urns.CoderBool, + }, + }, + }, + []byte{3, 0}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dec := pullDecoder(test.coder, test.coders) + buf := bytes.NewBuffer(test.input) + got := dec(buf) + if !bytes.EqualFold(test.input, got) { + t.Fatalf("pullDecoder(%v)(...) = %v, want %v", test.coder, got, test.input) + } + }) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/config/config.go b/sdks/go/pkg/beam/runners/prism/internal/config/config.go new file mode 100644 index 0000000000000..9c3bdd012bcb0 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/config/config.go @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package config defines and handles the parsing and provision of configurations +// for the runner. This package should be refered to, and should not take dependencies +// on other parts of this runner. +// +// 1. A given configuation file has one or more variations configured. +// 2. Each variation has a name, and one or more handlers configured. +// 3. Each handler maps to a specific struct. +// +// : +// : +// +// : +// +// +// : +// : +// +// : +// +// +// Handler has it's own name, and an associated characterisitc type. +package config + +import ( + "bytes" + "fmt" + "reflect" + "sort" + "strings" + + "golang.org/x/exp/maps" + "gopkg.in/yaml.v3" +) + +// configFile is the struct configs are decoded into by YAML. +// This represents the whole configuration file. +type configFile struct { + Version int + HandlerOrder []string + Default string // reserved for laer + Variants map[string]*rawVariant `yaml:",inline"` +} + +// rawVariant holds an individual Variant's handlers, +// and any common fields as decoded by YAML. +type rawVariant struct { + HandlerOrder []string + Handlers map[string]yaml.Node `yaml:",inline"` +} + +// HandlerMetadata is required information about handler configurations. +// Handlers have an URN, which key for how configurations refer to them, +// and a Characteristic type, which is it's own individual configuration. +// +// Characteristic types must have useful zero values, representing the +// default configuration for the handler. +type HandlerMetadata interface { + // ConfigURN represents the urn for the handle. + ConfigURN() string + + // ConfigCharacteristic returns the type of the detailed configuration for the handler. + // A characteristic type must have a useful zero value that defines the default behavior. + ConfigCharacteristic() reflect.Type +} + +type unknownHandlersErr struct { + handlersToVariants map[string][]string +} + +func (e *unknownHandlersErr) valid() bool { + return e.handlersToVariants != nil +} + +func (e *unknownHandlersErr) add(handler, variant string) { + if e.handlersToVariants == nil { + e.handlersToVariants = map[string][]string{} + } + vs := e.handlersToVariants[handler] + vs = append(vs, variant) + e.handlersToVariants[handler] = vs +} + +func (e *unknownHandlersErr) Error() string { + var sb strings.Builder + sb.WriteString("yaml config contained unknown handlers") + for h, vs := range e.handlersToVariants { + sort.Strings(vs) + sb.WriteString("\n\t") + sb.WriteString(h) + sb.WriteString(" present in variants ") + sb.WriteString(strings.Join(vs, ",")) + } + return sb.String() +} + +// Variant represents a single complete configuration of all handlers in the registry. +type Variant struct { + parent *HandlerRegistry + + name string + handlers map[string]yaml.Node +} + +// GetCharacteristics returns the characteristics of this handler within this variant. +// +// If the variant doesn't configure this handler, the zero value of the handler characteristic +// type will be returned. If the handler is unknown to the registry this variant came from, +// a nil will be returned. +func (v *Variant) GetCharacteristics(handler string) any { + if v == nil { + return nil + } + md, ok := v.parent.metadata[handler] + if !ok { + return nil + } + rt := md.ConfigCharacteristic() + + // Get a pointer to the concrete value. + rtv := reflect.New(rt) + + // look up the handler urn in the variant. + yn := v.handlers[handler] + // + if err := yn.Decode(rtv.Interface()); err != nil { + // We prevalidated the config, so this shouldn't happen. + panic(fmt.Sprintf("couldn't decode characteristic for variant %v handler %v: %v", v.name, handler, err)) + } + + // Return the value pointed to by the pointer. + return rtv.Elem().Interface() +} + +// HandlerRegistry stores known handlers and their associated metadata needed to parse +// the YAML configuration. +type HandlerRegistry struct { + variations map[string]*rawVariant + metadata map[string]HandlerMetadata + + // cached names + variantIDs, handerIDs []string +} + +// NewHandlerRegistry creates an initialized HandlerRegistry. +func NewHandlerRegistry() *HandlerRegistry { + return &HandlerRegistry{ + variations: map[string]*rawVariant{}, + metadata: map[string]HandlerMetadata{}, + } +} + +// RegisterHandlers is about registering the metadata for handler configurations. +func (r *HandlerRegistry) RegisterHandlers(mds ...HandlerMetadata) { + for _, md := range mds { + r.metadata[md.ConfigURN()] = md + } +} + +// LoadFromYaml takes in a yaml formatted configuration and eagerly processes it for errors. +// +// All handlers are validated against their registered characteristic, and it is an error +// to have configurations for unknown handlers +func (r *HandlerRegistry) LoadFromYaml(in []byte) error { + vs := configFile{Variants: r.variations} + buf := bytes.NewBuffer(in) + d := yaml.NewDecoder(buf) + if err := d.Decode(&vs); err != nil { + return err + } + + err := &unknownHandlersErr{} + handlers := map[string]struct{}{} + for v, hs := range r.variations { + for hk, hyn := range hs.Handlers { + handlers[hk] = struct{}{} + + md, ok := r.metadata[hk] + if !ok { + err.add(hk, v) + continue + } + + // Validate that handler config so we can give a good error message now. + // We re-encode, then decode, since then we don't need to re-implement + // the existing Known fields. Sadly, this doens't persist through + // yaml.Node fields. + hb, err := yaml.Marshal(hyn) + if err != nil { + panic(fmt.Sprintf("error re-encoding characteristic for variant %v handler %v: %v", v, hk, err)) + } + buf := bytes.NewBuffer(hb) + dec := yaml.NewDecoder(buf) + dec.KnownFields(true) + rt := md.ConfigCharacteristic() + rtv := reflect.New(rt) + if err := dec.Decode(rtv.Interface()); err != nil { + return fmt.Errorf("error decoding characteristic strictly for variant %v handler %v: %v", v, hk, err) + } + + } + } + + if err.valid() { + return err + } + + r.variantIDs = maps.Keys(r.variations) + sort.Strings(r.variantIDs) + r.handerIDs = maps.Keys(handlers) + sort.Strings(r.handerIDs) + return nil +} + +// Variants returns the IDs of all variations loaded into this registry. +func (r *HandlerRegistry) Variants() []string { + return r.variantIDs +} + +// Handlers returns the IDs of all handlers used in variations. +func (r *HandlerRegistry) UsedHandlers() []string { + return r.handerIDs +} + +// GetVariant returns the Variant with the given name. +// If none exist, GetVariant returns nil. +func (r *HandlerRegistry) GetVariant(name string) *Variant { + vs, ok := r.variations[name] + if !ok { + return nil + } + return &Variant{parent: r, name: name, handlers: vs.Handlers} +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/config/config_test.go b/sdks/go/pkg/beam/runners/prism/internal/config/config_test.go new file mode 100644 index 0000000000000..4c2642e78f99e --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/config/config_test.go @@ -0,0 +1,221 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "reflect" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +type generalMetadata struct { + urn string + characteristic reflect.Type +} + +func (m generalMetadata) ConfigURN() string { + return m.urn +} + +func (m generalMetadata) ConfigCharacteristic() reflect.Type { + return m.characteristic +} + +func TestHandlerRegistry(t *testing.T) { + type testCombine struct { + Lift bool + } + combineMetadata := generalMetadata{"combine", reflect.TypeOf(testCombine{})} + type testIterables struct { + StateBackedEnabled bool + StateBackedPageSize int64 + } + iterableMetadata := generalMetadata{"iterable", reflect.TypeOf(testIterables{})} + type testSdf struct { + Enabled bool + BatchSize int64 + } + sdfMetadata := generalMetadata{"sdf", reflect.TypeOf(testSdf{})} + + type spotCheck struct { + v, h string + want any + } + tests := []struct { + name string + handlers []HandlerMetadata + config string + + wantVariants, wantHandlers []string + wantSpots []spotCheck + }{ + { + name: "basics", + handlers: []HandlerMetadata{combineMetadata, iterableMetadata, sdfMetadata}, + config: ` +flink: + combine: + lift: false +dataflow: + combine: + lift: true + sdf: + enabled: true + batchsize: 5 +`, + wantVariants: []string{"dataflow", "flink"}, + wantHandlers: []string{"combine", "sdf"}, + wantSpots: []spotCheck{ + {v: "dataflow", h: "combine", want: testCombine{Lift: true}}, + {v: "flink", h: "combine", want: testCombine{Lift: false}}, + {v: "dataflow", h: "sdf", want: testSdf{Enabled: true, BatchSize: 5}}, + {v: "flink", h: "sdf", want: testSdf{Enabled: false, BatchSize: 0}}, // Unset means 0 value configs. + {v: "unknown", h: "missing", want: nil}, + {v: "dataflow", h: "missing", want: nil}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + reg := NewHandlerRegistry() + reg.RegisterHandlers(test.handlers...) + + if err := reg.LoadFromYaml([]byte(test.config)); err != nil { + t.Fatalf("error unmarshalling test config: %v", err) + } + + if d := cmp.Diff(test.wantVariants, reg.Variants()); d != "" { + t.Errorf("mismatch in variants (-want, +got):\n%v", d) + } + if d := cmp.Diff(test.wantHandlers, reg.UsedHandlers()); d != "" { + t.Errorf("mismatch in used handlers (-want, +got):\n%v", d) + } + for _, spot := range test.wantSpots { + got := reg.GetVariant(spot.v).GetCharacteristics(spot.h) + if d := cmp.Diff(spot.want, got); d != "" { + t.Errorf("mismatch in spot check for (%v, %v) (-want, +got):\n%v", spot.v, spot.h, d) + } + } + }) + } + + t.Run("trying to read a config with an unregistered handler should fail", func(t *testing.T) { + reg := NewHandlerRegistry() + reg.RegisterHandlers(combineMetadata) + + config := ` +dataflow: + sdf: + enabled: true + batchsize: 5 + combine: + lift: true` + + err := reg.LoadFromYaml([]byte(config)) + if err == nil { + t.Fatal("loaded config, got nil; want error") + } + if !strings.Contains(err.Error(), "sdf") { + t.Fatalf("error should contain \"sdf\", but was: %v", err) + } + }) + + t.Run("duplicate variants", func(t *testing.T) { + reg := NewHandlerRegistry() + reg.RegisterHandlers(combineMetadata) + + config := ` +dataflow: + combine: + lift: true +dataflow: + combine: + lift: false +` + err := reg.LoadFromYaml([]byte(config)) + if err == nil { + t.Fatal("loaded config, got nil; want error") + } + }) + + t.Run("duplicate handlers", func(t *testing.T) { + reg := NewHandlerRegistry() + reg.RegisterHandlers(combineMetadata) + + config := ` +dataflow: + combine: + lift: true + combine: + lift: false +` + err := reg.LoadFromYaml([]byte(config)) + if err == nil { + t.Fatal("loaded config, got nil; want error") + } + }) + + t.Run("invalid handler config:fieldtype", func(t *testing.T) { + reg := NewHandlerRegistry() + reg.RegisterHandlers(combineMetadata) + + config := ` +dataflow: + combine: + lift: d +` + err := reg.LoadFromYaml([]byte(config)) + if err == nil { + t.Fatal("loaded config, got nil; want error") + } + }) + t.Run("invalid handler config:extra field", func(t *testing.T) { + reg := NewHandlerRegistry() + reg.RegisterHandlers(combineMetadata) + + config := ` +dataflow: + combine: + lift: no + lower: foo +` + err := reg.LoadFromYaml([]byte(config)) + if err == nil { + t.Fatal("loaded config, got nil; want error") + } + }) + + t.Run("no variant", func(t *testing.T) { + reg := NewHandlerRegistry() + reg.RegisterHandlers(combineMetadata) + + config := ` +dataflow: + combine: + lift: true +` + err := reg.LoadFromYaml([]byte(config)) + if err != nil { + t.Fatalf("error loading config: %v", err) + } + if got, want := reg.GetVariant("notpresent"), (*Variant)(nil); got != want { + t.Errorf("GetVariant('notpresent') = %v, want %v", got, want) + } + }) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go new file mode 100644 index 0000000000000..6fc192ac83be6 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package engine + +// TentativeData is where data for in progress bundles is put +// until the bundle executes successfully. +type TentativeData struct { + Raw map[string][][]byte +} + +// WriteData adds data to a given global collectionID. +func (d *TentativeData) WriteData(colID string, data []byte) { + if d.Raw == nil { + d.Raw = map[string][][]byte{} + } + d.Raw[colID] = append(d.Raw[colID], data) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go new file mode 100644 index 0000000000000..aeabc81b8123f --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -0,0 +1,675 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package engine handles the operational components of a runner, to +// track elements, watermarks, timers, triggers etc +package engine + +import ( + "bytes" + "container/heap" + "context" + "fmt" + "io" + "sync" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "golang.org/x/exp/slog" +) + +type element struct { + window typex.Window + timestamp mtime.Time + pane typex.PaneInfo + + elmBytes []byte +} + +type elements struct { + es []element + minTimestamp mtime.Time +} + +type PColInfo struct { + GlobalID string + WDec exec.WindowDecoder + WEnc exec.WindowEncoder + EDec func(io.Reader) []byte +} + +// ToData recodes the elements with their approprate windowed value header. +func (es elements) ToData(info PColInfo) [][]byte { + var ret [][]byte + for _, e := range es.es { + var buf bytes.Buffer + exec.EncodeWindowedValueHeader(info.WEnc, []typex.Window{e.window}, e.timestamp, e.pane, &buf) + buf.Write(e.elmBytes) + ret = append(ret, buf.Bytes()) + } + return ret +} + +// elementHeap orders elements based on their timestamps +// so we can always find the minimum timestamp of pending elements. +type elementHeap []element + +func (h elementHeap) Len() int { return len(h) } +func (h elementHeap) Less(i, j int) bool { return h[i].timestamp < h[j].timestamp } +func (h elementHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *elementHeap) Push(x any) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(element)) +} + +func (h *elementHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +type Config struct { + // MaxBundleSize caps the number of elements permitted in a bundle. + // 0 or less means this is ignored. + MaxBundleSize int +} + +// ElementManager handles elements, watermarks, and related errata to determine +// if a stage is able to be executed. It is the core execution engine of Prism. +// +// Essentially, it needs to track the current watermarks for each PCollection +// and transform/stage. But it's tricky, since the watermarks for the +// PCollections are always relative to transforms/stages. +// +// Key parts: +// +// - The parallel input's PCollection's watermark is relative to committed consumed +// elements. That is, the input elements consumed by the transform after a successful +// bundle, can advance the watermark, based on the minimum of their elements. +// - An output PCollection's watermark is relative to its producing transform, +// which relates to *all of it's outputs*. +// +// This means that a PCollection's watermark is the minimum of all it's consuming transforms. +// +// So, the watermark manager needs to track: +// Pending Elements for each stage, along with their windows and timestamps. +// Each transform's view of the watermarks for the PCollections. +// +// Watermarks are advanced based on consumed input, except if the stage produces residuals. +type ElementManager struct { + config Config + + stages map[string]*stageState // The state for each stage. + + consumers map[string][]string // Map from pcollectionID to stageIDs that consumes them as primary input. + sideConsumers map[string][]string // Map from pcollectionID to stageIDs that consumes them as side input. + + pcolParents map[string]string // Map from pcollectionID to stageIDs that produce the pcollection. + + refreshCond sync.Cond // refreshCond protects the following fields with it's lock, and unblocks bundle scheduling. + inprogressBundles set[string] // Active bundleIDs + watermarkRefreshes set[string] // Scheduled stageID watermark refreshes + + pendingElements sync.WaitGroup // pendingElements counts all unprocessed elements in a job. Jobs with no pending elements terminate successfully. +} + +func NewElementManager(config Config) *ElementManager { + return &ElementManager{ + config: config, + stages: map[string]*stageState{}, + consumers: map[string][]string{}, + sideConsumers: map[string][]string{}, + pcolParents: map[string]string{}, + watermarkRefreshes: set[string]{}, + inprogressBundles: set[string]{}, + refreshCond: sync.Cond{L: &sync.Mutex{}}, + } +} + +// AddStage adds a stage to this element manager, connecting it's PCollections and +// nodes to the watermark propagation graph. +func (em *ElementManager) AddStage(ID string, inputIDs, sides, outputIDs []string) { + slog.Debug("AddStage", slog.String("ID", ID), slog.Any("inputs", inputIDs), slog.Any("sides", sides), slog.Any("outputs", outputIDs)) + ss := makeStageState(ID, inputIDs, sides, outputIDs) + + em.stages[ss.ID] = ss + for _, outputIDs := range ss.outputIDs { + em.pcolParents[outputIDs] = ss.ID + } + for _, input := range inputIDs { + em.consumers[input] = append(em.consumers[input], ss.ID) + } + for _, side := range ss.sides { + em.sideConsumers[side] = append(em.sideConsumers[side], ss.ID) + } +} + +// StageAggregates marks the given stage as an aggregation, which +// means elements will only be processed based on windowing strategies. +func (em *ElementManager) StageAggregates(ID string) { + em.stages[ID].aggregate = true +} + +// Impulse marks and initializes the given stage as an impulse which +// is a root transform that starts processing. +func (em *ElementManager) Impulse(stageID string) { + stage := em.stages[stageID] + newPending := []element{{ + window: window.GlobalWindow{}, + timestamp: mtime.MinTimestamp, + pane: typex.NoFiringPane(), + elmBytes: []byte{0}, // Represents an encoded 0 length byte slice. + }} + + consumers := em.consumers[stage.outputIDs[0]] + slog.Debug("Impulse", slog.String("stageID", stageID), slog.Any("outputs", stage.outputIDs), slog.Any("consumers", consumers)) + + em.pendingElements.Add(len(consumers)) + for _, sID := range consumers { + consumer := em.stages[sID] + consumer.AddPending(newPending) + } + refreshes := stage.updateWatermarks(mtime.MaxTimestamp, mtime.MaxTimestamp, em) + em.addRefreshes(refreshes) +} + +type RunBundle struct { + StageID string + BundleID string + Watermark mtime.Time +} + +func (rb RunBundle) LogValue() slog.Value { + return slog.GroupValue( + slog.String("ID", rb.BundleID), + slog.String("stage", rb.StageID), + slog.Time("watermark", rb.Watermark.ToTime())) +} + +// Bundles is the core execution loop. It produces a sequences of bundles able to be executed. +// The returned channel is closed when the context is canceled, or there are no pending elements +// remaining. +func (em *ElementManager) Bundles(ctx context.Context, nextBundID func() string) <-chan RunBundle { + runStageCh := make(chan RunBundle) + ctx, cancelFn := context.WithCancel(ctx) + go func() { + em.pendingElements.Wait() + slog.Info("no more pending elements: terminating pipeline") + cancelFn() + // Ensure the watermark evaluation goroutine exits. + em.refreshCond.Broadcast() + }() + // Watermark evaluation goroutine. + go func() { + defer close(runStageCh) + for { + em.refreshCond.L.Lock() + // If there are no watermark refreshes available, we wait until there are. + for len(em.watermarkRefreshes) == 0 { + // Check to see if we must exit + select { + case <-ctx.Done(): + em.refreshCond.L.Unlock() + return + default: + } + em.refreshCond.Wait() // until watermarks may have changed. + } + + // We know there is some work we can do that may advance the watermarks, + // refresh them, and see which stages have advanced. + advanced := em.refreshWatermarks() + + // Check each advanced stage, to see if it's able to execute based on the watermark. + for stageID := range advanced { + ss := em.stages[stageID] + watermark, ready := ss.bundleReady(em) + if ready { + bundleID, ok := ss.startBundle(watermark, nextBundID) + if !ok { + continue + } + rb := RunBundle{StageID: stageID, BundleID: bundleID, Watermark: watermark} + + em.inprogressBundles.insert(rb.BundleID) + em.refreshCond.L.Unlock() + + select { + case <-ctx.Done(): + return + case runStageCh <- rb: + } + em.refreshCond.L.Lock() + } + } + em.refreshCond.L.Unlock() + } + }() + return runStageCh +} + +// InputForBundle returns pre-allocated data for the given bundle, encoding the elements using +// the PCollection's coders. +func (em *ElementManager) InputForBundle(rb RunBundle, info PColInfo) [][]byte { + ss := em.stages[rb.StageID] + ss.mu.Lock() + defer ss.mu.Unlock() + es := ss.inprogress[rb.BundleID] + return es.ToData(info) +} + +// PersistBundle uses the tentative bundle output to update the watermarks for the stage. +// Each stage has two monotonically increasing watermarks, the input watermark, and the output +// watermark. +// +// MAX(CurrentInputWatermark, MIN(PendingElements, InputPCollectionWatermarks) +// MAX(CurrentOutputWatermark, MIN(InputWatermark, WatermarkHolds)) +// +// PersistBundle takes in the stage ID, ID of the bundle associated with the pending +// input elements, and the committed output elements. +func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PColInfo, d TentativeData, inputInfo PColInfo, residuals [][]byte, estimatedOWM map[string]mtime.Time) { + stage := em.stages[rb.StageID] + for output, data := range d.Raw { + info := col2Coders[output] + var newPending []element + slog.Debug("PersistBundle: processing output", "bundle", rb, slog.String("output", output)) + for _, datum := range data { + buf := bytes.NewBuffer(datum) + if len(datum) == 0 { + panic(fmt.Sprintf("zero length data for %v: ", output)) + } + for { + var rawBytes bytes.Buffer + tee := io.TeeReader(buf, &rawBytes) + ws, et, pn, err := exec.DecodeWindowedValueHeader(info.WDec, tee) + if err != nil { + if err == io.EOF { + break + } + slog.Error("PersistBundle: error decoding watermarks", err, "bundle", rb, slog.String("output", output)) + panic("error decoding watermarks") + } + // TODO: Optimize unnecessary copies. This is doubleteeing. + elmBytes := info.EDec(tee) + for _, w := range ws { + newPending = append(newPending, + element{ + window: w, + timestamp: et, + pane: pn, + elmBytes: elmBytes, + }) + } + } + } + consumers := em.consumers[output] + slog.Debug("PersistBundle: bundle has downstream consumers.", "bundle", rb, slog.Int("newPending", len(newPending)), "consumers", consumers) + for _, sID := range consumers { + em.pendingElements.Add(len(newPending)) + consumer := em.stages[sID] + consumer.AddPending(newPending) + } + } + + // Return unprocessed to this stage's pending + var unprocessedElements []element + for _, residual := range residuals { + buf := bytes.NewBuffer(residual) + ws, et, pn, err := exec.DecodeWindowedValueHeader(inputInfo.WDec, buf) + if err != nil { + if err == io.EOF { + break + } + slog.Error("PersistBundle: error decoding residual header", err, "bundle", rb) + panic("error decoding residual header") + } + + for _, w := range ws { + unprocessedElements = append(unprocessedElements, + element{ + window: w, + timestamp: et, + pane: pn, + elmBytes: buf.Bytes(), + }) + } + } + // Add unprocessed back to the pending stack. + if len(unprocessedElements) > 0 { + em.pendingElements.Add(len(unprocessedElements)) + stage.AddPending(unprocessedElements) + } + // Clear out the inprogress elements associated with the completed bundle. + // Must be done after adding the new pending elements to avoid an incorrect + // watermark advancement. + stage.mu.Lock() + completed := stage.inprogress[rb.BundleID] + em.pendingElements.Add(-len(completed.es)) + delete(stage.inprogress, rb.BundleID) + // If there are estimated output watermarks, set the estimated + // output watermark for the stage. + if len(estimatedOWM) > 0 { + estimate := mtime.MaxTimestamp + for _, t := range estimatedOWM { + estimate = mtime.Min(estimate, t) + } + stage.estimatedOutput = estimate + } + stage.mu.Unlock() + + // TODO support state/timer watermark holds. + em.addRefreshAndClearBundle(stage.ID, rb.BundleID) +} + +func (em *ElementManager) addRefreshes(stages set[string]) { + em.refreshCond.L.Lock() + defer em.refreshCond.L.Unlock() + em.watermarkRefreshes.merge(stages) + em.refreshCond.Broadcast() +} + +func (em *ElementManager) addRefreshAndClearBundle(stageID, bundID string) { + em.refreshCond.L.Lock() + defer em.refreshCond.L.Unlock() + delete(em.inprogressBundles, bundID) + em.watermarkRefreshes.insert(stageID) + em.refreshCond.Broadcast() +} + +// refreshWatermarks incrementally refreshes the watermarks, and returns the set of stages where the +// the watermark may have advanced. +// Must be called while holding em.refreshCond.L +func (em *ElementManager) refreshWatermarks() set[string] { + // Need to have at least one refresh signal. + nextUpdates := set[string]{} + refreshed := set[string]{} + var i int + for stageID := range em.watermarkRefreshes { + // clear out old one. + em.watermarkRefreshes.remove(stageID) + ss := em.stages[stageID] + refreshed.insert(stageID) + + dummyStateHold := mtime.MaxTimestamp + + refreshes := ss.updateWatermarks(ss.minPendingTimestamp(), dummyStateHold, em) + nextUpdates.merge(refreshes) + // cap refreshes incrementally. + if i < 10 { + i++ + } else { + break + } + } + em.watermarkRefreshes.merge(nextUpdates) + return refreshed +} + +type set[K comparable] map[K]struct{} + +func (s set[K]) remove(k K) { + delete(s, k) +} + +func (s set[K]) insert(k K) { + s[k] = struct{}{} +} + +func (s set[K]) merge(o set[K]) { + for k := range o { + s.insert(k) + } +} + +// stageState is the internal watermark and input tracking for a stage. +type stageState struct { + ID string + inputID string // PCollection ID of the parallel input + outputIDs []string // PCollection IDs of outputs to update consumers. + sides []string // PCollection IDs of side inputs that can block execution. + + // Special handling bits + aggregate bool // whether this state needs to block for aggregation. + strat winStrat // Windowing Strategy for aggregation fireings. + + mu sync.Mutex + upstreamWatermarks sync.Map // watermark set from inputPCollection's parent. + input mtime.Time // input watermark for the parallel input. + output mtime.Time // Output watermark for the whole stage + estimatedOutput mtime.Time // Estimated watermark output from DoFns + + pending elementHeap // pending input elements for this stage that are to be processesd + inprogress map[string]elements // inprogress elements by active bundles, keyed by bundle +} + +// makeStageState produces an initialized stageState. +func makeStageState(ID string, inputIDs, sides, outputIDs []string) *stageState { + ss := &stageState{ + ID: ID, + outputIDs: outputIDs, + sides: sides, + strat: defaultStrat{}, + + input: mtime.MinTimestamp, + output: mtime.MinTimestamp, + estimatedOutput: mtime.MinTimestamp, + } + + // Initialize the upstream watermarks to minTime. + for _, pcol := range inputIDs { + ss.upstreamWatermarks.Store(pcol, mtime.MinTimestamp) + } + if len(inputIDs) == 1 { + ss.inputID = inputIDs[0] + } + return ss +} + +// AddPending adds elements to the pending heap. +func (ss *stageState) AddPending(newPending []element) { + ss.mu.Lock() + defer ss.mu.Unlock() + ss.pending = append(ss.pending, newPending...) + heap.Init(&ss.pending) +} + +// updateUpstreamWatermark is for the parent of the input pcollection +// to call, to update downstream stages with it's current watermark. +// This avoids downstream stages inverting lock orderings from +// calling their parent stage to get their input pcollection's watermark. +func (ss *stageState) updateUpstreamWatermark(pcol string, upstream mtime.Time) { + // A stage will only have a single upstream watermark, so + // we simply set this. + ss.upstreamWatermarks.Store(pcol, upstream) +} + +// UpstreamWatermark gets the minimum value of all upstream watermarks. +func (ss *stageState) UpstreamWatermark() (string, mtime.Time) { + upstream := mtime.MaxTimestamp + var name string + ss.upstreamWatermarks.Range(func(key, val any) bool { + // Use <= to ensure if available we get a name. + if val.(mtime.Time) <= upstream { + upstream = val.(mtime.Time) + name = key.(string) + } + return true + }) + return name, upstream +} + +// InputWatermark gets the current input watermark for the stage. +func (ss *stageState) InputWatermark() mtime.Time { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.input +} + +// OutputWatermark gets the current output watermark for the stage. +func (ss *stageState) OutputWatermark() mtime.Time { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.output +} + +// startBundle initializes a bundle with elements if possible. +// A bundle only starts if there are elements at all, and if it's +// an aggregation stage, if the windowing stratgy allows it. +func (ss *stageState) startBundle(watermark mtime.Time, genBundID func() string) (string, bool) { + defer func() { + if e := recover(); e != nil { + panic(fmt.Sprintf("generating bundle for stage %v at %v panicked\n%v", ss.ID, watermark, e)) + } + }() + ss.mu.Lock() + defer ss.mu.Unlock() + + var toProcess, notYet []element + for _, e := range ss.pending { + if !ss.aggregate || ss.aggregate && ss.strat.EarliestCompletion(e.window) <= watermark { + toProcess = append(toProcess, e) + } else { + notYet = append(notYet, e) + } + } + ss.pending = notYet + heap.Init(&ss.pending) + + if len(toProcess) == 0 { + return "", false + } + // Is THIS is where basic splits should happen/per element processing? + es := elements{ + es: toProcess, + minTimestamp: toProcess[0].timestamp, + } + if ss.inprogress == nil { + ss.inprogress = make(map[string]elements) + } + bundID := genBundID() + ss.inprogress[bundID] = es + return bundID, true +} + +// minimumPendingTimestamp returns the minimum pending timestamp from all pending elements, +// including in progress ones. +// +// Assumes that the pending heap is initialized if it's not empty. +func (ss *stageState) minPendingTimestamp() mtime.Time { + ss.mu.Lock() + defer ss.mu.Unlock() + minPending := mtime.MaxTimestamp + if len(ss.pending) != 0 { + minPending = ss.pending[0].timestamp + } + for _, es := range ss.inprogress { + minPending = mtime.Min(minPending, es.minTimestamp) + } + return minPending +} + +func (ss *stageState) String() string { + pcol, up := ss.UpstreamWatermark() + return fmt.Sprintf("[%v] IN: %v OUT: %v UP: %q %v, aggregation: %v", ss.ID, ss.input, ss.output, pcol, up, ss.aggregate) +} + +// updateWatermarks performs the following operations: +// +// Watermark_In' = MAX(Watermark_In, MIN(U(TS_Pending), U(Watermark_InputPCollection))) +// Watermark_Out' = MAX(Watermark_Out, MIN(Watermark_In', U(StateHold))) +// Watermark_PCollection = Watermark_Out_ProducingPTransform +func (ss *stageState) updateWatermarks(minPending, minStateHold mtime.Time, em *ElementManager) set[string] { + ss.mu.Lock() + defer ss.mu.Unlock() + + // PCollection watermarks are based on their parents's output watermark. + _, newIn := ss.UpstreamWatermark() + + // Set the input watermark based on the minimum pending elements, + // and the current input pcollection watermark. + if minPending < newIn { + newIn = minPending + } + + // If bigger, advance the input watermark. + if newIn > ss.input { + ss.input = newIn + } + // The output starts with the new input as the basis. + newOut := ss.input + + // If we're given an estimate, and it's further ahead, we use that instead. + if ss.estimatedOutput > ss.output { + newOut = ss.estimatedOutput + } + + // We adjust based on the minimum state hold. + if minStateHold < newOut { + newOut = minStateHold + } + refreshes := set[string]{} + // If bigger, advance the output watermark + if newOut > ss.output { + ss.output = newOut + for _, outputCol := range ss.outputIDs { + consumers := em.consumers[outputCol] + + for _, sID := range consumers { + em.stages[sID].updateUpstreamWatermark(outputCol, ss.output) + refreshes.insert(sID) + } + // Inform side input consumers, but don't update the upstream watermark. + for _, sID := range em.sideConsumers[outputCol] { + refreshes.insert(sID) + } + } + } + return refreshes +} + +// bundleReady returns the maximum allowed watermark for this stage, and whether +// it's permitted to execute by side inputs. +func (ss *stageState) bundleReady(em *ElementManager) (mtime.Time, bool) { + ss.mu.Lock() + defer ss.mu.Unlock() + // If the upstream watermark and the input watermark are the same, + // then we can't yet process this stage. + inputW := ss.input + _, upstreamW := ss.UpstreamWatermark() + if inputW == upstreamW { + slog.Debug("bundleReady: insufficient upstream watermark", + slog.String("stage", ss.ID), + slog.Group("watermark", + slog.Any("upstream", upstreamW), + slog.Any("input", inputW))) + return mtime.MinTimestamp, false + } + ready := true + for _, side := range ss.sides { + pID := em.pcolParents[side] + parent := em.stages[pID] + ow := parent.OutputWatermark() + if upstreamW > ow { + ready = false + } + } + return upstreamW, ready +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go new file mode 100644 index 0000000000000..ddfdd5b8816c7 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager_test.go @@ -0,0 +1,516 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package engine + +import ( + "container/heap" + "context" + "fmt" + "io" + "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "github.com/google/go-cmp/cmp" +) + +func TestElementHeap(t *testing.T) { + elements := elementHeap{ + element{timestamp: mtime.EndOfGlobalWindowTime}, + element{timestamp: mtime.MaxTimestamp}, + element{timestamp: 3}, + element{timestamp: mtime.MinTimestamp}, + element{timestamp: 2}, + element{timestamp: mtime.ZeroTimestamp}, + element{timestamp: 1}, + } + heap.Init(&elements) + heap.Push(&elements, element{timestamp: 4}) + + if got, want := elements.Len(), len(elements); got != want { + t.Errorf("elements.Len() = %v, want %v", got, want) + } + if got, want := elements[0].timestamp, mtime.MinTimestamp; got != want { + t.Errorf("elements[0].timestamp = %v, want %v", got, want) + } + + wanted := []mtime.Time{mtime.MinTimestamp, mtime.ZeroTimestamp, 1, 2, 3, 4, mtime.EndOfGlobalWindowTime, mtime.MaxTimestamp} + for i, want := range wanted { + if got := heap.Pop(&elements).(element).timestamp; got != want { + t.Errorf("[%d] heap.Pop(&elements).(element).timestamp = %v, want %v", i, got, want) + } + } +} + +func TestStageState_minPendingTimestamp(t *testing.T) { + + newState := func() *stageState { + return makeStageState("test", []string{"testInput"}, nil, []string{"testOutput"}) + } + t.Run("noElements", func(t *testing.T) { + ss := newState() + got := ss.minPendingTimestamp() + want := mtime.MaxTimestamp + if got != want { + t.Errorf("ss.minPendingTimestamp() = %v, want %v", got, want) + } + }) + + want := mtime.ZeroTimestamp - 20 + t.Run("onlyPending", func(t *testing.T) { + ss := newState() + ss.pending = elementHeap{ + element{timestamp: mtime.EndOfGlobalWindowTime}, + element{timestamp: mtime.MaxTimestamp}, + element{timestamp: 3}, + element{timestamp: want}, + element{timestamp: 2}, + element{timestamp: mtime.ZeroTimestamp}, + element{timestamp: 1}, + } + heap.Init(&ss.pending) + + got := ss.minPendingTimestamp() + if got != want { + t.Errorf("ss.minPendingTimestamp() = %v, want %v", got, want) + } + }) + + t.Run("onlyInProgress", func(t *testing.T) { + ss := newState() + ss.inprogress = map[string]elements{ + "a": { + es: []element{ + {timestamp: mtime.EndOfGlobalWindowTime}, + {timestamp: mtime.MaxTimestamp}, + }, + minTimestamp: mtime.EndOfGlobalWindowTime, + }, + "b": { + es: []element{ + {timestamp: 3}, + {timestamp: want}, + {timestamp: 2}, + {timestamp: 1}, + }, + minTimestamp: want, + }, + "c": { + es: []element{ + {timestamp: mtime.ZeroTimestamp}, + }, + minTimestamp: mtime.ZeroTimestamp, + }, + } + + got := ss.minPendingTimestamp() + if got != want { + t.Errorf("ss.minPendingTimestamp() = %v, want %v", got, want) + } + }) + + t.Run("minInPending", func(t *testing.T) { + ss := newState() + ss.pending = elementHeap{ + {timestamp: 3}, + {timestamp: want}, + {timestamp: 2}, + {timestamp: 1}, + } + heap.Init(&ss.pending) + ss.inprogress = map[string]elements{ + "a": { + es: []element{ + {timestamp: mtime.EndOfGlobalWindowTime}, + {timestamp: mtime.MaxTimestamp}, + }, + minTimestamp: mtime.EndOfGlobalWindowTime, + }, + "c": { + es: []element{ + {timestamp: mtime.ZeroTimestamp}, + }, + minTimestamp: mtime.ZeroTimestamp, + }, + } + + got := ss.minPendingTimestamp() + if got != want { + t.Errorf("ss.minPendingTimestamp() = %v, want %v", got, want) + } + }) + t.Run("minInProgress", func(t *testing.T) { + ss := newState() + ss.pending = elementHeap{ + {timestamp: 3}, + {timestamp: 2}, + {timestamp: 1}, + } + heap.Init(&ss.pending) + ss.inprogress = map[string]elements{ + "a": { + es: []element{ + {timestamp: want}, + {timestamp: mtime.EndOfGlobalWindowTime}, + {timestamp: mtime.MaxTimestamp}, + }, + minTimestamp: want, + }, + "c": { + es: []element{ + {timestamp: mtime.ZeroTimestamp}, + }, + minTimestamp: mtime.ZeroTimestamp, + }, + } + + got := ss.minPendingTimestamp() + if got != want { + t.Errorf("ss.minPendingTimestamp() = %v, want %v", got, want) + } + }) +} + +func TestStageState_UpstreamWatermark(t *testing.T) { + impulse := makeStageState("impulse", nil, nil, []string{"output"}) + _, up := impulse.UpstreamWatermark() + if got, want := up, mtime.MaxTimestamp; got != want { + t.Errorf("impulse.UpstreamWatermark() = %v, want %v", got, want) + } + + dofn := makeStageState("dofn", []string{"input"}, nil, []string{"output"}) + dofn.updateUpstreamWatermark("input", 42) + + _, up = dofn.UpstreamWatermark() + if got, want := up, mtime.Time(42); got != want { + t.Errorf("dofn.UpstreamWatermark() = %v, want %v", got, want) + } + + flatten := makeStageState("flatten", []string{"a", "b", "c"}, nil, []string{"output"}) + flatten.updateUpstreamWatermark("a", 50) + flatten.updateUpstreamWatermark("b", 42) + flatten.updateUpstreamWatermark("c", 101) + _, up = flatten.UpstreamWatermark() + if got, want := up, mtime.Time(42); got != want { + t.Errorf("flatten.UpstreamWatermark() = %v, want %v", got, want) + } +} + +func TestStageState_updateWatermarks(t *testing.T) { + inputCol := "testInput" + outputCol := "testOutput" + newState := func() (*stageState, *stageState, *ElementManager) { + underTest := makeStageState("underTest", []string{inputCol}, nil, []string{outputCol}) + outStage := makeStageState("outStage", []string{outputCol}, nil, nil) + em := &ElementManager{ + consumers: map[string][]string{ + inputCol: {underTest.ID}, + outputCol: {outStage.ID}, + }, + stages: map[string]*stageState{ + outStage.ID: outStage, + underTest.ID: underTest, + }, + } + return underTest, outStage, em + } + + tests := []struct { + name string + initInput, initOutput mtime.Time + upstream, minPending, minStateHold mtime.Time + wantInput, wantOutput, wantDownstream mtime.Time + }{ + { + name: "initialized", + initInput: mtime.MinTimestamp, + initOutput: mtime.MinTimestamp, + upstream: mtime.MinTimestamp, + minPending: mtime.EndOfGlobalWindowTime, + minStateHold: mtime.EndOfGlobalWindowTime, + wantInput: mtime.MinTimestamp, // match default + wantOutput: mtime.MinTimestamp, // match upstream + wantDownstream: mtime.MinTimestamp, // match upstream + }, { + name: "upstream", + initInput: mtime.MinTimestamp, + initOutput: mtime.MinTimestamp, + upstream: mtime.ZeroTimestamp, + minPending: mtime.EndOfGlobalWindowTime, + minStateHold: mtime.EndOfGlobalWindowTime, + wantInput: mtime.ZeroTimestamp, // match upstream + wantOutput: mtime.ZeroTimestamp, // match upstream + wantDownstream: mtime.ZeroTimestamp, // match upstream + }, { + name: "useMinPending", + initInput: mtime.MinTimestamp, + initOutput: mtime.MinTimestamp, + upstream: mtime.ZeroTimestamp, + minPending: -20, + minStateHold: mtime.EndOfGlobalWindowTime, + wantInput: -20, // match minPending + wantOutput: -20, // match minPending + wantDownstream: -20, // match minPending + }, { + name: "useStateHold", + initInput: mtime.MinTimestamp, + initOutput: mtime.MinTimestamp, + upstream: mtime.ZeroTimestamp, + minPending: -20, + minStateHold: -30, + wantInput: -20, // match minPending + wantOutput: -30, // match state hold + wantDownstream: -30, // match state hold + }, { + name: "noAdvance", + initInput: 20, + initOutput: 30, + upstream: mtime.MinTimestamp, + wantInput: 20, // match original input + wantOutput: 30, // match original output + wantDownstream: mtime.MinTimestamp, // not propagated + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ss, outStage, em := newState() + ss.input = test.initInput + ss.output = test.initOutput + ss.updateUpstreamWatermark(inputCol, test.upstream) + ss.updateWatermarks(test.minPending, test.minStateHold, em) + if got, want := ss.input, test.wantInput; got != want { + pcol, up := ss.UpstreamWatermark() + t.Errorf("ss.updateWatermarks(%v,%v); ss.input = %v, want %v (upstream %v %v)", test.minPending, test.minStateHold, got, want, pcol, up) + } + if got, want := ss.output, test.wantOutput; got != want { + pcol, up := ss.UpstreamWatermark() + t.Errorf("ss.updateWatermarks(%v,%v); ss.output = %v, want %v (upstream %v %v)", test.minPending, test.minStateHold, got, want, pcol, up) + } + _, up := outStage.UpstreamWatermark() + if got, want := up, test.wantDownstream; got != want { + t.Errorf("outStage.UpstreamWatermark() = %v, want %v", got, want) + } + }) + } + +} + +func TestElementManager(t *testing.T) { + t.Run("impulse", func(t *testing.T) { + em := NewElementManager(Config{}) + em.AddStage("impulse", nil, nil, []string{"output"}) + em.AddStage("dofn", []string{"output"}, nil, nil) + + em.Impulse("impulse") + + if got, want := em.stages["impulse"].OutputWatermark(), mtime.MaxTimestamp; got != want { + t.Fatalf("impulse.OutputWatermark() = %v, want %v", got, want) + } + + var i int + ch := em.Bundles(context.Background(), func() string { + defer func() { i++ }() + return fmt.Sprintf("%v", i) + }) + rb, ok := <-ch + if !ok { + t.Error("Bundles channel unexpectedly closed") + } + if got, want := rb.StageID, "dofn"; got != want { + t.Errorf("stage to execute = %v, want %v", got, want) + } + em.PersistBundle(rb, nil, TentativeData{}, PColInfo{}, nil, nil) + _, ok = <-ch + if ok { + t.Error("Bundles channel expected to be closed") + } + if got, want := i, 1; got != want { + t.Errorf("got %v bundles, want %v", got, want) + } + }) + + info := PColInfo{ + GlobalID: "generic_info", // GlobalID isn't used except for debugging. + WDec: exec.MakeWindowDecoder(coder.NewGlobalWindow()), + WEnc: exec.MakeWindowEncoder(coder.NewGlobalWindow()), + EDec: func(r io.Reader) []byte { + b, err := io.ReadAll(r) + if err != nil { + t.Fatalf("error decoding \"generic_info\" data:%v", err) + } + return b + }, + } + es := elements{ + es: []element{{ + window: window.GlobalWindow{}, + timestamp: mtime.MinTimestamp, + pane: typex.NoFiringPane(), + elmBytes: []byte{3, 65, 66, 67}, // "ABC" + }}, + minTimestamp: mtime.MinTimestamp, + } + + t.Run("dofn", func(t *testing.T) { + em := NewElementManager(Config{}) + em.AddStage("impulse", nil, nil, []string{"input"}) + em.AddStage("dofn1", []string{"input"}, nil, []string{"output"}) + em.AddStage("dofn2", []string{"output"}, nil, nil) + em.Impulse("impulse") + + var i int + ch := em.Bundles(context.Background(), func() string { + defer func() { i++ }() + t.Log("generating bundle", i) + return fmt.Sprintf("%v", i) + }) + rb, ok := <-ch + if !ok { + t.Error("Bundles channel unexpectedly closed") + } + t.Log("received bundle", i) + + td := TentativeData{} + for _, d := range es.ToData(info) { + td.WriteData("output", d) + } + outputCoders := map[string]PColInfo{ + "output": info, + } + + em.PersistBundle(rb, outputCoders, td, info, nil, nil) + rb, ok = <-ch + if !ok { + t.Error("Bundles channel not expected to be closed") + } + // Check the data is what's expected: + data := em.InputForBundle(rb, info) + if got, want := len(data), 1; got != want { + t.Errorf("data len = %v, want %v", got, want) + } + if !cmp.Equal([]byte{127, 223, 59, 100, 90, 28, 172, 9, 0, 0, 0, 1, 15, 3, 65, 66, 67}, data[0]) { + t.Errorf("unexpected data, got %v", data[0]) + } + em.PersistBundle(rb, outputCoders, TentativeData{}, info, nil, nil) + rb, ok = <-ch + if ok { + t.Error("Bundles channel expected to be closed", rb) + } + + if got, want := i, 2; got != want { + t.Errorf("got %v bundles, want %v", got, want) + } + }) + + t.Run("side", func(t *testing.T) { + em := NewElementManager(Config{}) + em.AddStage("impulse", nil, nil, []string{"input"}) + em.AddStage("dofn1", []string{"input"}, nil, []string{"output"}) + em.AddStage("dofn2", []string{"input"}, []string{"output"}, nil) + em.Impulse("impulse") + + var i int + ch := em.Bundles(context.Background(), func() string { + defer func() { i++ }() + t.Log("generating bundle", i) + return fmt.Sprintf("%v", i) + }) + rb, ok := <-ch + if !ok { + t.Error("Bundles channel unexpectedly closed") + } + t.Log("received bundle", i) + + if got, want := rb.StageID, "dofn1"; got != want { + t.Fatalf("stage to execute = %v, want %v", got, want) + } + + td := TentativeData{} + for _, d := range es.ToData(info) { + td.WriteData("output", d) + } + outputCoders := map[string]PColInfo{ + "output": info, + "input": info, + "impulse": info, + } + + em.PersistBundle(rb, outputCoders, td, info, nil, nil) + rb, ok = <-ch + if !ok { + t.Fatal("Bundles channel not expected to be closed") + } + if got, want := rb.StageID, "dofn2"; got != want { + t.Fatalf("stage to execute = %v, want %v", got, want) + } + em.PersistBundle(rb, outputCoders, TentativeData{}, info, nil, nil) + rb, ok = <-ch + if ok { + t.Error("Bundles channel expected to be closed") + } + + if got, want := i, 2; got != want { + t.Errorf("got %v bundles, want %v", got, want) + } + }) + t.Run("residual", func(t *testing.T) { + em := NewElementManager(Config{}) + em.AddStage("impulse", nil, nil, []string{"input"}) + em.AddStage("dofn", []string{"input"}, nil, nil) + em.Impulse("impulse") + + var i int + ch := em.Bundles(context.Background(), func() string { + defer func() { i++ }() + t.Log("generating bundle", i) + return fmt.Sprintf("%v", i) + }) + rb, ok := <-ch + if !ok { + t.Error("Bundles channel unexpectedly closed") + } + t.Log("received bundle", i) + + // Add a residual + resid := es.ToData(info) + em.PersistBundle(rb, nil, TentativeData{}, info, resid, nil) + rb, ok = <-ch + if !ok { + t.Error("Bundles channel not expected to be closed") + } + // Check the data is what's expected: + data := em.InputForBundle(rb, info) + if got, want := len(data), 1; got != want { + t.Errorf("data len = %v, want %v", got, want) + } + if !cmp.Equal([]byte{127, 223, 59, 100, 90, 28, 172, 9, 0, 0, 0, 1, 15, 3, 65, 66, 67}, data[0]) { + t.Errorf("unexpected data, got %v", data[0]) + } + em.PersistBundle(rb, nil, TentativeData{}, info, nil, nil) + rb, ok = <-ch + if ok { + t.Error("Bundles channel expected to be closed", rb) + } + + if got, want := i, 2; got != want { + t.Errorf("got %v bundles, want %v", got, want) + } + }) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/strategy.go b/sdks/go/pkg/beam/runners/prism/internal/engine/strategy.go new file mode 100644 index 0000000000000..44e6064958c09 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/strategy.go @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package engine + +import ( + "fmt" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" +) + +type winStrat interface { + EarliestCompletion(typex.Window) mtime.Time +} + +type defaultStrat struct{} + +func (ws defaultStrat) EarliestCompletion(w typex.Window) mtime.Time { + return w.MaxTimestamp() +} + +func (defaultStrat) String() string { + return "default" +} + +type sessionStrat struct { + GapSize time.Duration +} + +func (ws sessionStrat) EarliestCompletion(w typex.Window) mtime.Time { + return w.MaxTimestamp().Add(ws.GapSize) +} + +func (ws sessionStrat) String() string { + return fmt.Sprintf("session[GapSize:%v]", ws.GapSize) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/strategy_test.go b/sdks/go/pkg/beam/runners/prism/internal/engine/strategy_test.go new file mode 100644 index 0000000000000..9d558396f8067 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/strategy_test.go @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package engine + +import ( + "testing" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" +) + +func TestEarliestCompletion(t *testing.T) { + tests := []struct { + strat winStrat + input typex.Window + want mtime.Time + }{ + {defaultStrat{}, window.GlobalWindow{}, mtime.EndOfGlobalWindowTime}, + {defaultStrat{}, window.IntervalWindow{Start: 0, End: 4}, 3}, + {defaultStrat{}, window.IntervalWindow{Start: mtime.MinTimestamp, End: mtime.MaxTimestamp}, mtime.MaxTimestamp - 1}, + {sessionStrat{}, window.IntervalWindow{Start: 0, End: 4}, 3}, + {sessionStrat{GapSize: 3 * time.Millisecond}, window.IntervalWindow{Start: 0, End: 4}, 6}, + } + + for _, test := range tests { + if got, want := test.strat.EarliestCompletion(test.input), test.want; got != want { + t.Errorf("%v.EarliestCompletion(%v)) = %v, want %v", test.strat, test.input, got, want) + } + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go new file mode 100644 index 0000000000000..b317b8e9f2123 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -0,0 +1,676 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "context" + "fmt" + "io" + "sort" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" + "golang.org/x/exp/maps" + "golang.org/x/exp/slog" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/proto" +) + +func executePipeline(ctx context.Context, wk *worker.W, j *jobservices.Job) { + pipeline := j.Pipeline + comps := proto.Clone(pipeline.GetComponents()).(*pipepb.Components) + + // TODO, configure the preprocessor from pipeline options. + // Maybe change these returns to a single struct for convenience and further + // annotation? + + handlers := []any{ + Combine(CombineCharacteristic{EnableLifting: true}), + ParDo(ParDoCharacteristic{DisableSDF: true}), + Runner(RunnerCharacteristic{ + SDKFlatten: false, + }), + } + + proc := processor{ + transformExecuters: map[string]transformExecuter{}, + } + + var preppers []transformPreparer + for _, h := range handlers { + if th, ok := h.(transformPreparer); ok { + preppers = append(preppers, th) + } + if th, ok := h.(transformExecuter); ok { + for _, urn := range th.ExecuteUrns() { + proc.transformExecuters[urn] = th + } + } + } + + prepro := newPreprocessor(preppers) + + topo := prepro.preProcessGraph(comps) + ts := comps.GetTransforms() + + em := engine.NewElementManager(engine.Config{}) + + // This is where the Batch -> Streaming tension exists. + // We don't *pre* do this, and we need a different mechanism + // to sort out processing order. + stages := map[string]*stage{} + var impulses []string + for i, stage := range topo { + if len(stage.transforms) != 1 { + panic(fmt.Sprintf("unsupported stage[%d]: contains multiple transforms: %v; TODO: implement fusion", i, stage.transforms)) + } + tid := stage.transforms[0] + t := ts[tid] + urn := t.GetSpec().GetUrn() + stage.exe = proc.transformExecuters[urn] + + // Stopgap until everythinng's moved to handlers. + stage.envID = t.GetEnvironmentId() + if stage.exe != nil { + stage.envID = stage.exe.ExecuteWith(t) + } + stage.ID = wk.NextStage() + + switch stage.envID { + case "": // Runner Transforms + + var onlyOut string + for _, out := range t.GetOutputs() { + onlyOut = out + } + stage.OutputsToCoders = map[string]engine.PColInfo{} + coders := map[string]*pipepb.Coder{} + makeWindowedValueCoder(onlyOut, comps, coders) + + col := comps.GetPcollections()[onlyOut] + ed := collectionPullDecoder(col.GetCoderId(), coders, comps) + wDec, wEnc := getWindowValueCoders(comps, col, coders) + + stage.OutputsToCoders[onlyOut] = engine.PColInfo{ + GlobalID: onlyOut, + WDec: wDec, + WEnc: wEnc, + EDec: ed, + } + + // There's either 0, 1 or many inputs, but they should be all the same + // so break after the first one. + for _, global := range t.GetInputs() { + col := comps.GetPcollections()[global] + ed := collectionPullDecoder(col.GetCoderId(), coders, comps) + wDec, wEnc := getWindowValueCoders(comps, col, coders) + stage.inputInfo = engine.PColInfo{ + GlobalID: global, + WDec: wDec, + WEnc: wEnc, + EDec: ed, + } + break + } + + switch urn { + case urns.TransformGBK: + em.AddStage(stage.ID, []string{getOnlyValue(t.GetInputs())}, nil, []string{getOnlyValue(t.GetOutputs())}) + for _, global := range t.GetInputs() { + col := comps.GetPcollections()[global] + ed := collectionPullDecoder(col.GetCoderId(), coders, comps) + wDec, wEnc := getWindowValueCoders(comps, col, coders) + stage.inputInfo = engine.PColInfo{ + GlobalID: global, + WDec: wDec, + WEnc: wEnc, + EDec: ed, + } + } + em.StageAggregates(stage.ID) + case urns.TransformImpulse: + impulses = append(impulses, stage.ID) + em.AddStage(stage.ID, nil, nil, []string{getOnlyValue(t.GetOutputs())}) + case urns.TransformFlatten: + inputs := maps.Values(t.GetInputs()) + sort.Strings(inputs) + em.AddStage(stage.ID, inputs, nil, []string{getOnlyValue(t.GetOutputs())}) + } + stages[stage.ID] = stage + wk.Descriptors[stage.ID] = stage.desc + case wk.ID: + // Great! this is for this environment. // Broken abstraction. + buildStage(stage, tid, t, comps, wk) + stages[stage.ID] = stage + slog.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) + outputs := maps.Keys(stage.OutputsToCoders) + sort.Strings(outputs) + em.AddStage(stage.ID, []string{stage.mainInputPCol}, stage.sides, outputs) + default: + err := fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) + slog.Error("Execute", err) + panic(err) + } + } + + // Prime the initial impulses, since we now know what consumes them. + for _, id := range impulses { + em.Impulse(id) + } + + // Execute stages here + for rb := range em.Bundles(ctx, wk.NextInst) { + s := stages[rb.StageID] + s.Execute(j, wk, comps, em, rb) + } + slog.Info("pipeline done!", slog.String("job", j.String())) +} + +func getOnlyValue[K comparable, V any](in map[K]V) V { + if len(in) != 1 { + panic(fmt.Sprintf("expected single value map, had %v", len(in))) + } + for _, v := range in { + return v + } + panic("unreachable") +} + +func buildStage(s *stage, tid string, t *pipepb.PTransform, comps *pipepb.Components, wk *worker.W) { + s.inputTransformID = tid + "_source" + + coders := map[string]*pipepb.Coder{} + transforms := map[string]*pipepb.PTransform{ + tid: t, // The Transform to Execute! + } + + sis, err := getSideInputs(t) + if err != nil { + slog.Error("buildStage: getSide Inputs", err, slog.String("transformID", tid)) + panic(err) + } + var inputInfo engine.PColInfo + var sides []string + for local, global := range t.GetInputs() { + // This id is directly used for the source, but this also copies + // coders used by side inputs to the coders map for the bundle, so + // needs to be run for every ID. + wInCid := makeWindowedValueCoder(global, comps, coders) + _, ok := sis[local] + if ok { + sides = append(sides, global) + } else { + // this is the main input + transforms[s.inputTransformID] = sourceTransform(s.inputTransformID, portFor(wInCid, wk), global) + col := comps.GetPcollections()[global] + ed := collectionPullDecoder(col.GetCoderId(), coders, comps) + wDec, wEnc := getWindowValueCoders(comps, col, coders) + inputInfo = engine.PColInfo{ + GlobalID: global, + WDec: wDec, + WEnc: wEnc, + EDec: ed, + } + } + // We need to process all inputs to ensure we have all input coders, so we must continue. + } + + prepareSides, err := handleSideInputs(t, comps, coders, wk) + if err != nil { + slog.Error("buildStage: handleSideInputs", err, slog.String("transformID", tid)) + panic(err) + } + + // TODO: We need a new logical PCollection to represent the source + // so we can avoid double counting PCollection metrics later. + // But this also means replacing the ID for the input in the bundle. + sink2Col := map[string]string{} + col2Coders := map[string]engine.PColInfo{} + for local, global := range t.GetOutputs() { + wOutCid := makeWindowedValueCoder(global, comps, coders) + sinkID := tid + "_" + local + col := comps.GetPcollections()[global] + ed := collectionPullDecoder(col.GetCoderId(), coders, comps) + wDec, wEnc := getWindowValueCoders(comps, col, coders) + sink2Col[sinkID] = global + col2Coders[global] = engine.PColInfo{ + GlobalID: global, + WDec: wDec, + WEnc: wEnc, + EDec: ed, + } + transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, wk), global) + } + + reconcileCoders(coders, comps.GetCoders()) + + desc := &fnpb.ProcessBundleDescriptor{ + Id: s.ID, + Transforms: transforms, + WindowingStrategies: comps.GetWindowingStrategies(), + Pcollections: comps.GetPcollections(), + Coders: coders, + StateApiServiceDescriptor: &pipepb.ApiServiceDescriptor{ + Url: wk.Endpoint(), + }, + } + + s.desc = desc + s.outputCount = len(t.Outputs) + s.prepareSides = prepareSides + s.sides = sides + s.SinkToPCollection = sink2Col + s.OutputsToCoders = col2Coders + s.mainInputPCol = inputInfo.GlobalID + s.inputInfo = inputInfo + + wk.Descriptors[s.ID] = s.desc +} + +func getSideInputs(t *pipepb.PTransform) (map[string]*pipepb.SideInput, error) { + if t.GetSpec().GetUrn() != urns.TransformParDo { + return nil, nil + } + pardo := &pipepb.ParDoPayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil { + return nil, fmt.Errorf("unable to decode ParDoPayload") + } + return pardo.GetSideInputs(), nil +} + +// handleSideInputs ensures appropriate coders are available to the bundle, and prepares a function to stage the data. +func handleSideInputs(t *pipepb.PTransform, comps *pipepb.Components, coders map[string]*pipepb.Coder, wk *worker.W) (func(b *worker.B, tid string, watermark mtime.Time), error) { + sis, err := getSideInputs(t) + if err != nil { + return nil, err + } + var prepSides []func(b *worker.B, tid string, watermark mtime.Time) + + // Get WindowedValue Coders for the transform's input and output PCollections. + for local, global := range t.GetInputs() { + si, ok := sis[local] + if !ok { + continue // This is the main input. + } + + // this is a side input + switch si.GetAccessPattern().GetUrn() { + case urns.SideInputIterable: + slog.Debug("urnSideInputIterable", + slog.String("sourceTransform", t.GetUniqueName()), + slog.String("local", local), + slog.String("global", global)) + col := comps.GetPcollections()[global] + ed := collectionPullDecoder(col.GetCoderId(), coders, comps) + wDec, wEnc := getWindowValueCoders(comps, col, coders) + // May be of zero length, but that's OK. Side inputs can be empty. + + global, local := global, local + prepSides = append(prepSides, func(b *worker.B, tid string, watermark mtime.Time) { + data := wk.D.GetAllData(global) + + if b.IterableSideInputData == nil { + b.IterableSideInputData = map[string]map[string]map[typex.Window][][]byte{} + } + if _, ok := b.IterableSideInputData[tid]; !ok { + b.IterableSideInputData[tid] = map[string]map[typex.Window][][]byte{} + } + b.IterableSideInputData[tid][local] = collateByWindows(data, watermark, wDec, wEnc, + func(r io.Reader) [][]byte { + return [][]byte{ed(r)} + }, func(a, b [][]byte) [][]byte { + return append(a, b...) + }) + }) + + case urns.SideInputMultiMap: + slog.Debug("urnSideInputMultiMap", + slog.String("sourceTransform", t.GetUniqueName()), + slog.String("local", local), + slog.String("global", global)) + col := comps.GetPcollections()[global] + + kvc := comps.GetCoders()[col.GetCoderId()] + if kvc.GetSpec().GetUrn() != urns.CoderKV { + return nil, fmt.Errorf("multimap side inputs needs KV coder, got %v", kvc.GetSpec().GetUrn()) + } + + kd := collectionPullDecoder(kvc.GetComponentCoderIds()[0], coders, comps) + vd := collectionPullDecoder(kvc.GetComponentCoderIds()[1], coders, comps) + wDec, wEnc := getWindowValueCoders(comps, col, coders) + + global, local := global, local + prepSides = append(prepSides, func(b *worker.B, tid string, watermark mtime.Time) { + // May be of zero length, but that's OK. Side inputs can be empty. + data := wk.D.GetAllData(global) + if b.MultiMapSideInputData == nil { + b.MultiMapSideInputData = map[string]map[string]map[typex.Window]map[string][][]byte{} + } + if _, ok := b.MultiMapSideInputData[tid]; !ok { + b.MultiMapSideInputData[tid] = map[string]map[typex.Window]map[string][][]byte{} + } + b.MultiMapSideInputData[tid][local] = collateByWindows(data, watermark, wDec, wEnc, + func(r io.Reader) map[string][][]byte { + kb := kd(r) + return map[string][][]byte{ + string(kb): {vd(r)}, + } + }, func(a, b map[string][][]byte) map[string][][]byte { + if len(a) == 0 { + return b + } + for k, vs := range b { + a[k] = append(a[k], vs...) + } + return a + }) + }) + default: + return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", local, global, si.GetAccessPattern().GetUrn()) + } + } + return func(b *worker.B, tid string, watermark mtime.Time) { + for _, prep := range prepSides { + prep(b, tid, watermark) + } + }, nil +} + +func collectionPullDecoder(coldCId string, coders map[string]*pipepb.Coder, comps *pipepb.Components) func(io.Reader) []byte { + cID := lpUnknownCoders(coldCId, coders, comps.GetCoders()) + return pullDecoder(coders[cID], coders) +} + +func getWindowValueCoders(comps *pipepb.Components, col *pipepb.PCollection, coders map[string]*pipepb.Coder) (exec.WindowDecoder, exec.WindowEncoder) { + ws := comps.GetWindowingStrategies()[col.GetWindowingStrategyId()] + wcID := lpUnknownCoders(ws.GetWindowCoderId(), coders, comps.GetCoders()) + return makeWindowCoders(coders[wcID]) +} + +func sourceTransform(parentID string, sourcePortBytes []byte, outPID string) *pipepb.PTransform { + source := &pipepb.PTransform{ + UniqueName: parentID, + Spec: &pipepb.FunctionSpec{ + Urn: urns.TransformSource, + Payload: sourcePortBytes, + }, + Outputs: map[string]string{ + "i0": outPID, + }, + } + return source +} + +func sinkTransform(sinkID string, sinkPortBytes []byte, inPID string) *pipepb.PTransform { + source := &pipepb.PTransform{ + UniqueName: sinkID, + Spec: &pipepb.FunctionSpec{ + Urn: urns.TransformSink, + Payload: sinkPortBytes, + }, + Inputs: map[string]string{ + "i0": inPID, + }, + } + return source +} + +func portFor(wInCid string, wk *worker.W) []byte { + sourcePort := &fnpb.RemoteGrpcPort{ + CoderId: wInCid, + ApiServiceDescriptor: &pipepb.ApiServiceDescriptor{ + Url: wk.Endpoint(), + }, + } + sourcePortBytes, err := proto.Marshal(sourcePort) + if err != nil { + slog.Error("bad port", err, slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) + } + return sourcePortBytes +} + +type transformExecuter interface { + ExecuteUrns() []string + ExecuteWith(t *pipepb.PTransform) string + ExecuteTransform(tid string, t *pipepb.PTransform, comps *pipepb.Components, watermark mtime.Time, data [][]byte) *worker.B +} + +type processor struct { + transformExecuters map[string]transformExecuter +} + +// collateByWindows takes the data and collates them into window keyed maps. +// Uses generics to consolidate the repetitive window loops. +func collateByWindows[T any](data [][]byte, watermark mtime.Time, wDec exec.WindowDecoder, wEnc exec.WindowEncoder, ed func(io.Reader) T, join func(T, T) T) map[typex.Window]T { + windowed := map[typex.Window]T{} + for _, datum := range data { + inBuf := bytes.NewBuffer(datum) + for { + ws, _, _, err := exec.DecodeWindowedValueHeader(wDec, inBuf) + if err == io.EOF { + break + } + // Get the element out, and window them properly. + e := ed(inBuf) + for _, w := range ws { + // if w.MaxTimestamp() > watermark { + // var t T + // slog.Debug(fmt.Sprintf("collateByWindows[%T]: window not yet closed, skipping %v > %v", t, w.MaxTimestamp(), watermark)) + // continue + // } + windowed[w] = join(windowed[w], e) + } + } + } + return windowed +} + +// stage represents a fused subgraph. +// +// TODO: do we guarantee that they are all +// the same environment at this point, or +// should that be handled later? +type stage struct { + ID string + transforms []string + + envID string + exe transformExecuter + outputCount int + inputTransformID string + mainInputPCol string + inputInfo engine.PColInfo + desc *fnpb.ProcessBundleDescriptor + sides []string + prepareSides func(b *worker.B, tid string, watermark mtime.Time) + + SinkToPCollection map[string]string + OutputsToCoders map[string]engine.PColInfo +} + +func (s *stage) Execute(j *jobservices.Job, wk *worker.W, comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) { + tid := s.transforms[0] + slog.Debug("Execute: starting bundle", "bundle", rb, slog.String("tid", tid)) + + var b *worker.B + var send bool + inputData := em.InputForBundle(rb, s.inputInfo) + switch s.envID { + case "": // Runner Transforms + // Runner transforms are processed immeadiately. + b = s.exe.ExecuteTransform(tid, comps.GetTransforms()[tid], comps, rb.Watermark, inputData) + b.InstID = rb.BundleID + slog.Debug("Execute: runner transform", "bundle", rb, slog.String("tid", tid)) + case wk.ID: + send = true + b = &worker.B{ + PBDID: s.ID, + InstID: rb.BundleID, + + InputTransformID: s.inputTransformID, + + // TODO Here's where we can split data for processing in multiple bundles. + InputData: inputData, + + SinkToPCollection: s.SinkToPCollection, + OutputCount: s.outputCount, + } + b.Init() + + s.prepareSides(b, s.transforms[0], rb.Watermark) + default: + err := fmt.Errorf("unknown environment[%v]", s.envID) + slog.Error("Execute", err) + panic(err) + } + + if send { + slog.Debug("Execute: processing", "bundle", rb) + b.ProcessOn(wk) // Blocks until finished. + } + // Tentative Data is ready, commit it to the main datastore. + slog.Debug("Execute: commiting data", "bundle", rb, slog.Any("outputsWithData", maps.Keys(b.OutputData.Raw)), slog.Any("outputs", maps.Keys(s.OutputsToCoders))) + + resp := &fnpb.ProcessBundleResponse{} + if send { + resp = <-b.Resp + // Tally metrics immeadiately so they're available before + // pipeline termination. + j.ContributeMetrics(resp) + } + // TODO handle side input data properly. + wk.D.Commit(b.OutputData) + var residualData [][]byte + var minOutputWatermark map[string]mtime.Time + for _, rr := range resp.GetResidualRoots() { + ba := rr.GetApplication() + residualData = append(residualData, ba.GetElement()) + if len(ba.GetElement()) == 0 { + slog.Log(slog.LevelError, "returned empty residual application", "bundle", rb) + panic("sdk returned empty residual application") + } + for col, wm := range ba.GetOutputWatermarks() { + if minOutputWatermark == nil { + minOutputWatermark = map[string]mtime.Time{} + } + cur, ok := minOutputWatermark[col] + if !ok { + cur = mtime.MaxTimestamp + } + minOutputWatermark[col] = mtime.Min(mtime.FromTime(wm.AsTime()), cur) + } + } + if l := len(residualData); l > 0 { + slog.Debug("returned empty residual application", "bundle", rb, slog.Int("numResiduals", l), slog.String("pcollection", s.mainInputPCol)) + } + em.PersistBundle(rb, s.OutputsToCoders, b.OutputData, s.inputInfo, residualData, minOutputWatermark) + b.OutputData = engine.TentativeData{} // Clear the data. +} + +// RunPipeline starts the main thread fo executing this job. +// It's analoguous to the manager side process for a distributed pipeline. +// It will begin "workers" +func RunPipeline(j *jobservices.Job) { + j.SendMsg("starting " + j.String()) + j.Start() + + // In a "proper" runner, we'd iterate through all the + // environments, and start up docker containers, but + // here, we only want and need the go one, operating + // in loopback mode. + env := "go" + wk := worker.New(env) // Cheating by having the worker id match the environment id. + go wk.Serve() + + // When this function exits, we + defer func() { + j.CancelFn() + }() + go runEnvironment(j.RootCtx, j, env, wk) + + j.SendMsg("running " + j.String()) + j.Running() + + executePipeline(j.RootCtx, wk, j) + j.SendMsg("pipeline completed " + j.String()) + + // Stop the worker. + wk.Stop() + + j.SendMsg("terminating " + j.String()) + j.Done() +} + +func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *worker.W) { + // TODO fix broken abstraction. + // We're starting a worker pool here, because that's the loopback environment. + // It's sort of a mess, largely because of loopback, which has + // a different flow from a provisioned docker container. + e := j.Pipeline.GetComponents().GetEnvironments()[env] + switch e.GetUrn() { + case urns.EnvExternal: + ep := &pipepb.ExternalPayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(e.GetPayload(), ep); err != nil { + slog.Error("unmarshing environment payload", err, slog.String("envID", wk.ID)) + } + externalEnvironment(ctx, ep, wk) + slog.Info("environment stopped", slog.String("envID", wk.String()), slog.String("job", j.String())) + default: + panic(fmt.Sprintf("environment %v with urn %v unimplemented", env, e.GetUrn())) + } +} + +func externalEnvironment(ctx context.Context, ep *pipepb.ExternalPayload, wk *worker.W) { + conn, err := grpc.Dial(ep.GetEndpoint().GetUrl(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + panic(fmt.Sprintf("unable to dial sdk worker %v: %v", ep.GetEndpoint().GetUrl(), err)) + } + defer conn.Close() + pool := fnpb.NewBeamFnExternalWorkerPoolClient(conn) + + endpoint := &pipepb.ApiServiceDescriptor{ + Url: wk.Endpoint(), + } + + pool.StartWorker(ctx, &fnpb.StartWorkerRequest{ + WorkerId: wk.ID, + ControlEndpoint: endpoint, + LoggingEndpoint: endpoint, + ArtifactEndpoint: endpoint, + ProvisionEndpoint: endpoint, + Params: nil, + }) + + // Job processing happens here, but orchestrated by other goroutines + // This goroutine blocks until the context is cancelled, signalling + // that the pool runner should stop the worker. + <-ctx.Done() + + // Previous context cancelled so we need a new one + // for this request. + pool.StopWorker(context.Background(), &fnpb.StopWorkerRequest{ + WorkerId: wk.ID, + }) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go new file mode 100644 index 0000000000000..de7247486bbc3 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go @@ -0,0 +1,417 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "os" + "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" + "github.com/apache/beam/sdks/v2/go/pkg/beam/options/jobopts" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" + "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" + "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" + "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/filter" + "github.com/apache/beam/sdks/v2/go/test/integration/primitives" +) + +func initRunner(t *testing.T) { + t.Helper() + if *jobopts.Endpoint == "" { + s := jobservices.NewServer(0, RunPipeline) + *jobopts.Endpoint = s.Endpoint() + go s.Serve() + t.Cleanup(func() { + *jobopts.Endpoint = "" + s.Stop() + }) + } + if !jobopts.IsLoopback() { + *jobopts.EnvironmentType = "loopback" + } + // Since we force loopback, avoid cross-compilation. + f, err := os.CreateTemp("", "dummy") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.Remove(f.Name()) }) + *jobopts.WorkerBinary = f.Name() +} + +func execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) { + return universal.Execute(ctx, p) +} + +func executeWithT(ctx context.Context, t *testing.T, p *beam.Pipeline) (beam.PipelineResult, error) { + t.Log("startingTest - ", t.Name()) + return execute(ctx, p) +} + +func init() { + // Not actually being used, but explicitly registering + // will avoid accidentally using a different runner for + // the tests if I change things later. + beam.RegisterRunner("testlocal", execute) +} + +func TestRunner_Pipelines(t *testing.T) { + initRunner(t) + + tests := []struct { + name string + pipeline func(s beam.Scope) + metrics func(t *testing.T, pr beam.PipelineResult) + }{ + { + name: "simple", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col := beam.ParDo(s, dofn1, imp) + beam.ParDo(s, &int64Check{ + Name: "simple", + Want: []int{1, 2, 3}, + }, col) + }, + }, { + name: "sequence", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + beam.Seq(s, imp, dofn1, dofn2, dofn2, dofn2, &int64Check{Name: "sequence", Want: []int{4, 5, 6}}) + }, + }, { + name: "gbk", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col := beam.ParDo(s, dofnKV, imp) + gbk := beam.GroupByKey(s, col) + beam.Seq(s, gbk, dofnGBK, &int64Check{Name: "gbk", Want: []int{9, 12}}) + }, + }, { + name: "gbk2", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col := beam.ParDo(s, dofnKV2, imp) + gbk := beam.GroupByKey(s, col) + beam.Seq(s, gbk, dofnGBK2, &stringCheck{Name: "gbk2", Want: []string{"aaa", "bbb"}}) + }, + }, { + name: "gbk3", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col := beam.ParDo(s, dofnKV3, imp) + gbk := beam.GroupByKey(s, col) + beam.Seq(s, gbk, dofnGBK3, &stringCheck{Name: "gbk3", Want: []string{"{a 1}: {a 1}"}}) + }, + }, { + name: "sink_nooutputs", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + beam.ParDo0(s, dofnSink, imp) + }, + metrics: func(t *testing.T, pr beam.PipelineResult) { + qr := pr.Metrics().Query(func(sr metrics.SingleResult) bool { + return sr.Name() == "sunk" + }) + if got, want := qr.Counters()[0].Committed, int64(73); got != want { + t.Errorf("pr.Metrics.Query(Name = \"sunk\")).Committed = %v, want %v", got, want) + } + }, + }, { + name: "fork_impulse", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1 := beam.ParDo(s, dofn1, imp) + col2 := beam.ParDo(s, dofn1, imp) + beam.ParDo(s, &int64Check{ + Name: "fork check1", + Want: []int{1, 2, 3}, + }, col1) + beam.ParDo(s, &int64Check{ + Name: "fork check2", + Want: []int{1, 2, 3}, + }, col2) + }, + }, { + name: "fork_postDoFn", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col := beam.ParDo(s, dofn1, imp) + beam.ParDo(s, &int64Check{ + Name: "fork check1", + Want: []int{1, 2, 3}, + }, col) + beam.ParDo(s, &int64Check{ + Name: "fork check2", + Want: []int{1, 2, 3}, + }, col) + }, + }, { + name: "fork_multipleOutputs1", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1, col2, col3, col4, col5 := beam.ParDo5(s, dofn1x5, imp) + beam.ParDo(s, &int64Check{ + Name: "col1", + Want: []int{1, 6}, + }, col1) + beam.ParDo(s, &int64Check{ + Name: "col2", + Want: []int{2, 7}, + }, col2) + beam.ParDo(s, &int64Check{ + Name: "col3", + Want: []int{3, 8}, + }, col3) + beam.ParDo(s, &int64Check{ + Name: "col4", + Want: []int{4, 9}, + }, col4) + beam.ParDo(s, &int64Check{ + Name: "col5", + Want: []int{5, 10}, + }, col5) + }, + }, { + name: "fork_multipleOutputs2", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1, col2, col3, col4, col5 := beam.ParDo5(s, dofn1x5, imp) + beam.ParDo(s, &int64Check{ + Name: "col1", + Want: []int{1, 6}, + }, col1) + beam.ParDo(s, &int64Check{ + Name: "col2", + Want: []int{2, 7}, + }, col2) + beam.ParDo(s, &int64Check{ + Name: "col3", + Want: []int{3, 8}, + }, col3) + beam.ParDo(s, &int64Check{ + Name: "col4", + Want: []int{4, 9}, + }, col4) + beam.ParDo(s, &int64Check{ + Name: "col5", + Want: []int{5, 10}, + }, col5) + }, + }, { + name: "flatten", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1 := beam.ParDo(s, dofn1, imp) + col2 := beam.ParDo(s, dofn1, imp) + flat := beam.Flatten(s, col1, col2) + beam.ParDo(s, &int64Check{ + Name: "flatten check", + Want: []int{1, 1, 2, 2, 3, 3}, + }, flat) + }, + }, { + name: "sideinput_iterable_oneimpulse", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1 := beam.ParDo(s, dofn1, imp) + sum := beam.ParDo(s, dofn2x1, imp, beam.SideInput{Input: col1}) + beam.ParDo(s, &int64Check{ + Name: "iter sideinput check", + Want: []int{6}, + }, sum) + }, + }, { + name: "sideinput_iterable_twoimpulse", + pipeline: func(s beam.Scope) { + imp1 := beam.Impulse(s) + col1 := beam.ParDo(s, dofn1, imp1) + imp2 := beam.Impulse(s) + sum := beam.ParDo(s, dofn2x1, imp2, beam.SideInput{Input: col1}) + beam.ParDo(s, &int64Check{ + Name: "iter sideinput check", + Want: []int{6}, + }, sum) + }, + }, { + name: "sideinput_iterableKV", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1 := beam.ParDo(s, dofnKV, imp) + keys, sum := beam.ParDo2(s, dofn2x2KV, imp, beam.SideInput{Input: col1}) + beam.ParDo(s, &stringCheck{ + Name: "iterKV sideinput check K", + Want: []string{"a", "a", "a", "b", "b", "b"}, + }, keys) + beam.ParDo(s, &int64Check{ + Name: "iterKV sideinput check V", + Want: []int{21}, + }, sum) + }, + }, { + name: "sideinput_iterableKV", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1 := beam.ParDo(s, dofnKV, imp) + keys, sum := beam.ParDo2(s, dofn2x2KV, imp, beam.SideInput{Input: col1}) + beam.ParDo(s, &stringCheck{ + Name: "iterKV sideinput check K", + Want: []string{"a", "a", "a", "b", "b", "b"}, + }, keys) + beam.ParDo(s, &int64Check{ + Name: "iterKV sideinput check V", + Want: []int{21}, + }, sum) + }, + }, { + name: "sideinput_multimap", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1 := beam.ParDo(s, dofnKV, imp) + keys := filter.Distinct(s, beam.DropValue(s, col1)) + ks, sum := beam.ParDo2(s, dofnMultiMap, keys, beam.SideInput{Input: col1}) + beam.ParDo(s, &stringCheck{ + Name: "multiMap sideinput check K", + Want: []string{"a", "b"}, + }, ks) + beam.ParDo(s, &int64Check{ + Name: "multiMap sideinput check V", + Want: []int{9, 12}, + }, sum) + }, + }, { + // Ensures topological sort is correct. + name: "sideinput_2iterable", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col0 := beam.ParDo(s, dofn1, imp) + col1 := beam.ParDo(s, dofn1, imp) + col2 := beam.ParDo(s, dofn2, col1) + sum := beam.ParDo(s, dofn3x1, col0, beam.SideInput{Input: col1}, beam.SideInput{Input: col2}) + beam.ParDo(s, &int64Check{ + Name: "iter sideinput check", + Want: []int{16, 17, 18}, + }, sum) + }, + }, { + name: "combine_perkey", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + in := beam.ParDo(s, dofn1kv, imp) + keyedsum := beam.CombinePerKey(s, combineIntSum, in) + sum := beam.DropKey(s, keyedsum) + beam.ParDo(s, &int64Check{ + Name: "combine", + Want: []int{6}, + }, sum) + }, + }, { + name: "combine_global", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + in := beam.ParDo(s, dofn1, imp) + sum := beam.Combine(s, combineIntSum, in) + beam.ParDo(s, &int64Check{ + Name: "combine", + Want: []int{6}, + }, sum) + }, + }, { + name: "sdf_single_split", + pipeline: func(s beam.Scope) { + configs := beam.Create(s, SourceConfig{NumElements: 10, InitialSplits: 1}) + in := beam.ParDo(s, &intRangeFn{}, configs) + beam.ParDo(s, &int64Check{ + Name: "sdf_single", + Want: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + }, in) + }, + }, { + name: "WindowedSideInputs", + pipeline: primitives.ValidateWindowedSideInputs, + }, { + name: "WindowSums_GBK", + pipeline: primitives.WindowSums_GBK, + }, { + name: "WindowSums_Lifted", + pipeline: primitives.WindowSums_Lifted, + }, { + name: "ProcessContinuations_globalCombine", + pipeline: func(s beam.Scope) { + out := beam.ParDo(s, &selfCheckpointingDoFn{}, beam.Impulse(s)) + passert.Count(s, out, "num ints", 10) + }, + }, { + name: "flatten_to_sideInput", + pipeline: func(s beam.Scope) { + imp := beam.Impulse(s) + col1 := beam.ParDo(s, dofn1, imp) + col2 := beam.ParDo(s, dofn1, imp) + flat := beam.Flatten(s, col1, col2) + beam.ParDo(s, &int64Check{ + Name: "flatten check", + Want: []int{1, 1, 2, 2, 3, 3}, + }, flat) + passert.NonEmpty(s, flat) + }, + }, + } + // TODO: Explicit DoFn Failure case. + // TODO: Session windows, where some are not merged. + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + test.pipeline(s) + pr, err := executeWithT(context.Background(), t, p) + if err != nil { + t.Fatal(err) + } + if test.metrics != nil { + test.metrics(t, pr) + } + }) + } +} + +func TestRunner_Metrics(t *testing.T) { + initRunner(t) + t.Run("counter", func(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + imp := beam.Impulse(s) + beam.ParDo(s, dofn1Counter, imp) + pr, err := executeWithT(context.Background(), t, p) + if err != nil { + t.Fatal(err) + } + qr := pr.Metrics().Query(func(sr metrics.SingleResult) bool { + return sr.Name() == "count" + }) + if got, want := qr.Counters()[0].Committed, int64(1); got != want { + t.Errorf("pr.Metrics.Query(Name = \"count\")).Committed = %v, want %v", got, want) + } + }) +} + +// TODO: PCollection metrics tests, in particular for element counts, in multi transform pipelines +// There's a doubling bug since we re-use the same pcollection IDs for the source & sink, and +// don't do any re-writing. + +func TestMain(m *testing.M) { + ptest.MainWithDefault(m, "testlocal") +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlecombine.go b/sdks/go/pkg/beam/runners/prism/internal/handlecombine.go new file mode 100644 index 0000000000000..ff9bd1e1c88a1 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/handlecombine.go @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "fmt" + "reflect" + + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "google.golang.org/protobuf/proto" +) + +// This file retains the logic for the combine handler + +// CombineCharacteristic holds the configuration for Combines. +type CombineCharacteristic struct { + EnableLifting bool // Sets whether a combine composite does combiner lifting or not. +} + +// TODO figure out the factory we'd like. + +func Combine(config any) *combine { + return &combine{config: config.(CombineCharacteristic)} +} + +// combine represents an instance of the combine handler. +type combine struct { + config CombineCharacteristic +} + +// ConfigURN returns the name for combine in the configuration file. +func (*combine) ConfigURN() string { + return "combine" +} + +func (*combine) ConfigCharacteristic() reflect.Type { + return reflect.TypeOf((*CombineCharacteristic)(nil)).Elem() +} + +var _ transformPreparer = (*combine)(nil) + +func (*combine) PrepareUrns() []string { + return []string{urns.TransformCombinePerKey} +} + +// PrepareTransform returns lifted combines and removes the leaves if enabled. Otherwise returns nothing. +func (h *combine) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb.Components) (*pipepb.Components, []string) { + // If we aren't lifting, the "default impl" for combines should be sufficient. + if !h.config.EnableLifting { + return nil, nil + } + + // To lift a combine, the spec should contain a CombinePayload. + // That contains the actual FunctionSpec for the DoFn, and the + // id for the accumulator coder. + // We can synthetically produce/determine the remaining coders for + // the Input and Output types from the existing PCollections. + // + // This means we also need to synthesize pcollections with the accumulator coder too. + + // What we have: + // Input PCol: KV -- INPUT + // -> GBK := KV> -- GROUPED_I + // -> Combine := KV -- OUTPUT + // + // What we want: + // Input PCol: KV -- INPUT + // -> PreCombine := KV -- LIFTED + // -> GBK -> KV> -- GROUPED_A + // -> MergeAccumulators := KV -- MERGED_A + // -> ExtractOutput -> KV -- OUTPUT + // + // First we need to produce new coders for Iter, KV>, and KV. + // The A coder ID is in the combine payload. + // + // Then we can produce the PCollections. + // We can reuse the INPUT and OUTPUT PCollections. + // We need LIFTED to have KV kv_k_a + // We need GROUPED_A to have KV> kv_k_iter_a + // We need MERGED_A to have KV kv_k_a + // + // GROUPED_I ends up unused. + // + // The PCollections inherit the properties of the Input PCollection + // such as Boundedness, and Windowing Strategy. + // + // With these, we can produce the PTransforms with the appropriate URNs for the + // different parts of the composite, and return the new components. + + cmbPayload := t.GetSpec().GetPayload() + cmb := &pipepb.CombinePayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(cmbPayload, cmb); err != nil { + panic(fmt.Sprintf("unable to decode ParDoPayload for transform[%v]", t.GetUniqueName())) + } + + // First lets get the key coder ID. + var pcolInID string + // There's only one input. + for _, pcol := range t.GetInputs() { + pcolInID = pcol + } + inputPCol := comps.GetPcollections()[pcolInID] + kvkiID := inputPCol.GetCoderId() + kID := comps.GetCoders()[kvkiID].GetComponentCoderIds()[0] + + // Now we can start synthesis! + // Coder IDs + aID := cmb.AccumulatorCoderId + + ckvprefix := "c" + tid + "_kv_" + + iterACID := "c" + tid + "_iter_" + aID + kvkaCID := ckvprefix + kID + "_" + aID + kvkIterACID := ckvprefix + kID + "_iter" + aID + + // PCollection IDs + nprefix := "n" + tid + "_" + liftedNID := nprefix + "lifted" + groupedNID := nprefix + "grouped" + mergedNID := nprefix + "merged" + + // Now we need the output collection ID + var pcolOutID string + // There's only one input. + for _, pcol := range t.GetOutputs() { + pcolOutID = pcol + } + + // Transform IDs + eprefix := "e" + tid + "_" + liftEID := eprefix + "lift" + gbkEID := eprefix + "gbk" + mergeEID := eprefix + "merge" + extractEID := eprefix + "extract" + + coder := func(urn string, componentIDs ...string) *pipepb.Coder { + return &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urn, + }, + ComponentCoderIds: componentIDs, + } + } + + pcol := func(name, coderID string) *pipepb.PCollection { + return &pipepb.PCollection{ + UniqueName: name, + CoderId: coderID, + IsBounded: inputPCol.GetIsBounded(), + WindowingStrategyId: inputPCol.GetWindowingStrategyId(), + } + } + + tform := func(name, urn, in, out, env string) *pipepb.PTransform { + return &pipepb.PTransform{ + UniqueName: name, + Spec: &pipepb.FunctionSpec{ + Urn: urn, + Payload: cmbPayload, + }, + Inputs: map[string]string{ + "i0": in, + }, + Outputs: map[string]string{ + "i0": out, + }, + EnvironmentId: env, + } + } + + newComps := &pipepb.Components{ + Coders: map[string]*pipepb.Coder{ + iterACID: coder(urns.CoderIterable, aID), + kvkaCID: coder(urns.CoderKV, kID, aID), + kvkIterACID: coder(urns.CoderKV, kID, iterACID), + }, + Pcollections: map[string]*pipepb.PCollection{ + liftedNID: pcol(liftedNID, kvkaCID), + groupedNID: pcol(groupedNID, kvkIterACID), + mergedNID: pcol(mergedNID, kvkaCID), + }, + Transforms: map[string]*pipepb.PTransform{ + liftEID: tform(liftEID, urns.TransformPreCombine, pcolInID, liftedNID, t.GetEnvironmentId()), + gbkEID: tform(gbkEID, urns.TransformGBK, liftedNID, groupedNID, ""), + mergeEID: tform(mergeEID, urns.TransformMerge, groupedNID, mergedNID, t.GetEnvironmentId()), + extractEID: tform(mergeEID, urns.TransformExtract, mergedNID, pcolOutID, t.GetEnvironmentId()), + }, + } + + // Now we return everything! + // TODO recurse through sub transforms to remove? + // We don't need to remove the composite, since we don't add it in + // when we return the new transforms, so it's not in the topology. + return newComps, t.GetSubtransforms() +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go new file mode 100644 index 0000000000000..2ac5ca5bbf595 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "fmt" + "reflect" + + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "golang.org/x/exp/maps" + "google.golang.org/protobuf/proto" +) + +// This file retains the logic for the pardo handler + +// ParDoCharacteristic holds the configuration for ParDos. +type ParDoCharacteristic struct { + DisableSDF bool // Sets whether a pardo supports SDFs or not. +} + +func ParDo(config any) *pardo { + return &pardo{config: config.(ParDoCharacteristic)} +} + +// pardo represents an instance of the pardo handler. +type pardo struct { + config ParDoCharacteristic +} + +// ConfigURN returns the name for combine in the configuration file. +func (*pardo) ConfigURN() string { + return "pardo" +} + +func (*pardo) ConfigCharacteristic() reflect.Type { + return reflect.TypeOf((*ParDoCharacteristic)(nil)).Elem() +} + +var _ transformPreparer = (*pardo)(nil) + +func (*pardo) PrepareUrns() []string { + return []string{urns.TransformParDo} +} + +// PrepareTransform handles special processing with respect to ParDos, since their handling is dependant on supported features +// and requirements. +func (h *pardo) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb.Components) (*pipepb.Components, []string) { + + // ParDos are a pain in the butt. + // Combines, by comparison, are dramatically simpler. + // This is because for ParDos, how they are handled, and what kinds of transforms are in + // and around the ParDo, the actual shape of the graph will change. + // At their simplest, it's something a DoFn will handle on their own. + // At their most complex, they require intimate interaction with the subgraph + // bundling process, the data layer, state layers, and control layers. + // But unlike combines, which have a clear urn for composite + special payload, + // ParDos have the standard URN for composites with the standard payload. + // So always, we need to first unmarshal the payload. + + pardoPayload := t.GetSpec().GetPayload() + pdo := &pipepb.ParDoPayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(pardoPayload, pdo); err != nil { + panic(fmt.Sprintf("unable to decode ParDoPayload for transform[%v]", t.GetUniqueName())) + } + + // Lets check for and remove anything that makes things less simple. + if pdo.OnWindowExpirationTimerFamilySpec == "" && + !pdo.RequestsFinalization && + !pdo.RequiresStableInput && + !pdo.RequiresTimeSortedInput && + len(pdo.StateSpecs) == 0 && + len(pdo.TimerFamilySpecs) == 0 && + pdo.RestrictionCoderId == "" { + // Which inputs are Side inputs don't change the graph further, + // so they're not included here. Any nearly any ParDo can have them. + + // At their simplest, we don't need to do anything special at pre-processing time, and simply pass through as normal. + return &pipepb.Components{ + Transforms: map[string]*pipepb.PTransform{ + tid: t, + }, + }, nil + } + + // Side inputs add to topology and make fusion harder to deal with + // (side input producers can't be in the same stage as their consumers) + // But we don't have fusion yet, so no worries. + + // State, Timers, Stable Input, Time Sorted Input, and some parts of SDF + // Are easier to deal including a fusion break. But We can do that with a + // runner specific transform for stable input, and another for timesorted + // input. + + // SplittableDoFns have 3 required phases and a 4th optional phase. + // + // PAIR_WITH_RESTRICTION which pairs elements with their restrictions + // Input: element; := INPUT + // Output: KV(element, restriction) := PWR + // + // SPLIT_AND_SIZE_RESTRICTIONS splits the pairs into sub element ranges + // and a relative size for each, in a float64 format. + // Input: KV(element, restriction) := PWR + // Output: KV(KV(element, restriction), float64) := SPLITnSIZED + // + // PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS actually processes the + // elements. This is also where splits need to be handled. + // In particular, primary and residual splits have the same format as the input. + // Input: KV(KV(element, restriction), size) := SPLITnSIZED + // Output: DoFn's output. := OUTPUT + // + // TRUNCATE_SIZED_RESTRICTION is how the runner has an SDK turn an + // unbounded transform into a bound one. Not needed until the pipeline + // is told to drain. + // Input: KV(KV(element, restriction), float64) := synthetic split results from above + // Output: KV(KV(element, restriction), float64). := synthetic, truncated results sent as Split n Sized + // + // So with that, we can figure out the coders we need. + // + // cE - Element Coder (same as input coder) + // cR - Restriction Coder + // cS - Size Coder (float64) + // ckvER - KV + // ckvERS - KV, Size> + // + // There could be a few output coders, but the outputs can be copied from + // the original transform directly. + + // First lets get the parallel input coder ID. + var pcolInID, inputLocalID string + for localID, globalID := range t.GetInputs() { + // The parallel input is the one that isn't a side input. + if _, ok := pdo.SideInputs[localID]; !ok { + inputLocalID = localID + pcolInID = globalID + break + } + } + inputPCol := comps.GetPcollections()[pcolInID] + cEID := inputPCol.GetCoderId() + cRID := pdo.RestrictionCoderId + cSID := "c" + tid + "size" + ckvERID := "c" + tid + "kv_ele_rest" + ckvERSID := ckvERID + "_size" + + coder := func(urn string, componentIDs ...string) *pipepb.Coder { + return &pipepb.Coder{ + Spec: &pipepb.FunctionSpec{ + Urn: urn, + }, + ComponentCoderIds: componentIDs, + } + } + + coders := map[string]*pipepb.Coder{ + ckvERID: coder(urns.CoderKV, cEID, cRID), + cSID: coder(urns.CoderDouble), + ckvERSID: coder(urns.CoderKV, ckvERID, cSID), + } + + // PCollections only have two new ones. + // INPUT -> same as ordinary DoFn + // PWR, uses ckvER + // SPLITnSIZED, uses ckvERS + // OUTPUT -> same as ordinary outputs + + nPWRID := "n" + tid + "_pwr" + nSPLITnSIZEDID := "n" + tid + "_splitnsized" + + pcol := func(name, coderID string) *pipepb.PCollection { + return &pipepb.PCollection{ + UniqueName: name, + CoderId: coderID, + IsBounded: inputPCol.GetIsBounded(), + WindowingStrategyId: inputPCol.GetWindowingStrategyId(), + } + } + + pcols := map[string]*pipepb.PCollection{ + nPWRID: pcol(nPWRID, ckvERID), + nSPLITnSIZEDID: pcol(nSPLITnSIZEDID, ckvERSID), + } + + // PTransforms have 3 new ones, with process sized elements and restrictions + // taking the brunt of the complexity, consuming the inputs + + ePWRID := "e" + tid + "_pwr" + eSPLITnSIZEDID := "e" + tid + "_splitnsize" + eProcessID := "e" + tid + "_processandsplit" + + tform := func(name, urn, in, out string) *pipepb.PTransform { + return &pipepb.PTransform{ + UniqueName: name, + Spec: &pipepb.FunctionSpec{ + Urn: urn, + Payload: pardoPayload, + }, + Inputs: map[string]string{ + inputLocalID: in, + }, + Outputs: map[string]string{ + "i0": out, + }, + EnvironmentId: t.GetEnvironmentId(), + } + } + + newInputs := maps.Clone(t.GetInputs()) + newInputs[inputLocalID] = nSPLITnSIZEDID + + tforms := map[string]*pipepb.PTransform{ + ePWRID: tform(ePWRID, urns.TransformPairWithRestriction, pcolInID, nPWRID), + eSPLITnSIZEDID: tform(eSPLITnSIZEDID, urns.TransformSplitAndSize, nPWRID, nSPLITnSIZEDID), + eProcessID: { + UniqueName: eProcessID, + Spec: &pipepb.FunctionSpec{ + Urn: urns.TransformProcessSizedElements, + Payload: pardoPayload, + }, + Inputs: newInputs, + Outputs: t.GetOutputs(), + EnvironmentId: t.GetEnvironmentId(), + }, + } + + return &pipepb.Components{ + Coders: coders, + Pcollections: pcols, + Transforms: tforms, + }, t.GetSubtransforms() +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go new file mode 100644 index 0000000000000..e841620625e97 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go @@ -0,0 +1,298 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "fmt" + "io" + "reflect" + "sort" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" + "golang.org/x/exp/slog" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" +) + +// This file retains the logic for the pardo handler + +// RunnerCharacteristic holds the configuration for Runner based transforms, +// such as GBKs, Flattens. +type RunnerCharacteristic struct { + SDKFlatten bool // Sets whether we should force an SDK side flatten. + SDKGBK bool // Sets whether the GBK should be handled by the SDK, if possible by the SDK. +} + +func Runner(config any) *runner { + return &runner{config: config.(RunnerCharacteristic)} +} + +// runner represents an instance of the runner transform handler. +type runner struct { + config RunnerCharacteristic +} + +// ConfigURN returns the name for combine in the configuration file. +func (*runner) ConfigURN() string { + return "runner" +} + +func (*runner) ConfigCharacteristic() reflect.Type { + return reflect.TypeOf((*RunnerCharacteristic)(nil)).Elem() +} + +var _ transformExecuter = (*runner)(nil) + +func (*runner) ExecuteUrns() []string { + return []string{urns.TransformFlatten, urns.TransformGBK} +} + +// ExecuteWith returns what environment the +func (h *runner) ExecuteWith(t *pipepb.PTransform) string { + urn := t.GetSpec().GetUrn() + if urn == urns.TransformFlatten && !h.config.SDKFlatten { + return "" + } + if urn == urns.TransformGBK && !h.config.SDKGBK { + return "" + } + return t.GetEnvironmentId() +} + +// ExecuteTransform handles special processing with respect to runner specific transforms +func (h *runner) ExecuteTransform(tid string, t *pipepb.PTransform, comps *pipepb.Components, watermark mtime.Time, inputData [][]byte) *worker.B { + urn := t.GetSpec().GetUrn() + var data [][]byte + var onlyOut string + for _, out := range t.GetOutputs() { + onlyOut = out + } + + switch urn { + case urns.TransformFlatten: + // Already done and collated. + data = inputData + + case urns.TransformGBK: + ws := windowingStrategy(comps, tid) + kvc := onlyInputCoderForTransform(comps, tid) + + coders := map[string]*pipepb.Coder{} + + // TODO assert this is a KV. It's probably fine, but we should fail anyway. + wcID := lpUnknownCoders(ws.GetWindowCoderId(), coders, comps.GetCoders()) + kcID := lpUnknownCoders(kvc.GetComponentCoderIds()[0], coders, comps.GetCoders()) + ecID := lpUnknownCoders(kvc.GetComponentCoderIds()[1], coders, comps.GetCoders()) + reconcileCoders(coders, comps.GetCoders()) + + wc := coders[wcID] + kc := coders[kcID] + ec := coders[ecID] + + data = append(data, gbkBytes(ws, wc, kc, ec, inputData, coders, watermark)) + if len(data[0]) == 0 { + panic("no data for GBK") + } + default: + panic(fmt.Sprintf("unimplemented runner transform[%v]", urn)) + } + + // To avoid conflicts with these single transform + // bundles, we suffix the transform IDs. + var localID string + for key := range t.GetOutputs() { + localID = key + } + + if localID == "" { + panic(fmt.Sprintf("bad transform: %v", prototext.Format(t))) + } + output := engine.TentativeData{} + for _, d := range data { + output.WriteData(onlyOut, d) + } + + dataID := tid + "_" + localID // The ID from which the consumer will read from. + b := &worker.B{ + InputTransformID: dataID, + SinkToPCollection: map[string]string{ + dataID: onlyOut, + }, + OutputData: output, + } + return b +} + +// windowingStrategy sources the transform's windowing strategy from a single parallel input. +func windowingStrategy(comps *pipepb.Components, tid string) *pipepb.WindowingStrategy { + t := comps.GetTransforms()[tid] + var inputPColID string + for _, pcolID := range t.GetInputs() { + inputPColID = pcolID + } + pcol := comps.GetPcollections()[inputPColID] + return comps.GetWindowingStrategies()[pcol.GetWindowingStrategyId()] +} + +// gbkBytes re-encodes gbk inputs in a gbk result. +func gbkBytes(ws *pipepb.WindowingStrategy, wc, kc, vc *pipepb.Coder, toAggregate [][]byte, coders map[string]*pipepb.Coder, watermark mtime.Time) []byte { + var outputTime func(typex.Window, mtime.Time) mtime.Time + switch ws.GetOutputTime() { + case pipepb.OutputTime_END_OF_WINDOW: + outputTime = func(w typex.Window, et mtime.Time) mtime.Time { + return w.MaxTimestamp() + } + default: + // TODO need to correct session logic if output time is different. + panic(fmt.Sprintf("unsupported OutputTime behavior: %v", ws.GetOutputTime())) + } + wDec, wEnc := makeWindowCoders(wc) + + type keyTime struct { + key []byte + w typex.Window + time mtime.Time + values [][]byte + } + // Map windows to a map of keys to a map of keys to time. + // We ultimately emit the window, the key, the time, and the iterable of elements, + // all contained in the final value. + windows := map[typex.Window]map[string]keyTime{} + + kd := pullDecoder(kc, coders) + vd := pullDecoder(vc, coders) + + // Right, need to get the key coder, and the element coder. + // Cus I'll need to pull out anything the runner knows how to deal with. + // And repeat. + for _, data := range toAggregate { + // Parse out each element's data, and repeat. + buf := bytes.NewBuffer(data) + for { + ws, tm, _, err := exec.DecodeWindowedValueHeader(wDec, buf) + if err == io.EOF { + break + } + if err != nil { + panic(fmt.Sprintf("can't decode windowed value header with %v: %v", wc, err)) + } + + keyByt := kd(buf) + key := string(keyByt) + value := vd(buf) + for _, w := range ws { + ft := outputTime(w, tm) + wk, ok := windows[w] + if !ok { + wk = make(map[string]keyTime) + windows[w] = wk + } + kt := wk[key] + kt.time = ft + kt.key = keyByt + kt.w = w + kt.values = append(kt.values, value) + wk[key] = kt + } + } + } + + // If the strategy is session windows, then we need to get all the windows, sort them + // and see which ones need to be merged together. + if ws.GetWindowFn().GetUrn() == urns.WindowFnSession { + slog.Debug("sorting by session window") + session := &pipepb.SessionWindowsPayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(ws.GetWindowFn().GetPayload(), session); err != nil { + panic("unable to decode SessionWindowsPayload") + } + gapSize := mtime.Time(session.GetGapSize().AsDuration()) + + ordered := make([]window.IntervalWindow, 0, len(windows)) + for k := range windows { + ordered = append(ordered, k.(window.IntervalWindow)) + } + // Use a decreasing sort (latest to earliest) so we can correct + // the output timestamp to the new end of window immeadiately. + // TODO need to correct this if output time is different. + sort.Slice(ordered, func(i, j int) bool { + return ordered[i].MaxTimestamp() > ordered[j].MaxTimestamp() + }) + + cur := ordered[0] + sessionData := windows[cur] + for _, iw := range ordered[1:] { + // If they overlap, then we merge the data. + if iw.End+gapSize < cur.Start { + // Start a new session. + windows[cur] = sessionData + cur = iw + sessionData = windows[iw] + continue + } + // Extend the session + cur.Start = iw.Start + toMerge := windows[iw] + delete(windows, iw) + for k, kt := range toMerge { + skt := sessionData[k] + skt.key = kt.key + skt.w = cur + skt.values = append(skt.values, kt.values...) + sessionData[k] = skt + } + } + } + // Everything's aggregated! + // Time to turn things into a windowed KV> + + var buf bytes.Buffer + for _, w := range windows { + for _, kt := range w { + exec.EncodeWindowedValueHeader( + wEnc, + []typex.Window{kt.w}, + kt.time, + typex.NoFiringPane(), + &buf, + ) + buf.Write(kt.key) + coder.EncodeInt32(int32(len(kt.values)), &buf) + for _, value := range kt.values { + buf.Write(value) + } + } + } + return buf.Bytes() +} + +func onlyInputCoderForTransform(comps *pipepb.Components, tid string) *pipepb.Coder { + t := comps.GetTransforms()[tid] + var inputPColID string + for _, pcolID := range t.GetInputs() { + inputPColID = pcolID + } + pcol := comps.GetPcollections()[inputPColID] + return comps.GetCoders()[pcol.GetCoderId()] +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go new file mode 100644 index 0000000000000..e66def5b0fe86 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jobservices + +import ( + "fmt" + "io" + + jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" + "golang.org/x/exp/slog" +) + +func (s *Server) ReverseArtifactRetrievalService(stream jobpb.ArtifactStagingService_ReverseArtifactRetrievalServiceServer) error { + in, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + job := s.jobs[in.GetStagingToken()] + + envs := job.Pipeline.GetComponents().GetEnvironments() + for _, env := range envs { + for _, dep := range env.GetDependencies() { + slog.Debug("GetArtifact start", + slog.Group("dep", + slog.String("urn", dep.GetTypeUrn()), + slog.String("payload", string(dep.GetTypePayload())))) + stream.Send(&jobpb.ArtifactRequestWrapper{ + Request: &jobpb.ArtifactRequestWrapper_GetArtifact{ + GetArtifact: &jobpb.GetArtifactRequest{ + Artifact: dep, + }, + }, + }) + var count int + for { + in, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + if in.IsLast { + slog.Debug("GetArtifact finish", + slog.Group("dep", + slog.String("urn", dep.GetTypeUrn()), + slog.String("payload", string(dep.GetTypePayload()))), + slog.Int("bytesReceived", count)) + break + } + // Here's where we go through each environment's artifacts. + // We do nothing with them. + switch req := in.GetResponse().(type) { + case *jobpb.ArtifactResponseWrapper_GetArtifactResponse: + count += len(req.GetArtifactResponse.GetData()) + case *jobpb.ArtifactResponseWrapper_ResolveArtifactResponse: + err := fmt.Errorf("unexpected ResolveArtifactResponse to GetArtifact: %v", in.GetResponse()) + slog.Error("GetArtifact failure", err) + return err + } + } + } + } + return nil +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go new file mode 100644 index 0000000000000..95b1ce12af93e --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package jobservices handles services necessary WRT handling jobs from +// SDKs. Nominally this is the entry point for most users, and a job's +// external interactions outside of pipeline execution. +// +// This includes handling receiving, staging, and provisioning artifacts, +// and orchestrating external workers, such as for loopback mode. +// +// Execution of jobs is abstracted away to an execute function specified +// at server construction time. +package jobservices + +import ( + "context" + "fmt" + "sort" + "strings" + "sync/atomic" + + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "golang.org/x/exp/slog" + "google.golang.org/protobuf/types/known/structpb" +) + +var capabilities = map[string]struct{}{ + urns.RequirementSplittableDoFn: {}, +} + +// TODO, move back to main package, and key off of executor handlers? +// Accept whole pipeline instead, and look at every PTransform too. +func isSupported(requirements []string) error { + var unsupported []string + for _, req := range requirements { + if _, ok := capabilities[req]; !ok { + unsupported = append(unsupported, req) + } + } + if len(unsupported) > 0 { + sort.Strings(unsupported) + return fmt.Errorf("local runner doesn't support the following required features: %v", strings.Join(unsupported, ",")) + } + return nil +} + +// Job is an interface to the job services for executing pipelines. +// It allows the executor to communicate status, messages, and metrics +// back to callers of the Job Management API. +type Job struct { + key string + jobName string + + Pipeline *pipepb.Pipeline + options *structpb.Struct + + // Management side concerns. + msgChan chan string + state atomic.Value // jobpb.JobState_Enum + stateChan chan jobpb.JobState_Enum + + // Context used to terminate this job. + RootCtx context.Context + CancelFn context.CancelFunc + + metrics metricsStore +} + +func (j *Job) ContributeMetrics(payloads *fnpb.ProcessBundleResponse) { + j.metrics.ContributeMetrics(payloads) +} + +func (j *Job) String() string { + return fmt.Sprintf("%v[%v]", j.key, j.jobName) +} + +func (j *Job) LogValue() slog.Value { + return slog.GroupValue( + slog.String("key", j.key), + slog.String("name", j.jobName)) +} + +func (j *Job) SendMsg(msg string) { + j.msgChan <- msg +} + +// Start indicates that the job is preparing to execute. +func (j *Job) Start() { + j.stateChan <- jobpb.JobState_STARTING +} + +// Running indicates that the job is executing. +func (j *Job) Running() { + j.stateChan <- jobpb.JobState_RUNNING +} + +// Done indicates that the job completed successfully. +func (j *Job) Done() { + j.stateChan <- jobpb.JobState_DONE +} + +// Failed indicates that the job completed unsuccessfully. +func (j *Job) Failed() { + j.stateChan <- jobpb.JobState_FAILED +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go new file mode 100644 index 0000000000000..af6c8c71a1d99 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jobservices + +import ( + "context" + "fmt" + "sync/atomic" + + jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "golang.org/x/exp/slog" +) + +func (s *Server) nextId() string { + v := atomic.AddUint32(&s.index, 1) + return fmt.Sprintf("job-%03d", v) +} + +func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jobpb.PrepareJobResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Since jobs execute in the background, they should not be tied to a request's context. + rootCtx, cancelFn := context.WithCancel(context.Background()) + job := &Job{ + key: s.nextId(), + Pipeline: req.GetPipeline(), + jobName: req.GetJobName(), + options: req.GetPipelineOptions(), + + msgChan: make(chan string, 100), + stateChan: make(chan jobpb.JobState_Enum, 1), + RootCtx: rootCtx, + CancelFn: cancelFn, + } + + // Queue initial state of the job. + job.state.Store(jobpb.JobState_STOPPED) + job.stateChan <- job.state.Load().(jobpb.JobState_Enum) + + if err := isSupported(job.Pipeline.GetRequirements()); err != nil { + slog.Error("unable to run job", err, slog.String("jobname", req.GetJobName())) + return nil, err + } + s.jobs[job.key] = job + return &jobpb.PrepareJobResponse{ + PreparationId: job.key, + StagingSessionToken: job.key, + ArtifactStagingEndpoint: &pipepb.ApiServiceDescriptor{ + Url: s.Endpoint(), + }, + }, nil +} + +func (s *Server) Run(ctx context.Context, req *jobpb.RunJobRequest) (*jobpb.RunJobResponse, error) { + s.mu.Lock() + job := s.jobs[req.GetPreparationId()] + s.mu.Unlock() + + // Bring up a background goroutine to allow the job to continue processing. + go s.execute(job) + + return &jobpb.RunJobResponse{ + JobId: job.key, + }, nil +} + +// GetMessageStream subscribes to a stream of state changes and messages from the job +func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream jobpb.JobService_GetMessageStreamServer) error { + s.mu.Lock() + job := s.jobs[req.GetJobId()] + s.mu.Unlock() + + for { + select { + case msg := <-job.msgChan: + stream.Send(&jobpb.JobMessagesResponse{ + Response: &jobpb.JobMessagesResponse_MessageResponse{ + MessageResponse: &jobpb.JobMessage{ + MessageText: msg, + Importance: jobpb.JobMessage_JOB_MESSAGE_BASIC, + }, + }, + }) + + case state, ok := <-job.stateChan: + // TODO: Don't block job execution if WaitForCompletion isn't being run. + // The state channel means the job may only execute if something is observing + // the message stream, as the send on the state or message channel may block + // once full. + // Not a problem for tests or short lived batch, but would be hazardous for + // asynchronous jobs. + + // Channel is closed, so the job must be done. + if !ok { + state = jobpb.JobState_DONE + } + job.state.Store(state) + stream.Send(&jobpb.JobMessagesResponse{ + Response: &jobpb.JobMessagesResponse_StateResponse{ + StateResponse: &jobpb.JobStateEvent{ + State: state, + }, + }, + }) + switch state { + case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, jobpb.JobState_DRAINED, jobpb.JobState_FAILED, jobpb.JobState_UPDATED: + // Reached terminal state. + return nil + } + } + } + +} + +// GetJobMetrics Fetch metrics for a given job. +func (s *Server) GetJobMetrics(ctx context.Context, req *jobpb.GetJobMetricsRequest) (*jobpb.GetJobMetricsResponse, error) { + j := s.getJob(req.GetJobId()) + if j == nil { + return nil, fmt.Errorf("GetJobMetrics: unknown jobID: %v", req.GetJobId()) + } + return &jobpb.GetJobMetricsResponse{ + Metrics: &jobpb.MetricResults{ + Attempted: j.metrics.Results(tentative), + Committed: j.metrics.Results(committed), + }, + }, nil +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go new file mode 100644 index 0000000000000..39936bae72f19 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go @@ -0,0 +1,495 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jobservices + +import ( + "bytes" + "fmt" + "hash/maphash" + "math" + "sort" + "sync" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "golang.org/x/exp/constraints" + "golang.org/x/exp/slog" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type labelsToKeyFunc func(string, map[string]string) metricKey + +type urnOps struct { + // keyFn produces the key for this metric from the labels. + // based on the required label set for the metric from it's spec. + keyFn labelsToKeyFunc + // newAccum produces an accumulator assuming we don't have an accumulator for it already. + // based on the type urn of the metric from it's spec. + newAccum accumFactory +} + +var ( + mUrn2Ops = map[string]urnOps{} +) + +func init() { + mUrn2Spec := map[string]*pipepb.MonitoringInfoSpec{} + specs := (pipepb.MonitoringInfoSpecs_Enum)(0).Descriptor().Values() + for i := 0; i < specs.Len(); i++ { + enum := specs.ByNumber(protoreflect.EnumNumber(i)) + spec := proto.GetExtension(enum.Options(), pipepb.E_MonitoringInfoSpec).(*pipepb.MonitoringInfoSpec) + mUrn2Spec[spec.GetUrn()] = spec + } + mUrn2Ops = buildUrnToOpsMap(mUrn2Spec) +} + +// Should probably just construct a slice or map to get the urns out +// since we'll ultimately be using them a lot. +var metTyps = (pipepb.MonitoringInfoTypeUrns_Enum)(0).Descriptor().Values() + +func getMetTyp(t pipepb.MonitoringInfoTypeUrns_Enum) string { + return proto.GetExtension(metTyps.ByNumber(protoreflect.EnumNumber(t)).Options(), pipepb.E_BeamUrn).(string) +} + +func buildUrnToOpsMap(mUrn2Spec map[string]*pipepb.MonitoringInfoSpec) map[string]urnOps { + var hasher maphash.Hash + + props := (pipepb.MonitoringInfo_MonitoringInfoLabels)(0).Descriptor().Values() + getProp := func(l pipepb.MonitoringInfo_MonitoringInfoLabels) string { + return proto.GetExtension(props.ByNumber(protoreflect.EnumNumber(l)).Options(), pipepb.E_LabelProps).(*pipepb.MonitoringInfoLabelProps).GetName() + } + + l2func := make(map[uint64]labelsToKeyFunc) + labelsToKey := func(required []pipepb.MonitoringInfo_MonitoringInfoLabels, fn labelsToKeyFunc) { + hasher.Reset() + // We need the string versions of things to sort against + // for consistent hashing. + var req []string + for _, l := range required { + v := getProp(l) + req = append(req, v) + } + sort.Strings(req) + for _, v := range req { + hasher.WriteString(v) + } + key := hasher.Sum64() + l2func[key] = fn + } + ls := func(ls ...pipepb.MonitoringInfo_MonitoringInfoLabels) []pipepb.MonitoringInfo_MonitoringInfoLabels { + return ls + } + + ptransformLabel := getProp(pipepb.MonitoringInfo_TRANSFORM) + namespaceLabel := getProp(pipepb.MonitoringInfo_NAMESPACE) + nameLabel := getProp(pipepb.MonitoringInfo_NAME) + pcollectionLabel := getProp(pipepb.MonitoringInfo_PCOLLECTION) + statusLabel := getProp(pipepb.MonitoringInfo_STATUS) + serviceLabel := getProp(pipepb.MonitoringInfo_SERVICE) + resourceLabel := getProp(pipepb.MonitoringInfo_RESOURCE) + methodLabel := getProp(pipepb.MonitoringInfo_METHOD) + + // Here's where we build the raw map from kinds of labels to the actual functions. + labelsToKey(ls(pipepb.MonitoringInfo_TRANSFORM, + pipepb.MonitoringInfo_NAMESPACE, + pipepb.MonitoringInfo_NAME), + func(urn string, labels map[string]string) metricKey { + return userMetricKey{ + urn: urn, + ptransform: labels[ptransformLabel], + namespace: labels[namespaceLabel], + name: labels[nameLabel], + } + }) + labelsToKey(ls(pipepb.MonitoringInfo_TRANSFORM), + func(urn string, labels map[string]string) metricKey { + return ptransformKey{ + urn: urn, + ptransform: labels[ptransformLabel], + } + }) + labelsToKey(ls(pipepb.MonitoringInfo_PCOLLECTION), + func(urn string, labels map[string]string) metricKey { + return pcollectionKey{ + urn: urn, + pcollection: labels[pcollectionLabel], + } + }) + labelsToKey(ls(pipepb.MonitoringInfo_SERVICE, + pipepb.MonitoringInfo_METHOD, + pipepb.MonitoringInfo_RESOURCE, + pipepb.MonitoringInfo_TRANSFORM, + pipepb.MonitoringInfo_STATUS), + func(urn string, labels map[string]string) metricKey { + return apiRequestKey{ + urn: urn, + service: labels[serviceLabel], + method: labels[methodLabel], + resource: labels[resourceLabel], + ptransform: labels[ptransformLabel], + status: labels[statusLabel], + } + }) + labelsToKey(ls(pipepb.MonitoringInfo_SERVICE, + pipepb.MonitoringInfo_METHOD, + pipepb.MonitoringInfo_RESOURCE, + pipepb.MonitoringInfo_TRANSFORM), + func(urn string, labels map[string]string) metricKey { + return apiRequestLatenciesKey{ + urn: urn, + service: labels[serviceLabel], + method: labels[methodLabel], + resource: labels[resourceLabel], + ptransform: labels[ptransformLabel], + } + }) + + // Specify accumulator decoders for all the metric types. + // These are a combination of the decoder (accepting the payload bytes) + // and represent how we hold onto them. Ultimately, these will also be + // able to extract back out to the protos. + + typ2accumFac := map[string]accumFactory{ + getMetTyp(pipepb.MonitoringInfoTypeUrns_SUM_INT64_TYPE): func() metricAccumulator { return &sumInt64{} }, + getMetTyp(pipepb.MonitoringInfoTypeUrns_SUM_DOUBLE_TYPE): func() metricAccumulator { return &sumFloat64{} }, + getMetTyp(pipepb.MonitoringInfoTypeUrns_DISTRIBUTION_INT64_TYPE): func() metricAccumulator { + // Defaults should be safe since the metric only exists if we get any values at all. + return &distributionInt64{dist: metrics.DistributionValue{Min: math.MaxInt64, Max: math.MinInt64}} + }, + getMetTyp(pipepb.MonitoringInfoTypeUrns_PROGRESS_TYPE): func() metricAccumulator { return &progress{} }, + } + + ret := make(map[string]urnOps) + for urn, spec := range mUrn2Spec { + hasher.Reset() + sorted := spec.GetRequiredLabels() + sort.Strings(sorted) + for _, l := range sorted { + hasher.WriteString(l) + } + key := hasher.Sum64() + fn, ok := l2func[key] + if !ok { + slog.Debug("unknown MonitoringSpec required Labels", + slog.String("urn", spec.GetType()), + slog.String("key", spec.GetType()), + slog.Any("sortedlabels", sorted)) + continue + } + fac, ok := typ2accumFac[spec.GetType()] + if !ok { + slog.Debug("unknown MonitoringSpec type") + continue + } + ret[urn] = urnOps{ + keyFn: fn, + newAccum: fac, + } + } + return ret +} + +type sumInt64 struct { + sum int64 +} + +func (m *sumInt64) accumulate(pyld []byte) error { + v, err := coder.DecodeVarInt(bytes.NewBuffer(pyld)) + if err != nil { + return err + } + m.sum += v + return nil +} + +func (m *sumInt64) toProto(key metricKey) *pipepb.MonitoringInfo { + var buf bytes.Buffer + coder.EncodeVarInt(m.sum, &buf) + return &pipepb.MonitoringInfo{ + Urn: key.Urn(), + Type: getMetTyp(pipepb.MonitoringInfoTypeUrns_SUM_INT64_TYPE), + Payload: buf.Bytes(), + Labels: key.Labels(), + } +} + +type sumFloat64 struct { + sum float64 +} + +func (m *sumFloat64) accumulate(pyld []byte) error { + v, err := coder.DecodeDouble(bytes.NewBuffer(pyld)) + if err != nil { + return err + } + m.sum += v + return nil +} + +func (m *sumFloat64) toProto(key metricKey) *pipepb.MonitoringInfo { + var buf bytes.Buffer + coder.EncodeDouble(m.sum, &buf) + return &pipepb.MonitoringInfo{ + Urn: key.Urn(), + Type: getMetTyp(pipepb.MonitoringInfoTypeUrns_SUM_DOUBLE_TYPE), + Payload: buf.Bytes(), + Labels: key.Labels(), + } +} + +type progress struct { + snap []float64 +} + +func (m *progress) accumulate(pyld []byte) error { + buf := bytes.NewBuffer(pyld) + // Assuming known length iterable + n, err := coder.DecodeInt32(buf) + if err != nil { + return err + } + progs := make([]float64, 0, n) + for i := int32(0); i < n; i++ { + v, err := coder.DecodeDouble(buf) + if err != nil { + return err + } + progs = append(progs, v) + } + m.snap = progs + return nil +} + +func (m *progress) toProto(key metricKey) *pipepb.MonitoringInfo { + var buf bytes.Buffer + coder.EncodeInt32(int32(len(m.snap)), &buf) + for _, v := range m.snap { + coder.EncodeDouble(v, &buf) + } + return &pipepb.MonitoringInfo{ + Urn: key.Urn(), + Type: getMetTyp(pipepb.MonitoringInfoTypeUrns_PROGRESS_TYPE), + Payload: buf.Bytes(), + Labels: key.Labels(), + } +} + +func ordMin[T constraints.Ordered](a T, b T) T { + if a < b { + return a + } + return b +} + +func ordMax[T constraints.Ordered](a T, b T) T { + if a > b { + return a + } + return b +} + +type distributionInt64 struct { + dist metrics.DistributionValue +} + +func (m *distributionInt64) accumulate(pyld []byte) error { + buf := bytes.NewBuffer(pyld) + var dist metrics.DistributionValue + var err error + if dist.Count, err = coder.DecodeVarInt(buf); err != nil { + return err + } + if dist.Sum, err = coder.DecodeVarInt(buf); err != nil { + return err + } + if dist.Min, err = coder.DecodeVarInt(buf); err != nil { + return err + } + if dist.Max, err = coder.DecodeVarInt(buf); err != nil { + return err + } + m.dist = metrics.DistributionValue{ + Count: m.dist.Count + dist.Count, + Sum: m.dist.Sum + dist.Sum, + Min: ordMin(m.dist.Min, dist.Min), + Max: ordMax(m.dist.Max, dist.Max), + } + return nil +} + +func (m *distributionInt64) toProto(key metricKey) *pipepb.MonitoringInfo { + var buf bytes.Buffer + coder.EncodeVarInt(m.dist.Count, &buf) + coder.EncodeVarInt(m.dist.Sum, &buf) + coder.EncodeVarInt(m.dist.Min, &buf) + coder.EncodeVarInt(m.dist.Max, &buf) + return &pipepb.MonitoringInfo{ + Urn: key.Urn(), + Type: getMetTyp(pipepb.MonitoringInfoTypeUrns_DISTRIBUTION_INT64_TYPE), + Payload: buf.Bytes(), + Labels: key.Labels(), + } +} + +type durability int + +const ( + tentative = durability(iota) + committed +) + +type metricAccumulator interface { + accumulate([]byte) error + // TODO, maybe just the payload, and another method for its type urn, + // Since they're all the same except for the payloads and type urn. + toProto(key metricKey) *pipepb.MonitoringInfo +} + +type accumFactory func() metricAccumulator + +type metricKey interface { + Urn() string + Labels() map[string]string +} + +type userMetricKey struct { + urn, ptransform, namespace, name string +} + +func (k userMetricKey) Urn() string { + return k.urn +} + +func (k userMetricKey) Labels() map[string]string { + return map[string]string{ + "PTRANSFORM": k.ptransform, + "NAMESPACE": k.namespace, + "NAME": k.name, + } +} + +type pcollectionKey struct { + urn, pcollection string +} + +func (k pcollectionKey) Urn() string { + return k.urn +} + +func (k pcollectionKey) Labels() map[string]string { + return map[string]string{ + "PCOLLECTION": k.pcollection, + } +} + +type ptransformKey struct { + urn, ptransform string +} + +func (k ptransformKey) Urn() string { + return k.urn +} + +func (k ptransformKey) Labels() map[string]string { + return map[string]string{ + "PTRANSFORM": k.ptransform, + } +} + +type apiRequestKey struct { + urn, service, method, resource, ptransform, status string +} + +func (k apiRequestKey) Urn() string { + return k.urn +} + +func (k apiRequestKey) Labels() map[string]string { + return map[string]string{ + "PTRANSFORM": k.ptransform, + "SERVICE": k.service, + "METHOD": k.method, + "RESOURCE": k.resource, + "STATUS": k.status, + } +} + +type apiRequestLatenciesKey struct { + urn, service, method, resource, ptransform string +} + +func (k apiRequestLatenciesKey) Urn() string { + return k.urn +} + +func (k apiRequestLatenciesKey) Labels() map[string]string { + return map[string]string{ + "PTRANSFORM": k.ptransform, + "SERVICE": k.service, + "METHOD": k.method, + "RESOURCE": k.resource, + } +} + +type metricsStore struct { + mu sync.Mutex + accums map[metricKey]metricAccumulator +} + +func (m *metricsStore) ContributeMetrics(payloads *fnpb.ProcessBundleResponse) { + m.mu.Lock() + defer m.mu.Unlock() + if m.accums == nil { + m.accums = map[metricKey]metricAccumulator{} + } + // Old and busted. + mons := payloads.GetMonitoringInfos() + for _, mon := range mons { + urn := mon.GetUrn() + ops, ok := mUrn2Ops[urn] + if !ok { + slog.Debug("unknown metrics urn", slog.String("urn", urn)) + continue + } + key := ops.keyFn(urn, mon.GetLabels()) + a, ok := m.accums[key] + if !ok { + a = ops.newAccum() + } + if err := a.accumulate(mon.GetPayload()); err != nil { + panic(fmt.Sprintf("error decoding metrics %v: %+v\n\t%+v", urn, key, a)) + } + m.accums[key] = a + } + // New hotness. + mdata := payloads.GetMonitoringData() + _ = mdata +} + +func (m *metricsStore) Results(d durability) []*pipepb.MonitoringInfo { + // We don't gather tentative metrics yet. + if d == tentative { + return nil + } + infos := make([]*pipepb.MonitoringInfo, 0, len(m.accums)) + for key, accum := range m.accums { + infos = append(infos, accum.toProto(key)) + } + return infos +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics_test.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics_test.go new file mode 100644 index 0000000000000..e0346731f3004 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics_test.go @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jobservices + +import ( + "bytes" + "encoding/binary" + "math" + "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/testing/protocmp" +) + +var metSpecs = (pipepb.MonitoringInfoSpecs_Enum)(0).Descriptor().Values() + +// makeInfo generates dummy Monitoring infos from a spec. +func makeInfo(enum pipepb.MonitoringInfoSpecs_Enum, payload []byte) *pipepb.MonitoringInfo { + spec := proto.GetExtension(metSpecs.ByNumber(protoreflect.EnumNumber(enum)).Options(), pipepb.E_MonitoringInfoSpec).(*pipepb.MonitoringInfoSpec) + + labels := map[string]string{} + for _, l := range spec.GetRequiredLabels() { + labels[l] = l + } + return &pipepb.MonitoringInfo{ + Urn: spec.GetUrn(), + Type: spec.GetType(), + Labels: labels, + Payload: payload, + } +} + +// This test validates that multiple contributions are correctly summed up and accumulated. +func Test_metricsStore_ContributeMetrics(t *testing.T) { + + doubleBytes := func(v float64) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, math.Float64bits(v)) + return b + } + + progress := func(vs ...float64) []byte { + var buf bytes.Buffer + coder.EncodeInt32(int32(len(vs)), &buf) + for _, v := range vs { + coder.EncodeDouble(v, &buf) + } + return buf.Bytes() + } + + tests := []struct { + name string + + // TODO convert input to non-legacy metrics once we support, and then delete these. + input [][]*pipepb.MonitoringInfo + + want []*pipepb.MonitoringInfo + }{ + { + name: "int64Sum", + input: [][]*pipepb.MonitoringInfo{ + {makeInfo(pipepb.MonitoringInfoSpecs_USER_SUM_INT64, []byte{3})}, + {makeInfo(pipepb.MonitoringInfoSpecs_USER_SUM_INT64, []byte{5})}, + }, + want: []*pipepb.MonitoringInfo{ + makeInfo(pipepb.MonitoringInfoSpecs_USER_SUM_INT64, []byte{8}), + }, + }, { + name: "float64Sum", + input: [][]*pipepb.MonitoringInfo{ + {makeInfo(pipepb.MonitoringInfoSpecs_USER_SUM_DOUBLE, doubleBytes(3.14))}, + {makeInfo(pipepb.MonitoringInfoSpecs_USER_SUM_DOUBLE, doubleBytes(1.06))}, + }, + want: []*pipepb.MonitoringInfo{ + makeInfo(pipepb.MonitoringInfoSpecs_USER_SUM_DOUBLE, doubleBytes(4.20)), + }, + }, { + name: "progress", + input: [][]*pipepb.MonitoringInfo{ + {makeInfo(pipepb.MonitoringInfoSpecs_WORK_REMAINING, progress(1, 2.2, 78))}, + {makeInfo(pipepb.MonitoringInfoSpecs_WORK_REMAINING, progress(0, 7.8, 22))}, + }, + want: []*pipepb.MonitoringInfo{ + makeInfo(pipepb.MonitoringInfoSpecs_WORK_REMAINING, progress(0, 7.8, 22)), + }, + }, { + name: "int64Distribution", + input: [][]*pipepb.MonitoringInfo{ + {makeInfo(pipepb.MonitoringInfoSpecs_USER_DISTRIBUTION_INT64, []byte{1, 2, 2, 2})}, + {makeInfo(pipepb.MonitoringInfoSpecs_USER_DISTRIBUTION_INT64, []byte{3, 17, 5, 7})}, + }, + want: []*pipepb.MonitoringInfo{ + makeInfo(pipepb.MonitoringInfoSpecs_USER_DISTRIBUTION_INT64, []byte{4, 19, 2, 7}), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ms := metricsStore{} + + for _, payload := range test.input { + resp := &fnpb.ProcessBundleResponse{ + MonitoringInfos: payload, + } + ms.ContributeMetrics(resp) + } + + got := ms.Results(committed) + + if diff := cmp.Diff(test.want, got, protocmp.Transform()); diff != "" { + t.Fatalf("metricsStore.ContributeMetrics(%v) diff (-want,+got):\n%v", test.input, diff) + } + }) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go new file mode 100644 index 0000000000000..2f88293c1dabe --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jobservices + +import ( + "fmt" + "net" + "sync" + + jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" + "golang.org/x/exp/slog" + "google.golang.org/grpc" +) + +type Server struct { + jobpb.UnimplementedJobServiceServer + jobpb.UnimplementedArtifactStagingServiceServer + + // Server management + lis net.Listener + server *grpc.Server + + // Job Management + mu sync.Mutex + index uint32 + jobs map[string]*Job + + // execute defines how a job is executed. + execute func(*Job) +} + +// NewServer acquires the indicated port. +func NewServer(port int, execute func(*Job)) *Server { + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + panic(fmt.Sprintf("failed to listen: %v", err)) + } + s := &Server{ + lis: lis, + jobs: make(map[string]*Job), + execute: execute, + } + slog.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) + var opts []grpc.ServerOption + s.server = grpc.NewServer(opts...) + jobpb.RegisterJobServiceServer(s.server, s) + jobpb.RegisterArtifactStagingServiceServer(s.server, s) + return s +} + +func (s *Server) getJob(id string) *Job { + s.mu.Lock() + defer s.mu.Unlock() + return s.jobs[id] +} + +func (s *Server) Endpoint() string { + return s.lis.Addr().String() +} + +// Serve serves on the started listener. Blocks. +func (s *Server) Serve() { + s.server.Serve(s.lis) +} + +// Stop the GRPC server. +func (s *Server) Stop() { + s.server.GracefulStop() +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go new file mode 100644 index 0000000000000..2223f030ce1d3 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jobservices + +import ( + "context" + "sync" + "testing" + + jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "google.golang.org/protobuf/encoding/prototext" +) + +// TestServer_Lifecycle validates that a server can start and stop. +func TestServer_Lifecycle(t *testing.T) { + undertest := NewServer(0, func(j *Job) { + t.Fatalf("unexpected call to execute: %v", j) + }) + + go undertest.Serve() + + undertest.Stop() +} + +// Validates that a job can start and stop. +func TestServer_JobLifecycle(t *testing.T) { + var called sync.WaitGroup + called.Add(1) + undertest := NewServer(0, func(j *Job) { + called.Done() + }) + ctx := context.Background() + + wantPipeline := &pipepb.Pipeline{ + Requirements: []string{urns.RequirementSplittableDoFn}, + } + wantName := "testJob" + + resp, err := undertest.Prepare(ctx, &jobpb.PrepareJobRequest{ + Pipeline: wantPipeline, + JobName: wantName, + }) + if err != nil { + t.Fatalf("server.Prepare() = %v, want nil", err) + } + + if got := resp.GetPreparationId(); got == "" { + t.Fatalf("server.Prepare() = returned empty preparation ID, want non-empty: %v", prototext.Format(resp)) + } + + runResp, err := undertest.Run(ctx, &jobpb.RunJobRequest{ + PreparationId: resp.GetPreparationId(), + }) + if err != nil { + t.Fatalf("server.Run() = %v, want nil", err) + } + if got := runResp.GetJobId(); got == "" { + t.Fatalf("server.Run() = returned empty preparation ID, want non-empty") + } + // If execute is never called, this doesn't unblock and timesout. + called.Wait() + t.Log("success!") + // Nothing to cleanup because we didn't start the server. +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go new file mode 100644 index 0000000000000..8769a05d38f47 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "sort" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "golang.org/x/exp/maps" + "golang.org/x/exp/slog" +) + +// transformPreparer is an interface for handling different urns in the preprocessor +// largely for exchanging transforms for others, to be added to the complete set of +// components in the pipeline. +type transformPreparer interface { + // PrepareUrns returns the Beam URNs that this handler deals with for preprocessing. + PrepareUrns() []string + // PrepareTransform takes a PTransform proto and returns a set of new Components, and a list of + // transformIDs leaves to remove and ignore from graph processing. + PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb.Components) (*pipepb.Components, []string) +} + +// preprocessor retains configuration for preprocessing the +// graph, such as special handling for lifted combiners or +// other configuration. +type preprocessor struct { + transformPreparers map[string]transformPreparer +} + +func newPreprocessor(preps []transformPreparer) *preprocessor { + preparers := map[string]transformPreparer{} + for _, prep := range preps { + for _, urn := range prep.PrepareUrns() { + preparers[urn] = prep + } + } + return &preprocessor{ + transformPreparers: preparers, + } +} + +// preProcessGraph takes the graph and preprocesses for consumption in bundles. +// The output is the topological sort of the transform ids. +// +// These are how transforms are related in graph form, but not the specific bundles themselves, which will come later. +// +// Handles awareness of composite transforms and similar. Ultimately, after this point +// the graph stops being a hypergraph, with composite transforms being treated as +// "leaves" downstream as needed. +// +// This is where Combines become lifted (if it makes sense, or is configured), and similar behaviors. +func (p *preprocessor) preProcessGraph(comps *pipepb.Components) []*stage { + ts := comps.GetTransforms() + + // TODO move this out of this part of the pre-processor? + leaves := map[string]struct{}{} + ignore := map[string]struct{}{} + for tid, t := range ts { + if _, ok := ignore[tid]; ok { + continue + } + + spec := t.GetSpec() + if spec == nil { + // Most composites don't have specs. + slog.Debug("transform is missing a spec", + slog.Group("transform", slog.String("ID", tid), slog.String("name", t.GetUniqueName()))) + continue + } + + // Composite Transforms basically means needing to remove the "leaves" from the + // handling set, and producing the new sub component transforms. The top level + // composite should have enough information to produce the new sub transforms. + // In particular, the inputs and outputs need to all be connected and matched up + // so the topological sort still works out. + h := p.transformPreparers[spec.GetUrn()] + if h == nil { + + // If there's an unknown urn, and it's not composite, simply add it to the leaves. + if len(t.GetSubtransforms()) == 0 { + leaves[tid] = struct{}{} + } else { + slog.Info("composite transform has unknown urn", + slog.Group("transform", slog.String("ID", tid), + slog.String("name", t.GetUniqueName()), + slog.String("urn", spec.GetUrn()))) + } + continue + } + + subs, toRemove := h.PrepareTransform(tid, t, comps) + + // Clear out unnecessary leaves from this composite for topological sort handling. + for _, key := range toRemove { + ignore[key] = struct{}{} + delete(leaves, key) + } + + // ts should be a clone, so we should be able to add new transforms into the map. + for tid, t := range subs.GetTransforms() { + leaves[tid] = struct{}{} + ts[tid] = t + } + for cid, c := range subs.GetCoders() { + comps.GetCoders()[cid] = c + } + for nid, n := range subs.GetPcollections() { + comps.GetPcollections()[nid] = n + } + // It's unlikely for these to change, but better to handle them now, to save a headache later. + for wid, w := range subs.GetWindowingStrategies() { + comps.GetWindowingStrategies()[wid] = w + } + for envid, env := range subs.GetEnvironments() { + comps.GetEnvironments()[envid] = env + } + } + + // Extract URNs for the given transform. + + keptLeaves := maps.Keys(leaves) + sort.Strings(keptLeaves) + topological := pipelinex.TopologicalSort(ts, keptLeaves) + slog.Debug("topological transform ordering", topological) + + var stages []*stage + for _, tid := range topological { + stages = append(stages, &stage{ + transforms: []string{tid}, + }) + } + return stages +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go new file mode 100644 index 0000000000000..add69a7c76792 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess_test.go @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "testing" + + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/testing/protocmp" +) + +func Test_preprocessor_preProcessGraph(t *testing.T) { + tests := []struct { + name string + input *pipepb.Components + + wantComponents *pipepb.Components + wantStages []*stage + }{ + { + name: "noPreparer", + input: &pipepb.Components{ + Transforms: map[string]*pipepb.PTransform{ + "e1": { + UniqueName: "e1", + Spec: &pipepb.FunctionSpec{ + Urn: "defaultUrn", + }, + }, + }, + }, + + wantStages: []*stage{{transforms: []string{"e1"}}}, + wantComponents: &pipepb.Components{ + Transforms: map[string]*pipepb.PTransform{ + "e1": { + UniqueName: "e1", + Spec: &pipepb.FunctionSpec{ + Urn: "defaultUrn", + }, + }, + }, + }, + }, { + name: "preparer", + input: &pipepb.Components{ + Transforms: map[string]*pipepb.PTransform{ + "e1": { + UniqueName: "e1", + Spec: &pipepb.FunctionSpec{ + Urn: "test_urn", + }, + }, + }, + // Initialize maps because they always are by proto unmarshallers. + Pcollections: map[string]*pipepb.PCollection{}, + WindowingStrategies: map[string]*pipepb.WindowingStrategy{}, + Coders: map[string]*pipepb.Coder{}, + Environments: map[string]*pipepb.Environment{}, + }, + + wantStages: []*stage{{transforms: []string{"e1_early"}}, {transforms: []string{"e1_late"}}}, + wantComponents: &pipepb.Components{ + Transforms: map[string]*pipepb.PTransform{ + // Original is always kept + "e1": { + UniqueName: "e1", + Spec: &pipepb.FunctionSpec{ + Urn: "test_urn", + }, + }, + "e1_early": { + UniqueName: "e1_early", + Spec: &pipepb.FunctionSpec{ + Urn: "defaultUrn", + }, + Outputs: map[string]string{"i0": "pcol1"}, + EnvironmentId: "env1", + }, + "e1_late": { + UniqueName: "e1_late", + Spec: &pipepb.FunctionSpec{ + Urn: "defaultUrn", + }, + Inputs: map[string]string{"i0": "pcol1"}, + EnvironmentId: "env1", + }, + }, + Pcollections: map[string]*pipepb.PCollection{ + "pcol1": { + UniqueName: "pcol1", + CoderId: "coder1", + WindowingStrategyId: "ws1", + }, + }, + Coders: map[string]*pipepb.Coder{ + "coder1": {Spec: &pipepb.FunctionSpec{Urn: "coder1"}}, + }, + WindowingStrategies: map[string]*pipepb.WindowingStrategy{ + "ws1": {WindowCoderId: "global"}, + }, + Environments: map[string]*pipepb.Environment{ + "env1": {Urn: "env1"}, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pre := newPreprocessor([]transformPreparer{&testPreparer{}}) + + gotStages := pre.preProcessGraph(test.input) + if diff := cmp.Diff(test.wantStages, gotStages, cmp.AllowUnexported(stage{})); diff != "" { + t.Errorf("preProcessGraph(%q) stages diff (-want,+got)\n%v", test.name, diff) + } + + if diff := cmp.Diff(test.input, test.wantComponents, protocmp.Transform()); diff != "" { + t.Errorf("preProcessGraph(%q) components diff (-want,+got)\n%v", test.name, diff) + } + }) + } +} + +type testPreparer struct{} + +func (p *testPreparer) PrepareUrns() []string { + return []string{"test_urn"} +} + +func (p *testPreparer) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipepb.Components) (*pipepb.Components, []string) { + return &pipepb.Components{ + Transforms: map[string]*pipepb.PTransform{ + "e1_early": { + UniqueName: "e1_early", + Spec: &pipepb.FunctionSpec{ + Urn: "defaultUrn", + }, + Outputs: map[string]string{"i0": "pcol1"}, + EnvironmentId: "env1", + }, + "e1_late": { + UniqueName: "e1_late", + Spec: &pipepb.FunctionSpec{ + Urn: "defaultUrn", + }, + Inputs: map[string]string{"i0": "pcol1"}, + EnvironmentId: "env1", + }, + }, + Pcollections: map[string]*pipepb.PCollection{ + "pcol1": { + UniqueName: "pcol1", + CoderId: "coder1", + WindowingStrategyId: "ws1", + }, + }, + Coders: map[string]*pipepb.Coder{ + "coder1": {Spec: &pipepb.FunctionSpec{Urn: "coder1"}}, + }, + WindowingStrategies: map[string]*pipepb.WindowingStrategy{ + "ws1": {WindowCoderId: "global"}, + }, + Environments: map[string]*pipepb.Environment{ + "env1": {Urn: "env1"}, + }, + }, []string{"e1"} +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go new file mode 100644 index 0000000000000..edfe37365031e --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go @@ -0,0 +1,593 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "net" + "net/http" + "net/rpc" + "sync" + "testing" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" + "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" + "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/stats" + "golang.org/x/exp/slog" +) + +// separate_test.go is retains structures and tests to ensure the runner can +// perform separation, and terminate checkpoints. + +// Global variable, so only one is registered with the OS. +var ws = &Watchers{} + +// TestSeparation validates that the runner is able to split +// elements in time and space. Beam has a few mechanisms to +// do this. +// +// First is channel splits, where a slowly processing +// bundle might have it's remaining buffered elements truncated +// so they can be processed by a another bundle, +// possibly simultaneously. +// +// Second is sub element splitting, where a single element +// in an SDF might be split into smaller restrictions. +// +// Third with Checkpointing or ProcessContinuations, +// a User DoFn may decide to defer processing of an element +// until later, permitting a bundle to terminate earlier, +// delaying processing. +// +// All these may be tested locally or in process with a small +// server the DoFns can connect to. This can then indicate which +// elements, or positions are considered "sentinels". +// +// When a sentinel is to be processed, instead the DoFn blocks. +// The goal for Splitting tests is to succeed only when all +// sentinels are blocking waiting to be processed. +// This indicates the runner has "separated" the sentinels, hence +// the name "separation harness tests". +// +// Delayed Process Continuations can be similiarly tested, +// as this emulates external processing servers anyway. +// It's much simpler though, as the request is to determine if +// a given element should be delayed or not. This could be used +// for arbitrarily complex splitting patterns, as desired. +func TestSeparation(t *testing.T) { + initRunner(t) + + ws.initRPCServer() + + tests := []struct { + name string + pipeline func(s beam.Scope) + metrics func(t *testing.T, pr beam.PipelineResult) + }{ + { + name: "ProcessContinuations_combine_globalWindow", + pipeline: func(s beam.Scope) { + count := 10 + imp := beam.Impulse(s) + out := beam.ParDo(s, &sepHarnessSdfStream{ + Base: sepHarnessBase{ + WatcherID: ws.newWatcher(3), + Sleep: time.Second, + IsSentinelEncoded: beam.EncodedFunc{Fn: reflectx.MakeFunc(allSentinel)}, + LocalService: ws.serviceAddress, + }, + RestSize: int64(count), + }, imp) + passert.Count(s, out, "global num ints", count) + }, + }, { + name: "ProcessContinuations_stepped_combine_globalWindow", + pipeline: func(s beam.Scope) { + count := 10 + imp := beam.Impulse(s) + out := beam.ParDo(s, &singleStepSdfStream{ + Sleep: time.Second, + RestSize: int64(count), + }, imp) + passert.Count(s, out, "global stepped num ints", count) + sum := beam.ParDo(s, dofn2x1, imp, beam.SideInput{Input: out}) + beam.ParDo(s, &int64Check{Name: "stepped", Want: []int{45}}, sum) + }, + }, { + name: "ProcessContinuations_stepped_combine_fixedWindow", + pipeline: func(s beam.Scope) { + elms, mod := 1000, 10 + count := int(elms / mod) + imp := beam.Impulse(s) + out := beam.ParDo(s, &eventtimeSDFStream{ + Sleep: time.Second, + RestSize: int64(elms), + Mod: int64(mod), + Fixed: 1, + }, imp) + windowed := beam.WindowInto(s, window.NewFixedWindows(time.Second*10), out) + sum := stats.Sum(s, windowed) + // We expect each window to be processed ASAP, and produced one + // at a time, with the same results. + beam.ParDo(s, &int64Check{Name: "single", Want: []int{55}}, sum) + // But we need to receive the expected number of identical results + gsum := beam.WindowInto(s, window.NewGlobalWindows(), sum) + passert.Count(s, gsum, "total sums", count) + }, + }, + } + + // TODO: Channel Splits + // TODO: SubElement/dynamic splits. + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + p, s := beam.NewPipelineWithRoot() + test.pipeline(s) + pr, err := executeWithT(context.Background(), t, p) + if err != nil { + t.Fatal(err) + } + if test.metrics != nil { + test.metrics(t, pr) + } + }) + } +} + +func init() { + register.Function1x1(allSentinel) +} + +// allSentinel indicates that all elements are sentinels. +func allSentinel(v beam.T) bool { + return true +} + +// Watcher is an instance of the counters. +type watcher struct { + id int + mu sync.Mutex + sentinelCount, sentinelCap int +} + +func (w *watcher) LogValue() slog.Value { + return slog.GroupValue( + slog.Int("id", w.id), + slog.Int("sentinelCount", w.sentinelCount), + slog.Int("sentinelCap", w.sentinelCap), + ) +} + +// Watchers is a "net/rpc" service. +type Watchers struct { + mu sync.Mutex + nextID int + lookup map[int]*watcher + serviceOnce sync.Once + serviceAddress string +} + +// Args is the set of parameters to the watchers RPC methdos. +type Args struct { + WatcherID int +} + +// Block is called once per sentinel, to indicate it will block +// until all sentinels are blocked. +func (ws *Watchers) Block(args *Args, _ *bool) error { + ws.mu.Lock() + defer ws.mu.Unlock() + w, ok := ws.lookup[args.WatcherID] + if !ok { + return fmt.Errorf("no watcher with id %v", args.WatcherID) + } + w.mu.Lock() + w.sentinelCount++ + w.mu.Unlock() + return nil +} + +// Check returns whether the sentinels are unblocked or not. +func (ws *Watchers) Check(args *Args, unblocked *bool) error { + ws.mu.Lock() + defer ws.mu.Unlock() + w, ok := ws.lookup[args.WatcherID] + if !ok { + return fmt.Errorf("no watcher with id %v", args.WatcherID) + } + w.mu.Lock() + *unblocked = w.sentinelCount >= w.sentinelCap + w.mu.Unlock() + slog.Debug("sentinel target for watcher%d is %d/%d. unblocked=%v", args.WatcherID, w.sentinelCount, w.sentinelCap, *unblocked) + return nil +} + +// Delay returns whether the sentinels shoudld delay. +// This increments the sentinel cap, and returns unblocked. +// Intended to validate ProcessContinuation behavior. +func (ws *Watchers) Delay(args *Args, delay *bool) error { + ws.mu.Lock() + defer ws.mu.Unlock() + w, ok := ws.lookup[args.WatcherID] + if !ok { + return fmt.Errorf("no watcher with id %v", args.WatcherID) + } + w.mu.Lock() + w.sentinelCount++ + // Delay as long as the sentinel count is under the cap. + *delay = w.sentinelCount < w.sentinelCap + w.mu.Unlock() + slog.Debug("Delay: sentinel target", "watcher", w, slog.Bool("delay", *delay)) + return nil +} + +func (ws *Watchers) initRPCServer() { + ws.serviceOnce.Do(func() { + l, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + rpc.Register(ws) + rpc.HandleHTTP() + go http.Serve(l, nil) + ws.serviceAddress = l.Addr().String() + }) +} + +// newWatcher starts an rpc server to maange state for watching for +// sentinels across local machines. +func (ws *Watchers) newWatcher(sentinelCap int) int { + ws.mu.Lock() + defer ws.mu.Unlock() + ws.initRPCServer() + if ws.lookup == nil { + ws.lookup = map[int]*watcher{} + } + w := &watcher{id: ws.nextID, sentinelCap: sentinelCap} + ws.nextID++ + ws.lookup[w.id] = w + return w.id +} + +// sepHarnessBase contains fields and functions that are shared by all +// versions of the separation harness. +type sepHarnessBase struct { + WatcherID int + Sleep time.Duration + IsSentinelEncoded beam.EncodedFunc + LocalService string +} + +// One connection per binary. +var ( + sepClientOnce sync.Once + sepClient *rpc.Client + sepClientMu sync.Mutex + sepWaitMap map[int]chan struct{} +) + +func (fn *sepHarnessBase) setup() error { + sepClientMu.Lock() + defer sepClientMu.Unlock() + sepClientOnce.Do(func() { + client, err := rpc.DialHTTP("tcp", fn.LocalService) + if err != nil { + slog.Error("failed to dial sentinels server", err, slog.String("endpoint", fn.LocalService)) + panic(fmt.Sprintf("dialing sentinels server %v: %v", fn.LocalService, err)) + } + sepClient = client + sepWaitMap = map[int]chan struct{}{} + }) + + // Check if there's alreaedy a local channel for this id, and if not + // start a watcher goroutine to poll and unblock the harness when + // the expected number of ssentinels is reached. + if _, ok := sepWaitMap[fn.WatcherID]; !ok { + return nil + } + // We need a channel to block on for this watcherID + // We use a channel instead of a wait group since the finished + // count is hosted in a different process. + c := make(chan struct{}) + sepWaitMap[fn.WatcherID] = c + go func(id int, c chan struct{}) { + for { + time.Sleep(time.Second * 1) // Check counts every second. + sepClientMu.Lock() + var unblock bool + err := sepClient.Call("Watchers.Check", &Args{WatcherID: id}, &unblock) + if err != nil { + slog.Error("Watchers.Check: sentinels server error", err, slog.String("endpoint", fn.LocalService)) + panic("sentinel server error") + } + if unblock { + close(c) // unblock all the local waiters. + slog.Debug("sentinel target for watcher, unblocking", slog.Int("watcherID", id)) + sepClientMu.Unlock() + return + } + slog.Debug("sentinel target for watcher not met", slog.Int("watcherID", id)) + sepClientMu.Unlock() + } + }(fn.WatcherID, c) + return nil +} + +func (fn *sepHarnessBase) block() { + sepClientMu.Lock() + var ignored bool + err := sepClient.Call("Watchers.Block", &Args{WatcherID: fn.WatcherID}, &ignored) + if err != nil { + slog.Error("Watchers.Block error", err, slog.String("endpoint", fn.LocalService)) + panic(err) + } + c := sepWaitMap[fn.WatcherID] + sepClientMu.Unlock() + + // Block until the watcher closes the channel. + <-c +} + +// delay inform the DoFn whether or not to return a delayed Processing continuation for this position. +func (fn *sepHarnessBase) delay() bool { + sepClientMu.Lock() + defer sepClientMu.Unlock() + var delay bool + err := sepClient.Call("Watchers.Delay", &Args{WatcherID: fn.WatcherID}, &delay) + if err != nil { + slog.Error("Watchers.Delay error", err) + panic(err) + } + return delay +} + +// sepHarness is a simple DoFn that blocks when reaching a sentinel. +// It's useful for testing blocks on channel splits. +type sepHarness struct { + Base sepHarnessBase +} + +func (fn *sepHarness) Setup() error { + return fn.Base.setup() +} + +func (fn *sepHarness) ProcessElement(v beam.T) beam.T { + if fn.Base.IsSentinelEncoded.Fn.Call([]any{v})[0].(bool) { + slog.Debug("blocking on sentinel", slog.Any("sentinel", v)) + fn.Base.block() + slog.Debug("unblocking from sentinel", slog.Any("sentinel", v)) + } else { + time.Sleep(fn.Base.Sleep) + } + return v +} + +type sepHarnessSdf struct { + Base sepHarnessBase + RestSize int64 +} + +func (fn *sepHarnessSdf) Setup() error { + return fn.Base.setup() +} + +func (fn *sepHarnessSdf) CreateInitialRestriction(v beam.T) offsetrange.Restriction { + return offsetrange.Restriction{Start: 0, End: fn.RestSize} +} + +func (fn *sepHarnessSdf) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction { + return r.EvenSplits(2) +} + +func (fn *sepHarnessSdf) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 { + return r.Size() +} + +func (fn *sepHarnessSdf) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(offsetrange.NewTracker(r)) +} + +func (fn *sepHarnessSdf) ProcessElement(rt *sdf.LockRTracker, v beam.T, emit func(beam.T)) { + i := rt.GetRestriction().(offsetrange.Restriction).Start + for rt.TryClaim(i) { + if fn.Base.IsSentinelEncoded.Fn.Call([]any{i, v})[0].(bool) { + slog.Debug("blocking on sentinel", slog.Group("sentinel", slog.Any("value", v), slog.Int64("pos", i))) + fn.Base.block() + slog.Debug("unblocking from sentinel", slog.Group("sentinel", slog.Any("value", v), slog.Int64("pos", i))) + } else { + time.Sleep(fn.Base.Sleep) + } + emit(v) + i++ + } +} + +func init() { + register.DoFn3x1[*sdf.LockRTracker, beam.T, func(beam.T), sdf.ProcessContinuation]((*sepHarnessSdfStream)(nil)) + register.Emitter1[beam.T]() + register.DoFn3x1[*sdf.LockRTracker, beam.T, func(int64), sdf.ProcessContinuation]((*singleStepSdfStream)(nil)) + register.Emitter1[int64]() + register.DoFn4x1[*CWE, *sdf.LockRTracker, beam.T, func(beam.EventTime, int64), sdf.ProcessContinuation]((*eventtimeSDFStream)(nil)) + register.Emitter2[beam.EventTime, int64]() +} + +type sepHarnessSdfStream struct { + Base sepHarnessBase + RestSize int64 +} + +func (fn *sepHarnessSdfStream) Setup() error { + return fn.Base.setup() +} + +func (fn *sepHarnessSdfStream) CreateInitialRestriction(v beam.T) offsetrange.Restriction { + return offsetrange.Restriction{Start: 0, End: fn.RestSize} +} + +func (fn *sepHarnessSdfStream) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction { + return r.EvenSplits(2) +} + +func (fn *sepHarnessSdfStream) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 { + return r.Size() +} + +func (fn *sepHarnessSdfStream) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(offsetrange.NewTracker(r)) +} + +func (fn *sepHarnessSdfStream) ProcessElement(rt *sdf.LockRTracker, v beam.T, emit func(beam.T)) sdf.ProcessContinuation { + if fn.Base.IsSentinelEncoded.Fn.Call([]any{v})[0].(bool) { + if fn.Base.delay() { + slog.Debug("delaying on sentinel", slog.Group("sentinel", slog.Any("value", v))) + return sdf.ResumeProcessingIn(fn.Base.Sleep) + } + slog.Debug("cleared to process sentinel", slog.Group("sentinel", slog.Any("value", v))) + } + r := rt.GetRestriction().(offsetrange.Restriction) + i := r.Start + for rt.TryClaim(i) { + emit(v) + i++ + } + return sdf.StopProcessing() +} + +// singleStepSdfStream only emits a single position at a time then sleeps. +// Stops when a restriction of size 0 is provided. +type singleStepSdfStream struct { + RestSize int64 + Sleep time.Duration +} + +func (fn *singleStepSdfStream) Setup() error { + return nil +} + +func (fn *singleStepSdfStream) CreateInitialRestriction(v beam.T) offsetrange.Restriction { + return offsetrange.Restriction{Start: 0, End: fn.RestSize} +} + +func (fn *singleStepSdfStream) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction { + return r.EvenSplits(2) +} + +func (fn *singleStepSdfStream) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 { + return r.Size() +} + +func (fn *singleStepSdfStream) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(offsetrange.NewTracker(r)) +} + +func (fn *singleStepSdfStream) ProcessElement(rt *sdf.LockRTracker, v beam.T, emit func(int64)) sdf.ProcessContinuation { + r := rt.GetRestriction().(offsetrange.Restriction) + i := r.Start + if r.Size() < 1 { + slog.Debug("size 0 restriction, stoping to process sentinel", slog.Any("value", v)) + return sdf.StopProcessing() + } + slog.Debug("emitting element to restriction", slog.Any("value", v), slog.Group("restriction", + slog.Any("value", v), + slog.Float64("size", r.Size()), + slog.Int64("pos", i), + )) + if rt.TryClaim(i) { + emit(i) + } + return sdf.ResumeProcessingIn(fn.Sleep) +} + +type eventtimeSDFStream struct { + RestSize, Mod, Fixed int64 + Sleep time.Duration +} + +func (fn *eventtimeSDFStream) Setup() error { + return nil +} + +func (fn *eventtimeSDFStream) CreateInitialRestriction(v beam.T) offsetrange.Restriction { + return offsetrange.Restriction{Start: 0, End: fn.RestSize} +} + +func (fn *eventtimeSDFStream) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction { + // No split + return []offsetrange.Restriction{r} +} + +func (fn *eventtimeSDFStream) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 { + return r.Size() +} + +func (fn *eventtimeSDFStream) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(offsetrange.NewTracker(r)) +} + +func (fn *eventtimeSDFStream) ProcessElement(_ *CWE, rt *sdf.LockRTracker, v beam.T, emit func(beam.EventTime, int64)) sdf.ProcessContinuation { + r := rt.GetRestriction().(offsetrange.Restriction) + i := r.Start + if r.Size() < 1 { + slog.Debug("size 0 restriction, stoping to process sentinel", slog.Any("value", v)) + return sdf.StopProcessing() + } + slog.Debug("emitting element to restriction", slog.Any("value", v), slog.Group("restriction", + slog.Any("value", v), + slog.Float64("size", r.Size()), + slog.Int64("pos", i), + )) + if rt.TryClaim(i) { + timestamp := mtime.FromMilliseconds(int64((i + 1) * 1000)).Subtract(10 * time.Millisecond) + v := (i % fn.Mod) + fn.Fixed + emit(timestamp, v) + } + return sdf.ResumeProcessingIn(fn.Sleep) +} + +func (fn *eventtimeSDFStream) InitialWatermarkEstimatorState(_ beam.EventTime, _ offsetrange.Restriction, _ beam.T) int64 { + return int64(mtime.MinTimestamp) +} + +func (fn *eventtimeSDFStream) CreateWatermarkEstimator(initialState int64) *CWE { + return &CWE{Watermark: initialState} +} + +func (fn *eventtimeSDFStream) WatermarkEstimatorState(e *CWE) int64 { + return e.Watermark +} + +type CWE struct { + Watermark int64 // uses int64, since the SDK prevent mtime.Time from serialization. +} + +func (e *CWE) CurrentWatermark() time.Time { + return mtime.Time(e.Watermark).ToTime() +} + +func (e *CWE) ObserveTimestamp(ts time.Time) { + // We add 10 milliseconds to allow window boundaries to + // progress after emitting + e.Watermark = int64(mtime.FromTime(ts.Add(-90 * time.Millisecond))) +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/testdofns.go b/sdks/go/pkg/beam/runners/prism/internal/testdofns.go new file mode 100644 index 0000000000000..4aa07a46c6f22 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/testdofns.go @@ -0,0 +1,349 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" + "github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange" + "github.com/apache/beam/sdks/v2/go/pkg/beam/log" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" + "github.com/google/go-cmp/cmp" +) + +// The Test DoFns live outside of the test files to get coverage information on DoFn +// Lifecycle method execution. This inflates binary size, but ensures the runner is +// exercising the expected feature set. +// +// Once there's enough confidence in the runner, we can move these into a dedicated testing +// package along with the pipelines that use them. + +// Registrations should happen in the test files, so the compiler can prune these +// when they are not in use. + +func dofn1(imp []byte, emit func(int64)) { + emit(1) + emit(2) + emit(3) +} + +func dofn1kv(imp []byte, emit func(int64, int64)) { + emit(0, 1) + emit(0, 2) + emit(0, 3) +} + +func dofn1x2(imp []byte, emitA func(int64), emitB func(int64)) { + emitA(1) + emitA(2) + emitA(3) + emitB(4) + emitB(5) + emitB(6) +} + +func dofn1x5(imp []byte, emitA, emitB, emitC, emitD, emitE func(int64)) { + emitA(1) + emitB(2) + emitC(3) + emitD(4) + emitE(5) + emitA(6) + emitB(7) + emitC(8) + emitD(9) + emitE(10) +} + +func dofn2x1(imp []byte, iter func(*int64) bool, emit func(int64)) { + var v, sum, c int64 + for iter(&v) { + fmt.Println("dofn2x1 v", v, " c ", c) + sum += v + c++ + } + fmt.Println("dofn2x1 sum", sum, "count", c) + emit(sum) +} + +func dofn2x2KV(imp []byte, iter func(*string, *int64) bool, emitK func(string), emitV func(int64)) { + var k string + var v, sum int64 + for iter(&k, &v) { + sum += v + emitK(k) + } + emitV(sum) +} + +func dofnMultiMap(key string, lookup func(string) func(*int64) bool, emitK func(string), emitV func(int64)) { + var v, sum int64 + iter := lookup(key) + for iter(&v) { + sum += v + } + emitK(key) + emitV(sum) +} + +func dofn3x1(sum int64, iter1, iter2 func(*int64) bool, emit func(int64)) { + var v int64 + for iter1(&v) { + sum += v + } + for iter2(&v) { + sum += v + } + emit(sum) +} + +// int64Check validates that within a single bundle, +// we received the expected int64 values & sends them downstream. +// +// Invalid pattern for general testing, as it will fail +// on other valid execution patterns, like single element bundles. +type int64Check struct { + Name string + Want []int + got []int +} + +func (fn *int64Check) ProcessElement(v int64, _ func(int64)) { + fn.got = append(fn.got, int(v)) +} + +func (fn *int64Check) FinishBundle(_ func(int64)) error { + sort.Ints(fn.got) + sort.Ints(fn.Want) + if d := cmp.Diff(fn.Want, fn.got); d != "" { + return fmt.Errorf("int64Check[%v] (-want, +got): %v", fn.Name, d) + } + // Clear for subsequent calls. + fn.got = nil + return nil +} + +// stringCheck validates that within a single bundle, +// we received the expected string values. +// Re-emits them downstream. +// +// Invalid pattern for general testing, as it will fail +// on other valid execution patterns, like single element bundles. +type stringCheck struct { + Name string + Want []string + got []string +} + +func (fn *stringCheck) ProcessElement(v string, _ func(string)) { + fn.got = append(fn.got, v) +} + +func (fn *stringCheck) FinishBundle(_ func(string)) error { + sort.Strings(fn.got) + sort.Strings(fn.Want) + if d := cmp.Diff(fn.Want, fn.got); d != "" { + return fmt.Errorf("stringCheck[%v] (-want, +got): %v", fn.Name, d) + } + return nil +} + +func dofn2(v int64, emit func(int64)) { + emit(v + 1) +} + +func dofnKV(imp []byte, emit func(string, int64)) { + emit("a", 1) + emit("b", 2) + emit("a", 3) + emit("b", 4) + emit("a", 5) + emit("b", 6) +} + +func dofnKV2(imp []byte, emit func(int64, string)) { + emit(1, "a") + emit(2, "b") + emit(1, "a") + emit(2, "b") + emit(1, "a") + emit(2, "b") +} + +func dofnGBK(k string, vs func(*int64) bool, emit func(int64)) { + var v, sum int64 + for vs(&v) { + sum += v + } + emit(sum) +} + +func dofnGBK2(k int64, vs func(*string) bool, emit func(string)) { + var v, sum string + for vs(&v) { + sum += v + } + emit(sum) +} + +type testRow struct { + A string + B int64 +} + +func dofnKV3(imp []byte, emit func(testRow, testRow)) { + emit(testRow{"a", 1}, testRow{"a", 1}) +} + +func dofnGBK3(k testRow, vs func(*testRow) bool, emit func(string)) { + var v testRow + vs(&v) + emit(fmt.Sprintf("%v: %v", k, v)) +} + +const ( + ns = "localtest" +) + +func dofnSink(ctx context.Context, _ []byte) { + beam.NewCounter(ns, "sunk").Inc(ctx, 73) +} + +func dofn1Counter(ctx context.Context, _ []byte, emit func(int64)) { + beam.NewCounter(ns, "count").Inc(ctx, 1) +} + +func combineIntSum(a, b int64) int64 { + return a + b +} + +// SourceConfig is a struct containing all the configuration options for a +// synthetic source. It should be created via a SourceConfigBuilder, not by +// directly initializing it (the fields are public to allow encoding). +type SourceConfig struct { + NumElements int64 `json:"num_records" beam:"num_records"` + InitialSplits int64 `json:"initial_splits" beam:"initial_splits"` +} + +// intRangeFn is a splittable DoFn for counting from 1 to N. +type intRangeFn struct{} + +// CreateInitialRestriction creates an offset range restriction representing +// the number of elements to emit. +func (fn *intRangeFn) CreateInitialRestriction(config SourceConfig) offsetrange.Restriction { + return offsetrange.Restriction{ + Start: 0, + End: int64(config.NumElements), + } +} + +// SplitRestriction splits restrictions equally according to the number of +// initial splits specified in SourceConfig. Each restriction output by this +// method will contain at least one element, so the number of splits will not +// exceed the number of elements. +func (fn *intRangeFn) SplitRestriction(config SourceConfig, rest offsetrange.Restriction) (splits []offsetrange.Restriction) { + return rest.EvenSplits(int64(config.InitialSplits)) +} + +// RestrictionSize outputs the size of the restriction as the number of elements +// that restriction will output. +func (fn *intRangeFn) RestrictionSize(_ SourceConfig, rest offsetrange.Restriction) float64 { + return rest.Size() +} + +// CreateTracker just creates an offset range restriction tracker for the +// restriction. +func (fn *intRangeFn) CreateTracker(rest offsetrange.Restriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(offsetrange.NewTracker(rest)) +} + +// ProcessElement creates a number of random elements based on the restriction +// tracker received. Each element is a random byte slice key and value, in the +// form of KV<[]byte, []byte>. +func (fn *intRangeFn) ProcessElement(rt *sdf.LockRTracker, config SourceConfig, emit func(int64)) error { + for i := rt.GetRestriction().(offsetrange.Restriction).Start; rt.TryClaim(i); i++ { + // Add 1 since the restrictions are from [0 ,N), but we want [1, N] + emit(i + 1) + } + return nil +} + +func init() { + register.DoFn3x1[*sdf.LockRTracker, []byte, func(int64), sdf.ProcessContinuation](&selfCheckpointingDoFn{}) + register.Emitter1[int64]() +} + +type selfCheckpointingDoFn struct{} + +// CreateInitialRestriction creates the restriction being used by the SDF. In this case, the range +// of values produced by the restriction is [Start, End). +func (fn *selfCheckpointingDoFn) CreateInitialRestriction(_ []byte) offsetrange.Restriction { + return offsetrange.Restriction{ + Start: int64(0), + End: int64(10), + } +} + +// CreateTracker wraps the given restriction into a LockRTracker type. +func (fn *selfCheckpointingDoFn) CreateTracker(rest offsetrange.Restriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(offsetrange.NewTracker(rest)) +} + +// RestrictionSize returns the size of the current restriction +func (fn *selfCheckpointingDoFn) RestrictionSize(_ []byte, rest offsetrange.Restriction) float64 { + return rest.Size() +} + +// SplitRestriction modifies the offsetrange.Restriction's sized restriction function to produce a size-zero restriction +// at the end of execution. +func (fn *selfCheckpointingDoFn) SplitRestriction(_ []byte, rest offsetrange.Restriction) []offsetrange.Restriction { + size := int64(3) + s := rest.Start + var splits []offsetrange.Restriction + for e := s + size; e <= rest.End; s, e = e, e+size { + splits = append(splits, offsetrange.Restriction{Start: s, End: e}) + } + splits = append(splits, offsetrange.Restriction{Start: s, End: rest.End}) + return splits +} + +// ProcessElement continually gets the start position of the restriction and emits it as an int64 value before checkpointing. +// This causes the restriction to be split after the claimed work and produce no primary roots. +func (fn *selfCheckpointingDoFn) ProcessElement(rt *sdf.LockRTracker, _ []byte, emit func(int64)) sdf.ProcessContinuation { + position := rt.GetRestriction().(offsetrange.Restriction).Start + + for { + if rt.TryClaim(position) { + // Successful claim, emit the value and move on. + emit(position) + position++ + } else if rt.GetError() != nil || rt.IsDone() { + // Stop processing on error or completion + if err := rt.GetError(); err != nil { + log.Errorf(context.Background(), "error in restriction tracker, got %v", err) + } + return sdf.StopProcessing() + } else { + // Resume later. + return sdf.ResumeProcessingIn(5 * time.Second) + } + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go b/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go new file mode 100644 index 0000000000000..3596c40f0dcd2 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/testdofns_test.go @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" +) + +// Test DoFns are registered in the test file, to allow them to be pruned +// by the compiler outside of test use. +func init() { + register.Function2x0(dofn1) + register.Function2x0(dofn1kv) + register.Function3x0(dofn1x2) + register.Function6x0(dofn1x5) + register.Function3x0(dofn2x1) + register.Function4x0(dofn2x2KV) + register.Function4x0(dofnMultiMap) + register.Iter1[int64]() + register.Function4x0(dofn3x1) + register.Iter2[string, int64]() + register.Emitter1[string]() + + register.Function2x0(dofn2) + register.Function2x0(dofnKV) + register.Function2x0(dofnKV2) + register.Function3x0(dofnGBK) + register.Function3x0(dofnGBK2) + register.DoFn2x0[int64, func(int64)]((*int64Check)(nil)) + register.DoFn2x0[string, func(string)]((*stringCheck)(nil)) + register.Function2x0(dofnKV3) + register.Function3x0(dofnGBK3) + register.Function3x0(dofn1Counter) + register.Function2x0(dofnSink) + + register.Function2x1(combineIntSum) + + register.DoFn3x1[*sdf.LockRTracker, SourceConfig, func(int64), error]((*intRangeFn)(nil)) + register.Emitter1[int64]() + register.Emitter2[int64, int64]() +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go new file mode 100644 index 0000000000000..035ab3c0727fa --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package urns handles extracting urns from all the protos. +package urns + +import ( + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type protoEnum interface { + ~int32 + Descriptor() protoreflect.EnumDescriptor +} + +// toUrn returns a function that can get the urn string from the proto. +func toUrn[Enum protoEnum]() func(Enum) string { + evd := (Enum)(0).Descriptor().Values() + return func(v Enum) string { + return proto.GetExtension(evd.ByNumber(protoreflect.EnumNumber(v)).Options(), pipepb.E_BeamUrn).(string) + } +} + +// quickUrn handles one off urns instead of retaining a helper function. +// Notably useful for the windowFns due to their older design. +func quickUrn[Enum protoEnum](v Enum) string { + return toUrn[Enum]()(v) +} + +var ( + ptUrn = toUrn[pipepb.StandardPTransforms_Primitives]() + ctUrn = toUrn[pipepb.StandardPTransforms_Composites]() + cmbtUrn = toUrn[pipepb.StandardPTransforms_CombineComponents]() + sdfUrn = toUrn[pipepb.StandardPTransforms_SplittableParDoComponents]() + siUrn = toUrn[pipepb.StandardSideInputTypes_Enum]() + cdrUrn = toUrn[pipepb.StandardCoders_Enum]() + reqUrn = toUrn[pipepb.StandardRequirements_Enum]() + envUrn = toUrn[pipepb.StandardEnvironments_Environments]() +) + +var ( + // SDK transforms. + TransformParDo = ptUrn(pipepb.StandardPTransforms_PAR_DO) + TransformCombinePerKey = ctUrn(pipepb.StandardPTransforms_COMBINE_PER_KEY) + TransformPreCombine = cmbtUrn(pipepb.StandardPTransforms_COMBINE_PER_KEY_PRECOMBINE) + TransformMerge = cmbtUrn(pipepb.StandardPTransforms_COMBINE_PER_KEY_MERGE_ACCUMULATORS) + TransformExtract = cmbtUrn(pipepb.StandardPTransforms_COMBINE_PER_KEY_EXTRACT_OUTPUTS) + TransformPairWithRestriction = sdfUrn(pipepb.StandardPTransforms_PAIR_WITH_RESTRICTION) + TransformSplitAndSize = sdfUrn(pipepb.StandardPTransforms_SPLIT_AND_SIZE_RESTRICTIONS) + TransformProcessSizedElements = sdfUrn(pipepb.StandardPTransforms_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS) + TransformTruncate = sdfUrn(pipepb.StandardPTransforms_TRUNCATE_SIZED_RESTRICTION) + + // Window Manipulation + TransformAssignWindows = ptUrn(pipepb.StandardPTransforms_ASSIGN_WINDOWS) + TransformMapWindows = ptUrn(pipepb.StandardPTransforms_MAP_WINDOWS) + TransformMergeWindows = ptUrn(pipepb.StandardPTransforms_MERGE_WINDOWS) + + // Undocumented Urns + GoDoFn = "beam:go:transform:dofn:v1" // Only used for Go DoFn. + TransformSource = "beam:runner:source:v1" // The data source reading transform. + TransformSink = "beam:runner:sink:v1" // The data sink writing transform. + + // Runner transforms. + TransformImpulse = ptUrn(pipepb.StandardPTransforms_IMPULSE) + TransformGBK = ptUrn(pipepb.StandardPTransforms_GROUP_BY_KEY) + TransformFlatten = ptUrn(pipepb.StandardPTransforms_FLATTEN) + + // Side Input access patterns + SideInputIterable = siUrn(pipepb.StandardSideInputTypes_ITERABLE) + SideInputMultiMap = siUrn(pipepb.StandardSideInputTypes_MULTIMAP) + + // WindowsFns + WindowFnGlobal = quickUrn(pipepb.GlobalWindowsPayload_PROPERTIES) + WindowFnFixed = quickUrn(pipepb.FixedWindowsPayload_PROPERTIES) + WindowFnSliding = quickUrn(pipepb.SlidingWindowsPayload_PROPERTIES) + WindowFnSession = quickUrn(pipepb.SessionWindowsPayload_PROPERTIES) + + // Coders + CoderBytes = cdrUrn(pipepb.StandardCoders_BYTES) + CoderBool = cdrUrn(pipepb.StandardCoders_BOOL) + CoderDouble = cdrUrn(pipepb.StandardCoders_DOUBLE) + CoderStringUTF8 = cdrUrn(pipepb.StandardCoders_STRING_UTF8) + CoderRow = cdrUrn(pipepb.StandardCoders_ROW) + CoderVarInt = cdrUrn(pipepb.StandardCoders_VARINT) + + CoderGlobalWindow = cdrUrn(pipepb.StandardCoders_GLOBAL_WINDOW) + CoderIntervalWindow = cdrUrn(pipepb.StandardCoders_INTERVAL_WINDOW) + CoderCustomWindow = cdrUrn(pipepb.StandardCoders_CUSTOM_WINDOW) + + CoderParamWindowedValue = cdrUrn(pipepb.StandardCoders_PARAM_WINDOWED_VALUE) + CoderWindowedValue = cdrUrn(pipepb.StandardCoders_WINDOWED_VALUE) + CoderTimer = cdrUrn(pipepb.StandardCoders_TIMER) + + CoderKV = cdrUrn(pipepb.StandardCoders_KV) + CoderLengthPrefix = cdrUrn(pipepb.StandardCoders_LENGTH_PREFIX) + CoderNullable = cdrUrn(pipepb.StandardCoders_NULLABLE) + CoderIterable = cdrUrn(pipepb.StandardCoders_ITERABLE) + CoderStateBackedIterable = cdrUrn(pipepb.StandardCoders_STATE_BACKED_ITERABLE) + CoderShardedKey = cdrUrn(pipepb.StandardCoders_SHARDED_KEY) + + // Requirements + RequirementSplittableDoFn = reqUrn(pipepb.StandardRequirements_REQUIRES_SPLITTABLE_DOFN) + RequirementBundleFinalization = reqUrn(pipepb.StandardRequirements_REQUIRES_BUNDLE_FINALIZATION) + RequirementOnWindowExpiration = reqUrn(pipepb.StandardRequirements_REQUIRES_ON_WINDOW_EXPIRATION) + RequirementStableInput = reqUrn(pipepb.StandardRequirements_REQUIRES_STABLE_INPUT) + RequirementStatefulProcessing = reqUrn(pipepb.StandardRequirements_REQUIRES_STATEFUL_PROCESSING) + RequirementTimeSortedInput = reqUrn(pipepb.StandardRequirements_REQUIRES_TIME_SORTED_INPUT) + + // Environment types + EnvDocker = envUrn(pipepb.StandardEnvironments_DOCKER) + EnvProcess = envUrn(pipepb.StandardEnvironments_PROCESS) + EnvExternal = envUrn(pipepb.StandardEnvironments_EXTERNAL) + EnvDefault = envUrn(pipepb.StandardEnvironments_DEFAULT) +) diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go new file mode 100644 index 0000000000000..7b553f6ad6519 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package urn handles extracting urns from all the protos. +package urns + +import ( + "testing" + + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" +) + +// Test_toUrn validates that generic urn extraction mechnanism works, which is used for +// all the urns present. +func Test_toUrn(t *testing.T) { + want := "beam:transform:pardo:v1" + if got := TransformParDo; got != want { + t.Errorf("TransformParDo = %v, want %v", got, want) + } + // Validate that quickUrn gets the same thing + if got := quickUrn(pipepb.StandardPTransforms_PAR_DO); got != want { + t.Errorf("quickUrn(\"pipepb.StandardPTransforms_PAR_DO\") = %v, want %v", got, want) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go new file mode 100644 index 0000000000000..f6fbf1293f47e --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package worker + +import ( + "sync" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "golang.org/x/exp/slog" +) + +// B represents an extant ProcessBundle instruction sent to an SDK worker. +// Generally manipulated by another package to interact with a worker. +type B struct { + InstID string // ID for the instruction processing this bundle. + PBDID string // ID for the ProcessBundleDescriptor + + // InputTransformID is data being sent to the SDK. + InputTransformID string + InputData [][]byte // Data specifically for this bundle. + + // TODO change to a single map[tid] -> map[input] -> map[window] -> struct { Iter data, MultiMap data } instead of all maps. + // IterableSideInputData is a map from transformID, to inputID, to window, to data. + IterableSideInputData map[string]map[string]map[typex.Window][][]byte + // MultiMapSideInputData is a map from transformID, to inputID, to window, to data key, to data values. + MultiMapSideInputData map[string]map[string]map[typex.Window]map[string][][]byte + + // OutputCount is the number of data outputs this bundle has. + // We need to see this many closed data channels before the bundle is complete. + OutputCount int + // dataWait is how we determine if a bundle is finished, by waiting for each of + // a Bundle's DataSinks to produce their last output. + // After this point we can "commit" the bundle's output for downstream use. + dataWait sync.WaitGroup + OutputData engine.TentativeData + Resp chan *fnpb.ProcessBundleResponse + + SinkToPCollection map[string]string + + // TODO: Metrics for this bundle, can be handled after the fact. +} + +// Init initializes the bundle's internal state for waiting on all +// data and for relaying a response back. +func (b *B) Init() { + // We need to see final data signals that match the number of + // outputs the stage this bundle executes posesses + b.dataWait.Add(b.OutputCount) + b.Resp = make(chan *fnpb.ProcessBundleResponse, 1) +} + +func (b *B) LogValue() slog.Value { + return slog.GroupValue( + slog.String("ID", b.InstID), + slog.String("stage", b.PBDID)) +} + +// ProcessOn executes the given bundle on the given W, blocking +// until all data is complete. +// +// Assumes the bundle is initialized (all maps are non-nil, and data waitgroup is set, response channel initialized) +// Assumes the bundle descriptor is already registered with the W. +// +// While this method mostly manipulates a W, putting it on a B avoids mixing the workers +// public GRPC APIs up with local calls. +func (b *B) ProcessOn(wk *W) { + wk.mu.Lock() + wk.bundles[b.InstID] = b + wk.mu.Unlock() + + slog.Debug("processing", "bundle", b, "worker", wk) + + // Tell the SDK to start processing the bundle. + wk.InstReqs <- &fnpb.InstructionRequest{ + InstructionId: b.InstID, + Request: &fnpb.InstructionRequest_ProcessBundle{ + ProcessBundle: &fnpb.ProcessBundleRequest{ + ProcessBundleDescriptorId: b.PBDID, + }, + }, + } + + // TODO: make batching decisions. + for i, d := range b.InputData { + wk.DataReqs <- &fnpb.Elements{ + Data: []*fnpb.Elements_Data{ + { + InstructionId: b.InstID, + TransformId: b.InputTransformID, + Data: d, + IsLast: i+1 == len(b.InputData), + }, + }, + } + } + + slog.Debug("waiting on data", "bundle", b) + b.dataWait.Wait() // Wait until data is ready. +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go new file mode 100644 index 0000000000000..154306c3f6ba2 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package worker + +import ( + "bytes" + "sync" + "testing" +) + +func TestBundle_ProcessOn(t *testing.T) { + wk := New("test") + b := &B{ + InstID: "testInst", + PBDID: "testPBDID", + OutputCount: 1, + InputData: [][]byte{{1, 2, 3}}, + } + b.Init() + var completed sync.WaitGroup + completed.Add(1) + go func() { + b.ProcessOn(wk) + completed.Done() + }() + b.dataWait.Done() + gotData := <-wk.DataReqs + if got, want := gotData.GetData()[0].GetData(), []byte{1, 2, 3}; !bytes.EqualFold(got, want) { + t.Errorf("ProcessOn(): data not sent; got %v, want %v", got, want) + } + + gotInst := <-wk.InstReqs + if got, want := gotInst.GetInstructionId(), b.InstID; got != want { + t.Errorf("ProcessOn(): bad instruction ID; got %v, want %v", got, want) + } + if got, want := gotInst.GetProcessBundle().GetProcessBundleDescriptorId(), b.PBDID; got != want { + t.Errorf("ProcessOn(): bad process bundle descriptor ID; got %v, want %v", got, want) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go new file mode 100644 index 0000000000000..8458ce39e1168 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -0,0 +1,421 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package worker handles interactions with SDK side workers, representing +// the worker services, communicating with those services, and SDK environments. +package worker + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" + "golang.org/x/exp/slog" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/prototext" +) + +// A W manages worker environments, sending them work +// that they're able to execute, and manages the server +// side handlers for FnAPI RPCs. +type W struct { + fnpb.UnimplementedBeamFnControlServer + fnpb.UnimplementedBeamFnDataServer + fnpb.UnimplementedBeamFnStateServer + fnpb.UnimplementedBeamFnLoggingServer + + ID string + + // Server management + lis net.Listener + server *grpc.Server + + // These are the ID sources + inst, bund uint64 + + InstReqs chan *fnpb.InstructionRequest + DataReqs chan *fnpb.Elements + + mu sync.Mutex + bundles map[string]*B // Bundles keyed by InstructionID + Descriptors map[string]*fnpb.ProcessBundleDescriptor // Stages keyed by PBDID + + D *DataService +} + +// New starts the worker server components of FnAPI Execution. +func New(id string) *W { + lis, err := net.Listen("tcp", ":0") + if err != nil { + panic(fmt.Sprintf("failed to listen: %v", err)) + } + var opts []grpc.ServerOption + wk := &W{ + ID: id, + lis: lis, + server: grpc.NewServer(opts...), + + InstReqs: make(chan *fnpb.InstructionRequest, 10), + DataReqs: make(chan *fnpb.Elements, 10), + + bundles: make(map[string]*B), + Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), + + D: &DataService{}, + } + slog.Info("Serving Worker components", slog.String("endpoint", wk.Endpoint())) + fnpb.RegisterBeamFnControlServer(wk.server, wk) + fnpb.RegisterBeamFnDataServer(wk.server, wk) + fnpb.RegisterBeamFnLoggingServer(wk.server, wk) + fnpb.RegisterBeamFnStateServer(wk.server, wk) + return wk +} + +func (wk *W) Endpoint() string { + return wk.lis.Addr().String() +} + +// Serve serves on the started listener. Blocks. +func (wk *W) Serve() { + wk.server.Serve(wk.lis) +} + +func (wk *W) String() string { + return "worker[" + wk.ID + "]" +} + +func (wk *W) LogValue() slog.Value { + return slog.GroupValue( + slog.String("ID", wk.ID), + slog.String("endpoint", wk.Endpoint()), + ) +} + +// Stop the GRPC server. +func (wk *W) Stop() { + slog.Debug("stopping", "worker", wk) + close(wk.InstReqs) + close(wk.DataReqs) + wk.server.Stop() + wk.lis.Close() + slog.Debug("stopped", "worker", wk) +} + +func (wk *W) NextInst() string { + return fmt.Sprintf("inst%03d", atomic.AddUint64(&wk.inst, 1)) +} + +func (wk *W) NextStage() string { + return fmt.Sprintf("stage%03d", atomic.AddUint64(&wk.bund, 1)) +} + +// TODO set logging level. +var minsev = fnpb.LogEntry_Severity_DEBUG + +// Logging relates SDK worker messages back to the job that spawned them. +// Messages are received from the SDK, +func (wk *W) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { + for { + in, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + slog.Error("logging.Recv", err, "worker", wk) + return err + } + for _, l := range in.GetLogEntries() { + if l.Severity >= minsev { + // TODO: Connect to the associated Job for this worker instead of + // logging locally for SDK side logging. + slog.Log(toSlogSev(l.GetSeverity()), l.GetMessage(), + slog.String(slog.SourceKey, l.GetLogLocation()), + slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), + "worker", wk, + ) + } + } + } +} + +func toSlogSev(sev fnpb.LogEntry_Severity_Enum) slog.Level { + switch sev { + case fnpb.LogEntry_Severity_TRACE: + return slog.Level(-8) + case fnpb.LogEntry_Severity_DEBUG: + return slog.LevelDebug // -4 + case fnpb.LogEntry_Severity_INFO: + return slog.LevelInfo // 0 + case fnpb.LogEntry_Severity_NOTICE: + return slog.Level(2) + case fnpb.LogEntry_Severity_WARN: + return slog.LevelWarn // 4 + case fnpb.LogEntry_Severity_ERROR: + return slog.LevelError // 8 + case fnpb.LogEntry_Severity_CRITICAL: + return slog.Level(10) + } + return slog.LevelInfo +} + +func (wk *W) GetProcessBundleDescriptor(ctx context.Context, req *fnpb.GetProcessBundleDescriptorRequest) (*fnpb.ProcessBundleDescriptor, error) { + desc, ok := wk.Descriptors[req.GetProcessBundleDescriptorId()] + if !ok { + return nil, fmt.Errorf("descriptor %v not found", req.GetProcessBundleDescriptorId()) + } + return desc, nil +} + +// Control relays instructions to SDKs and back again, coordinated via unique instructionIDs. +// +// Requests come from the runner, and are sent to the client in the SDK. +func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { + done := make(chan bool) + go func() { + for { + resp, err := ctrl.Recv() + if err == io.EOF { + slog.Debug("ctrl.Recv finished; marking done", "worker", wk) + done <- true // means stream is finished + return + } + if err != nil { + switch status.Code(err) { + case codes.Canceled: // Might ignore this all the time instead. + slog.Error("ctrl.Recv Canceled", err, "worker", wk) + done <- true // means stream is finished + return + default: + slog.Error("ctrl.Recv failed", err, "worker", wk) + panic(err) + } + } + + // TODO: Do more than assume these are ProcessBundleResponses. + wk.mu.Lock() + if b, ok := wk.bundles[resp.GetInstructionId()]; ok { + // TODO. Better pipeline error handling. + if resp.Error != "" { + slog.Log(slog.LevelError, "ctrl.Recv pipeline error", slog.ErrorKey, resp.GetError()) + panic(resp.GetError()) + } + b.Resp <- resp.GetProcessBundle() + } else { + slog.Debug("ctrl.Recv: %v", resp) + } + wk.mu.Unlock() + } + }() + + for req := range wk.InstReqs { + ctrl.Send(req) + } + slog.Debug("ctrl.Send finished waiting on done") + <-done + slog.Debug("Control done") + return nil +} + +// Data relays elements and timer bytes to SDKs and back again, coordinated via +// ProcessBundle instructionIDs, and receiving input transforms. +// +// Data is multiplexed on a single stream for all active bundles on a worker. +func (wk *W) Data(data fnpb.BeamFnData_DataServer) error { + go func() { + for { + resp, err := data.Recv() + if err == io.EOF { + return + } + if err != nil { + switch status.Code(err) { + case codes.Canceled: + slog.Error("data.Recv Canceled", err, "worker", wk) + return + default: + slog.Error("data.Recv failed", err, "worker", wk) + panic(err) + } + } + wk.mu.Lock() + for _, d := range resp.GetData() { + b, ok := wk.bundles[d.GetInstructionId()] + if !ok { + slog.Info("data.Recv for unknown bundle", "response", resp) + continue + } + colID := b.SinkToPCollection[d.GetTransformId()] + + // There might not be data, eg. for side inputs, so we need to reconcile this elsewhere for + // downstream side inputs. + if len(d.GetData()) > 0 { + b.OutputData.WriteData(colID, d.GetData()) + } + if d.GetIsLast() { + b.dataWait.Done() + } + } + wk.mu.Unlock() + } + }() + + for req := range wk.DataReqs { + if err := data.Send(req); err != nil { + slog.Log(slog.LevelDebug, "data.Send error", slog.ErrorKey, err) + } + } + return nil +} + +// State relays elements and timer bytes to SDKs and back again, coordinated via +// ProcessBundle instructionIDs, and receiving input transforms. +// +// State requests come from SDKs, and the runner responds. +func (wk *W) State(state fnpb.BeamFnState_StateServer) error { + responses := make(chan *fnpb.StateResponse) + go func() { + // This go routine creates all responses to state requests from the worker + // so we want to close the State handler when it's all done. + defer close(responses) + for { + req, err := state.Recv() + if err == io.EOF { + return + } + if err != nil { + switch status.Code(err) { + case codes.Canceled: + slog.Error("state.Recv Canceled", err, "worker", wk) + return + default: + slog.Error("state.Recv failed", err, "worker", wk) + panic(err) + } + } + switch req.GetRequest().(type) { + case *fnpb.StateRequest_Get: + // TODO: move data handling to be pcollection based. + b := wk.bundles[req.GetInstructionId()] + key := req.GetStateKey() + slog.Debug("StateRequest_Get", prototext.Format(req), "bundle", b) + + var data [][]byte + switch key.GetType().(type) { + case *fnpb.StateKey_IterableSideInput_: + ikey := key.GetIterableSideInput() + wKey := ikey.GetWindow() + var w typex.Window + if len(wKey) == 0 { + w = window.GlobalWindow{} + } else { + w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey)) + if err != nil { + panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, err)) + } + } + winMap := b.IterableSideInputData[ikey.GetTransformId()][ikey.GetSideInputId()] + var wins []typex.Window + for w := range winMap { + wins = append(wins, w) + } + slog.Debug(fmt.Sprintf("side input[%v][%v] I Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, wins)) + data = winMap[w] + + case *fnpb.StateKey_MultimapSideInput_: + mmkey := key.GetMultimapSideInput() + wKey := mmkey.GetWindow() + var w typex.Window + if len(wKey) == 0 { + w = window.GlobalWindow{} + } else { + w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey)) + if err != nil { + panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, err)) + } + } + dKey := mmkey.GetKey() + winMap := b.MultiMapSideInputData[mmkey.GetTransformId()][mmkey.GetSideInputId()] + var wins []typex.Window + for w := range winMap { + wins = append(wins, w) + } + slog.Debug(fmt.Sprintf("side input[%v][%v] MM Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, wins)) + + data = winMap[w][string(dKey)] + + default: + panic(fmt.Sprintf("unsupported StateKey Access type: %T: %v", key.GetType(), prototext.Format(key))) + } + + // Encode the runner iterable (no length, just consecutive elements), and send it out. + // This is also where we can handle things like State Backed Iterables. + var buf bytes.Buffer + for _, value := range data { + buf.Write(value) + } + responses <- &fnpb.StateResponse{ + Id: req.GetId(), + Response: &fnpb.StateResponse_Get{ + Get: &fnpb.StateGetResponse{ + Data: buf.Bytes(), + }, + }, + } + default: + panic(fmt.Sprintf("unsupported StateRequest kind %T: %v", req.GetRequest(), prototext.Format(req))) + } + } + }() + for resp := range responses { + if err := state.Send(resp); err != nil { + slog.Error("state.Send error", err) + } + } + return nil +} + +// DataService is slated to be deleted in favour of stage based state +// management for side inputs. +type DataService struct { + // TODO actually quick process the data to windows here as well. + raw map[string][][]byte +} + +// Commit tentative data to the datastore. +func (d *DataService) Commit(tent engine.TentativeData) { + if d.raw == nil { + d.raw = map[string][][]byte{} + } + for colID, data := range tent.Raw { + d.raw[colID] = append(d.raw[colID], data...) + } +} + +// GetAllData is a hack for Side Inputs until watermarks are sorted out. +func (d *DataService) GetAllData(colID string) [][]byte { + return d.raw[colID] +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go new file mode 100644 index 0000000000000..29b3fab92d648 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package worker + +import ( + "bytes" + "context" + "net" + "sync" + "testing" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +func TestWorker_New(t *testing.T) { + w := New("test") + if got, want := w.ID, "test"; got != want { + t.Errorf("New(%q) = %v, want %v", want, got, want) + } +} + +func TestWorker_NextInst(t *testing.T) { + w := New("test") + + instIDs := map[string]struct{}{} + for i := 0; i < 100; i++ { + instIDs[w.NextInst()] = struct{}{} + } + if got, want := len(instIDs), 100; got != want { + t.Errorf("calling w.NextInst() got %v unique ids, want %v", got, want) + } +} + +func TestWorker_NextStage(t *testing.T) { + w := New("test") + + stageIDs := map[string]struct{}{} + for i := 0; i < 100; i++ { + stageIDs[w.NextStage()] = struct{}{} + } + if got, want := len(stageIDs), 100; got != want { + t.Errorf("calling w.NextStage() got %v unique ids, want %v", got, want) + } +} + +func TestWorker_GetProcessBundleDescriptor(t *testing.T) { + w := New("test") + + id := "available" + w.Descriptors[id] = &fnpb.ProcessBundleDescriptor{ + Id: id, + } + + pbd, err := w.GetProcessBundleDescriptor(context.Background(), &fnpb.GetProcessBundleDescriptorRequest{ + ProcessBundleDescriptorId: id, + }) + if err != nil { + t.Errorf("got GetProcessBundleDescriptor(%q) error: %v, want nil", id, err) + } + if got, want := pbd.GetId(), id; got != want { + t.Errorf("got GetProcessBundleDescriptor(%q) = %v, want id %v", id, got, want) + } + + pbd, err = w.GetProcessBundleDescriptor(context.Background(), &fnpb.GetProcessBundleDescriptorRequest{ + ProcessBundleDescriptorId: "unknown", + }) + if err == nil { + t.Errorf("got GetProcessBundleDescriptor(%q) = %v, want error", "unknown", pbd) + } +} + +func serveTestWorker(t *testing.T) (context.Context, *W, *grpc.ClientConn) { + t.Helper() + ctx, cancelFn := context.WithCancel(context.Background()) + t.Cleanup(cancelFn) + + w := New("test") + lis := bufconn.Listen(2048) + w.lis = lis + t.Cleanup(func() { w.Stop() }) + go w.Serve() + + clientConn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + }), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + if err != nil { + t.Fatal("couldn't create bufconn grpc connection:", err) + } + return ctx, w, clientConn +} + +func TestWorker_Logging(t *testing.T) { + ctx, _, clientConn := serveTestWorker(t) + + logCli := fnpb.NewBeamFnLoggingClient(clientConn) + logStream, err := logCli.Logging(ctx) + if err != nil { + t.Fatal("couldn't create log client:", err) + } + + logStream.Send(&fnpb.LogEntry_List{ + LogEntries: []*fnpb.LogEntry{{ + Severity: fnpb.LogEntry_Severity_INFO, + Message: "squeamish ossiphrage", + }}, + }) + + // TODO: Connect to the job management service. + // At this point job messages are just logged to wherever the prism runner executes + // But this should pivot to anyone connecting to the Job Management service for the + // job. + // In the meantime, sleep to validate execution via coverage. + time.Sleep(20 * time.Millisecond) +} + +func TestWorker_Control_HappyPath(t *testing.T) { + ctx, wk, clientConn := serveTestWorker(t) + + ctrlCli := fnpb.NewBeamFnControlClient(clientConn) + ctrlStream, err := ctrlCli.Control(ctx) + if err != nil { + t.Fatal("couldn't create control client:", err) + } + + instID := wk.NextInst() + + b := &B{} + b.Init() + wk.bundles[instID] = b + b.ProcessOn(wk) + + ctrlStream.Send(&fnpb.InstructionResponse{ + InstructionId: instID, + Response: &fnpb.InstructionResponse_ProcessBundle{ + ProcessBundle: &fnpb.ProcessBundleResponse{ + RequiresFinalization: true, // Simple thing to check. + }, + }, + }) + + if err := ctrlStream.CloseSend(); err != nil { + t.Errorf("ctrlStream.CloseSend() = %v", err) + } + resp := <-b.Resp + + if !resp.RequiresFinalization { + t.Errorf("got %v, want response that Requires Finalization", resp) + } +} + +func TestWorker_Data_HappyPath(t *testing.T) { + ctx, wk, clientConn := serveTestWorker(t) + + dataCli := fnpb.NewBeamFnDataClient(clientConn) + dataStream, err := dataCli.Data(ctx) + if err != nil { + t.Fatal("couldn't create data client:", err) + } + + instID := wk.NextInst() + + b := &B{ + InstID: instID, + PBDID: wk.NextStage(), + InputData: [][]byte{ + {1, 1, 1, 1, 1, 1}, + }, + OutputCount: 1, + } + b.Init() + wk.bundles[instID] = b + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + b.ProcessOn(wk) + }() + + wk.InstReqs <- &fnpb.InstructionRequest{ + InstructionId: instID, + } + + elements, err := dataStream.Recv() + if err != nil { + t.Fatal("couldn't receive data elements:", err) + } + + if got, want := elements.GetData()[0].GetInstructionId(), b.InstID; got != want { + t.Fatalf("couldn't receive data elements ID: got %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetData(), []byte{1, 1, 1, 1, 1, 1}; !bytes.Equal(got, want) { + t.Fatalf("client Data received %v, want %v", got, want) + } + if got, want := elements.GetData()[0].GetIsLast(), true; got != want { + t.Fatalf("client Data received wasn't last: got %v, want %v", got, want) + } + + dataStream.Send(elements) + + if err := dataStream.CloseSend(); err != nil { + t.Errorf("ctrlStream.CloseSend() = %v", err) + } + + wg.Wait() + t.Log("ProcessOn successfully exited") +} + +func TestWorker_State_Iterable(t *testing.T) { + ctx, wk, clientConn := serveTestWorker(t) + + stateCli := fnpb.NewBeamFnStateClient(clientConn) + stateStream, err := stateCli.State(ctx) + if err != nil { + t.Fatal("couldn't create state client:", err) + } + + instID := wk.NextInst() + wk.bundles[instID] = &B{ + IterableSideInputData: map[string]map[string]map[typex.Window][][]byte{ + "transformID": { + "i1": { + window.GlobalWindow{}: [][]byte{ + {42}, + }, + }, + }, + }, + } + + stateStream.Send(&fnpb.StateRequest{ + Id: "first", + InstructionId: instID, + Request: &fnpb.StateRequest_Get{ + Get: &fnpb.StateGetRequest{}, + }, + StateKey: &fnpb.StateKey{Type: &fnpb.StateKey_IterableSideInput_{ + IterableSideInput: &fnpb.StateKey_IterableSideInput{ + TransformId: "transformID", + SideInputId: "i1", + Window: []byte{}, // Global Windows + }, + }}, + }) + + resp, err := stateStream.Recv() + if err != nil { + t.Fatal("couldn't receive state response:", err) + } + + if got, want := resp.GetId(), "first"; got != want { + t.Fatalf("didn't receive expected state response: got %v, want %v", got, want) + } + + if got, want := resp.GetGet().GetData(), []byte{42}; !bytes.Equal(got, want) { + t.Fatalf("didn't receive expected state response data: got %v, want %v", got, want) + } + + if err := stateStream.CloseSend(); err != nil { + t.Errorf("stateStream.CloseSend() = %v", err) + } +} diff --git a/sdks/go/pkg/beam/runners/prism/prism.go b/sdks/go/pkg/beam/runners/prism/prism.go new file mode 100644 index 0000000000000..dc78e5e6c2307 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/prism.go @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package prism contains a local runner for running +// pipelines in the current process. Useful for testing. +package prism + +import ( + "context" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/options/jobopts" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" +) + +func init() { + beam.RegisterRunner("prism", Execute) + beam.RegisterRunner("PrismRunner", Execute) +} + +func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) { + if *jobopts.Endpoint == "" { + // One hasn't been selected, so lets start one up and set the address. + // Conveniently, this means that if multiple pipelines are executed against + // the local runner, they will all use the same server. + s := jobservices.NewServer(0, internal.RunPipeline) + *jobopts.Endpoint = s.Endpoint() + go s.Serve() + } + if !jobopts.IsLoopback() { + *jobopts.EnvironmentType = "loopback" + } + return universal.Execute(ctx, p) +}