Skip to content

Commit

Permalink
Checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
nstogner committed Dec 3, 2024
1 parent 5726cb2 commit 4cf6ecf
Show file tree
Hide file tree
Showing 17 changed files with 335 additions and 233 deletions.
27 changes: 27 additions & 0 deletions api/v1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ type ModelSpec struct {
// DEPRECATED.
// +kubebuilder:validation:Optional
Owner string `json:"owner"`

// LoadBalancing configuration for the model.
// If not specified, a default is used based on the engine and request.
LoadBalancing *LoadBalancing `json:"loadBalancing,omitempty"`
}

// +kubebuilder:validation:Enum=TextGeneration;TextEmbedding;SpeechToText
Expand Down Expand Up @@ -144,6 +148,29 @@ type Adapter struct {
URL string `json:"url"`
}

type LoadBalancing struct {
Strategy LoadBalancingStrategy `json:"strategy"`
CHWBL *CHWBL `json:"chwbl,omitempty"`
}

// +kubebuilder:validation:Enum=LeastLoad;CHWBL
type LoadBalancingStrategy string

const (
LeastLoadStrategy LoadBalancingStrategy = "LeastLoad"
CHWBLStrategy LoadBalancingStrategy = "CHWBL"
)

type CHWBL struct {
// MeanLoadFactor is the multiple that any given endpoint's load must not exceed
// over the mean load of all endpoints in the hash ring.
MeanLoadFactor float64 `json:"meanLoadFactor"`
// Replication is the number of replicas of each endpoint on the hash ring.
// Higher values will result in a more even distribution of load but will
// decrease lookup performance.
Replication int `json:"replication"`
}

// ModelStatus defines the observed state of Model.
type ModelStatus struct {
Replicas ModelStatusReplicas `json:"replicas,omitempty"`
Expand Down
40 changes: 40 additions & 0 deletions api/v1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package endpoints
package loadbalancer

import (
"fmt"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package endpoints
package loadbalancer

func (g *group) getAddrLeastLoad(adapter string) (endpoint, bool) {
var bestEp endpoint
Expand Down
20 changes: 11 additions & 9 deletions internal/endpoints/group.go → internal/loadbalancer/group.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package endpoints
package loadbalancer

import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"

v1 "github.com/substratusai/kubeai/api/v1"
)

func newEndpointGroup() *group {
Expand Down Expand Up @@ -49,7 +51,7 @@ const (

// getBestAddr returns the best "IP:Port". It blocks until there are available endpoints
// in the endpoint group.
func (g *group) getBestAddr(ctx context.Context, strategy LoadBalancingStrategy, adapter string, awaitChangeEndpoints bool) (string, func(), error) {
func (g *group) getBestAddr(ctx context.Context, req AddressRequest, awaitChangeEndpoints bool) (string, func(), error) {
g.mtx.RLock()
// await endpoints exists
for awaitChangeEndpoints || len(g.endpoints) == 0 {
Expand All @@ -64,19 +66,19 @@ func (g *group) getBestAddr(ctx context.Context, strategy LoadBalancingStrategy,

var ep endpoint
var found bool
switch strategy {
case CHWBL:
switch req.Strategy {
case v1.CHWBLStrategy:
// TODO: prefix
ep, found = g.chwbl.getAddr(adapter + prefix)
case LeastLoaded:
ep, found = g.getAddrLeastLoad(adapter)
ep, found = g.chwbl.getAddr(req.Adapter + req.Prefix)
case v1.LeastLoadStrategy:
ep, found = g.getAddrLeastLoad(req.Adapter)
default:
return "", func() {}, fmt.Errorf("unknown load balancing strategy: %v", strategy)
return "", func() {}, fmt.Errorf("unknown load balancing strategy: %v", req.Strategy)
}

if !found {
g.mtx.RUnlock()
return g.getBestAddr(ctx, strategy, adapter, true)
return g.getBestAddr(ctx, req, true)
}

g.addInFlight(ep.inFlight, 1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package endpoints
package loadbalancer

import (
"context"
Expand All @@ -11,7 +11,7 @@ func BenchmarkEndpointGroup(b *testing.B) {
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, f, err := e.getBestAddr(context.Background(), "", false)
_, f, err := e.getBestAddr(context.Background(), AddressRequest{}, false)
if err != nil {
b.Fatal(err)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package endpoints
package loadbalancer

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package endpoints
package loadbalancer

import (
"context"
Expand All @@ -8,6 +8,7 @@ import (
"sync"

kubeaiv1 "github.com/substratusai/kubeai/api/v1"
v1 "github.com/substratusai/kubeai/api/v1"
"github.com/substratusai/kubeai/internal/k8sutils"
corev1 "k8s.io/api/core/v1"
"k8s.io/utils/ptr"
Expand All @@ -17,8 +18,8 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
)

func NewResolver(mgr ctrl.Manager) (*Resolver, error) {
r := &Resolver{}
func New(mgr ctrl.Manager) (*LoadBalancer, error) {
r := &LoadBalancer{}
r.Client = mgr.GetClient()
r.endpoints = map[string]*group{}
r.ExcludePods = map[string]struct{}{}
Expand All @@ -28,7 +29,7 @@ func NewResolver(mgr ctrl.Manager) (*Resolver, error) {
return r, nil
}

type Resolver struct {
type LoadBalancer struct {
client.Client

endpointsMtx sync.Mutex
Expand All @@ -41,14 +42,14 @@ type Resolver struct {
ExcludePods map[string]struct{}
}

func (r *Resolver) SetupWithManager(mgr ctrl.Manager) error {
func (r *LoadBalancer) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
WithOptions(controller.Options{NeedLeaderElection: ptr.To(false)}).
For(&corev1.Pod{}).
Complete(r)
}

func (r *Resolver) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
func (r *LoadBalancer) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
var pod corev1.Pod
if err := r.Get(ctx, req.NamespacedName, &pod); err != nil {
return ctrl.Result{}, client.IgnoreNotFound(err)
Expand Down Expand Up @@ -148,7 +149,10 @@ func getPodAnnotation(pod corev1.Pod, key string) string {
return ""
}

func (r *Resolver) getEndpoints(model string) *group {
// getEndpoints returns the endpoint group for the given model.
// If the group does not exist, it is created.
// This assumes that the existance of the model is already checked.
func (r *LoadBalancer) getEndpoints(model string) *group {
r.endpointsMtx.Lock()
e, ok := r.endpoints[model]
if !ok {
Expand All @@ -159,7 +163,7 @@ func (r *Resolver) getEndpoints(model string) *group {
return e
}

func (r *Resolver) GetSelfIPs() []string {
func (r *LoadBalancer) GetSelfIPs() []string {
r.selfIPsMtx.RLock()
defer r.selfIPsMtx.RUnlock()
return r.selfIPs
Expand All @@ -169,16 +173,17 @@ type AddressRequest struct {
Model string
Adapter string
Prefix string
v1.LoadBalancing
}

// AwaitBestAddress returns the "IP:Port" with the lowest number of in-flight requests. It will block until an endpoint
// becomes available or the context times out. It returns a function that should be called when the
// request is complete to decrement the in-flight count.
func (r *Resolver) AwaitBestAddress(ctx context.Context, req AddressRequest) (string, func(), error) {
return r.getEndpoints(req.Model).getBestAddr(ctx, req.Adapter, req.Prefix, false)
func (r *LoadBalancer) AwaitBestAddress(ctx context.Context, req AddressRequest) (string, func(), error) {
return r.getEndpoints(req.Model).getBestAddr(ctx, req, false)
}

// GetAllHosts retrieves the list of all hosts for a given model.
func (r *Resolver) GetAllAddresses(model string) []string {
func (r *LoadBalancer) GetAllAddresses(model string) []string {
return r.getEndpoints(model).getAllAddrs()
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package endpoints
package loadbalancer

import (
"context"
Expand All @@ -7,6 +7,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
v1 "github.com/substratusai/kubeai/api/v1"
)

func TestAwaitBestHost(t *testing.T) {
Expand All @@ -19,7 +20,7 @@ func TestAwaitBestHost(t *testing.T) {
myAddrWithAdapter = "10.0.0.2:8000"
)

manager := &Resolver{endpoints: make(map[string]*group, 1)}
manager := &LoadBalancer{endpoints: make(map[string]*group, 1)}

testCases := map[string]struct {
model string
Expand Down Expand Up @@ -67,7 +68,13 @@ func TestAwaitBestHost(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()

gotAddr, gotFunc, gotErr := manager.AwaitBestAddress(ctx, spec.model, spec.adapter)
gotAddr, gotFunc, gotErr := manager.AwaitBestAddress(ctx, AddressRequest{
Model: spec.model,
Adapter: spec.adapter,
LoadBalancing: v1.LoadBalancing{
Strategy: v1.LeastLoadStrategy,
},
})
if spec.expErr != nil {
require.ErrorIs(t, spec.expErr, gotErr)
return
Expand Down
18 changes: 9 additions & 9 deletions internal/manager/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ import (
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"

kubeaiv1 "github.com/substratusai/kubeai/api/v1"
"github.com/substratusai/kubeai/internal/endpoints"
"github.com/substratusai/kubeai/internal/leader"
"github.com/substratusai/kubeai/internal/loadbalancer"
"github.com/substratusai/kubeai/internal/messenger"
"github.com/substratusai/kubeai/internal/modelautoscaler"
"github.com/substratusai/kubeai/internal/modelclient"
"github.com/substratusai/kubeai/internal/modelcontroller"
"github.com/substratusai/kubeai/internal/modelproxy"
"github.com/substratusai/kubeai/internal/modelscaler"
"github.com/substratusai/kubeai/internal/openaiserver"
"github.com/substratusai/kubeai/internal/vllmclient"

Expand Down Expand Up @@ -204,7 +204,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
cfg.LeaderElection.RetryPeriod.Duration,
)

endpointResolver, err := endpoints.NewResolver(mgr)
loadBalancer, err := loadbalancer.New(mgr)
if err != nil {
return fmt.Errorf("unable to setup model resolver: %w", err)
}
Expand Down Expand Up @@ -239,7 +239,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
return fmt.Errorf("unable to set up ready check: %w", err)
}

modelScaler := modelscaler.NewModelScaler(mgr.GetClient(), namespace)
modelClient := modelclient.NewModelClient(mgr.GetClient(), namespace)

metricsPort, err := parsePortFromAddr(cfg.MetricsAddr)
if err != nil {
Expand All @@ -250,8 +250,8 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
ctx,
k8sClient,
leaderElection,
modelScaler,
endpointResolver,
modelClient,
loadBalancer,
cfg.ModelAutoscaling,
metricsPort,
types.NamespacedName{Name: cfg.ModelAutoscaling.StateConfigMapName, Namespace: namespace},
Expand All @@ -261,7 +261,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
return fmt.Errorf("unable to create model autoscaler: %w", err)
}

modelProxy := modelproxy.NewHandler(modelScaler, endpointResolver, 3, nil)
modelProxy := modelproxy.NewHandler(modelClient, loadBalancer, 3, nil)
openaiHandler := openaiserver.NewHandler(mgr.GetClient(), modelProxy)
mux := http.NewServeMux()
mux.Handle("/openai/", openaiHandler)
Expand All @@ -288,8 +288,8 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
stream.ResponsesURL,
stream.MaxHandlers,
cfg.Messaging.ErrorMaxBackoff.Duration,
modelScaler,
endpointResolver,
modelClient,
loadBalancer,
httpClient,
)
if err != nil {
Expand Down
Loading

0 comments on commit 4cf6ecf

Please sign in to comment.