From bc4ba0df53feb27e7b10923008b2d25ef985807b Mon Sep 17 00:00:00 2001 From: Alexander Peters Date: Thu, 21 Dec 2023 15:23:49 +0100 Subject: [PATCH] Handle model undeployment (#44) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🚧 spike Simple version to fail the request fast when a model backed was undeployed while queued. Not handled is the case when a model was removed from the deployment annotation --- pkg/deployments/manager.go | 22 ++++++- pkg/proxy/handler.go | 8 +++ tests/integration/integration_test.go | 85 ++++++++++++++++++++++++--- tests/integration/main_test.go | 3 +- 4 files changed, 107 insertions(+), 11 deletions(-) diff --git a/pkg/deployments/manager.go b/pkg/deployments/manager.go index 02446fad..feebfcc1 100644 --- a/pkg/deployments/manager.go +++ b/pkg/deployments/manager.go @@ -11,8 +11,8 @@ import ( appsv1 "k8s.io/api/apps/v1" autoscalingv1 "k8s.io/api/autoscaling/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" - ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -65,7 +65,11 @@ func (r *Manager) SetDesiredScale(deploymentName string, n int32) { func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { var d appsv1.Deployment - if err := r.Get(ctx, req.NamespacedName, &d); err != nil { + switch err := r.Get(ctx, req.NamespacedName, &d); { + case apierrors.IsNotFound(err): + r.removeDeployment(req) + return ctrl.Result{}, nil + case err != nil: return ctrl.Result{}, fmt.Errorf("get: %w", err) } @@ -98,6 +102,20 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, return ctrl.Result{}, nil } +func (r *Manager) removeDeployment(req ctrl.Request) { + r.scalersMtx.Lock() + delete(r.scalers, req.Name) + r.scalersMtx.Unlock() + + r.modelToDeploymentMtx.Lock() + for model, deployment := range r.modelToDeployment { + if deployment == req.Name { + delete(r.modelToDeployment, model) + } + } + r.modelToDeploymentMtx.Unlock() +} + func (r *Manager) getScaler(deploymentName string) *scaler { r.scalersMtx.Lock() b, ok := r.scalers[deploymentName] diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index 36bfa616..7c32ba7b 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -74,6 +74,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Println("Admitted into queue", id) defer complete() + // abort when deployment was removed meanwhile + if _, exists := h.Deployments.ResolveDeployment(modelName); !exists { + log.Printf("deployment not active for model removed: %v", err) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName))) + return + } + log.Println("Waiting for IPs", id) host := h.Endpoints.GetHost(r.Context(), deploy, "http") log.Printf("Got host: %v, id: %v\n", host, id) diff --git a/tests/integration/integration_test.go b/tests/integration/integration_test.go index 59c2f02f..10fb5e45 100644 --- a/tests/integration/integration_test.go +++ b/tests/integration/integration_test.go @@ -13,17 +13,20 @@ import ( "testing" "time" + "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" disv1 "k8s.io/api/discovery/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" ) -func TestIntegration(t *testing.T) { +func TestScaleUpAndDown(t *testing.T) { const modelName = "test-model-a" deploy := testDeployment(modelName) @@ -57,7 +60,7 @@ func TestIntegration(t *testing.T) { // Send request number 1 var wg sync.WaitGroup - sendRequests(t, &wg, modelName, 1) + sendRequests(t, &wg, modelName, 1, http.StatusOK) requireDeploymentReplicas(t, deploy, 1) require.Equal(t, int32(1), backendRequests.Load(), "ensure the request made its way to the backend") @@ -66,11 +69,11 @@ func TestIntegration(t *testing.T) { // Ensure the deployment scaled scaled past 1. // 1/2 should be admitted // 1/2 should remain in queue - sendRequests(t, &wg, modelName, 2) + sendRequests(t, &wg, modelName, 2, http.StatusOK) requireDeploymentReplicas(t, deploy, 2) // Make sure deployment will not be scaled past default max (3). - sendRequests(t, &wg, modelName, 2) + sendRequests(t, &wg, modelName, 2, http.StatusOK) requireDeploymentReplicas(t, deploy, 3) // Have the mock backend respond to the remaining 4 requests. @@ -83,6 +86,71 @@ func TestIntegration(t *testing.T) { wg.Wait() } +func TestHandleModelUndeployment(t *testing.T) { + const modelName = "test-model-b" + deploy := testDeployment(modelName) + + require.NoError(t, testK8sClient.Create(testCtx, deploy)) + + backendComplete := make(chan struct{}) + + backendRequests := &atomic.Int32{} + testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Println("Serving request from testBackend") + backendRequests.Add(1) + <-backendComplete + w.WriteHeader(200) + })) + + // Mock an EndpointSlice. + testBackendURL, err := url.Parse(testBackend.URL) + require.NoError(t, err) + testBackendPort, err := strconv.Atoi(testBackendURL.Port()) + require.NoError(t, err) + require.NoError(t, testK8sClient.Create(testCtx, + endpointSlice( + modelName, + testBackendURL.Hostname(), + int32(testBackendPort), + ), + )) + + // Wait for deployment mapping to sync. + time.Sleep(3 * time.Second) + + // Send request number 1 + var wg sync.WaitGroup + // send single request to scale up and block on the handler to build a queue + sendRequests(t, &wg, modelName, 1, http.StatusOK) + + requireDeploymentReplicas(t, deploy, 1) + require.Equal(t, int32(1), backendRequests.Load(), "ensure the request made its way to the backend") + // Add some more requests to the queue but with 404 expected + // because the deployment is deleted before un-queued + sendRequests(t, &wg, modelName, 2, http.StatusNotFound) + + require.NoError(t, testK8sClient.Delete(testCtx, deploy)) + + // Check that the deployment was deleted + err = testK8sClient.Get(testCtx, client.ObjectKey{ + Namespace: deploy.Namespace, + Name: deploy.Name, + }, deploy) + + // ErrNotFound is desired since we delete the resource earlier + assert.True(t, apierrors.IsNotFound(err)) + // release blocked request + completeRequests(backendComplete, 1) + + // Wait for deployment mapping to sync. + require.Eventually(t, func() bool { + return queueManager.TotalCounts()[modelName+"-deploy"] == 0 + }, 3*time.Second, 100*time.Millisecond) + + t.Logf("Waiting for wait group") + wg.Wait() +} + func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) { require.EventuallyWithT(t, func(t *assert.CollectT) { err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy) @@ -92,13 +160,14 @@ func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) }, 3*time.Second, time.Second/2, "waiting for the deployment to be scaled up") } -func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int) { +func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int, expCode int) { for i := 0; i < n; i++ { - sendRequest(t, wg, modelName) + sendRequest(t, wg, modelName, expCode) } } -func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string) { +func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) { + t.Helper() wg.Add(1) go func() { defer wg.Done() @@ -109,7 +178,7 @@ func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string) { res, err := testHTTPClient.Do(req) require.NoError(t, err) - require.Equal(t, 200, res.StatusCode) + require.Equal(t, expCode, res.StatusCode) }() } diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index b628bb9f..74697559 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -38,6 +38,7 @@ var ( testCancel context.CancelFunc testServer *httptest.Server testHTTPClient = &http.Client{Timeout: 10 * time.Second} + queueManager *queue.Manager ) func init() { @@ -78,7 +79,7 @@ func TestMain(m *testing.M) { requireNoError(err) const concurrencyPerReplica = 1 - queueManager := queue.NewManager(concurrencyPerReplica) + queueManager = queue.NewManager(concurrencyPerReplica) endpointManager, err := endpoints.NewManager(mgr) requireNoError(err)