diff --git a/agent/hcp/telemetry_provider_test.go b/agent/hcp/telemetry_provider_test.go index c8d8ca66c295..184aadd99e6d 100644 --- a/agent/hcp/telemetry_provider_test.go +++ b/agent/hcp/telemetry_provider_test.go @@ -19,9 +19,10 @@ import ( ) const ( - testRefreshInterval = 100 * time.Millisecond - testSinkServiceName = "test.telemetry_config_provider" - testRaceSampleCount = 5000 + testRefreshInterval = 100 * time.Millisecond + testSinkServiceName = "test.telemetry_config_provider" + testRaceWriteSampleCount = 100 + testRaceReadSampleCount = 5000 ) var ( @@ -222,58 +223,109 @@ func TestTelemetryConfigProviderGetUpdate(t *testing.T) { } } -func TestTelemetryConfigProvider_Race(t *testing.T) { - cfg, err := testTelemetryCfg(&testConfig{ - endpoint: "http://test.com/v1/metrics", - filters: "test", - labels: map[string]string{ - "test_label": "123", - }, +type mockRaceClient struct { + cfg *client.TelemetryConfig + rw sync.RWMutex +} + +func newMockRaceClient() (*mockRaceClient, error) { + initCfg, err := testTelemetryCfg(&testConfig{ + endpoint: "test.com", + filters: "test", + labels: map[string]string{"test_label": "test_value"}, refreshInterval: testRefreshInterval, }) - require.NoError(t, err) + if err != nil { + return nil, err + } + + return &mockRaceClient{ + cfg: initCfg, + }, nil +} + +func (m *mockRaceClient) updateCfg(count int) (*client.TelemetryConfig, error) { + m.rw.Lock() + defer m.rw.Unlock() + + labels := map[string]string{fmt.Sprintf("label_%d", count): fmt.Sprintf("value_%d", count)} + + filters, err := regexp.Compile(fmt.Sprintf("consul_filter_%d", count)) + if err != nil { + return nil, err + } + + endpoint, err := url.Parse(fmt.Sprintf("http://consul-endpoint-%d.com", count)) + if err != nil { + return nil, err + } + + cfg := &client.TelemetryConfig{ + MetricsConfig: &client.MetricsConfig{ + Filters: filters, + Endpoint: endpoint, + Labels: labels, + }, + RefreshConfig: &client.RefreshConfig{ + RefreshInterval: testRefreshInterval, + }, + } + m.cfg = cfg + + return cfg, nil +} + +func (m *mockRaceClient) FetchBootstrap(ctx context.Context) (*client.BootstrapConfig, error) { + return nil, nil +} +func (m *mockRaceClient) PushServerStatus(ctx context.Context, status *client.ServerStatus) error { + return nil +} +func (m *mockRaceClient) DiscoverServers(ctx context.Context) ([]string, error) { return nil, nil } +func (m *mockRaceClient) FetchTelemetryConfig(ctx context.Context) (*client.TelemetryConfig, error) { + m.rw.RLock() + defer m.rw.RUnlock() + + return m.cfg, nil +} +func TestTelemetryConfigProvider_Race(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mockClient := client.NewMockClient(t) - mockClient.EXPECT().FetchTelemetryConfig(mock.Anything).Return(cfg, nil) + m, err := newMockRaceClient() + require.NoError(t, err) - // Start the provider goroutine - // Every refresh interval, config will be modified. - provider, err := NewHCPProvider(ctx, mockClient, cfg) + // Start the provider goroutine, which continuously fetches config. + provider, err := NewHCPProvider(ctx, m, m.cfg) require.NoError(t, err) - // Every refresh interval, try to query config using Get* methods inducing a race condition. - timer := time.NewTimer(testRefreshInterval) - defer timer.Stop() - for { - select { - case <-timer.C: - wg := &sync.WaitGroup{} - // Start goroutines that try to access label configuration. - kickOff(wg, testRaceSampleCount, provider, func(provider *hcpProviderImpl) { - require.Equal(t, provider.GetLabels(), cfg.MetricsConfig.Labels) - }) - - // Start goroutines that try to access endpoint configuration. - kickOff(wg, testRaceSampleCount, provider, func(provider *hcpProviderImpl) { - require.Equal(t, provider.GetFilters(), cfg.MetricsConfig.Filters) - }) - - // Start goroutines that try to access filter configuration. - kickOff(wg, testRaceSampleCount, provider, func(provider *hcpProviderImpl) { - require.Equal(t, provider.GetEndpoint(), cfg.MetricsConfig.Endpoint) - }) - - wg.Wait() - // Stop after 10 refresh intervals. - case <-time.After(10 * testRefreshInterval): - return - case <-ctx.Done(): - require.Fail(t, "Context cancelled before test finishes") - return - } + for count := 0; count < testRaceWriteSampleCount; count++ { + // Force a config value change in the client. + cfg, err := m.updateCfg(count) + require.NoError(t, err) + // Force provider to obtain new client config immediately. + // The provider goroutine in the background continues to fetch this value every RefreshInterval + // but this call is necessary to guarantee we can assert on expected values below. + provider.getUpdate(context.Background()) + + // Start goroutines to access label configuration. + wg := &sync.WaitGroup{} + kickOff(wg, testRaceReadSampleCount, provider, func(provider *hcpProviderImpl) { + require.Equal(t, provider.GetLabels(), cfg.MetricsConfig.Labels) + }) + + // Start goroutines to access endpoint configuration. + kickOff(wg, testRaceReadSampleCount, provider, func(provider *hcpProviderImpl) { + require.Equal(t, provider.GetFilters(), cfg.MetricsConfig.Filters) + }) + + // Start goroutines to access filter configuration. + kickOff(wg, testRaceReadSampleCount, provider, func(provider *hcpProviderImpl) { + require.Equal(t, provider.GetEndpoint(), cfg.MetricsConfig.Endpoint) + }) + + wg.Wait() } }