Skip to content

Commit

Permalink
Apply review feedback; refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
alpe committed Jan 15, 2024
1 parent 6001e97 commit 9b99c6c
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 60 deletions.
13 changes: 2 additions & 11 deletions pkg/endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,28 +132,19 @@ func (g *endpointGroup) broadcastEndpoints() {
g.bcast = make(chan struct{})
}

func (e *endpointGroup) AddInflight(addr string, cancelRequest context.CancelFunc) (func(), error) {
func (e *endpointGroup) AddInflight(addr string) (func(), error) {
tokens := strings.Split(addr, ":")
if len(tokens) != 2 {
return nil, errors.New("unsupported address format")
}
e.mtx.RLock()
defer e.mtx.RUnlock()
endpoint, ok := e.endpoints[tokens[0]]
e.mtx.RUnlock()
if !ok {
return nil, errors.New("unsupported endpoint address")
}
endpoint.inFlight.Add(1)
done := make(chan struct{})
go func() {
select {
case <-done:
case <-endpoint.terminated:
cancelRequest()
}
}()
return func() {
close(done)
endpoint.inFlight.Add(-1)
}, nil
}
6 changes: 2 additions & 4 deletions pkg/endpoints/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ func (r *Manager) GetAllHosts(service, portName string) []string {
return r.getEndpoints(service).getAllHosts(portName)
}

func (r *Manager) RegisterInFlight(ctx context.Context, service string, hostAddr string) (context.Context, func(), error) {
ctx, cancel := context.WithCancel(ctx)
completed, err := r.getEndpoints(service).AddInflight(hostAddr, cancel)
return ctx, completed, err
func (r *Manager) RegisterInFlight(service string, hostAddr string) (func(), error) {
return r.getEndpoints(service).AddInflight(hostAddr)
}
19 changes: 13 additions & 6 deletions pkg/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,32 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
log.Printf("Got host: %v, id: %v\n", host, id)

done, err := h.Endpoints.RegisterInFlight(deploy, host)
if err != nil {
log.Printf("error registering in-flight request: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer done()

log.Printf("Proxying request to host %v: %v\n", host, id)
// TODO: Avoid creating new reverse proxies for each request.
// TODO: Consider implementing a round robin scheme.
log.Printf("Proxying request to host %v: %v\n", host, id)
middleware := withCancelDeadTargets(h.Endpoints, deploy, host)
middleware(newReverseProxy(host)).ServeHTTP(w, proxyRequest)
newReverseProxy(host).ServeHTTP(w, proxyRequest)
}

func withCancelDeadTargets(endpoints *endpoints.Manager, deploy string, host string) func(other http.Handler) http.HandlerFunc {
func withInflightCounted(endpoints *endpoints.Manager, deploy string, host string) func(other http.Handler) http.HandlerFunc {
return func(other http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
newCtx, done, err := endpoints.RegisterInFlight(r.Context(), deploy, host)
done, err := endpoints.RegisterInFlight(deploy, host)
if err != nil {
log.Printf("error registering in-flight request: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer done()

other.ServeHTTP(w, r.Clone(newCtx))
other.ServeHTTP(w, r)
}
}
}
Expand Down
14 changes: 3 additions & 11 deletions pkg/proxy/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package proxy

import (
"context"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -23,7 +21,7 @@ func TestProxy(t *testing.T) {
}{
"ok": {
request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)),
expCode: http.StatusBadGateway,
expCode: http.StatusOK,
},
}
for name, spec := range specs {
Expand All @@ -41,8 +39,7 @@ func TestProxy(t *testing.T) {

svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
em.SetEndpoints("my-deployment", map[string]struct{}{"my-other-ip": {}}, map[string]int32{"my-other-port": 8080})
time.Sleep(time.Millisecond)
w.WriteHeader(999)
w.WriteHeader(http.StatusOK)
}))
recorder := httptest.NewRecorder()

Expand All @@ -51,12 +48,7 @@ func TestProxy(t *testing.T) {
}

// when
// newCtx, cancel := context.WithCancel(spec.request.Context())
// cancel()
// newCtx, _ := context.WithTimeout(spec.request.Context(), time.Nanosecond)
newCtx := context.Background()

h.ServeHTTP(recorder, spec.request.Clone(newCtx))
h.ServeHTTP(recorder, spec.request)
// then
assert.Equal(t, spec.expCode, recorder.Code)
})
Expand Down
102 changes: 74 additions & 28 deletions pkg/proxy/middleware.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package proxy

import (
"bytes"
"log"
"io"
"math/rand"
"net/http"
"time"
Expand All @@ -28,50 +27,97 @@ func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware {
}

func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
var capturedResp *responseBuffer
var capturedResp *responseWriterDelegator
for i := 0; ; i++ {
capturedResp = &responseBuffer{
header: make(http.Header),
body: bytes.NewBuffer([]byte{}),
capturedResp = &responseWriterDelegator{
ResponseWriter: writer,
headerBuf: make(http.Header),
discardErrResp: i < r.MaxRetries &&
request.Context().Err() == nil, // abort early on timeout, context cancel
}
// call next handler in chain
r.nextHandler.ServeHTTP(capturedResp, request.Clone(request.Context()))

if i == r.MaxRetries || // max retries reached
request.Context().Err() != nil || // abort early on timeout, context cancel
capturedResp.status != http.StatusBadGateway &&
capturedResp.status != http.StatusServiceUnavailable {
if !capturedResp.discardErrResp || // max retries reached
!isRetryableStatusCode(capturedResp.statusCode) {
break
}
totalRetries.Inc()
// Exponential backoff
jitter := time.Duration(r.rSource.Intn(50))
time.Sleep(time.Millisecond*time.Duration(1<<uint(i)) + jitter)
}
for k, v := range capturedResp.header {
writer.Header()[k] = v
}
writer.WriteHeader(capturedResp.status)
if _, err := capturedResp.body.WriteTo(writer); err != nil {
log.Printf("response write: %v", err)
}
}

type responseBuffer struct {
header http.Header
body *bytes.Buffer
status int
func isRetryableStatusCode(status int) bool {
return status == http.StatusBadGateway ||
status == http.StatusServiceUnavailable ||
status == http.StatusGatewayTimeout
}

var (
_ http.Flusher = &responseWriterDelegator{}
_ io.ReaderFrom = &responseWriterDelegator{}
)

type responseWriterDelegator struct {
http.ResponseWriter
headerBuf http.Header
wroteHeader bool
statusCode int
// always writes to responseWriter when false
discardErrResp bool
}

func (rb *responseBuffer) Header() http.Header {
return rb.header
func (r *responseWriterDelegator) Header() http.Header {
return r.headerBuf
}

func (r *responseBuffer) WriteHeader(status int) {
r.status = status
r.header = r.Header().Clone()
func (r *responseWriterDelegator) WriteHeader(status int) {
r.statusCode = status
if !r.wroteHeader {
r.wroteHeader = true
// any 1xx informational response should be written
r.discardErrResp = r.discardErrResp && !(status >= 100 && status < 200)
}
if r.discardErrResp && isRetryableStatusCode(status) {
return
}
// copy header values to target
for k, vals := range r.headerBuf {
for _, val := range vals {
r.ResponseWriter.Header().Add(k, val)
}
}
r.ResponseWriter.WriteHeader(status)
}

func (rb *responseBuffer) Write(data []byte) (int, error) {
return rb.body.Write(data)
func (r *responseWriterDelegator) Write(data []byte) (int, error) {
// ensure header is set. default is 200 in Go stdlib
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
if r.discardErrResp && isRetryableStatusCode(r.statusCode) {
return io.Discard.Write(data)
} else {
return r.ResponseWriter.Write(data)
}
}

func (r *responseWriterDelegator) ReadFrom(re io.Reader) (int64, error) {
// ensure header is set. default is 200 in Go stdlib
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
if r.discardErrResp && isRetryableStatusCode(r.statusCode) {
return io.Copy(io.Discard, re)
} else {
return r.ResponseWriter.(io.ReaderFrom).ReadFrom(re)
}
}

func (r *responseWriterDelegator) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
6 changes: 6 additions & 0 deletions pkg/proxy/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ func TestServeHTTP(t *testing.T) {
respStatus: http.StatusBadGateway,
expRetries: 3,
},
"not buffered on 100": {
context: func() context.Context { return context.TODO() },
maxRetries: 3,
respStatus: http.StatusContinue,
expRetries: 0,
},
"context cancelled": {
context: func() context.Context {
ctx, cancel := context.WithCancel(context.TODO())
Expand Down

0 comments on commit 9b99c6c

Please sign in to comment.