diff --git a/internal/internal_workflow_client_test.go b/internal/internal_workflow_client_test.go index eee1596ae..a1123f110 100644 --- a/internal/internal_workflow_client_test.go +++ b/internal/internal_workflow_client_test.go @@ -2348,8 +2348,7 @@ func (s *workflowClientTestSuite) TestQueryWorkflowWithOptions() { testcases := []struct { name string queryArgs []interface{} - requestValidator func(req *shared.QueryWorkflowRequest) - expectRPC bool + requestValidator func(req *shared.QueryWorkflowRequest) // nil if RPC is not expected rpcResponse *shared.QueryWorkflowResponse rpcError error responseValidator func(resp *QueryWorkflowWithOptionsResponse, err error) @@ -2357,8 +2356,13 @@ func (s *workflowClientTestSuite) TestQueryWorkflowWithOptions() { { name: "success without arguments", queryArgs: nil, - expectRPC: true, requestValidator: func(req *shared.QueryWorkflowRequest) { + // do common validation for common fields as well + s.Equal(domain, req.GetDomain()) + s.Equal(workflowID, req.GetExecution().GetWorkflowId()) + s.Equal(runID, req.GetExecution().GetRunId()) + s.Equal(queryType, req.GetQuery().GetQueryType()) + s.Empty(req.GetQuery().GetQueryArgs(), "no input queryArgs provided") }, @@ -2376,7 +2380,6 @@ func (s *workflowClientTestSuite) TestQueryWorkflowWithOptions() { { name: "success with arguments", queryArgs: []interface{}{"arg1", "arg2"}, - expectRPC: true, requestValidator: func(req *shared.QueryWorkflowRequest) { s.Equal("\"arg1\"\n\"arg2\"\n", string(req.GetQuery().GetQueryArgs())) }, @@ -2393,9 +2396,9 @@ func (s *workflowClientTestSuite) TestQueryWorkflowWithOptions() { }, }, { - name: "failed to encode arguments", - queryArgs: []interface{}{make(chan int)}, // you can't marshal this object to JSON - expectRPC: false, + name: "failed to encode arguments", + queryArgs: []interface{}{make(chan int)}, // you can't marshal this object to JSON + requestValidator: nil, responseValidator: func(resp *QueryWorkflowWithOptionsResponse, err error) { s.ErrorContains(err, "unable to encode") @@ -2405,7 +2408,6 @@ func (s *workflowClientTestSuite) TestQueryWorkflowWithOptions() { { name: "RPC fails", queryArgs: nil, - expectRPC: true, requestValidator: func(req *shared.QueryWorkflowRequest) {}, rpcResponse: nil, @@ -2418,7 +2420,6 @@ func (s *workflowClientTestSuite) TestQueryWorkflowWithOptions() { { name: "query rejected", queryArgs: nil, - expectRPC: true, requestValidator: func(req *shared.QueryWorkflowRequest) {}, rpcResponse: &shared.QueryWorkflowResponse{ @@ -2439,16 +2440,10 @@ func (s *workflowClientTestSuite) TestQueryWorkflowWithOptions() { for _, tt := range testcases { s.Run(tt.name, func() { - if tt.expectRPC { + if tt.requestValidator != nil { s.service.EXPECT(). QueryWorkflow(gomock.Any(), gomock.Any(), gomock.Any()). Do(func(_ context.Context, req *shared.QueryWorkflowRequest, _ ...yarpc.CallOption) { - // do common validation for some fields - s.Equal(domain, req.GetDomain()) - s.Equal(workflowID, req.GetExecution().GetWorkflowId()) - s.Equal(runID, req.GetExecution().GetRunId()) - s.Equal(queryType, req.GetQuery().GetQueryType()) - // ... and custom validation for the testcase tt.requestValidator(req) }). Return(tt.rpcResponse, tt.rpcError)