Skip to content

Commit

Permalink
Add RateLimitController (#698)
Browse files Browse the repository at this point in the history
* Initial implementation of NotificationControllerV2

* Simplify implementation

* Fix small type issue

* Simplify and start using messaging system

* Adapt tests for controller messaging

* Simplify tests a bit

* Add action handler

* Fix types

* Fix tests

* Add some documentation and exports

* Pivot to RateLimitController

* Fix export

* Remove subject metadata type

* Make controller generic and allow caller to pass implementation mapping

* Fix PR comments

* Fix type name casing

* Fix controller docstrings

The docstrings contained references to "notifications" from this controller's infancy.

* Return potential result from implementation, Wrap in promise in case implementation is async

* Improve typing and API

* Fix type issue

* Fix PR comments

* Cast action handler

* Change use of useFakeTimers

* Throw error when API is rate-limited

* Small fixes

Co-authored-by: Erik Marks <25517051+rekmarks@users.noreply.github.com>
  • Loading branch information
FrederikBolding and rekmarks authored Mar 4, 2022
1 parent 613c5b9 commit 9c5b75a
Show file tree
Hide file tree
Showing 3 changed files with 344 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ export * from './assets/TokenDetectionController';
export * from './assets/CollectibleDetectionController';
export * from './permissions';
export * from './subject-metadata';
export * from './ratelimit/RateLimitController';
export { util };
150 changes: 150 additions & 0 deletions src/ratelimit/RateLimitController.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import { ControllerMessenger } from '../ControllerMessenger';
import {
ControllerActions,
RateLimitStateChange,
RateLimitController,
RateLimitMessenger,
GetRateLimitState,
CallApi,
} from './RateLimitController';

const name = 'RateLimitController';

const implementations = {
showNativeNotification: jest.fn(),
};

type RateLimitedApis = typeof implementations;

/**
* Constructs a unrestricted controller messenger.
*
* @returns A unrestricted controller messenger.
*/
function getUnrestrictedMessenger() {
return new ControllerMessenger<
GetRateLimitState<RateLimitedApis> | CallApi<RateLimitedApis>,
RateLimitStateChange<RateLimitedApis>
>();
}

/**
* Constructs a restricted controller messenger.
*
* @param controllerMessenger - An optional unrestricted messenger
* @returns A restricted controller messenger.
*/
function getRestrictedMessenger(
controllerMessenger = getUnrestrictedMessenger(),
) {
return controllerMessenger.getRestricted<
typeof name,
ControllerActions<RateLimitedApis>['type'],
never
>({
name,
allowedActions: ['RateLimitController:call'],
}) as RateLimitMessenger<RateLimitedApis>;
}

const origin = 'snap_test';
const message = 'foo';

describe('RateLimitController', () => {
beforeEach(() => {
jest.useFakeTimers();
});

afterEach(() => {
implementations.showNativeNotification.mockClear();
jest.useRealTimers();
});

it('action: RateLimitController:call', async () => {
const unrestricted = getUnrestrictedMessenger();
const messenger = getRestrictedMessenger(unrestricted);

// Registers action handlers
new RateLimitController({
implementations,
messenger,
});

expect(
await unrestricted.call(
'RateLimitController:call',
origin,
'showNativeNotification',
origin,
message,
),
).toBeUndefined();

expect(implementations.showNativeNotification).toHaveBeenCalledWith(
origin,
message,
);
});

it('uses showNativeNotification to show a notification', async () => {
const messenger = getRestrictedMessenger();

const controller = new RateLimitController({
implementations,
messenger,
});
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
).toBeUndefined();

expect(implementations.showNativeNotification).toHaveBeenCalledWith(
origin,
message,
);
});

it('returns false if rate-limited', async () => {
const messenger = getRestrictedMessenger();
const controller = new RateLimitController({
implementations,
messenger,
rateLimitCount: 1,
});

expect(
await controller.call(origin, 'showNativeNotification', origin, message),
).toBeUndefined();

await expect(
controller.call(origin, 'showNativeNotification', origin, message),
).rejects.toThrow(
`"showNativeNotification" is currently rate-limited. Please try again later`,
);
expect(implementations.showNativeNotification).toHaveBeenCalledTimes(1);
expect(implementations.showNativeNotification).toHaveBeenCalledWith(
origin,
message,
);
});

it('rate limit is reset after timeout', async () => {
const messenger = getRestrictedMessenger();
const controller = new RateLimitController({
implementations,
messenger,
rateLimitCount: 1,
});
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
).toBeUndefined();
jest.runAllTimers();
expect(
await controller.call(origin, 'showNativeNotification', origin, message),
).toBeUndefined();
expect(implementations.showNativeNotification).toHaveBeenCalledTimes(2);
expect(implementations.showNativeNotification).toHaveBeenCalledWith(
origin,
message,
);
});
});
193 changes: 193 additions & 0 deletions src/ratelimit/RateLimitController.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import { ethErrors } from 'eth-rpc-errors';
import type { Patch } from 'immer';

import { BaseController } from '../BaseControllerV2';

import type { RestrictedControllerMessenger } from '../ControllerMessenger';

/**
* @type RateLimitState
* @property requests - Object containing number of requests in a given interval for each origin and api type combination
*/
export type RateLimitState<
RateLimitedApis extends Record<string, (...args: any[]) => any>
> = {
requests: Record<keyof RateLimitedApis, Record<string, number>>;
};

const name = 'RateLimitController';

export type RateLimitStateChange<
RateLimitedApis extends Record<string, (...args: any[]) => any>
> = {
type: `${typeof name}:stateChange`;
payload: [RateLimitState<RateLimitedApis>, Patch[]];
};

export type GetRateLimitState<
RateLimitedApis extends Record<string, (...args: any[]) => any>
> = {
type: `${typeof name}:getState`;
handler: () => RateLimitState<RateLimitedApis>;
};

export type CallApi<
RateLimitedApis extends Record<string, (...args: any[]) => any>
> = {
type: `${typeof name}:call`;
handler: RateLimitController<RateLimitedApis>['call'];
};

export type ControllerActions<
RateLimitedApis extends Record<string, (...args: any[]) => any>
> = GetRateLimitState<RateLimitedApis> | CallApi<RateLimitedApis>;

export type RateLimitMessenger<
RateLimitedApis extends Record<string, (...args: any[]) => any>
> = RestrictedControllerMessenger<
typeof name,
ControllerActions<RateLimitedApis>,
RateLimitStateChange<RateLimitedApis>,
never,
never
>;

const metadata = {
requests: { persist: false, anonymous: false },
};

/**
* Controller with logic for rate-limiting API endpoints per requesting origin.
*/
export class RateLimitController<
RateLimitedApis extends Record<string, (...args: any[]) => any>
> extends BaseController<
typeof name,
RateLimitState<RateLimitedApis>,
RateLimitMessenger<RateLimitedApis>
> {
private implementations;

private rateLimitTimeout;

private rateLimitCount;

/**
* Creates a RateLimitController instance.
*
* @param options - Constructor options.
* @param options.messenger - A reference to the messaging system.
* @param options.state - Initial state to set on this controller.
* @param options.implementations - Mapping from API type to API implementation.
* @param options.rateLimitTimeout - The time window in which the rate limit is applied (in ms).
* @param options.rateLimitCount - The amount of calls an origin can make in the rate limit time window.
*/
constructor({
rateLimitTimeout = 5000,
rateLimitCount = 1,
messenger,
state,
implementations,
}: {
rateLimitTimeout?: number;
rateLimitCount?: number;
messenger: RateLimitMessenger<RateLimitedApis>;
state?: Partial<RateLimitState<RateLimitedApis>>;
implementations: RateLimitedApis;
}) {
const defaultState = {
requests: Object.keys(implementations).reduce(
(acc, key) => ({ ...acc, [key]: {} }),
{} as Record<keyof RateLimitedApis, Record<string, number>>,
),
};
super({
name,
metadata,
messenger,
state: { ...defaultState, ...state },
});
this.implementations = implementations;
this.rateLimitTimeout = rateLimitTimeout;
this.rateLimitCount = rateLimitCount;

this.messagingSystem.registerActionHandler(
`${name}:call` as const,
((
origin: string,
type: keyof RateLimitedApis,
...args: Parameters<RateLimitedApis[keyof RateLimitedApis]>
) => this.call(origin, type, ...args)) as any,
);
}

/**
* Calls an API if the requesting origin is not rate-limited.
*
* @param origin - The requesting origin.
* @param type - The type of API call to make.
* @param args - Arguments for the API call.
* @returns `false` if rate-limited, and `true` otherwise.
*/
async call<ApiType extends keyof RateLimitedApis>(
origin: string,
type: ApiType,
...args: Parameters<RateLimitedApis[ApiType]>
): Promise<ReturnType<RateLimitedApis[ApiType]>> {
if (this.isRateLimited(type, origin)) {
throw ethErrors.rpc.limitExceeded({
message: `"${type}" is currently rate-limited. Please try again later.`,
});
}
this.recordRequest(type, origin);

const implementation = this.implementations[type];

if (!implementation) {
throw new Error('Invalid api type');
}

return implementation(...args);
}

/**
* Checks whether an origin is rate limited for the a specific API.
*
* @param api - The API the origin is trying to access.
* @param origin - The origin trying to access the API.
* @returns `true` if rate-limited, and `false` otherwise.
*/
private isRateLimited(api: keyof RateLimitedApis, origin: string) {
return this.state.requests[api][origin] >= this.rateLimitCount;
}

/**
* Records that an origin has made a request to call an API, for rate-limiting purposes.
*
* @param api - The API the origin is trying to access.
* @param origin - The origin trying to access the API.
*/
private recordRequest(api: keyof RateLimitedApis, origin: string) {
this.update((state) => {
(state as any).requests[api][origin] =
((state as any).requests[api][origin] ?? 0) + 1;

setTimeout(
() => this.resetRequestCount(api, origin),
this.rateLimitTimeout,
);
});
}

/**
* Resets the request count for a given origin and API combination, for rate-limiting purposes.
*
* @param api - The API in question.
* @param origin - The origin in question.
*/
private resetRequestCount(api: keyof RateLimitedApis, origin: string) {
this.update((state) => {
(state as any).requests[api][origin] = 0;
});
}
}

0 comments on commit 9c5b75a

Please sign in to comment.