Skip to content

Commit

Permalink
HeaderPropagation: reset AsyncLocal per request
Browse files Browse the repository at this point in the history
As Kestrel can bleed the AsyncLocal across requests,
see dotnet#13991.
  • Loading branch information
alefranz committed Jan 12, 2020
1 parent e6af4bf commit fbe7ecf
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public HeaderPropagationMiddleware(RequestDelegate next, IOptions<HeaderPropagat
_values = values ?? throw new ArgumentNullException(nameof(values));
}

public Task Invoke(HttpContext context)
// This needs to be async as otherwise the AsyncLocal could bleed across requests, see https://github.com/aspnet/AspNetCore/issues/13991.
public async Task Invoke(HttpContext context)
{
// We need to intialize the headers because the message handler will use this to detect misconfiguration.
var headers = _values.Headers ??= new Dictionary<string, StringValues>(StringComparer.OrdinalIgnoreCase);
Expand All @@ -56,7 +57,7 @@ public Task Invoke(HttpContext context)
}
}

return _next.Invoke(context);
await _next.Invoke(context);
}

private static StringValues GetValue(HttpContext context, HeaderPropagationEntry entry)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Options;
Expand All @@ -14,7 +16,11 @@ public class HeaderPropagationMiddlewareTest
public HeaderPropagationMiddlewareTest()
{
Context = new DefaultHttpContext();
Next = ctx => Task.CompletedTask;
Next = ctx =>
{
CapturedHeaders = State.Headers;
return Task.CompletedTask;
};
Configuration = new HeaderPropagationOptions();
State = new HeaderPropagationValues();
Middleware = new HeaderPropagationMiddleware(Next,
Expand All @@ -24,8 +30,10 @@ public HeaderPropagationMiddlewareTest()

public DefaultHttpContext Context { get; set; }
public RequestDelegate Next { get; set; }
public Action Assertion { get; set; }
public HeaderPropagationOptions Configuration { get; set; }
public HeaderPropagationValues State { get; set; }
public IDictionary<string, StringValues> CapturedHeaders { get; set; }
public HeaderPropagationMiddleware Middleware { get; set; }

[Fact]
Expand All @@ -39,8 +47,8 @@ public async Task HeaderInRequest_AddCorrectValue()
await Middleware.Invoke(Context);

// Assert
Assert.Contains("in", State.Headers.Keys);
Assert.Equal(new[] { "test" }, State.Headers["in"]);
Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal(new[] { "test" }, CapturedHeaders["in"]);
}

[Fact]
Expand All @@ -53,7 +61,7 @@ public async Task NoHeaderInRequest_DoesNotAddIt()
await Middleware.Invoke(Context);

// Assert
Assert.Empty(State.Headers);
Assert.Empty(CapturedHeaders);
}

[Fact]
Expand All @@ -66,7 +74,7 @@ public async Task HeaderInRequest_NotInOptions_DoesNotAddIt()
await Middleware.Invoke(Context);

// Assert
Assert.Empty(State.Headers);
Assert.Empty(CapturedHeaders);
}

[Fact]
Expand All @@ -82,10 +90,10 @@ public async Task MultipleHeadersInRequest_AddAllHeaders()
await Middleware.Invoke(Context);

// Assert
Assert.Contains("in", State.Headers.Keys);
Assert.Equal(new[] { "test" }, State.Headers["in"]);
Assert.Contains("another", State.Headers.Keys);
Assert.Equal(new[] { "test2" }, State.Headers["another"]);
Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal(new[] { "test" }, CapturedHeaders["in"]);
Assert.Contains("another", CapturedHeaders.Keys);
Assert.Equal(new[] { "test2" }, CapturedHeaders["another"]);
}

[Theory]
Expand All @@ -101,7 +109,7 @@ public async Task HeaderEmptyInRequest_DoesNotAddIt(string headerValue)
await Middleware.Invoke(Context);

// Assert
Assert.DoesNotContain("in", State.Headers.Keys);
Assert.DoesNotContain("in", CapturedHeaders.Keys);
}

[Theory]
Expand All @@ -127,8 +135,8 @@ public async Task UsesValueFilter(string[] filterValues, string[] expectedValues
await Middleware.Invoke(Context);

// Assert
Assert.Contains("in", State.Headers.Keys);
Assert.Equal(expectedValues, State.Headers["in"]);
Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal(expectedValues, CapturedHeaders["in"]);
Assert.Equal("in", receivedName);
Assert.Equal(new StringValues("value"), receivedValue);
Assert.Same(Context, receivedContext);
Expand All @@ -145,8 +153,8 @@ public async Task PreferValueFilter_OverRequestHeader()
await Middleware.Invoke(Context);

// Assert
Assert.Contains("in", State.Headers.Keys);
Assert.Equal("test", State.Headers["in"]);
Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal("test", CapturedHeaders["in"]);
}

[Fact]
Expand All @@ -159,7 +167,7 @@ public async Task EmptyValuesFromValueFilter_DoesNotAddIt()
await Middleware.Invoke(Context);

// Assert
Assert.DoesNotContain("in", State.Headers.Keys);
Assert.DoesNotContain("in", CapturedHeaders.Keys);
}

[Fact]
Expand All @@ -174,8 +182,46 @@ public async Task MultipleEntries_AddsFirstToProduceValue()
await Middleware.Invoke(Context);

// Assert
Assert.Contains("in", State.Headers.Keys);
Assert.Equal("Test", State.Headers["in"]);
Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal("Test", CapturedHeaders["in"]);
}

[Fact]
public async Task HeaderInRequest_WithBleedAsyncLocal_HasCorrectValue()
{
// Arrange
Configuration.Headers.Add("in");

// Process first request
Context.Request.Headers.Add("in", "dirty");
await Middleware.Invoke(Context);

// Process second request
Context = new DefaultHttpContext();
Context.Request.Headers.Add("in", "test");
await Middleware.Invoke(Context);

// Assert
Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal(new[] { "test" }, CapturedHeaders["in"]);
}

[Fact]
public async Task NoHeaderInRequest_WithBleedAsyncLocal_DoesNotHaveIt()
{
// Arrange
Configuration.Headers.Add("in");

// Process first request
Context.Request.Headers.Add("in", "dirty");
await Middleware.Invoke(Context);

// Process second request
Context = new DefaultHttpContext();
await Middleware.Invoke(Context);

// Assert
Assert.Empty(CapturedHeaders);
}
}
}

0 comments on commit fbe7ecf

Please sign in to comment.