diff --git a/connection_test.go b/connection_test.go index 5679ed9f..b6509a5b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -496,6 +496,9 @@ func TestServerClientCancellation(t *testing.T) { opts.DefaultConnectionOptions.SendCancelOnContextCanceled = true opts.DefaultConnectionOptions.PropagateCancel = true + serverStats := newRecordingStatsReporter() + opts.StatsReporter = serverStats + testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { callReceived := make(chan struct{}) testutils.RegisterFunc(ts.Server(), "ctxWait", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { @@ -519,16 +522,23 @@ func TestServerClientCancellation(t *testing.T) { _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "ctxWait", nil, nil) assert.Equal(t, ErrRequestCancelled, err, "client call result") + statsTags := ts.Server().StatsTags() + serverStats.Expected.IncCounter("inbound.cancels.requested", statsTags, 1) + serverStats.Expected.IncCounter("inbound.cancels.honored", statsTags, 1) + calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "ctxWait").Failed("canceled").End() ts.AssertRelayStats(calls) }) + + serverStats.ValidateExpected(t) } func TestCancelWithoutSendCancelOnContextCanceled(t *testing.T) { tests := []struct { msg string sendCancelOnContextCanceled bool + wantCancelRequested bool }{ { msg: "no send or process cancel", @@ -537,6 +547,7 @@ func TestCancelWithoutSendCancelOnContextCanceled(t *testing.T) { { msg: "only enable cancels on outbounds", sendCancelOnContextCanceled: true, + wantCancelRequested: true, }, } @@ -545,7 +556,12 @@ func TestCancelWithoutSendCancelOnContextCanceled(t *testing.T) { opts := testutils.NewOpts() opts.DefaultConnectionOptions.SendCancelOnContextCanceled = tt.sendCancelOnContextCanceled + serverStats := newRecordingStatsReporter() + opts.StatsReporter = serverStats + testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { + serverStats.Reset() + callReceived := make(chan struct{}) testutils.RegisterFunc(ts.Server(), "ctxWait", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { require.NoError(t, ctx.Err(), "context valid before cancellation") @@ -571,6 +587,17 @@ func TestCancelWithoutSendCancelOnContextCanceled(t *testing.T) { calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "ctxWait").Failed("timeout").End() ts.AssertRelayStats(calls) + + ts.AddPostFn(func() { + // Validating these at the end of the test, when server has fully processed the cancellation. + if tt.wantCancelRequested && !ts.HasRelay() { + serverStats.Expected.IncCounter("inbound.cancels.requested", ts.Server().StatsTags(), 1) + serverStats.ValidateExpected(t) + } else { + serverStats.EnsureNotPresent(t, "inbound.cancels.requested") + } + serverStats.EnsureNotPresent(t, "inbound.cancels.honored") + }) }) }) } diff --git a/inbound.go b/inbound.go index d6e70ab9..dd525237 100644 --- a/inbound.go +++ b/inbound.go @@ -144,6 +144,8 @@ func (c *Connection) handleCallReqContinue(frame *Frame) bool { } func (c *Connection) handleCancel(frame *Frame) bool { + c.statsReporter.IncCounter("inbound.cancels.requested", c.commonStatsTags, 1) + if !c.opts.PropagateCancel { if c.log.Enabled(LogLevelDebug) { c.log.Debugf("Ignoring cancel for %v", frame.Header.ID) @@ -151,6 +153,8 @@ func (c *Connection) handleCancel(frame *Frame) bool { return true } + c.statsReporter.IncCounter("inbound.cancels.honored", c.commonStatsTags, 1) + c.inbound.handleCancel(frame) // Free the frame, as it's consumed immediately. diff --git a/stats_utils_test.go b/stats_utils_test.go index 312eeb42..e820dc79 100644 --- a/stats_utils_test.go +++ b/stats_utils_test.go @@ -125,9 +125,28 @@ func (r *recordingStatsReporter) Validate(t *testing.T) { assert.Equal(t, keysMap(r.Expected.Values), keysMap(r.Values), "Metric keys are different") - for counterKey, counter := range r.Values { - expectedCounter, ok := r.Expected.Values[counterKey] - if !ok { + r.validateExpectedLocked(t) +} + +// ValidateExpected only validates metrics added to expected rather than all recorded metrics. +func (r *recordingStatsReporter) ValidateExpected(t testing.TB) { + r.Lock() + defer r.Unlock() + + r.validateExpectedLocked(t) +} + +func (r *recordingStatsReporter) EnsureNotPresent(t testing.TB, counter string) { + r.Lock() + defer r.Unlock() + + assert.NotContains(t, r.Values, counter, "metric should not be present") +} + +func (r *recordingStatsReporter) validateExpectedLocked(t testing.TB) { + for counterKey, expectedCounter := range r.Expected.Values { + counter, ok := r.Values[counterKey] + if !assert.True(t, ok, "expected %v not found", counterKey) { continue } diff --git a/testutils/test_server.go b/testutils/test_server.go index f3abae47..c8353831 100644 --- a/testutils/test_server.go +++ b/testutils/test_server.go @@ -400,6 +400,11 @@ func (ts *TestServer) verify(ch *tchannel.Channel) { assert.NoError(ts, errs, "Verification failed. Channel state:\n%v", IntrospectJSON(ch, nil /* opts */)) } +// AddPostFn registers a function that will be executed after channels are closed. +func (ts *TestServer) AddPostFn(fn func()) { + ts.postFns = append(ts.postFns, fn) +} + func (ts *TestServer) post() { if !ts.Failed() { for _, ch := range ts.channels {