From 4cf6ecf8b04f7a550596fba3130d18a41eccbe6f Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Tue, 3 Dec 2024 07:54:22 -0500 Subject: [PATCH] Checkpoint --- api/v1/model_types.go | 27 +++ api/v1/zz_generated.deepcopy.go | 40 +++++ .../balance_chwbl.go} | 2 +- .../balance_least_load.go} | 2 +- internal/{endpoints => loadbalancer}/group.go | 20 ++- .../group_bench_test.go | 4 +- .../{endpoints => loadbalancer}/group_test.go | 2 +- .../load_balancer.go} | 27 +-- .../load_balancer_test.go} | 13 +- internal/manager/run.go | 18 +- internal/messenger/messenger.go | 19 +- internal/modelautoscaler/autoscaler.go | 18 +- internal/modelclient/client.go | 73 ++++++++ internal/modelclient/scale.go | 100 +++++++++++ internal/modelproxy/handler.go | 33 ++-- internal/modelproxy/handler_test.go | 7 +- internal/modelscaler/scaler.go | 163 ------------------ 17 files changed, 335 insertions(+), 233 deletions(-) rename internal/{endpoints/group_lb_chwbl.go => loadbalancer/balance_chwbl.go} (99%) rename internal/{endpoints/group_lb_least_load.go => loadbalancer/balance_least_load.go} (95%) rename internal/{endpoints => loadbalancer}/group.go (86%) rename internal/{endpoints => loadbalancer}/group_bench_test.go (76%) rename internal/{endpoints => loadbalancer}/group_test.go (99%) rename internal/{endpoints/resolver.go => loadbalancer/load_balancer.go} (82%) rename internal/{endpoints/resolver_test.go => loadbalancer/load_balancer_test.go} (83%) create mode 100644 internal/modelclient/client.go create mode 100644 internal/modelclient/scale.go delete mode 100644 internal/modelscaler/scaler.go diff --git a/api/v1/model_types.go b/api/v1/model_types.go index 8cace19b..5feda0ab 100644 --- a/api/v1/model_types.go +++ b/api/v1/model_types.go @@ -115,6 +115,10 @@ type ModelSpec struct { // DEPRECATED. // +kubebuilder:validation:Optional Owner string `json:"owner"` + + // LoadBalancing configuration for the model. + // If not specified, a default is used based on the engine and request. + LoadBalancing *LoadBalancing `json:"loadBalancing,omitempty"` } // +kubebuilder:validation:Enum=TextGeneration;TextEmbedding;SpeechToText @@ -144,6 +148,29 @@ type Adapter struct { URL string `json:"url"` } +type LoadBalancing struct { + Strategy LoadBalancingStrategy `json:"strategy"` + CHWBL *CHWBL `json:"chwbl,omitempty"` +} + +// +kubebuilder:validation:Enum=LeastLoad;CHWBL +type LoadBalancingStrategy string + +const ( + LeastLoadStrategy LoadBalancingStrategy = "LeastLoad" + CHWBLStrategy LoadBalancingStrategy = "CHWBL" +) + +type CHWBL struct { + // MeanLoadFactor is the multiple that any given endpoint's load must not exceed + // over the mean load of all endpoints in the hash ring. + MeanLoadFactor float64 `json:"meanLoadFactor"` + // Replication is the number of replicas of each endpoint on the hash ring. + // Higher values will result in a more even distribution of load but will + // decrease lookup performance. + Replication int `json:"replication"` +} + // ModelStatus defines the observed state of Model. type ModelStatus struct { Replicas ModelStatusReplicas `json:"replicas,omitempty"` diff --git a/api/v1/zz_generated.deepcopy.go b/api/v1/zz_generated.deepcopy.go index 9e61934e..fa7bfd16 100644 --- a/api/v1/zz_generated.deepcopy.go +++ b/api/v1/zz_generated.deepcopy.go @@ -39,6 +39,41 @@ func (in *Adapter) DeepCopy() *Adapter { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *CHWBL) DeepCopyInto(out *CHWBL) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CHWBL. +func (in *CHWBL) DeepCopy() *CHWBL { + if in == nil { + return nil + } + out := new(CHWBL) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *LoadBalancing) DeepCopyInto(out *LoadBalancing) { + *out = *in + if in.CHWBL != nil { + in, out := &in.CHWBL, &out.CHWBL + *out = new(CHWBL) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new LoadBalancing. +func (in *LoadBalancing) DeepCopy() *LoadBalancing { + if in == nil { + return nil + } + out := new(LoadBalancing) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Model) DeepCopyInto(out *Model) { *out = *in @@ -143,6 +178,11 @@ func (in *ModelSpec) DeepCopyInto(out *ModelSpec) { *out = new(int64) **out = **in } + if in.LoadBalancing != nil { + in, out := &in.LoadBalancing, &out.LoadBalancing + *out = new(LoadBalancing) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelSpec. diff --git a/internal/endpoints/group_lb_chwbl.go b/internal/loadbalancer/balance_chwbl.go similarity index 99% rename from internal/endpoints/group_lb_chwbl.go rename to internal/loadbalancer/balance_chwbl.go index 3a5a94d0..c9aa93ce 100644 --- a/internal/endpoints/group_lb_chwbl.go +++ b/internal/loadbalancer/balance_chwbl.go @@ -1,4 +1,4 @@ -package endpoints +package loadbalancer import ( "fmt" diff --git a/internal/endpoints/group_lb_least_load.go b/internal/loadbalancer/balance_least_load.go similarity index 95% rename from internal/endpoints/group_lb_least_load.go rename to internal/loadbalancer/balance_least_load.go index 5d7be750..84ddd7ab 100644 --- a/internal/endpoints/group_lb_least_load.go +++ b/internal/loadbalancer/balance_least_load.go @@ -1,4 +1,4 @@ -package endpoints +package loadbalancer func (g *group) getAddrLeastLoad(adapter string) (endpoint, bool) { var bestEp endpoint diff --git a/internal/endpoints/group.go b/internal/loadbalancer/group.go similarity index 86% rename from internal/endpoints/group.go rename to internal/loadbalancer/group.go index bfdac824..6c1951ce 100644 --- a/internal/endpoints/group.go +++ b/internal/loadbalancer/group.go @@ -1,4 +1,4 @@ -package endpoints +package loadbalancer import ( "context" @@ -6,6 +6,8 @@ import ( "log" "sync" "sync/atomic" + + v1 "github.com/substratusai/kubeai/api/v1" ) func newEndpointGroup() *group { @@ -49,7 +51,7 @@ const ( // getBestAddr returns the best "IP:Port". It blocks until there are available endpoints // in the endpoint group. -func (g *group) getBestAddr(ctx context.Context, strategy LoadBalancingStrategy, adapter string, awaitChangeEndpoints bool) (string, func(), error) { +func (g *group) getBestAddr(ctx context.Context, req AddressRequest, awaitChangeEndpoints bool) (string, func(), error) { g.mtx.RLock() // await endpoints exists for awaitChangeEndpoints || len(g.endpoints) == 0 { @@ -64,19 +66,19 @@ func (g *group) getBestAddr(ctx context.Context, strategy LoadBalancingStrategy, var ep endpoint var found bool - switch strategy { - case CHWBL: + switch req.Strategy { + case v1.CHWBLStrategy: // TODO: prefix - ep, found = g.chwbl.getAddr(adapter + prefix) - case LeastLoaded: - ep, found = g.getAddrLeastLoad(adapter) + ep, found = g.chwbl.getAddr(req.Adapter + req.Prefix) + case v1.LeastLoadStrategy: + ep, found = g.getAddrLeastLoad(req.Adapter) default: - return "", func() {}, fmt.Errorf("unknown load balancing strategy: %v", strategy) + return "", func() {}, fmt.Errorf("unknown load balancing strategy: %v", req.Strategy) } if !found { g.mtx.RUnlock() - return g.getBestAddr(ctx, strategy, adapter, true) + return g.getBestAddr(ctx, req, true) } g.addInFlight(ep.inFlight, 1) diff --git a/internal/endpoints/group_bench_test.go b/internal/loadbalancer/group_bench_test.go similarity index 76% rename from internal/endpoints/group_bench_test.go rename to internal/loadbalancer/group_bench_test.go index f05bb51a..47c8cf5f 100644 --- a/internal/endpoints/group_bench_test.go +++ b/internal/loadbalancer/group_bench_test.go @@ -1,4 +1,4 @@ -package endpoints +package loadbalancer import ( "context" @@ -11,7 +11,7 @@ func BenchmarkEndpointGroup(b *testing.B) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, f, err := e.getBestAddr(context.Background(), "", false) + _, f, err := e.getBestAddr(context.Background(), AddressRequest{}, false) if err != nil { b.Fatal(err) } diff --git a/internal/endpoints/group_test.go b/internal/loadbalancer/group_test.go similarity index 99% rename from internal/endpoints/group_test.go rename to internal/loadbalancer/group_test.go index e5811031..d112b1bb 100644 --- a/internal/endpoints/group_test.go +++ b/internal/loadbalancer/group_test.go @@ -1,4 +1,4 @@ -package endpoints +package loadbalancer import ( "context" diff --git a/internal/endpoints/resolver.go b/internal/loadbalancer/load_balancer.go similarity index 82% rename from internal/endpoints/resolver.go rename to internal/loadbalancer/load_balancer.go index 96015510..146a1067 100644 --- a/internal/endpoints/resolver.go +++ b/internal/loadbalancer/load_balancer.go @@ -1,4 +1,4 @@ -package endpoints +package loadbalancer import ( "context" @@ -8,6 +8,7 @@ import ( "sync" kubeaiv1 "github.com/substratusai/kubeai/api/v1" + v1 "github.com/substratusai/kubeai/api/v1" "github.com/substratusai/kubeai/internal/k8sutils" corev1 "k8s.io/api/core/v1" "k8s.io/utils/ptr" @@ -17,8 +18,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -func NewResolver(mgr ctrl.Manager) (*Resolver, error) { - r := &Resolver{} +func New(mgr ctrl.Manager) (*LoadBalancer, error) { + r := &LoadBalancer{} r.Client = mgr.GetClient() r.endpoints = map[string]*group{} r.ExcludePods = map[string]struct{}{} @@ -28,7 +29,7 @@ func NewResolver(mgr ctrl.Manager) (*Resolver, error) { return r, nil } -type Resolver struct { +type LoadBalancer struct { client.Client endpointsMtx sync.Mutex @@ -41,14 +42,14 @@ type Resolver struct { ExcludePods map[string]struct{} } -func (r *Resolver) SetupWithManager(mgr ctrl.Manager) error { +func (r *LoadBalancer) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). WithOptions(controller.Options{NeedLeaderElection: ptr.To(false)}). For(&corev1.Pod{}). Complete(r) } -func (r *Resolver) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { +func (r *LoadBalancer) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { var pod corev1.Pod if err := r.Get(ctx, req.NamespacedName, &pod); err != nil { return ctrl.Result{}, client.IgnoreNotFound(err) @@ -148,7 +149,10 @@ func getPodAnnotation(pod corev1.Pod, key string) string { return "" } -func (r *Resolver) getEndpoints(model string) *group { +// getEndpoints returns the endpoint group for the given model. +// If the group does not exist, it is created. +// This assumes that the existance of the model is already checked. +func (r *LoadBalancer) getEndpoints(model string) *group { r.endpointsMtx.Lock() e, ok := r.endpoints[model] if !ok { @@ -159,7 +163,7 @@ func (r *Resolver) getEndpoints(model string) *group { return e } -func (r *Resolver) GetSelfIPs() []string { +func (r *LoadBalancer) GetSelfIPs() []string { r.selfIPsMtx.RLock() defer r.selfIPsMtx.RUnlock() return r.selfIPs @@ -169,16 +173,17 @@ type AddressRequest struct { Model string Adapter string Prefix string + v1.LoadBalancing } // AwaitBestAddress returns the "IP:Port" with the lowest number of in-flight requests. It will block until an endpoint // becomes available or the context times out. It returns a function that should be called when the // request is complete to decrement the in-flight count. -func (r *Resolver) AwaitBestAddress(ctx context.Context, req AddressRequest) (string, func(), error) { - return r.getEndpoints(req.Model).getBestAddr(ctx, req.Adapter, req.Prefix, false) +func (r *LoadBalancer) AwaitBestAddress(ctx context.Context, req AddressRequest) (string, func(), error) { + return r.getEndpoints(req.Model).getBestAddr(ctx, req, false) } // GetAllHosts retrieves the list of all hosts for a given model. -func (r *Resolver) GetAllAddresses(model string) []string { +func (r *LoadBalancer) GetAllAddresses(model string) []string { return r.getEndpoints(model).getAllAddrs() } diff --git a/internal/endpoints/resolver_test.go b/internal/loadbalancer/load_balancer_test.go similarity index 83% rename from internal/endpoints/resolver_test.go rename to internal/loadbalancer/load_balancer_test.go index 0b2c13e9..caa1b809 100644 --- a/internal/endpoints/resolver_test.go +++ b/internal/loadbalancer/load_balancer_test.go @@ -1,4 +1,4 @@ -package endpoints +package loadbalancer import ( "context" @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + v1 "github.com/substratusai/kubeai/api/v1" ) func TestAwaitBestHost(t *testing.T) { @@ -19,7 +20,7 @@ func TestAwaitBestHost(t *testing.T) { myAddrWithAdapter = "10.0.0.2:8000" ) - manager := &Resolver{endpoints: make(map[string]*group, 1)} + manager := &LoadBalancer{endpoints: make(map[string]*group, 1)} testCases := map[string]struct { model string @@ -67,7 +68,13 @@ func TestAwaitBestHost(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - gotAddr, gotFunc, gotErr := manager.AwaitBestAddress(ctx, spec.model, spec.adapter) + gotAddr, gotFunc, gotErr := manager.AwaitBestAddress(ctx, AddressRequest{ + Model: spec.model, + Adapter: spec.adapter, + LoadBalancing: v1.LoadBalancing{ + Strategy: v1.LeastLoadStrategy, + }, + }) if spec.expErr != nil { require.ErrorIs(t, spec.expErr, gotErr) return diff --git a/internal/manager/run.go b/internal/manager/run.go index fb9a6dcb..3ea13877 100644 --- a/internal/manager/run.go +++ b/internal/manager/run.go @@ -33,13 +33,13 @@ import ( metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" kubeaiv1 "github.com/substratusai/kubeai/api/v1" - "github.com/substratusai/kubeai/internal/endpoints" "github.com/substratusai/kubeai/internal/leader" + "github.com/substratusai/kubeai/internal/loadbalancer" "github.com/substratusai/kubeai/internal/messenger" "github.com/substratusai/kubeai/internal/modelautoscaler" + "github.com/substratusai/kubeai/internal/modelclient" "github.com/substratusai/kubeai/internal/modelcontroller" "github.com/substratusai/kubeai/internal/modelproxy" - "github.com/substratusai/kubeai/internal/modelscaler" "github.com/substratusai/kubeai/internal/openaiserver" "github.com/substratusai/kubeai/internal/vllmclient" @@ -204,7 +204,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { cfg.LeaderElection.RetryPeriod.Duration, ) - endpointResolver, err := endpoints.NewResolver(mgr) + loadBalancer, err := loadbalancer.New(mgr) if err != nil { return fmt.Errorf("unable to setup model resolver: %w", err) } @@ -239,7 +239,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { return fmt.Errorf("unable to set up ready check: %w", err) } - modelScaler := modelscaler.NewModelScaler(mgr.GetClient(), namespace) + modelClient := modelclient.NewModelClient(mgr.GetClient(), namespace) metricsPort, err := parsePortFromAddr(cfg.MetricsAddr) if err != nil { @@ -250,8 +250,8 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { ctx, k8sClient, leaderElection, - modelScaler, - endpointResolver, + modelClient, + loadBalancer, cfg.ModelAutoscaling, metricsPort, types.NamespacedName{Name: cfg.ModelAutoscaling.StateConfigMapName, Namespace: namespace}, @@ -261,7 +261,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { return fmt.Errorf("unable to create model autoscaler: %w", err) } - modelProxy := modelproxy.NewHandler(modelScaler, endpointResolver, 3, nil) + modelProxy := modelproxy.NewHandler(modelClient, loadBalancer, 3, nil) openaiHandler := openaiserver.NewHandler(mgr.GetClient(), modelProxy) mux := http.NewServeMux() mux.Handle("/openai/", openaiHandler) @@ -288,8 +288,8 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { stream.ResponsesURL, stream.MaxHandlers, cfg.Messaging.ErrorMaxBackoff.Duration, - modelScaler, - endpointResolver, + modelClient, + loadBalancer, httpClient, ) if err != nil { diff --git a/internal/messenger/messenger.go b/internal/messenger/messenger.go index ce6d51b3..a78da87e 100644 --- a/internal/messenger/messenger.go +++ b/internal/messenger/messenger.go @@ -14,6 +14,7 @@ import ( "time" "github.com/substratusai/kubeai/internal/apiutils" + "github.com/substratusai/kubeai/internal/loadbalancer" "github.com/substratusai/kubeai/internal/metrics" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -21,8 +22,8 @@ import ( ) type Messenger struct { - modelScaler ModelScaler - resolver EndpointResolver + modelScaler ModelScaler + loadBalancer LoadBalancer HTTPC *http.Client @@ -44,7 +45,7 @@ func NewMessenger( maxHandlers int, errorMaxBackoff time.Duration, modelScaler ModelScaler, - resolver EndpointResolver, + lb LoadBalancer, httpClient *http.Client, ) (*Messenger, error) { requests, err := pubsub.OpenSubscription(ctx, requestsURL) @@ -59,7 +60,7 @@ func NewMessenger( return &Messenger{ modelScaler: modelScaler, - resolver: resolver, + loadBalancer: lb, HTTPC: httpClient, requestsURL: requestsURL, requests: requests, @@ -74,8 +75,8 @@ type ModelScaler interface { ScaleAtLeastOneReplica(ctx context.Context, model string) error } -type EndpointResolver interface { - AwaitBestAddress(ctx context.Context, model, adapter string) (string, func(), error) +type LoadBalancer interface { + AwaitBestAddress(ctx context.Context, req loadbalancer.AddressRequest) (string, func(), error) } func (m *Messenger) Start(ctx context.Context) error { @@ -222,7 +223,11 @@ func (m *Messenger) handleRequest(ctx context.Context, msg *pubsub.Message) { log.Printf("Awaiting host for message %s", msg.LoggableID) - host, completeFunc, err := m.resolver.AwaitBestAddress(ctx, req.model, req.adapter) + host, completeFunc, err := m.loadBalancer.AwaitBestAddress(ctx, loadbalancer.AddressRequest{ + Model: req.model, + Adapter: req.adapter, + // TODO: Prefix + }) if err != nil { m.sendResponse(req, m.jsonError("error awaiting host for backend: %v", err), http.StatusBadGateway) return diff --git a/internal/modelautoscaler/autoscaler.go b/internal/modelautoscaler/autoscaler.go index 6d4bf88b..52c5d05c 100644 --- a/internal/modelautoscaler/autoscaler.go +++ b/internal/modelautoscaler/autoscaler.go @@ -9,9 +9,9 @@ import ( "time" "github.com/substratusai/kubeai/internal/config" - "github.com/substratusai/kubeai/internal/endpoints" "github.com/substratusai/kubeai/internal/leader" - "github.com/substratusai/kubeai/internal/modelscaler" + "github.com/substratusai/kubeai/internal/loadbalancer" + "github.com/substratusai/kubeai/internal/modelclient" "github.com/substratusai/kubeai/internal/movingaverage" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" @@ -21,8 +21,8 @@ func New( ctx context.Context, k8sClient client.Client, leaderElection *leader.Election, - scaler *modelscaler.ModelScaler, - resolver *endpoints.Resolver, + modelClient *modelclient.ModelClient, + resolver *loadbalancer.LoadBalancer, cfg config.ModelAutoscaling, metricsPort int, stateConfigMapRef types.NamespacedName, @@ -31,7 +31,7 @@ func New( a := &Autoscaler{ k8sClient: k8sClient, leaderElection: leaderElection, - scaler: scaler, + modelClient: modelClient, resolver: resolver, movingAvgByModel: map[string]*movingaverage.Simple{}, cfg: cfg, @@ -78,8 +78,8 @@ type Autoscaler struct { leaderElection *leader.Election - scaler *modelscaler.ModelScaler - resolver *endpoints.Resolver + modelClient *modelclient.ModelClient + resolver *loadbalancer.LoadBalancer cfg config.ModelAutoscaling @@ -107,7 +107,7 @@ func (a *Autoscaler) Start(ctx context.Context) { // TODO: Remove hardcoded Service lookup by name "lingo". - models, err := a.scaler.ListAllModels(ctx) + models, err := a.modelClient.ListAllModels(ctx) if err != nil { log.Printf("Failed to list models: %v", err) continue @@ -159,7 +159,7 @@ func (a *Autoscaler) Start(ctx context.Context) { ceil := math.Ceil(normalized) log.Printf("Calculated target replicas for model %q: ceil(%v/%v) = %v, current requests: sum(%v) = %v, history: %v", m.Name, avgActiveRequests, *m.Spec.TargetRequests, ceil, activeRequests, activeRequestSum, avg.History()) - a.scaler.Scale(ctx, &m, int32(ceil), a.cfg.RequiredConsecutiveScaleDowns(*m.Spec.ScaleDownDelaySeconds)) + a.modelClient.Scale(ctx, &m, int32(ceil), a.cfg.RequiredConsecutiveScaleDowns(*m.Spec.ScaleDownDelaySeconds)) nextModelState.Models[m.Name] = modelState{ AverageActiveRequests: avgActiveRequests, diff --git a/internal/modelclient/client.go b/internal/modelclient/client.go new file mode 100644 index 00000000..d908f045 --- /dev/null +++ b/internal/modelclient/client.go @@ -0,0 +1,73 @@ +package modelclient + +import ( + "context" + "fmt" + "sync" + + kubeaiv1 "github.com/substratusai/kubeai/api/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type ModelClient struct { + client client.Client + namespace string + consecutiveScaleDownsMtx sync.RWMutex + consecutiveScaleDowns map[string]int +} + +func NewModelClient(client client.Client, namespace string) *ModelClient { + return &ModelClient{client: client, namespace: namespace, consecutiveScaleDowns: map[string]int{}} +} + +// LookupModel checks if a model exists and matches the given label selectors. +func (c *ModelClient) LookupModel(ctx context.Context, model, adapter string, labelSelectors []string) (bool, error) { + m := &kubeaiv1.Model{} + if err := c.client.Get(ctx, types.NamespacedName{Name: model, Namespace: c.namespace}, m); err != nil { + if apierrors.IsNotFound(err) { + return false, nil + } + return false, err + } + + modelLabels := m.GetLabels() + if modelLabels == nil { + modelLabels = map[string]string{} + } + for _, sel := range labelSelectors { + parsedSel, err := labels.Parse(sel) + if err != nil { + return false, fmt.Errorf("parse label selector: %w", err) + } + if !parsedSel.Matches(labels.Set(modelLabels)) { + return false, nil + } + } + + if adapter != "" { + adapterFound := false + for _, a := range m.Spec.Adapters { + if a.Name == adapter { + adapterFound = true + break + } + } + if !adapterFound { + return false, nil + } + } + + return true, nil +} + +func (s *ModelClient) ListAllModels(ctx context.Context) ([]kubeaiv1.Model, error) { + models := &kubeaiv1.ModelList{} + if err := s.client.List(ctx, models, client.InNamespace(s.namespace)); err != nil { + return nil, fmt.Errorf("list models: %w", err) + } + + return models.Items, nil +} diff --git a/internal/modelclient/scale.go b/internal/modelclient/scale.go new file mode 100644 index 00000000..f576da7c --- /dev/null +++ b/internal/modelclient/scale.go @@ -0,0 +1,100 @@ +package modelclient + +import ( + "context" + "fmt" + "log" + + kubeaiv1 "github.com/substratusai/kubeai/api/v1" + autoscalingv1 "k8s.io/api/autoscaling/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func (c *ModelClient) ScaleAtLeastOneReplica(ctx context.Context, model string) error { + obj := &kubeaiv1.Model{} + if err := c.client.Get(ctx, types.NamespacedName{Namespace: c.namespace, Name: model}, obj); err != nil { + return fmt.Errorf("get scale: %w", err) + } + + if obj.Spec.AutoscalingDisabled { + return nil + } + + replicas := int32(0) + if obj.Spec.Replicas != nil { + replicas = *obj.Spec.Replicas + } + + if replicas == 0 && !obj.Spec.AutoscalingDisabled { + scale := &autoscalingv1.Scale{ + Spec: autoscalingv1.ScaleSpec{Replicas: 1}, + } + if err := c.client.SubResource("scale").Update(ctx, obj, client.WithSubResourceBody(scale)); err != nil { + return fmt.Errorf("update scale: %w", err) + } + } + + return nil +} + +// Scale scales the model to the desired number of replicas, enforcing the min and max replica bounds. +// Model should have .Spec defined before calling Scale(). +func (c *ModelClient) Scale(ctx context.Context, model *kubeaiv1.Model, replicas int32, requiredConsecutiveScaleDowns int) error { + //obj := &kubeaiv1.Model{} + //if err := s.client.Get(ctx, types.NamespacedName{Namespace: s.namespace, Name: model}, obj); err != nil { + // return fmt.Errorf("get scale: %w", err) + //} + + replicas = enforceReplicaBounds(replicas, model) + + var existingReplicas int32 = 0 + if model.Spec.Replicas != nil { + existingReplicas = *model.Spec.Replicas + } + + if existingReplicas > replicas { + // Scale down + c.consecutiveScaleDownsMtx.RLock() + consec := c.consecutiveScaleDowns[model.Name] + c.consecutiveScaleDownsMtx.RUnlock() + if consec < requiredConsecutiveScaleDowns { + log.Printf("model %s has %d consecutive scale downs (< %d), not scaling down yet", model.Name, consec, requiredConsecutiveScaleDowns) + c.consecutiveScaleDownsMtx.Lock() + c.consecutiveScaleDowns[model.Name]++ + c.consecutiveScaleDownsMtx.Unlock() + return nil + } + } else { + // Scale up or constant scale. + c.consecutiveScaleDownsMtx.Lock() + c.consecutiveScaleDowns[model.Name] = 0 + c.consecutiveScaleDownsMtx.Unlock() + } + + if existingReplicas != replicas { + log.Printf("scaling model %s from %d to %d replicas", model.Name, existingReplicas, replicas) + scale := &autoscalingv1.Scale{ + Spec: autoscalingv1.ScaleSpec{Replicas: replicas}, + } + if err := c.client.SubResource("scale").Update(ctx, model, client.WithSubResourceBody(scale)); err != nil { + return fmt.Errorf("update scale: %w", err) + } + } + + return nil +} + +func enforceReplicaBounds(replicas int32, model *kubeaiv1.Model) int32 { + max := model.Spec.MaxReplicas + min := model.Spec.MinReplicas + if max != nil { + if replicas > *max { + return *max + } + } + if replicas < min { + return min + } + return replicas +} diff --git a/internal/modelproxy/handler.go b/internal/modelproxy/handler.go index 5edaf4a6..6d4dd47a 100644 --- a/internal/modelproxy/handler.go +++ b/internal/modelproxy/handler.go @@ -8,40 +8,41 @@ import ( "net/http/httputil" "net/url" + "github.com/substratusai/kubeai/internal/loadbalancer" "github.com/substratusai/kubeai/internal/metrics" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) -type ModelScaler interface { +type ModelClient interface { LookupModel(ctx context.Context, model, adapter string, selectors []string) (bool, error) ScaleAtLeastOneReplica(ctx context.Context, model string) error } -type EndpointResolver interface { - AwaitBestAddress(ctx context.Context, model, adapter string) (string, func(), error) +type LoadBalancer interface { + AwaitBestAddress(ctx context.Context, req loadbalancer.AddressRequest) (string, func(), error) } // Handler serves http requests for end-clients. // It is also responsible for triggering scale-from-zero. type Handler struct { - modelScaler ModelScaler - resolver EndpointResolver - maxRetries int - retryCodes map[int]struct{} + modelScaler ModelClient + loadBalancer LoadBalancer + maxRetries int + retryCodes map[int]struct{} } func NewHandler( - modelScaler ModelScaler, - resolver EndpointResolver, + modelScaler ModelClient, + loadBalancer LoadBalancer, maxRetries int, retryCodes map[int]struct{}, ) *Handler { return &Handler{ - modelScaler: modelScaler, - resolver: resolver, - maxRetries: maxRetries, - retryCodes: retryCodes, + modelScaler: modelScaler, + loadBalancer: loadBalancer, + maxRetries: maxRetries, + retryCodes: retryCodes, } } @@ -100,7 +101,11 @@ var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {} func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) { log.Printf("Waiting for host: %v", pr.id) - addr, decrementInflight, err := h.resolver.AwaitBestAddress(pr.r.Context(), pr.model, pr.adapter) + addr, decrementInflight, err := h.loadBalancer.AwaitBestAddress(pr.r.Context(), loadbalancer.AddressRequest{ + Model: pr.model, + Adapter: pr.adapter, + // TODO: Prefix + }) if err != nil { switch { case errors.Is(err, context.Canceled): diff --git a/internal/modelproxy/handler_test.go b/internal/modelproxy/handler_test.go index 5028ad92..da43e77f 100644 --- a/internal/modelproxy/handler_test.go +++ b/internal/modelproxy/handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substratusai/kubeai/internal/apiutils" + "github.com/substratusai/kubeai/internal/loadbalancer" "github.com/substratusai/kubeai/internal/metrics/metricstest" ) @@ -294,9 +295,9 @@ func (t *testModelInterface) ScaleAtLeastOneReplica(ctx context.Context, model s return nil } -func (t *testModelInterface) AwaitBestAddress(ctx context.Context, model, adapter string) (string, func(), error) { +func (t *testModelInterface) AwaitBestAddress(ctx context.Context, req loadbalancer.AddressRequest) (string, func(), error) { t.hostRequestCount++ - t.requestedModel = model - t.requestedAdapter = adapter + t.requestedModel = req.Model + t.requestedAdapter = req.Adapter return t.address, func() {}, nil } diff --git a/internal/modelscaler/scaler.go b/internal/modelscaler/scaler.go deleted file mode 100644 index fdb44294..00000000 --- a/internal/modelscaler/scaler.go +++ /dev/null @@ -1,163 +0,0 @@ -package modelscaler - -import ( - "context" - "fmt" - "log" - "sync" - - kubeaiv1 "github.com/substratusai/kubeai/api/v1" - autoscalingv1 "k8s.io/api/autoscaling/v1" - apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/labels" - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -type ModelScaler struct { - client client.Client - namespace string - consecutiveScaleDownsMtx sync.RWMutex - consecutiveScaleDowns map[string]int -} - -func NewModelScaler(client client.Client, namespace string) *ModelScaler { - return &ModelScaler{client: client, namespace: namespace, consecutiveScaleDowns: map[string]int{}} -} - -// LookupModel checks if a model exists and matches the given label selectors. -func (s *ModelScaler) LookupModel(ctx context.Context, model, adapter string, labelSelectors []string) (bool, error) { - m := &kubeaiv1.Model{} - if err := s.client.Get(ctx, types.NamespacedName{Name: model, Namespace: s.namespace}, m); err != nil { - if apierrors.IsNotFound(err) { - return false, nil - } - return false, err - } - - modelLabels := m.GetLabels() - if modelLabels == nil { - modelLabels = map[string]string{} - } - for _, sel := range labelSelectors { - parsedSel, err := labels.Parse(sel) - if err != nil { - return false, fmt.Errorf("parse label selector: %w", err) - } - if !parsedSel.Matches(labels.Set(modelLabels)) { - return false, nil - } - } - - if adapter != "" { - adapterFound := false - for _, a := range m.Spec.Adapters { - if a.Name == adapter { - adapterFound = true - break - } - } - if !adapterFound { - return false, nil - } - } - - return true, nil -} - -func (s *ModelScaler) ListAllModels(ctx context.Context) ([]kubeaiv1.Model, error) { - models := &kubeaiv1.ModelList{} - if err := s.client.List(ctx, models, client.InNamespace(s.namespace)); err != nil { - return nil, fmt.Errorf("list models: %w", err) - } - - return models.Items, nil -} - -func (s *ModelScaler) ScaleAtLeastOneReplica(ctx context.Context, model string) error { - obj := &kubeaiv1.Model{} - if err := s.client.Get(ctx, types.NamespacedName{Namespace: s.namespace, Name: model}, obj); err != nil { - return fmt.Errorf("get scale: %w", err) - } - - if obj.Spec.AutoscalingDisabled { - return nil - } - - replicas := int32(0) - if obj.Spec.Replicas != nil { - replicas = *obj.Spec.Replicas - } - - if replicas == 0 && !obj.Spec.AutoscalingDisabled { - scale := &autoscalingv1.Scale{ - Spec: autoscalingv1.ScaleSpec{Replicas: 1}, - } - if err := s.client.SubResource("scale").Update(ctx, obj, client.WithSubResourceBody(scale)); err != nil { - return fmt.Errorf("update scale: %w", err) - } - } - - return nil -} - -// Scale scales the model to the desired number of replicas, enforcing the min and max replica bounds. -// Model should have .Spec defined before calling Scale(). -func (s *ModelScaler) Scale(ctx context.Context, model *kubeaiv1.Model, replicas int32, requiredConsecutiveScaleDowns int) error { - //obj := &kubeaiv1.Model{} - //if err := s.client.Get(ctx, types.NamespacedName{Namespace: s.namespace, Name: model}, obj); err != nil { - // return fmt.Errorf("get scale: %w", err) - //} - - replicas = enforceReplicaBounds(replicas, model) - - var existingReplicas int32 = 0 - if model.Spec.Replicas != nil { - existingReplicas = *model.Spec.Replicas - } - - if existingReplicas > replicas { - // Scale down - s.consecutiveScaleDownsMtx.RLock() - consec := s.consecutiveScaleDowns[model.Name] - s.consecutiveScaleDownsMtx.RUnlock() - if consec < requiredConsecutiveScaleDowns { - log.Printf("model %s has %d consecutive scale downs (< %d), not scaling down yet", model.Name, consec, requiredConsecutiveScaleDowns) - s.consecutiveScaleDownsMtx.Lock() - s.consecutiveScaleDowns[model.Name]++ - s.consecutiveScaleDownsMtx.Unlock() - return nil - } - } else { - // Scale up or constant scale. - s.consecutiveScaleDownsMtx.Lock() - s.consecutiveScaleDowns[model.Name] = 0 - s.consecutiveScaleDownsMtx.Unlock() - } - - if existingReplicas != replicas { - log.Printf("scaling model %s from %d to %d replicas", model.Name, existingReplicas, replicas) - scale := &autoscalingv1.Scale{ - Spec: autoscalingv1.ScaleSpec{Replicas: replicas}, - } - if err := s.client.SubResource("scale").Update(ctx, model, client.WithSubResourceBody(scale)); err != nil { - return fmt.Errorf("update scale: %w", err) - } - } - - return nil -} - -func enforceReplicaBounds(replicas int32, model *kubeaiv1.Model) int32 { - max := model.Spec.MaxReplicas - min := model.Spec.MinReplicas - if max != nil { - if replicas > *max { - return *max - } - } - if replicas < min { - return min - } - return replicas -}