Skip to content

Commit

Permalink
add lookup operations to *Symbols; improve getPackages and importPack…
Browse files Browse the repository at this point in the history
…ages to not allocate slice and names
  • Loading branch information
jhump committed Apr 22, 2024
1 parent c4cc0aa commit b347ffd
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 37 deletions.
114 changes: 83 additions & 31 deletions linker/symbols.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type packageSymbols struct {
children map[protoreflect.FullName]*packageSymbols
files map[protoreflect.FileDescriptor]struct{}
symbols map[protoreflect.FullName]symbolEntry
exts map[extNumber]ast.SourcePos
exts map[extNumber]ast.SourceSpan
}

type extNumber struct {
Expand All @@ -68,7 +68,7 @@ type symbolEntry struct {
}

type extDecl struct {
pos ast.SourcePos
span ast.SourceSpan
extendee protoreflect.FullName
tag protoreflect.FieldNumber
}
Expand Down Expand Up @@ -174,24 +174,22 @@ func (s *Symbols) importPackages(pkgSpan ast.SourceSpan, pkg protoreflect.FullNa
return &s.pkgTrie, nil
}

parts := strings.Split(string(pkg), ".")
for i := 1; i < len(parts); i++ {
parts[i] = parts[i-1] + "." + parts[i]
}

cur := &s.pkgTrie
for _, p := range parts {
enumerator := nameEnumerator{name: pkg}
for {
p, ok := enumerator.next()
if !ok {
return cur, nil
}
var err error
cur, err = cur.importPackage(pkgSpan, protoreflect.FullName(p), handler)
cur, err = cur.importPackage(pkgSpan, p, handler)
if err != nil {
return nil, err
}
if cur == nil {
return nil, nil
}
}

return cur, nil
}

func (s *packageSymbols) importPackage(pkgSpan ast.SourceSpan, pkg protoreflect.FullName, handler *reporter.Handler) (*packageSymbols, error) {
Expand Down Expand Up @@ -232,29 +230,29 @@ func (s *packageSymbols) importPackage(pkgSpan ast.SourceSpan, pkg protoreflect.
return child, nil
}

func (s *Symbols) getPackage(pkg protoreflect.FullName) *packageSymbols {
func (s *Symbols) getPackage(pkg protoreflect.FullName, exact bool) *packageSymbols {
if pkg == "" {
return &s.pkgTrie
}

parts := strings.Split(string(pkg), ".")
for i := 1; i < len(parts); i++ {
parts[i] = parts[i-1] + "." + parts[i]
}

cur := &s.pkgTrie
for _, p := range parts {
enumerator := nameEnumerator{name: pkg}
for {
p, ok := enumerator.next()
if !ok {
return cur
}
cur.mu.RLock()
next := cur.children[protoreflect.FullName(p)]
next := cur.children[p]
cur.mu.RUnlock()

if next == nil {
return nil
if exact {
return nil
}
return cur
}
cur = next
}

return cur
}

func reportSymbolCollision(span ast.SourceSpan, fqn protoreflect.FullName, additionIsEnumVal bool, existing symbolEntry, handler *reporter.Handler) error {
Expand Down Expand Up @@ -405,7 +403,7 @@ func (s *packageSymbols) commitFileLocked(f protoreflect.FileDescriptor) {
s.symbols = map[protoreflect.FullName]symbolEntry{}
}
if s.exts == nil {
s.exts = map[extNumber]ast.SourcePos{}
s.exts = map[extNumber]ast.SourceSpan{}
}
_ = walk.Descriptors(f, func(d protoreflect.Descriptor) error {
span := sourceSpanFor(d)
Expand Down Expand Up @@ -543,7 +541,7 @@ func (s *packageSymbols) commitResultLocked(r *result) {
s.symbols = map[protoreflect.FullName]symbolEntry{}
}
if s.exts == nil {
s.exts = map[extNumber]ast.SourcePos{}
s.exts = map[extNumber]ast.SourceSpan{}
}
_ = walk.DescriptorProtos(r.FileDescriptorProto(), func(fqn protoreflect.FullName, d proto.Message) error {
span := nameSpan(r.FileNode(), r.Node(d))
Expand All @@ -567,7 +565,7 @@ func (s *Symbols) AddExtension(pkg, extendee protoreflect.FullName, tag protoref
return handler.HandleErrorf(span, "could not register extension: extendee %q does not match package %q", extendee, pkg)
}
}
pkgSyms := s.getPackage(pkg)
pkgSyms := s.getPackage(pkg, true)
if pkgSyms == nil {
// should never happen
return handler.HandleErrorf(span, "could not register extension: missing package symbols for %q", pkg)
Expand All @@ -581,13 +579,13 @@ func (s *packageSymbols) addExtension(extendee protoreflect.FullName, tag protor

extNum := extNumber{extendee: extendee, tag: tag}
if existing, ok := s.exts[extNum]; ok {
return handler.HandleErrorf(span, "extension with tag %d for message %s already defined at %v", tag, extendee, existing)
return handler.HandleErrorf(span, "extension with tag %d for message %s already defined at %v", tag, extendee, existing.Start())
}

if s.exts == nil {
s.exts = map[extNumber]ast.SourcePos{}
s.exts = map[extNumber]ast.SourceSpan{}
}
s.exts[extNum] = span.Start()
s.exts[extNum] = span
return nil
}

Expand All @@ -602,15 +600,69 @@ func (s *Symbols) AddExtensionDeclaration(extension, extendee protoreflect.FullN
// This is a declaration that has already been added. Ignore.
return nil
}
return handler.HandleErrorf(span, "extension %s already declared as extending %s with tag %d at %v", extension, existing.extendee, existing.tag, existing.pos)
return handler.HandleErrorf(span, "extension %s already declared as extending %s with tag %d at %v", extension, existing.extendee, existing.tag, existing.span.Start())
}
if s.extDecls == nil {
s.extDecls = map[protoreflect.FullName]extDecl{}
}
s.extDecls[extension] = extDecl{
pos: span.Start(),
span: span,
extendee: extendee,
tag: tag,
}
return nil
}

// Lookup finds the registered location of the given name. If the given name has
// not been seen/registered, nil is returned.
func (s *Symbols) Lookup(name protoreflect.FullName) ast.SourceSpan {
pkg := name.Parent()
for {
pkgSyms := s.getPackage(pkg, false)
if pkgSyms != nil {
if entry, ok := pkgSyms.symbols[name]; ok {
return entry.span
}
return nil
}
if pkg == "" {
return nil
}
pkg = pkg.Parent()
}
}

// LookupExtension finds the registered location of the given extension. If the given
// extension has not been seen/registered, nil is returned.
func (s *Symbols) LookupExtension(messageName protoreflect.FullName, extensionNumber protoreflect.FieldNumber) ast.SourceSpan {
pkg := messageName.Parent()
for {
pkgSyms := s.getPackage(pkg, false)
if pkgSyms != nil {
return pkgSyms.exts[extNumber{messageName, extensionNumber}]
}
if pkg == "" {
return nil
}
pkg = pkg.Parent()
}
}

type nameEnumerator struct {
name protoreflect.FullName
start int
}

func (e *nameEnumerator) next() (protoreflect.FullName, bool) {
if e.start < 0 {
return "", false
}
pos := strings.IndexByte(string(e.name[e.start:]), '.')
if pos == -1 {
e.start = -1
return e.name, true
}
pos += e.start
e.start = pos + 1
return e.name[:pos], true
}
12 changes: 6 additions & 6 deletions linker/symbols_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestSymbolsPackages(t *testing.T) {

var s Symbols
// default/nameless package is the root
assert.Equal(t, &s.pkgTrie, s.getPackage(""))
assert.Equal(t, &s.pkgTrie, s.getPackage("", true))

h := reporter.NewHandler(nil)
span := ast.UnknownSpan("foo.proto")
Expand All @@ -46,7 +46,7 @@ func TestSymbolsPackages(t *testing.T) {
assert.Empty(t, pkg.symbols)
assert.Empty(t, pkg.exts)

assert.Equal(t, pkg, s.getPackage("build.buf.foo.bar.baz"))
assert.Equal(t, pkg, s.getPackage("build.buf.foo.bar.baz", true))

// verify that trie was created correctly:
// each package has just one entry, which is its immediate sub-package
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestSymbolsImport(t *testing.T) {

// verify contents of s

pkg := s.getPackage("foo.bar")
pkg := s.getPackage("foo.bar", true)
syms := pkg.symbols
assert.Len(t, syms, 6)
assert.Contains(t, syms, protoreflect.FullName("foo.bar.Foo"))
Expand All @@ -129,7 +129,7 @@ func TestSymbolsImport(t *testing.T) {
assert.Contains(t, exts, extNumber{"foo.bar.Foo", 10})
assert.Contains(t, exts, extNumber{"foo.bar.Foo", 11})

pkg = s.getPackage("google.protobuf")
pkg = s.getPackage("google.protobuf", true)
exts = pkg.exts
assert.Len(t, exts, 1)
assert.Contains(t, exts, extNumber{"google.protobuf.FieldOptions", 20000})
Expand Down Expand Up @@ -179,13 +179,13 @@ func TestSymbolExtensions(t *testing.T) {

// verify contents of s

pkg := s.getPackage("foo.bar")
pkg := s.getPackage("foo.bar", true)
exts := pkg.exts
assert.Len(t, exts, 2)
assert.Contains(t, exts, extNumber{"foo.bar.Foo", 11})
assert.Contains(t, exts, extNumber{"foo.bar.Foo", 12})

pkg = s.getPackage("google.protobuf")
pkg = s.getPackage("google.protobuf", true)
exts = pkg.exts
assert.Len(t, exts, 3)
assert.Contains(t, exts, extNumber{"google.protobuf.FileOptions", 10101})
Expand Down

0 comments on commit b347ffd

Please sign in to comment.