Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support postgres, mysql diarect #42

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions internal/arcgen/lang/go/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,7 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
if config.GenerateGoCRUDPackage() {
crudFileExt := ".crud" + genFileExt

if err := func() error {
filename := filepath.Join(config.GoCRUDPackagePath(), "common"+crudFileExt)
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, rw_r__r__)
if err != nil {
return errorz.Errorf("os.OpenFile: %w", err)
}
defer f.Close()

if err := fprintCRUDCommon(f, bytes.NewBuffer(nil), arcSrcSetSlice); err != nil {
return errorz.Errorf("sprint: %w", err)
}

return nil
}(); err != nil {
return errorz.Errorf("f: %w", err)
}

crudFiles := make([]string, 0)
for _, arcSrcSet := range arcSrcSetSlice {
// closure for defer
if err := func() error {
Expand All @@ -84,7 +68,7 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
return errorz.Errorf("os.OpenFile: %w", err)
}
defer f.Close()
f.Name()
crudFiles = append(crudFiles, filename)

if err := fprintCRUD(
f,
Expand All @@ -98,6 +82,23 @@ func generate(arcSrcSetSlice ARCSourceSetSlice) error {
return errorz.Errorf("f: %w", err)
}
}

if err := func() error {
filename := filepath.Join(config.GoCRUDPackagePath(), "common"+crudFileExt)
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, rw_r__r__)
if err != nil {
return errorz.Errorf("os.OpenFile: %w", err)
}
defer f.Close()

if err := fprintCRUDCommon(f, bytes.NewBuffer(nil), arcSrcSetSlice, crudFiles); err != nil {
return errorz.Errorf("sprint: %w", err)
}

return nil
}(); err != nil {
return errorz.Errorf("f: %w", err)
}
}

return nil
Expand Down
174 changes: 149 additions & 25 deletions internal/arcgen/lang/go/generate_crud_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@ package arcgengo

import (
"go/ast"
"go/parser"
"go/printer"
"go/token"
"io"
"path/filepath"
"strconv"
"strings"

errorz "github.com/kunitsucom/util.go/errors"

"github.com/kunitsucom/arcgen/internal/arcgen/lang/util"
"github.com/kunitsucom/arcgen/internal/config"
)

func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice) error {
content, err := generateCRUDCommonFileContent(buf, arcSrcSetSlice)
func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlice, crudFiles []string) error {
content, err := generateCRUDCommonFileContent(buf, arcSrcSetSlice, crudFiles)
if err != nil {
return errorz.Errorf("generateCRUDCommonFileContent: %w", err)
}
Expand All @@ -27,8 +30,13 @@ func fprintCRUDCommon(osFile osFile, buf buffer, arcSrcSetSlice ARCSourceSetSlic
return nil
}

//nolint:funlen
func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, error) {
const (
sqlQueryerContextVarName = "sqlContext"
sqlQueryerContextTypeName = "sqlQueryerContext"
)

//nolint:cyclop,funlen,gocognit,maintidx
func generateCRUDCommonFileContent(buf buffer, arcSrcSetSlice ARCSourceSetSlice, crudFiles []string) (string, error) {
astFile := &ast.File{
// package
Name: &ast.Ident{
Expand All @@ -38,18 +46,19 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
Decls: []ast.Decl{},
}

// // Since all directories are the same from arcSrcSetSlice[0].Filename to arcSrcSetSlice[len(-1)].Filename,
// // get the package path from arcSrcSetSlice[0].Filename.
// dir := filepath.Dir(arcSrcSetSlice[0].Filename)
// structPackagePath, err := util.GetPackagePath(dir)
// if err != nil {
// return "", errorz.Errorf("GetPackagePath: %w", err)
// }
// Since all directories are the same from arcSrcSetSlice[0].Filename to arcSrcSetSlice[len(-1)].Filename,
// get the package path from arcSrcSetSlice[0].Filename.
dir := filepath.Dir(arcSrcSetSlice[0].Filename)
structPackagePath, err := util.GetPackagePath(dir)
if err != nil {
return "", errorz.Errorf("GetPackagePath: %w", err)
}

astFile.Decls = append(astFile.Decls,
// import (
// "context"
// "database/sql"
// "log/slog"
//
// dao "path/to/your/dao"
// )
Expand All @@ -62,15 +71,18 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
&ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("database/sql")},
},
// &ast.ImportSpec{
// Name: &ast.Ident{Name: "dao"},
// Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(structPackagePath)},
// },
&ast.ImportSpec{
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("log/slog")},
},
&ast.ImportSpec{
Name: &ast.Ident{Name: "dao"},
Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(structPackagePath)},
},
},
},
)

// type sqlContext interface {
// type sqlQueryerContext interface {
// QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
// QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
// ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Expand All @@ -80,7 +92,8 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
Tok: token.TYPE,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: &ast.Ident{Name: "sqlContext"},
// Assign: token.Pos(1),
Name: &ast.Ident{Name: sqlQueryerContextTypeName},
Type: &ast.InterfaceType{
Methods: &ast.FieldList{
List: []*ast.Field{
Expand Down Expand Up @@ -133,27 +146,138 @@ func generateCRUDCommonFileContent(buf buffer, _ ARCSourceSetSlice) (string, err
},
)

// type Queryer struct {}
// type _CRUD struct {
// }
astFile.Decls = append(astFile.Decls,
&ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: &ast.Ident{Name: "Queryer"},
Name: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()},
Type: &ast.StructType{Fields: &ast.FieldList{}},
},
},
},
)

// func NewQueryer() *Query {
// return &Queryer{}
// }
// func LoggerFromContext(ctx context.Context) *slog.Logger {
// if ctx == nil {
// return slog.Default()
// }
// if logger, ok := ctx.Value((*slog.Logger)(nil)).(*slog.Logger); ok {
// return logger
// }
// return slog.Default()
// }
astFile.Decls = append(astFile.Decls,
&ast.FuncDecl{
Name: &ast.Ident{Name: "LoggerFromContext"},
Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}}}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}}}},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.IfStmt{
Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "ctx"}, Op: token.EQL, Y: &ast.Ident{Name: "nil"}},
Body: &ast.BlockStmt{List: []ast.Stmt{
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Default"}}}}},
}},
},
&ast.IfStmt{
// if logger, ok := ctx.Value((*slog.Logger)(nil)).(*slog.Logger); ok {
Init: &ast.AssignStmt{
Lhs: []ast.Expr{&ast.Ident{Name: "logger"}, &ast.Ident{Name: "ok"}},
Tok: token.DEFINE,
Rhs: []ast.Expr{
&ast.TypeAssertExpr{
X: &ast.CallExpr{
Fun: &ast.Ident{Name: "ctx.Value"},
Args: []ast.Expr{&ast.CallExpr{Fun: &ast.ParenExpr{X: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}, Args: []ast.Expr{&ast.Ident{Name: "nil"}}}},
},
Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}},
},
},
},
Cond: &ast.Ident{Name: "ok"},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.Ident{Name: "logger"}}}}},
},
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Default"}}}}},
},
},
},
)

// func LoggerWithContext(ctx context.Context, logger *slog.Logger) context.Context {
// return context.WithValue(ctx, (*slog.Logger)(nil), logger)
// }
astFile.Decls = append(astFile.Decls,
&ast.FuncDecl{
Name: &ast.Ident{Name: "LoggerWithContext"},
Type: &ast.FuncType{Params: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}}, {Names: []*ast.Ident{{Name: "logger"}}, Type: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}}}, Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: "context.Context"}}}}},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "context"}, Sel: &ast.Ident{Name: "WithValue"}}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}, &ast.CallExpr{Fun: &ast.ParenExpr{X: &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "slog"}, Sel: &ast.Ident{Name: "Logger"}}}}, Args: []ast.Expr{&ast.Ident{Name: "nil"}}}, &ast.Ident{Name: "logger"}}}}},
},
},
},
)

// type CRUD interface {
// Create{StructName}(ctx context.Context, sqlQueryer sqlQueryerContext, s *{Struct}) error
// ...
// }
methods := make([]*ast.Field, 0)
fset := token.NewFileSet()
for _, crudFile := range crudFiles {
rootNode, err := parser.ParseFile(fset, crudFile, nil, parser.ParseComments)
if err != nil {
// MEMO: parser.ParseFile err contains file path, so no need to log it
return "", errorz.Errorf("parser.ParseFile: %w", err)
}

// MEMO: Inspect is used to get the method declaration from the file
ast.Inspect(rootNode, func(node ast.Node) bool {
switch n := node.(type) {
case *ast.FuncDecl:
//nolint:nestif
if n.Recv != nil && len(n.Recv.List) > 0 {
if t, ok := n.Recv.List[0].Type.(*ast.StarExpr); ok {
if ident, ok := t.X.(*ast.Ident); ok {
if ident.Name == config.GoCRUDTypeNameUnexported() {
methods = append(methods, &ast.Field{
Names: []*ast.Ident{{Name: n.Name.Name}},
Type: n.Type,
})
}
}
}
}
default:
// noop
}
return true
})
}
astFile.Decls = append(astFile.Decls,
&ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: &ast.Ident{Name: config.GoCRUDTypeName()},
Type: &ast.InterfaceType{
Methods: &ast.FieldList{List: methods},
},
},
},
},
)

// func NewCRUD() CRUD {
// return &_CRUD{}
// }
astFile.Decls = append(astFile.Decls,
&ast.FuncDecl{
Name: &ast.Ident{Name: "NewQueryer"},
Type: &ast.FuncType{Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.Ident{Name: "Queryer{}"}}}}}},
Name: &ast.Ident{Name: "New" + config.GoCRUDTypeName()},
Type: &ast.FuncType{Results: &ast.FieldList{List: []*ast.Field{{Type: &ast.Ident{Name: config.GoCRUDTypeName()}}}}},
Body: &ast.BlockStmt{List: []ast.Stmt{&ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported() + "{}"}}}}}},
},
)

Expand Down
37 changes: 25 additions & 12 deletions internal/arcgen/lang/go/generate_crud_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"go/token"
"strconv"
"strings"

"github.com/kunitsucom/arcgen/internal/config"
)

//nolint:funlen
Expand All @@ -13,13 +15,14 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
structName := arcSrc.extractStructName()
tableName := arcSrc.extractTableNameFromCommentGroup()
tableInfo := arcSrc.extractFieldNamesAndColumnNames()
columnNames := tableInfo.ColumnNames()
columnNames := tableInfo.Columns.ColumnNames()

// const Create{StructName}Query = `INSERT INTO {table_name} ({column_name1}, {column_name2}) VALUES (?, ?)`
// const Create{StructName}Query = `INSERT INTO {table_name} ({column_name1}, {column_name2}) VALUES ($1, $2)`
//
// func (q *query) Create{StructName}(ctx context.Context, queryer sqlContext, s *{Struct}) error {
// if _, err := queryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
// return fmt.Errorf("q.queryer.ExecContext: %w", err)
// func (q *query) Create{StructName}(ctx context.Context, queryer sqlQueryerContext, s *{Struct}) error {
// LoggerFromContext(ctx).Debug(Create{StructName}Query)
// if _, err := sqlContext.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
// return fmt.Errorf("sqlContext.ExecContext: %w", err)
// }
// return nil
// }
Expand All @@ -33,18 +36,18 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
Names: []*ast.Ident{{Name: queryName}},
Values: []ast.Expr{&ast.BasicLit{
Kind: token.STRING,
Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (?" + strings.Repeat(", ?", len(columnNames)-1) + ")`",
Value: "`INSERT INTO " + tableName + " (" + strings.Join(columnNames, ", ") + ") VALUES (" + columnValuesPlaceholder(columnNames) + ")`",
}},
},
},
},
&ast.FuncDecl{
Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "Queryer"}}}}},
Recv: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "q"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: config.GoCRUDTypeNameUnexported()}}}}},
Name: &ast.Ident{Name: funcName},
Type: &ast.FuncType{
Params: &ast.FieldList{List: []*ast.Field{
{Names: []*ast.Ident{{Name: "ctx"}}, Type: &ast.Ident{Name: "context.Context"}},
{Names: []*ast.Ident{{Name: "sqlCtx"}}, Type: &ast.Ident{Name: "sqlContext"}},
{Names: []*ast.Ident{{Name: sqlQueryerContextVarName}}, Type: &ast.Ident{Name: sqlQueryerContextTypeName}},
{Names: []*ast.Ident{{Name: "s"}}, Type: &ast.StarExpr{X: &ast.Ident{Name: "dao." + structName}}},
}},
Results: &ast.FieldList{List: []*ast.Field{
Expand All @@ -53,14 +56,24 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{
// LoggerFromContext(ctx).Debug(queryName)
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.CallExpr{Fun: &ast.Ident{Name: "LoggerFromContext"}, Args: []ast.Expr{&ast.Ident{Name: "ctx"}}},
Sel: &ast.Ident{Name: "Debug"},
},
Args: []ast.Expr{&ast.Ident{Name: queryName}},
},
},
&ast.IfStmt{
// if _, err := queryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
// if _, err := sqlQueryer.ExecContext(ctx, Create{StructName}Query, s.{ColumnName1}, s.{ColumnName2}); err != nil {
Init: &ast.AssignStmt{
Lhs: []ast.Expr{&ast.Ident{Name: "_"}, &ast.Ident{Name: "err"}},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "sqlCtx"},
X: &ast.Ident{Name: sqlQueryerContextVarName},
Sel: &ast.Ident{Name: "ExecContext"},
},
Args: append(
Expand All @@ -80,10 +93,10 @@ func generateCREATEContent(astFile *ast.File, arcSrcSet *ARCSourceSet) {
// err != nil {
Cond: &ast.BinaryExpr{X: &ast.Ident{Name: "err"}, Op: token.NEQ, Y: &ast.Ident{Name: "nil"}},
Body: &ast.BlockStmt{List: []ast.Stmt{
// return fmt.Errorf("queryer.ExecContext: %w", err)
// return fmt.Errorf("sqlContext.ExecContext: %w", err)
&ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "fmt"}, Sel: &ast.Ident{Name: "Errorf"}},
Args: []ast.Expr{&ast.Ident{Name: strconv.Quote("queryer.ExecContext: %w")}, &ast.Ident{Name: "err"}},
Args: []ast.Expr{&ast.Ident{Name: strconv.Quote(sqlQueryerContextVarName + ".ExecContext: %w")}, &ast.Ident{Name: "err"}},
}}},
}},
},
Expand Down
Loading
Loading