diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs index 655d1170c..044c842c3 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs @@ -36,10 +36,22 @@ public class LambdaBootstrap : IDisposable private LambdaBootstrapInitializer _initializer; private LambdaBootstrapHandler _handler; + private bool _ownsHttpClient; private HttpClient _httpClient; internal IRuntimeApiClient Client { get; set; } + /// <summary> + /// Create a LambdaBootstrap that will call the given initializer and handler. + /// </summary> + /// <param name="httpClient">The HTTP client to use with the Lambda runtime.</param> + /// <param name="handler">Delegate called for each invocation of the Lambda function.</param> + /// <param name="initializer">Delegate called to initialize the Lambda function. If not provided the initialization step is skipped.</param> + /// <returns></returns> + public LambdaBootstrap(HttpClient httpClient, LambdaBootstrapHandler handler, LambdaBootstrapInitializer initializer = null) + : this(httpClient, handler, initializer, ownsHttpClient: false) + { } + /// <summary> /// Create a LambdaBootstrap that will call the given initializer and handler. /// </summary> @@ -47,15 +59,8 @@ public class LambdaBootstrap : IDisposable /// <param name="initializer">Delegate called to initialize the Lambda function. If not provided the initialization step is skipped.</param> /// <returns></returns> public LambdaBootstrap(LambdaBootstrapHandler handler, LambdaBootstrapInitializer initializer = null) - { - _handler = handler ?? throw new ArgumentNullException(nameof(handler)); - _initializer = initializer; - _httpClient = new HttpClient - { - Timeout = RuntimeApiHttpTimeout - }; - Client = new RuntimeApiClient(new SystemEnvironmentVariables(), _httpClient); - } + : this(new HttpClient(), handler, initializer, ownsHttpClient: true) + { } /// <summary> /// Create a LambdaBootstrap that will call the given initializer and handler. @@ -67,6 +72,35 @@ public LambdaBootstrap(HandlerWrapper handlerWrapper, LambdaBootstrapInitializer : this(handlerWrapper.Handler, initializer) { } + /// <summary> + /// Create a LambdaBootstrap that will call the given initializer and handler. + /// </summary> + /// <param name="httpClient">The HTTP client to use with the Lambda runtime.</param> + /// <param name="handlerWrapper">The HandlerWrapper to call for each invocation of the Lambda function.</param> + /// <param name="initializer">Delegate called to initialize the Lambda function. If not provided the initialization step is skipped.</param> + /// <returns></returns> + public LambdaBootstrap(HttpClient httpClient, HandlerWrapper handlerWrapper, LambdaBootstrapInitializer initializer = null) + : this(httpClient, handlerWrapper.Handler, initializer, ownsHttpClient: false) + { } + + /// <summary> + /// Create a LambdaBootstrap that will call the given initializer and handler. + /// </summary> + /// <param name="httpClient">The HTTP client to use with the Lambda runtime.</param> + /// <param name="handler">Delegate called for each invocation of the Lambda function.</param> + /// <param name="initializer">Delegate called to initialize the Lambda function. If not provided the initialization step is skipped.</param> + /// <param name="ownsHttpClient">Whether the instance owns the HTTP client and should dispose of it.</param> + /// <returns></returns> + private LambdaBootstrap(HttpClient httpClient, LambdaBootstrapHandler handler, LambdaBootstrapInitializer initializer, bool ownsHttpClient) + { + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + _handler = handler ?? throw new ArgumentNullException(nameof(handler)); + _ownsHttpClient = ownsHttpClient; + _initializer = initializer; + _httpClient.Timeout = RuntimeApiHttpTimeout; + Client = new RuntimeApiClient(new SystemEnvironmentVariables(), _httpClient); + } + /// <summary> /// Run the initialization Func if provided. /// Then run the invoke loop, calling the handler for each invocation. @@ -79,7 +113,14 @@ public LambdaBootstrap(HandlerWrapper handlerWrapper, LambdaBootstrapInitializer while (doStartInvokeLoop && !cancellationToken.IsCancellationRequested) { - await InvokeOnceAsync(); + try + { + await InvokeOnceAsync(cancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Loop cancelled + } } } @@ -96,9 +137,9 @@ internal async Task<bool> InitializeAsync() } } - internal async Task InvokeOnceAsync() + internal async Task InvokeOnceAsync(CancellationToken cancellationToken = default) { - using (var invocation = await Client.GetNextInvocationAsync()) + using (var invocation = await Client.GetNextInvocationAsync(cancellationToken)) { InvocationResponse response = null; bool invokeSucceeded = false; @@ -137,7 +178,7 @@ protected virtual void Dispose(bool disposing) { if (!disposedValue) { - if (disposing) + if (disposing && _ownsHttpClient) { _httpClient?.Dispose(); } diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/IRuntimeApiClient.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/IRuntimeApiClient.cs index f2579f97b..7f148104d 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/IRuntimeApiClient.cs +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/IRuntimeApiClient.cs @@ -33,31 +33,35 @@ public interface IRuntimeApiClient /// Report an initialization error as an asynchronous operation. /// </summary> /// <param name="exception">The exception to report.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns>A Task representing the asynchronous operation.</returns> - Task ReportInitializationErrorAsync(Exception exception); + Task ReportInitializationErrorAsync(Exception exception, CancellationToken cancellationToken = default); /// <summary> /// Send an initialization error with a type string but no other information as an asynchronous operation. /// This can be used to directly control flow in Step Functions without creating an Exception class and throwing it. /// </summary> /// <param name="errorType">The type of the error to report to Lambda. This does not need to be a .NET type name.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns>A Task representing the asynchronous operation.</returns> - Task ReportInitializationErrorAsync(string errorType); + Task ReportInitializationErrorAsync(string errorType, CancellationToken cancellationToken = default); /// <summary> /// Get the next function invocation from the Runtime API as an asynchronous operation. /// Completes when the next invocation is received. /// </summary> + /// <param name="cancellationToken">The optional cancellation token to use to stop listening for the next invocation.</param> /// <returns>A Task representing the asynchronous operation.</returns> - Task<InvocationRequest> GetNextInvocationAsync(); + Task<InvocationRequest> GetNextInvocationAsync(CancellationToken cancellationToken = default); /// <summary> /// Report an invocation error as an asynchronous operation. /// </summary> /// <param name="awsRequestId">The ID of the function request that caused the error.</param> /// <param name="exception">The exception to report.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns>A Task representing the asynchronous operation.</returns> - Task ReportInvocationErrorAsync(string awsRequestId, Exception exception); + Task ReportInvocationErrorAsync(string awsRequestId, Exception exception, CancellationToken cancellationToken = default); /// <summary> /// Send an initialization error with a type string but no other information as an asynchronous operation. @@ -65,15 +69,17 @@ public interface IRuntimeApiClient /// </summary> /// <param name="awsRequestId">The ID of the function request that caused the error.</param> /// <param name="errorType">The type of the error to report to Lambda. This does not need to be a .NET type name.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns>A Task representing the asynchronous operation.</returns> - Task ReportInvocationErrorAsync(string awsRequestId, string errorType); + Task ReportInvocationErrorAsync(string awsRequestId, string errorType, CancellationToken cancellationToken = default); /// <summary> /// Send a response to a function invocation to the Runtime API as an asynchronous operation. /// </summary> /// <param name="awsRequestId">The ID of the function request being responded to.</param> /// <param name="outputStream">The content of the response to the function invocation.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns></returns> - Task SendResponseAsync(string awsRequestId, Stream outputStream); + Task SendResponseAsync(string awsRequestId, Stream outputStream, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs index c9ccaf75f..f78b5f315 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs @@ -65,13 +65,14 @@ internal RuntimeApiClient(IEnvironmentVariables environmentVariables, IInternalR /// Report an initialization error as an asynchronous operation. /// </summary> /// <param name="exception">The exception to report.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns>A Task representing the asynchronous operation.</returns> - public Task ReportInitializationErrorAsync(Exception exception) + public Task ReportInitializationErrorAsync(Exception exception, CancellationToken cancellationToken = default) { if (exception == null) throw new ArgumentNullException(nameof(exception)); - return _internalClient.ErrorAsync(null, LambdaJsonExceptionWriter.WriteJson(ExceptionInfo.GetExceptionInfo(exception))); + return _internalClient.ErrorAsync(null, LambdaJsonExceptionWriter.WriteJson(ExceptionInfo.GetExceptionInfo(exception)), cancellationToken); } /// <summary> @@ -79,23 +80,25 @@ public Task ReportInitializationErrorAsync(Exception exception) /// This can be used to directly control flow in Step Functions without creating an Exception class and throwing it. /// </summary> /// <param name="errorType">The type of the error to report to Lambda. This does not need to be a .NET type name.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns>A Task representing the asynchronous operation.</returns> - public Task ReportInitializationErrorAsync(string errorType) + public Task ReportInitializationErrorAsync(string errorType, CancellationToken cancellationToken = default) { if (errorType == null) throw new ArgumentNullException(nameof(errorType)); - return _internalClient.ErrorAsync(errorType, null); + return _internalClient.ErrorAsync(errorType, null, cancellationToken); } /// <summary> /// Get the next function invocation from the Runtime API as an asynchronous operation. /// Completes when the next invocation is received. /// </summary> + /// <param name="cancellationToken">The optional cancellation token to use to stop listening for the next invocation.</param> /// <returns>A Task representing the asynchronous operation.</returns> - public async Task<InvocationRequest> GetNextInvocationAsync() + public async Task<InvocationRequest> GetNextInvocationAsync(CancellationToken cancellationToken = default) { - SwaggerResponse<Stream> response = await _internalClient.NextAsync(System.Threading.CancellationToken.None); + SwaggerResponse<Stream> response = await _internalClient.NextAsync(cancellationToken); var lambdaContext = new LambdaContext(new RuntimeApiHeaders(response.Headers), LambdaEnvironment); return new InvocationRequest @@ -110,8 +113,9 @@ public async Task<InvocationRequest> GetNextInvocationAsync() /// </summary> /// <param name="awsRequestId">The ID of the function request that caused the error.</param> /// <param name="exception">The exception to report.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns>A Task representing the asynchronous operation.</returns> - public Task ReportInvocationErrorAsync(string awsRequestId, Exception exception) + public Task ReportInvocationErrorAsync(string awsRequestId, Exception exception, CancellationToken cancellationToken = default) { if (awsRequestId == null) throw new ArgumentNullException(nameof(awsRequestId)); @@ -120,7 +124,7 @@ public Task ReportInvocationErrorAsync(string awsRequestId, Exception exception) throw new ArgumentNullException(nameof(exception)); var exceptionInfo = ExceptionInfo.GetExceptionInfo(exception); - return _internalClient.Error2Async(awsRequestId, exceptionInfo.ErrorType, LambdaJsonExceptionWriter.WriteJson(exceptionInfo)); + return _internalClient.Error2Async(awsRequestId, exceptionInfo.ErrorType, LambdaJsonExceptionWriter.WriteJson(exceptionInfo), cancellationToken); } /// <summary> @@ -129,10 +133,11 @@ public Task ReportInvocationErrorAsync(string awsRequestId, Exception exception) /// </summary> /// <param name="awsRequestId">The ID of the function request that caused the error.</param> /// <param name="errorType">The type of the error to report to Lambda. This does not need to be a .NET type name.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns>A Task representing the asynchronous operation.</returns> - public Task ReportInvocationErrorAsync(string awsRequestId, string errorType) + public Task ReportInvocationErrorAsync(string awsRequestId, string errorType, CancellationToken cancellationToken = default) { - return _internalClient.Error2Async(awsRequestId, errorType, null); + return _internalClient.Error2Async(awsRequestId, errorType, null, cancellationToken); } /// <summary> @@ -140,10 +145,11 @@ public Task ReportInvocationErrorAsync(string awsRequestId, string errorType) /// </summary> /// <param name="awsRequestId">The ID of the function request being responded to.</param> /// <param name="outputStream">The content of the response to the function invocation.</param> + /// <param name="cancellationToken">The optional cancellation token to use.</param> /// <returns></returns> - public async Task SendResponseAsync(string awsRequestId, Stream outputStream) + public async Task SendResponseAsync(string awsRequestId, Stream outputStream, CancellationToken cancellationToken = default) { - await _internalClient.ResponseAsync(awsRequestId, outputStream, CancellationToken.None); + await _internalClient.ResponseAsync(awsRequestId, outputStream, cancellationToken); } } } \ No newline at end of file diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs index 19a44a18e..020dcf9a7 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs @@ -13,6 +13,7 @@ * permissions and limitations under the License. */ using System; +using System.Net.Http; using System.Text; using System.Threading.Tasks; using Xunit; @@ -29,6 +30,7 @@ public class LambdaBootstrapTests TestInitializer _testInitializer; TestRuntimeApiClient _testRuntimeApiClient; TestEnvironmentVariables _environmentVariables; + HandlerWrapper _testWrapper; public LambdaBootstrapTests() { @@ -36,6 +38,7 @@ public LambdaBootstrapTests() _testRuntimeApiClient = new TestRuntimeApiClient(_environmentVariables); _testInitializer = new TestInitializer(); _testFunction = new TestHandler(); + _testWrapper = HandlerWrapper.GetHandlerWrapper(_testFunction.HandlerVoidVoidSync); } [Fact] @@ -44,6 +47,13 @@ public void ThrowsExceptionForNullHandler() Assert.Throws<ArgumentNullException>("handler", () => { new LambdaBootstrap((LambdaBootstrapHandler)null); }); } + [Fact] + public void ThrowsExceptionForNullHttpClient() + { + Assert.Throws<ArgumentNullException>("httpClient", () => { new LambdaBootstrap((HttpClient)null, _testFunction.BaseHandlerAsync); }); + Assert.Throws<ArgumentNullException>("httpClient", () => { new LambdaBootstrap((HttpClient)null, _testWrapper); }); + } + [Fact] public async Task NoInitializer() { diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestRuntimeApiClient.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestRuntimeApiClient.cs index 66c645c19..ae1005544 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestRuntimeApiClient.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestRuntimeApiClient.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using System.IO; using System.Text; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -75,7 +76,7 @@ public void VerifyOutput(byte[] expectedOutput) } } - public Task<InvocationRequest> GetNextInvocationAsync() + public Task<InvocationRequest> GetNextInvocationAsync(CancellationToken cancellationToken = default) { GetNextInvocationAsyncCalled = true; @@ -94,31 +95,31 @@ public Task<InvocationRequest> GetNextInvocationAsync() }); } - public Task ReportInitializationErrorAsync(Exception exception) + public Task ReportInitializationErrorAsync(Exception exception, CancellationToken cancellationToken = default) { ReportInitializationErrorAsyncExceptionCalled = true; return Task.Run(() => { }); } - public Task ReportInitializationErrorAsync(string errorType) + public Task ReportInitializationErrorAsync(string errorType, CancellationToken cancellationToken = default) { ReportInitializationErrorAsyncTypeCalled = true; return Task.Run(() => { }); } - public Task ReportInvocationErrorAsync(string awsRequestId, Exception exception) + public Task ReportInvocationErrorAsync(string awsRequestId, Exception exception, CancellationToken cancellationToken = default) { ReportInvocationErrorAsyncExceptionCalled = true; return Task.Run(() => { }); } - public Task ReportInvocationErrorAsync(string awsRequestId, string errorType) + public Task ReportInvocationErrorAsync(string awsRequestId, string errorType, CancellationToken cancellationToken = default) { ReportInvocationErrorAsyncTypeCalled = true; return Task.Run(() => { }); } - public Task SendResponseAsync(string awsRequestId, Stream outputStream) + public Task SendResponseAsync(string awsRequestId, Stream outputStream, CancellationToken cancellationToken = default) { if (outputStream != null) {