-
-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial implementation of NotificationControllerV2 * Simplify implementation * Fix small type issue * Simplify and start using messaging system * Adapt tests for controller messaging * Simplify tests a bit * Add action handler * Fix types * Fix tests * Add some documentation and exports * Pivot to RateLimitController * Fix export * Remove subject metadata type * Make controller generic and allow caller to pass implementation mapping * Fix PR comments * Fix type name casing * Fix controller docstrings The docstrings contained references to "notifications" from this controller's infancy. * Return potential result from implementation, Wrap in promise in case implementation is async * Improve typing and API * Fix type issue * Fix PR comments * Cast action handler * Change use of useFakeTimers * Throw error when API is rate-limited * Small fixes Co-authored-by: Erik Marks <25517051+rekmarks@users.noreply.github.com>
- Loading branch information
1 parent
613c5b9
commit 9c5b75a
Showing
3 changed files
with
344 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import { ControllerMessenger } from '../ControllerMessenger'; | ||
import { | ||
ControllerActions, | ||
RateLimitStateChange, | ||
RateLimitController, | ||
RateLimitMessenger, | ||
GetRateLimitState, | ||
CallApi, | ||
} from './RateLimitController'; | ||
|
||
const name = 'RateLimitController'; | ||
|
||
const implementations = { | ||
showNativeNotification: jest.fn(), | ||
}; | ||
|
||
type RateLimitedApis = typeof implementations; | ||
|
||
/** | ||
* Constructs a unrestricted controller messenger. | ||
* | ||
* @returns A unrestricted controller messenger. | ||
*/ | ||
function getUnrestrictedMessenger() { | ||
return new ControllerMessenger< | ||
GetRateLimitState<RateLimitedApis> | CallApi<RateLimitedApis>, | ||
RateLimitStateChange<RateLimitedApis> | ||
>(); | ||
} | ||
|
||
/** | ||
* Constructs a restricted controller messenger. | ||
* | ||
* @param controllerMessenger - An optional unrestricted messenger | ||
* @returns A restricted controller messenger. | ||
*/ | ||
function getRestrictedMessenger( | ||
controllerMessenger = getUnrestrictedMessenger(), | ||
) { | ||
return controllerMessenger.getRestricted< | ||
typeof name, | ||
ControllerActions<RateLimitedApis>['type'], | ||
never | ||
>({ | ||
name, | ||
allowedActions: ['RateLimitController:call'], | ||
}) as RateLimitMessenger<RateLimitedApis>; | ||
} | ||
|
||
const origin = 'snap_test'; | ||
const message = 'foo'; | ||
|
||
describe('RateLimitController', () => { | ||
beforeEach(() => { | ||
jest.useFakeTimers(); | ||
}); | ||
|
||
afterEach(() => { | ||
implementations.showNativeNotification.mockClear(); | ||
jest.useRealTimers(); | ||
}); | ||
|
||
it('action: RateLimitController:call', async () => { | ||
const unrestricted = getUnrestrictedMessenger(); | ||
const messenger = getRestrictedMessenger(unrestricted); | ||
|
||
// Registers action handlers | ||
new RateLimitController({ | ||
implementations, | ||
messenger, | ||
}); | ||
|
||
expect( | ||
await unrestricted.call( | ||
'RateLimitController:call', | ||
origin, | ||
'showNativeNotification', | ||
origin, | ||
message, | ||
), | ||
).toBeUndefined(); | ||
|
||
expect(implementations.showNativeNotification).toHaveBeenCalledWith( | ||
origin, | ||
message, | ||
); | ||
}); | ||
|
||
it('uses showNativeNotification to show a notification', async () => { | ||
const messenger = getRestrictedMessenger(); | ||
|
||
const controller = new RateLimitController({ | ||
implementations, | ||
messenger, | ||
}); | ||
expect( | ||
await controller.call(origin, 'showNativeNotification', origin, message), | ||
).toBeUndefined(); | ||
|
||
expect(implementations.showNativeNotification).toHaveBeenCalledWith( | ||
origin, | ||
message, | ||
); | ||
}); | ||
|
||
it('returns false if rate-limited', async () => { | ||
const messenger = getRestrictedMessenger(); | ||
const controller = new RateLimitController({ | ||
implementations, | ||
messenger, | ||
rateLimitCount: 1, | ||
}); | ||
|
||
expect( | ||
await controller.call(origin, 'showNativeNotification', origin, message), | ||
).toBeUndefined(); | ||
|
||
await expect( | ||
controller.call(origin, 'showNativeNotification', origin, message), | ||
).rejects.toThrow( | ||
`"showNativeNotification" is currently rate-limited. Please try again later`, | ||
); | ||
expect(implementations.showNativeNotification).toHaveBeenCalledTimes(1); | ||
expect(implementations.showNativeNotification).toHaveBeenCalledWith( | ||
origin, | ||
message, | ||
); | ||
}); | ||
|
||
it('rate limit is reset after timeout', async () => { | ||
const messenger = getRestrictedMessenger(); | ||
const controller = new RateLimitController({ | ||
implementations, | ||
messenger, | ||
rateLimitCount: 1, | ||
}); | ||
expect( | ||
await controller.call(origin, 'showNativeNotification', origin, message), | ||
).toBeUndefined(); | ||
jest.runAllTimers(); | ||
expect( | ||
await controller.call(origin, 'showNativeNotification', origin, message), | ||
).toBeUndefined(); | ||
expect(implementations.showNativeNotification).toHaveBeenCalledTimes(2); | ||
expect(implementations.showNativeNotification).toHaveBeenCalledWith( | ||
origin, | ||
message, | ||
); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import { ethErrors } from 'eth-rpc-errors'; | ||
import type { Patch } from 'immer'; | ||
|
||
import { BaseController } from '../BaseControllerV2'; | ||
|
||
import type { RestrictedControllerMessenger } from '../ControllerMessenger'; | ||
|
||
/** | ||
* @type RateLimitState | ||
* @property requests - Object containing number of requests in a given interval for each origin and api type combination | ||
*/ | ||
export type RateLimitState< | ||
RateLimitedApis extends Record<string, (...args: any[]) => any> | ||
> = { | ||
requests: Record<keyof RateLimitedApis, Record<string, number>>; | ||
}; | ||
|
||
const name = 'RateLimitController'; | ||
|
||
export type RateLimitStateChange< | ||
RateLimitedApis extends Record<string, (...args: any[]) => any> | ||
> = { | ||
type: `${typeof name}:stateChange`; | ||
payload: [RateLimitState<RateLimitedApis>, Patch[]]; | ||
}; | ||
|
||
export type GetRateLimitState< | ||
RateLimitedApis extends Record<string, (...args: any[]) => any> | ||
> = { | ||
type: `${typeof name}:getState`; | ||
handler: () => RateLimitState<RateLimitedApis>; | ||
}; | ||
|
||
export type CallApi< | ||
RateLimitedApis extends Record<string, (...args: any[]) => any> | ||
> = { | ||
type: `${typeof name}:call`; | ||
handler: RateLimitController<RateLimitedApis>['call']; | ||
}; | ||
|
||
export type ControllerActions< | ||
RateLimitedApis extends Record<string, (...args: any[]) => any> | ||
> = GetRateLimitState<RateLimitedApis> | CallApi<RateLimitedApis>; | ||
|
||
export type RateLimitMessenger< | ||
RateLimitedApis extends Record<string, (...args: any[]) => any> | ||
> = RestrictedControllerMessenger< | ||
typeof name, | ||
ControllerActions<RateLimitedApis>, | ||
RateLimitStateChange<RateLimitedApis>, | ||
never, | ||
never | ||
>; | ||
|
||
const metadata = { | ||
requests: { persist: false, anonymous: false }, | ||
}; | ||
|
||
/** | ||
* Controller with logic for rate-limiting API endpoints per requesting origin. | ||
*/ | ||
export class RateLimitController< | ||
RateLimitedApis extends Record<string, (...args: any[]) => any> | ||
> extends BaseController< | ||
typeof name, | ||
RateLimitState<RateLimitedApis>, | ||
RateLimitMessenger<RateLimitedApis> | ||
> { | ||
private implementations; | ||
|
||
private rateLimitTimeout; | ||
|
||
private rateLimitCount; | ||
|
||
/** | ||
* Creates a RateLimitController instance. | ||
* | ||
* @param options - Constructor options. | ||
* @param options.messenger - A reference to the messaging system. | ||
* @param options.state - Initial state to set on this controller. | ||
* @param options.implementations - Mapping from API type to API implementation. | ||
* @param options.rateLimitTimeout - The time window in which the rate limit is applied (in ms). | ||
* @param options.rateLimitCount - The amount of calls an origin can make in the rate limit time window. | ||
*/ | ||
constructor({ | ||
rateLimitTimeout = 5000, | ||
rateLimitCount = 1, | ||
messenger, | ||
state, | ||
implementations, | ||
}: { | ||
rateLimitTimeout?: number; | ||
rateLimitCount?: number; | ||
messenger: RateLimitMessenger<RateLimitedApis>; | ||
state?: Partial<RateLimitState<RateLimitedApis>>; | ||
implementations: RateLimitedApis; | ||
}) { | ||
const defaultState = { | ||
requests: Object.keys(implementations).reduce( | ||
(acc, key) => ({ ...acc, [key]: {} }), | ||
{} as Record<keyof RateLimitedApis, Record<string, number>>, | ||
), | ||
}; | ||
super({ | ||
name, | ||
metadata, | ||
messenger, | ||
state: { ...defaultState, ...state }, | ||
}); | ||
this.implementations = implementations; | ||
this.rateLimitTimeout = rateLimitTimeout; | ||
this.rateLimitCount = rateLimitCount; | ||
|
||
this.messagingSystem.registerActionHandler( | ||
`${name}:call` as const, | ||
(( | ||
origin: string, | ||
type: keyof RateLimitedApis, | ||
...args: Parameters<RateLimitedApis[keyof RateLimitedApis]> | ||
) => this.call(origin, type, ...args)) as any, | ||
); | ||
} | ||
|
||
/** | ||
* Calls an API if the requesting origin is not rate-limited. | ||
* | ||
* @param origin - The requesting origin. | ||
* @param type - The type of API call to make. | ||
* @param args - Arguments for the API call. | ||
* @returns `false` if rate-limited, and `true` otherwise. | ||
*/ | ||
async call<ApiType extends keyof RateLimitedApis>( | ||
origin: string, | ||
type: ApiType, | ||
...args: Parameters<RateLimitedApis[ApiType]> | ||
): Promise<ReturnType<RateLimitedApis[ApiType]>> { | ||
if (this.isRateLimited(type, origin)) { | ||
throw ethErrors.rpc.limitExceeded({ | ||
message: `"${type}" is currently rate-limited. Please try again later.`, | ||
}); | ||
} | ||
this.recordRequest(type, origin); | ||
|
||
const implementation = this.implementations[type]; | ||
|
||
if (!implementation) { | ||
throw new Error('Invalid api type'); | ||
} | ||
|
||
return implementation(...args); | ||
} | ||
|
||
/** | ||
* Checks whether an origin is rate limited for the a specific API. | ||
* | ||
* @param api - The API the origin is trying to access. | ||
* @param origin - The origin trying to access the API. | ||
* @returns `true` if rate-limited, and `false` otherwise. | ||
*/ | ||
private isRateLimited(api: keyof RateLimitedApis, origin: string) { | ||
return this.state.requests[api][origin] >= this.rateLimitCount; | ||
} | ||
|
||
/** | ||
* Records that an origin has made a request to call an API, for rate-limiting purposes. | ||
* | ||
* @param api - The API the origin is trying to access. | ||
* @param origin - The origin trying to access the API. | ||
*/ | ||
private recordRequest(api: keyof RateLimitedApis, origin: string) { | ||
this.update((state) => { | ||
(state as any).requests[api][origin] = | ||
((state as any).requests[api][origin] ?? 0) + 1; | ||
|
||
setTimeout( | ||
() => this.resetRequestCount(api, origin), | ||
this.rateLimitTimeout, | ||
); | ||
}); | ||
} | ||
|
||
/** | ||
* Resets the request count for a given origin and API combination, for rate-limiting purposes. | ||
* | ||
* @param api - The API in question. | ||
* @param origin - The origin in question. | ||
*/ | ||
private resetRequestCount(api: keyof RateLimitedApis, origin: string) { | ||
this.update((state) => { | ||
(state as any).requests[api][origin] = 0; | ||
}); | ||
} | ||
} |