diff --git a/linker/symbols.go b/linker/symbols.go index 8a66af0..ef32cf9 100644 --- a/linker/symbols.go +++ b/linker/symbols.go @@ -53,7 +53,7 @@ type packageSymbols struct { children map[protoreflect.FullName]*packageSymbols files map[protoreflect.FileDescriptor]struct{} symbols map[protoreflect.FullName]symbolEntry - exts map[extNumber]ast.SourcePos + exts map[extNumber]ast.SourceSpan } type extNumber struct { @@ -68,7 +68,7 @@ type symbolEntry struct { } type extDecl struct { - pos ast.SourcePos + span ast.SourceSpan extendee protoreflect.FullName tag protoreflect.FieldNumber } @@ -174,15 +174,15 @@ func (s *Symbols) importPackages(pkgSpan ast.SourceSpan, pkg protoreflect.FullNa return &s.pkgTrie, nil } - parts := strings.Split(string(pkg), ".") - for i := 1; i < len(parts); i++ { - parts[i] = parts[i-1] + "." + parts[i] - } - cur := &s.pkgTrie - for _, p := range parts { + enumerator := nameEnumerator{name: pkg} + for { + p, ok := enumerator.next() + if !ok { + return cur, nil + } var err error - cur, err = cur.importPackage(pkgSpan, protoreflect.FullName(p), handler) + cur, err = cur.importPackage(pkgSpan, p, handler) if err != nil { return nil, err } @@ -190,8 +190,6 @@ func (s *Symbols) importPackages(pkgSpan ast.SourceSpan, pkg protoreflect.FullNa return nil, nil } } - - return cur, nil } func (s *packageSymbols) importPackage(pkgSpan ast.SourceSpan, pkg protoreflect.FullName, handler *reporter.Handler) (*packageSymbols, error) { @@ -232,29 +230,29 @@ func (s *packageSymbols) importPackage(pkgSpan ast.SourceSpan, pkg protoreflect. return child, nil } -func (s *Symbols) getPackage(pkg protoreflect.FullName) *packageSymbols { +func (s *Symbols) getPackage(pkg protoreflect.FullName, exact bool) *packageSymbols { if pkg == "" { return &s.pkgTrie } - - parts := strings.Split(string(pkg), ".") - for i := 1; i < len(parts); i++ { - parts[i] = parts[i-1] + "." + parts[i] - } - cur := &s.pkgTrie - for _, p := range parts { + enumerator := nameEnumerator{name: pkg} + for { + p, ok := enumerator.next() + if !ok { + return cur + } cur.mu.RLock() - next := cur.children[protoreflect.FullName(p)] + next := cur.children[p] cur.mu.RUnlock() if next == nil { - return nil + if exact { + return nil + } + return cur } cur = next } - - return cur } func reportSymbolCollision(span ast.SourceSpan, fqn protoreflect.FullName, additionIsEnumVal bool, existing symbolEntry, handler *reporter.Handler) error { @@ -405,7 +403,7 @@ func (s *packageSymbols) commitFileLocked(f protoreflect.FileDescriptor) { s.symbols = map[protoreflect.FullName]symbolEntry{} } if s.exts == nil { - s.exts = map[extNumber]ast.SourcePos{} + s.exts = map[extNumber]ast.SourceSpan{} } _ = walk.Descriptors(f, func(d protoreflect.Descriptor) error { span := sourceSpanFor(d) @@ -543,7 +541,7 @@ func (s *packageSymbols) commitResultLocked(r *result) { s.symbols = map[protoreflect.FullName]symbolEntry{} } if s.exts == nil { - s.exts = map[extNumber]ast.SourcePos{} + s.exts = map[extNumber]ast.SourceSpan{} } _ = walk.DescriptorProtos(r.FileDescriptorProto(), func(fqn protoreflect.FullName, d proto.Message) error { span := nameSpan(r.FileNode(), r.Node(d)) @@ -567,7 +565,7 @@ func (s *Symbols) AddExtension(pkg, extendee protoreflect.FullName, tag protoref return handler.HandleErrorf(span, "could not register extension: extendee %q does not match package %q", extendee, pkg) } } - pkgSyms := s.getPackage(pkg) + pkgSyms := s.getPackage(pkg, true) if pkgSyms == nil { // should never happen return handler.HandleErrorf(span, "could not register extension: missing package symbols for %q", pkg) @@ -581,13 +579,13 @@ func (s *packageSymbols) addExtension(extendee protoreflect.FullName, tag protor extNum := extNumber{extendee: extendee, tag: tag} if existing, ok := s.exts[extNum]; ok { - return handler.HandleErrorf(span, "extension with tag %d for message %s already defined at %v", tag, extendee, existing) + return handler.HandleErrorf(span, "extension with tag %d for message %s already defined at %v", tag, extendee, existing.Start()) } if s.exts == nil { - s.exts = map[extNumber]ast.SourcePos{} + s.exts = map[extNumber]ast.SourceSpan{} } - s.exts[extNum] = span.Start() + s.exts[extNum] = span return nil } @@ -602,15 +600,53 @@ func (s *Symbols) AddExtensionDeclaration(extension, extendee protoreflect.FullN // This is a declaration that has already been added. Ignore. return nil } - return handler.HandleErrorf(span, "extension %s already declared as extending %s with tag %d at %v", extension, existing.extendee, existing.tag, existing.pos) + return handler.HandleErrorf(span, "extension %s already declared as extending %s with tag %d at %v", extension, existing.extendee, existing.tag, existing.span.Start()) } if s.extDecls == nil { s.extDecls = map[protoreflect.FullName]extDecl{} } s.extDecls[extension] = extDecl{ - pos: span.Start(), + span: span, extendee: extendee, tag: tag, } return nil } + +// Lookup finds the registered location of the given name. If the given name has +// not been seen/registered, nil is returned. +func (s *Symbols) Lookup(name protoreflect.FullName) ast.SourceSpan { + // note: getPackage never returns nil when exact=false + pkgSyms := s.getPackage(name, false) + if entry, ok := pkgSyms.symbols[name]; ok { + return entry.span + } + return nil +} + +// LookupExtension finds the registered location of the given extension. If the given +// extension has not been seen/registered, nil is returned. +func (s *Symbols) LookupExtension(messageName protoreflect.FullName, extensionNumber protoreflect.FieldNumber) ast.SourceSpan { + // note: getPackage never returns nil when exact=false + pkgSyms := s.getPackage(messageName, false) + return pkgSyms.exts[extNumber{messageName, extensionNumber}] +} + +type nameEnumerator struct { + name protoreflect.FullName + start int +} + +func (e *nameEnumerator) next() (protoreflect.FullName, bool) { + if e.start < 0 { + return "", false + } + pos := strings.IndexByte(string(e.name[e.start:]), '.') + if pos == -1 { + e.start = -1 + return e.name, true + } + pos += e.start + e.start = pos + 1 + return e.name[:pos], true +} diff --git a/linker/symbols_benchmark_test.go b/linker/symbols_benchmark_test.go new file mode 100644 index 0000000..d14f94f --- /dev/null +++ b/linker/symbols_benchmark_test.go @@ -0,0 +1,31 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linker + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func BenchmarkSymbols(b *testing.B) { + s := &Symbols{} + _, err := s.importPackages(nil, "foo.bar.baz.fizz.buzz.frob.nitz", nil) + require.NoError(b, err) + for i := 0; i < b.N; i++ { + pkg := s.getPackage("foo.bar.baz.fizz.buzz.frob.nitz", true) + require.NotNil(b, pkg) + } +} diff --git a/linker/symbols_test.go b/linker/symbols_test.go index 3999de3..bff472b 100644 --- a/linker/symbols_test.go +++ b/linker/symbols_test.go @@ -34,7 +34,7 @@ func TestSymbolsPackages(t *testing.T) { var s Symbols // default/nameless package is the root - assert.Equal(t, &s.pkgTrie, s.getPackage("")) + assert.Equal(t, &s.pkgTrie, s.getPackage("", true)) h := reporter.NewHandler(nil) span := ast.UnknownSpan("foo.proto") @@ -46,7 +46,7 @@ func TestSymbolsPackages(t *testing.T) { assert.Empty(t, pkg.symbols) assert.Empty(t, pkg.exts) - assert.Equal(t, pkg, s.getPackage("build.buf.foo.bar.baz")) + assert.Equal(t, pkg, s.getPackage("build.buf.foo.bar.baz", true)) // verify that trie was created correctly: // each package has just one entry, which is its immediate sub-package @@ -115,7 +115,7 @@ func TestSymbolsImport(t *testing.T) { // verify contents of s - pkg := s.getPackage("foo.bar") + pkg := s.getPackage("foo.bar", true) syms := pkg.symbols assert.Len(t, syms, 6) assert.Contains(t, syms, protoreflect.FullName("foo.bar.Foo")) @@ -129,7 +129,7 @@ func TestSymbolsImport(t *testing.T) { assert.Contains(t, exts, extNumber{"foo.bar.Foo", 10}) assert.Contains(t, exts, extNumber{"foo.bar.Foo", 11}) - pkg = s.getPackage("google.protobuf") + pkg = s.getPackage("google.protobuf", true) exts = pkg.exts assert.Len(t, exts, 1) assert.Contains(t, exts, extNumber{"google.protobuf.FieldOptions", 20000}) @@ -179,13 +179,13 @@ func TestSymbolExtensions(t *testing.T) { // verify contents of s - pkg := s.getPackage("foo.bar") + pkg := s.getPackage("foo.bar", true) exts := pkg.exts assert.Len(t, exts, 2) assert.Contains(t, exts, extNumber{"foo.bar.Foo", 11}) assert.Contains(t, exts, extNumber{"foo.bar.Foo", 12}) - pkg = s.getPackage("google.protobuf") + pkg = s.getPackage("google.protobuf", true) exts = pkg.exts assert.Len(t, exts, 3) assert.Contains(t, exts, extNumber{"google.protobuf.FileOptions", 10101})