diff --git a/pkg/cache/v3/delta_test.go b/pkg/cache/v3/delta_test.go index 22af7e80fc..9e4825bed5 100644 --- a/pkg/cache/v3/delta_test.go +++ b/pkg/cache/v3/delta_test.go @@ -24,9 +24,7 @@ import ( func assertResourceMapEqual(t *testing.T, want, got map[string]types.Resource) { t.Helper() - if !cmp.Equal(want, got, protocmp.Transform()) { - t.Errorf("got resources %v, want %v", got, want) - } + assert.Truef(t, cmp.Equal(want, got, protocmp.Transform()), "got resources %v, want %v", got, want) } func TestSnapshotCacheDeltaWatch(t *testing.T) { @@ -45,9 +43,7 @@ func TestSnapshotCacheDeltaWatch(t *testing.T) { }, stream.NewStreamState(true, nil), watches[typ]) } - if err := c.SetSnapshot(context.Background(), key, fixture.snapshot()); err != nil { - t.Fatal(err) - } + require.NoError(t, c.SetSnapshot(context.Background(), key, fixture.snapshot())) versionMap := make(map[string]map[string]string) for _, typ := range testTypes { @@ -81,19 +77,15 @@ func TestSnapshotCacheDeltaWatch(t *testing.T) { }, state, watches[typ]) } - if count := c.GetStatusInfo(key).GetNumDeltaWatches(); count != len(testTypes) { - t.Errorf("watches should be created for the latest version, saw %d watches expected %d", count, len(testTypes)) - } + count := c.GetStatusInfo(key).GetNumDeltaWatches() + assert.Lenf(t, testTypes, count, "watches should be created for the latest version, saw %d watches expected %d", count, len(testTypes)) // set partially-versioned snapshot snapshot2 := fixture.snapshot() snapshot2.Resources[types.Endpoint] = cache.NewResources(fixture.version2, []types.Resource{resource.MakeEndpoint(clusterName, 9090)}) - if err := c.SetSnapshot(context.Background(), key, snapshot2); err != nil { - t.Fatal(err) - } - if count := c.GetStatusInfo(key).GetNumDeltaWatches(); count != len(testTypes)-1 { - t.Errorf("watches should be preserved for all but one, got: %d open watches instead of the expected %d open watches", count, len(testTypes)-1) - } + require.NoError(t, c.SetSnapshot(context.Background(), key, snapshot2)) + count = c.GetStatusInfo(key).GetNumDeltaWatches() + assert.Equalf(t, count, len(testTypes)-1, "watches should be preserved for all but one, got: %d open watches instead of the expected %d open watches", count, len(testTypes)-1) // validate response for endpoints select { @@ -127,9 +119,7 @@ func TestDeltaRemoveResources(t *testing.T) { }, *streams[typ], watches[typ]) } - if err := c.SetSnapshot(context.Background(), key, fixture.snapshot()); err != nil { - t.Fatal(err) - } + require.NoError(t, c.SetSnapshot(context.Background(), key, fixture.snapshot())) for _, typ := range testTypes { t.Run(typ, func(t *testing.T) { @@ -157,16 +147,13 @@ func TestDeltaRemoveResources(t *testing.T) { }, *streams[typ], watches[typ]) } - if count := c.GetStatusInfo(key).GetNumDeltaWatches(); count != len(testTypes) { - t.Errorf("watches should be created for the latest version, saw %d watches expected %d", count, len(testTypes)) - } + count := c.GetStatusInfo(key).GetNumDeltaWatches() + assert.Lenf(t, testTypes, count, "watches should be created for the latest version, saw %d watches expected %d", count, len(testTypes)) // set a partially versioned snapshot with no endpoints snapshot2 := fixture.snapshot() snapshot2.Resources[types.Endpoint] = cache.NewResources(fixture.version2, []types.Resource{}) - if err := c.SetSnapshot(context.Background(), key, snapshot2); err != nil { - t.Fatal(err) - } + require.NoError(t, c.SetSnapshot(context.Background(), key, snapshot2)) // validate response for endpoints select { @@ -177,9 +164,7 @@ func TestDeltaRemoveResources(t *testing.T) { nextVersionMap := out.GetNextVersionMap() // make sure the version maps are different since we no longer are tracking any endpoint resources - if reflect.DeepEqual(streams[testTypes[0]].GetResourceVersions(), nextVersionMap) { - t.Fatalf("versionMap for the endpoint resource type did not change, received: %v, instead of an empty map", nextVersionMap) - } + require.Falsef(t, reflect.DeepEqual(streams[testTypes[0]].GetResourceVersions(), nextVersionMap), "versionMap for the endpoint resource type did not change, received: %v, instead of an empty map", nextVersionMap) case <-time.After(time.Second): t.Fatal("failed to receive snapshot response") } @@ -196,13 +181,10 @@ func TestConcurrentSetDeltaWatch(t *testing.T) { responses := make(chan cache.DeltaResponse, 1) if i < 25 { snap, err := cache.NewSnapshot("", map[rsrc.Type][]types.Resource{}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) snap.Resources[types.Endpoint] = cache.NewResources(version, []types.Resource{resource.MakeEndpoint(clusterName, uint32(i))}) - if err := c.SetSnapshot(context.Background(), key, snap); err != nil { - t.Fatalf("snapshot failed: %s", err) - } + err = c.SetSnapshot(context.Background(), key, snap) + require.NoErrorf(t, err, "snapshot failed") } else { cancel := c.CreateDeltaWatch(&discovery.DeltaDiscoveryRequest{ Node: &core.Node{ @@ -282,17 +264,14 @@ func TestSnapshotCacheDeltaWatchCancel(t *testing.T) { cancel() } // c.GetStatusKeys() should return at least 1 because we register a node ID with the above watch creations - if keys := c.GetStatusKeys(); len(keys) == 0 { - t.Errorf("expected to see a status info registered for watch, saw %d entries", len(keys)) - } + keys := c.GetStatusKeys() + assert.NotEmptyf(t, keys, "expected to see a status info registered for watch, saw %d entries", len(keys)) for _, typ := range testTypes { - if count := c.GetStatusInfo(key).GetNumDeltaWatches(); count > 0 { - t.Errorf("watches should be released for %s", typ) - } + count := c.GetStatusInfo(key).GetNumDeltaWatches() + assert.LessOrEqualf(t, count, 0, "watches should be released for %s", typ) } - if s := c.GetStatusInfo("missing"); s != nil { - t.Errorf("should not return a status for unknown key: got %#v", s) - } + s := c.GetStatusInfo("missing") + assert.Nilf(t, s, "should not return a status for unknown key: got %#v", s) } diff --git a/pkg/cache/v3/linear_test.go b/pkg/cache/v3/linear_test.go index 01ef92f9a0..4327faf14b 100644 --- a/pkg/cache/v3/linear_test.go +++ b/pkg/cache/v3/linear_test.go @@ -41,28 +41,17 @@ func testResource(s string) types.Resource { func verifyResponse(t *testing.T, ch <-chan Response, version string, num int) { t.Helper() r := <-ch - if r.GetRequest().GetTypeUrl() != testType { - t.Errorf("unexpected empty request type URL: %q", r.GetRequest().GetTypeUrl()) - } - if r.GetContext() == nil { - t.Errorf("unexpected empty response context") - } + assert.Equalf(t, testType, r.GetRequest().GetTypeUrl(), "unexpected empty request type URL: %q", r.GetRequest().GetTypeUrl()) + assert.NotNilf(t, r.GetContext(), "unexpected empty response context") out, err := r.GetDiscoveryResponse() - if err != nil { - t.Fatal(err) - } - if out.GetVersionInfo() == "" { - t.Error("unexpected response empty version") - } - if n := len(out.GetResources()); n != num { - t.Errorf("unexpected number of responses: got %d, want %d", n, num) - } - if version != "" && out.GetVersionInfo() != version { - t.Errorf("unexpected version: got %q, want %q", out.GetVersionInfo(), version) - } - if out.GetTypeUrl() != testType { - t.Errorf("unexpected type URL: %q", out.GetTypeUrl()) + require.NoError(t, err) + assert.NotEqualf(t, "", out.GetVersionInfo(), "unexpected response empty version") + n := len(out.GetResources()) + assert.Equalf(t, n, num, "unexpected number of responses: got %d, want %d", n, num) + if version != "" { + assert.Equalf(t, out.GetVersionInfo(), version, "unexpected version: got %q, want %q", out.GetVersionInfo(), version) } + assert.Equalf(t, testType, out.GetTypeUrl(), "unexpected type URL: %q", out.GetTypeUrl()) } type resourceInfo struct { @@ -73,16 +62,10 @@ type resourceInfo struct { func validateDeltaResponse(t *testing.T, resp DeltaResponse, resources []resourceInfo, deleted []string) { t.Helper() - if resp.GetDeltaRequest().GetTypeUrl() != testType { - t.Errorf("unexpected empty request type URL: %q", resp.GetDeltaRequest().GetTypeUrl()) - } + assert.Equalf(t, testType, resp.GetDeltaRequest().GetTypeUrl(), "unexpected empty request type URL: %q", resp.GetDeltaRequest().GetTypeUrl()) out, err := resp.GetDeltaDiscoveryResponse() - if err != nil { - t.Fatal(err) - } - if len(out.GetResources()) != len(resources) { - t.Errorf("unexpected number of responses: got %d, want %d", len(out.GetResources()), len(resources)) - } + require.NoError(t, err) + assert.Equalf(t, len(out.GetResources()), len(resources), "unexpected number of responses: got %d, want %d", len(out.GetResources()), len(resources)) for _, r := range resources { found := false for _, r1 := range out.GetResources() { @@ -95,16 +78,10 @@ func validateDeltaResponse(t *testing.T, resp DeltaResponse, resources []resourc break } } - if !found { - t.Errorf("resource with name %q not found in response", r.name) - } - } - if out.GetTypeUrl() != testType { - t.Errorf("unexpected type URL: %q", out.GetTypeUrl()) - } - if len(out.GetRemovedResources()) != len(deleted) { - t.Errorf("unexpected number of removed resurces: got %d, want %d", len(out.GetRemovedResources()), len(deleted)) + assert.Truef(t, found, "resource with name %q not found in response", r.name) } + assert.Equalf(t, testType, out.GetTypeUrl(), "unexpected type URL: %q", out.GetTypeUrl()) + assert.Equalf(t, len(out.GetRemovedResources()), len(deleted), "unexpected number of removed resurces: got %d, want %d", len(out.GetRemovedResources()), len(deleted)) for _, r := range deleted { found := false for _, rr := range out.GetRemovedResources() { @@ -113,9 +90,7 @@ func validateDeltaResponse(t *testing.T, resp DeltaResponse, resources []resourc break } } - if !found { - t.Errorf("Expected resource %s to be deleted", r) - } + assert.Truef(t, found, "Expected resource %s to be deleted", r) } } @@ -134,32 +109,25 @@ func verifyDeltaResponse(t *testing.T, ch <-chan DeltaResponse, resources []reso func checkWatchCount(t *testing.T, c *LinearCache, name string, count int) { t.Helper() - if i := c.NumWatches(name); i != count { - t.Errorf("unexpected number of watches for %q: got %d, want %d", name, i, count) - } + i := c.NumWatches(name) + assert.Equalf(t, i, count, "unexpected number of watches for %q: got %d, want %d", name, i, count) } func checkDeltaWatchCount(t *testing.T, c *LinearCache, count int) { t.Helper() - if i := c.NumDeltaWatches(); i != count { - t.Errorf("unexpected number of delta watches: got %d, want %d", i, count) - } + i := c.NumDeltaWatches() + assert.Equalf(t, i, count, "unexpected number of delta watches: got %d, want %d", i, count) } func checkVersionMapNotSet(t *testing.T, c *LinearCache) { t.Helper() - if c.versionMap != nil { - t.Errorf("version map is set on the cache with %d elements", len(c.versionMap)) - } + assert.Nilf(t, c.versionMap, "version map is set on the cache with %d elements", len(c.versionMap)) } func checkVersionMapSet(t *testing.T, c *LinearCache) { t.Helper() - if c.versionMap == nil { - t.Errorf("version map is not set on the cache") - } else if len(c.versionMap) != len(c.resources) { - t.Errorf("version map has the wrong number of elements: %d instead of %d expected", len(c.versionMap), len(c.resources)) - } + assert.NotNilf(t, c.versionMap, "version map is not set on the cache") + assert.Equalf(t, len(c.versionMap), len(c.resources), "version map has the wrong number of elements: %d instead of %d expected", len(c.versionMap), len(c.resources)) } func mustBlock(t *testing.T, w <-chan Response) { @@ -180,9 +148,7 @@ func mustBlockDelta(t *testing.T, w <-chan DeltaResponse) { func hashResource(t *testing.T, resource types.Resource) string { marshaledResource, err := MarshalResource(resource) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) v := HashResource(marshaledResource) if v == "" { t.Fatal(errors.New("failed to build resource version")) @@ -213,17 +179,13 @@ func TestLinearCornerCases(t *testing.T) { streamState := stream.NewStreamState(false, map[string]string{}) c := NewLinearCache(testType) err := c.UpdateResource("a", nil) - if err == nil { - t.Error("expected error on nil resource") - } + require.Errorf(t, err, "expected error on nil resource") // create an incorrect type URL request w := make(chan Response, 1) c.CreateWatch(&Request{TypeUrl: "test"}, streamState, w) select { case r := <-w: - if r != nil { - t.Error("response should be nil") - } + assert.Nilf(t, r, "response should be nil") default: t.Error("should receive nil response") } @@ -325,9 +287,7 @@ func TestLinearGetResources(t *testing.T) { resources := c.GetResources() - if !reflect.DeepEqual(expectedResources, resources) { - t.Errorf("resources are not equal. got: %v want: %v", resources, expectedResources) - } + assert.Truef(t, reflect.DeepEqual(expectedResources, resources), "resources are not equal. got: %v want: %v", resources, expectedResources) } func TestLinearVersionPrefix(t *testing.T) { diff --git a/pkg/cache/v3/resource_test.go b/pkg/cache/v3/resource_test.go index 0f311af5c8..cfbe2db35b 100644 --- a/pkg/cache/v3/resource_test.go +++ b/pkg/cache/v3/resource_test.go @@ -77,12 +77,10 @@ func TestValidate(t *testing.T) { }}, } - if err := invalidRoute.Validate(); err == nil { - t.Error("expected an error") - } - if err := invalidRoute.GetVirtualHosts()[0].Validate(); err == nil { - t.Error("expected an error") - } + err := invalidRoute.Validate() + require.Errorf(t, err, "expected an error") + err = invalidRoute.GetVirtualHosts()[0].Validate() + assert.Errorf(t, err, "expected an error") } type customResource struct { @@ -96,33 +94,24 @@ func (cs *customResource) GetName() string { return customName } var _ types.ResourceWithName = &customResource{} func TestGetResourceName(t *testing.T) { - if name := cache.GetResourceName(testEndpoint); name != clusterName { - t.Errorf("GetResourceName(%v) => got %q, want %q", testEndpoint, name, clusterName) - } - if name := cache.GetResourceName(testCluster); name != clusterName { - t.Errorf("GetResourceName(%v) => got %q, want %q", testCluster, name, clusterName) - } - if name := cache.GetResourceName(testRoute); name != routeName { - t.Errorf("GetResourceName(%v) => got %q, want %q", testRoute, name, routeName) - } - if name := cache.GetResourceName(testScopedRoute); name != scopedRouteName { - t.Errorf("GetResourceName(%v) => got %q, want %q", testScopedRoute, name, scopedRouteName) - } - if name := cache.GetResourceName(testVirtualHost); name != virtualHostName { - t.Errorf("GetResourceName(%v) => got %q, want %q", testVirtualHost, name, virtualHostName) - } - if name := cache.GetResourceName(testListener); name != listenerName { - t.Errorf("GetResourceName(%v) => got %q, want %q", testListener, name, listenerName) - } - if name := cache.GetResourceName(testRuntime); name != runtimeName { - t.Errorf("GetResourceName(%v) => got %q, want %q", testRuntime, name, runtimeName) - } - if name := cache.GetResourceName(&customResource{}); name != customName { - t.Errorf("GetResourceName(nil) => got %q, want %q", name, customName) - } - if name := cache.GetResourceName(nil); name != "" { - t.Errorf("GetResourceName(nil) => got %q, want none", name) - } + name := cache.GetResourceName(testEndpoint) + assert.Equalf(t, clusterName, name, "GetResourceName(%v) => got %q, want %q", testEndpoint, name, clusterName) + name = cache.GetResourceName(testCluster) + assert.Equalf(t, clusterName, name, "GetResourceName(%v) => got %q, want %q", testCluster, name, clusterName) + name = cache.GetResourceName(testRoute) + assert.Equalf(t, routeName, name, "GetResourceName(%v) => got %q, want %q", testRoute, name, routeName) + name = cache.GetResourceName(testScopedRoute) + assert.Equalf(t, scopedRouteName, name, "GetResourceName(%v) => got %q, want %q", testScopedRoute, name, scopedRouteName) + name = cache.GetResourceName(testVirtualHost) + assert.Equalf(t, virtualHostName, name, "GetResourceName(%v) => got %q, want %q", testVirtualHost, name, virtualHostName) + name = cache.GetResourceName(testListener) + assert.Equalf(t, listenerName, name, "GetResourceName(%v) => got %q, want %q", testListener, name, listenerName) + name = cache.GetResourceName(testRuntime) + assert.Equalf(t, runtimeName, name, "GetResourceName(%v) => got %q, want %q", testRuntime, name, runtimeName) + name = cache.GetResourceName(&customResource{}) + assert.Equalf(t, customName, name, "GetResourceName(nil) => got %q, want %q", name, customName) + name = cache.GetResourceName(nil) + assert.Equalf(t, "", name, "GetResourceName(nil) => got %q, want none", name) } func TestGetResourceNames(t *testing.T) { @@ -218,9 +207,7 @@ func TestGetResourceReferences(t *testing.T) { } for _, cs := range cases { names := cache.GetResourceReferences(cache.IndexResourcesByName([]types.ResourceWithTTL{{Resource: cs.in}})) - if !reflect.DeepEqual(names, cs.out) { - t.Errorf("GetResourceReferences(%v) => got %v, want %v", cs.in, names, cs.out) - } + assert.Truef(t, reflect.DeepEqual(names, cs.out), "GetResourceReferences(%v) => got %v, want %v", cs.in, names, cs.out) } } @@ -238,7 +225,5 @@ func TestGetAllResourceReferencesReturnsExpectedRefs(t *testing.T) { resources[types.ScopedRoute] = cache.NewResources("1", []types.Resource{testScopedRoute}) actual := cache.GetAllResourceReferences(resources) - if !reflect.DeepEqual(actual, expected) { - t.Errorf("GetAllResourceReferences(%v) => got %v, want %v", resources, actual, expected) - } + assert.Truef(t, reflect.DeepEqual(actual, expected), "GetAllResourceReferences(%v) => got %v, want %v", resources, actual, expected) } diff --git a/pkg/cache/v3/simple_test.go b/pkg/cache/v3/simple_test.go index eba4cf96d9..9df1b6c1cb 100644 --- a/pkg/cache/v3/simple_test.go +++ b/pkg/cache/v3/simple_test.go @@ -101,21 +101,15 @@ func TestSnapshotCacheWithTTL(t *testing.T) { defer cancel() c := cache.NewSnapshotCacheWithHeartbeating(ctx, true, group{}, logger{t: t}, time.Second) - if _, err := c.GetSnapshot(key); err == nil { - t.Errorf("unexpected snapshot found for key %q", key) - } + _, err := c.GetSnapshot(key) + require.Errorf(t, err, "unexpected snapshot found for key %q", key) - if err := c.SetSnapshot(context.Background(), key, snapshotWithTTL); err != nil { - t.Fatal(err) - } + err = c.SetSnapshot(context.Background(), key, snapshotWithTTL) + require.NoError(t, err) snap, err := c.GetSnapshot(key) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(snap, snapshotWithTTL) { - t.Errorf("expect snapshot: %v, got: %v", snapshotWithTTL, snap) - } + require.NoError(t, err) + assert.Truef(t, reflect.DeepEqual(snap, snapshotWithTTL), "expect snapshot: %v, got: %v", snapshotWithTTL, snap) wg := sync.WaitGroup{} // All the resources should respond immediately when version is not up to date. @@ -128,12 +122,9 @@ func TestSnapshotCacheWithTTL(t *testing.T) { c.CreateWatch(&discovery.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]}, streamState, value) select { case out := <-value: - if gotVersion, _ := out.GetVersion(); gotVersion != fixture.version { - t.Errorf("got version %q, want %q", gotVersion, fixture.version) - } - if !reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshotWithTTL.GetResourcesAndTTL(typ)) { - t.Errorf("get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshotWithTTL.GetResourcesAndTTL(typ)) - } + gotVersion, _ := out.GetVersion() + assert.Equalf(t, gotVersion, fixture.version, "got version %q, want %q", gotVersion, fixture.version) + assert.Truef(t, reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshotWithTTL.GetResourcesAndTTL(typ)), "get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshotWithTTL.GetResourcesAndTTL(typ)) // Update streamState streamState.SetKnownResourceNamesAsList(typ, out.GetRequest().GetResourceNames()) case <-time.After(2 * time.Second): @@ -159,16 +150,11 @@ func TestSnapshotCacheWithTTL(t *testing.T) { select { case out := <-value: - if gotVersion, _ := out.GetVersion(); gotVersion != fixture.version { - t.Errorf("got version %q, want %q", gotVersion, fixture.version) - } - if !reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshotWithTTL.GetResourcesAndTTL(typ)) { - t.Errorf("get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshotWithTTL.GetResources(typ)) - } + gotVersion, _ := out.GetVersion() + assert.Equalf(t, gotVersion, fixture.version, "got version %q, want %q", gotVersion, fixture.version) + assert.Truef(t, reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshotWithTTL.GetResourcesAndTTL(typ)), "get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshotWithTTL.GetResources(typ)) - if !reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshotWithTTL.GetResourcesAndTTL(typ)) { - t.Errorf("get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshotWithTTL.GetResources(typ)) - } + assert.Truef(t, reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshotWithTTL.GetResourcesAndTTL(typ)), "get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshotWithTTL.GetResources(typ)) updatesByType[typ]++ @@ -183,33 +169,23 @@ func TestSnapshotCacheWithTTL(t *testing.T) { wg.Wait() - if len(updatesByType) != 1 { - t.Errorf("expected to only receive updates for TTL'd type, got %v", updatesByType) - } + assert.Lenf(t, updatesByType, 1, "expected to only receive updates for TTL'd type, got %v", updatesByType) // Avoid an exact match on number of triggers to avoid this being flaky. - if updatesByType[rsrc.EndpointType] < 2 { - t.Errorf("expected at least two TTL updates for endpoints, got %d", updatesByType[rsrc.EndpointType]) - } + assert.GreaterOrEqualf(t, updatesByType[rsrc.EndpointType], 2, "expected at least two TTL updates for endpoints, got %d", updatesByType[rsrc.EndpointType]) } func TestSnapshotCache(t *testing.T) { c := cache.NewSnapshotCache(true, group{}, logger{t: t}) - if _, err := c.GetSnapshot(key); err == nil { - t.Errorf("unexpected snapshot found for key %q", key) - } + _, err := c.GetSnapshot(key) + require.Errorf(t, err, "unexpected snapshot found for key %q", key) - if err := c.SetSnapshot(context.Background(), key, fixture.snapshot()); err != nil { - t.Fatal(err) - } + err = c.SetSnapshot(context.Background(), key, fixture.snapshot()) + require.NoError(t, err) snap, err := c.GetSnapshot(key) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(snap, fixture.snapshot()) { - t.Errorf("expect snapshot: %v, got: %v", fixture.snapshot(), snap) - } + require.NoError(t, err) + assert.Truef(t, reflect.DeepEqual(snap, fixture.snapshot()), "expect snapshot: %v, got: %v", fixture.snapshot(), snap) // try to get endpoints with incorrect list of names // should not receive response @@ -232,12 +208,9 @@ func TestSnapshotCache(t *testing.T) { select { case out := <-value: snapshot := fixture.snapshot() - if gotVersion, _ := out.GetVersion(); gotVersion != fixture.version { - t.Errorf("got version %q, want %q", gotVersion, fixture.version) - } - if !reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshot.GetResourcesAndTTL(typ)) { - t.Errorf("get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshot.GetResourcesAndTTL(typ)) - } + gotVersion, _ := out.GetVersion() + assert.Equalf(t, gotVersion, fixture.version, "got version %q, want %q", gotVersion, fixture.version) + assert.Truef(t, reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshot.GetResourcesAndTTL(typ)), "get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshot.GetResourcesAndTTL(typ)) case <-time.After(time.Second): t.Fatal("failed to receive snapshot response") } @@ -247,33 +220,29 @@ func TestSnapshotCache(t *testing.T) { func TestSnapshotCacheFetch(t *testing.T) { c := cache.NewSnapshotCache(true, group{}, logger{t: t}) - if err := c.SetSnapshot(context.Background(), key, fixture.snapshot()); err != nil { - t.Fatal(err) - } + require.NoError(t, c.SetSnapshot(context.Background(), key, fixture.snapshot())) for _, typ := range testTypes { t.Run(typ, func(t *testing.T) { resp, err := c.Fetch(context.Background(), &discovery.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]}) - if err != nil || resp == nil { - t.Fatal("unexpected error or null response") - } - if gotVersion, _ := resp.GetVersion(); gotVersion != fixture.version { - t.Errorf("got version %q, want %q", gotVersion, fixture.version) - } + require.NoErrorf(t, err, "unexpected error") + require.NotNilf(t, resp, "null response") + gotVersion, _ := resp.GetVersion() + assert.Equalf(t, gotVersion, fixture.version, "got version %q, want %q", gotVersion, fixture.version) }) } // no response for missing snapshot - if resp, err := c.Fetch(context.Background(), - &discovery.DiscoveryRequest{TypeUrl: rsrc.ClusterType, Node: &core.Node{Id: "oof"}}); resp != nil || err == nil { - t.Errorf("missing snapshot: response is not nil %v", resp) - } + resp, err := c.Fetch(context.Background(), + &discovery.DiscoveryRequest{TypeUrl: rsrc.ClusterType, Node: &core.Node{Id: "oof"}}) + require.Errorf(t, err, "missing snapshot: response is not nil %v", resp) + assert.Nilf(t, resp, "missing snapshot: response is not nil %v", resp) // no response for latest version - if resp, err := c.Fetch(context.Background(), - &discovery.DiscoveryRequest{TypeUrl: rsrc.ClusterType, VersionInfo: fixture.version}); resp != nil || err == nil { - t.Errorf("latest version: response is not nil %v", resp) - } + resp, err = c.Fetch(context.Background(), + &discovery.DiscoveryRequest{TypeUrl: rsrc.ClusterType, VersionInfo: fixture.version}) + require.Errorf(t, err, "latest version: response is not nil %v", resp) + assert.Nilf(t, resp, "latest version: response is not nil %v", resp) } func TestSnapshotCacheWatch(t *testing.T) { @@ -284,20 +253,16 @@ func TestSnapshotCacheWatch(t *testing.T) { watches[typ] = make(chan cache.Response, 1) c.CreateWatch(&discovery.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ]}, streamState, watches[typ]) } - if err := c.SetSnapshot(context.Background(), key, fixture.snapshot()); err != nil { - t.Fatal(err) - } + + require.NoError(t, c.SetSnapshot(context.Background(), key, fixture.snapshot())) for _, typ := range testTypes { t.Run(typ, func(t *testing.T) { select { case out := <-watches[typ]: - if gotVersion, _ := out.GetVersion(); gotVersion != fixture.version { - t.Errorf("got version %q, want %q", gotVersion, fixture.version) - } + gotVersion, _ := out.GetVersion() + assert.Equalf(t, gotVersion, fixture.version, "got version %q, want %q", gotVersion, fixture.version) snapshot := fixture.snapshot() - if !reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshot.GetResourcesAndTTL(typ)) { - t.Errorf("get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshot.GetResourcesAndTTL(typ)) - } + assert.Truef(t, reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshot.GetResourcesAndTTL(typ)), "get resources %v, want %v", out.(*cache.RawResponse).Resources, snapshot.GetResourcesAndTTL(typ)) streamState.SetKnownResourceNamesAsList(typ, out.GetRequest().GetResourceNames()) case <-time.After(time.Second): t.Fatal("failed to receive snapshot response") @@ -311,29 +276,22 @@ func TestSnapshotCacheWatch(t *testing.T) { c.CreateWatch(&discovery.DiscoveryRequest{TypeUrl: typ, ResourceNames: names[typ], VersionInfo: fixture.version}, streamState, watches[typ]) } - if count := c.GetStatusInfo(key).GetNumWatches(); count != len(testTypes) { - t.Errorf("watches should be created for the latest version: %d", count) - } + count := c.GetStatusInfo(key).GetNumWatches() + assert.Lenf(t, testTypes, count, "watches should be created for the latest version: %d", count) // set partially-versioned snapshot snapshot2 := fixture.snapshot() snapshot2.Resources[types.Endpoint] = cache.NewResources(fixture.version2, []types.Resource{resource.MakeEndpoint(clusterName, 9090)}) - if err := c.SetSnapshot(context.Background(), key, snapshot2); err != nil { - t.Fatal(err) - } - if count := c.GetStatusInfo(key).GetNumWatches(); count != len(testTypes)-1 { - t.Errorf("watches should be preserved for all but one: %d", count) - } + require.NoError(t, c.SetSnapshot(context.Background(), key, snapshot2)) + count = c.GetStatusInfo(key).GetNumWatches() + assert.Equalf(t, count, len(testTypes)-1, "watches should be preserved for all but one: %d", count) // validate response for endpoints select { case out := <-watches[rsrc.EndpointType]: - if gotVersion, _ := out.GetVersion(); gotVersion != fixture.version2 { - t.Errorf("got version %q, want %q", gotVersion, fixture.version2) - } - if !reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshot2.Resources[types.Endpoint].Items) { - t.Errorf("got resources %v, want %v", out.(*cache.RawResponse).Resources, snapshot2.Resources[types.Endpoint].Items) - } + gotVersion, _ := out.GetVersion() + assert.Equalf(t, gotVersion, fixture.version2, "got version %q, want %q", gotVersion, fixture.version2) + assert.Truef(t, reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshot2.Resources[types.Endpoint].Items), "got resources %v, want %v", out.(*cache.RawResponse).Resources, snapshot2.Resources[types.Endpoint].Items) case <-time.After(time.Second): t.Fatal("failed to receive snapshot response") } @@ -350,9 +308,8 @@ func TestConcurrentSetWatch(t *testing.T) { if i < 25 { snap := cache.Snapshot{} snap.Resources[types.Endpoint] = cache.NewResources(fmt.Sprintf("v%d", i), []types.Resource{resource.MakeEndpoint(clusterName, uint32(i))}) - if err := c.SetSnapshot(context.Background(), id, &snap); err != nil { - t.Fatalf("failed to set snapshot %q: %s", id, err) - } + err := c.SetSnapshot(context.Background(), id, &snap) + require.NoErrorf(t, err, "failed to set snapshot %q", id) } else { streamState := stream.NewStreamState(false, map[string]string{}) cancel := c.CreateWatch(&discovery.DiscoveryRequest{ @@ -374,19 +331,16 @@ func TestSnapshotCacheWatchCancel(t *testing.T) { cancel() } // should be status info for the node - if keys := c.GetStatusKeys(); len(keys) == 0 { - t.Error("got 0, want status info for the node") - } + keys := c.GetStatusKeys() + assert.NotEmptyf(t, keys, "got 0, want status info for the node") for _, typ := range testTypes { - if count := c.GetStatusInfo(key).GetNumWatches(); count > 0 { - t.Errorf("watches should be released for %s", typ) - } + count := c.GetStatusInfo(key).GetNumWatches() + assert.LessOrEqualf(t, count, 0, "watches should be released for %s", typ) } - if empty := c.GetStatusInfo("missing"); empty != nil { - t.Errorf("should not return a status for unknown key: got %#v", empty) - } + empty := c.GetStatusInfo("missing") + assert.Nilf(t, empty, "should not return a status for unknown key: got %#v", empty) } func TestSnapshotCacheWatchTimeout(t *testing.T) { @@ -443,9 +397,8 @@ func TestSnapshotCreateWatchWithResourcePreviouslyNotRequested(t *testing.T) { rsrc.SecretType: {}, rsrc.ExtensionConfigType: {}, }) - if err := c.SetSnapshot(context.Background(), key, snapshot2); err != nil { - t.Fatal(err) - } + err := c.SetSnapshot(context.Background(), key, snapshot2) + require.NoError(t, err) watch := make(chan cache.Response) // Request resource with name=ClusterName @@ -456,13 +409,10 @@ func TestSnapshotCreateWatchWithResourcePreviouslyNotRequested(t *testing.T) { select { case out := <-watch: - if gotVersion, _ := out.GetVersion(); gotVersion != fixture.version { - t.Errorf("got version %q, want %q", gotVersion, fixture.version) - } + gotVersion, _ := out.GetVersion() + assert.Equalf(t, gotVersion, fixture.version, "got version %q, want %q", gotVersion, fixture.version) want := map[string]types.ResourceWithTTL{clusterName: snapshot2.Resources[types.Endpoint].Items[clusterName]} - if !reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), want) { - t.Errorf("got resources %v, want %v", out.(*cache.RawResponse).Resources, want) - } + assert.Truef(t, reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), want), "got resources %v, want %v", out.(*cache.RawResponse).Resources, want) case <-time.After(time.Second): t.Fatal("failed to receive snapshot response") } @@ -479,12 +429,9 @@ func TestSnapshotCreateWatchWithResourcePreviouslyNotRequested(t *testing.T) { select { case out := <-watch: - if gotVersion, _ := out.GetVersion(); gotVersion != fixture.version { - t.Errorf("got version %q, want %q", gotVersion, fixture.version) - } - if !reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshot2.Resources[types.Endpoint].Items) { - t.Errorf("got resources %v, want %v", out.(*cache.RawResponse).Resources, snapshot2.Resources[types.Endpoint].Items) - } + gotVersion, _ := out.GetVersion() + assert.Equalf(t, gotVersion, fixture.version, "got version %q, want %q", gotVersion, fixture.version) + assert.Truef(t, reflect.DeepEqual(cache.IndexResourcesByName(out.(*cache.RawResponse).Resources), snapshot2.Resources[types.Endpoint].Items), "got resources %v, want %v", out.(*cache.RawResponse).Resources, snapshot2.Resources[types.Endpoint].Items) case <-time.After(time.Second): t.Fatal("failed to receive snapshot response") } @@ -504,16 +451,13 @@ func TestSnapshotCreateWatchWithResourcePreviouslyNotRequested(t *testing.T) { func TestSnapshotClear(t *testing.T) { c := cache.NewSnapshotCache(true, group{}, logger{t: t}) - if err := c.SetSnapshot(context.Background(), key, fixture.snapshot()); err != nil { - t.Fatal(err) - } + + require.NoError(t, c.SetSnapshot(context.Background(), key, fixture.snapshot())) c.ClearSnapshot(key) - if empty := c.GetStatusInfo(key); empty != nil { - t.Errorf("cache should be cleared") - } - if keys := c.GetStatusKeys(); len(keys) != 0 { - t.Errorf("keys should be empty") - } + empty := c.GetStatusInfo(key) + assert.Nilf(t, empty, "cache should be cleared") + keys := c.GetStatusKeys() + assert.Emptyf(t, keys, "keys should be empty") } type singleResourceSnapshot struct { @@ -637,9 +581,8 @@ func TestAvertPanicForWatchOnNonExistentSnapshot(t *testing.T) { name: "one-second", resource: durationpb.New(time.Second), } - if err := c.SetSnapshot(ctx, "test", srs); err != nil { - t.Errorf("unexpected error setting snapshot %v", err) - } + err := c.SetSnapshot(ctx, "test", srs) + assert.NoErrorf(t, err, "unexpected error setting snapshot %v", err) }() <-responder diff --git a/pkg/cache/v3/snapshot_test.go b/pkg/cache/v3/snapshot_test.go index c32f492076..5335ec98e4 100644 --- a/pkg/cache/v3/snapshot_test.go +++ b/pkg/cache/v3/snapshot_test.go @@ -30,43 +30,38 @@ import ( func TestTestSnapshotIsConsistent(t *testing.T) { snapshot := fixture.snapshot() - if err := snapshot.Consistent(); err != nil { - t.Errorf("got inconsistent snapshot for %#v\nerr=%s", snapshot, err.Error()) - } + err := snapshot.Consistent() + assert.NoErrorf(t, err, "got inconsistent snapshot for %#v\nerr=%s", snapshot, err.Error()) } func TestSnapshotWithOnlyEndpointIsInconsistent(t *testing.T) { - if snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ + snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ rsrc.EndpointType: {testEndpoint}, - }); snap.Consistent() == nil { - t.Errorf("got consistent snapshot %#v", snap) - } + }) + assert.Errorf(t, snap.Consistent(), "got consistent snapshot %#v", snap) } func TestClusterWithMissingEndpointIsInconsistent(t *testing.T) { - if snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ + snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ rsrc.EndpointType: {resource.MakeEndpoint("missing", 8080)}, rsrc.ClusterType: {testCluster}, - }); snap.Consistent() == nil { - t.Errorf("got consistent snapshot %#v", snap) - } + }) + assert.Errorf(t, snap.Consistent(), "got consistent snapshot %#v", snap) } func TestListenerWithMissingRoutesIsInconsistent(t *testing.T) { - if snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ + snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ rsrc.ListenerType: {testListener}, - }); snap.Consistent() == nil { - t.Errorf("got consistent snapshot %#v", snap) - } + }) + assert.Errorf(t, snap.Consistent(), "got consistent snapshot %#v", snap) } func TestListenerWithUnidentifiedRouteIsInconsistent(t *testing.T) { - if snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ + snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ rsrc.RouteType: {resource.MakeRouteConfig("test", clusterName)}, rsrc.ListenerType: {testListener}, - }); snap.Consistent() == nil { - t.Errorf("got consistent snapshot %#v", snap) - } + }) + assert.Errorf(t, snap.Consistent(), "got consistent snapshot %#v", snap) } func TestRouteListenerWithRouteIsConsistent(t *testing.T) { @@ -79,22 +74,20 @@ func TestRouteListenerWithRouteIsConsistent(t *testing.T) { }, }) - if err := snap.Consistent(); err != nil { - t.Errorf("got inconsistent snapshot %s, %#v", err.Error(), snap) - } + err := snap.Consistent() + assert.NoErrorf(t, err, "got inconsistent snapshot %s, %#v", err.Error(), snap) } func TestScopedRouteListenerWithScopedRouteOnlyIsInconsistent(t *testing.T) { - if snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ + snap, _ := cache.NewSnapshot(fixture.version, map[rsrc.Type][]types.Resource{ rsrc.ListenerType: { resource.MakeScopedRouteHTTPListener(resource.Xds, "listener0", 80), }, rsrc.ScopedRouteType: { resource.MakeScopedRouteConfig("scopedRoute0", "testRoute0", []string{"1.2.3.4"}), }, - }); snap.Consistent() == nil { - t.Errorf("got consistent snapshot %#v", snap) - } + }) + assert.Errorf(t, snap.Consistent(), "got consistent snapshot %#v", snap) } func TestScopedRouteListenerWithScopedRouteAndRouteIsConsistent(t *testing.T) { @@ -156,30 +149,24 @@ func TestMultipleListenersWithScopedRouteAndRouteIsConsistent(t *testing.T) { }, }) - if err := snap.Consistent(); err != nil { - t.Errorf("got inconsistent snapshot %s, %#v", err.Error(), snap) - } + err := snap.Consistent() + assert.NoErrorf(t, err, "got inconsistent snapshot %s, %#v", err.Error(), snap) } func TestSnapshotGetters(t *testing.T) { var nilsnap *cache.Snapshot - if out := nilsnap.GetResources(rsrc.EndpointType); out != nil { - t.Errorf("got non-empty resources for nil snapshot: %#v", out) - } - if out := nilsnap.Consistent(); out == nil { - t.Errorf("nil snapshot should be inconsistent") - } - if out := nilsnap.GetVersion(rsrc.EndpointType); out != "" { - t.Errorf("got non-empty version for nil snapshot: %#v", out) - } + outRT := nilsnap.GetResources(rsrc.EndpointType) + assert.Nilf(t, outRT, "got non-empty resources for nil snapshot: %#v", outRT) + err := nilsnap.Consistent() + require.Errorf(t, err, "nil snapshot should be inconsistent") + version := nilsnap.GetVersion(rsrc.EndpointType) + assert.Equalf(t, "", version, "got non-empty version for nil snapshot: %#v", version) snapshot := fixture.snapshot() - if out := snapshot.GetResources("not a type"); out != nil { - t.Errorf("got non-empty resources for unknown type: %#v", out) - } - if out := snapshot.GetVersion("not a type"); out != "" { - t.Errorf("got non-empty version for unknown type: %#v", out) - } + out := snapshot.GetResources("not a type") + assert.Nilf(t, out, "got non-empty resources for unknown type: %#v", out) + version = snapshot.GetVersion("not a type") + assert.Equalf(t, "", version, "got non-empty version for unknown type: %#v", version) } func TestNewSnapshotBadType(t *testing.T) { diff --git a/pkg/cache/v3/status_test.go b/pkg/cache/v3/status_test.go index def8346118..ca8fea8aa3 100644 --- a/pkg/cache/v3/status_test.go +++ b/pkg/cache/v3/status_test.go @@ -18,40 +18,35 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" ) func TestIDHash(t *testing.T) { node := &core.Node{Id: "test"} - if got := (IDHash{}).ID(node); got != "test" { - t.Errorf("IDHash.ID(%v) => got %s, want %s", node, got, node.GetId()) - } - if got := (IDHash{}).ID(nil); got != "" { - t.Errorf("IDHash.ID(nil) => got %s, want empty", got) - } + got := (IDHash{}).ID(node) + assert.Equalf(t, "test", got, "IDHash.ID(%v) => got %s, want %s", node, got, node.GetId()) + got = (IDHash{}).ID(nil) + assert.Equalf(t, "", got, "IDHash.ID(nil) => got %s, want empty", got) } func TestNewStatusInfo(t *testing.T) { node := &core.Node{Id: "test"} info := newStatusInfo(node) - if got := info.GetNode(); !reflect.DeepEqual(got, node) { - t.Errorf("GetNode() => got %#v, want %#v", got, node) - } + gotNode := info.GetNode() + assert.Truef(t, reflect.DeepEqual(gotNode, node), "GetNode() => got %#v, want %#v", gotNode, node) - if got := info.GetNumWatches(); got != 0 { - t.Errorf("GetNumWatches() => got %d, want 0", got) - } + gotNumWatches := info.GetNumWatches() + assert.Equalf(t, 0, gotNumWatches, "GetNumWatches() => got %d, want 0", gotNumWatches) - if got := info.GetLastWatchRequestTime(); !got.IsZero() { - t.Errorf("GetLastWatchRequestTime() => got %v, want zero time", got) - } + gotLastWatchRequestTime := info.GetLastWatchRequestTime() + assert.Truef(t, gotLastWatchRequestTime.IsZero(), "GetLastWatchRequestTime() => got %v, want zero time", gotLastWatchRequestTime) - if got := info.GetNumDeltaWatches(); got != 0 { - t.Errorf("GetNumDeltaWatches() => got %d, want 0", got) - } + gotNumDeltaWatches := info.GetNumDeltaWatches() + assert.Equalf(t, 0, gotNumDeltaWatches, "GetNumDeltaWatches() => got %d, want 0", gotNumDeltaWatches) - if got := info.GetLastDeltaWatchRequestTime(); !got.IsZero() { - t.Errorf("GetLastDeltaWatchRequestTime() => got %v, want zero time", got) - } + gotLastDeltaWatchRequestTime := info.GetLastDeltaWatchRequestTime() + assert.Truef(t, gotLastDeltaWatchRequestTime.IsZero(), "GetLastDeltaWatchRequestTime() => got %v, want zero time", gotLastDeltaWatchRequestTime) } diff --git a/pkg/conversion/struct_test.go b/pkg/conversion/struct_test.go index 0c6fef869b..531b50c4bf 100644 --- a/pkg/conversion/struct_test.go +++ b/pkg/conversion/struct_test.go @@ -18,6 +18,8 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" @@ -32,9 +34,7 @@ func TestConversion(t *testing.T) { Node: &core.Node{Id: "proxy"}, } st, err := conversion.MessageToStruct(pb) - if err != nil { - t.Fatalf("unexpected error %v", err) - } + require.NoErrorf(t, err, "unexpected error") pbst := map[string]*structpb.Value{ "version_info": {Kind: &structpb.Value_StringValue{StringValue: "test"}}, "node": {Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{ @@ -43,24 +43,16 @@ func TestConversion(t *testing.T) { }, }}}, } - if !cmp.Equal(st.GetFields(), pbst, cmp.Comparer(proto.Equal)) { - t.Errorf("MessageToStruct(%v) => got %v, want %v", pb, st.GetFields(), pbst) - } + assert.Truef(t, cmp.Equal(st.GetFields(), pbst, cmp.Comparer(proto.Equal)), "MessageToStruct(%v) => got %v, want %v", pb, st.GetFields(), pbst) out := &discovery.DiscoveryRequest{} err = conversion.StructToMessage(st, out) - if err != nil { - t.Fatalf("unexpected error %v", err) - } - if !cmp.Equal(pb, out, cmp.Comparer(proto.Equal)) { - t.Errorf("StructToMessage(%v) => got %v, want %v", st, out, pb) - } + require.NoErrorf(t, err, "unexpected error") + assert.Truef(t, cmp.Equal(pb, out, cmp.Comparer(proto.Equal)), "StructToMessage(%v) => got %v, want %v", st, out, pb) - if _, err = conversion.MessageToStruct(nil); err == nil { - t.Error("MessageToStruct(nil) => got no error") - } + _, err = conversion.MessageToStruct(nil) + require.Errorf(t, err, "MessageToStruct(nil) => got no error") - if err = conversion.StructToMessage(nil, &discovery.DiscoveryRequest{}); err == nil { - t.Error("StructToMessage(nil) => got no error") - } + err = conversion.StructToMessage(nil, &discovery.DiscoveryRequest{}) + assert.Errorf(t, err, "StructToMessage(nil) => got no error") } diff --git a/pkg/server/v3/gateway_test.go b/pkg/server/v3/gateway_test.go index 26dba5be8a..f750278c4f 100644 --- a/pkg/server/v3/gateway_test.go +++ b/pkg/server/v3/gateway_test.go @@ -22,6 +22,9 @@ import ( "testing" "testing/iotest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" "github.com/envoyproxy/go-control-plane/pkg/cache/types" "github.com/envoyproxy/go-control-plane/pkg/cache/v3" @@ -88,35 +91,19 @@ func TestGateway(t *testing.T) { } for _, cs := range failCases { req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, cs.path, cs.body) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp, code, err := gtw.ServeHTTP(req) - if err == nil { - t.Errorf("ServeHTTP succeeded, but should have failed") - } - if resp != nil { - t.Errorf("handler returned wrong response") - } - if status := code; status != cs.expect { - t.Errorf("handler returned wrong status: %d, want %d", status, cs.expect) - } + require.Errorf(t, err, "ServeHTTP succeeded, but should have failed") + assert.Nilf(t, resp, "handler returned wrong response") + assert.Equalf(t, code, cs.expect, "handler returned wrong status: %d, want %d", code, cs.expect) } for _, path := range []string{resource.FetchClusters, resource.FetchRoutes, resource.FetchListeners} { req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, path, strings.NewReader("{\"node\": {\"id\": \"test\"}}")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp, code, err := gtw.ServeHTTP(req) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Errorf("handler returned wrong response") - } - if status := code; status != 200 { - t.Errorf("handler returned wrong status: %d, want %d", status, 200) - } + require.NoError(t, err) + assert.NotNilf(t, resp, "handler returned wrong response") + assert.Equalf(t, 200, code, "handler returned wrong status: %d, want %d", code, 200) } } diff --git a/pkg/server/v3/server_test.go b/pkg/server/v3/server_test.go index 53b937c037..f8ac66f43f 100644 --- a/pkg/server/v3/server_test.go +++ b/pkg/server/v3/server_test.go @@ -293,9 +293,7 @@ func TestServerShutdown(t *testing.T) { case opaqueType: err = s.StreamAggregatedResources(resp) } - if err != nil { - t.Errorf("Stream() => got %v, want no error", err) - } + assert.NoErrorf(t, err, "Stream() => got %v, want no error", err) shutdown <- true }(typ) @@ -356,9 +354,8 @@ func TestResponseHandlers(t *testing.T) { select { case <-resp.sent: close(resp.recv) - if want := map[string]int{typ: 1}; !reflect.DeepEqual(want, config.counts) { - t.Errorf("watch counts => got %v, want %v", config.counts, want) - } + want := map[string]int{typ: 1} + assert.Truef(t, reflect.DeepEqual(want, config.counts), "watch counts => got %v, want %v", config.counts, want) case <-time.After(1 * time.Second): t.Fatalf("got no response") }