diff --git a/go/ast/inspector/inspector.go b/go/ast/inspector/inspector.go index 1fc1de0bd10..0e0ba4c035c 100644 --- a/go/ast/inspector/inspector.go +++ b/go/ast/inspector/inspector.go @@ -73,6 +73,15 @@ func (in *Inspector) Preorder(types []ast.Node, f func(ast.Node)) { // check, Preorder is almost twice as fast as Nodes. The two // features seem to contribute similar slowdowns (~1.4x each). + // This function is equivalent to the PreorderSeq call below, + // but to avoid the additional dynamic call (which adds 13-35% + // to the benchmarks), we expand it out. + // + // in.PreorderSeq(types...)(func(n ast.Node) bool { + // f(n) + // return true + // }) + mask := maskOf(types) for i := 0; i < len(in.events); { ev := in.events[i] diff --git a/go/ast/inspector/inspector_test.go b/go/ast/inspector/inspector_test.go index 5d7cb6e44eb..a19ba653e0a 100644 --- a/go/ast/inspector/inspector_test.go +++ b/go/ast/inspector/inspector_test.go @@ -160,7 +160,8 @@ func TestInspectPruning(t *testing.T) { compare(t, nodesA, nodesB) } -func compare(t *testing.T, nodesA, nodesB []ast.Node) { +// compare calls t.Error if !slices.Equal(nodesA, nodesB). +func compare[N comparable](t *testing.T, nodesA, nodesB []N) { if len(nodesA) != len(nodesB) { t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB)) } else { diff --git a/go/ast/inspector/iter.go b/go/ast/inspector/iter.go new file mode 100644 index 00000000000..b7e959114cb --- /dev/null +++ b/go/ast/inspector/iter.go @@ -0,0 +1,85 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.23 + +package inspector + +import ( + "go/ast" + "iter" +) + +// PreorderSeq returns an iterator that visits all the +// nodes of the files supplied to New in depth-first order. +// It visits each node n before n's children. +// The complete traversal sequence is determined by ast.Inspect. +// +// The types argument, if non-empty, enables type-based +// filtering of events: only nodes whose type matches an +// element of the types slice are included in the sequence. +func (in *Inspector) PreorderSeq(types ...ast.Node) iter.Seq[ast.Node] { + + // This implementation is identical to Preorder, + // except that it supports breaking out of the loop. + + return func(yield func(ast.Node) bool) { + mask := maskOf(types) + for i := 0; i < len(in.events); { + ev := in.events[i] + if ev.index > i { + // push + if ev.typ&mask != 0 { + if !yield(ev.node) { + break + } + } + pop := ev.index + if in.events[pop].typ&mask == 0 { + // Subtrees do not contain types: skip them and pop. + i = pop + 1 + continue + } + } + i++ + } + } +} + +// All[N] returns an iterator over all the nodes of type N. +// N must be a pointer-to-struct type that implements ast.Node. +// +// Example: +// +// for call := range All[*ast.CallExpr](in) { ... } +func All[N interface { + *S + ast.Node +}, S any](in *Inspector) iter.Seq[N] { + + // To avoid additional dynamic call overheads, + // we duplicate rather than call the logic of PreorderSeq. + + mask := typeOf((N)(nil)) + return func(yield func(N) bool) { + for i := 0; i < len(in.events); { + ev := in.events[i] + if ev.index > i { + // push + if ev.typ&mask != 0 { + if !yield(ev.node.(N)) { + break + } + } + pop := ev.index + if in.events[pop].typ&mask == 0 { + // Subtrees do not contain types: skip them and pop. + i = pop + 1 + continue + } + } + i++ + } + } +} diff --git a/go/ast/inspector/iter_test.go b/go/ast/inspector/iter_test.go new file mode 100644 index 00000000000..2f52998c558 --- /dev/null +++ b/go/ast/inspector/iter_test.go @@ -0,0 +1,83 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.23 + +package inspector_test + +import ( + "go/ast" + "iter" + "slices" + "testing" + + "golang.org/x/tools/go/ast/inspector" +) + +// TestPreorderSeq checks PreorderSeq against Preorder. +func TestPreorderSeq(t *testing.T) { + inspect := inspector.New(netFiles) + + nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)} + + // reference implementation + var want []ast.Node + inspect.Preorder(nodeFilter, func(n ast.Node) { + want = append(want, n) + }) + + // Check entire sequence. + got := slices.Collect(inspect.PreorderSeq(nodeFilter...)) + compare(t, got, want) + + // Check that break works. + got = firstN(10, inspect.PreorderSeq(nodeFilter...)) + compare(t, got, want[:10]) +} + +// TestAll checks All against Preorder. +func TestAll(t *testing.T) { + inspect := inspector.New(netFiles) + + // reference implementation + var want []*ast.CallExpr + inspect.Preorder([]ast.Node{(*ast.CallExpr)(nil)}, func(n ast.Node) { + want = append(want, n.(*ast.CallExpr)) + }) + + // Check entire sequence. + got := slices.Collect(inspector.All[*ast.CallExpr](inspect)) + compare(t, got, want) + + // Check that break works. + got = firstN(10, inspector.All[*ast.CallExpr](inspect)) + compare(t, got, want[:10]) +} + +// firstN(n, seq), returns a slice of up to n elements of seq. +func firstN[T any](n int, seq iter.Seq[T]) (res []T) { + for x := range seq { + res = append(res, x) + if len(res) == n { + break + } + } + return res +} + +// BenchmarkAllCalls is like BenchmarkInspectCalls, +// but using the single-type filtering iterator, All. +// (The iterator adds about 5-15%.) +func BenchmarkAllCalls(b *testing.B) { + inspect := inspector.New(netFiles) + b.ResetTimer() + + // Measure marginal cost of traversal. + var ncalls int + for range b.N { + for range inspector.All[*ast.CallExpr](inspect) { + ncalls++ + } + } +}