Skip to content

Commit

Permalink
Improve the API for TypeNameSet (#1607)
Browse files Browse the repository at this point in the history
* Remove set returns from TypeNameSet and test

* Code gardening

* Fix bug in TypeMatcher
  • Loading branch information
theunrepentantgeek authored Jun 28, 2021
1 parent 76d4bf8 commit 3b1d6ce
Show file tree
Hide file tree
Showing 15 changed files with 111 additions and 64 deletions.
2 changes: 1 addition & 1 deletion hack/generator/pkg/astmodel/allof_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (allOf *AllOfType) Types() ReadonlyTypeSet {

// References returns any type referenced by the AllOf types
func (allOf *AllOfType) References() TypeNameSet {
var result TypeNameSet
result := NewTypeNameSet()
allOf.types.ForEach(func(t Type, _ int) {
result = SetUnion(result, t.References())
})
Expand Down
2 changes: 1 addition & 1 deletion hack/generator/pkg/astmodel/interface_implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (iface *InterfaceImplementation) RequiredPackageReferences() *PackageRefere

// References indicates whether this type includes any direct references to the given type
func (iface *InterfaceImplementation) References() TypeNameSet {
var results TypeNameSet
results := NewTypeNameSet()
for _, f := range iface.functions {
for ref := range f.References() {
results.Add(ref)
Expand Down
4 changes: 2 additions & 2 deletions hack/generator/pkg/astmodel/interface_implementer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ func (i InterfaceImplementer) WithInterface(iface *InterfaceImplementation) Inte
}

func (i InterfaceImplementer) References() TypeNameSet {
var results TypeNameSet
results := NewTypeNameSet()
for _, iface := range i.interfaces {
for ref := range iface.References() {
results = results.Add(ref)
results.Add(ref)
}
}

Expand Down
4 changes: 2 additions & 2 deletions hack/generator/pkg/astmodel/object_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ func (objectType *ObjectType) RequiredPackageReferences() *PackageReferenceSet {
func (objectType *ObjectType) References() TypeNameSet {
results := NewTypeNameSet()
for _, property := range objectType.properties {
results = results.AddAll(property.PropertyType().References())
results.AddAll(property.PropertyType().References())
}

for _, property := range objectType.embedded {
results = results.AddAll(property.PropertyType().References())
results.AddAll(property.PropertyType().References())
}

// Not collecting types from functions deliberately.
Expand Down
3 changes: 1 addition & 2 deletions hack/generator/pkg/astmodel/oneof_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ func (oneOf *OneOfType) Types() ReadonlyTypeSet {

// References returns any type referenced by the OneOf types
func (oneOf *OneOfType) References() TypeNameSet {
var result TypeNameSet

result := NewTypeNameSet()
oneOf.types.ForEach(func(t Type, _ int) {
result = SetUnion(result, t.References())
})
Expand Down
4 changes: 2 additions & 2 deletions hack/generator/pkg/astmodel/reference_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type ReferenceGraph struct {
// CollectResourceDefinitions returns a TypeNameSet of all of the
// root definitions in the definitions passed in.
func CollectResourceDefinitions(definitions Types) TypeNameSet {
resources := make(TypeNameSet)
resources := NewTypeNameSet()
for _, def := range definitions {
if _, ok := def.Type().(*ResourceType); ok {
resources.Add(def.Name())
Expand All @@ -45,7 +45,7 @@ func CollectARMSpecAndStatusDefinitions(definitions Types) TypeNameSet {
return armName, nil
}

armSpecAndStatus := make(TypeNameSet)
armSpecAndStatus := NewTypeNameSet()
for _, def := range definitions {
if resourceType, ok := definitions.ResolveResourceType(def.Type()); ok {

Expand Down
2 changes: 1 addition & 1 deletion hack/generator/pkg/astmodel/reference_graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func Test_ReferenceGraph_Gives_Correct_Depth(t *testing.T) {
}

names := func(ns ...string) TypeNameSet {
result := make(TypeNameSet)
result := NewTypeNameSet()
for _, n := range ns {
result.Add(name(n))
}
Expand Down
57 changes: 19 additions & 38 deletions hack/generator/pkg/astmodel/type_name_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package astmodel

import "fmt"
import (
"fmt"
)

// TypeNameSet stores type names in no particular order without
// duplicates.
Expand All @@ -14,23 +16,17 @@ type TypeNameSet map[TypeName]struct{}
// NewTypeNameSet makes a TypeNameSet containing the specified
// names. If no elements are passed it might be nil.
func NewTypeNameSet(initial ...TypeName) TypeNameSet {
var result TypeNameSet
result := make(TypeNameSet)
for _, name := range initial {
result = result.Add(name)
result.Add(name)
}

return result
}

// Add includes the passed name in the set and returns the updated
// set, so that adding can work for a nil set - this makes it more
// convenient to add to sets kept in a map (in the way you might with
// a map of slices).
func (ts TypeNameSet) Add(val TypeName) TypeNameSet {
if ts == nil {
ts = make(TypeNameSet)
}
// Add includes the passed name in the set
func (ts TypeNameSet) Add(val TypeName) {
ts[val] = struct{}{}
return ts
}

// Contains returns whether this name is in the set. Works for nil
Expand All @@ -43,16 +39,9 @@ func (ts TypeNameSet) Contains(val TypeName) bool {
return found
}

// Remove removes the specified item if it is in the set. If it is not in
// the set this is a no-op.
func (ts TypeNameSet) Remove(val TypeName) TypeNameSet {
if ts == nil {
return ts
}

// Remove removes the specified item if it is in the set. If it is not in the set this is a no-op.
func (ts TypeNameSet) Remove(val TypeName) {
delete(ts, val)

return ts
}

func (ts TypeNameSet) Equals(set TypeNameSet) bool {
Expand All @@ -72,39 +61,31 @@ func (ts TypeNameSet) Equals(set TypeNameSet) bool {
}

// AddAll adds the provided TypeNameSet to the set
func (ts TypeNameSet) AddAll(other TypeNameSet) TypeNameSet {
if ts == nil {
ts = make(TypeNameSet)
}

func (ts TypeNameSet) AddAll(other TypeNameSet) {
for val := range other {
ts[val] = struct{}{}
}

return ts
}

// Single returns the single TypeName in the set. This panics if there is not a single item in the set.
func (ts TypeNameSet) Single() TypeName {
if len(ts) != 1 {
panic(fmt.Sprintf("Single() cannot be called with %d types in the set", len(ts)))
}

for name := range ts {
return name
if len(ts) == 1 {
for name := range ts {
return name
}
}

panic("Reached unreachable code")
panic(fmt.Sprintf("Single() cannot be called with %d types in the set", len(ts)))
}

// SetUnion returns a new set with all of the names in s1 or s2.
func SetUnion(s1, s2 TypeNameSet) TypeNameSet {
var result TypeNameSet
result := NewTypeNameSet()
for val := range s1 {
result = result.Add(val)
result.Add(val)
}
for val := range s2 {
result = result.Add(val)
result.Add(val)
}
return result
}
46 changes: 46 additions & 0 deletions hack/generator/pkg/astmodel/type_name_set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

package astmodel

import (
"testing"

. "github.com/onsi/gomega"
)

var (
oneTypeName = MakeTypeName(GenRuntimeReference, "One")
twoTypeName = MakeTypeName(GenRuntimeReference, "Two")
)

/*
* Empty Set tests
*/

func Test_TypeNameSet_WhenEmpty_HasLengthZero(t *testing.T) {
g := NewGomegaWithT(t)
emptySet := NewTypeNameSet()
g.Expect(len(emptySet)).To(Equal(0))
}

/*
* Add() Tests
*/

func Test_TypeNameSet_AfterAddingFirstItem_ContainsItem(t *testing.T) {
g := NewGomegaWithT(t)
set := NewTypeNameSet()
set.Add(oneTypeName)
g.Expect(set.Contains(oneTypeName)).To(BeTrue())
}

func Test_TypeNameSet_AfterAddingSecondItem_ContainsItem(t *testing.T) {
g := NewGomegaWithT(t)
set := NewTypeNameSet()
set.Add(oneTypeName)
set.Add(twoTypeName)
g.Expect(set.Contains(twoTypeName)).To(BeTrue())
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func RemoveEmptyObjects(types astmodel.Types) (astmodel.Types, error) {
}

func findEmptyObjectTypes(types astmodel.Types) astmodel.TypeNameSet {
result := make(astmodel.TypeNameSet)
result := astmodel.NewTypeNameSet()

for _, def := range types {
ot, ok := astmodel.AsObjectType(def.Type())
Expand Down
28 changes: 20 additions & 8 deletions hack/generator/pkg/codegen/embeddedresources/remover.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,21 @@ func findSubResourcePropertiesTypeNames(types astmodel.Types) (map[astmodel.Type
errs = append(errs, errors.Wrapf(err, "couldn't extract spec/status properties from %q", def.Name()))
continue
}

if specPropertiesTypeName != nil {
result[owner] = result[owner].Add(*specPropertiesTypeName)
if result[owner] == nil {
result[owner] = astmodel.NewTypeNameSet(*specPropertiesTypeName)
} else {
result[owner].Add(*specPropertiesTypeName)
}
}

if statusPropertiesTypeName != nil {
result[owner] = result[owner].Add(*statusPropertiesTypeName)
if result[owner] == nil {
result[owner] = astmodel.NewTypeNameSet(*statusPropertiesTypeName)
} else {
result[owner].Add(*statusPropertiesTypeName)
}
}
}

Expand Down Expand Up @@ -322,7 +332,7 @@ func findAllResourcePropertiesTypes(types astmodel.Types) (astmodel.TypeNameSet,
})

var errs []error
result := make(astmodel.TypeNameSet)
result := astmodel.NewTypeNameSet()

// Identify sub-resources and their "properties", associate them with parent resource
// Look through parent resource for subresource properties
Expand All @@ -338,11 +348,13 @@ func findAllResourcePropertiesTypes(types astmodel.Types) (astmodel.TypeNameSet,
errs = append(errs, errors.Wrapf(err, "couldn't extract spec/status properties from %q", def.Name()))
continue
}

if specPropertiesTypeName != nil {
result = result.Add(*specPropertiesTypeName)
result.Add(*specPropertiesTypeName)
}

if statusPropertiesTypeName != nil {
result = result.Add(*statusPropertiesTypeName)
result.Add(*statusPropertiesTypeName)
}
}

Expand All @@ -362,7 +374,7 @@ func findAllResourceStatusTypes(types astmodel.Types) astmodel.TypeNameSet {
return ok
})

result := make(astmodel.TypeNameSet)
result := astmodel.NewTypeNameSet()

// Identify sub-resources and their "properties", associate them with parent resource
// Look through parent resource for subresource properties
Expand All @@ -378,7 +390,7 @@ func findAllResourceStatusTypes(types astmodel.Types) astmodel.TypeNameSet {
continue
}

result = result.Add(statusName)
result.Add(statusName)
}

return result
Expand Down Expand Up @@ -412,7 +424,7 @@ func requiredResourceProperties() []string {
}

// optionalResourceProperties are properties which may or may not be on a resource. Technically all resources
// should have all of these properties, but because we drop the top-level allof that joins resource types with
// should have all of these properties, but because we drop the top-level AllOf that joins resource types with
// ResourceBase when parsing schemas sometimes they aren't defined.
func optionalResourceProperties() []string {
return []string{
Expand Down
6 changes: 5 additions & 1 deletion hack/generator/pkg/codegen/embeddedresources/renamer.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ func simplifyTypeNames(types astmodel.Types, flag astmodel.TypeFlag) (astmodel.T
return nil, err
}

updatedNames[embeddedName.original] = updatedNames[embeddedName.original].Add(def.Name())
if updatedNames[embeddedName.original] == nil {
updatedNames[embeddedName.original] = astmodel.NewTypeNameSet(def.Name())
} else {
updatedNames[embeddedName.original].Add(def.Name())
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ func TestConnectionChecker_Avoids_Cycles(t *testing.T) {
graph := astmodel.NewReferenceGraph(roots, references)
connectedSet := graph.Connected()

var names astmodel.TypeNameSet
names := astmodel.NewTypeNameSet()
for name := range connectedSet {
names = names.Add(name)
names.Add(name)
}

g.Expect(names).To(Equal(makeSet("res1", "res2", "A", "B", "C", "D")))
Expand Down
2 changes: 1 addition & 1 deletion hack/generator/pkg/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func buildExportFilterFunc(f *ExportFilter, allTypes astmodel.Types) ExportFilte
}

case ExportFilterIncludeTransitive:
applicableTypes := make(astmodel.TypeNameSet)
applicableTypes := astmodel.NewTypeNameSet()
for tn := range allTypes {
if f.AppliesToType(tn) {
collectAllReferencedTypes(allTypes, tn, applicableTypes)
Expand Down
9 changes: 7 additions & 2 deletions hack/generator/pkg/config/type_matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func (t *TypeMatcher) Initialize() error {
t.groupRegex = createGlobbingRegex(t.Group)
t.versionRegex = createGlobbingRegex(t.Version)
t.nameRegex = createGlobbingRegex(t.Name)
t.matchedTypes = make(astmodel.TypeNameSet)
t.matchedTypes = astmodel.NewTypeNameSet()

// Default MatchRequired
if t.MatchRequired == nil {
temp := true
Expand Down Expand Up @@ -87,7 +88,11 @@ func (t *TypeMatcher) AppliesToType(typeName astmodel.TypeName) bool {

// Track this match so we can later report if we didn't match anything
if result {
t.matchedTypes = t.matchedTypes.Add(typeName)
if t.matchedTypes == nil {
t.matchedTypes = astmodel.NewTypeNameSet(typeName)
} else {
t.matchedTypes.Add(typeName)
}
}

return result
Expand Down

0 comments on commit 3b1d6ce

Please sign in to comment.