diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs index f62c9e4a72bd..bd24fb63e2e5 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs @@ -33,7 +33,8 @@ public HeaderPropagationMiddleware(RequestDelegate next, IOptions(StringComparer.OrdinalIgnoreCase); @@ -56,7 +57,7 @@ public Task Invoke(HttpContext context) } } - return _next.Invoke(context); + await _next.Invoke(context); } private static StringValues GetValue(HttpContext context, HeaderPropagationEntry entry) diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs index f6576d2d688d..99bb4997a5ef 100644 --- a/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Options; @@ -14,7 +16,11 @@ public class HeaderPropagationMiddlewareTest public HeaderPropagationMiddlewareTest() { Context = new DefaultHttpContext(); - Next = ctx => Task.CompletedTask; + Next = ctx => + { + CapturedHeaders = State.Headers; + return Task.CompletedTask; + }; Configuration = new HeaderPropagationOptions(); State = new HeaderPropagationValues(); Middleware = new HeaderPropagationMiddleware(Next, @@ -24,8 +30,10 @@ public HeaderPropagationMiddlewareTest() public DefaultHttpContext Context { get; set; } public RequestDelegate Next { get; set; } + public Action Assertion { get; set; } public HeaderPropagationOptions Configuration { get; set; } public HeaderPropagationValues State { get; set; } + public IDictionary CapturedHeaders { get; set; } public HeaderPropagationMiddleware Middleware { get; set; } [Fact] @@ -39,8 +47,8 @@ public async Task HeaderInRequest_AddCorrectValue() await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(new[] { "test" }, State.Headers["in"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal(new[] { "test" }, CapturedHeaders["in"]); } [Fact] @@ -53,7 +61,7 @@ public async Task NoHeaderInRequest_DoesNotAddIt() await Middleware.Invoke(Context); // Assert - Assert.Empty(State.Headers); + Assert.Empty(CapturedHeaders); } [Fact] @@ -66,7 +74,7 @@ public async Task HeaderInRequest_NotInOptions_DoesNotAddIt() await Middleware.Invoke(Context); // Assert - Assert.Empty(State.Headers); + Assert.Empty(CapturedHeaders); } [Fact] @@ -82,10 +90,10 @@ public async Task MultipleHeadersInRequest_AddAllHeaders() await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(new[] { "test" }, State.Headers["in"]); - Assert.Contains("another", State.Headers.Keys); - Assert.Equal(new[] { "test2" }, State.Headers["another"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal(new[] { "test" }, CapturedHeaders["in"]); + Assert.Contains("another", CapturedHeaders.Keys); + Assert.Equal(new[] { "test2" }, CapturedHeaders["another"]); } [Theory] @@ -101,7 +109,7 @@ public async Task HeaderEmptyInRequest_DoesNotAddIt(string headerValue) await Middleware.Invoke(Context); // Assert - Assert.DoesNotContain("in", State.Headers.Keys); + Assert.DoesNotContain("in", CapturedHeaders.Keys); } [Theory] @@ -127,8 +135,8 @@ public async Task UsesValueFilter(string[] filterValues, string[] expectedValues await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(expectedValues, State.Headers["in"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal(expectedValues, CapturedHeaders["in"]); Assert.Equal("in", receivedName); Assert.Equal(new StringValues("value"), receivedValue); Assert.Same(Context, receivedContext); @@ -145,8 +153,8 @@ public async Task PreferValueFilter_OverRequestHeader() await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal("test", State.Headers["in"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal("test", CapturedHeaders["in"]); } [Fact] @@ -159,7 +167,7 @@ public async Task EmptyValuesFromValueFilter_DoesNotAddIt() await Middleware.Invoke(Context); // Assert - Assert.DoesNotContain("in", State.Headers.Keys); + Assert.DoesNotContain("in", CapturedHeaders.Keys); } [Fact] @@ -174,8 +182,46 @@ public async Task MultipleEntries_AddsFirstToProduceValue() await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal("Test", State.Headers["in"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal("Test", CapturedHeaders["in"]); + } + + [Fact] + public async Task HeaderInRequest_WithBleedAsyncLocal_HasCorrectValue() + { + // Arrange + Configuration.Headers.Add("in"); + + // Process first request + Context.Request.Headers.Add("in", "dirty"); + await Middleware.Invoke(Context); + + // Process second request + Context = new DefaultHttpContext(); + Context.Request.Headers.Add("in", "test"); + await Middleware.Invoke(Context); + + // Assert + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal(new[] { "test" }, CapturedHeaders["in"]); + } + + [Fact] + public async Task NoHeaderInRequest_WithBleedAsyncLocal_DoesNotHaveIt() + { + // Arrange + Configuration.Headers.Add("in"); + + // Process first request + Context.Request.Headers.Add("in", "dirty"); + await Middleware.Invoke(Context); + + // Process second request + Context = new DefaultHttpContext(); + await Middleware.Invoke(Context); + + // Assert + Assert.Empty(CapturedHeaders); } } }