Skip to content

Commit

Permalink
add: grpc method for parallel steps
Browse files Browse the repository at this point in the history
Signed-off-by: Vladislav Sukhin <vladislav@kubeshop.io>
  • Loading branch information
vsukhin committed Dec 3, 2024
1 parent bfdc72f commit ce5215f
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
7 changes: 7 additions & 0 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ func (ag *Agent) run(ctx context.Context) (err error) {
return ag.runTestWorkflowServiceNotificationsWorker(groupCtx, ag.testWorkflowServiceNotificationsWorkerCount)
})

g.Go(func() error {
return ag.runTestWorkflowParallelStepNotificationsLoop(groupCtx)
})
g.Go(func() error {
return ag.runTestWorkflowParallelStepNotificationsWorker(groupCtx, ag.testWorkflowParallelStepNotificationsWorkerCount)
})

err = g.Wait()

return err
Expand Down
192 changes: 192 additions & 0 deletions pkg/agent/testworkflows.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,51 @@ func (ag *Agent) runTestWorkflowServiceNotificationsLoop(ctx context.Context) er
return err
}

func (ag *Agent) runTestWorkflowParallelStepNotificationsLoop(ctx context.Context) error {
ctx = agentclient.AddAPIKeyMeta(ctx, ag.apiKey)

ag.logger.Infow("initiating workflow parallel step notifications streaming connection with Cloud API")
// creates a new Stream from the client side. ctx is used for the lifetime of the stream.
opts := []grpc.CallOption{grpc.UseCompressor(gzip.Name), grpc.MaxCallRecvMsgSize(math.MaxInt32)}
stream, err := ag.client.GetTestWorkflowParallelStepNotificationsStream(ctx, opts...)
if err != nil {
ag.logger.Errorf("failed to execute: %w", err)
return errors.Wrap(err, "failed to setup stream")
}

// GRPC stream have special requirements for concurrency on SendMsg, and RecvMsg calls.
// Please check https://github.com/grpc/grpc-go/blob/master/Documentation/concurrency.md
g, groupCtx := errgroup.WithContext(ctx)
g.Go(func() error {
for {
cmd, err := ag.receiveTestWorkflowParallelStepNotificationsRequest(groupCtx, stream)
if err != nil {
return err
}

ag.testWorkflowParallelStepNotificationsRequestBuffer <- cmd
}
})

g.Go(func() error {
for {
select {
case resp := <-ag.testWorkflowParallelStepNotificationsResponseBuffer:
err := ag.sendTestWorkflowParallelStepNotificationsResponse(groupCtx, stream, resp)
if err != nil {
return err
}
case <-groupCtx.Done():
return groupCtx.Err()
}
}
})

err = g.Wait()

return err
}

func (ag *Agent) runTestWorkflowNotificationsWorker(ctx context.Context, numWorkers int) error {
g, groupCtx := errgroup.WithContext(ctx)
for i := 0; i < numWorkers; i++ {
Expand Down Expand Up @@ -181,6 +226,34 @@ func (ag *Agent) runTestWorkflowServiceNotificationsWorker(ctx context.Context,
return g.Wait()
}

func (ag *Agent) runTestWorkflowParallelStepNotificationsWorker(ctx context.Context, numWorkers int) error {
g, groupCtx := errgroup.WithContext(ctx)
for i := 0; i < numWorkers; i++ {
g.Go(func() error {
for {
select {
case req := <-ag.testWorkflowParallelStepNotificationsRequestBuffer:
if req.RequestType == cloud.TestWorkflowNotificationsRequestType_WORKFLOW_STREAM_HEALTH_CHECK {
ag.testWorkflowParallelStepNotificationsResponseBuffer <- &cloud.TestWorkflowParallelStepNotificationsResponse{
StreamId: req.StreamId,
SeqNo: 0,
}
break
}

err := ag.executeWorkflowParallelStepNotificationsRequest(groupCtx, req)
if err != nil {
ag.logger.Errorf("error executing workflow parallel step notifications request: %s", err.Error())
}
case <-groupCtx.Done():
return groupCtx.Err()
}
}
})
}
return g.Wait()
}

func (ag *Agent) executeWorkflowNotificationsRequest(ctx context.Context, req *cloud.TestWorkflowNotificationsRequest) error {
notificationsCh, err := ag.testWorkflowNotificationsFunc(ctx, req.ExecutionId)
for i := 0; i < testWorkflowNotificationsRetryCount; i++ {
Expand Down Expand Up @@ -306,6 +379,71 @@ func (ag *Agent) executeWorkflowServiceNotificationsRequest(ctx context.Context,
}
}

func (ag *Agent) executeWorkflowParallelStepNotificationsRequest(ctx context.Context, req *cloud.TestWorkflowParallelStepNotificationsRequest) error {
notificationsCh, err := retry.DoWithData(
func() (<-chan testkube.TestWorkflowExecutionNotification, error) {
// We have a race condition here
// Cloud sometimes slow to start service
// while WorkflowNotifications request from websockets comes in faster
// so we retry up to wait till service pod is uo or execution is finished.
return ag.testWorkflowServiceNotificationsFunc(ctx, req.ExecutionId, req.Ref, int(req.ParallelStepIndex))
},
retry.DelayType(retry.FixedDelay),
retry.Delay(logRetryDelay),
retry.RetryIf(func(err error) bool {
return errors.Is(err, registry.ErrResourceNotFound)
}),
retry.UntilSucceeded(),
)

if err != nil {
message := fmt.Sprintf("cannot get service pod logs: %s", err.Error())
ag.testWorkflowParallelStepNotificationsResponseBuffer <- &cloud.TestWorkflowParallelStepNotificationsResponse{
StreamId: req.StreamId,
SeqNo: 0,
Type: cloud.TestWorkflowNotificationType_WORKFLOW_STREAM_ERROR,
Message: fmt.Sprintf("%s %s", time.Now().Format(controller.KubernetesLogTimeFormat), message),
}
return nil
}

for {
var i uint32
select {
case n, ok := <-notificationsCh:
if !ok {
return nil
}
t := getTestWorkflowNotificationType(n)
msg := &cloud.TestWorkflowParallelStepNotificationsResponse{
StreamId: req.StreamId,
SeqNo: i,
Timestamp: n.Ts.Format(time.RFC3339Nano),
Ref: n.Ref,
Type: t,
}
if n.Result != nil {
m, _ := json.Marshal(n.Result)
msg.Message = string(m)
} else if n.Output != nil {
m, _ := json.Marshal(n.Output)
msg.Message = string(m)
} else {
msg.Message = n.Log
}
i++

select {
case ag.testWorkflowParallelStepNotificationsResponseBuffer <- msg:
case <-ctx.Done():
return ctx.Err()
}
case <-ctx.Done():
return ctx.Err()
}
}
}

func (ag *Agent) receiveTestWorkflowNotificationsRequest(ctx context.Context, stream cloud.TestKubeCloudAPI_GetTestWorkflowNotificationsStreamClient) (*cloud.TestWorkflowNotificationsRequest, error) {
respChan := make(chan testWorkflowNotificationsRequest, 1)
go func() {
Expand Down Expand Up @@ -364,6 +502,35 @@ type testWorkflowServiceNotificationsRequest struct {
err error
}

func (ag *Agent) receiveTestWorkflowParallelStepNotificationsRequest(ctx context.Context, stream cloud.TestKubeCloudAPI_GetTestWorkflowParallelStepNotificationsStreamClient) (*cloud.TestWorkflowParallelStepNotificationsRequest, error) {
respChan := make(chan testWorkflowParallelStepNotificationsRequest, 1)
go func() {
cmd, err := stream.Recv()
respChan <- testWorkflowParallelStepNotificationsRequest{resp: cmd, err: err}
}()

var cmd *cloud.TestWorkflowParallelStepNotificationsRequest
select {
case resp := <-respChan:
cmd = resp.resp
err := resp.err

if err != nil {
ag.logger.Errorf("agent stream receive: %v", err)
return nil, err
}
case <-ctx.Done():
return nil, ctx.Err()
}

return cmd, nil
}

type testWorkflowParallelStepNotificationsRequest struct {
resp *cloud.TestWorkflowParallelStepNotificationsRequest
err error
}

func (ag *Agent) sendTestWorkflowNotificationsResponse(ctx context.Context, stream cloud.TestKubeCloudAPI_GetTestWorkflowNotificationsStreamClient, resp *cloud.TestWorkflowNotificationsResponse) error {
errChan := make(chan error, 1)
go func() {
Expand Down Expand Up @@ -413,3 +580,28 @@ func (ag *Agent) sendTestWorkflowServiceNotificationsResponse(ctx context.Contex
return errors.New("send response too slow")
}
}

func (ag *Agent) sendTestWorkflowParallelStepNotificationsResponse(ctx context.Context, stream cloud.TestKubeCloudAPI_GetTestWorkflowParallelStepNotificationsStreamClient, resp *cloud.TestWorkflowParallelStepNotificationsResponse) error {
errChan := make(chan error, 1)
go func() {
errChan <- stream.Send(resp)
close(errChan)
}()

t := time.NewTimer(ag.sendTimeout)
select {
case err := <-errChan:
if !t.Stop() {
<-t.C
}
return err
case <-ctx.Done():
if !t.Stop() {
<-t.C
}

return ctx.Err()
case <-t.C:
return errors.New("send response too slow")
}
}

0 comments on commit ce5215f

Please sign in to comment.