Skip to content

Commit

Permalink
try out map for tracking the checked set, instead of slice
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump committed Apr 22, 2024
1 parent 63c135c commit 89cc7b1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
10 changes: 5 additions & 5 deletions linker/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ type fileResolver struct {
}

func (r fileResolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.FileDescriptor, error) {
return resolveInFile(r.f, false, map[string]struct{}{}, func(f File) (protoreflect.FileDescriptor, error) {
if f.Path() == path {
return f, nil
}
Expand All @@ -210,7 +210,7 @@ func (r fileResolver) FindFileByPath(path string) (protoreflect.FileDescriptor,
}

func (r fileResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.Descriptor, error) {
return resolveInFile(r.f, false, map[string]struct{}{}, func(f File) (protoreflect.Descriptor, error) {
if d := f.FindDescriptorByName(name); d != nil {
return d, nil
}
Expand All @@ -219,7 +219,7 @@ func (r fileResolver) FindDescriptorByName(name protoreflect.FullName) (protoref
}

func (r fileResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.MessageType, error) {
return resolveInFile(r.f, false, map[string]struct{}{}, func(f File) (protoreflect.MessageType, error) {
d := f.FindDescriptorByName(message)
if d != nil {
md, ok := d.(protoreflect.MessageDescriptor)
Expand All @@ -243,7 +243,7 @@ func messageNameFromURL(url string) string {
}

func (r fileResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.ExtensionType, error) {
return resolveInFile(r.f, false, map[string]struct{}{}, func(f File) (protoreflect.ExtensionType, error) {
d := f.FindDescriptorByName(field)
if d != nil {
fld, ok := d.(protoreflect.FieldDescriptor)
Expand All @@ -260,7 +260,7 @@ func (r fileResolver) FindExtensionByName(field protoreflect.FullName) (protoref
}

func (r fileResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.ExtensionType, error) {
return resolveInFile(r.f, false, map[string]struct{}{}, func(f File) (protoreflect.ExtensionType, error) {
ext := findExtension(f, message, field)
if ext != nil {
return ext.Type(), nil
Expand Down
36 changes: 19 additions & 17 deletions linker/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ func (r *result) ResolveMessageLiteralExtensionName(node ast.IdentValueNode) str
return r.optionQualifiedNames[node]
}

func (r *result) resolveElement(name protoreflect.FullName, checkedCache []string) protoreflect.Descriptor {
func (r *result) resolveElement(name protoreflect.FullName, checkedCache map[string]struct{}) protoreflect.Descriptor {
if len(name) > 0 && name[0] == '.' {
name = name[1:]
}
res, _ := resolveInFile(r, false, checkedCache[:0], func(f File) (protoreflect.Descriptor, error) {
// clear cached map instance before use
for k := range checkedCache {
delete(checkedCache, k)
}
res, _ := resolveInFile(r, false, checkedCache, func(f File) (protoreflect.Descriptor, error) {
d := resolveElementInFile(name, f)
if d != nil {
return d, nil
Expand All @@ -48,16 +52,14 @@ func (r *result) resolveElement(name protoreflect.FullName, checkedCache []strin
return res
}

func resolveInFile[T any](f File, publicImportsOnly bool, checked []string, fn func(File) (T, error)) (T, error) {
func resolveInFile[T any](f File, publicImportsOnly bool, checked map[string]struct{}, fn func(File) (T, error)) (T, error) {
var zero T
path := f.Path()
for _, str := range checked {
if str == path {
// already checked
return zero, protoregistry.NotFound
}
if _, ok := checked[path]; ok {
// already checked
return zero, protoregistry.NotFound
}
checked = append(checked, path)
checked[path] = struct{}{}

res, err := fn(f)
if err == nil {
Expand Down Expand Up @@ -168,7 +170,7 @@ func (r *result) createDescendants() {

func (r *result) resolveReferences(handler *reporter.Handler, s *Symbols) error {
fd := r.FileDescriptorProto()
checkedCache := make([]string, 0, 16)
checkedCache := make(map[string]struct{}, 16)
scopes := []scope{fileScope(r, checkedCache)}
if fd.Options != nil {
if err := r.resolveOptions(handler, "file", protoreflect.FullName(fd.GetName()), fd.Options.UninterpretedOption, scopes, checkedCache); err != nil {
Expand Down Expand Up @@ -302,7 +304,7 @@ func allowedProto3Extendee(n string) bool {
return ok
}

func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, extendees map[ast.Node]struct{}, s *Symbols, scopes []scope, checkedCache []string) error {
func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, extendees map[ast.Node]struct{}, s *Symbols, scopes []scope, checkedCache map[string]struct{}) error {
r := f.file
fld := f.proto
file := r.FileNode()
Expand Down Expand Up @@ -466,7 +468,7 @@ func isValidMap(mapField protoreflect.FieldDescriptor, mapEntry protoreflect.Mes
string(mapEntry.Name()) == internal.InitCap(internal.JSONName(string(mapField.Name())))+"Entry"
}

func resolveMethodTypes(m *mtdDescriptor, handler *reporter.Handler, scopes []scope, checkedCache []string) error {
func resolveMethodTypes(m *mtdDescriptor, handler *reporter.Handler, scopes []scope, checkedCache map[string]struct{}) error {
scope := fmt.Sprintf("method %s", m.fqn)
r := m.file
mtd := m.proto
Expand Down Expand Up @@ -518,7 +520,7 @@ func resolveMethodTypes(m *mtdDescriptor, handler *reporter.Handler, scopes []sc
return nil
}

func (r *result) resolveOptions(handler *reporter.Handler, elemType string, elemName protoreflect.FullName, opts []*descriptorpb.UninterpretedOption, scopes []scope, checkedCache []string) error {
func (r *result) resolveOptions(handler *reporter.Handler, elemType string, elemName protoreflect.FullName, opts []*descriptorpb.UninterpretedOption, scopes []scope, checkedCache map[string]struct{}) error {
mc := &internal.MessageContext{
File: r,
ElementName: string(elemName),
Expand Down Expand Up @@ -552,7 +554,7 @@ opts:
return nil
}

func (r *result) resolveOptionValue(handler *reporter.Handler, mc *internal.MessageContext, val ast.ValueNode, scopes []scope, checkedCache []string) error {
func (r *result) resolveOptionValue(handler *reporter.Handler, mc *internal.MessageContext, val ast.ValueNode, scopes []scope, checkedCache map[string]struct{}) error {
optVal := val.Value()
switch optVal := optVal.(type) {
case []ast.ValueNode:
Expand Down Expand Up @@ -610,7 +612,7 @@ func (r *result) resolveOptionValue(handler *reporter.Handler, mc *internal.Mess
return nil
}

func (r *result) resolveExtensionName(name string, scopes []scope, checkedCache []string) (string, error) {
func (r *result) resolveExtensionName(name string, scopes []scope, checkedCache map[string]struct{}) (string, error) {
dsc := r.resolve(name, false, scopes, checkedCache)
if dsc == nil {
return "", fmt.Errorf("unknown extension %s", name)
Expand All @@ -626,7 +628,7 @@ func (r *result) resolveExtensionName(name string, scopes []scope, checkedCache
return string("." + dsc.FullName()), nil
}

func (r *result) resolve(name string, onlyTypes bool, scopes []scope, checkedCache []string) protoreflect.Descriptor {
func (r *result) resolve(name string, onlyTypes bool, scopes []scope, checkedCache map[string]struct{}) protoreflect.Descriptor {
if strings.HasPrefix(name, ".") {
// already fully-qualified
return r.resolveElement(protoreflect.FullName(name[1:]), checkedCache)
Expand Down Expand Up @@ -674,7 +676,7 @@ func isType(d protoreflect.Descriptor) bool {
// can be declared.
type scope func(firstName, fullName string) protoreflect.Descriptor

func fileScope(r *result, checkedCache []string) scope {
func fileScope(r *result, checkedCache map[string]struct{}) scope {
// we search symbols in this file, but also symbols in other files that have
// the same package as this file or a "parent" package (in protobuf,
// packages are a hierarchy like C++ namespaces)
Expand Down

0 comments on commit 89cc7b1

Please sign in to comment.