diff --git a/.eslintrc.json b/.eslintrc.json index edd0417a..2682af37 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -43,6 +43,7 @@ "@typescript-eslint/restrict-template-expressions": "off", "@typescript-eslint/unbound-method": "off", "no-empty": "off", + "prefer-const": ["error", { "destructuring": "all" }], "prefer-rest-params": "off", "prefer-spread": "off" }, diff --git a/src/Disposable.ts b/src/Disposable.ts new file mode 100644 index 00000000..c40aee79 --- /dev/null +++ b/src/Disposable.ts @@ -0,0 +1,35 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. + +/** + * Based off of VS Code + * https://github.com/microsoft/vscode/blob/a64e8e5673a44e5b9c2d493666bde684bd5a135c/src/vs/workbench/api/common/extHostTypes.ts#L32 + */ +export class Disposable { + static from(...inDisposables: { dispose(): any }[]): Disposable { + let disposables: ReadonlyArray<{ dispose(): any }> | undefined = inDisposables; + return new Disposable(function () { + if (disposables) { + for (const disposable of disposables) { + if (disposable && typeof disposable.dispose === 'function') { + disposable.dispose(); + } + } + disposables = undefined; + } + }); + } + + #callOnDispose?: () => any; + + constructor(callOnDispose: () => any) { + this.#callOnDispose = callOnDispose; + } + + dispose(): any { + if (this.#callOnDispose instanceof Function) { + this.#callOnDispose(); + this.#callOnDispose = undefined; + } + } +} diff --git a/src/Worker.ts b/src/Worker.ts index ebf77868..a9d0d030 100644 --- a/src/Worker.ts +++ b/src/Worker.ts @@ -4,6 +4,7 @@ import * as parseArgs from 'minimist'; import { FunctionLoader } from './FunctionLoader'; import { CreateGrpcEventStream } from './GrpcClient'; +import { setupCoreModule } from './setupCoreModule'; import { setupEventStream } from './setupEventStream'; import { ensureErrorType } from './utils/ensureErrorType'; import { InternalException } from './utils/InternalException'; @@ -42,6 +43,7 @@ export function startNodeWorker(args) { const channel = new WorkerChannel(eventStream, new FunctionLoader()); setupEventStream(workerId, channel); + setupCoreModule(channel); eventStream.write({ requestId: requestId, diff --git a/src/WorkerChannel.ts b/src/WorkerChannel.ts index bd78e6d2..5e0630f1 100644 --- a/src/WorkerChannel.ts +++ b/src/WorkerChannel.ts @@ -1,9 +1,10 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the MIT License. -import { Context } from '@azure/functions'; +import { HookCallback, HookContext } from '@azure/functions-core'; import { readJson } from 'fs-extra'; import { AzureFunctionsRpcMessages as rpc } from '../azure-functions-language-worker-protobuf/src/rpc'; +import { Disposable } from './Disposable'; import { IFunctionLoader } from './FunctionLoader'; import { IEventStream } from './GrpcClient'; import { ensureErrorType } from './utils/ensureErrorType'; @@ -11,9 +12,6 @@ import path = require('path'); import LogLevel = rpc.RpcLog.Level; import LogCategory = rpc.RpcLog.RpcLogCategory; -type InvocationRequestBefore = (context: Context, userFn: Function) => Function; -type InvocationRequestAfter = (context: Context) => void; - export interface PackageJson { type?: string; } @@ -22,15 +20,13 @@ export class WorkerChannel { public eventStream: IEventStream; public functionLoader: IFunctionLoader; public packageJson: PackageJson; - private _invocationRequestBefore: InvocationRequestBefore[]; - private _invocationRequestAfter: InvocationRequestAfter[]; + #preInvocationHooks: HookCallback[] = []; + #postInvocationHooks: HookCallback[] = []; constructor(eventStream: IEventStream, functionLoader: IFunctionLoader) { this.eventStream = eventStream; this.functionLoader = functionLoader; this.packageJson = {}; - this._invocationRequestBefore = []; - this._invocationRequestAfter = []; } /** @@ -44,32 +40,32 @@ export class WorkerChannel { }); } - /** - * Register a patching function to be run before User Function is executed. - * Hook should return a patched version of User Function. - */ - public registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void { - this._invocationRequestBefore.push(beforeCb); - } - - /** - * Register a function to be run after User Function resolves. - */ - public registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void { - this._invocationRequestAfter.push(afterCb); + public registerHook(hookName: string, callback: HookCallback): Disposable { + const hooks = this.#getHooks(hookName); + hooks.push(callback); + return new Disposable(() => { + const index = hooks.indexOf(callback); + if (index > -1) { + hooks.splice(index, 1); + } + }); } - public runInvocationRequestBefore(context: Context, userFunction: Function): Function { - let wrappedFunction = userFunction; - for (const before of this._invocationRequestBefore) { - wrappedFunction = before(context, wrappedFunction); + public async executeHooks(hookName: string, context: HookContext): Promise { + const callbacks = this.#getHooks(hookName); + for (const callback of callbacks) { + await callback(context); } - return wrappedFunction; } - public runInvocationRequestAfter(context: Context) { - for (const after of this._invocationRequestAfter) { - after(context); + #getHooks(hookName: string): HookCallback[] { + switch (hookName) { + case 'preInvocation': + return this.#preInvocationHooks; + case 'postInvocation': + return this.#postInvocationHooks; + default: + throw new RangeError(`Unrecognized hook "${hookName}"`); } } diff --git a/src/eventHandlers/invocationRequest.ts b/src/eventHandlers/invocationRequest.ts index 90c9766a..3f1d7b22 100644 --- a/src/eventHandlers/invocationRequest.ts +++ b/src/eventHandlers/invocationRequest.ts @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the MIT License. +import { HookData, PostInvocationContext, PreInvocationContext } from '@azure/functions-core'; import { format } from 'util'; import { AzureFunctionsRpcMessages as rpc } from '../../azure-functions-language-worker-protobuf/src/rpc'; import { CreateContextAndInputs } from '../Context'; @@ -67,7 +68,7 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin isDone = true; } - const { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog); + let { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog); try { const legacyDoneTask = new Promise((resolve, reject) => { doneEmitter.on('done', (err?: unknown, result?: any) => { @@ -80,8 +81,13 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin }); }); - let userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId')); - userFunction = channel.runInvocationRequestBefore(context, userFunction); + const hookData: HookData = {}; + const userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId')); + const preInvocContext: PreInvocationContext = { hookData, invocationContext: context, inputs }; + + await channel.executeHooks('preInvocation', preInvocContext); + inputs = preInvocContext.inputs; + let rawResult = userFunction(context, ...inputs); resultIsPromise = rawResult && typeof rawResult.then === 'function'; let resultTask: Promise; @@ -95,7 +101,24 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin resultTask = legacyDoneTask; } - const result = await resultTask; + const postInvocContext: PostInvocationContext = { + hookData, + invocationContext: context, + inputs, + result: null, + error: null, + }; + try { + postInvocContext.result = await resultTask; + } catch (err) { + postInvocContext.error = err; + } + await channel.executeHooks('postInvocation', postInvocContext); + + if (isError(postInvocContext.error)) { + throw postInvocContext.error; + } + const result = postInvocContext.result; // Allow HTTP response from context.res if HTTP response is not defined from the context.bindings object if (info.httpOutputName && context.res && context.bindings[info.httpOutputName] === undefined) { @@ -164,6 +187,4 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin requestId: requestId, invocationResponse: response, }); - - channel.runInvocationRequestAfter(context); } diff --git a/src/setupCoreModule.ts b/src/setupCoreModule.ts new file mode 100644 index 00000000..34bf511e --- /dev/null +++ b/src/setupCoreModule.ts @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. + +import { HookCallback } from '@azure/functions-core'; +import { Disposable } from './Disposable'; +import { WorkerChannel } from './WorkerChannel'; +import Module = require('module'); + +/** + * Intercepts the default "require" method so that we can provide our own "built-in" module + * This module is essentially the publicly accessible API for our worker + * This module is available to users only at runtime, not as an installable npm package + */ +export function setupCoreModule(channel: WorkerChannel): void { + const coreApi = { + registerHook: (hookName: string, callback: HookCallback) => channel.registerHook(hookName, callback), + Disposable, + }; + + Module.prototype.require = new Proxy(Module.prototype.require, { + apply(target, thisArg, argArray) { + if (argArray[0] === '@azure/functions-core') { + return coreApi; + } else { + return Reflect.apply(target, thisArg, argArray); + } + }, + }); +} diff --git a/test/eventHandlers/beforeEventHandlerSuite.ts b/test/eventHandlers/beforeEventHandlerSuite.ts index 78b19fa2..c543c4ec 100644 --- a/test/eventHandlers/beforeEventHandlerSuite.ts +++ b/test/eventHandlers/beforeEventHandlerSuite.ts @@ -3,6 +3,7 @@ import * as sinon from 'sinon'; import { FunctionLoader } from '../../src/FunctionLoader'; +import { setupCoreModule } from '../../src/setupCoreModule'; import { setupEventStream } from '../../src/setupEventStream'; import { WorkerChannel } from '../../src/WorkerChannel'; import { TestEventStream } from './TestEventStream'; @@ -12,5 +13,6 @@ export function beforeEventHandlerSuite() { const loader = sinon.createStubInstance(FunctionLoader); const channel = new WorkerChannel(stream, loader); setupEventStream('workerId', channel); + setupCoreModule(channel); return { stream, loader, channel }; } diff --git a/test/eventHandlers/invocationRequest.test.ts b/test/eventHandlers/invocationRequest.test.ts index b1f74fc4..8a67f713 100644 --- a/test/eventHandlers/invocationRequest.test.ts +++ b/test/eventHandlers/invocationRequest.test.ts @@ -4,13 +4,13 @@ /* eslint-disable deprecation/deprecation */ import { AzureFunction, Context } from '@azure/functions'; +import * as coreTypes from '@azure/functions-core'; import { expect } from 'chai'; import 'mocha'; import * as sinon from 'sinon'; import { AzureFunctionsRpcMessages as rpc } from '../../azure-functions-language-worker-protobuf/src/rpc'; import { FunctionInfo } from '../../src/FunctionInfo'; import { FunctionLoader } from '../../src/FunctionLoader'; -import { WorkerChannel } from '../../src/WorkerChannel'; import { beforeEventHandlerSuite } from './beforeEventHandlerSuite'; import { TestEventStream } from './TestEventStream'; import LogCategory = rpc.RpcLog.RpcLogCategory; @@ -59,7 +59,7 @@ namespace Binding { }; export const queue = { bindings: { - test: { + testOutput: { type: 'queue', direction: 1, dataType: 1, @@ -79,6 +79,8 @@ function addSuffix(asyncFunc: AzureFunction, callbackFunc: AzureFunction): [Azur ]; } +let hookData: string; + namespace TestFunc { const basicAsync = async (context: Context) => { context.log('testUserLog'); @@ -113,6 +115,27 @@ namespace TestFunc { }; export const resHttp = addSuffix(resHttpAsync, resHttpCallback); + const logHookDataAsync = async (context: Context) => { + hookData += 'invoc'; + context.log(hookData); + return 'hello'; + }; + const logHookDataCallback = (context: Context) => { + hookData += 'invoc'; + context.log(hookData); + context.done(null, 'hello'); + }; + export const logHookData = addSuffix(logHookDataAsync, logHookDataCallback); + + const logInputAsync = async (context: Context, input: any) => { + context.log(input); + }; + const logInputCallback = (context: Context, input: any) => { + context.log(input); + context.done(); + }; + export const logInput = addSuffix(logInputAsync, logInputCallback); + const multipleBindingsAsync = async (context: Context) => { context.bindings.queueOutput = 'queue message'; context.bindings.overriddenQueueOutput = 'start message'; @@ -194,15 +217,17 @@ namespace Msg { logCategory: LogCategory.System, }, }; - export const userTestLog: rpc.IStreamingMessage = { - rpcLog: { - category: 'testFuncName.Invocation', - invocationId: '1', - message: 'testUserLog', - level: LogLevel.Information, - logCategory: LogCategory.User, - }, - }; + export function userTestLog(data = 'testUserLog'): rpc.IStreamingMessage { + return { + rpcLog: { + category: 'testFuncName.Invocation', + invocationId: '1', + message: data, + level: LogLevel.Information, + logCategory: LogCategory.User, + }, + }; + } export const invocResFailed: rpc.IStreamingMessage = { requestId: 'testReqId', invocationResponse: { @@ -248,17 +273,50 @@ namespace Msg { } } +namespace InputData { + export const http = { + name: 'req', + data: { + data: 'http', + http: { + body: { + string: 'blahh', + }, + rawBody: { + string: 'blahh', + }, + }, + }, + }; + + export const string = { + name: 'testInput', + data: { + data: 'string', + string: 'testStringData', + }, + }; +} + describe('invocationRequest', () => { - let channel: WorkerChannel; let stream: TestEventStream; let loader: sinon.SinonStubbedInstance; + let coreApi: typeof coreTypes; + let testDisposables: coreTypes.Disposable[] = []; + + before(async () => { + ({ stream, loader } = beforeEventHandlerSuite()); + coreApi = await import('@azure/functions-core'); + }); - before(() => { - ({ stream, loader, channel } = beforeEventHandlerSuite()); + beforeEach(async () => { + hookData = ''; }); afterEach(async () => { await stream.afterEachEventHandlerTest(); + coreApi.Disposable.from(...testDisposables).dispose(); + testDisposables = []; }); function sendInvokeMessage(inputData?: rpc.IParameterBinding[] | null): void { @@ -272,21 +330,6 @@ describe('invocationRequest', () => { }); } - const httpInputData = { - name: 'req', - data: { - data: 'http', - http: { - body: { - string: 'blahh', - }, - rawBody: { - string: 'blahh', - }, - }, - }, - }; - function getHttpResponse(rawBody?: string | {} | undefined, name = 'res'): rpc.IParameterBinding { let body: rpc.ITypedData; if (typeof rawBody === 'string') { @@ -314,10 +357,10 @@ describe('invocationRequest', () => { it('invokes function' + suffix, async () => { loader.getFunc.returns(func); loader.getInfo.returns(new FunctionInfo(Binding.httpRes)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); await stream.assertCalledWith( Msg.receivedInvocLog(), - Msg.userTestLog, + Msg.userTestLog(), Msg.invocResponse([getHttpResponse()]) ); }); @@ -327,7 +370,7 @@ describe('invocationRequest', () => { it('returns correct data with $return binding' + suffix, async () => { loader.getFunc.returns(func); loader.getInfo.returns(new FunctionInfo(Binding.httpReturn)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); const expectedOutput = getHttpResponse(undefined, '$return'); const expectedReturnValue = { http: { @@ -369,7 +412,7 @@ describe('invocationRequest', () => { it('serializes output binding data through context.done' + suffix, async () => { loader.getFunc.returns(func); loader.getInfo.returns(new FunctionInfo(Binding.httpRes)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); const expectedOutput = [getHttpResponse({ hello: 'world' })]; await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse(expectedOutput)); }); @@ -389,7 +432,7 @@ describe('invocationRequest', () => { name: 'testFuncName', }) ); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); const expectedOutput = [ getHttpResponse({ hello: 'world' }), { @@ -413,7 +456,7 @@ describe('invocationRequest', () => { it('returns failed status for user error' + suffix, async () => { loader.getFunc.returns(func); loader.getInfo.returns(new FunctionInfo(Binding.queue)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResFailed); }); } @@ -429,7 +472,7 @@ describe('invocationRequest', () => { it('empty function does not return invocation response', async () => { loader.getFunc.returns(() => {}); loader.getInfo.returns(new FunctionInfo(Binding.httpRes)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); await stream.assertCalledWith(Msg.receivedInvocLog()); }); @@ -438,7 +481,7 @@ describe('invocationRequest', () => { context.done(); }); loader.getInfo.returns(new FunctionInfo(Binding.httpRes)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); await stream.assertCalledWith( Msg.receivedInvocLog(), Msg.asyncAndDoneLog, @@ -452,7 +495,7 @@ describe('invocationRequest', () => { context.done(); }); loader.getInfo.returns(new FunctionInfo(Binding.httpRes)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); await stream.assertCalledWith( Msg.receivedInvocLog(), Msg.duplicateDoneLog, @@ -466,11 +509,11 @@ describe('invocationRequest', () => { context.log('testUserLog'); }); loader.getInfo.returns(new FunctionInfo(Binding.httpRes)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); await stream.assertCalledWith( Msg.receivedInvocLog(), Msg.unexpectedLogAfterDoneLog, - Msg.userTestLog, + Msg.userTestLog(), Msg.invocResponse([getHttpResponse()]) ); }); @@ -482,109 +525,193 @@ describe('invocationRequest', () => { return 'hello'; }); loader.getInfo.returns(new FunctionInfo(Binding.httpRes)); - sendInvokeMessage([httpInputData]); + sendInvokeMessage([InputData.http]); // wait for first two messages to ensure invocation happens await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([getHttpResponse()])); // then add extra context.log _context!.log('testUserLog'); - await stream.assertCalledWith(Msg.unexpectedLogAfterDoneLog, Msg.userTestLog); + await stream.assertCalledWith(Msg.unexpectedLogAfterDoneLog, Msg.userTestLog()); }); - describe('#invocationRequestBefore, #invocationRequestAfter', () => { - afterEach(() => { - channel['_invocationRequestAfter'] = []; - channel['_invocationRequestBefore'] = []; - }); + for (const [func, suffix] of TestFunc.logHookData) { + it('preInvocationHook' + suffix, async () => { + loader.getFunc.returns(func); + loader.getInfo.returns(new FunctionInfo(Binding.queue)); - it('should apply hook before user function is executed', async () => { - channel.registerBeforeInvocationRequest((context, userFunction) => { - context['magic_flag'] = 'magic value'; - return userFunction.bind({ __wrapped: true }); - }); + testDisposables.push( + coreApi.registerHook('preInvocation', () => { + hookData += 'pre'; + }) + ); - channel.registerBeforeInvocationRequest((context, userFunction) => { - context['secondary_flag'] = 'magic value'; - return userFunction; - }); + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith( + Msg.receivedInvocLog(), + Msg.userTestLog('preinvoc'), + Msg.invocResponse([], { string: 'hello' }) + ); + expect(hookData).to.equal('preinvoc'); + }); + } - loader.getFunc.returns(function (this: any, context) { - expect(context['magic_flag']).to.equal('magic value'); - expect(context['secondary_flag']).to.equal('magic value'); - expect(this.__wrapped).to.equal(true); - expect(channel['_invocationRequestBefore'].length).to.equal(2); - expect(channel['_invocationRequestAfter'].length).to.equal(0); - context.done(); - }); + for (const [func, suffix] of TestFunc.logInput) { + it('preInvocationHook respects change to inputs' + suffix, async () => { + loader.getFunc.returns(func); loader.getInfo.returns(new FunctionInfo(Binding.queue)); - sendInvokeMessage([httpInputData]); - await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([])); + testDisposables.push( + coreApi.registerHook('preInvocation', (context: coreTypes.PreInvocationContext) => { + expect(context.inputs.length).to.equal(1); + expect(context.inputs[0]).to.equal('testStringData'); + context.inputs = ['changedStringData']; + }) + ); + + sendInvokeMessage([InputData.string]); + await stream.assertCalledWith( + Msg.receivedInvocLog(), + Msg.userTestLog('changedStringData'), + Msg.invocResponse([]) + ); }); + } - it('should apply hook after user function is executed (callback)', async () => { - let finished = false; - let count = 0; - channel.registerAfterInvocationRequest((_context) => { - expect(finished).to.equal(true); - count += 1; - }); + for (const [func, suffix] of TestFunc.logHookData) { + it('postInvocationHook' + suffix, async () => { + loader.getFunc.returns(func); + loader.getInfo.returns(new FunctionInfo(Binding.queue)); - loader.getFunc.returns((context: Context) => { - finished = true; - expect(channel['_invocationRequestBefore'].length).to.equal(0); - expect(channel['_invocationRequestAfter'].length).to.equal(1); - expect(count).to.equal(0); - context.done(); - }); + testDisposables.push( + coreApi.registerHook('postInvocation', (context: coreTypes.PostInvocationContext) => { + hookData += 'post'; + expect(context.result).to.equal('hello'); + expect(context.error).to.be.null; + }) + ); + + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith( + Msg.receivedInvocLog(), + Msg.userTestLog('invoc'), + Msg.invocResponse([], { string: 'hello' }) + ); + expect(hookData).to.equal('invocpost'); + }); + } + + for (const [func, suffix] of TestFunc.logHookData) { + it('postInvocationHook respects change to context.result' + suffix, async () => { + loader.getFunc.returns(func); loader.getInfo.returns(new FunctionInfo(Binding.queue)); - sendInvokeMessage([httpInputData]); - await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([])); - expect(count).to.equal(1); + testDisposables.push( + coreApi.registerHook('postInvocation', (context: coreTypes.PostInvocationContext) => { + hookData += 'post'; + expect(context.result).to.equal('hello'); + expect(context.error).to.be.null; + context.result = 'world'; + }) + ); + + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith( + Msg.receivedInvocLog(), + Msg.userTestLog('invoc'), + Msg.invocResponse([], { string: 'world' }) + ); + expect(hookData).to.equal('invocpost'); }); + } - it('should apply hook after user function resolves (promise)', async () => { - let finished = false; - let count = 0; - channel.registerAfterInvocationRequest((_context) => { - expect(finished).to.equal(true); - count += 1; - }); + for (const [func, suffix] of TestFunc.error) { + it('postInvocationHook executes if function throws error' + suffix, async () => { + loader.getFunc.returns(func); + loader.getInfo.returns(new FunctionInfo(Binding.queue)); - loader.getFunc.returns(async () => { - finished = true; - expect(channel['_invocationRequestBefore'].length).to.equal(0); - expect(channel['_invocationRequestAfter'].length).to.equal(1); - expect(count).to.equal(0); - }); + testDisposables.push( + coreApi.registerHook('postInvocation', (context: coreTypes.PostInvocationContext) => { + hookData += 'post'; + expect(context.result).to.be.null; + expect(context.error).to.equal(testError); + }) + ); + + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResFailed); + expect(hookData).to.equal('post'); + }); + } + + for (const [func, suffix] of TestFunc.error) { + it('postInvocationHook respects change to context.error' + suffix, async () => { + loader.getFunc.returns(func); loader.getInfo.returns(new FunctionInfo(Binding.queue)); - sendInvokeMessage([httpInputData]); - await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([])); - expect(count).to.equal(1); + testDisposables.push( + coreApi.registerHook('postInvocation', (context: coreTypes.PostInvocationContext) => { + hookData += 'post'; + expect(context.result).to.be.null; + expect(context.error).to.equal(testError); + context.error = null; + context.result = 'hello'; + }) + ); + + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([], { string: 'hello' })); + expect(hookData).to.equal('post'); }); + } - it('should apply hook after user function rejects (promise)', async () => { - let finished = false; - let count = 0; - channel.registerAfterInvocationRequest((_context) => { - expect(finished).to.equal(true); - count += 1; - }); + it('pre and post invocation hooks share data', async () => { + loader.getFunc.returns(async () => {}); + loader.getInfo.returns(new FunctionInfo(Binding.queue)); - loader.getFunc.returns(async () => { - finished = true; - expect(channel['_invocationRequestBefore'].length).to.equal(0); - expect(channel['_invocationRequestAfter'].length).to.equal(1); - expect(count).to.equal(0); - throw testError; - }); - loader.getInfo.returns(new FunctionInfo(Binding.queue)); + testDisposables.push( + coreApi.registerHook('preInvocation', (context: coreTypes.PreInvocationContext) => { + context.hookData['hello'] = 'world'; + hookData += 'pre'; + }) + ); - sendInvokeMessage([httpInputData]); - await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResFailed); - expect(count).to.equal(1); + testDisposables.push( + coreApi.registerHook('postInvocation', (context: coreTypes.PostInvocationContext) => { + expect(context.hookData['hello']).to.equal('world'); + hookData += 'post'; + }) + ); + + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([])); + expect(hookData).to.equal('prepost'); + }); + + it('dispose hooks', async () => { + loader.getFunc.returns(async () => {}); + loader.getInfo.returns(new FunctionInfo(Binding.queue)); + + const disposableA: coreTypes.Disposable = coreApi.registerHook('preInvocation', () => { + hookData += 'a'; }); + testDisposables.push(disposableA); + const disposableB: coreTypes.Disposable = coreApi.registerHook('preInvocation', () => { + hookData += 'b'; + }); + testDisposables.push(disposableB); + + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([])); + expect(hookData).to.equal('ab'); + + disposableA.dispose(); + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([])); + expect(hookData).to.equal('abb'); + + disposableB.dispose(); + sendInvokeMessage([InputData.http]); + await stream.assertCalledWith(Msg.receivedInvocLog(), Msg.invocResponse([])); + expect(hookData).to.equal('abb'); }); for (const [func, suffix] of TestFunc.returnEmptyString) { diff --git a/tsconfig.json b/tsconfig.json index 10fb7691..44222174 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -11,6 +11,9 @@ "paths": { "@azure/functions": [ "types" + ], + "@azure/functions-core": [ + "types-core" ] } } diff --git a/types-core/index.d.ts b/types-core/index.d.ts new file mode 100644 index 00000000..05be85ac --- /dev/null +++ b/types-core/index.d.ts @@ -0,0 +1,101 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. + +import { Context } from '@azure/functions'; + +/** + * This module is shipped as a built-in part of the Azure Functions Node.js worker and is available at runtime + */ +declare module '@azure/functions-core' { + /** + * Register a hook to interact with the lifecycle of Azure Functions. + * Hooks are executed in the order they were registered and will block execution if they throw an error + */ + export function registerHook(hookName: 'preInvocation', callback: PreInvocationCallback): Disposable; + export function registerHook(hookName: 'postInvocation', callback: PostInvocationCallback): Disposable; + export function registerHook(hookName: string, callback: HookCallback): Disposable; + + export type HookCallback = (context: HookContext) => void | Promise; + export type PreInvocationCallback = (context: PreInvocationContext) => void | Promise; + export type PostInvocationCallback = (context: PostInvocationContext) => void | Promise; + + export type HookData = { [key: string]: any }; + + /** + * Base interface for all hook context objects + */ + export interface HookContext { + /** + * The recommended place to share data between hooks + */ + hookData: HookData; + } + + /** + * Context on a function that is about to be executed + * This object will be passed to all pre invocation hooks + */ + export interface PreInvocationContext extends HookContext { + /** + * The context object passed to the function + */ + invocationContext: Context; + + /** + * The input values for this specific invocation. Changes to this array _will_ affect the inputs passed to your function + */ + inputs: any[]; + } + + /** + * Context on a function that has just executed + * This object will be passed to all post invocation hooks + */ + export interface PostInvocationContext extends HookContext { + /** + * The context object passed to the function + */ + invocationContext: Context; + + /** + * The input values for this specific invocation + */ + inputs: any[]; + + /** + * The result of the function, or null if there is no result. Changes to this value _will_ affect the overall result of the function + */ + result: any; + + /** + * The error for the function, or null if there is no error. Changes to this value _will_ affect the overall result of the function + */ + error: any; + } + + /** + * Represents a type which can release resources, such as event listening or a timer. + */ + export class Disposable { + /** + * Combine many disposable-likes into one. You can use this method when having objects with a dispose function which aren't instances of `Disposable`. + * + * @param disposableLikes Objects that have at least a `dispose`-function member. Note that asynchronous dispose-functions aren't awaited. + * @return Returns a new disposable which, upon dispose, will dispose all provided disposables. + */ + static from(...disposableLikes: { dispose: () => any }[]): Disposable; + + /** + * Creates a new disposable that calls the provided function on dispose. + * *Note* that an asynchronous function is not awaited. + * + * @param callOnDispose Function that disposes something. + */ + constructor(callOnDispose: () => any); + + /** + * Dispose this object. + */ + dispose(): any; + } +}