Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/auto increment add rest handlers #930

Merged
merged 11 commits into from
Aug 21, 2024
146 changes: 121 additions & 25 deletions pkg/gofr/crud_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (
)

var (
errInvalidObject = errors.New("unexpected object given for AddRESTHandlers")
errEntityNotFound = errors.New("entity not found")
errObjectIsNil = errors.New("object given for AddRESTHandlers is nil")
errNonPointerObject = errors.New("passed object is not pointer")
errInvalidObject = errors.New("unexpected object given for AddRESTHandlers")
errEntityNotFound = errors.New("entity not found")
errObjectIsNil = errors.New("object given for AddRESTHandlers is nil")
errNonPointerObject = errors.New("passed object is not pointer")
errFieldCannotBeNull = errors.New("field cannot be null")
errInvalidSQLTag = errors.New("invalid sql tag")
)

type Create interface {
Expand Down Expand Up @@ -54,11 +56,12 @@ type CRUD interface {

// entity stores information about an entity.
type entity struct {
name string
entityType reflect.Type
primaryKey string
tableName string
restPath string
name string
entityType reflect.Type
primaryKey string
tableName string
restPath string
constraints map[string]sql.FieldConstraints
}

// scanEntity extracts entity information for CRUD operations.
Expand Down Expand Up @@ -86,13 +89,52 @@ func scanEntity(object interface{}) (*entity, error) {
tableName := getTableName(object, structName)
restPath := getRestPath(object, structName)

return &entity{
name: structName,
entityType: entityType,
primaryKey: primaryKeyFieldName,
tableName: tableName,
restPath: restPath,
}, nil
e := &entity{
name: structName,
entityType: entityType,
primaryKey: primaryKeyFieldName,
tableName: tableName,
restPath: restPath,
constraints: make(map[string]sql.FieldConstraints),
}

for i := 0; i < entityType.NumField(); i++ {
field := entityType.Field(i)
fieldName := toSnakeCase(field.Name)

sqlTag := field.Tag.Get("sql")
if sqlTag != "" {
constraints, err := parseSQLTag(sqlTag)
if err != nil {
return nil, err
}

e.constraints[fieldName] = constraints
}
}

return e, nil
}

func parseSQLTag(tag string) (sql.FieldConstraints, error) {
var constraints sql.FieldConstraints

tags := strings.Split(tag, ",")
Umang01-hash marked this conversation as resolved.
Show resolved Hide resolved

for _, t := range tags {
Umang01-hash marked this conversation as resolved.
Show resolved Hide resolved
t = strings.ToLower(t) // Convert to lowercase for case-insensitivity

switch t {
case "auto_increment":
constraints.AutoIncrement = true
case "not_null":
constraints.NotNull = true
default:
return constraints, fmt.Errorf("%w: %s", errInvalidSQLTag, t)
}
}

return constraints, nil
}

func getTableName(object any, structName string) string {
Expand Down Expand Up @@ -148,30 +190,84 @@ func (a *App) registerCRUDHandlers(e *entity, object interface{}) {
}

func (e *entity) Create(c *Context) (interface{}, error) {
newEntity, err := e.bindAndValidateEntity(c)
if err != nil {
return nil, err
}

fieldNames, fieldValues := e.extractFields(newEntity)

stmt, err := sql.InsertQuery(c.SQL.Dialect(), e.tableName, fieldNames, fieldValues, e.constraints)
if err != nil {
return nil, err
}

result, err := c.SQL.ExecContext(c, stmt, fieldValues...)
if err != nil {
return nil, err
}

var lastID interface{}

if hasAutoIncrementID(e.constraints) { // Check for auto-increment ID
lastID, err = result.LastInsertId()
if err != nil {
return nil, err
}
} else {
lastID = fieldValues[0]
}

return fmt.Sprintf("%s successfully created with id: %v", e.name, lastID), nil
aryanmehrotra marked this conversation as resolved.
Show resolved Hide resolved
}

func (e *entity) bindAndValidateEntity(c *Context) (interface{}, error) {
newEntity := reflect.New(e.entityType).Interface()
err := c.Bind(newEntity)

err := c.Bind(newEntity)
if err != nil {
return nil, err
}

fieldNames := make([]string, 0, e.entityType.NumField())
fieldValues := make([]interface{}, 0, e.entityType.NumField())
for i := 0; i < e.entityType.NumField(); i++ {
field := e.entityType.Field(i)
fieldName := toSnakeCase(field.Name)

if e.constraints[fieldName].NotNull && reflect.ValueOf(newEntity).Elem().Field(i).Interface() == nil {
return nil, fmt.Errorf("%w: %s", errFieldCannotBeNull, fieldName)
}
}

return newEntity, nil
}

func (e *entity) extractFields(newEntity interface{}) (fieldNames []string, fieldValues []interface{}) {
fieldNames = make([]string, 0, e.entityType.NumField())
fieldValues = make([]interface{}, 0, e.entityType.NumField())
Umang01-hash marked this conversation as resolved.
Show resolved Hide resolved

for i := 0; i < e.entityType.NumField(); i++ {
field := e.entityType.Field(i)
fieldNames = append(fieldNames, toSnakeCase(field.Name))
fieldName := toSnakeCase(field.Name)

if e.constraints[fieldName].AutoIncrement {
continue // Skip auto-increment fields for insertion
}

fieldNames = append(fieldNames, fieldName)
fieldValues = append(fieldValues, reflect.ValueOf(newEntity).Elem().Field(i).Interface())
}

stmt := sql.InsertQuery(c.SQL.Dialect(), e.tableName, fieldNames)
return fieldNames, fieldValues
}

_, err = c.SQL.ExecContext(c, stmt, fieldValues...)
if err != nil {
return nil, err
func hasAutoIncrementID(constraints map[string]sql.FieldConstraints) bool {
for _, constraint := range constraints {
if constraint.AutoIncrement {
return true
}
}

return fmt.Sprintf("%s successfully created with id: %d", e.name, fieldValues[0]), nil
return false
}

func (e *entity) GetAll(c *Context) (interface{}, error) {
Expand Down
18 changes: 11 additions & 7 deletions pkg/gofr/crud_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func Test_scanEntity(t *testing.T) {
var invalidObject int

type userTestEntity struct {
ID int
Name string
ID int `sql:"auto_increment"`
Name string `sql:"not_null"`
}

tests := []struct {
Expand All @@ -77,18 +77,22 @@ func Test_scanEntity(t *testing.T) {
primaryKey: "id",
tableName: "user_test_entity",
restPath: "userTestEntity",
constraints: map[string]gofrSql.FieldConstraints{"id": {AutoIncrement: true, NotNull: false},
"name": {AutoIncrement: false, NotNull: true},
},
},
err: nil,
},
{
desc: "success case (custom)",
input: &userEntity{},
resp: &entity{
name: "userEntity",
entityType: reflect.TypeOf(userEntity{}),
primaryKey: "id",
tableName: "user",
restPath: "users",
name: "userEntity",
entityType: reflect.TypeOf(userEntity{}),
primaryKey: "id",
tableName: "user",
restPath: "users",
constraints: map[string]gofrSql.FieldConstraints{},
},
err: nil,
},
Expand Down
87 changes: 80 additions & 7 deletions pkg/gofr/datasource/sql/query_builder.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,51 @@
package sql

import (
"errors"
"fmt"

"reflect"
"strings"
Umang01-hash marked this conversation as resolved.
Show resolved Hide resolved
)

func InsertQuery(dialect, tableName string, fieldNames []string) string {
fieldNamesLength := len(fieldNames)
var (
errFieldCannotBeEmpty = errors.New("field cannot be empty")
errFieldCannotBeZero = errors.New("field cannot be zero")
errFieldCannotBeNull = errors.New("field cannot be null")
)

type FieldConstraints struct {
AutoIncrement bool
NotNull bool
}

func InsertQuery(dialect, tableName string, fieldNames []string, values []interface{},
constraints map[string]FieldConstraints) (string, error) {
bindVars := make([]string, 0, len(fieldNames))
columns := make([]string, 0, len(fieldNames))

for i, fieldName := range fieldNames {
if constraints[fieldName].AutoIncrement {
continue
}

if err := validateNotNull(fieldName, values[i], constraints[fieldName].NotNull); err != nil {
return "", err
}

var bindVars []string
for i := 1; i <= fieldNamesLength; i++ {
bindVars = append(bindVars, bindVar(dialect, i))
bindVars = append(bindVars, bindVar(dialect, i+1))
columns = append(columns, quotedString(quote(dialect), fieldName))
}

q := quote(dialect)

stmt := fmt.Sprintf(`INSERT INTO %s (%s) VALUES (%s)`,
quotedString(q, tableName),
quotedString(q, strings.Join(fieldNames, quotedString(q, ", "))),
strings.Join(columns, ", "),
strings.Join(bindVars, ", "),
)

return stmt
return stmt, nil
}

func SelectQuery(dialect, tableName string) string {
Expand Down Expand Up @@ -64,3 +88,52 @@ func DeleteByQuery(dialect, tableName, field string) string {
quotedString(q, field),
bindVar(dialect, 1))
}

func validateNotNull(fieldName string, value interface{}, isNotNull bool) error {
if !isNotNull {
return nil
}

switch v := value.(type) {
case string:
return validateStringNotNull(fieldName, v)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return validateIntNotNull(fieldName, v)
case float32, float64:
return validateFloatNotNull(fieldName, v)
default:
return validateDefaultNotNull(fieldName, value)
}
}

func validateStringNotNull(fieldName, value string) error {
if value == "" {
return fmt.Errorf("%w: %s", errFieldCannotBeEmpty, fieldName)
}

return nil
}

func validateIntNotNull(fieldName string, value interface{}) error {
if reflect.ValueOf(value).Int() == 0 {
return fmt.Errorf("%w: %s", errFieldCannotBeZero, fieldName)
}

return nil
}

func validateFloatNotNull(fieldName string, value interface{}) error {
if reflect.ValueOf(value).Float() == 0.0 {
return fmt.Errorf("%w: %s", errFieldCannotBeZero, fieldName)
}

return nil
}

func validateDefaultNotNull(fieldName string, value interface{}) error {
if reflect.ValueOf(value).IsNil() {
return fmt.Errorf("%w: %s", errFieldCannotBeNull, fieldName)
}

return nil
}
Loading
Loading