Skip to content

Commit

Permalink
Add support for auto increment and not null in add rest handlers from…
Browse files Browse the repository at this point in the history
… struct definition (#930)
  • Loading branch information
Umang01-hash authored Aug 21, 2024
1 parent 290d4e8 commit b98a7a3
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 83 deletions.
15 changes: 15 additions & 0 deletions docs/quick-start/add-rest-handlers/page.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ func (u *userEntity) TableName() string {
}
```

## Adding Database Constraints
By default, GoFr assumes to have manual insertion of id for a given struct, but to support sql constraints like `auto-increment`,
`not-null` user can use the `sql` tag while declaring the struct fields.

```go
type user struct {
ID int `json:"id" sql:"auto_increment"`
Name string `json:"name" sql:"not_null"`
Age int `json:"age"`
IsEmployed bool `json:"isEmployed"`
}
```

Now when posting data for the user struct, the `Id` we be auto-incremented and the `Name` will be a not-null field in table.

## Benefits of Adding REST Handlers of GoFr

1. Reduced Boilerplate Code: Eliminate repetitive code for CRUD operations, freeing user to focus on core application logic.
Expand Down
145 changes: 84 additions & 61 deletions pkg/gofr/crud_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ import (
"errors"
"fmt"
"reflect"
"strings"

"gofr.dev/pkg/gofr/datasource/sql"
)

var (
errInvalidObject = errors.New("unexpected object given for AddRESTHandlers")
errEntityNotFound = errors.New("entity not found")
errObjectIsNil = errors.New("object given for AddRESTHandlers is nil")
errNonPointerObject = errors.New("passed object is not pointer")
errInvalidObject = errors.New("unexpected object given for AddRESTHandlers")
errEntityNotFound = errors.New("entity not found")
errObjectIsNil = errors.New("object given for AddRESTHandlers is nil")
errNonPointerObject = errors.New("passed object is not pointer")
errFieldCannotBeNull = errors.New("field cannot be null")
errInvalidSQLTag = errors.New("invalid sql tag")
)

type Create interface {
Expand Down Expand Up @@ -54,11 +55,12 @@ type CRUD interface {

// entity stores information about an entity.
type entity struct {
name string
entityType reflect.Type
primaryKey string
tableName string
restPath string
name string
entityType reflect.Type
primaryKey string
tableName string
restPath string
constraints map[string]sql.FieldConstraints
}

// scanEntity extracts entity information for CRUD operations.
Expand Down Expand Up @@ -86,29 +88,28 @@ func scanEntity(object interface{}) (*entity, error) {
tableName := getTableName(object, structName)
restPath := getRestPath(object, structName)

return &entity{
name: structName,
entityType: entityType,
primaryKey: primaryKeyFieldName,
tableName: tableName,
restPath: restPath,
}, nil
}

func getTableName(object any, structName string) string {
if v, ok := object.(TableNameOverrider); ok {
return v.TableName()
e := &entity{
name: structName,
entityType: entityType,
primaryKey: primaryKeyFieldName,
tableName: tableName,
restPath: restPath,
constraints: make(map[string]sql.FieldConstraints),
}

return toSnakeCase(structName)
}
for i := 0; i < entityType.NumField(); i++ {
field := entityType.Field(i)
fieldName := toSnakeCase(field.Name)

constraints, err := parseSQLTag(field.Tag)
if err != nil {
return nil, err
}

func getRestPath(object any, structName string) string {
if v, ok := object.(RestPathOverrider); ok {
return v.RestPath()
e.constraints[fieldName] = constraints
}

return structName
return e, nil
}

// registerCRUDHandlers registers CRUD handlers for an entity.
Expand Down Expand Up @@ -148,30 +149,74 @@ func (a *App) registerCRUDHandlers(e *entity, object interface{}) {
}

func (e *entity) Create(c *Context) (interface{}, error) {
newEntity := reflect.New(e.entityType).Interface()
err := c.Bind(newEntity)
newEntity, err := e.bindAndValidateEntity(c)
if err != nil {
return nil, err
}

fieldNames, fieldValues := e.extractFields(newEntity)

stmt, err := sql.InsertQuery(c.SQL.Dialect(), e.tableName, fieldNames, fieldValues, e.constraints)
if err != nil {
return nil, err
}

fieldNames := make([]string, 0, e.entityType.NumField())
fieldValues := make([]interface{}, 0, e.entityType.NumField())
result, err := c.SQL.ExecContext(c, stmt, fieldValues...)
if err != nil {
return nil, err
}

for i := 0; i < e.entityType.NumField(); i++ {
field := e.entityType.Field(i)
fieldNames = append(fieldNames, toSnakeCase(field.Name))
fieldValues = append(fieldValues, reflect.ValueOf(newEntity).Elem().Field(i).Interface())
var lastID interface{}

if hasAutoIncrementID(e.constraints) { // Check for auto-increment ID
lastID, err = result.LastInsertId()
if err != nil {
return nil, err
}
} else {
lastID = fieldValues[0]
}

stmt := sql.InsertQuery(c.SQL.Dialect(), e.tableName, fieldNames)
return fmt.Sprintf("%s successfully created with id: %v", e.name, lastID), nil
}

func (e *entity) bindAndValidateEntity(c *Context) (interface{}, error) {
newEntity := reflect.New(e.entityType).Interface()

_, err = c.SQL.ExecContext(c, stmt, fieldValues...)
err := c.Bind(newEntity)
if err != nil {
return nil, err
}

return fmt.Sprintf("%s successfully created with id: %d", e.name, fieldValues[0]), nil
for i := 0; i < e.entityType.NumField(); i++ {
field := e.entityType.Field(i)
fieldName := toSnakeCase(field.Name)

if e.constraints[fieldName].NotNull && reflect.ValueOf(newEntity).Elem().Field(i).Interface() == nil {
return nil, fmt.Errorf("%w: %s", errFieldCannotBeNull, fieldName)
}
}

return newEntity, nil
}

func (e *entity) extractFields(newEntity any) (fieldNames []string, fieldValues []any) {
fieldNames = make([]string, 0, e.entityType.NumField())
fieldValues = make([]any, 0, e.entityType.NumField())

for i := 0; i < e.entityType.NumField(); i++ {
field := e.entityType.Field(i)
fieldName := toSnakeCase(field.Name)

if e.constraints[fieldName].AutoIncrement {
continue // Skip auto-increment fields for insertion
}

fieldNames = append(fieldNames, fieldName)
fieldValues = append(fieldValues, reflect.ValueOf(newEntity).Elem().Field(i).Interface())
}

return fieldNames, fieldValues
}

func (e *entity) GetAll(c *Context) (interface{}, error) {
Expand Down Expand Up @@ -287,25 +332,3 @@ func (e *entity) Delete(c *Context) (interface{}, error) {

return fmt.Sprintf("%s successfully deleted with id: %v", e.name, id), nil
}

func toSnakeCase(str string) string {
diff := 'a' - 'A'
length := len(str)

var builder strings.Builder

for i, char := range str {
if char >= 'a' {
builder.WriteRune(char)
continue
}

if (i != 0 || i == length-1) && ((i > 0 && rune(str[i-1]) >= 'a') || (i < length-1 && rune(str[i+1]) >= 'a')) {
builder.WriteRune('_')
}

builder.WriteRune(char + diff)
}

return builder.String()
}
9 changes: 7 additions & 2 deletions pkg/gofr/crud_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func Test_scanEntity(t *testing.T) {
var invalidObject int

type userTestEntity struct {
ID int
Name string
ID int `sql:"auto_increment"`
Name string `sql:"not_null"`
}

tests := []struct {
Expand All @@ -77,6 +77,9 @@ func Test_scanEntity(t *testing.T) {
primaryKey: "id",
tableName: "user_test_entity",
restPath: "userTestEntity",
constraints: map[string]gofrSql.FieldConstraints{"id": {AutoIncrement: true, NotNull: false},
"name": {AutoIncrement: false, NotNull: true},
},
},
err: nil,
},
Expand All @@ -89,6 +92,8 @@ func Test_scanEntity(t *testing.T) {
primaryKey: "id",
tableName: "user",
restPath: "users",
constraints: map[string]gofrSql.FieldConstraints{"id": {AutoIncrement: false, NotNull: false},
"is_employed": {AutoIncrement: false, NotNull: false}, "name": {AutoIncrement: false, NotNull: false}},
},
err: nil,
},
Expand Down
83 changes: 83 additions & 0 deletions pkg/gofr/crud_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package gofr

import (
"fmt"
"reflect"
"strings"

"gofr.dev/pkg/gofr/datasource/sql"
)

func getTableName(object any, structName string) string {
if v, ok := object.(TableNameOverrider); ok {
return v.TableName()
}

return toSnakeCase(structName)
}

func getRestPath(object any, structName string) string {
if v, ok := object.(RestPathOverrider); ok {
return v.RestPath()
}

return structName
}

func hasAutoIncrementID(constraints map[string]sql.FieldConstraints) bool {
for _, constraint := range constraints {
if constraint.AutoIncrement {
return true
}
}

return false
}

func parseSQLTag(inputTags reflect.StructTag) (sql.FieldConstraints, error) {
var constraints sql.FieldConstraints

sqlTag := inputTags.Get("sql")
if sqlTag == "" {
return constraints, nil
}

tags := strings.Split(sqlTag, ",")

for _, tag := range tags {
tag = strings.ToLower(tag) // Convert to lowercase for case-insensitivity

switch tag {
case "auto_increment":
constraints.AutoIncrement = true
case "not_null":
constraints.NotNull = true
default:
return constraints, fmt.Errorf("%w: %s", errInvalidSQLTag, tag)
}
}

return constraints, nil
}

func toSnakeCase(str string) string {
diff := 'a' - 'A'
length := len(str)

var builder strings.Builder

for i, char := range str {
if char >= 'a' {
builder.WriteRune(char)
continue
}

if (i != 0 || i == length-1) && ((i > 0 && rune(str[i-1]) >= 'a') || (i < length-1 && rune(str[i+1]) >= 'a')) {
builder.WriteRune('_')
}

builder.WriteRune(char + diff)
}

return builder.String()
}
Loading

0 comments on commit b98a7a3

Please sign in to comment.