diff --git a/docs/references/context/page.md b/docs/references/context/page.md index 5c1a5dead..61502f1ae 100644 --- a/docs/references/context/page.md +++ b/docs/references/context/page.md @@ -50,6 +50,24 @@ parts of the request. ctx.Bind(&p) // the Bind() method will map the incoming request to variable p ``` + + ### Binding multipart-form data + - To bind multipart-form data, you can use the Bind method similarly. The struct fields should be tagged appropriately + to map the form fields to the struct fields. + + ```go + type Data struct { + Name string `form:"name"` + + Compressed file.Zip `file:"upload"` + + FileHeader *multipart.FileHeader `file:"file_upload"` + } + ``` + + - The `form` tag is used to bind non-file fields. + - The `file` tag is used to bind file fields. If the tag is not present, the field name is used as the key. + - `HostName()` - to access the host name for the incoming request ```go diff --git a/examples/using-file-bind/README.md b/examples/using-file-bind/README.md index ef74f96a5..0b10006b8 100644 --- a/examples/using-file-bind/README.md +++ b/examples/using-file-bind/README.md @@ -8,7 +8,7 @@ it to the fields of the struct. GoFr currently supports zip file type and also b type Data struct { Compressed file.Zip `file:"upload"` - FileHeader *multipart.FileHeader `file:"a"` + FileHeader *multipart.FileHeader `file:"file_upload"` } func Handler (c *gofr.Context) (interface{}, error) { diff --git a/examples/using-file-bind/main.go b/examples/using-file-bind/main.go index aa813504e..9ff76507b 100644 --- a/examples/using-file-bind/main.go +++ b/examples/using-file-bind/main.go @@ -30,7 +30,7 @@ type Data struct { // The FileHeader determines the generic file format that we can get // from the multipart form that gets parsed by the incoming HTTP request - FileHeader *multipart.FileHeader `file:"a"` + FileHeader *multipart.FileHeader `file:"file_upload"` } func UploadHandler(c *gofr.Context) (interface{}, error) { diff --git a/examples/using-file-bind/main_test.go b/examples/using-file-bind/main_test.go index 50a385a74..839e3277f 100644 --- a/examples/using-file-bind/main_test.go +++ b/examples/using-file-bind/main_test.go @@ -56,7 +56,7 @@ func generateMultiPartBody(t *testing.T) (*bytes.Buffer, string) { t.Fatalf("Failed to write file to form: %v", err) } - fileHeader, err := writer.CreateFormFile("a", "hello.txt") + fileHeader, err := writer.CreateFormFile("file_upload", "hello.txt") if err != nil { t.Fatalf("Failed to create form file: %v", err) } diff --git a/pkg/gofr/http/form_data_binder.go b/pkg/gofr/http/form_data_binder.go new file mode 100644 index 000000000..811cdfff4 --- /dev/null +++ b/pkg/gofr/http/form_data_binder.go @@ -0,0 +1,205 @@ +package http + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" +) + +func (*formData) setInterfaceValue(value reflect.Value, data any) (bool, error) { + if !value.CanSet() { + return false, fmt.Errorf("%w: %s", errUnsupportedInterfaceType, value.Kind()) + } + + value.Set(reflect.ValueOf(data)) + + return true, nil +} + +func (uf *formData) setSliceOrArrayValue(value reflect.Value, data string) (bool, error) { + if value.Kind() != reflect.Slice && value.Kind() != reflect.Array { + return false, fmt.Errorf("%w: %s", errUnsupportedKind, value.Kind()) + } + + elemType := value.Type().Elem() + + elements := strings.Split(data, ",") + + // Create a new slice/array with appropriate length and capacity + var newSlice reflect.Value + + if value.Kind() == reflect.Slice { + newSlice = reflect.MakeSlice(value.Type(), len(elements), len(elements)) + } else if len(elements) > value.Len() { + return false, errDataLengthExceeded + } else { + newSlice = reflect.New(value.Type()).Elem() + } + + // Create a reusable element value to avoid unnecessary allocations + elemValue := reflect.New(elemType).Elem() + + // Set the elements of the slice/array + for i, strVal := range elements { + // Update the reusable element value + if _, err := uf.setFieldValue(elemValue, strVal); err != nil { + return false, fmt.Errorf("%w %d: %w", errSettingValueFailure, i, err) + } + + newSlice.Index(i).Set(elemValue) + } + + value.Set(newSlice) + + return true, nil +} + +func (*formData) setStructValue(value reflect.Value, data string) (bool, error) { + if value.Kind() != reflect.Struct { + return false, errNotAStruct + } + + dataMap, err := parseStringToMap(data) + if err != nil { + return false, err + } + + if len(dataMap) == 0 { + return false, errFieldsNotSet + } + + numFieldsSet := 0 + + var multiErr error + + // Create a map for case-insensitive lookups + caseInsensitiveMap := make(map[string]interface{}) + for key, val := range dataMap { + caseInsensitiveMap[strings.ToLower(key)] = val + } + + for i := 0; i < value.NumField(); i++ { + fieldType := value.Type().Field(i) + fieldValue := value.Field(i) + fieldName := fieldType.Name + + // Perform case-insensitive lookup for the key in dataMap + val, exists := caseInsensitiveMap[strings.ToLower(fieldName)] + if !exists { + continue + } + + if !fieldValue.CanSet() { + multiErr = fmt.Errorf("%w: %s", errUnexportedField, fieldName) + continue + } + + if err := setFieldValueFromData(fieldValue, val); err != nil { + multiErr = fmt.Errorf("%w; %w", multiErr, err) + continue + } + + numFieldsSet++ + } + + if numFieldsSet == 0 { + return false, errFieldsNotSet + } + + return true, multiErr +} + +// setFieldValueFromData sets the field's value based on the provided data. +func setFieldValueFromData(field reflect.Value, data interface{}) error { + switch field.Kind() { + case reflect.String: + return setStringField(field, data) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return setIntField(field, data) + case reflect.Float32, reflect.Float64: + return setFloatField(field, data) + case reflect.Bool: + return setBoolField(field, data) + case reflect.Invalid, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Complex64, reflect.Complex128, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.Pointer, reflect.Slice, reflect.Struct, reflect.UnsafePointer: + return fmt.Errorf("%w: %s, %T", errUnsupportedFieldType, field.Type().Name(), data) + default: + return fmt.Errorf("%w: %s, %T", errUnsupportedFieldType, field.Type().Name(), data) + } +} + +type customUnmarshaller struct { + dataMap map[string]interface{} +} + +// UnmarshalJSON is a custom unmarshaller because json package in Go unmarshal numbers to float64 by default. +func (c *customUnmarshaller) UnmarshalJSON(data []byte) error { + var rawData map[string]interface{} + + err := json.Unmarshal(data, &rawData) + if err != nil { + return err + } + + dataMap := make(map[string]any, len(rawData)) + + for key, val := range rawData { + if valFloat, ok := val.(float64); ok { + valInt := int(valFloat) + if valFloat == float64(valInt) { + val = valInt + } + } + + dataMap[key] = val + } + + *c = customUnmarshaller{dataMap} + + return nil +} + +func parseStringToMap(data string) (map[string]interface{}, error) { + var c customUnmarshaller + err := json.Unmarshal([]byte(data), &c) + + return c.dataMap, err +} + +func setStringField(field reflect.Value, data interface{}) error { + if val, ok := data.(string); ok { + field.SetString(val) + return nil + } + + return fmt.Errorf("%w: expected string but got %T", errUnsupportedFieldType, data) +} + +func setIntField(field reflect.Value, data interface{}) error { + if val, ok := data.(int); ok { + field.SetInt(int64(val)) + return nil + } + + return fmt.Errorf("%w: expected int but got %T", errUnsupportedFieldType, data) +} + +func setFloatField(field reflect.Value, data interface{}) error { + if val, ok := data.(float64); ok { + field.SetFloat(val) + return nil + } + + return fmt.Errorf("%w: expected float64 but got %T", errUnsupportedFieldType, data) +} + +func setBoolField(field reflect.Value, data interface{}) error { + if val, ok := data.(bool); ok { + field.SetBool(val) + return nil + } + + return fmt.Errorf("%w: expected bool but got %T", errUnsupportedFieldType, data) +} diff --git a/pkg/gofr/http/multipart_file_bind.go b/pkg/gofr/http/multipart_file_bind.go index c55025509..a4fe481f6 100644 --- a/pkg/gofr/http/multipart_file_bind.go +++ b/pkg/gofr/http/multipart_file_bind.go @@ -1,6 +1,7 @@ package http import ( + "errors" "io" "mime/multipart" "reflect" @@ -9,6 +10,17 @@ import ( "gofr.dev/pkg/gofr/file" ) +var ( + errUnsupportedInterfaceType = errors.New("unsupported interface value type") + errDataLengthExceeded = errors.New("data length exceeds array capacity") + errUnsupportedKind = errors.New("unsupported kind") + errSettingValueFailure = errors.New("error setting value at index") + errNotAStruct = errors.New("provided value is not a struct") + errUnexportedField = errors.New("cannot set field; it might be unexported") + errUnsupportedFieldType = errors.New("unsupported type for field") + errFieldsNotSet = errors.New("no fields were set") +) + type formData struct { fields map[string][]string files map[string][]*multipart.FileHeader @@ -134,6 +146,8 @@ func (*formData) setFile(value reflect.Value, header []*multipart.FileHeader) (b } func (uf *formData) setFieldValue(value reflect.Value, data string) (bool, error) { + value = dereferencePointerType(value) + kind := value.Kind() switch kind { case reflect.String: @@ -146,13 +160,31 @@ func (uf *formData) setFieldValue(value reflect.Value, data string) (bool, error return uf.setFloatValue(value, data) case reflect.Bool: return uf.setBoolValue(value, data) - case reflect.Invalid, reflect.Complex64, reflect.Complex128, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, - reflect.Map, reflect.Pointer, reflect.Slice, reflect.Struct, reflect.UnsafePointer: - // These types are not supported for setting via form data - return false, nil - default: + case reflect.Slice, reflect.Array: + return uf.setSliceOrArrayValue(value, data) + case reflect.Interface: + return uf.setInterfaceValue(value, data) + case reflect.Struct: + return uf.setStructValue(value, data) + case reflect.Invalid, reflect.Complex64, reflect.Complex128, reflect.Chan, reflect.Func, + reflect.Map, reflect.Pointer, reflect.UnsafePointer: return false, nil } + + return false, nil +} + +func dereferencePointerType(value reflect.Value) reflect.Value { + if value.Kind() == reflect.Ptr { + if value.IsNil() { + // Initialize the pointer to a new value if it's nil + value.Set(reflect.New(value.Type().Elem())) + } + + value = value.Elem() // Dereference the pointer + } + + return value } func (*formData) setStringValue(value reflect.Value, data string) (bool, error) { diff --git a/pkg/gofr/http/multipart_file_bind_test.go b/pkg/gofr/http/multipart_file_bind_test.go index 762836734..e69dd97bb 100644 --- a/pkg/gofr/http/multipart_file_bind_test.go +++ b/pkg/gofr/http/multipart_file_bind_test.go @@ -1,10 +1,18 @@ package http import ( + "errors" + "fmt" "reflect" "testing" + "unsafe" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + errUnsupportedType = errors.New("unsupported type for field: expected float64 but got bool") + errJSON = errors.New("unexpected end of JSON input") ) func TestGetFieldName(t *testing.T) { @@ -49,8 +57,209 @@ func TestGetFieldName(t *testing.T) { for i, tt := range tests { t.Run(tt.desc, func(t *testing.T) { result, gotOk := getFieldName(tt.field) - assert.Equal(t, tt.key, result, "TestGetFieldName[%d] : %v Failed!", i, tt.desc) - assert.Equal(t, tt.wantOk, gotOk, "TestGetFieldName[%d] : %v Failed!", i, tt.desc) + require.Equal(t, tt.key, result, "TestGetFieldName[%d] : %v Failed!", i, tt.desc) + require.Equal(t, tt.wantOk, gotOk, "TestGetFieldName[%d] : %v Failed!", i, tt.desc) + }) + } +} + +type testValue struct { + kind reflect.Kind + value interface{} +} + +func Test_SetFieldValue_Success(t *testing.T) { + testCases := []struct { + desc string + data string + expected bool + valueType testValue + }{ + {"String", "test", true, testValue{reflect.String, "string"}}, + {"Int", "10", true, testValue{reflect.Int, 0}}, + {"Uint", "10", true, testValue{reflect.Uint16, uint16(10)}}, + {"Float64", "3.14", true, testValue{reflect.Float64, 0.0}}, + {"Bool", "true", true, testValue{reflect.Bool, false}}, + {"Slice", "1,2,3,4,5", true, testValue{reflect.Slice, []int{}}}, + {"Array", "1,2,3,4,5", true, testValue{reflect.Array, [5]int{}}}, + {"Struct", `{"name": "John", "age": 30}`, true, testValue{reflect.Struct, struct { + Name string `json:"name"` + Age int `json:"age"` + }{}}}, + {"Interface", "test interface", true, testValue{reflect.Interface, new(any)}}, + } + + for _, tc := range testCases { + f := &formData{} + val := reflect.New(reflect.TypeOf(tc.valueType.value)).Elem() + + set, err := f.setFieldValue(val, tc.data) + + require.NoErrorf(t, err, "Unexpected error for value kind %v and data %q", val.Kind(), tc.data) + + require.Equalf(t, tc.expected, set, "Expected set to be %v for value kind %v and data %q", tc.expected, val.Kind(), tc.data) + } +} + +func TestSetFieldValue_InvalidKinds(t *testing.T) { + uf := &formData{} + + tests := []struct { + kind reflect.Kind + data string + typ reflect.Type + }{ + {reflect.Complex64, "foo", reflect.TypeOf(complex64(0))}, + {reflect.Complex128, "bar", reflect.TypeOf(complex128(0))}, + {reflect.Chan, "baz", reflect.TypeOf(make(chan int))}, + {reflect.Func, "qux", reflect.TypeOf(func() {})}, + {reflect.Map, "quux", reflect.TypeOf(map[string]int{})}, + {reflect.UnsafePointer, "grault", reflect.TypeOf(unsafe.Pointer(nil))}, + } + + for _, tt := range tests { + value := reflect.New(tt.typ).Elem() + ok, err := uf.setFieldValue(value, tt.data) + + require.False(t, ok, "expected false, got true for kind %v", tt.kind) + + require.NoError(t, err, "expected nil, got %v for kind %v", err, tt.kind) + } +} + +func TestSetSliceOrArrayValue(t *testing.T) { + type testStruct struct { + Slice []string + Array [3]string + } + + uf := &formData{} + + // Test with a slice + value := reflect.ValueOf(&testStruct{Slice: nil}).Elem().FieldByName("Slice") + + data := "a,b,c" + + ok, err := uf.setSliceOrArrayValue(value, data) + + require.True(t, ok, "setSliceOrArrayValue failed") + + require.NoError(t, err, "setSliceOrArrayValue failed: %v", err) + + require.Len(t, value.Interface().([]string), 3, "slice not set correctly") + + // Test with an array + value = reflect.ValueOf(&testStruct{Array: [3]string{}}).Elem().FieldByName("Array") + + data = "a,b,c" + + ok, err = uf.setSliceOrArrayValue(value, data) + + require.True(t, ok, "setSliceOrArrayValue failed") + + require.NoError(t, err, "setSliceOrArrayValue failed: %v", err) +} + +func TestSetStructValue_Success(t *testing.T) { + type testStruct struct { + Field1 string + Field2 int + } + + uf := &formData{} + + tests := []struct { + name string + data string + wantField1 string + wantField2 int + }{ + { + name: "Valid input with correct case", + data: `{"Field1":"value1","Field2":123}`, + wantField1: "value1", + wantField2: 123, + }, + { + name: "Valid input with case insensitive fields", + data: `{"field1":"value2","FIELD2":456}`, + wantField1: "value2", + wantField2: 456, + }, + { + name: "Mixed Case and invalid field names", + data: `{"FielD1":"value4", "invalidField":"ignored", "FiEld2":789}`, + wantField1: "value4", + wantField2: 789, + }, + { + name: "Case-insensitive field name but not in dataMap", + data: `{"fIeLd1":"value5", "not_in_dataMap": 123}`, + wantField1: "value5", + wantField2: 0, // Field2 should remain unset (default 0) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value := reflect.ValueOf(&testStruct{}).Elem() + + ok, err := uf.setStructValue(value, tt.data) + + require.NoError(t, err, "TestSetStructValue_Success Failed.") + require.True(t, ok, "TestSetStructValue_Success Failed.") + require.Equal(t, tt.wantField1, value.FieldByName("Field1").String(), + "TestSetStructValue_Success Failed : Field1 not set correctly") + require.Equal(t, tt.wantField2, int(value.FieldByName("Field2").Int()), + "TestSetStructValue_Success Failed : Field2 not set correctly") + }) + } +} + +func TestSetStructValue_Errors(t *testing.T) { + type testStruct struct { + Field1 string + Field2 int + Field4 float64 + } + + uf := &formData{} + + tests := []struct { + name string + data string + err error + }{ + { + name: "Unexported field", + data: `{"field3":"value3"}`, + err: errFieldsNotSet, + }, + { + name: "Unsupported field type", + data: `{"field2":1,"Field4":true}`, + err: fmt.Errorf("%w; %w", nil, errUnsupportedType), + }, + { + name: "Invalid JSON", + data: `{"Field1":"value1", "Field2":123,`, + err: errJSON, // JSON parsing error + }, + { + name: "Field not settable", + data: `{"Field1":"value1", "Field2":123, "Field4": "not a float"}`, + err: errUnsupportedFieldType, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value := reflect.ValueOf(&testStruct{}).Elem() + + _, err := uf.setStructValue(value, tt.data) + + require.Error(t, err, "TestSetStructValue_Errors Failed.") + require.Contains(t, err.Error(), tt.err.Error(), "TestSetStructValue_Errors Failed.") }) } }