Skip to content

Commit

Permalink
fix: Support postgres, mysql diarect (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
ginokent authored Aug 12, 2024
2 parents a64c8e7 + 35b5276 commit f8e328c
Show file tree
Hide file tree
Showing 11 changed files with 439 additions and 155 deletions.
37 changes: 19 additions & 18 deletions internal/arcgen/lang/go/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,7 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
if config.GenerateGoCRUDPackage() {
crudFileExt := ".crud" + genFileExt

if err := func() error {
filename := filepath.Join(config.GoCRUDPackagePath(), "common"+crudFileExt)
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, rw_r__r__)
if err != nil {
return errorz.Errorf("os.OpenFile: %w", err)
}
defer f.Close()

if err := fprintCRUDCommon(f, bytes.NewBuffer(nil), arcSrcSetSlice); err != nil {
return errorz.Errorf("sprint: %w", err)
}

return nil
}(); err != nil {
return errorz.Errorf("f: %w", err)
}

crudFiles := make([]string, 0)
for _, arcSrcSet := range arcSrcSetSlice {
// closure for defer
if err := func() error {
Expand All @@ -84,7 +68,7 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
return errorz.Errorf("os.OpenFile: %w", err)
}
defer f.Close()
f.Name()
crudFiles = append(crudFiles, filename)

if err := fprintCRUD(
f,
Expand All @@ -98,6 +82,23 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
return errorz.Errorf("f: %w", err)
}
}

if err := func() error {
filename := filepath.Join(config.GoCRUDPackagePath(), "common"+crudFileExt)
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, rw_r__r__)
if err != nil {
return errorz.Errorf("os.OpenFile: %w", err)
}
defer f.Close()

if err := fprintCRUDCommon(f, bytes.NewBuffer(nil), arcSrcSetSlice, crudFiles); err != nil {
return errorz.Errorf("sprint: %w", err)
}

return nil
}(); err != nil {
return errorz.Errorf("f: %w", err)
}
}

return nil
Expand Down
174 changes: 149 additions & 25 deletions internal/arcgen/lang/go/generate_crud_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@ package arcgengo

import (
"go/ast"
"go/parser"
"go/printer"
"go/token"
"io"
"path/filepath"
"strconv"
"strings"

errorz "github.com/kunitsucom/util.go/errors"

"github.com/kunitsucom/arcgen/internal/arcgen/lang/util"
"github.com/kunitsucom/arcgen/internal/config"
)

func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice) error {
content, err := generateCRUDCommonFileContent(buf, arcSrcSetSlice)
func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice, crudFiles []string) error {
content, err := generateCRUDCommonFileContent(buf, arcSrcSetSlice, crudFiles)
if err != nil {
return errorz.Errorf("generateCRUDCommonFileContent: %w", err)
}
Expand All @@ -27,8 +30,13 @@ func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlic
return nil
}

//nolint:funlen
func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, error) {
const (
sqlQueryerContextVarName = "sqlContext"
sqlQueryerContextTypeName = "sqlQueryerContext"
)

//nolint:cyclop,funlen,gocognit,maintidx
func generateCRUDCommonFileContent(buf buffer, arcSrcSetSlice ARCSourceSetSlice, crudFiles []string) (string, error) {
astFile := &ast.File{
// package
Name: &ast.Ident{
Expand All @@ -38,18 +46,19 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
Decls: []ast.Decl{},
}

// // Since all directories are the same from arcSrcSetSlice[0].Filename to arcSrcSetSlice[len(-1)].Filename,
// // get the package path from arcSrcSetSlice[0].Filename.
// dir := filepath.Dir(arcSrcSetSlice[0].Filename)
// structPackagePath, err := util.GetPackagePath(dir)
// if err != nil {
// return "", errorz.Errorf("GetPackagePath: %w", err)
// }
// Since all directories are the same from arcSrcSetSlice[0].Filename to arcSrcSetSlice[len(-1)].Filename,
// get the package path from arcSrcSetSlice[0].Filename.
dir := filepath.Dir(arcSrcSetSlice[0].Filename)
structPackagePath, err := util.GetPackagePath(dir)
if err != nil {
return "", errorz.Errorf("GetPackagePath: %w", err)
}

astFile.Decls = append(astFile.Decls,
// import (
// "context"
// "database/sql"
// "log/slog"
//
// dao "path/to/your/dao"
// )
Expand All @@ -62,15 +71,18 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
&ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("database/sql")},
},
// &ast.ImportSpec{
// Name: &ast.Ident{Name: "dao"},
// Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(structPackagePath)},
// },
&ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("log/slog")},
},
&ast.ImportSpec{
Name: &ast.Ident{Name: "dao"},
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(structPackagePath)},
},
},
},
)

// type sqlContext interface {
// type sqlQueryerContext interface {
// QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
// QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
// ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Expand All @@ -80,7 +92,8 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
Tok: token.TYPE,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: &ast.Ident{Name: "sqlContext"},
// Assign: token.Pos(1),
Name: &ast.Ident{Name: sqlQueryerContextTypeName},
Type: &ast.InterfaceType{
Methods: &ast.FieldList{
List: []*ast.Field{
Expand Down Expand Up @@ -133,27 +146,138 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
},
)

// type Queryer struct {}
// type _CRUD struct {
// }
astFile.Decls = append(astFile.Decls,
&ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: &ast.Ident{Name: "Queryer"},
Name: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()},
Type: &ast.StructType{Fields: &ast.FieldList{}},
},
},
},
)

// func NewQueryer() *Query {
// return &Queryer{}
// }
// func LoggerFromContext(ctx context.Context) *slog.Logger {
// if ctx == nil {
// return slog.Default()
// }
// if logger, ok := ctx.Value((*slog.Logger)(nil)).(*slog.Logger); ok {
// return logger
// }
// return slog.Default()
// }
astFile.Decls = append(astFile.Decls,
&ast.FuncDecl{
Name: &ast.Ident{Name: "LoggerFromContext"},
Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}}}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}}}},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.IfStmt{
Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "ctx"}, Op: token.EQL, Y: &ast.Ident{Name: "nil"}},
Body: &ast.BlockStmt{List: []ast.Stmt{
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Default"}}}}},
}},
},
&ast.IfStmt{
// if logger, ok := ctx.Value((*slog.Logger)(nil)).(*slog.Logger); ok {
Init: &ast.AssignStmt{
Lhs: []ast.Expr{&ast.Ident{Name: "logger"}, &ast.Ident{Name: "ok"}},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.TypeAssertExpr{
X: &ast.CallExpr{
Fun: &ast.Ident{Name: "ctx.Value"},
Args: []ast.Expr{&ast.CallExpr{Fun: &ast.ParenExpr{X: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}, Args: []ast.Expr{&ast.Ident{Name: "nil"}}}},
},
Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}},
},
},
},
Cond: &ast.Ident{Name: "ok"},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.Ident{Name: "logger"}}}}},
},
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Default"}}}}},
},
},
},
)

// func LoggerWithContext(ctx context.Context, logger *slog.Logger) context.Context {
// return context.WithValue(ctx, (*slog.Logger)(nil), logger)
// }
astFile.Decls = append(astFile.Decls,
&ast.FuncDecl{
Name: &ast.Ident{Name: "LoggerWithContext"},
Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}, {Names: []*ast.Ident{{Name: "logger"}}, Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}}}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: "context.Context"}}}}},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "context"}, Sel: &ast.Ident{Name: "WithValue"}}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}, &ast.CallExpr{Fun: &ast.ParenExpr{X: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}, Args: []ast.Expr{&ast.Ident{Name: "nil"}}}, &ast.Ident{Name: "logger"}}}}},
},
},
},
)

// type CRUD interface {
// Create{StructName}(ctx context.Context, sqlQueryer sqlQueryerContext, s *{Struct}) error
// ...
// }
methods := make([]*ast.Field, 0)
fset := token.NewFileSet()
for _, crudFile := range crudFiles {
rootNode, err := parser.ParseFile(fset, crudFile, nil, parser.ParseComments)
if err != nil {
// MEMO: parser.ParseFile err contains file path, so no need to log it
return "", errorz.Errorf("parser.ParseFile: %w", err)
}

// MEMO: Inspect is used to get the method declaration from the file
ast.Inspect(rootNode, func(node ast.Node) bool {
switch n := node.(type) {
case *ast.FuncDecl:
//nolint:nestif
if n.Recv != nil && len(n.Recv.List) > 0 {
if t, ok := n.Recv.List[0].Type.(*ast.StarExpr); ok {
if ident, ok := t.X.(*ast.Ident); ok {
if ident.Name == config.GoCRUDTypeNameUnexported() {
methods = append(methods, &ast.Field{
Names: []*ast.Ident{{Name: n.Name.Name}},
Type: n.Type,
})
}
}
}
}
default:
// noop
}
return true
})
}
astFile.Decls = append(astFile.Decls,
&ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: &ast.Ident{Name: config.GoCRUDTypeName()},
Type: &ast.InterfaceType{
Methods: &ast.FieldList{List: methods},
},
},
},
},
)

// func NewCRUD() CRUD {
// return &_CRUD{}
// }
astFile.Decls = append(astFile.Decls,
&ast.FuncDecl{
Name: &ast.Ident{Name: "NewQueryer"},
Type: &ast.FuncType{Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.Ident{Name: "Queryer{}"}}}}}},
Name: &ast.Ident{Name: "New" + config.GoCRUDTypeName()},
Type: &ast.FuncType{Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: config.GoCRUDTypeName()}}}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported() + "{}"}}}}}},
},
)

Expand Down
37 changes: 25 additions & 12 deletions internal/arcgen/lang/go/generate_crud_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"go/token"
"strconv"
"strings"

"github.com/kunitsucom/arcgen/internal/config"
)

//nolint:funlen
Expand All @@ -13,13 +15,14 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
structName := arcSrc.extractStructName()
tableName := arcSrc.extractTableNameFromCommentGroup()
tableInfo := arcSrc.extractFieldNamesAndColumnNames()
columnNames := tableInfo.ColumnNames()
columnNames := tableInfo.Columns.ColumnNames()

// const Create{StructName}Query = `INSERT INTO {table_name} ({column_name1}, {column_name2}) VALUES (?, ?)`
// const Create{StructName}Query = `INSERT INTO {table_name} ({column_name1}, {column_name2}) VALUES ($1, $2)`
//
// func (q *query) Create{StructName}(ctx context.Context, queryer sqlContext, s *{Struct}) error {
// if _, err := queryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
// return fmt.Errorf("q.queryer.ExecContext: %w", err)
// func (q *query) Create{StructName}(ctx context.Context, queryer sqlQueryerContext, s *{Struct}) error {
// LoggerFromContext(ctx).Debug(Create{StructName}Query)
// if _, err := sqlContext.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
// return fmt.Errorf("sqlContext.ExecContext: %w", err)
// }
// return nil
// }
Expand All @@ -33,18 +36,18 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
Names: []*ast.Ident{{Name: queryName}},
Values: []ast.Expr{&ast.BasicLit{
Kind: token.STRING,
Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (?" + strings.Repeat(", ?", len(columnNames)-1) + ")`",
Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (" + columnValuesPlaceholder(columnNames) + ")`",
}},
},
},
},
&ast.FuncDecl{
Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}},
Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}}}},
Name: &ast.Ident{Name: funcName},
Type: &ast.FuncType{
Params: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}},
{Names: []*ast.Ident{{Name: "sqlCtx"}}, Type: &ast.Ident{Name: "sqlContext"}},
{Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, Type: &ast.Ident{Name: sqlQueryerContextTypeName}},
{Names: []*ast.Ident{{Name: "s"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "dao." + structName}}},
}},
Results: &ast.FieldList{List: []*ast.Field{
Expand All @@ -53,14 +56,24 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{
// LoggerFromContext(ctx).Debug(queryName)
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}},
Sel: &ast.Ident{Name: "Debug"},
},
Args: []ast.Expr{&ast.Ident{Name: queryName}},
},
},
&ast.IfStmt{
// if _, err := queryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
// if _, err := sqlQueryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
Init: &ast.AssignStmt{
Lhs: []ast.Expr{&ast.Ident{Name: "_"}, &ast.Ident{Name: "err"}},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "sqlCtx"},
X: &ast.Ident{Name: sqlQueryerContextVarName},
Sel: &ast.Ident{Name: "ExecContext"},
},
Args: append(
Expand All @@ -80,10 +93,10 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
// err != nil {
Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "err"}, Op: token.NEQ, Y: &ast.Ident{Name: "nil"}},
Body: &ast.BlockStmt{List: []ast.Stmt{
// return fmt.Errorf("queryer.ExecContext: %w", err)
// return fmt.Errorf("sqlContext.ExecContext: %w", err)
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "fmt"}, Sel: &ast.Ident{Name: "Errorf"}},
Args: []ast.Expr{&ast.Ident{Name: strconv.Quote("queryer.ExecContext: %w")}, &ast.Ident{Name: "err"}},
Args: []ast.Expr{&ast.Ident{Name: strconv.Quote(sqlQueryerContextVarName + ".ExecContext: %w")}, &ast.Ident{Name: "err"}},
}}},
}},
},
Expand Down
Loading

0 comments on commit f8e328c

Please sign in to comment.