Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prism] Fix top for unfused execution. Move to register. #27585

Merged
merged 2 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions sdks/go/pkg/beam/transforms/top/top.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
)

//go:generate go install github.com/apache/beam/sdks/v2/go/cmd/starcgen
//go:generate starcgen --package=top
//go:generate go fmt

func init() {
beam.RegisterDoFn(reflect.TypeOf((*combineFn)(nil)))
register.Combiner3[accum, beam.T, []beam.T]((*combineFn)(nil))
}

var (
Expand Down Expand Up @@ -157,18 +154,20 @@ func accumEnc() func(accum) ([]byte, error) {
panic(err)
}
return func(a accum) ([]byte, error) {
if a.enc == nil {
return nil, errors.Errorf("top.accum: element encoder unspecified")
if len(a.list) > 0 && a.enc == nil {
return nil, errors.Errorf("top.accum: element encoder unspecified with non-zero elements: %v data available", len(a.data))
}
var values [][]byte
if len(a.list) == 0 && len(a.data) > 0 {
values = a.data
}
Comment on lines +161 to +163
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of an understanding question if a.data could also be assigned to nil as we do for a.list after this if block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very good question.

My instinct is that since this function is supposed to write and encode data, it shouldn't be mutating the value at all.

Which means technically it shouldn't be nil'ing out that list field at all either.

In principle, the encoding will only happen when emitting the values downstream, which should only happen once per given value, at which point the value itself should be garbage collected away anyway.

So I'm going to do the opposite: remove the nil on encode line there.

for _, value := range a.list {
var buf bytes.Buffer
if err := a.enc.Encode(value, &buf); err != nil {
return nil, errors.WithContextf(err, "top.accum: marshalling %v", value)
}
values = append(values, buf.Bytes())
}
a.list = nil

var buf bytes.Buffer
if err := coder.WriteSimpleRowHeader(1, &buf); err != nil {
Expand Down
185 changes: 0 additions & 185 deletions sdks/go/pkg/beam/transforms/top/top.shims.go

This file was deleted.

54 changes: 26 additions & 28 deletions sdks/go/pkg/beam/transforms/top/top_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,33 @@ import (

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
)

func TestMain(m *testing.M) {
ptest.Main(m)
}

func init() {
register.Function2x2(addKeyFn)
register.Function2x1(lessInt)
register.Function2x1(shorterString)
}

func lessInt(a, b int) bool {
return a < b
}

func shorterString(a, b string) bool {
return len(a) < len(b)
}

// TestCombineFn3String verifies that the accumulator correctly
// maintains the top 3 longest strings.
func TestCombineFn3String(t *testing.T) {
less := func(a, b string) bool {
return len(a) < len(b)
}
fn := newCombineFn(less, 3, reflectx.String, false)
fn := newCombineFn(shorterString, 3, reflectx.String, false)

tests := []struct {
Elms []string
Expand All @@ -57,10 +73,7 @@ func TestCombineFn3String(t *testing.T) {
// TestCombineFn3RevString verifies that the accumulator correctly
// maintains the top 3 shortest strings.
func TestCombineFn3RevString(t *testing.T) {
less := func(a, b string) bool {
return len(a) < len(b)
}
fn := newCombineFn(less, 3, reflectx.String, true)
fn := newCombineFn(shorterString, 3, reflectx.String, true)

tests := []struct {
Elms []string
Expand All @@ -86,10 +99,7 @@ func TestCombineFn3RevString(t *testing.T) {
// extractOutput still works on the marshalled accumulators it receives after
// merging.
func TestCombineFnMerge(t *testing.T) {
less := func(a, b string) bool {
return len(a) < len(b)
}
fn := newCombineFn(less, 3, reflectx.String, false)
fn := newCombineFn(shorterString, 3, reflectx.String, false)
tests := []struct {
Elms [][]string
Expected []string
Expand Down Expand Up @@ -170,12 +180,9 @@ func output(fn *combineFn, a accum) []string {
// TestLargest checks that the Largest transform outputs the correct elements
// for a given PCollection of ints and a comparator function.
func TestLargest(t *testing.T) {
less := func(a, b int) bool {
return a < b
}
p, s := beam.NewPipelineWithRoot()
col := beam.Create(s, 1, 11, 7, 5, 10)
topTwo := Largest(s, col, 2, less)
topTwo := Largest(s, col, 2, lessInt)
passert.Equals(s, topTwo, []int{11, 10})
if err := ptest.Run(p); err != nil {
t.Errorf("pipeline failed but should have succeeded, got %v", err)
Expand All @@ -185,12 +192,9 @@ func TestLargest(t *testing.T) {
// TestSmallest checks that the Smallest transform outputs the correct elements
// for a given PCollection of ints and a comparator function.
func TestSmallest(t *testing.T) {
less := func(a, b int) bool {
return a < b
}
p, s := beam.NewPipelineWithRoot()
col := beam.Create(s, 1, 11, 7, 5, 10)
botTwo := Smallest(s, col, 2, less)
botTwo := Smallest(s, col, 2, lessInt)
passert.Equals(s, botTwo, []int{1, 5})
if err := ptest.Run(p); err != nil {
t.Errorf("pipeline failed but should have succeeded, got %v", err)
Expand All @@ -209,9 +213,6 @@ func addKeyFn(elm beam.T, newKey int) (int, beam.T) {
// TestLargestPerKey ensures that the LargestPerKey transform outputs the proper
// collection for a PCollection of type <int, int>.
func TestLargestPerKey(t *testing.T) {
less := func(a, b int) bool {
return a < b
}
p, s := beam.NewPipelineWithRoot()
colZero := beam.Create(s, 1, 11, 7, 5, 10)
keyedZero := addKey(s, colZero, 0)
Expand All @@ -220,7 +221,7 @@ func TestLargestPerKey(t *testing.T) {
keyedOne := addKey(s, colOne, 1)

col := beam.Flatten(s, keyedZero, keyedOne)
top := LargestPerKey(s, col, 2, less)
top := LargestPerKey(s, col, 2, lessInt)
out := beam.DropKey(s, top)
passert.Equals(s, out, []int{11, 10}, []int{12, 11})
if err := ptest.Run(p); err != nil {
Expand All @@ -231,9 +232,6 @@ func TestLargestPerKey(t *testing.T) {
// TestSmallestPerKey ensures that the SmallestPerKey transform outputs the proper
// collection for a PCollection of type <int, int>.
func TestSmallestPerKey(t *testing.T) {
less := func(a, b int) bool {
return a < b
}
p, s := beam.NewPipelineWithRoot()
colZero := beam.Create(s, 1, 11, 7, 5, 10)
keyedZero := addKey(s, colZero, 0)
Expand All @@ -242,7 +240,7 @@ func TestSmallestPerKey(t *testing.T) {
keyedOne := addKey(s, colOne, 1)

col := beam.Flatten(s, keyedZero, keyedOne)
bot := SmallestPerKey(s, col, 2, less)
bot := SmallestPerKey(s, col, 2, lessInt)
out := beam.DropKey(s, bot)
passert.Equals(s, out, []int{1, 5}, []int{2, 6})
if err := ptest.Run(p); err != nil {
Expand Down