forked from alphagov/cdn-acceptance-tests
-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.go
480 lines (407 loc) · 12.8 KB
/
helpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
package main
import (
"bytes"
"crypto/rand"
"crypto/tls"
"fmt"
"io/ioutil"
"log"
"mime"
"net"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
)
// CDNBackendServer is a backend server which will receive and respond to
// requests from the CDN.
type CDNBackendServer struct {
Name string
Port int
TLSCerts []tls.Certificate
handler func(w http.ResponseWriter, r *http.Request)
server *httptest.Server
}
// ServeHTTP satisfies the http.HandlerFunc interface. Health check requests
// for `HEAD` are always served 200 responses. Other requests are passed
// off to a custom handler provided by SwitchHandler.
func (s *CDNBackendServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Backend-Name", s.Name)
// swallow healtheck requests
if r.Method == "HEAD" {
w.Header().Set("PING", "PONG")
return
}
s.handler(w, r)
}
// ResetHandler sets the handler back to an empty function that will return
// a 200 response.
func (s *CDNBackendServer) ResetHandler() {
s.handler = func(w http.ResponseWriter, r *http.Request) {}
}
// SwitchHandler sets the handler to a custom function. This is used by
// tests to pass in their own request inspection and response handler.
func (s *CDNBackendServer) SwitchHandler(h func(w http.ResponseWriter, r *http.Request)) {
s.handler = h
}
// IsStarted checks whether the server is currently started.
func (s *CDNBackendServer) IsStarted() bool {
return (s.server != nil)
}
// Stop closes all outstanding client connections and unbind the port.
// Resets server back to nil, as if the backend had been instantiated but
// Start() not called.
func (s *CDNBackendServer) Stop() {
s.server.Close()
s.server = nil
}
// Start resets the handler back to the default and starts the server on
// Port. It will exit immediately if it's unable to bind the port, due to
// permissions or a conflicting application.
func (s *CDNBackendServer) Start() {
s.ResetHandler()
addr := fmt.Sprintf(":%d", s.Port)
ln, err := net.Listen("tcp", addr)
if err != nil {
log.Fatal(err)
}
// Store the port randomly assigned by the kernel if we started with 0.
if s.Port == 0 {
_, portStr, _ := net.SplitHostPort(ln.Addr().String())
s.Port, _ = strconv.Atoi(portStr)
}
s.server = httptest.NewUnstartedServer(s)
s.server.Listener = ln
if len(s.TLSCerts) > 0 {
s.server.TLS = &tls.Config{
Certificates: s.TLSCerts,
}
}
s.server.StartTLS()
log.Printf("Started server on port %d", s.Port)
}
// CachedHostLookup caches DNS lookups for the given `Host` in order to
// prevent us switching to another edge location in the middle of tests.
type CachedHostLookup struct {
Host string
hardCachedIP string
}
// lookup performs a DNS lookup and caches the first IP address returned.
// Subsequent requests always return the cached address, preventing further
// DNS requests.
func (c *CachedHostLookup) lookup(host string) string {
if c.hardCachedIP == "" {
ipAddresses, err := net.LookupHost(host)
if err != nil {
log.Fatal(err)
}
c.hardCachedIP = ipAddresses[0]
}
return c.hardCachedIP
}
// Dial acts as a wrapper for `net.Dial`, ostensibly for use with
// `http.Transport`. If the hostname matches `Host` then it will use the
// cached address.
func (c *CachedHostLookup) Dial(network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
log.Fatal(err)
}
if host != c.Host {
return net.Dial(network, addr)
}
ipAddr := c.lookup(host)
return net.Dial(network, net.JoinHostPort(ipAddr, port))
}
// NewCachedDial returns the `Dial` function for a new CachedHostLookup
// object with the given host.
func NewCachedDial(host string) func(string, string) (net.Conn, error) {
c := CachedHostLookup{
Host: host,
}
return c.Dial
}
// NewUUID returns a v4 (random) UUID string.
// This might not be strictly RFC4122 compliant, but it will do. Credit:
// https://groups.google.com/d/msg/golang-nuts/Rn13T6BZpgE/dBaYVJ4hB5gJ
func NewUUID() string {
bs := make([]byte, 16)
rand.Read(bs)
bs[6] = (bs[6] & 0x0f) | 0x40
bs[8] = (bs[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x", bs[0:4], bs[4:6], bs[6:8], bs[8:10], bs[10:])
}
// NewUniqueEdgeURL constructs a new URL for edge. Always uses HTTPS. A random
// UUID is used in the path to ensure that it hasn't previously been cached. It
// is passed as a query param for / so that some of the tests can be run
// against a service that hasn't been configured to point at our test backends.
func NewUniqueEdgeURL() string {
url := url.URL{
Scheme: "https",
Host: *edgeHost,
Path: "/",
RawQuery: url.Values{
"nocache": []string{NewUUID()},
}.Encode(),
}
return url.String()
}
// NewUniqueEdgeGET constructs a GET request (but not perform it) against edge.
// Uses NewUniqueEdgeURL() to ensure that it hasn't previously been cached. The
// request method field of the returned object can be later modified if
// required.
func NewUniqueEdgeGET(t *testing.T) *http.Request {
url := NewUniqueEdgeURL()
req, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Fatal(err)
}
return req
}
// RoundTripCheckError makes an HTTP request using http.RoundTrip, which
// doesn't handle redirects or cookies, and return the response. If there are
// any errors then the calling test will be aborted so as not to operate on a
// nil response.
func RoundTripCheckError(t *testing.T, req *http.Request) *http.Response {
start := time.Now()
resp, err := client.RoundTrip(req)
if duration := time.Since(start); duration > requestSlowThreshold {
t.Error("Slow request, took:", duration)
}
if *debugResp {
t.Logf("%#v", resp)
}
if err != nil {
t.Fatal(err)
}
return resp
}
// ResetBackends resets all backends, ensuring that they are started, have the
// default handler function, and that the edge considers them healthy. It may
// take some time because we need to receive and respond to enough probe health
// checks to be considered up.
func ResetBackends(backends []*CDNBackendServer) {
remainingBackendsStopped := false
// Reverse priority order so that waitForBackend works.
for i := len(backends); i > 0; i-- {
backend := backends[i-1]
if backend.IsStarted() {
backend.ResetHandler()
} else {
if !remainingBackendsStopped {
// Ensure all remaining unchecked backends are stopped so that
// waitForBackend will work. We'll bring them back one-by-one.
stopBackends(backends[0 : i-1])
remainingBackendsStopped = true
}
backend.Start()
err := waitForBackend(backend.Name)
if err != nil {
log.Fatal(err)
}
}
}
}
// Ensure that a slice of backends are stopped.
func stopBackends(backends []*CDNBackendServer) {
for _, backend := range backends {
if backend.IsStarted() {
backend.Stop()
}
}
}
// Wait for the backend to return with the header we expect. This is designed to
// confirm that requests are hitting this specific backend, rather than a lower-level
// backend that this overrides (for example, origin over a mirror)
//
func waitForBackend(expectedBackendName string) error {
const maxRetries = 20
const waitForCdnProbeToPropagate = time.Duration(5 * time.Second)
const timeBetweenAttempts = time.Duration(2 * time.Second)
var url string
log.Printf("Checking health of %s...", expectedBackendName)
for try := 0; try <= maxRetries; try++ {
url = NewUniqueEdgeURL()
req, _ := http.NewRequest("GET", url, nil)
resp, err := client.RoundTrip(req)
if err != nil {
return err
}
resp.Body.Close()
if resp.Header.Get("Backend-Name") == expectedBackendName {
if try != 0 {
time.Sleep(waitForCdnProbeToPropagate)
}
log.Println(expectedBackendName + " is up!")
return nil // all is well!
}
time.Sleep(timeBetweenAttempts)
}
return fmt.Errorf(
"%s still not available after %d attempts",
expectedBackendName,
maxRetries,
)
}
// Callback function to modify complete response.
type responseCallback func(w http.ResponseWriter)
// Wrapper for testRequestsCachedDuration() with a respTTL of zero.
// Meaning that the cached object doesn't expire.
func testRequestsCachedIndefinite(
t *testing.T,
req *http.Request,
respCB responseCallback,
) {
testRequestsCachedDuration(t, req, respCB, time.Duration(0))
}
// Helper function to make three requests and test responses. If respTTL is:
//
// - zero: no delay between requests, origin should only see one request,
// and all response bodies should be identical (from cache).
// - non-zero: first and second request without delay, origin should only
// see one request and responses bodies should be identical, then after a
// delay of respTTL + a buffer a third response should get a new response
// directly from origin.
//
// A responseCallback, if not nil, will be called to modify the response
// before calling Write(body).
func testRequestsCachedDuration(
t *testing.T,
req *http.Request,
respCB responseCallback,
respTTL time.Duration,
) {
const responseCached = "first response"
const responseNotCached = "subsequent response"
var testCacheExpiry = respTTL > 0
var respTTLWithBuffer = respTTL + (respTTL / 4)
var requestsExpectedCount int
requestsReceivedCount := 0
switch testCacheExpiry {
case true:
requestsExpectedCount = 2
case false:
requestsExpectedCount = 1
}
originServer.SwitchHandler(func(w http.ResponseWriter, r *http.Request) {
if respCB != nil {
respCB(w)
}
if requestsReceivedCount == 0 {
w.Write([]byte(responseCached))
} else {
w.Write([]byte(responseNotCached))
}
requestsReceivedCount++
})
for requestCount := 1; requestCount < 4; requestCount++ {
if testCacheExpiry && requestCount == 3 {
time.Sleep(respTTLWithBuffer)
}
resp := RoundTripCheckError(t, req)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
var expectedBody string
if testCacheExpiry && requestCount > 2 {
expectedBody = responseNotCached
} else {
expectedBody = responseCached
}
if receivedBody := string(body); receivedBody != expectedBody {
t.Errorf(
"Request %d received incorrect response body. Expected %q, got %q",
requestCount,
expectedBody,
receivedBody,
)
}
}
if requestsReceivedCount != requestsExpectedCount {
t.Errorf(
"Origin received the wrong number of requests. Expected %d, got %d",
requestsExpectedCount,
requestsReceivedCount,
)
}
}
// Callback function to modify response headers.
type responseHeaderCallback func(h http.Header)
// Helper function to make three requests and verify that we get three
// unique and uncached responses back. A responseHeaderCallback, if not nil,
// will be called to modify the response headers.
func testThreeRequestsNotCached(t *testing.T, req *http.Request, headerCB responseHeaderCallback) {
requestsReceivedCount := 0
responseBodies := []string{
"first response",
"second response",
"third response",
}
originServer.SwitchHandler(func(w http.ResponseWriter, r *http.Request) {
if headerCB != nil {
headerCB(w.Header())
}
w.Write([]byte(responseBodies[requestsReceivedCount]))
requestsReceivedCount++
})
for requestCount, expectedBody := range responseBodies {
resp := RoundTripCheckError(t, req)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if receivedBody := string(body); receivedBody != expectedBody {
t.Errorf(
"Request %d received incorrect response body. Expected %q, got %q",
requestCount+1,
expectedBody,
receivedBody,
)
}
}
}
// testResponseNotManipulated configures origin to respond to a request with
// the contents of fixture file. It then makes a request and asserts that
// the response body matches the original fixture file, meaning that the CDN
// hasn't manipulated it in any way. The `Content-Type` and request path are
// set according to the fixture's file extension to ensure that the CDN
// detects it correctly.
func testResponseNotManipulated(t *testing.T, fixtureFile string) {
fixtureData, err := ioutil.ReadFile(fixtureFile)
if err != nil {
t.Fatalf("Unable load fixture file %q", fixtureFile)
}
contentType := mime.TypeByExtension(filepath.Ext(fixtureFile))
if contentType == "" || strings.Contains(contentType, "text/plain") {
t.Fatalf("Unable to determine fixture Content-Type. Got %q", contentType)
}
originServer.SwitchHandler(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", contentType)
w.Write(fixtureData)
})
req := NewUniqueEdgeGET(t)
req.URL.Path = "/" + filepath.Base(fixtureFile)
resp := RoundTripCheckError(t, req)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(body, fixtureData) {
t.Error("Response body did not match fixture")
if bytes.Compare(body, fixtureData) != 0 {
t.Errorf(
"Response body sizes for debug purposes. Expected %d, got %d",
len(fixtureData),
len(body),
)
}
}
}