diff --git a/pkg/engine/datasource/rest_datasource/rest_datasource.go b/pkg/engine/datasource/rest_datasource/rest_datasource.go index b4e4b1dc54..713fb2623b 100644 --- a/pkg/engine/datasource/rest_datasource/rest_datasource.go +++ b/pkg/engine/datasource/rest_datasource/rest_datasource.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/buger/jsonparser" "io" "net/http" "regexp" @@ -22,16 +23,9 @@ type Planner struct { config Configuration rootField int operationDefinition int - argumentTypeMap map[string]int - Operation *ast.Document } -func (p *Planner) EnterDocument(operation, definition *ast.Document) { - p.Operation = operation -} - -func (p *Planner) LeaveDocument(operation, definition *ast.Document) { -} +const typeInt = "Int" func (p *Planner) DownstreamResponseFieldAlias(_ int) (alias string, exists bool) { // the REST DataSourcePlanner doesn't rewrite upstream fields: skip @@ -55,8 +49,7 @@ type Factory struct { func (f *Factory) Planner(ctx context.Context) plan.DataSourcePlanner { return &Planner{ - client: f.Client, - argumentTypeMap: map[string]int{}, + client: f.Client, } } @@ -84,16 +77,15 @@ type FetchConfiguration struct { } type QueryConfiguration struct { - Name string `json:"name"` - Value string `json:"value"` + Name string `json:"name"` + Value string `json:"value"` + rawMessage json.RawMessage } func (p *Planner) Register(visitor *plan.Visitor, configuration plan.DataSourceConfiguration, isNested bool) error { p.v = visitor - visitor.Walker.RegisterDocumentVisitor(p) visitor.Walker.RegisterEnterFieldVisitor(p) visitor.Walker.RegisterEnterOperationVisitor(p) - visitor.Walker.RegisterEnterArgumentVisitor(p) return json.Unmarshal(configuration.Custom, &p.config) } @@ -101,25 +93,6 @@ func (p *Planner) EnterField(ref int) { p.rootField = ref } -func (p *Planner) EnterArgument(ref int) { - fieldName := p.Operation.FieldNameString(p.rootField) - argumentName := p.Operation.ArgumentNameString(ref) - key := fmt.Sprintf("%s_%s", fieldName, argumentName) - fmt.Println(key) - val := p.Operation.Arguments[ref].Value - if val.Kind == ast.ValueKindVariable { - if !p.Operation.OperationDefinitionHasVariableDefinition(p.operationDefinition, p.Operation.VariableValueNameString(val.Ref)) { - return - } - variableDefinition, exists := p.Operation.VariableDefinitionByNameAndOperation(p.operationDefinition, p.Operation.VariableValueNameBytes(val.Ref)) - if !exists { - return - } - p.argumentTypeMap[key] = p.Operation.VariableDefinitions[variableDefinition].Type - return - } -} - func (p *Planner) configureInput() []byte { input := httpclient.SetInputURL(nil, []byte(p.config.Fetch.URL)) @@ -132,7 +105,7 @@ func (p *Planner) configureInput() []byte { } preparedQuery := p.prepareQueryParams(p.rootField, p.config.Fetch.Query) - query, err := json.Marshal(preparedQuery) + query, err := p.marshalQueryParams(preparedQuery) if err == nil && len(preparedQuery) != 0 { input = httpclient.SetInputQueryParams(input, query) } @@ -184,6 +157,17 @@ Next: if value.Kind != ast.ValueKindVariable { continue Next } + + variableDefRef, exists := p.v.Operation.VariableDefinitionByNameAndOperation(p.operationDefinition, p.v.Operation.VariableValueNameBytes(value.Ref)) + if !exists { + continue + } + typeName := p.v.Operation.TypeNameString(p.v.Operation.VariableDefinitions[variableDefRef].Type) + query[i].rawMessage = []byte(`"` + query[i].Value + `"`) + if typeName == typeInt { + query[i].rawMessage = []byte(query[i].Value) + } + variableName := p.v.Operation.VariableValueNameString(value.Ref) if !p.v.Operation.OperationDefinitionHasVariableDefinition(p.operationDefinition, variableName) { continue Next @@ -195,6 +179,24 @@ Next: return out } +func (p *Planner) marshalQueryParams(params []QueryConfiguration) ([]byte, error) { + marshalled, err := json.Marshal(params) + if err != nil { + return nil, err + } + for i := range params { + if params[i].rawMessage != nil { + marshalled, err = jsonparser.Set(marshalled, params[i].rawMessage, fmt.Sprintf("[%d]", i), "value") + } else { + marshalled, err = jsonparser.Set(marshalled, []byte(params[i].Value), fmt.Sprintf("[%d]", i), "value") + } + if err != nil { + return nil, err + } + } + return marshalled, nil +} + type Source struct { client *http.Client } diff --git a/pkg/engine/datasource/rest_datasource/rest_datasource_test.go b/pkg/engine/datasource/rest_datasource/rest_datasource_test.go index b9cae1416f..e7a04217a1 100644 --- a/pkg/engine/datasource/rest_datasource/rest_datasource_test.go +++ b/pkg/engine/datasource/rest_datasource/rest_datasource_test.go @@ -125,7 +125,6 @@ const ( query ArgumentQuery { withIntArgument(limit: 10) { name - hasArg(limit:10) } } ` @@ -1103,7 +1102,7 @@ func TestFastHttpJsonDataSourcePlanning(t *testing.T) { Data: &resolve.Object{ Fetch: &resolve.SingleFetch{ BufferId: 0, - Input: `{"query_params":[{"name":"names","value": $$0$$}],"method":"GET","url":"https://example.com/friend"}`, + Input: `{"query_params":[{"name":"limit","value":$$0$$}],"method":"GET","url":"https://example.com/friend"}`, DataSource: &Source{}, Variables: resolve.NewVariables( &resolve.ContextVariable{