Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better concurrent request handling for model host address #38

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ jobs:
uses: actions/setup-go@v4
with:
go-version: '>=1.21.0'
- name: Run race tests
run: make test-race
- name: Run integration tests
run: make test-integration
14 changes: 14 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
ENVTEST_K8S_VERSION = 1.27.1

.PHONY: test
test: test-unit

.PHONY: test-all
test-all: test-race test-integration

.PHONY: test-unit
test-unit:
go test -mod=readonly ./pkg/...

.PHONY: test-race
test-race:
go test -mod=readonly -race ./pkg/...

.PHONY: test-integration
test-integration: envtest
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./tests/integration -v

Expand Down
2 changes: 1 addition & 1 deletion pkg/autoscaler/autoscaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (a *Autoscaler) Start() {
log.Println("Calculating scales for all")

// TODO: Remove hardcoded Service lookup by name "lingo".
otherLingoEndpoints := a.Endpoints.GetAllHosts(context.Background(), "lingo", "stats")
otherLingoEndpoints := a.Endpoints.GetAllHosts("lingo", "stats")

stats, errs := aggregateStats(stats.Stats{
ActiveRequests: a.Queues.TotalCounts(),
Expand Down
68 changes: 50 additions & 18 deletions pkg/endpoints/endpoints.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package endpoints

import (
"context"
"fmt"
"sync"
"sync/atomic"
Expand All @@ -10,7 +11,7 @@ func newEndpointGroup() *endpointGroup {
e := &endpointGroup{}
e.ports = make(map[string]int32)
e.endpoints = make(map[string]endpoint)
e.active = sync.NewCond(&e.mtx)
e.bcast = make(chan struct{})
return e
}

Expand All @@ -19,20 +20,37 @@ type endpoint struct {
}

type endpointGroup struct {
mtx sync.RWMutex
ports map[string]int32
endpoints map[string]endpoint
active *sync.Cond
mtx sync.Mutex
}

func (e *endpointGroup) getHost(portName string) string {
e.mtx.Lock()
defer e.mtx.Unlock()
bmtx sync.RWMutex
bcast chan struct{} // closed when there's a broadcast
}

// getBestHost returns the best host for the given port name. It blocks until there are available endpoints
// in the endpoint group.
//
// It selects the host with the minimum in-flight requests among all the available endpoints.
// The host is returned as a string in the format "IP:Port".
//
// Parameters:
// - portName: The name of the port for which the best host needs to be determined.
//
// Returns:
// - string: The best host with the minimum in-flight requests.
func (e *endpointGroup) getBestHost(ctx context.Context, portName string) (string, error) {
e.mtx.RLock()
// await endpoints exists
for len(e.endpoints) == 0 {
e.active.Wait()
e.mtx.RUnlock()
select {
case <-e.awaitEndpoints():
case <-ctx.Done():
return "", ctx.Err()
}
e.mtx.RLock()
}

var bestIP string
port := e.getPort(portName)
var minInFlight int
Expand All @@ -43,13 +61,19 @@ func (e *endpointGroup) getHost(portName string) string {
minInFlight = inFlight
}
}
e.mtx.RUnlock()
return fmt.Sprintf("%s:%v", bestIP, port), nil
}

return fmt.Sprintf("%s:%v", bestIP, port)
func (e *endpointGroup) awaitEndpoints() chan struct{} {
e.bmtx.RLock()
defer e.bmtx.RUnlock()
return e.bcast
}

func (e *endpointGroup) getAllHosts(portName string) []string {
e.mtx.Lock()
defer e.mtx.Unlock()
e.mtx.RLock()
defer e.mtx.RUnlock()

var hosts []string
port := e.getPort(portName)
Expand All @@ -70,15 +94,13 @@ func (e *endpointGroup) getPort(portName string) int32 {
}

func (g *endpointGroup) lenIPs() int {
g.mtx.Lock()
defer g.mtx.Unlock()
g.mtx.RLock()
defer g.mtx.RUnlock()
return len(g.endpoints)
}

func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32) {
g.mtx.Lock()
defer g.mtx.Unlock()

g.ports = ports
for ip := range ips {
if _, ok := g.endpoints[ip]; !ok {
Expand All @@ -90,8 +112,18 @@ func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32)
delete(g.endpoints, ip)
}
}
g.mtx.Unlock()

if len(g.endpoints) > 0 {
g.active.Broadcast()
// notify waiting requests
if len(ips) > 0 {
g.broadcastEndpoints()
}
}

func (g *endpointGroup) broadcastEndpoints() {
g.bmtx.Lock()
defer g.bmtx.Unlock()

close(g.bcast)
g.bcast = make(chan struct{})
}
117 changes: 117 additions & 0 deletions pkg/endpoints/endpoints_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package endpoints

import (
"context"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"k8s.io/apimachinery/pkg/util/rand"
)

func TestConcurrentAccess(t *testing.T) {
const myService = "myService"
const myPort = "myPort"

testCases := map[string]struct {
readerCount int
writerCount int
}{
"lot of reader": {readerCount: 1_000, writerCount: 1},
"lot of writer": {readerCount: 1, writerCount: 1_000},
"lot of both": {readerCount: 1_000, writerCount: 1_000},
}
for name, spec := range testCases {
randomReadFn := []func(g *endpointGroup){
func(g *endpointGroup) { g.getBestHost(nil, myPort) },
func(g *endpointGroup) { g.getAllHosts(myPort) },
func(g *endpointGroup) { g.lenIPs() },
}
t.Run(name, func(t *testing.T) {
// setup endpoint with one service so that requests are not waiting
endpoint := newEndpointGroup()
endpoint.setIPs(
map[string]struct{}{myService: {}},
map[string]int32{myPort: 1},
)

var startWg, doneWg sync.WaitGroup
startTogether := func(n int, f func()) {
startWg.Add(n)
doneWg.Add(n)
for i := 0; i < n; i++ {
go func() {
startWg.Done()
startWg.Wait()
f()
doneWg.Done()
}()
}
}
// when
startTogether(spec.readerCount, func() { randomReadFn[rand.Intn(len(randomReadFn)-1)](endpoint) })
startTogether(spec.writerCount, func() {
endpoint.setIPs(
map[string]struct{}{rand.String(1): {}},
map[string]int32{rand.String(1): 1},
)
})
doneWg.Wait()
})
}
}

func TestBlockAndWaitForEndpoints(t *testing.T) {
var completed atomic.Int32
var startWg, doneWg sync.WaitGroup
startTogether := func(n int, f func()) {
startWg.Add(n)
doneWg.Add(n)
for i := 0; i < n; i++ {
go func() {
startWg.Done()
startWg.Wait()
f()
completed.Add(1)
doneWg.Done()
}()
}
}
endpoint := newEndpointGroup()
ctx := context.TODO()
startTogether(100, func() {
endpoint.getBestHost(ctx, rand.String(4))
})
startWg.Wait()

// when broadcast triggered
endpoint.setIPs(
map[string]struct{}{rand.String(4): {}},
map[string]int32{rand.String(4): 1},
)
// then
doneWg.Wait()
assert.Equal(t, int32(100), completed.Load())
}

func TestAbortOnCtxCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

var startWg, doneWg sync.WaitGroup
startWg.Add(1)
doneWg.Add(1)
go func(t *testing.T) {
startWg.Wait()
endpoint := newEndpointGroup()
_, err := endpoint.getBestHost(ctx, rand.String(4))
require.Error(t, err)
doneWg.Done()
}(t)
startWg.Done()
cancel()

doneWg.Wait()
}
20 changes: 20 additions & 0 deletions pkg/endpoints/endponts_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package endpoints

import (
"context"
"testing"
)

func BenchmarkEndpointGroup(b *testing.B) {
e := newEndpointGroup()
e.setIPs(map[string]struct{}{"10.0.0.1": {}}, map[string]int32{"testPort": 1})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := e.getBestHost(context.Background(), "testPort")
if err != nil {
b.Fatal(err)
}
}
})
}
11 changes: 8 additions & 3 deletions pkg/endpoints/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,15 @@ func (r *Manager) getEndpoints(service string) *endpointGroup {
return e
}

func (r *Manager) GetHost(ctx context.Context, service, portName string) string {
return r.getEndpoints(service).getHost(portName)
// AwaitHostAddress returns the host address with the lowest number of in-flight requests. It will block until the host address
// becomes available or the context times out.
//
// It returns a string in the format "host:port" or error on timeout
func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string) (string, error) {
return r.getEndpoints(service).getBestHost(ctx, portName)
}

func (r *Manager) GetAllHosts(ctx context.Context, service, portName string) []string {
// GetAllHosts retrieves the list of all hosts for a given service and port.
func (r *Manager) GetAllHosts(service, portName string) []string {
return r.getEndpoints(service).getAllHosts(portName)
}
59 changes: 59 additions & 0 deletions pkg/endpoints/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package endpoints

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAwaitBestHost(t *testing.T) {
const myService = "myService"
const myPort = "myPort"

manager := &Manager{endpoints: make(map[string]*endpointGroup, 1)}
manager.getEndpoints(myService).
setIPs(map[string]struct{}{myService: {}}, map[string]int32{myPort: 1})

testCases := map[string]struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these would be better written as individual tests instead of a table of tests cases. They each test different things. For example: for the timeout example it would be good to assert that the returned error is due to context cancellation and this code would only be used for that test case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the error check a bit vague but IMHO it makes sense to have a spec for the methods that defines all cases. I find it more readable.
But to be fair, I use table tests as my default structure for unit tests and may be biased. If this is very important for you, I can refactor. The error type is checked now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a strong opinion, good with this.

service string
portName string
timeout time.Duration
expErr error
}{
"all good": {
service: myService,
portName: myPort,
timeout: time.Millisecond,
},
"unknown port - returns default if only 1": {
service: myService,
portName: "unknownPort",
timeout: time.Millisecond,
},
"unknown service - blocks until timeout": {
service: "unknownService",
portName: myPort,
timeout: time.Millisecond,
expErr: context.DeadlineExceeded,
},
// not covered: unknown port with multiple ports on entrypoint
}

for name, spec := range testCases {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), spec.timeout)
defer cancel()

gotHost, gotErr := manager.AwaitHostAddress(ctx, spec.service, spec.portName)
if spec.expErr != nil {
require.ErrorIs(t, spec.expErr, gotErr)
return
}
require.NoError(t, gotErr)
assert.Equal(t, myService+":1", gotHost)
})
}
}
Loading
Loading