Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spike: Add integration test #65

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 120 additions & 24 deletions tests/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package integration
import (
"bytes"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -43,17 +44,7 @@ func TestScaleUpAndDown(t *testing.T) {
}))

// Mock an EndpointSlice.
testBackendURL, err := url.Parse(testBackend.URL)
require.NoError(t, err)
testBackendPort, err := strconv.Atoi(testBackendURL.Port())
require.NoError(t, err)
require.NoError(t, testK8sClient.Create(testCtx,
endpointSlice(
modelName,
testBackendURL.Hostname(),
int32(testBackendPort),
),
))
withMockEndpointSlice(t, testBackend, modelName)

// Wait for deployment mapping to sync.
time.Sleep(3 * time.Second)
Expand Down Expand Up @@ -103,17 +94,7 @@ func TestHandleModelUndeployment(t *testing.T) {
}))

// Mock an EndpointSlice.
testBackendURL, err := url.Parse(testBackend.URL)
require.NoError(t, err)
testBackendPort, err := strconv.Atoi(testBackendURL.Port())
require.NoError(t, err)
require.NoError(t, testK8sClient.Create(testCtx,
endpointSlice(
modelName,
testBackendURL.Hostname(),
int32(testBackendPort),
),
))
withMockEndpointSlice(t, testBackend, modelName)

// Wait for deployment mapping to sync.
time.Sleep(3 * time.Second)
Expand All @@ -132,7 +113,7 @@ func TestHandleModelUndeployment(t *testing.T) {
require.NoError(t, testK8sClient.Delete(testCtx, deploy))

// Check that the deployment was deleted
err = testK8sClient.Get(testCtx, client.ObjectKey{
err := testK8sClient.Get(testCtx, client.ObjectKey{
Namespace: deploy.Namespace,
Name: deploy.Name,
}, deploy)
Expand All @@ -151,6 +132,107 @@ func TestHandleModelUndeployment(t *testing.T) {
wg.Wait()
}

func TestRetryMiddleware(t *testing.T) {
const modelName = "test-model-c"
deploy := testDeployment(modelName)
require.NoError(t, testK8sClient.Create(testCtx, deploy))

// Wait for deployment mapping to sync.
time.Sleep(3 * time.Second)
backendRequests := &atomic.Int32{}
var serverCodes []int
testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expBody := []byte(fmt.Sprintf(`{"model": %q}`, modelName))
gotBody, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.Equal(t, expBody, gotBody)

i := backendRequests.Add(1)
code := serverCodes[i-1]
t.Logf("Serving request from testBackend: %d; code: %d\n", i, code)
w.WriteHeader(code)
_, err = w.Write([]byte(strconv.Itoa(code)))
require.NoError(t, err)
}))

// Mock an EndpointSlice.
withMockEndpointSlice(t, testBackend, modelName)

specs := map[string]struct {
serverCodes []int
header []tuple
expResultCode int
expResultBody string
expBackendHits int32
}{
"max retries - succeeds": {
serverCodes: []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusOK},
expResultCode: http.StatusOK,
expResultBody: "200",
expBackendHits: 4,
},
"max retries - fails": {
serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusBadGateway},
expResultCode: http.StatusBadGateway,
expResultBody: "502",
expBackendHits: 4,
},
"non retryable error code": {
serverCodes: []int{http.StatusNotImplemented},
expResultCode: http.StatusNotImplemented,
expResultBody: "501",
expBackendHits: 1,
},
"200 status code": {
serverCodes: []int{http.StatusOK},
expResultCode: http.StatusOK,
expResultBody: "200",
expBackendHits: 1,
},
"200 status code - model header": {
serverCodes: []int{http.StatusOK},
header: []tuple{{k: "X-Model", v: modelName}},
expResultCode: http.StatusOK,
expResultBody: "200",
expBackendHits: 1,
},
"503 with model header": {
serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable},
header: []tuple{{k: "X-Model", v: modelName}},
expResultCode: http.StatusServiceUnavailable,
expResultBody: "503",
expBackendHits: 4,
},
}
for name, spec := range specs {
t.Run(name, func(t *testing.T) {
// setup
serverCodes = spec.serverCodes
backendRequests.Store(0)

// when single request sent
gotBody := <-sendRequest(t, &sync.WaitGroup{}, modelName, spec.expResultCode, spec.header...)
// then only the last body is written
assert.Equal(t, spec.expResultBody, gotBody)
require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit")
})
}
}

func withMockEndpointSlice(t *testing.T, testBackend *httptest.Server, modelName string) {
testBackendURL, err := url.Parse(testBackend.URL)
require.NoError(t, err)
testBackendPort, err := strconv.Atoi(testBackendURL.Port())
require.NoError(t, err)
require.NoError(t, testK8sClient.Create(testCtx,
endpointSlice(
modelName,
testBackendURL.Hostname(),
int32(testBackendPort),
),
))
}

func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) {
require.EventuallyWithT(t, func(t *assert.CollectT) {
err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy)
Expand All @@ -166,20 +248,34 @@ func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int, exp
}
}

func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) {
type tuple struct {
k, v string
}

func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int, headers ...tuple) <-chan string {
t.Helper()
wg.Add(1)
bodyRespChan := make(chan string, 1)
go func() {
defer wg.Done()
defer close(bodyRespChan)

body := []byte(fmt.Sprintf(`{"model": %q}`, modelName))
req, err := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewReader(body))
requireNoError(err)
for _, e := range headers {
req.Header.Add(e.k, e.v)
}

res, err := testHTTPClient.Do(req)
require.NoError(t, err)
require.Equal(t, expCode, res.StatusCode)
got, err := io.ReadAll(res.Body)
_ = res.Body.Close()
require.NoError(t, err)
bodyRespChan <- string(got)
}()
return bodyRespChan
}

func completeRequests(c chan struct{}, n int) {
Expand Down
1 change: 1 addition & 0 deletions tests/integration/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func TestMain(m *testing.M) {
Deployments: deploymentManager,
Endpoints: endpointManager,
Queues: queueManager,
MaxRetries: 3,
}
testServer = httptest.NewServer(handler)
defer testServer.Close()
Expand Down
Loading