Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre and post invocation hooks #548

Merged
merged 13 commits into from
Mar 24, 2022
1 change: 1 addition & 0 deletions .eslintrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"@typescript-eslint/restrict-template-expressions": "off",
"@typescript-eslint/unbound-method": "off",
"no-empty": "off",
"prefer-const": ["error", { "destructuring": "all" }],
"prefer-rest-params": "off",
"prefer-spread": "off"
},
Expand Down
35 changes: 35 additions & 0 deletions src/Disposable.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

/**
* Based off of VS Code
* https://github.com/microsoft/vscode/blob/a64e8e5673a44e5b9c2d493666bde684bd5a135c/src/vs/workbench/api/common/extHostTypes.ts#L32
*/
export class Disposable {
static from(...inDisposables: { dispose(): any }[]): Disposable {
let disposables: ReadonlyArray<{ dispose(): any }> | undefined = inDisposables;
return new Disposable(function () {
if (disposables) {
for (const disposable of disposables) {
if (disposable && typeof disposable.dispose === 'function') {
disposable.dispose();
}
}
disposables = undefined;
}
});
}

#callOnDispose?: () => any;
ejizba marked this conversation as resolved.
Show resolved Hide resolved

constructor(callOnDispose: () => any) {
this.#callOnDispose = callOnDispose;
}

dispose(): any {
if (this.#callOnDispose instanceof Function) {
this.#callOnDispose();
this.#callOnDispose = undefined;
}
}
}
2 changes: 2 additions & 0 deletions src/Worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import * as parseArgs from 'minimist';
import { FunctionLoader } from './FunctionLoader';
import { CreateGrpcEventStream } from './GrpcClient';
import { setupCoreModule } from './setupCoreModule';
import { setupEventStream } from './setupEventStream';
import { ensureErrorType } from './utils/ensureErrorType';
import { InternalException } from './utils/InternalException';
Expand Down Expand Up @@ -42,6 +43,7 @@ export function startNodeWorker(args) {

const channel = new WorkerChannel(eventStream, new FunctionLoader());
setupEventStream(workerId, channel);
setupCoreModule(channel);

eventStream.write({
requestId: requestId,
Expand Down
54 changes: 25 additions & 29 deletions src/WorkerChannel.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

import { Context } from '@azure/functions';
import { HookCallback, HookContext } from '@azure/functions-core';
import { readJson } from 'fs-extra';
import { AzureFunctionsRpcMessages as rpc } from '../azure-functions-language-worker-protobuf/src/rpc';
import { Disposable } from './Disposable';
import { IFunctionLoader } from './FunctionLoader';
import { IEventStream } from './GrpcClient';
import { ensureErrorType } from './utils/ensureErrorType';
import path = require('path');
import LogLevel = rpc.RpcLog.Level;
import LogCategory = rpc.RpcLog.RpcLogCategory;

type InvocationRequestBefore = (context: Context, userFn: Function) => Function;
hossam-nasr marked this conversation as resolved.
Show resolved Hide resolved
type InvocationRequestAfter = (context: Context) => void;

export interface PackageJson {
type?: string;
}
Expand All @@ -22,15 +20,13 @@ export class WorkerChannel {
public eventStream: IEventStream;
public functionLoader: IFunctionLoader;
public packageJson: PackageJson;
private _invocationRequestBefore: InvocationRequestBefore[];
private _invocationRequestAfter: InvocationRequestAfter[];
#preInvocationHooks: HookCallback[] = [];
#postInvocationHooks: HookCallback[] = [];

constructor(eventStream: IEventStream, functionLoader: IFunctionLoader) {
this.eventStream = eventStream;
this.functionLoader = functionLoader;
this.packageJson = {};
this._invocationRequestBefore = [];
this._invocationRequestAfter = [];
}

/**
Expand All @@ -44,32 +40,32 @@ export class WorkerChannel {
});
}

/**
* Register a patching function to be run before User Function is executed.
* Hook should return a patched version of User Function.
*/
public registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void {
this._invocationRequestBefore.push(beforeCb);
}

/**
* Register a function to be run after User Function resolves.
*/
public registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void {
this._invocationRequestAfter.push(afterCb);
public registerHook(hookName: string, callback: HookCallback): Disposable {
const hooks = this.#getHooks(hookName);
hooks.push(callback);
return new Disposable(() => {
const index = hooks.indexOf(callback);
if (index > -1) {
hooks.splice(index, 1);
}
});
}

public runInvocationRequestBefore(context: Context, userFunction: Function): Function {
let wrappedFunction = userFunction;
for (const before of this._invocationRequestBefore) {
wrappedFunction = before(context, wrappedFunction);
public async executeHooks(hookName: string, context: HookContext): Promise<void> {
const callbacks = this.#getHooks(hookName);
for (const callback of callbacks) {
await callback(context);
}
return wrappedFunction;
}

public runInvocationRequestAfter(context: Context) {
for (const after of this._invocationRequestAfter) {
after(context);
#getHooks(hookName: string): HookCallback[] {
switch (hookName) {
case 'preInvocation':
return this.#preInvocationHooks;
case 'postInvocation':
ejizba marked this conversation as resolved.
Show resolved Hide resolved
return this.#postInvocationHooks;
default:
hossam-nasr marked this conversation as resolved.
Show resolved Hide resolved
throw new RangeError(`Unrecognized hook "${hookName}"`);
}
}

Expand Down
33 changes: 27 additions & 6 deletions src/eventHandlers/invocationRequest.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

import { HookData, PostInvocationContext, PreInvocationContext } from '@azure/functions-core';
import { format } from 'util';
import { AzureFunctionsRpcMessages as rpc } from '../../azure-functions-language-worker-protobuf/src/rpc';
import { CreateContextAndInputs } from '../Context';
Expand Down Expand Up @@ -67,7 +68,7 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
isDone = true;
}

const { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
let { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
try {
const legacyDoneTask = new Promise((resolve, reject) => {
doneEmitter.on('done', (err?: unknown, result?: any) => {
Expand All @@ -80,8 +81,13 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
});
});

let userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
userFunction = channel.runInvocationRequestBefore(context, userFunction);
const hookData: HookData = {};
const userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
const preInvocContext: PreInvocationContext = { hookData, invocationContext: context, inputs };

await channel.executeHooks('preInvocation', preInvocContext);
inputs = preInvocContext.inputs;

let rawResult = userFunction(context, ...inputs);
resultIsPromise = rawResult && typeof rawResult.then === 'function';
let resultTask: Promise<any>;
Expand All @@ -95,7 +101,24 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
resultTask = legacyDoneTask;
}

const result = await resultTask;
const postInvocContext: PostInvocationContext = {
hookData,
invocationContext: context,
inputs,
result: null,
error: null,
};
try {
postInvocContext.result = await resultTask;
} catch (err) {
postInvocContext.error = err;
}
await channel.executeHooks('postInvocation', postInvocContext);

if (isError(postInvocContext.error)) {
throw postInvocContext.error;
}
const result = postInvocContext.result;

// Allow HTTP response from context.res if HTTP response is not defined from the context.bindings object
if (info.httpOutputName && context.res && context.bindings[info.httpOutputName] === undefined) {
Expand Down Expand Up @@ -164,6 +187,4 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
requestId: requestId,
invocationResponse: response,
});

channel.runInvocationRequestAfter(context);
}
29 changes: 29 additions & 0 deletions src/setupCoreModule.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

import { HookCallback } from '@azure/functions-core';
import { Disposable } from './Disposable';
import { WorkerChannel } from './WorkerChannel';
import Module = require('module');

/**
* Intercepts the default "require" method so that we can provide our own "built-in" module
* This module is essentially the publicly accessible API for our worker
* This module is available to users only at runtime, not as an installable npm package
*/
export function setupCoreModule(channel: WorkerChannel): void {
const coreApi = {
registerHook: (hookName: string, callback: HookCallback) => channel.registerHook(hookName, callback),
Disposable,
};

Module.prototype.require = new Proxy(Module.prototype.require, {
apply(target, thisArg, argArray) {
if (argArray[0] === '@azure/functions-core') {
return coreApi;
} else {
return Reflect.apply(target, thisArg, argArray);
}
},
});
}
2 changes: 2 additions & 0 deletions test/eventHandlers/beforeEventHandlerSuite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import * as sinon from 'sinon';
import { FunctionLoader } from '../../src/FunctionLoader';
import { setupCoreModule } from '../../src/setupCoreModule';
import { setupEventStream } from '../../src/setupEventStream';
import { WorkerChannel } from '../../src/WorkerChannel';
import { TestEventStream } from './TestEventStream';
Expand All @@ -12,5 +13,6 @@ export function beforeEventHandlerSuite() {
const loader = sinon.createStubInstance<FunctionLoader>(FunctionLoader);
const channel = new WorkerChannel(stream, loader);
setupEventStream('workerId', channel);
setupCoreModule(channel);
return { stream, loader, channel };
}
Loading