From 3b1d6ce58a00d01144ee7684a79c12e6553b589d Mon Sep 17 00:00:00 2001 From: Bevan Arps Date: Mon, 28 Jun 2021 13:56:27 +1200 Subject: [PATCH] Improve the API for TypeNameSet (#1607) * Remove set returns from TypeNameSet and test * Code gardening * Fix bug in TypeMatcher --- hack/generator/pkg/astmodel/allof_type.go | 2 +- .../pkg/astmodel/interface_implementation.go | 2 +- .../pkg/astmodel/interface_implementer.go | 4 +- hack/generator/pkg/astmodel/object_type.go | 4 +- hack/generator/pkg/astmodel/oneof_type.go | 3 +- .../generator/pkg/astmodel/reference_graph.go | 4 +- .../pkg/astmodel/reference_graph_test.go | 2 +- hack/generator/pkg/astmodel/type_name_set.go | 57 +++++++------------ .../pkg/astmodel/type_name_set_test.go | 46 +++++++++++++++ .../embeddedresources/remove_empty_objects.go | 2 +- .../pkg/codegen/embeddedresources/remover.go | 28 ++++++--- .../pkg/codegen/embeddedresources/renamer.go | 6 +- .../pipeline/strip_unused_types_test.go | 4 +- hack/generator/pkg/config/configuration.go | 2 +- hack/generator/pkg/config/type_matcher.go | 9 ++- 15 files changed, 111 insertions(+), 64 deletions(-) create mode 100644 hack/generator/pkg/astmodel/type_name_set_test.go diff --git a/hack/generator/pkg/astmodel/allof_type.go b/hack/generator/pkg/astmodel/allof_type.go index 8de6583baf1..c5d9f51a50b 100644 --- a/hack/generator/pkg/astmodel/allof_type.go +++ b/hack/generator/pkg/astmodel/allof_type.go @@ -97,7 +97,7 @@ func (allOf *AllOfType) Types() ReadonlyTypeSet { // References returns any type referenced by the AllOf types func (allOf *AllOfType) References() TypeNameSet { - var result TypeNameSet + result := NewTypeNameSet() allOf.types.ForEach(func(t Type, _ int) { result = SetUnion(result, t.References()) }) diff --git a/hack/generator/pkg/astmodel/interface_implementation.go b/hack/generator/pkg/astmodel/interface_implementation.go index 8dd419eeddf..3712ac56c39 100644 --- a/hack/generator/pkg/astmodel/interface_implementation.go +++ b/hack/generator/pkg/astmodel/interface_implementation.go @@ -47,7 +47,7 @@ func (iface *InterfaceImplementation) RequiredPackageReferences() *PackageRefere // References indicates whether this type includes any direct references to the given type func (iface *InterfaceImplementation) References() TypeNameSet { - var results TypeNameSet + results := NewTypeNameSet() for _, f := range iface.functions { for ref := range f.References() { results.Add(ref) diff --git a/hack/generator/pkg/astmodel/interface_implementer.go b/hack/generator/pkg/astmodel/interface_implementer.go index f88d540d11b..15a7c0cae68 100644 --- a/hack/generator/pkg/astmodel/interface_implementer.go +++ b/hack/generator/pkg/astmodel/interface_implementer.go @@ -41,10 +41,10 @@ func (i InterfaceImplementer) WithInterface(iface *InterfaceImplementation) Inte } func (i InterfaceImplementer) References() TypeNameSet { - var results TypeNameSet + results := NewTypeNameSet() for _, iface := range i.interfaces { for ref := range iface.References() { - results = results.Add(ref) + results.Add(ref) } } diff --git a/hack/generator/pkg/astmodel/object_type.go b/hack/generator/pkg/astmodel/object_type.go index 583905bdada..1c7e5a98188 100644 --- a/hack/generator/pkg/astmodel/object_type.go +++ b/hack/generator/pkg/astmodel/object_type.go @@ -233,11 +233,11 @@ func (objectType *ObjectType) RequiredPackageReferences() *PackageReferenceSet { func (objectType *ObjectType) References() TypeNameSet { results := NewTypeNameSet() for _, property := range objectType.properties { - results = results.AddAll(property.PropertyType().References()) + results.AddAll(property.PropertyType().References()) } for _, property := range objectType.embedded { - results = results.AddAll(property.PropertyType().References()) + results.AddAll(property.PropertyType().References()) } // Not collecting types from functions deliberately. diff --git a/hack/generator/pkg/astmodel/oneof_type.go b/hack/generator/pkg/astmodel/oneof_type.go index 6561179de1a..cbcf199fee5 100644 --- a/hack/generator/pkg/astmodel/oneof_type.go +++ b/hack/generator/pkg/astmodel/oneof_type.go @@ -61,8 +61,7 @@ func (oneOf *OneOfType) Types() ReadonlyTypeSet { // References returns any type referenced by the OneOf types func (oneOf *OneOfType) References() TypeNameSet { - var result TypeNameSet - + result := NewTypeNameSet() oneOf.types.ForEach(func(t Type, _ int) { result = SetUnion(result, t.References()) }) diff --git a/hack/generator/pkg/astmodel/reference_graph.go b/hack/generator/pkg/astmodel/reference_graph.go index 1a29673a084..2cd1104e91a 100644 --- a/hack/generator/pkg/astmodel/reference_graph.go +++ b/hack/generator/pkg/astmodel/reference_graph.go @@ -18,7 +18,7 @@ type ReferenceGraph struct { // CollectResourceDefinitions returns a TypeNameSet of all of the // root definitions in the definitions passed in. func CollectResourceDefinitions(definitions Types) TypeNameSet { - resources := make(TypeNameSet) + resources := NewTypeNameSet() for _, def := range definitions { if _, ok := def.Type().(*ResourceType); ok { resources.Add(def.Name()) @@ -45,7 +45,7 @@ func CollectARMSpecAndStatusDefinitions(definitions Types) TypeNameSet { return armName, nil } - armSpecAndStatus := make(TypeNameSet) + armSpecAndStatus := NewTypeNameSet() for _, def := range definitions { if resourceType, ok := definitions.ResolveResourceType(def.Type()); ok { diff --git a/hack/generator/pkg/astmodel/reference_graph_test.go b/hack/generator/pkg/astmodel/reference_graph_test.go index f66cfc20a35..b6894f11cdd 100644 --- a/hack/generator/pkg/astmodel/reference_graph_test.go +++ b/hack/generator/pkg/astmodel/reference_graph_test.go @@ -21,7 +21,7 @@ func Test_ReferenceGraph_Gives_Correct_Depth(t *testing.T) { } names := func(ns ...string) TypeNameSet { - result := make(TypeNameSet) + result := NewTypeNameSet() for _, n := range ns { result.Add(name(n)) } diff --git a/hack/generator/pkg/astmodel/type_name_set.go b/hack/generator/pkg/astmodel/type_name_set.go index 9ea3943471e..f8bb119db5c 100644 --- a/hack/generator/pkg/astmodel/type_name_set.go +++ b/hack/generator/pkg/astmodel/type_name_set.go @@ -5,7 +5,9 @@ package astmodel -import "fmt" +import ( + "fmt" +) // TypeNameSet stores type names in no particular order without // duplicates. @@ -14,23 +16,17 @@ type TypeNameSet map[TypeName]struct{} // NewTypeNameSet makes a TypeNameSet containing the specified // names. If no elements are passed it might be nil. func NewTypeNameSet(initial ...TypeName) TypeNameSet { - var result TypeNameSet + result := make(TypeNameSet) for _, name := range initial { - result = result.Add(name) + result.Add(name) } + return result } -// Add includes the passed name in the set and returns the updated -// set, so that adding can work for a nil set - this makes it more -// convenient to add to sets kept in a map (in the way you might with -// a map of slices). -func (ts TypeNameSet) Add(val TypeName) TypeNameSet { - if ts == nil { - ts = make(TypeNameSet) - } +// Add includes the passed name in the set +func (ts TypeNameSet) Add(val TypeName) { ts[val] = struct{}{} - return ts } // Contains returns whether this name is in the set. Works for nil @@ -43,16 +39,9 @@ func (ts TypeNameSet) Contains(val TypeName) bool { return found } -// Remove removes the specified item if it is in the set. If it is not in -// the set this is a no-op. -func (ts TypeNameSet) Remove(val TypeName) TypeNameSet { - if ts == nil { - return ts - } - +// Remove removes the specified item if it is in the set. If it is not in the set this is a no-op. +func (ts TypeNameSet) Remove(val TypeName) { delete(ts, val) - - return ts } func (ts TypeNameSet) Equals(set TypeNameSet) bool { @@ -72,39 +61,31 @@ func (ts TypeNameSet) Equals(set TypeNameSet) bool { } // AddAll adds the provided TypeNameSet to the set -func (ts TypeNameSet) AddAll(other TypeNameSet) TypeNameSet { - if ts == nil { - ts = make(TypeNameSet) - } - +func (ts TypeNameSet) AddAll(other TypeNameSet) { for val := range other { ts[val] = struct{}{} } - - return ts } // Single returns the single TypeName in the set. This panics if there is not a single item in the set. func (ts TypeNameSet) Single() TypeName { - if len(ts) != 1 { - panic(fmt.Sprintf("Single() cannot be called with %d types in the set", len(ts))) - } - - for name := range ts { - return name + if len(ts) == 1 { + for name := range ts { + return name + } } - panic("Reached unreachable code") + panic(fmt.Sprintf("Single() cannot be called with %d types in the set", len(ts))) } // SetUnion returns a new set with all of the names in s1 or s2. func SetUnion(s1, s2 TypeNameSet) TypeNameSet { - var result TypeNameSet + result := NewTypeNameSet() for val := range s1 { - result = result.Add(val) + result.Add(val) } for val := range s2 { - result = result.Add(val) + result.Add(val) } return result } diff --git a/hack/generator/pkg/astmodel/type_name_set_test.go b/hack/generator/pkg/astmodel/type_name_set_test.go new file mode 100644 index 00000000000..6349662ff43 --- /dev/null +++ b/hack/generator/pkg/astmodel/type_name_set_test.go @@ -0,0 +1,46 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +package astmodel + +import ( + "testing" + + . "github.com/onsi/gomega" +) + +var ( + oneTypeName = MakeTypeName(GenRuntimeReference, "One") + twoTypeName = MakeTypeName(GenRuntimeReference, "Two") +) + +/* + * Empty Set tests + */ + +func Test_TypeNameSet_WhenEmpty_HasLengthZero(t *testing.T) { + g := NewGomegaWithT(t) + emptySet := NewTypeNameSet() + g.Expect(len(emptySet)).To(Equal(0)) +} + +/* + * Add() Tests + */ + +func Test_TypeNameSet_AfterAddingFirstItem_ContainsItem(t *testing.T) { + g := NewGomegaWithT(t) + set := NewTypeNameSet() + set.Add(oneTypeName) + g.Expect(set.Contains(oneTypeName)).To(BeTrue()) +} + +func Test_TypeNameSet_AfterAddingSecondItem_ContainsItem(t *testing.T) { + g := NewGomegaWithT(t) + set := NewTypeNameSet() + set.Add(oneTypeName) + set.Add(twoTypeName) + g.Expect(set.Contains(twoTypeName)).To(BeTrue()) +} diff --git a/hack/generator/pkg/codegen/embeddedresources/remove_empty_objects.go b/hack/generator/pkg/codegen/embeddedresources/remove_empty_objects.go index 14e833ce029..1b87ac0a640 100644 --- a/hack/generator/pkg/codegen/embeddedresources/remove_empty_objects.go +++ b/hack/generator/pkg/codegen/embeddedresources/remove_empty_objects.go @@ -32,7 +32,7 @@ func RemoveEmptyObjects(types astmodel.Types) (astmodel.Types, error) { } func findEmptyObjectTypes(types astmodel.Types) astmodel.TypeNameSet { - result := make(astmodel.TypeNameSet) + result := astmodel.NewTypeNameSet() for _, def := range types { ot, ok := astmodel.AsObjectType(def.Type()) diff --git a/hack/generator/pkg/codegen/embeddedresources/remover.go b/hack/generator/pkg/codegen/embeddedresources/remover.go index 4daaeb309e5..2e58e62de5b 100644 --- a/hack/generator/pkg/codegen/embeddedresources/remover.go +++ b/hack/generator/pkg/codegen/embeddedresources/remover.go @@ -273,11 +273,21 @@ func findSubResourcePropertiesTypeNames(types astmodel.Types) (map[astmodel.Type errs = append(errs, errors.Wrapf(err, "couldn't extract spec/status properties from %q", def.Name())) continue } + if specPropertiesTypeName != nil { - result[owner] = result[owner].Add(*specPropertiesTypeName) + if result[owner] == nil { + result[owner] = astmodel.NewTypeNameSet(*specPropertiesTypeName) + } else { + result[owner].Add(*specPropertiesTypeName) + } } + if statusPropertiesTypeName != nil { - result[owner] = result[owner].Add(*statusPropertiesTypeName) + if result[owner] == nil { + result[owner] = astmodel.NewTypeNameSet(*statusPropertiesTypeName) + } else { + result[owner].Add(*statusPropertiesTypeName) + } } } @@ -322,7 +332,7 @@ func findAllResourcePropertiesTypes(types astmodel.Types) (astmodel.TypeNameSet, }) var errs []error - result := make(astmodel.TypeNameSet) + result := astmodel.NewTypeNameSet() // Identify sub-resources and their "properties", associate them with parent resource // Look through parent resource for subresource properties @@ -338,11 +348,13 @@ func findAllResourcePropertiesTypes(types astmodel.Types) (astmodel.TypeNameSet, errs = append(errs, errors.Wrapf(err, "couldn't extract spec/status properties from %q", def.Name())) continue } + if specPropertiesTypeName != nil { - result = result.Add(*specPropertiesTypeName) + result.Add(*specPropertiesTypeName) } + if statusPropertiesTypeName != nil { - result = result.Add(*statusPropertiesTypeName) + result.Add(*statusPropertiesTypeName) } } @@ -362,7 +374,7 @@ func findAllResourceStatusTypes(types astmodel.Types) astmodel.TypeNameSet { return ok }) - result := make(astmodel.TypeNameSet) + result := astmodel.NewTypeNameSet() // Identify sub-resources and their "properties", associate them with parent resource // Look through parent resource for subresource properties @@ -378,7 +390,7 @@ func findAllResourceStatusTypes(types astmodel.Types) astmodel.TypeNameSet { continue } - result = result.Add(statusName) + result.Add(statusName) } return result @@ -412,7 +424,7 @@ func requiredResourceProperties() []string { } // optionalResourceProperties are properties which may or may not be on a resource. Technically all resources -// should have all of these properties, but because we drop the top-level allof that joins resource types with +// should have all of these properties, but because we drop the top-level AllOf that joins resource types with // ResourceBase when parsing schemas sometimes they aren't defined. func optionalResourceProperties() []string { return []string{ diff --git a/hack/generator/pkg/codegen/embeddedresources/renamer.go b/hack/generator/pkg/codegen/embeddedresources/renamer.go index 3fd14f6d8a2..e4fca51f87c 100644 --- a/hack/generator/pkg/codegen/embeddedresources/renamer.go +++ b/hack/generator/pkg/codegen/embeddedresources/renamer.go @@ -145,7 +145,11 @@ func simplifyTypeNames(types astmodel.Types, flag astmodel.TypeFlag) (astmodel.T return nil, err } - updatedNames[embeddedName.original] = updatedNames[embeddedName.original].Add(def.Name()) + if updatedNames[embeddedName.original] == nil { + updatedNames[embeddedName.original] = astmodel.NewTypeNameSet(def.Name()) + } else { + updatedNames[embeddedName.original].Add(def.Name()) + } } } diff --git a/hack/generator/pkg/codegen/pipeline/strip_unused_types_test.go b/hack/generator/pkg/codegen/pipeline/strip_unused_types_test.go index 6180de27623..a18f9d02c26 100644 --- a/hack/generator/pkg/codegen/pipeline/strip_unused_types_test.go +++ b/hack/generator/pkg/codegen/pipeline/strip_unused_types_test.go @@ -43,9 +43,9 @@ func TestConnectionChecker_Avoids_Cycles(t *testing.T) { graph := astmodel.NewReferenceGraph(roots, references) connectedSet := graph.Connected() - var names astmodel.TypeNameSet + names := astmodel.NewTypeNameSet() for name := range connectedSet { - names = names.Add(name) + names.Add(name) } g.Expect(names).To(Equal(makeSet("res1", "res2", "A", "B", "C", "D"))) diff --git a/hack/generator/pkg/config/configuration.go b/hack/generator/pkg/config/configuration.go index ebc7bd5e562..0848aad215a 100644 --- a/hack/generator/pkg/config/configuration.go +++ b/hack/generator/pkg/config/configuration.go @@ -328,7 +328,7 @@ func buildExportFilterFunc(f *ExportFilter, allTypes astmodel.Types) ExportFilte } case ExportFilterIncludeTransitive: - applicableTypes := make(astmodel.TypeNameSet) + applicableTypes := astmodel.NewTypeNameSet() for tn := range allTypes { if f.AppliesToType(tn) { collectAllReferencedTypes(allTypes, tn, applicableTypes) diff --git a/hack/generator/pkg/config/type_matcher.go b/hack/generator/pkg/config/type_matcher.go index d50c340613a..4f49f92e4cc 100644 --- a/hack/generator/pkg/config/type_matcher.go +++ b/hack/generator/pkg/config/type_matcher.go @@ -41,7 +41,8 @@ func (t *TypeMatcher) Initialize() error { t.groupRegex = createGlobbingRegex(t.Group) t.versionRegex = createGlobbingRegex(t.Version) t.nameRegex = createGlobbingRegex(t.Name) - t.matchedTypes = make(astmodel.TypeNameSet) + t.matchedTypes = astmodel.NewTypeNameSet() + // Default MatchRequired if t.MatchRequired == nil { temp := true @@ -87,7 +88,11 @@ func (t *TypeMatcher) AppliesToType(typeName astmodel.TypeName) bool { // Track this match so we can later report if we didn't match anything if result { - t.matchedTypes = t.matchedTypes.Add(typeName) + if t.matchedTypes == nil { + t.matchedTypes = astmodel.NewTypeNameSet(typeName) + } else { + t.matchedTypes.Add(typeName) + } } return result