Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better concurrent request handling for model host address #38

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/autoscaler/autoscaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ func (a *Autoscaler) Start() {
log.Println("Calculating scales for all")

// TODO: Remove hardcoded Service lookup by name "lingo".
otherLingoEndpoints := a.Endpoints.GetAllHosts(context.Background(), "lingo", "stats")
otherLingoEndpoints, err := a.Endpoints.GetAllHosts(context.Background(), "lingo", "stats")
if err != nil {
log.Printf("Failed to find endpoints: %v", err)
continue
}

stats, errs := aggregateStats(stats.Stats{
ActiveRequests: a.Queues.TotalCounts(),
Expand Down
56 changes: 38 additions & 18 deletions pkg/endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func newEndpointGroup() *endpointGroup {
e := &endpointGroup{}
e.ports = make(map[string]int32)
e.endpoints = make(map[string]endpoint)
e.active = sync.NewCond(&e.mtx)
e.active = sync.NewCond(&sync.Mutex{})
return e
}

Expand All @@ -19,19 +19,27 @@ type endpoint struct {
}

type endpointGroup struct {
mtx sync.RWMutex
ports map[string]int32
endpoints map[string]endpoint
active *sync.Cond
mtx sync.Mutex

active *sync.Cond
}

func (e *endpointGroup) getHost(portName string) string {
e.mtx.Lock()
defer e.mtx.Unlock()

for len(e.endpoints) == 0 {
e.active.Wait()
}
// getBestHost returns the best host for the given port name. It blocks until there are available endpoints
// in the endpoint group.
//
// It selects the host with the minimum in-flight requests among all the available endpoints.
// The host is returned as a string in the format "IP:Port".
//
// Parameters:
// - portName: The name of the port for which the best host needs to be determined.
//
// Returns:
// - string: The best host with the minimum in-flight requests.
func (e *endpointGroup) getBestHost(portName string) string {
alpe marked this conversation as resolved.
Show resolved Hide resolved
e.mtx.RLock()
e.awaitAnyEndpointsExist()

var bestIP string
port := e.getPort(portName)
Expand All @@ -43,13 +51,25 @@ func (e *endpointGroup) getHost(portName string) string {
minInFlight = inFlight
}
}

e.mtx.RUnlock()
return fmt.Sprintf("%s:%v", bestIP, port)
}

func (e *endpointGroup) awaitAnyEndpointsExist() {
for len(e.endpoints) == 0 {
e.mtx.RUnlock()
// await update notification
e.active.L.Lock()
e.active.Wait()
e.active.L.Unlock()
// proceed
e.mtx.RLock()
}
}

func (e *endpointGroup) getAllHosts(portName string) []string {
e.mtx.Lock()
defer e.mtx.Unlock()
e.mtx.RLock()
defer e.mtx.RUnlock()

var hosts []string
port := e.getPort(portName)
Expand All @@ -70,15 +90,13 @@ func (e *endpointGroup) getPort(portName string) int32 {
}

func (g *endpointGroup) lenIPs() int {
g.mtx.Lock()
defer g.mtx.Unlock()
g.mtx.RLock()
defer g.mtx.RUnlock()
return len(g.endpoints)
}

func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32) {
g.mtx.Lock()
defer g.mtx.Unlock()

g.ports = ports
for ip := range ips {
if _, ok := g.endpoints[ip]; !ok {
Expand All @@ -90,8 +108,10 @@ func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32)
delete(g.endpoints, ip)
}
}
g.mtx.Unlock()

if len(g.endpoints) > 0 {
// notify waiting requests
if len(ips) > 0 {
g.active.Broadcast()
}
}
54 changes: 54 additions & 0 deletions pkg/endpoints/endpoints_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package endpoints

import (
"sync"
"testing"

"k8s.io/apimachinery/pkg/util/rand"
)

func TestConcurrentAccess(t *testing.T) {
const myService = "myService"
const myPort = "myPort"

testCases := map[string]struct {
readerCount int
writerCount int
}{
"lot of reader": {readerCount: 10_000, writerCount: 1},
"lot of writer": {readerCount: 1, writerCount: 10_000},
"lot of both": {readerCount: 10_000, writerCount: 10_000},
}

for name, spec := range testCases {
t.Run(name, func(t *testing.T) {
endpoint := newEndpointGroup()
endpoint.setIPs(
map[string]struct{}{myService: {}},
map[string]int32{myPort: 1},
)

var startWg, doneWg sync.WaitGroup
startTogether := func(n int, f func()) {
startWg.Add(n)
doneWg.Add(n)
for i := 0; i < n; i++ {
go func() {
startWg.Done()
startWg.Wait()
f()
doneWg.Done()
}()
}
}
startTogether(spec.readerCount, func() { endpoint.getBestHost(myPort) })
startTogether(spec.writerCount, func() {
endpoint.setIPs(
map[string]struct{}{rand.String(1): {}},
map[string]int32{rand.String(1): 1},
)
})
doneWg.Wait()
})
}
}
29 changes: 25 additions & 4 deletions pkg/endpoints/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,31 @@ func (r *Manager) getEndpoints(service string) *endpointGroup {
return e
}

func (r *Manager) GetHost(ctx context.Context, service, portName string) string {
return r.getEndpoints(service).getHost(portName)
// AwaitHostAddress returns the host address with the lowest number of in-flight requests. It will block until the host address
// becomes available or the context times out.
//
// It returns a string in the format "host:port" or error on timeout
func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string) (string, error) {
return execWithTimeout(ctx, func() string { return r.getEndpoints(service).getBestHost(portName) })
}

func (r *Manager) GetAllHosts(ctx context.Context, service, portName string) []string {
return r.getEndpoints(service).getAllHosts(portName)
// GetAllHosts retrieves the list of all hosts for a given service and port.
// It returns a slice of strings representing the hosts or an error if any context timeout occurred.
func (r *Manager) GetAllHosts(ctx context.Context, service, portName string) ([]string, error) {
return execWithTimeout(ctx, func() []string { return r.getEndpoints(service).getAllHosts(portName) })
}

func execWithTimeout[T any](ctx context.Context, f func() T) (T, error) {
resultChan := make(chan T)
defer close(resultChan)
go func() {
resultChan <- f()
alpe marked this conversation as resolved.
Show resolved Hide resolved
}()
var result T
select {
case <-ctx.Done():
return result, ctx.Err()
case result = <-resultChan:
return result, nil
}
}
59 changes: 59 additions & 0 deletions pkg/endpoints/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package endpoints

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAwaitBestHost(t *testing.T) {
const myService = "myService"
const myPort = "myPort"

manager := &Manager{endpoints: make(map[string]*endpointGroup, 1)}
manager.getEndpoints(myService).
setIPs(map[string]struct{}{myService: {}}, map[string]int32{myPort: 1})

testCases := map[string]struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these would be better written as individual tests instead of a table of tests cases. They each test different things. For example: for the timeout example it would be good to assert that the returned error is due to context cancellation and this code would only be used for that test case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the error check a bit vague but IMHO it makes sense to have a spec for the methods that defines all cases. I find it more readable.
But to be fair, I use table tests as my default structure for unit tests and may be biased. If this is very important for you, I can refactor. The error type is checked now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a strong opinion, good with this.

service string
portName string
timeout time.Duration
expErr bool
}{
"all good": {
service: myService,
portName: myPort,
timeout: time.Millisecond,
},
"unknown port - returns default if only 1": {
service: myService,
portName: "unknownPort",
timeout: time.Millisecond,
},
"unknown service - blocks until timeout": {
service: "unknownService",
portName: myPort,
timeout: time.Millisecond,
expErr: true,
},
// not covered: unknown port with multiple ports on entrypoint
}

for name, spec := range testCases {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), spec.timeout)
defer cancel()

gotHost, gotErr := manager.AwaitHostAddress(ctx, spec.service, spec.portName)
if spec.expErr {
require.Error(t, gotErr)
return
}
require.NoError(t, gotErr)
assert.Equal(t, myService+":1", gotHost)
})
}
}
8 changes: 7 additions & 1 deletion pkg/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer complete()

log.Println("Waiting for IPs", id)
host := h.Endpoints.GetHost(r.Context(), deploy, "http")
host, err := h.Endpoints.AwaitHostAddress(r.Context(), deploy, "http")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more robust to check this error instead of assuming is it a timeout error. In the future the error logic in the invoked function might be updated to return different error types but this call-site might not be reconsidered. Also, it is not always a timeout today: if the caller cancels the request the context will cancel (not technically a timeout).

if err != nil {
log.Printf("timeout finding the host address %v", err)
w.WriteHeader(http.StatusRequestTimeout)
w.Write([]byte(fmt.Sprintf("Request timed out for model: %v", modelName)))
return
}
log.Printf("Got host: %v, id: %v\n", host, id)

// TODO: Avoid creating new reverse proxies for each request.
Expand Down