diff --git a/internal/internal_workflow_client_test.go b/internal/internal_workflow_client_test.go index f52220cd8..59324c35e 100644 --- a/internal/internal_workflow_client_test.go +++ b/internal/internal_workflow_client_test.go @@ -1913,6 +1913,128 @@ func (s *workflowClientTestSuite) TestResetWorkflow() { } } +func (s *workflowClientTestSuite) TestQueryWorkflowWithOptions() { + testcases := []struct { + name string + queryArgs []interface{} + requestValidator func(req *shared.QueryWorkflowRequest) + expectRPC bool + rpcResponse *shared.QueryWorkflowResponse + rpcError error + responseValidator func(resp *QueryWorkflowWithOptionsResponse, err error) + }{ + { + name: "success without arguments", + queryArgs: nil, + expectRPC: true, + requestValidator: func(req *shared.QueryWorkflowRequest) { + s.Empty(req.GetQuery().GetQueryArgs(), "no input queryArgs provided") + }, + + rpcResponse: &shared.QueryWorkflowResponse{QueryResult: []byte("\"result\"")}, + rpcError: nil, + responseValidator: func(resp *QueryWorkflowWithOptionsResponse, err error) { + s.Require().Nil(err) + s.Nil(resp.QueryRejected) + + var res string + s.NoError(resp.QueryResult.Get(&res)) + s.Equal("result", res) + }, + }, + { + name: "success with arguments", + queryArgs: []interface{}{"arg1", "arg2"}, + expectRPC: true, + requestValidator: func(req *shared.QueryWorkflowRequest) { + s.Equal("\"arg1\"\n\"arg2\"\n", string(req.GetQuery().GetQueryArgs())) + }, + + rpcResponse: &shared.QueryWorkflowResponse{QueryResult: []byte("\"result\"")}, + rpcError: nil, + responseValidator: func(resp *QueryWorkflowWithOptionsResponse, err error) { + s.Require().Nil(err) + s.Nil(resp.QueryRejected) + + var res string + s.NoError(resp.QueryResult.Get(&res)) + s.Equal("result", res) + }, + }, + { + name: "failed to encode arguments", + queryArgs: []interface{}{make(chan int)}, // you can't marshal this object to JSON + expectRPC: false, + + responseValidator: func(resp *QueryWorkflowWithOptionsResponse, err error) { + s.ErrorContains(err, "unable to encode") + s.Nil(resp) + }, + }, + { + name: "RPC fails", + queryArgs: nil, + expectRPC: true, + requestValidator: func(req *shared.QueryWorkflowRequest) {}, + + rpcResponse: nil, + rpcError: &shared.AccessDeniedError{}, + responseValidator: func(resp *QueryWorkflowWithOptionsResponse, err error) { + s.Equal(&shared.AccessDeniedError{}, err) + s.Nil(resp) + }, + }, + { + name: "query rejected", + queryArgs: nil, + expectRPC: true, + requestValidator: func(req *shared.QueryWorkflowRequest) {}, + + rpcResponse: &shared.QueryWorkflowResponse{ + QueryRejected: &shared.QueryRejected{ + CloseStatus: shared.WorkflowExecutionCloseStatusTerminated.Ptr(), + }, + }, + rpcError: nil, + responseValidator: func(resp *QueryWorkflowWithOptionsResponse, err error) { + s.Require().Nil(err) + s.Nil(resp.QueryResult, "should be nil when query rejected") + + s.Require().NotNil(resp.QueryRejected) + s.Equal(shared.WorkflowExecutionCloseStatusTerminated, resp.QueryRejected.GetCloseStatus()) + }, + }, + } + + for _, tt := range testcases { + s.Run(tt.name, func() { + if tt.expectRPC { + s.service.EXPECT(). + QueryWorkflow(gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(_ context.Context, req *shared.QueryWorkflowRequest, _ ...yarpc.CallOption) { + // do common validation for some fields + s.Equal(domain, req.GetDomain()) + s.Equal(workflowID, req.GetExecution().GetWorkflowId()) + s.Equal(runID, req.GetExecution().GetRunId()) + s.Equal(queryType, req.GetQuery().GetQueryType()) + // ... and custom validation for the testcase + tt.requestValidator(req) + }). + Return(tt.rpcResponse, tt.rpcError) + } + + request := &QueryWorkflowWithOptionsRequest{ + WorkflowID: workflowID, + QueryType: queryType, + RunID: runID, + Args: tt.queryArgs, + } + resp, err := s.client.QueryWorkflowWithOptions(context.Background(), request) + tt.responseValidator(resp, err) + }) + } +} + func (s *workflowClientTestSuite) TestGetWorkflowHistory() { // Page 1 of 2 //// Events