Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/auto increment add rest handlers #930

Merged
merged 11 commits into from
Aug 21, 2024
15 changes: 15 additions & 0 deletions docs/quick-start/add-rest-handlers/page.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ func (u *userEntity) TableName() string {
}
```

## Adding Database Constraints
By default, GoFr assumes to have manual insertion of id for a given struct, but to support sql constraints like `auto-increment`,
`not-null` user can use the `sql` tag while declaring the struct fields.

```go
type user struct {
ID int `json:"id" sql:"auto_increment"`
Name string `json:"name" sql:"not_null"`
Age int `json:"age"`
IsEmployed bool `json:"isEmployed"`
}
```

Now when posting data for the user struct, the `Id` we be auto-incremented and the `Name` will be a not-null field in table.

## Benefits of Adding REST Handlers of GoFr

1. Reduced Boilerplate Code: Eliminate repetitive code for CRUD operations, freeing user to focus on core application logic.
Expand Down
145 changes: 84 additions & 61 deletions pkg/gofr/crud_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ import (
"errors"
"fmt"
"reflect"
"strings"

"gofr.dev/pkg/gofr/datasource/sql"
)

var (
errInvalidObject = errors.New("unexpected object given for AddRESTHandlers")
errEntityNotFound = errors.New("entity not found")
errObjectIsNil = errors.New("object given for AddRESTHandlers is nil")
errNonPointerObject = errors.New("passed object is not pointer")
errInvalidObject = errors.New("unexpected object given for AddRESTHandlers")
errEntityNotFound = errors.New("entity not found")
errObjectIsNil = errors.New("object given for AddRESTHandlers is nil")
errNonPointerObject = errors.New("passed object is not pointer")
errFieldCannotBeNull = errors.New("field cannot be null")
errInvalidSQLTag = errors.New("invalid sql tag")
)

type Create interface {
Expand Down Expand Up @@ -54,11 +55,12 @@ type CRUD interface {

// entity stores information about an entity.
type entity struct {
name string
entityType reflect.Type
primaryKey string
tableName string
restPath string
name string
entityType reflect.Type
primaryKey string
tableName string
restPath string
constraints map[string]sql.FieldConstraints
}

// scanEntity extracts entity information for CRUD operations.
Expand Down Expand Up @@ -86,29 +88,28 @@ func scanEntity(object interface{}) (*entity, error) {
tableName := getTableName(object, structName)
restPath := getRestPath(object, structName)

return &entity{
name: structName,
entityType: entityType,
primaryKey: primaryKeyFieldName,
tableName: tableName,
restPath: restPath,
}, nil
}

func getTableName(object any, structName string) string {
if v, ok := object.(TableNameOverrider); ok {
return v.TableName()
e := &entity{
name: structName,
entityType: entityType,
primaryKey: primaryKeyFieldName,
tableName: tableName,
restPath: restPath,
constraints: make(map[string]sql.FieldConstraints),
}

return toSnakeCase(structName)
}
for i := 0; i < entityType.NumField(); i++ {
field := entityType.Field(i)
fieldName := toSnakeCase(field.Name)

constraints, err := parseSQLTag(field.Tag)
if err != nil {
return nil, err
}

func getRestPath(object any, structName string) string {
if v, ok := object.(RestPathOverrider); ok {
return v.RestPath()
e.constraints[fieldName] = constraints
}

return structName
return e, nil
}

// registerCRUDHandlers registers CRUD handlers for an entity.
Expand Down Expand Up @@ -148,30 +149,74 @@ func (a *App) registerCRUDHandlers(e *entity, object interface{}) {
}

func (e *entity) Create(c *Context) (interface{}, error) {
newEntity := reflect.New(e.entityType).Interface()
err := c.Bind(newEntity)
newEntity, err := e.bindAndValidateEntity(c)
if err != nil {
return nil, err
}

fieldNames, fieldValues := e.extractFields(newEntity)

stmt, err := sql.InsertQuery(c.SQL.Dialect(), e.tableName, fieldNames, fieldValues, e.constraints)
if err != nil {
return nil, err
}

fieldNames := make([]string, 0, e.entityType.NumField())
fieldValues := make([]interface{}, 0, e.entityType.NumField())
result, err := c.SQL.ExecContext(c, stmt, fieldValues...)
if err != nil {
return nil, err
}

for i := 0; i < e.entityType.NumField(); i++ {
field := e.entityType.Field(i)
fieldNames = append(fieldNames, toSnakeCase(field.Name))
fieldValues = append(fieldValues, reflect.ValueOf(newEntity).Elem().Field(i).Interface())
var lastID interface{}

if hasAutoIncrementID(e.constraints) { // Check for auto-increment ID
lastID, err = result.LastInsertId()
if err != nil {
return nil, err
}
} else {
lastID = fieldValues[0]
}

stmt := sql.InsertQuery(c.SQL.Dialect(), e.tableName, fieldNames)
return fmt.Sprintf("%s successfully created with id: %v", e.name, lastID), nil
aryanmehrotra marked this conversation as resolved.
Show resolved Hide resolved
}

func (e *entity) bindAndValidateEntity(c *Context) (interface{}, error) {
newEntity := reflect.New(e.entityType).Interface()

_, err = c.SQL.ExecContext(c, stmt, fieldValues...)
err := c.Bind(newEntity)
if err != nil {
return nil, err
}

return fmt.Sprintf("%s successfully created with id: %d", e.name, fieldValues[0]), nil
for i := 0; i < e.entityType.NumField(); i++ {
field := e.entityType.Field(i)
fieldName := toSnakeCase(field.Name)

if e.constraints[fieldName].NotNull && reflect.ValueOf(newEntity).Elem().Field(i).Interface() == nil {
return nil, fmt.Errorf("%w: %s", errFieldCannotBeNull, fieldName)
}
}

return newEntity, nil
}

func (e *entity) extractFields(newEntity any) (fieldNames []string, fieldValues []any) {
fieldNames = make([]string, 0, e.entityType.NumField())
fieldValues = make([]any, 0, e.entityType.NumField())

for i := 0; i < e.entityType.NumField(); i++ {
field := e.entityType.Field(i)
fieldName := toSnakeCase(field.Name)

if e.constraints[fieldName].AutoIncrement {
continue // Skip auto-increment fields for insertion
}

fieldNames = append(fieldNames, fieldName)
fieldValues = append(fieldValues, reflect.ValueOf(newEntity).Elem().Field(i).Interface())
}

return fieldNames, fieldValues
}

func (e *entity) GetAll(c *Context) (interface{}, error) {
Expand Down Expand Up @@ -287,25 +332,3 @@ func (e *entity) Delete(c *Context) (interface{}, error) {

return fmt.Sprintf("%s successfully deleted with id: %v", e.name, id), nil
}

func toSnakeCase(str string) string {
diff := 'a' - 'A'
length := len(str)

var builder strings.Builder

for i, char := range str {
if char >= 'a' {
builder.WriteRune(char)
continue
}

if (i != 0 || i == length-1) && ((i > 0 && rune(str[i-1]) >= 'a') || (i < length-1 && rune(str[i+1]) >= 'a')) {
builder.WriteRune('_')
}

builder.WriteRune(char + diff)
}

return builder.String()
}
9 changes: 7 additions & 2 deletions pkg/gofr/crud_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func Test_scanEntity(t *testing.T) {
var invalidObject int

type userTestEntity struct {
ID int
Name string
ID int `sql:"auto_increment"`
Name string `sql:"not_null"`
}

tests := []struct {
Expand All @@ -77,6 +77,9 @@ func Test_scanEntity(t *testing.T) {
primaryKey: "id",
tableName: "user_test_entity",
restPath: "userTestEntity",
constraints: map[string]gofrSql.FieldConstraints{"id": {AutoIncrement: true, NotNull: false},
"name": {AutoIncrement: false, NotNull: true},
},
},
err: nil,
},
Expand All @@ -89,6 +92,8 @@ func Test_scanEntity(t *testing.T) {
primaryKey: "id",
tableName: "user",
restPath: "users",
constraints: map[string]gofrSql.FieldConstraints{"id": {AutoIncrement: false, NotNull: false},
"is_employed": {AutoIncrement: false, NotNull: false}, "name": {AutoIncrement: false, NotNull: false}},
},
err: nil,
},
Expand Down
83 changes: 83 additions & 0 deletions pkg/gofr/crud_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package gofr

import (
"fmt"
"reflect"
"strings"

"gofr.dev/pkg/gofr/datasource/sql"
)

func getTableName(object any, structName string) string {
if v, ok := object.(TableNameOverrider); ok {
return v.TableName()
}

return toSnakeCase(structName)
}

func getRestPath(object any, structName string) string {
if v, ok := object.(RestPathOverrider); ok {
return v.RestPath()
}

return structName
}

func hasAutoIncrementID(constraints map[string]sql.FieldConstraints) bool {
for _, constraint := range constraints {
if constraint.AutoIncrement {
return true
}
}

return false
}

func parseSQLTag(inputTags reflect.StructTag) (sql.FieldConstraints, error) {
var constraints sql.FieldConstraints

sqlTag := inputTags.Get("sql")
if sqlTag == "" {
return constraints, nil
}

tags := strings.Split(sqlTag, ",")

for _, tag := range tags {
tag = strings.ToLower(tag) // Convert to lowercase for case-insensitivity

switch tag {
case "auto_increment":
constraints.AutoIncrement = true
case "not_null":
constraints.NotNull = true
default:
return constraints, fmt.Errorf("%w: %s", errInvalidSQLTag, tag)
}
}

return constraints, nil
}

func toSnakeCase(str string) string {
diff := 'a' - 'A'
length := len(str)

var builder strings.Builder

for i, char := range str {
if char >= 'a' {
builder.WriteRune(char)
continue
}

if (i != 0 || i == length-1) && ((i > 0 && rune(str[i-1]) >= 'a') || (i < length-1 && rune(str[i+1]) >= 'a')) {
builder.WriteRune('_')
}

builder.WriteRune(char + diff)
}

return builder.String()
}
Loading
Loading