diff --git a/scripts/moduleReport.js b/scripts/moduleReport.js index 61bdff99e3..1937afc34d 100644 --- a/scripts/moduleReport.js +++ b/scripts/moduleReport.js @@ -1,7 +1,18 @@ const esbuild = require('esbuild'); // List of all modules accepted in ModulesMap -const moduleNames = ['Rest']; +const moduleNames = ['Rest', 'Crypto']; + +// List of all free-standing functions exported by the library along with the +// ModulesMap entries that we expect them to transitively import +const functions = [ + { name: 'generateRandomKey', transitiveImports: ['Crypto'] }, + { name: 'getDefaultCryptoParams', transitiveImports: ['Crypto'] }, + { name: 'decodeMessage', transitiveImports: [] }, + { name: 'decodeEncryptedMessage', transitiveImports: ['Crypto'] }, + { name: 'decodeMessages', transitiveImports: [] }, + { name: 'decodeEncryptedMessages', transitiveImports: ['Crypto'] }, +]; function formatBytes(bytes) { const kibibytes = bytes / 1024; @@ -35,19 +46,45 @@ const errors = []; // First display the size of the base client console.log(`${baseClient}: ${formatBytes(baseClientSize)}`); - // Then display the size of each module together with the base client - moduleNames.forEach((moduleName) => { - const size = getImportSize([baseClient, moduleName]); - console.log(`${baseClient} + ${moduleName}: ${formatBytes(size)}`); + // Then display the size of each export together with the base client + [...moduleNames, ...Object.values(functions).map((functionData) => functionData.name)].forEach((exportName) => { + const size = getImportSize([baseClient, exportName]); + console.log(`${baseClient} + ${exportName}: ${formatBytes(size)}`); - if (!(baseClientSize < size) && !(baseClient === 'BaseRest' && moduleName === 'Rest')) { + if (!(baseClientSize < size) && !(baseClient === 'BaseRest' && exportName === 'Rest')) { // Emit an error if adding the module does not increase the bundle size // (this means that the module is not being tree-shaken correctly). - errors.push(new Error(`Adding ${moduleName} to ${baseClient} does not increase the bundle size.`)); + errors.push(new Error(`Adding ${exportName} to ${baseClient} does not increase the bundle size.`)); } }); }); +for (const functionData of functions) { + const { name: functionName, transitiveImports } = functionData; + + // First display the size of the function + const standaloneSize = getImportSize([functionName]); + console.log(`${functionName}: ${formatBytes(standaloneSize)}`); + + // Then display the size of the function together with the modules we expect + // it to transitively import + if (transitiveImports.length > 0) { + const withTransitiveImportsSize = getImportSize([functionName, ...transitiveImports]); + console.log(`${functionName} + ${transitiveImports.join(' + ')}: ${formatBytes(withTransitiveImportsSize)}`); + + if (withTransitiveImportsSize > standaloneSize) { + // Emit an error if the bundle size is increased by adding the modules + // that we expect this function to have transitively imported anyway. + // This seemed like a useful sense check, but it might need tweaking in + // the future if we make future optimisations that mean that the + // standalone functions don’t necessarily import the whole module. + errors.push( + new Error(`Adding ${transitiveImports.join(' + ')} to ${functionName} unexpectedly increases the bundle size.`) + ); + } + } +} + if (errors.length > 0) { for (const error of errors) { console.log(error.message); diff --git a/src/common/lib/client/baseclient.ts b/src/common/lib/client/baseclient.ts index aeac034d62..da5ee730e3 100644 --- a/src/common/lib/client/baseclient.ts +++ b/src/common/lib/client/baseclient.ts @@ -10,10 +10,11 @@ import ClientOptions, { NormalisedClientOptions } from '../../types/ClientOption import * as API from '../../../../ably'; import Platform from '../../platform'; -import Message from '../types/message'; import PresenceMessage from '../types/presencemessage'; import { ModulesMap } from './modulesmap'; import { Rest } from './rest'; +import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic'; +import { throwMissingModuleError } from '../util/utils'; type BatchResult = API.Types.BatchResult; type BatchPublishSpec = API.Types.BatchPublishSpec; @@ -38,6 +39,7 @@ class BaseClient { auth: Auth; private readonly _rest: Rest | null; + readonly _Crypto: IUntypedCryptoStatic | null; constructor(options: ClientOptions | string, modules: ModulesMap) { if (!options) { @@ -88,11 +90,12 @@ class BaseClient { this.auth = new Auth(this, normalOptions); this._rest = modules.Rest ? new modules.Rest(this) : null; + this._Crypto = modules.Crypto ?? null; } private get rest(): Rest { if (!this._rest) { - throw new ErrorInfo('Rest module not provided', 400, 40000); + throwMissingModuleError('Rest'); } return this._rest; } @@ -147,8 +150,6 @@ class BaseClient { } static Platform = Platform; - static Crypto?: typeof Platform.Crypto; - static Message = Message; static PresenceMessage = PresenceMessage; } diff --git a/src/common/lib/client/channel.ts b/src/common/lib/client/channel.ts index d14f4c9752..370c32993e 100644 --- a/src/common/lib/client/channel.ts +++ b/src/common/lib/client/channel.ts @@ -10,8 +10,8 @@ import { ChannelOptions } from '../../types/channel'; import { PaginatedResultCallback, StandardCallback } from '../../types/utils'; import BaseClient from './baseclient'; import * as API from '../../../../ably'; -import Platform from 'common/platform'; import Defaults from '../util/defaults'; +import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic'; interface RestHistoryParams { start?: number; @@ -30,11 +30,11 @@ function allEmptyIds(messages: Array) { }); } -function normaliseChannelOptions(options?: ChannelOptions) { +function normaliseChannelOptions(Crypto: IUntypedCryptoStatic | null, options?: ChannelOptions) { const channelOptions = options || {}; if (channelOptions.cipher) { - if (!Platform.Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead'); - const cipher = Platform.Crypto.getCipher(channelOptions.cipher); + if (!Crypto) Utils.throwMissingModuleError('Crypto'); + const cipher = Crypto.getCipher(channelOptions.cipher); channelOptions.cipher = cipher.cipherParams; channelOptions.channelCipher = cipher.cipher; } else if ('cipher' in channelOptions) { @@ -60,11 +60,11 @@ class Channel extends EventEmitter { this.name = name; this.basePath = '/channels/' + encodeURIComponent(name); this.presence = new Presence(this); - this.channelOptions = normaliseChannelOptions(channelOptions); + this.channelOptions = normaliseChannelOptions(client._Crypto ?? null, channelOptions); } setOptions(options?: ChannelOptions): void { - this.channelOptions = normaliseChannelOptions(options); + this.channelOptions = normaliseChannelOptions(this.client._Crypto ?? null, options); } history( diff --git a/src/common/lib/client/defaultrealtime.ts b/src/common/lib/client/defaultrealtime.ts index 4b75d9bd94..e8b66f878e 100644 --- a/src/common/lib/client/defaultrealtime.ts +++ b/src/common/lib/client/defaultrealtime.ts @@ -4,16 +4,32 @@ import { allCommonModules } from './modulesmap'; import * as Utils from '../util/utils'; import ConnectionManager from '../transport/connectionmanager'; import ProtocolMessage from '../types/protocolmessage'; +import Platform from 'common/platform'; +import { DefaultMessage } from '../types/defaultmessage'; /** `DefaultRealtime` is the class that the non tree-shakable version of the SDK exports as `Realtime`. It ensures that this version of the SDK includes all of the functionality which is optionally available in the tree-shakable version. */ export class DefaultRealtime extends BaseRealtime { constructor(options: ClientOptions) { - super(options, allCommonModules); + super(options, { ...allCommonModules, Crypto: DefaultRealtime.Crypto ?? undefined }); } static Utils = Utils; static ConnectionManager = ConnectionManager; static ProtocolMessage = ProtocolMessage; + + private static _Crypto: typeof Platform.Crypto = null; + static get Crypto() { + if (this._Crypto === null) { + throw new Error('Encryption not enabled; use ably.encryption.js instead'); + } + + return this._Crypto; + } + static set Crypto(newValue: typeof Platform.Crypto) { + this._Crypto = newValue; + } + + static Message = DefaultMessage; } diff --git a/src/common/lib/client/defaultrest.ts b/src/common/lib/client/defaultrest.ts index 431e4547d8..1b7df607a6 100644 --- a/src/common/lib/client/defaultrest.ts +++ b/src/common/lib/client/defaultrest.ts @@ -1,12 +1,28 @@ import { BaseRest } from './baserest'; import ClientOptions from '../../types/ClientOptions'; import { allCommonModules } from './modulesmap'; +import Platform from 'common/platform'; +import { DefaultMessage } from '../types/defaultmessage'; /** `DefaultRest` is the class that the non tree-shakable version of the SDK exports as `Rest`. It ensures that this version of the SDK includes all of the functionality which is optionally available in the tree-shakable version. */ export class DefaultRest extends BaseRest { constructor(options: ClientOptions | string) { - super(options, allCommonModules); + super(options, { ...allCommonModules, Crypto: DefaultRest.Crypto ?? undefined }); } + + private static _Crypto: typeof Platform.Crypto = null; + static get Crypto() { + if (this._Crypto === null) { + throw new Error('Encryption not enabled; use ably.encryption.js instead'); + } + + return this._Crypto; + } + static set Crypto(newValue: typeof Platform.Crypto) { + this._Crypto = newValue; + } + + static Message = DefaultMessage; } diff --git a/src/common/lib/client/modulesmap.ts b/src/common/lib/client/modulesmap.ts index c56c2d98e7..4133bcc3a3 100644 --- a/src/common/lib/client/modulesmap.ts +++ b/src/common/lib/client/modulesmap.ts @@ -1,7 +1,9 @@ import { Rest } from './rest'; +import { IUntypedCryptoStatic } from '../../types/ICryptoStatic'; export interface ModulesMap { Rest?: typeof Rest; + Crypto?: IUntypedCryptoStatic; } export const allCommonModules: ModulesMap = { Rest }; diff --git a/src/common/lib/types/defaultmessage.ts b/src/common/lib/types/defaultmessage.ts new file mode 100644 index 0000000000..3ef7dab056 --- /dev/null +++ b/src/common/lib/types/defaultmessage.ts @@ -0,0 +1,16 @@ +import Message, { fromEncoded, fromEncodedArray } from './message'; +import * as API from '../../../../ably'; +import Platform from 'common/platform'; + +/** + `DefaultMessage` is the class returned by `DefaultRest` and `DefaultRealtime`’s `Message` static property. It introduces the static methods described in the `MessageStatic` interface of the public API of the non tree-shakable version of the library. + */ +export class DefaultMessage extends Message { + static async fromEncoded(encoded: unknown, inputOptions?: API.Types.ChannelOptions): Promise { + return fromEncoded(Platform.Crypto, encoded, inputOptions); + } + + static async fromEncodedArray(encodedArray: Array, options?: API.Types.ChannelOptions): Promise { + return fromEncodedArray(Platform.Crypto, encodedArray, options); + } +} diff --git a/src/common/lib/types/message.ts b/src/common/lib/types/message.ts index 052df0b178..692c5e069f 100644 --- a/src/common/lib/types/message.ts +++ b/src/common/lib/types/message.ts @@ -6,6 +6,7 @@ import PresenceMessage from './presencemessage'; import * as Utils from '../util/utils'; import { Bufferlike as BrowserBufferlike } from '../../../platform/web/lib/util/bufferutils'; import * as API from '../../../../ably'; +import { IUntypedCryptoStatic } from 'common/types/ICryptoStatic'; export type CipherOptions = { channelCipher: { @@ -42,10 +43,13 @@ function normaliseContext(context: CipherOptions | EncodingDecodingContext | Cha return context as EncodingDecodingContext; } -function normalizeCipherOptions(options: API.Types.ChannelOptions | null): ChannelOptions { +function normalizeCipherOptions( + Crypto: IUntypedCryptoStatic | null, + options: API.Types.ChannelOptions | null +): ChannelOptions { if (options && options.cipher) { - if (!Platform.Crypto) throw new Error('Encryption not enabled; use ably.encryption.js instead'); - const cipher = Platform.Crypto.getCipher(options.cipher); + if (!Crypto) Utils.throwMissingModuleError('Crypto'); + const cipher = Crypto.getCipher(options.cipher); return { cipher: cipher.cipherParams, channelCipher: cipher.cipher, @@ -71,6 +75,35 @@ function getMessageSize(msg: Message) { return size; } +export async function fromEncoded( + Crypto: IUntypedCryptoStatic | null, + encoded: unknown, + inputOptions?: API.Types.ChannelOptions +): Promise { + const msg = Message.fromValues(encoded); + const options = normalizeCipherOptions(Crypto, inputOptions ?? null); + /* if decoding fails at any point, catch and return the message decoded to + * the fullest extent possible */ + try { + await Message.decode(msg, options); + } catch (e) { + Logger.logAction(Logger.LOG_ERROR, 'Message.fromEncoded()', (e as Error).toString()); + } + return msg; +} + +export async function fromEncodedArray( + Crypto: IUntypedCryptoStatic | null, + encodedArray: Array, + options?: API.Types.ChannelOptions +): Promise { + return Promise.all( + encodedArray.map(function (encoded) { + return fromEncoded(Crypto, encoded, options); + }) + ); +} + class Message { name?: string; id?: string; @@ -330,27 +363,6 @@ class Message { return result; } - static async fromEncoded(encoded: unknown, inputOptions?: API.Types.ChannelOptions): Promise { - const msg = Message.fromValues(encoded); - const options = normalizeCipherOptions(inputOptions ?? null); - /* if decoding fails at any point, catch and return the message decoded to - * the fullest extent possible */ - try { - await Message.decode(msg, options); - } catch (e) { - Logger.logAction(Logger.LOG_ERROR, 'Message.fromEncoded()', (e as Error).toString()); - } - return msg; - } - - static async fromEncodedArray(encodedArray: Array, options?: API.Types.ChannelOptions): Promise { - return Promise.all( - encodedArray.map(function (encoded) { - return Message.fromEncoded(encoded, options); - }) - ); - } - /* This should be called on encode()d (and encrypt()d) Messages (as it * assumes the data is a string or buffer) */ static getMessagesSize(messages: Message[]): number { diff --git a/src/common/lib/util/utils.ts b/src/common/lib/util/utils.ts index 03d546b0a8..788f54d502 100644 --- a/src/common/lib/util/utils.ts +++ b/src/common/lib/util/utils.ts @@ -1,5 +1,6 @@ import Platform from 'common/platform'; import ErrorInfo, { PartialErrorInfo } from 'common/lib/types/errorinfo'; +import { ModulesMap } from '../client/modulesmap'; function randomPosn(arrOrStr: Array | string) { return Math.floor(Math.random() * arrOrStr.length); @@ -551,3 +552,7 @@ export function arrEquals(a: any[], b: any[]) { }) ); } + +export function throwMissingModuleError(moduleName: keyof ModulesMap): never { + throw new ErrorInfo(`${moduleName} module not provided`, 400, 40000); +} diff --git a/src/common/platform.ts b/src/common/platform.ts index 41992c3bc5..b55e625a26 100644 --- a/src/common/platform.ts +++ b/src/common/platform.ts @@ -7,6 +7,7 @@ import IBufferUtils from './types/IBufferUtils'; import Transport from './lib/transport/transport'; import * as WebBufferUtils from '../platform/web/lib/util/bufferutils'; import * as NodeBufferUtils from '../platform/nodejs/lib/util/bufferutils'; +import { IUntypedCryptoStatic } from '../common/types/ICryptoStatic'; type Bufferlike = WebBufferUtils.Bufferlike | NodeBufferUtils.Bufferlike; type BufferUtilsOutput = WebBufferUtils.Output | NodeBufferUtils.Output; @@ -23,12 +24,11 @@ export default class Platform { */ static BufferUtils: IBufferUtils; /* - This should be a class whose static methods implement the ICryptoStatic - interface, but (for the same reasons as described in the BufferUtils - comment above) Platform doesn’t currently allow us to express the - generic parameters, hence keeping the type as `any`. + We’d like this to be ICryptoStatic with the correct generic arguments, + but Platform doesn’t currently allow that, as described in the BufferUtils + comment above. */ - static Crypto: any; + static Crypto: IUntypedCryptoStatic | null; static Http: typeof IHttp; static Transports: Array<(connectionManager: typeof ConnectionManager) => Transport>; static Defaults: IDefaults; diff --git a/src/common/types/ICryptoStatic.ts b/src/common/types/ICryptoStatic.ts index 97dac8fcce..4115c5d248 100644 --- a/src/common/types/ICryptoStatic.ts +++ b/src/common/types/ICryptoStatic.ts @@ -13,3 +13,11 @@ export default interface ICryptoStatic ): IGetCipherReturnValue>; } + +/* + A less strongly typed version of ICryptoStatic to use until we + can make Platform a generic type (see comment there). + */ +export interface IUntypedCryptoStatic extends API.Types.Crypto { + getCipher(params: any): any; +} diff --git a/src/platform/nativescript/index.ts b/src/platform/nativescript/index.ts index be9941dc66..bdb55c8a34 100644 --- a/src/platform/nativescript/index.ts +++ b/src/platform/nativescript/index.ts @@ -1,5 +1,4 @@ // Common -import BaseClient from '../../common/lib/client/baseclient'; import { DefaultRest } from '../../common/lib/client/defaultrest'; import { DefaultRealtime } from '../../common/lib/client/defaultrealtime'; import Platform from '../../common/platform'; @@ -8,7 +7,7 @@ import ErrorInfo from '../../common/lib/types/errorinfo'; // Platform Specific import BufferUtils from '../web/lib/util/bufferutils'; // @ts-ignore -import CryptoFactory from '../web/lib/util/crypto'; +import { createCryptoClass } from '../web/lib/util/crypto'; import Http from '../web/lib/util/http'; // @ts-ignore import Config from './config'; @@ -21,7 +20,7 @@ import WebStorage from './lib/util/webstorage'; import PlatformDefaults from '../web/lib/util/defaults'; import msgpack from '../web/lib/util/msgpack'; -const Crypto = CryptoFactory(Config, BufferUtils); +const Crypto = createCryptoClass(Config, BufferUtils); Platform.Crypto = Crypto; Platform.BufferUtils = BufferUtils; @@ -30,7 +29,9 @@ Platform.Config = Config; Platform.Transports = Transports; Platform.WebStorage = WebStorage; -BaseClient.Crypto = Crypto; +for (const clientClass of [DefaultRest, DefaultRealtime]) { + clientClass.Crypto = Crypto; +} Logger.initLogHandlers(); diff --git a/src/platform/nodejs/index.ts b/src/platform/nodejs/index.ts index eb67752641..ba683f9bad 100644 --- a/src/platform/nodejs/index.ts +++ b/src/platform/nodejs/index.ts @@ -1,5 +1,4 @@ // Common -import BaseClient from '../../common/lib/client/baseclient'; import { DefaultRest } from '../../common/lib/client/defaultrest'; import { DefaultRealtime } from '../../common/lib/client/defaultrealtime'; import Platform from '../../common/platform'; @@ -8,7 +7,7 @@ import ErrorInfo from '../../common/lib/types/errorinfo'; // Platform Specific import BufferUtils from './lib/util/bufferutils'; // @ts-ignore -import CryptoFactory from './lib/util/crypto'; +import { createCryptoClass } from './lib/util/crypto'; import Http from './lib/util/http'; import Config from './config'; // @ts-ignore @@ -17,7 +16,7 @@ import Logger from '../../common/lib/util/logger'; import { getDefaults } from '../../common/lib/util/defaults'; import PlatformDefaults from './lib/util/defaults'; -const Crypto = CryptoFactory(BufferUtils); +const Crypto = createCryptoClass(BufferUtils); Platform.Crypto = Crypto; Platform.BufferUtils = BufferUtils as typeof Platform.BufferUtils; @@ -26,7 +25,9 @@ Platform.Config = Config; Platform.Transports = Transports; Platform.WebStorage = null; -BaseClient.Crypto = Crypto; +for (const clientClass of [DefaultRest, DefaultRealtime]) { + clientClass.Crypto = Crypto; +} Logger.initLogHandlers(); diff --git a/src/platform/nodejs/lib/util/crypto.ts b/src/platform/nodejs/lib/util/crypto.ts index 1d99b060af..62fe8303b8 100644 --- a/src/platform/nodejs/lib/util/crypto.ts +++ b/src/platform/nodejs/lib/util/crypto.ts @@ -18,7 +18,7 @@ type OutputCiphertext = Buffer; type InputCiphertext = CryptoDataTypes.InputCiphertext; type OutputPlaintext = Buffer; -var CryptoFactory = function (bufferUtils: typeof BufferUtils) { +var createCryptoClass = function (bufferUtils: typeof BufferUtils) { var DEFAULT_ALGORITHM = 'aes'; var DEFAULT_KEYLENGTH = 256; // bits var DEFAULT_MODE = 'cbc'; @@ -276,4 +276,4 @@ var CryptoFactory = function (bufferUtils: typeof BufferUtils) { return Crypto; }; -export default CryptoFactory; +export { createCryptoClass }; diff --git a/src/platform/react-native/index.ts b/src/platform/react-native/index.ts index 88dae850da..478ee9664d 100644 --- a/src/platform/react-native/index.ts +++ b/src/platform/react-native/index.ts @@ -1,5 +1,4 @@ // Common -import BaseClient from '../../common/lib/client/baseclient'; import { DefaultRest } from '../../common/lib/client/defaultrest'; import { DefaultRealtime } from '../../common/lib/client/defaultrealtime'; import Platform from '../../common/platform'; @@ -8,7 +7,7 @@ import ErrorInfo from '../../common/lib/types/errorinfo'; // Platform Specific import BufferUtils from '../web/lib/util/bufferutils'; // @ts-ignore -import CryptoFactory from '../web/lib/util/crypto'; +import { createCryptoClass } from '../web/lib/util/crypto'; import Http from '../web/lib/util/http'; import configFactory from './config'; // @ts-ignore @@ -21,7 +20,7 @@ import msgpack from '../web/lib/util/msgpack'; const Config = configFactory(BufferUtils); -const Crypto = CryptoFactory(Config, BufferUtils); +const Crypto = createCryptoClass(Config, BufferUtils); Platform.Crypto = Crypto; Platform.BufferUtils = BufferUtils; @@ -30,7 +29,9 @@ Platform.Config = Config; Platform.Transports = Transports; Platform.WebStorage = WebStorage; -BaseClient.Crypto = Crypto; +for (const clientClass of [DefaultRest, DefaultRealtime]) { + clientClass.Crypto = Crypto; +} Logger.initLogHandlers(); diff --git a/src/platform/web-noencryption/index.ts b/src/platform/web-noencryption/index.ts index e10ae425c2..80748e6b9e 100644 --- a/src/platform/web-noencryption/index.ts +++ b/src/platform/web-noencryption/index.ts @@ -1,5 +1,4 @@ // Common -import BaseClient from '../../common/lib/client/baseclient'; import { DefaultRest } from '../../common/lib/client/defaultrest'; import { DefaultRealtime } from '../../common/lib/client/defaultrealtime'; import Platform from '../../common/platform'; @@ -25,8 +24,6 @@ Platform.Config = Config; Platform.Transports = Transports; Platform.WebStorage = WebStorage; -BaseClient.Crypto = null; - Logger.initLogHandlers(); Platform.Defaults = getDefaults(PlatformDefaults); diff --git a/src/platform/web/index.ts b/src/platform/web/index.ts index ea6fb1b53e..9fa12989e4 100644 --- a/src/platform/web/index.ts +++ b/src/platform/web/index.ts @@ -1,5 +1,4 @@ // Common -import BaseClient from '../../common/lib/client/baseclient'; import { DefaultRest } from '../../common/lib/client/defaultrest'; import { DefaultRealtime } from '../../common/lib/client/defaultrealtime'; import Platform from '../../common/platform'; @@ -8,7 +7,7 @@ import ErrorInfo from '../../common/lib/types/errorinfo'; // Platform Specific import BufferUtils from './lib/util/bufferutils'; // @ts-ignore -import CryptoFactory from './lib/util/crypto'; +import { createCryptoClass } from './lib/util/crypto'; import Http from './lib/util/http'; import Config from './config'; // @ts-ignore @@ -19,7 +18,7 @@ import WebStorage from './lib/util/webstorage'; import PlatformDefaults from './lib/util/defaults'; import msgpack from './lib/util/msgpack'; -const Crypto = CryptoFactory(Config, BufferUtils); +const Crypto = createCryptoClass(Config, BufferUtils); Platform.Crypto = Crypto; Platform.BufferUtils = BufferUtils; @@ -28,7 +27,9 @@ Platform.Config = Config; Platform.Transports = Transports; Platform.WebStorage = WebStorage; -BaseClient.Crypto = Crypto; +for (const clientClass of [DefaultRest, DefaultRealtime]) { + clientClass.Crypto = Crypto; +} Logger.initLogHandlers(); diff --git a/src/platform/web/lib/util/crypto.ts b/src/platform/web/lib/util/crypto.ts index 23c572d9dd..eff091f094 100644 --- a/src/platform/web/lib/util/crypto.ts +++ b/src/platform/web/lib/util/crypto.ts @@ -16,7 +16,7 @@ type OutputCiphertext = ArrayBuffer; type InputCiphertext = CryptoDataTypes.InputCiphertext; type OutputPlaintext = ArrayBuffer; -var CryptoFactory = function (config: IPlatformConfig, bufferUtils: typeof BufferUtils) { +var createCryptoClass = function (config: IPlatformConfig, bufferUtils: typeof BufferUtils) { var DEFAULT_ALGORITHM = 'aes'; var DEFAULT_KEYLENGTH = 256; // bits var DEFAULT_MODE = 'cbc'; @@ -318,4 +318,4 @@ var CryptoFactory = function (config: IPlatformConfig, bufferUtils: typeof Buffe return Crypto; }; -export default CryptoFactory; +export { createCryptoClass }; diff --git a/src/platform/web/modules.ts b/src/platform/web/modules.ts index d924eab04e..ed91109bfe 100644 --- a/src/platform/web/modules.ts +++ b/src/platform/web/modules.ts @@ -1,5 +1,4 @@ // Common -import BaseClient from 'common/lib/client/baseclient'; import { BaseRest } from '../../common/lib/client/baserest'; import BaseRealtime from '../../common/lib/client/baserealtime'; import Platform from '../../common/platform'; @@ -8,7 +7,6 @@ import ErrorInfo from '../../common/lib/types/errorinfo'; // Platform Specific import BufferUtils from './lib/util/bufferutils'; // @ts-ignore -import CryptoFactory from './lib/util/crypto'; import Http from './lib/util/http'; import Config from './config'; // @ts-ignore @@ -18,17 +16,12 @@ import { getDefaults } from '../../common/lib/util/defaults'; import WebStorage from './lib/util/webstorage'; import PlatformDefaults from './lib/util/defaults'; -const Crypto = CryptoFactory(Config, BufferUtils); - -Platform.Crypto = Crypto; Platform.BufferUtils = BufferUtils; Platform.Http = Http; Platform.Config = Config; Platform.Transports = Transports; Platform.WebStorage = WebStorage; -BaseClient.Crypto = Crypto; - Logger.initLogHandlers(); Platform.Defaults = getDefaults(PlatformDefaults); @@ -46,5 +39,7 @@ if (Platform.Config.noUpgrade) { Platform.Defaults.upgradeTransports = []; } +export * from './modules/crypto'; +export * from './modules/message'; export { Rest } from '../../common/lib/client/rest'; export { BaseRest, BaseRealtime, ErrorInfo }; diff --git a/src/platform/web/modules/crypto.ts b/src/platform/web/modules/crypto.ts new file mode 100644 index 0000000000..2c65d23eb1 --- /dev/null +++ b/src/platform/web/modules/crypto.ts @@ -0,0 +1,14 @@ +import BufferUtils from '../lib/util/bufferutils'; +import { createCryptoClass } from '../lib/util/crypto'; +import Config from '../config'; +import * as API from '../../../../ably'; + +export const Crypto = /* @__PURE__@ */ createCryptoClass(Config, BufferUtils); + +export const generateRandomKey: API.Types.Crypto['generateRandomKey'] = (keyLength) => { + return Crypto.generateRandomKey(keyLength); +}; + +export const getDefaultCryptoParams: API.Types.Crypto['getDefaultParams'] = (params) => { + return Crypto.getDefaultParams(params); +}; diff --git a/src/platform/web/modules/message.ts b/src/platform/web/modules/message.ts new file mode 100644 index 0000000000..de8b2ab4f9 --- /dev/null +++ b/src/platform/web/modules/message.ts @@ -0,0 +1,21 @@ +import * as API from '../../../../ably'; +import { Crypto } from './crypto'; +import { fromEncoded, fromEncodedArray } from '../../../common/lib/types/message'; + +// The type assertions for the decode* functions below are due to https://github.com/ably/ably-js/issues/1421 + +export const decodeMessage = ((obj, options) => { + return fromEncoded(null, obj, options); +}) as API.Types.MessageStatic['fromEncoded']; + +export const decodeEncryptedMessage = ((obj, options) => { + return fromEncoded(Crypto, obj, options); +}) as API.Types.MessageStatic['fromEncoded']; + +export const decodeMessages = ((obj, options) => { + return fromEncodedArray(null, obj, options); +}) as API.Types.MessageStatic['fromEncodedArray']; + +export const decodeEncryptedMessages = ((obj, options) => { + return fromEncodedArray(Crypto, obj, options); +}) as API.Types.MessageStatic['fromEncodedArray']; diff --git a/test/browser/modules.test.js b/test/browser/modules.test.js index 6ccbf64bbc..72aabe2c8e 100644 --- a/test/browser/modules.test.js +++ b/test/browser/modules.test.js @@ -1,12 +1,36 @@ -import { BaseRest, BaseRealtime, Rest } from '../../build/modules/index.js'; +import { + BaseRest, + BaseRealtime, + Rest, + generateRandomKey, + getDefaultCryptoParams, + decodeMessage, + decodeEncryptedMessage, + decodeMessages, + decodeEncryptedMessages, + Crypto, +} from '../../build/modules/index.js'; describe('browser/modules', function () { this.timeout(10 * 1000); const expect = chai.expect; + const BufferUtils = BaseRest.Platform.BufferUtils; let ablyClientOptions; + let testResourcesPath; + let loadTestData; + let testMessageEquality; before((done) => { ablyClientOptions = window.ablyHelpers.ablyClientOptions; + testResourcesPath = window.ablyHelpers.testResourcesPath; + testMessageEquality = window.ablyHelpers.testMessageEquality; + + loadTestData = async (dataPath) => { + return new Promise((resolve, reject) => { + window.ablyHelpers.loadTestData(dataPath, (err, testData) => (err ? reject(err) : resolve(testData))); + }); + }; + window.ablyHelpers.setupApp(done); }); @@ -44,4 +68,182 @@ describe('browser/modules', function () { }); }); }); + + describe('Crypto standalone functions', () => { + it('generateRandomKey', async () => { + const key = await generateRandomKey(); + expect(key).to.be.an('ArrayBuffer'); + }); + + it('getDefaultCryptoParams', async () => { + const key = await generateRandomKey(); + const params = getDefaultCryptoParams({ key }); + expect(params).to.be.an('object'); + }); + }); + + describe('Message standalone functions', () => { + async function testDecodesMessageData(functionUnderTest) { + const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + + const item = testData.items[1]; + const decoded = await functionUnderTest(item.encoded); + + expect(decoded.data).to.be.an('ArrayBuffer'); + } + + describe('decodeMessage', () => { + it('decodes a message’s data', async () => { + testDecodesMessageData(decodeMessage); + }); + + it('throws an error when given channel options with a cipher', async () => { + const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + const key = BufferUtils.base64Decode(testData.key); + const iv = BufferUtils.base64Decode(testData.iv); + + let thrownError = null; + try { + await decodeMessage(testData.items[0].encrypted, { cipher: { key, iv } }); + } catch (error) { + thrownError = error; + } + + expect(thrownError).not.to.be.null; + expect(thrownError.message).to.equal('Crypto module not provided'); + }); + }); + + describe('decodeEncryptedMessage', async () => { + it('decodes a message’s data', async () => { + testDecodesMessageData(decodeEncryptedMessage); + }); + + it('decrypts a message', async () => { + const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + + const key = BufferUtils.base64Decode(testData.key); + const iv = BufferUtils.base64Decode(testData.iv); + + for (const item of testData.items) { + const [decodedFromEncoded, decodedFromEncrypted] = await Promise.all([ + decodeMessage(item.encoded), + decodeEncryptedMessage(item.encrypted, { cipher: { key, iv } }), + ]); + + testMessageEquality(decodedFromEncoded, decodedFromEncrypted); + } + }); + }); + + async function testDecodesMessagesData(functionUnderTest) { + const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + + const items = [testData.items[1], testData.items[3]]; + const decoded = await functionUnderTest(items.map((item) => item.encoded)); + + expect(decoded[0].data).to.be.an('ArrayBuffer'); + expect(decoded[1].data).to.be.an('array'); + } + + describe('decodeMessages', () => { + it('decodes messages’ data', async () => { + testDecodesMessagesData(decodeMessages); + }); + + it('throws an error when given channel options with a cipher', async () => { + const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + const key = BufferUtils.base64Decode(testData.key); + const iv = BufferUtils.base64Decode(testData.iv); + + let thrownError = null; + try { + await decodeMessages( + testData.items.map((item) => item.encrypted), + { cipher: { key, iv } } + ); + } catch (error) { + thrownError = error; + } + + expect(thrownError).not.to.be.null; + expect(thrownError.message).to.equal('Crypto module not provided'); + }); + }); + + describe('decodeEncryptedMessages', () => { + it('decodes messages’ data', async () => { + testDecodesMessagesData(decodeEncryptedMessages); + }); + + it('decrypts messages', async () => { + const testData = await loadTestData(testResourcesPath + 'crypto-data-128.json'); + + const key = BufferUtils.base64Decode(testData.key); + const iv = BufferUtils.base64Decode(testData.iv); + + const [decodedFromEncoded, decodedFromEncrypted] = await Promise.all([ + decodeMessages(testData.items.map((item) => item.encoded)), + decodeEncryptedMessages( + testData.items.map((item) => item.encrypted), + { cipher: { key, iv } } + ), + ]); + + for (let i = 0; i < decodedFromEncoded.length; i++) { + testMessageEquality(decodedFromEncoded[i], decodedFromEncrypted[i]); + } + }); + }); + }); + + describe('Crypto', () => { + describe('without Crypto', () => { + for (const clientClass of [BaseRest, BaseRealtime]) { + describe(clientClass.name, () => { + it('throws an error when given channel options with a cipher', async () => { + const client = new clientClass(ablyClientOptions(), {}); + const key = await generateRandomKey(); + expect(() => client.channels.get('channel', { cipher: { key } })).to.throw('Crypto module not provided'); + }); + }); + } + }); + + describe('with Crypto', () => { + for (const clientClass of [BaseRest, BaseRealtime]) { + describe(clientClass.name, () => { + it('is able to publish encrypted messages', async () => { + const clientOptions = ablyClientOptions(); + + const key = await generateRandomKey(); + + // Publish the message on a channel configured to use encryption, and receive it on one not configured to use encryption + + const rxClient = new BaseRealtime(clientOptions, {}); + const rxChannel = rxClient.channels.get('channel'); + await rxChannel.attach(); + + const rxMessagePromise = new Promise((resolve, _) => rxChannel.subscribe((message) => resolve(message))); + + const encryptionChannelOptions = { cipher: { key } }; + + const txMessage = { name: 'message', data: 'data' }; + const txClient = new clientClass(clientOptions, { Crypto }); + const txChannel = txClient.channels.get('channel', encryptionChannelOptions); + await txChannel.publish(txMessage); + + const rxMessage = await rxMessagePromise; + + // Verify that the message was published with encryption + expect(rxMessage.encoding).to.equal('utf-8/cipher+aes-256-cbc'); + + // Verify that the message was correctly encrypted + const rxMessageDecrypted = await decodeEncryptedMessage(rxMessage, encryptionChannelOptions); + testMessageEquality(rxMessageDecrypted, txMessage); + }); + }); + } + }); + }); }); diff --git a/test/common/modules/shared_helper.js b/test/common/modules/shared_helper.js index c3265b1b26..4d8581ab22 100644 --- a/test/common/modules/shared_helper.js +++ b/test/common/modules/shared_helper.js @@ -8,9 +8,12 @@ define([ 'test/common/modules/client_module', 'test/common/modules/testapp_manager', 'async', -], function (testAppModule, clientModule, testAppManager, async) { + 'chai', +], function (testAppModule, clientModule, testAppManager, async, chai) { var utils = clientModule.Ably.Realtime.Utils; var platform = clientModule.Ably.Realtime.Platform; + var BufferUtils = platform.BufferUtils; + var expect = chai.expect; clientModule.Ably.Realtime.ConnectionManager.initTransports(); var availableTransports = utils.keysArray(clientModule.Ably.Realtime.ConnectionManager.supportedTransports), bestTransport = availableTransports[0], @@ -222,6 +225,30 @@ define([ return Math.random().toString().slice(2); } + function testMessageEquality(one, two) { + // treat `null` same as `undefined` (using ==, rather than ===) + expect(one.encoding == two.encoding, "Encoding mismatch ('" + one.encoding + "' != '" + two.encoding + "').").to.be + .ok; + + if (typeof one.data === 'string' && typeof two.data === 'string') { + expect(one.data === two.data, 'String data contents mismatch.').to.be.ok; + return; + } + + if (BufferUtils.isBuffer(one.data) && BufferUtils.isBuffer(two.data)) { + expect(BufferUtils.areBuffersEqual(one.data, two.data), 'Buffer data contents mismatch.').to.equal(true); + return; + } + + var json1 = JSON.stringify(one.data); + var json2 = JSON.stringify(two.data); + if (null === json1 || undefined === json1 || null === json2 || undefined === json2) { + expect(false, 'JSON stringify failed.').to.be.ok; + return; + } + expect(json1 === json2, 'JSON data contents mismatch.').to.be.ok; + } + var exports = { setupApp: testAppModule.setup, tearDownApp: testAppModule.tearDown, @@ -255,6 +282,7 @@ define([ arrFilter: arrFilter, whenPromiseSettles: whenPromiseSettles, randomString: randomString, + testMessageEquality: testMessageEquality, }; if (typeof window !== 'undefined') { diff --git a/test/realtime/crypto.test.js b/test/realtime/crypto.test.js index 8952bcf71e..407dbc6e25 100644 --- a/test/realtime/crypto.test.js +++ b/test/realtime/crypto.test.js @@ -25,27 +25,7 @@ define(['ably', 'shared_helper', 'async', 'chai'], function (Ably, helper, async function testMessageEquality(done, one, two) { try { - // treat `null` same as `undefined` (using ==, rather than ===) - expect(one.encoding == two.encoding, "Encoding mismatch ('" + one.encoding + "' != '" + two.encoding + "').").to - .be.ok; - - if (typeof one.data === 'string' && typeof two.data === 'string') { - expect(one.data === two.data, 'String data contents mismatch.').to.be.ok; - return; - } - - if (BufferUtils.isBuffer(one.data) && BufferUtils.isBuffer(two.data)) { - expect(BufferUtils.areBuffersEqual(one.data, two.data), 'Buffer data contents mismatch.').to.equal(true); - return; - } - - var json1 = JSON.stringify(one.data); - var json2 = JSON.stringify(two.data); - if (null === json1 || undefined === json1 || null === json2 || undefined === json2) { - expect(false, 'JSON stringify failed.').to.be.ok; - return; - } - expect(json1 === json2, 'JSON data contents mismatch.').to.be.ok; + helper.testMessageEquality(one, two); } catch (err) { done(err); }