Skip to content

Commit

Permalink
fix parse nested structs and aliases (#1866)
Browse files Browse the repository at this point in the history
Co-authored-by: ma.mikhaylov <ma.mikhaylov@tinkoff.ru>
  • Loading branch information
zdon0 and ma.mikhaylov authored Aug 20, 2024
1 parent c7f1cd8 commit 10030b0
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 24 deletions.
20 changes: 18 additions & 2 deletions generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func (pkgDefs *PackagesDefinitions) getTypeFromGenericParam(genericParam string,
Enums: typeSpecDef.Enums,
PkgPath: typeSpecDef.PkgPath,
ParentSpec: typeSpecDef.ParentSpec,
SchemaName: "array_" + typeSpecDef.SchemaName,
NotUnique: false,
}
}
Expand Down Expand Up @@ -96,16 +97,17 @@ func (pkgDefs *PackagesDefinitions) getTypeFromGenericParam(genericParam string,
Enums: typeSpecDef.Enums,
PkgPath: typeSpecDef.PkgPath,
ParentSpec: typeSpecDef.ParentSpec,
SchemaName: "map_" + parts[0] + "_" + typeSpecDef.SchemaName,
NotUnique: false,
}

}
if IsGolangPrimitiveType(genericParam) {
return &TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: ast.NewIdent(genericParam),
Type: ast.NewIdent(genericParam),
},
SchemaName: genericParam,
}
}
return pkgDefs.FindTypeSpec(genericParam, file)
Expand Down Expand Up @@ -155,14 +157,27 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi
}

name = fmt.Sprintf("%s%s-", string(IgnoreNameOverridePrefix), original.TypeName())
schemaName := fmt.Sprintf("%s-", original.SchemaName)

var nameParts []string
var schemaNameParts []string

for _, def := range formals {
if specDef, ok := genericParamTypeDefs[def.Name]; ok {
nameParts = append(nameParts, specDef.TypeName())
nameParts = append(nameParts, specDef.Name)

schemaNamePart := specDef.Name

if specDef.TypeSpec != nil {
schemaNamePart = specDef.TypeSpec.SchemaName
}

schemaNameParts = append(schemaNameParts, schemaNamePart)
}
}

name += normalizeGenericTypeName(strings.Join(nameParts, "-"))
schemaName += normalizeGenericTypeName(strings.Join(schemaNameParts, "-"))

if typeSpec, ok := pkgDefs.uniqueDefinitions[name]; ok {
return typeSpec
Expand All @@ -180,6 +195,7 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi
Doc: original.TypeSpec.Doc,
Assign: original.TypeSpec.Assign,
},
SchemaName: schemaName,
}
pkgDefs.uniqueDefinitions[name] = parametrizedTypeSpec

Expand Down
24 changes: 17 additions & 7 deletions packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packag
pkgDefs.uniqueDefinitions[fullName] = nil
anotherTypeDef.NotUnique = true
pkgDefs.uniqueDefinitions[anotherTypeDef.TypeName()] = anotherTypeDef
anotherTypeDef.SetSchemaName()

typeSpecDef.NotUnique = true
fullName = typeSpecDef.TypeName()
pkgDefs.uniqueDefinitions[fullName] = typeSpecDef
Expand All @@ -174,6 +176,8 @@ func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packag
pkgDefs.uniqueDefinitions[fullName] = typeSpecDef
}

typeSpecDef.SetSchemaName()

if pkgDefs.packages[typeSpecDef.PkgPath] == nil {
pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name, typeSpecDef.PkgPath).AddTypeSpec(typeSpecDef.Name(), typeSpecDef)
} else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok {
Expand Down Expand Up @@ -579,17 +583,23 @@ func (pkgDefs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File
return typeDef
}

//in case that comment //@name renamed the type with a name without a dot
typeDef, ok = pkgDefs.uniqueDefinitions[typeName]
if ok {
return typeDef
}

name := parts[0]
typeDef, ok = pkgDefs.uniqueDefinitions[fullTypeName(file.Name.Name, name)]
if !ok {
pkgPaths, externalPkgPaths := pkgDefs.findPackagePathFromImports("", file)
typeDef = pkgDefs.findTypeSpecFromPackagePaths(pkgPaths, externalPkgPaths, name)
}
return pkgDefs.parametrizeGenericType(file, typeDef, typeName)

if typeDef != nil {
return pkgDefs.parametrizeGenericType(file, typeDef, typeName)
}

//in case that comment //@name renamed the type with a name without a dot
for _, v := range pkgDefs.uniqueDefinitions {
if v.SchemaName == typeName {
return v
}
}

return nil
}
8 changes: 7 additions & 1 deletion parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1316,8 +1316,14 @@ func (parser *Parser) ParseDefinition(typeSpecDef *TypeSpecDef) (*Schema, error)
}
}

schemaName := typeName

if typeSpecDef.SchemaName != "" {
schemaName = typeSpecDef.SchemaName
}

sch := Schema{
Name: typeName,
Name: schemaName,
PkgPath: typeSpecDef.PkgPath,
Schema: definition,
}
Expand Down
19 changes: 19 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4367,3 +4367,22 @@ func Test(){
assert.True(t, ok)
assert.NotNil(t, val2.Get)
}

func TestParser_EmbeddedStructAsOtherAliasGoListNested(t *testing.T) {
t.Parallel()

p := New(SetParseDependency(1), ParseUsingGoList(true))

p.parseGoList = true

searchDir := "testdata/alias_nested"
expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json"))
assert.NoError(t, err)

err = p.ParseAPI(searchDir, "cmd/main/main.go", 0)
assert.NoError(t, err)

b, err := json.MarshalIndent(p.swagger, "", " ")
assert.NoError(t, err)
assert.Equal(t, string(expected), string(b))
}
9 changes: 9 additions & 0 deletions testdata/alias_nested/cmd/main/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package main

import "github.com/swaggo/swag/testdata/alias_nested/pkg/good"

// @Success 200 {object} good.Gen
// @Router /api [get].
func main() {
var _ good.Gen
}
38 changes: 38 additions & 0 deletions testdata/alias_nested/expected.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"swagger": "2.0",
"info": {
"contact": {}
},
"paths": {
"/api": {
"get": {
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/Gen"
}
}
}
}
}
},
"definitions": {
"Gen": {
"type": "object",
"properties": {
"emb": {
"$ref": "#/definitions/github_com_swaggo_swag_testdata_alias_nested_pkg_good.Emb"
}
}
},
"github_com_swaggo_swag_testdata_alias_nested_pkg_good.Emb": {
"type": "object",
"properties": {
"good": {
"type": "boolean"
}
}
}
}
}
5 changes: 5 additions & 0 deletions testdata/alias_nested/pkg/bad/data.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package bad

type Emb struct {
Bad bool `json:"bad"`
} // @name Emb
9 changes: 9 additions & 0 deletions testdata/alias_nested/pkg/good/data.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package good

type Gen struct {
Emb Emb `json:"emb"`
} // @name Gen

type Emb struct {
Good bool `json:"good"`
}
46 changes: 32 additions & 14 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type TypeSpecDef struct {
PkgPath string
ParentSpec ast.Decl

SchemaName string

NotUnique bool
}

Expand All @@ -46,20 +48,6 @@ func (t *TypeSpecDef) Name() string {
func (t *TypeSpecDef) TypeName() string {
if ignoreNameOverride(t.TypeSpec.Name.Name) {
return t.TypeSpec.Name.Name[1:]
} else if t.TypeSpec.Comment != nil {
// get alias from comment '// @name '
const regexCaseInsensitive = "(?i)"
reTypeName, err := regexp.Compile(regexCaseInsensitive + `^@name\s+(\S+)`)
if err != nil {
panic(err)
}
for _, comment := range t.TypeSpec.Comment.List {
trimmedComment := strings.TrimSpace(strings.TrimLeft(comment.Text, "/"))
texts := reTypeName.FindStringSubmatch(trimmedComment)
if len(texts) > 1 {
return texts[1]
}
}
}

var names []string
Expand All @@ -86,6 +74,36 @@ func (t *TypeSpecDef) FullPath() string {
return t.PkgPath + "." + t.Name()
}

const regexCaseInsensitive = "(?i)"

var reTypeName = regexp.MustCompile(regexCaseInsensitive + `^@name\s+(\S+)`)

func (t *TypeSpecDef) Alias() string {
if t.TypeSpec.Comment == nil {
return ""
}

// get alias from comment '// @name '
for _, comment := range t.TypeSpec.Comment.List {
trimmedComment := strings.TrimSpace(strings.TrimLeft(comment.Text, "/"))
texts := reTypeName.FindStringSubmatch(trimmedComment)
if len(texts) > 1 {
return texts[1]
}
}

return ""
}

func (t *TypeSpecDef) SetSchemaName() {
if alias := t.Alias(); alias != "" {
t.SchemaName = alias
return
}

t.SchemaName = t.TypeName()
}

// AstFileInfo information of an ast.File.
type AstFileInfo struct {
//FileSet the FileSet object which is used to parse this go source file
Expand Down

0 comments on commit 10030b0

Please sign in to comment.