Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retries to evaluator client #306

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion app/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (a *Agent) completeWithRetries(ctx context.Context, req *v1alpha1.GenerateR
if len(blocks) == 0 {
assertBlocks.Result = v1alpha1.AssertResult_FAILED
}
log.Info(logs.Level1Assertion, "assertion", assertion)
log.Info(logs.Level1Assertion, "assertion", assertBlocks)
return blocks, nil
}
err := errors.Errorf("Failed to generate a chat completion after %d tries", maxTries)
Expand Down Expand Up @@ -450,6 +450,12 @@ func (a *Agent) StreamGenerate(ctx context.Context, stream *connect.BidiStream[v
// Terminate because the request got cancelled
case <-ctx.Done():
log.Info("Context cancelled; stopping streaming request", "err", ctx.Err())
if errors.Is(ctx.Err(), context.Canceled) || errors.Is(ctx.Err(), context.DeadlineExceeded) {
// N.B. If the context was cancelled then we should return a DeadlineExceeded error to indicate we hit
// a timeout on the server.
// My assumption is if the client terminates the connection there is a different error.
return connect.NewError(connect.CodeDeadlineExceeded, errors.Wrapf(ctx.Err(), "The request context was cancelled. This usually happens because the read or write timeout of the HTTP server was reched."))
}
// Cancel functions will be called when this function returns
return ctx.Err()
case s := <-statusChan:
Expand Down Expand Up @@ -486,6 +492,13 @@ func (a *Agent) GenerateCells(ctx context.Context, req *connect.Request[v1alpha1
agentResp, err := a.Generate(ctx, agentReq)
if err != nil {
log.Error(err, "Agent.Generate failed")
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// N.B. If the context was cancelled then we should return a DeadlineExceeded error to indicate we hit
// a timeout on the server.
// My assumption is if the client terminates the connection there is a different error.
err := errors.Wrapf(err, "Agent.Generate failed; traceId %s. \"The request context was cancelled. This usually happens because the read or write timeout of the HTTP server was reached.", span.SpanContext().TraceID().String())
return nil, connect.NewError(connect.CodeDeadlineExceeded, err)
}
err := errors.Wrapf(err, "Agent.Generate failed; traceId %s", span.SpanContext().TraceID().String())
return nil, err
}
Expand Down
62 changes: 62 additions & 0 deletions app/pkg/agent/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package agent

import (
"context"
"errors"
"time"

"connectrpc.com/connect"
)

// RetryInterceptor defines a retry interceptor
type RetryInterceptor struct {
MaxRetries int
Backoff time.Duration
}

func (r *RetryInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
// We return a function that will wrap the next function call in a try loop.
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
var resp connect.AnyResponse
var err error

for i := 0; i <= r.MaxRetries; i++ {
resp, err = next(ctx, req)

// If no error, return the response
if err == nil {
return resp, nil
}

// Check if the error is a DeadlineExceeded or Cancelled
// Check if the error is a DeadlineExceeded or Canceled
var connectErr *connect.Error
if errors.As(err, &connectErr) {
code := connectErr.Code()
if code == connect.CodeDeadlineExceeded || code == connect.CodeCanceled {
// Delay before retrying
time.Sleep(r.Backoff)
continue
}
}

// For other errors, return immediately
return nil, err
}

// After max retries, return the last error
return nil, err
}
}

// WrapStreamingClient implements [Interceptor] with a no-op.
func (r *RetryInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
// TODO(jeremy): We should implement this
return next
}

// WrapStreamingHandler implements [Interceptor] with a no-op.
func (r *RetryInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
// TODO(jeremy): We should implement this
return next
}
133 changes: 133 additions & 0 deletions app/pkg/agent/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package agent

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"testing"
"time"

"connectrpc.com/connect"
"github.com/go-logr/zapr"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1/v1alpha1connect"
"github.com/jlewi/monogo/networking"
"github.com/pkg/errors"
parserv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/parser/v1"
"go.uber.org/zap"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)

type FakeAgent struct {
*v1alpha1connect.UnimplementedAIServiceHandler
numTries int
}

func (f *FakeAgent) GenerateCells(ctx context.Context, req *connect.Request[v1alpha1.GenerateCellsRequest]) (*connect.Response[v1alpha1.GenerateCellsResponse], error) {
f.numTries += 1
if f.numTries < 2 {
return nil, connect.NewError(connect.CodeDeadlineExceeded, errors.New("Deadline exceeded"))
}

resp := &v1alpha1.GenerateCellsResponse{
Cells: []*parserv1.Cell{
{
Kind: parserv1.CellKind_CELL_KIND_MARKUP,
},
},
}
return connect.NewResponse(resp), nil
}

func setupAndRunFakeServer(addr string, a *FakeAgent) (*http.Server, error) {
log := zapr.NewLogger(zap.L())
mux := http.NewServeMux()
path, handler := v1alpha1connect.NewAIServiceHandler(a)
mux.Handle(path, handler)

srv := &http.Server{
Addr: addr,
// NB that we are using h2c here to support HTTP/2 without TLS
// bidirectional streaming requires HTTP/2
Handler: h2c.NewHandler(mux, &http2.Server{}),
}

// Graceful shutdown setup
idleConnsClosed := make(chan struct{})
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt)
<-sigint

log.Info("Shutting down server...")

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Info("HTTP server Shutdown: %v", err)
}
close(idleConnsClosed)
}()

go func() {
log.Info("Server starting on ", "address", addr)
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
log.Error(err, "Server ListenAndServe error")
}

<-idleConnsClosed
log.Info("Server stopped")
}()
return srv, nil
}

func Test_RetryInterceptor(t *testing.T) {
port, err := networking.GetFreePort()
if err != nil {
t.Fatalf("Error getting free port: %v", err)
}

addr := fmt.Sprintf("localhost:%d", port)

fake := &FakeAgent{}
srv, err := setupAndRunFakeServer(addr, fake)
if err != nil {
t.Fatalf("Error starting server: %v", err)
}
baseURL := fmt.Sprintf("http://%s", addr)
client := v1alpha1connect.NewAIServiceClient(
&http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
// Use the standard Dial function to create a plain TCP connection
return net.Dial(network, addr)
},
},
},
baseURL,
connect.WithInterceptors(&RetryInterceptor{
MaxRetries: 3,
Backoff: 10 * time.Millisecond,
}),
)

// First call should fail but the interceptor should retry
resp, err := client.GenerateCells(context.Background(), connect.NewRequest(&v1alpha1.GenerateCellsRequest{}))
if err != nil {
t.Fatalf("Error calling GenerateCells: %v", err)
}

if len(resp.Msg.Cells) != 1 {
t.Fatalf("Expected 1 cell but got: %v", len(resp.Msg.Cells))
}

if err := srv.Shutdown(context.Background()); err != nil {
t.Logf("Error shutting down server: %v", err)
}
}
9 changes: 8 additions & 1 deletion app/pkg/eval/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,14 @@ func (e *Evaluator) Reconcile(ctx context.Context, experiment api.Experiment) er
return errors.Wrapf(err, "Failed to create OpenTelemetry interceptor")
}

aiClient := newAIServiceClient(experiment.Spec.AgentAddress, connect.WithInterceptors(otelInterceptor))
// Handle retries for the AI service.
// This should help with requests ocassionally timing out.
retryer := &agent.RetryInterceptor{
MaxRetries: 3,
Backoff: 5 * time.Second,
}

aiClient := newAIServiceClient(experiment.Spec.AgentAddress, connect.WithInterceptors(otelInterceptor, retryer))

logsClient := logspbconnect.NewLogsServiceClient(
newHTTPClient(),
Expand Down
Loading