From 5dca0f12e0ae4fb4b76f42c7de9364c6fe440676 Mon Sep 17 00:00:00 2001 From: David Lee <10739819+DL444@users.noreply.github.com> Date: Thu, 9 Nov 2023 07:47:59 +0800 Subject: [PATCH] Fix incorrect invocation of functions with names only differ in casing (#2033) Co-authored-by: Fabio Cavalcante --- .../FunctionExecutorGenerator.Emitter.cs | 4 +- sdk/release_notes.md | 1 + .../E2EApps/E2EApp/Http/BasicHttpFunctions.cs | 26 ++++ test/E2ETests/E2ETests/HttpEndToEndTests.cs | 2 + .../FunctionExecutorGeneratorTests.cs | 123 +++++++++++++++--- 5 files changed, 134 insertions(+), 22 deletions(-) diff --git a/sdk/Sdk.Generators/FunctionExecutor/FunctionExecutorGenerator.Emitter.cs b/sdk/Sdk.Generators/FunctionExecutor/FunctionExecutorGenerator.Emitter.cs index d74bc645a..2d4c3e5fc 100644 --- a/sdk/Sdk.Generators/FunctionExecutor/FunctionExecutorGenerator.Emitter.cs +++ b/sdk/Sdk.Generators/FunctionExecutor/FunctionExecutorGenerator.Emitter.cs @@ -124,14 +124,16 @@ private static string GetMethodBody(IEnumerable functions) var inputArguments = inputBindingResult.Values; """); + bool first = true; foreach (ExecutableFunction function in functions) { sb.Append($$""" - if (string.Equals(context.FunctionDefinition.EntryPoint, "{{function.EntryPoint}}", StringComparison.OrdinalIgnoreCase)) + {{(first ? string.Empty : "else ")}}if (string.Equals(context.FunctionDefinition.EntryPoint, "{{function.EntryPoint}}", StringComparison.Ordinal)) { """); + first = false; int functionParamCounter = 0; var functionParamList = new List(); foreach (var argumentTypeName in function.ParameterTypeNames) diff --git a/sdk/release_notes.md b/sdk/release_notes.md index 577369971..10507d77b 100644 --- a/sdk/release_notes.md +++ b/sdk/release_notes.md @@ -18,3 +18,4 @@ - Updated source generated versions of IFunctionExecutor to use `global::` namespace prefix to avoid build errors for function class with the same name as its containing namespace. (#1993) - Updated source generated versions of IFunctionExecutor to include XML documentation for all public types and members - Updated source generated versions of IFunctionMedatadaProvider to include XML documentation for all public types and members +- Updated source generated versions of IFunctionExecutor to use case-sensitive comparison to fix incorrect invocation of functions with method names only differ in casing. (#2003) diff --git a/test/E2ETests/E2EApps/E2EApp/Http/BasicHttpFunctions.cs b/test/E2ETests/E2EApps/E2EApp/Http/BasicHttpFunctions.cs index f9ea25e69..4cb1a949c 100644 --- a/test/E2ETests/E2EApps/E2EApp/Http/BasicHttpFunctions.cs +++ b/test/E2ETests/E2EApps/E2EApp/Http/BasicHttpFunctions.cs @@ -11,6 +11,32 @@ namespace Microsoft.Azure.Functions.Worker.E2EApp { public static class BasicHttpFunctions { + [Function("HelloPascal")] + public static HttpResponseData Hello( + [HttpTrigger(AuthorizationLevel.Anonymous, "get", "post", Route = null)] HttpRequestData req, + FunctionContext context) + { + var logger = context.GetLogger(nameof(Hello)); + logger.LogInformation(".NET Worker HTTP trigger function processed a request"); + + var response = req.CreateResponse(HttpStatusCode.OK); + response.WriteString("Hello!"); + return response; + } + + [Function("HelloAllCaps")] + public static HttpResponseData HELLO( + [HttpTrigger(AuthorizationLevel.Anonymous, "get", "post", Route = null)] HttpRequestData req, + FunctionContext context) + { + var logger = context.GetLogger(nameof(HELLO)); + logger.LogInformation(".NET Worker HTTP trigger function processed a request"); + + var response = req.CreateResponse(HttpStatusCode.OK); + response.WriteString("HELLO!"); + return response; + } + [Function(nameof(HelloFromQuery))] public static HttpResponseData HelloFromQuery( [HttpTrigger(AuthorizationLevel.Anonymous, "get", "post", Route = null)] HttpRequestData req, diff --git a/test/E2ETests/E2ETests/HttpEndToEndTests.cs b/test/E2ETests/E2ETests/HttpEndToEndTests.cs index 5c6da07b6..764043d40 100644 --- a/test/E2ETests/E2ETests/HttpEndToEndTests.cs +++ b/test/E2ETests/E2ETests/HttpEndToEndTests.cs @@ -24,6 +24,8 @@ public HttpEndToEndTests(FunctionAppFixture fixture, ITestOutputHelper testOutpu } [Theory] + [InlineData("HelloPascal", "", HttpStatusCode.OK, "Hello!")] + [InlineData("HelloAllCaps", "", HttpStatusCode.OK, "HELLO!")] [InlineData("HelloFromQuery", "?name=Test", HttpStatusCode.OK, "Hello Test")] [InlineData("HelloFromQuery", "?name=John&lastName=Doe", HttpStatusCode.OK, "Hello John")] [InlineData("HelloFromQuery", "?emptyProperty=&name=Jane", HttpStatusCode.OK, "Hello Jane")] diff --git a/test/Sdk.Generator.Tests/FunctionExecutor/FunctionExecutorGeneratorTests.cs b/test/Sdk.Generator.Tests/FunctionExecutor/FunctionExecutorGeneratorTests.cs index a0bf57d76..b437b295e 100644 --- a/test/Sdk.Generator.Tests/FunctionExecutor/FunctionExecutorGeneratorTests.cs +++ b/test/Sdk.Generator.Tests/FunctionExecutor/FunctionExecutorGeneratorTests.cs @@ -133,29 +133,29 @@ public async ValueTask ExecuteAsync(FunctionContext context) var inputBindingResult = await inputBindingFeature.BindFunctionInputAsync(context)!; var inputArguments = inputBindingResult.Values; - if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Foo"", StringComparison.OrdinalIgnoreCase)) + if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Foo"", StringComparison.Ordinal)) {{ var instanceType = types[""MyCompany.MyHttpTriggers""]; var i = _functionActivator.CreateInstance(instanceType, context) as global::MyCompany.MyHttpTriggers; context.GetInvocationResult().Value = i.Foo((global::Microsoft.Azure.Functions.Worker.Http.HttpRequestData)inputArguments[0], (global::Microsoft.Azure.Functions.Worker.FunctionContext)inputArguments[1]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers2.Bar"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers2.Bar"", StringComparison.Ordinal)) {{ var instanceType = types[""MyCompany.MyHttpTriggers2""]; var i = _functionActivator.CreateInstance(instanceType, context) as global::MyCompany.MyHttpTriggers2; context.GetInvocationResult().Value = i.Bar((global::Microsoft.Azure.Functions.Worker.Http.HttpRequestData)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.Foo.MyAsyncStaticMethod"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.Foo.MyAsyncStaticMethod"", StringComparison.Ordinal)) {{ context.GetInvocationResult().Value = await global::MyCompany.Foo.MyAsyncStaticMethod((string)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.QueueTriggers.Run"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.QueueTriggers.Run"", StringComparison.Ordinal)) {{ var instanceType = types[""MyCompany.QueueTriggers""]; var i = _functionActivator.CreateInstance(instanceType, context) as global::MyCompany.QueueTriggers; i.Run((global::Azure.Storage.Queues.Models.QueueMessage)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.QueueTriggers.Run2"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.QueueTriggers.Run2"", StringComparison.Ordinal)) {{ var instanceType = types[""MyCompany.QueueTriggers""]; var i = _functionActivator.CreateInstance(instanceType, context) as global::MyCompany.QueueTriggers; @@ -240,13 +240,13 @@ public async ValueTask ExecuteAsync(FunctionContext context) var inputBindingResult = await inputBindingFeature.BindFunctionInputAsync(context)!; var inputArguments = inputBindingResult.Values; - if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Run1"", StringComparison.OrdinalIgnoreCase)) + if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Run1"", StringComparison.Ordinal)) {{ var instanceType = types[""MyCompany.MyHttpTriggers""]; var i = _functionActivator.CreateInstance(instanceType, context) as global::MyCompany.MyHttpTriggers; context.GetInvocationResult().Value = i.Run1((global::Microsoft.Azure.Functions.Worker.Http.HttpRequestData)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Run2"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Run2"", StringComparison.Ordinal)) {{ var instanceType = types[""MyCompany.MyHttpTriggers""]; var i = _functionActivator.CreateInstance(instanceType, context) as global::MyCompany.MyHttpTriggers; @@ -373,47 +373,47 @@ public async ValueTask ExecuteAsync(FunctionContext context) var inputBindingResult = await inputBindingFeature.BindFunctionInputAsync(context)!; var inputArguments = inputBindingResult.Values; - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyTaskStaticMethod"", StringComparison.OrdinalIgnoreCase)) + if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyTaskStaticMethod"", StringComparison.Ordinal)) {{ await global::FunctionApp26.MyQTriggers.MyTaskStaticMethod((string)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyAsyncStaticMethod"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyAsyncStaticMethod"", StringComparison.Ordinal)) {{ context.GetInvocationResult().Value = await global::FunctionApp26.MyQTriggers.MyAsyncStaticMethod((string)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyVoidStaticMethod"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyVoidStaticMethod"", StringComparison.Ordinal)) {{ global::FunctionApp26.MyQTriggers.MyVoidStaticMethod((string)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyAsyncStaticMethodWithReturn"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyAsyncStaticMethodWithReturn"", StringComparison.Ordinal)) {{ context.GetInvocationResult().Value = await global::FunctionApp26.MyQTriggers.MyAsyncStaticMethodWithReturn((string)inputArguments[0], (string)inputArguments[1]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyValueTaskOfTStaticAsyncMethod"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyValueTaskOfTStaticAsyncMethod"", StringComparison.Ordinal)) {{ context.GetInvocationResult().Value = await global::FunctionApp26.MyQTriggers.MyValueTaskOfTStaticAsyncMethod((string)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyValueTaskStaticAsyncMethod2"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.MyQTriggers.MyValueTaskStaticAsyncMethod2"", StringComparison.Ordinal)) {{ await global::FunctionApp26.MyQTriggers.MyValueTaskStaticAsyncMethod2((string)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.BlobTriggers.Run"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.BlobTriggers.Run"", StringComparison.Ordinal)) {{ await global::FunctionApp26.BlobTriggers.Run((global::System.IO.Stream)inputArguments[0], (string)inputArguments[1]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.EventHubTriggers.Run1"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.EventHubTriggers.Run1"", StringComparison.Ordinal)) {{ global::FunctionApp26.EventHubTriggers.Run1((global::Azure.Messaging.EventHubs.EventData[])inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.EventHubTriggers.Run2"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.EventHubTriggers.Run2"", StringComparison.Ordinal)) {{ context.GetInvocationResult().Value = global::FunctionApp26.EventHubTriggers.Run2((global::Azure.Messaging.EventHubs.EventData)inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.EventHubTriggers.RunAsync1"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.EventHubTriggers.RunAsync1"", StringComparison.Ordinal)) {{ await global::FunctionApp26.EventHubTriggers.RunAsync1((global::Azure.Messaging.EventHubs.EventData[])inputArguments[0]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.EventHubTriggers.RunAsync2"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""FunctionApp26.EventHubTriggers.RunAsync2"", StringComparison.Ordinal)) {{ await global::FunctionApp26.EventHubTriggers.RunAsync2((global::Azure.Messaging.EventHubs.EventData[])inputArguments[0]); }} @@ -486,7 +486,7 @@ public async ValueTask ExecuteAsync(FunctionContext context) var inputBindingResult = await inputBindingFeature.BindFunctionInputAsync(context)!; var inputArguments = inputBindingResult.Values; - if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Run1"", StringComparison.OrdinalIgnoreCase)) + if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Run1"", StringComparison.Ordinal)) {{ var instanceType = types[""MyCompany.MyHttpTriggers""]; var i = _functionActivator.CreateInstance(instanceType, context) as global::MyCompany.MyHttpTriggers; @@ -571,13 +571,13 @@ public async ValueTask ExecuteAsync(FunctionContext context) var inputBindingResult = await inputBindingFeature.BindFunctionInputAsync(context)!; var inputArguments = inputBindingResult.Values; - if (string.Equals(context.FunctionDefinition.EntryPoint, ""TestProject.TestProject.Foo"", StringComparison.OrdinalIgnoreCase)) + if (string.Equals(context.FunctionDefinition.EntryPoint, ""TestProject.TestProject.Foo"", StringComparison.Ordinal)) {{ var instanceType = types[""TestProject.TestProject""]; var i = _functionActivator.CreateInstance(instanceType, context) as global::TestProject.TestProject; context.GetInvocationResult().Value = i.Foo((global::Microsoft.Azure.Functions.Worker.Http.HttpRequestData)inputArguments[0], (global::Microsoft.Azure.Functions.Worker.FunctionContext)inputArguments[1]); }} - if (string.Equals(context.FunctionDefinition.EntryPoint, ""TestProject.TestProject.FooStatic"", StringComparison.OrdinalIgnoreCase)) + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""TestProject.TestProject.FooStatic"", StringComparison.Ordinal)) {{ context.GetInvocationResult().Value = global::TestProject.TestProject.FooStatic((global::Microsoft.Azure.Functions.Worker.Http.HttpRequestData)inputArguments[0], (global::Microsoft.Azure.Functions.Worker.FunctionContext)inputArguments[1]); }} @@ -593,6 +593,87 @@ public async ValueTask ExecuteAsync(FunctionContext context) expectedOutput); } + [Fact] + public async Task FunctionsWithSameNameExceptForCasing() + { + const string inputSourceCode = @" +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.Hosting; +using Azure.Storage.Queues.Models; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.Extensions.Logging; +namespace MyCompany +{ + public class MyHttpTriggers + { + [Function(""FunctionA"")] + public HttpResponseData Hello([HttpTrigger(AuthorizationLevel.User, ""get"")] HttpRequestData r, FunctionContext c) + { + return r.CreateResponse(System.Net.HttpStatusCode.OK); + } + + [Function(""FunctionB"")] + public static HttpResponseData HELLO([HttpTrigger(AuthorizationLevel.User, ""get"")] HttpRequestData r, FunctionContext c) + { + return r.CreateResponse(System.Net.HttpStatusCode.OK); + } + } +} +"; + var expectedOutput = $@"// +using System; +using System.Threading.Tasks; +using System.Collections.Generic; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Context.Features; +using Microsoft.Azure.Functions.Worker.Invocation; +namespace TestProject +{{ + internal class DirectFunctionExecutor : IFunctionExecutor + {{ + private readonly IFunctionActivator _functionActivator; + private readonly Dictionary types = new() + {{ + {{ ""MyCompany.MyHttpTriggers"", Type.GetType(""MyCompany.MyHttpTriggers"")! }} + }}; + + public DirectFunctionExecutor(IFunctionActivator functionActivator) + {{ + _functionActivator = functionActivator ?? throw new ArgumentNullException(nameof(functionActivator)); + }} + + public async ValueTask ExecuteAsync(FunctionContext context) + {{ + var inputBindingFeature = context.Features.Get()!; + var inputBindingResult = await inputBindingFeature.BindFunctionInputAsync(context)!; + var inputArguments = inputBindingResult.Values; + + if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.Hello"", StringComparison.Ordinal)) + {{ + var instanceType = types[""MyCompany.MyHttpTriggers""]; + var i = _functionActivator.CreateInstance(instanceType, context) as global::MyCompany.MyHttpTriggers; + context.GetInvocationResult().Value = i.Hello((global::Microsoft.Azure.Functions.Worker.Http.HttpRequestData)inputArguments[0], (global::Microsoft.Azure.Functions.Worker.FunctionContext)inputArguments[1]); + }} + else if (string.Equals(context.FunctionDefinition.EntryPoint, ""MyCompany.MyHttpTriggers.HELLO"", StringComparison.Ordinal)) + {{ + context.GetInvocationResult().Value = global::MyCompany.MyHttpTriggers.HELLO((global::Microsoft.Azure.Functions.Worker.Http.HttpRequestData)inputArguments[0], (global::Microsoft.Azure.Functions.Worker.FunctionContext)inputArguments[1]); + }} + }} + }} +{GetExpectedExtensionMethodCode()} +}}".Replace("'", "\""); + + await TestHelpers.RunTestAsync( + _referencedAssemblies, + inputSourceCode, + Constants.FileNames.GeneratedFunctionExecutor, + expectedOutput); + } + private static string GetExpectedExtensionMethodCode(bool includeAutoStartupType = false) { if (includeAutoStartupType)