diff --git a/packages/utils/src/ClassUtils.test.ts b/packages/utils/src/ClassUtils.test.ts new file mode 100644 index 0000000000..cf57debfc5 --- /dev/null +++ b/packages/utils/src/ClassUtils.test.ts @@ -0,0 +1,81 @@ +/* eslint-disable max-classes-per-file */ +import { bindAllMethods, getAllMethodNames } from './ClassUtils'; + +class Aaa { + nameA = 'Aaa'; + + getAaa() { + return this.nameA; + } +} + +class Bbb extends Aaa { + nameB = 'Bbb'; + + getBbb() { + return this.nameB; + } +} + +class Ccc extends Bbb { + nameC = 'Ccc'; + + getCcc() { + return this.nameC; + } + + getCcc2 = () => this.nameC; +} + +beforeEach(() => { + jest.clearAllMocks(); + expect.hasAssertions(); +}); + +describe('getAllMethodNames', () => { + it.each([true, false])( + 'should return all method names: %s', + traversePrototypeChain => { + const instance = new Ccc(); + + const methodNames = getAllMethodNames( + instance, + traversePrototypeChain + ).sort(); + + if (traversePrototypeChain) { + expect(methodNames).toEqual(['getAaa', 'getBbb', 'getCcc', 'getCcc2']); + } else { + expect(methodNames).toEqual(['getCcc', 'getCcc2']); + } + } + ); +}); + +describe('bindAllMethods', () => { + it.each([true, false, undefined])( + 'should bind all methods: %s', + traversePrototypeChain => { + const instance = new Ccc(); + + bindAllMethods(instance, traversePrototypeChain); + + const { getAaa, getBbb, getCcc, getCcc2 } = instance; + + if (traversePrototypeChain === true) { + expect(getAaa()).toEqual('Aaa'); + expect(getBbb()).toEqual('Bbb'); + } else { + expect(() => getAaa()).toThrow( + "Cannot read properties of undefined (reading 'nameA')" + ); + expect(() => getBbb()).toThrow( + "Cannot read properties of undefined (reading 'nameB')" + ); + } + + expect(getCcc()).toEqual('Ccc'); + expect(getCcc2()).toEqual('Ccc'); + } + ); +}); diff --git a/packages/utils/src/ClassUtils.ts b/packages/utils/src/ClassUtils.ts new file mode 100644 index 0000000000..19a70d1456 --- /dev/null +++ b/packages/utils/src/ClassUtils.ts @@ -0,0 +1,65 @@ +export type MethodName = { + [K in keyof T]: T[K] extends (...args: unknown[]) => unknown ? K : never; +}[keyof T]; + +/** + * Bind all methods on the instance + its prototype to the instance. If + * `traversePrototypeChain` is true, the prototype chain will be traversed until + * Object.prototype is reached, and any additional methods found will be included. + * @param instance The instance to bind methods to + * @param traversePrototypeChain Whether to traverse the prototype chain or not + */ +export function bindAllMethods( + instance: object, + traversePrototypeChain = false +): void { + const methodNames = getAllMethodNames(instance, traversePrototypeChain); + + methodNames.forEach(methodName => { + // eslint-disable-next-line no-param-reassign + (instance as Record)[methodName] = ( + instance[methodName] as (...args: unknown[]) => unknown + ).bind(instance); + }); +} + +/** + * Get all class method names. This will return names of all methods defined on + * the instance + its prototype. If `traversePrototypeChain` is true, the prototype + * chain will be traversed until Object.prototype is reached, and any additional + * methods found will be included. + * @param instance Instance to get method names from + * @param traversePrototypeChain Whether to traverse the prototype chain or not + */ +export function getAllMethodNames( + instance: T, + traversePrototypeChain: boolean +): MethodName[] { + const methodNames = new Set>(); + + let current = instance; + + // Get method names for instance + prototype. Optionally traverse prototype + // chain until Object.prototype is reached. + let level = 0; + while ( + current != null && + current !== Object.prototype && + (level <= 1 || traversePrototypeChain) + ) { + // eslint-disable-next-line no-restricted-syntax + for (const name of Object.getOwnPropertyNames(current)) { + if ( + name !== 'constructor' && + typeof current[name as keyof typeof current] === 'function' + ) { + methodNames.add(name as MethodName); + } + } + + current = Object.getPrototypeOf(current); + level += 1; + } + + return [...methodNames.keys()]; +} diff --git a/packages/utils/src/index.ts b/packages/utils/src/index.ts index b1b0d34de0..319fef2123 100644 --- a/packages/utils/src/index.ts +++ b/packages/utils/src/index.ts @@ -1,5 +1,6 @@ export * from './DataUtils'; export { default as CanceledPromiseError } from './CanceledPromiseError'; +export * from './ClassUtils'; export { default as ColorUtils } from './ColorUtils'; export * from './ClipboardUtils'; export { default as DbNameValidator } from './DbNameValidator';