From e2ba43c0415382e89591542ad2ab08c5f1378191 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 28 Nov 2023 12:37:51 -0800 Subject: [PATCH] Stream weights to the GPU when loading a model (#7994) When downloading model weight data, slice it into weight tensors and push them to the GPU eagerly. This avoids storing an extra copy of the weights on CPU, allowing for larger models (1.3B to possibly ~6.7B or larger) to be loaded without causing a V8 OOM crash. When streaming the weights, check CPU_HANDOFF_SIZE_THRESHOLD or WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD to determine whether the weight should be sent to GPU or remain on CPU. This feature is guarded by the streamWeights option in LoadOptions. Since most of TFJS's graph model saving relies on the CPU copy of the model, model saving is disabled when the model was streamed (i.e. it will throw an error since the weights ArrayBuffer is missing). --- tfjs-converter/src/executor/graph_model.ts | 30 +- .../src/executor/graph_model_test.ts | 38 ++- tfjs-core/src/io/http.ts | 68 ++-- tfjs-core/src/io/io.ts | 3 +- tfjs-core/src/io/io_utils.ts | 317 ++++++++++++------ tfjs-core/src/io/io_utils_test.ts | 253 ++++++++------ tfjs-core/src/io/progress.ts | 8 +- tfjs-core/src/io/router_registry_test.ts | 2 +- tfjs-core/src/io/types.ts | 15 +- tfjs-core/src/io/weights_loader.ts | 34 ++ tfjs/yarn.lock | 48 +-- 11 files changed, 532 insertions(+), 284 deletions(-) diff --git a/tfjs-converter/src/executor/graph_model.ts b/tfjs-converter/src/executor/graph_model.ts index ef9c8ad57a1..95334467706 100644 --- a/tfjs-converter/src/executor/graph_model.ts +++ b/tfjs-converter/src/executor/graph_model.ts @@ -23,6 +23,8 @@ import {OperationMapper} from '../operations/operation_mapper'; import {GraphExecutor} from './graph_executor'; import {ResourceManager} from './resource_manager'; +// tslint:disable-next-line: no-imports-from-dist +import {decodeWeightsStream} from '@tensorflow/tfjs-core/dist/io/io_utils'; export const TFHUB_SEARCH_PARAM = '?tfjs-format=file'; export const DEFAULT_MODEL_NAME = 'model.json'; @@ -154,7 +156,12 @@ export class GraphModel implements const loadResult = this.handler.load() as ReturnType; if (util.isPromise(loadResult)) { - return loadResult.then(artifacts => this.loadSync(artifacts)) as Result; + return loadResult.then(artifacts => { + if (artifacts.getWeightStream == null) { + return this.loadSync(artifacts); + } + return this.loadStreaming(artifacts); + }) as Result; } return this.loadSync(loadResult) as Result; @@ -167,6 +174,25 @@ export class GraphModel implements * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ loadSync(artifacts: io.ModelArtifacts) { + const weightMap = this.io.decodeWeights( + artifacts.weightData, artifacts.weightSpecs); + + return this.loadWithWeightMap(artifacts, weightMap); + } + + private async loadStreaming(artifacts: io.ModelArtifacts): Promise { + if (artifacts.getWeightStream == null) { + throw new Error('Model artifacts missing streamWeights function'); + } + + const weightMap = await decodeWeightsStream( + artifacts.getWeightStream(), artifacts.weightSpecs); + + return this.loadWithWeightMap(artifacts, weightMap); + } + + private loadWithWeightMap(artifacts: io.ModelArtifacts, + weightMap: NamedTensorMap) { this.artifacts = artifacts; const graph = this.artifacts.modelTopology as tensorflow.IGraphDef; @@ -184,8 +210,6 @@ export class GraphModel implements this.signature = signature; this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`; - const weightMap = this.io.decodeWeights( - this.artifacts.weightData, this.artifacts.weightSpecs); this.executor = new GraphExecutor( OperationMapper.Instance.transformGraph(graph, this.signature)); this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap); diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index 2b6d4ca104e..4ccd826b1ce 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -25,6 +25,8 @@ import {GraphNode} from '../operations/types'; import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model'; import {HASH_TABLE_MODEL_V2} from './test_data/hash_table_v2_model_loader'; import {STRUCTURED_OUTPUTS_MODEL} from './test_data/structured_outputs_model_loader'; +// tslint:disable-next-line: no-imports-from-dist +import {expectArrayBuffersEqual} from '@tensorflow/tfjs-core/dist/test_util'; const HOST = 'http://example.org'; const MODEL_URL = `${HOST}/model.json`; @@ -125,6 +127,24 @@ const SIMPLE_HTTP_MODEL_LOADER = { } }; +const SIMPLE_STREAMING_MODEL_LOADER = { + load: async () => { + return { + modelTopology: SIMPLE_MODEL, + weightSpecs: weightsManifest, + getWeightStream: () => { + const data = bias.dataSync(); + const blob = new Blob([data]); + return blob.stream(); + }, + format: 'tfjs-graph-model', + generatedBy: '1.15', + convertedBy: '1.3.1', + userDefinedMetadata: {signature: SIGNATURE} + }; + } +}; + const NO_INPUT_SIGNATURE_MODEL_LOADER = { load: async () => { return { @@ -438,7 +458,7 @@ describe('loadGraphModel', () => { }); it('Pass a fetchFunc', async () => { - const fetchFunc = () => {}; + const fetchFunc = (() => {}) as unknown as typeof fetch; spyIo.getLoadHandlers.and.returnValue([CUSTOM_HTTP_MODEL_LOADER]); await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo); expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc}); @@ -594,7 +614,13 @@ describe('Model', () => { describe('simple model', () => { beforeEach(() => { - spyIo.getLoadHandlers.and.returnValue([SIMPLE_HTTP_MODEL_LOADER]); + spyIo.getLoadHandlers.and.callFake((_url: string|string[], + loadOptions?: io.LoadOptions) => { + if (loadOptions.streamWeights) { + return [SIMPLE_STREAMING_MODEL_LOADER]; + } + return [SIMPLE_HTTP_MODEL_LOADER]; + }); spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER); }); it('load', async () => { @@ -776,6 +802,14 @@ describe('Model', () => { expect(model).toBeDefined(); }); + it('should stream graph model weights', async () => { + const model = await loadGraphModel(MODEL_URL, {streamWeights: true}, + spyIo); + expect(model).toBeDefined(); + expectArrayBuffersEqual(model.weights['Const'][0].dataSync(), + bias.dataSync()); + }); + describe('InferenceModel interface', () => { it('should expose inputs', async () => { await model.load(); diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index c30ce501dd3..a8ba2da62ca 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -27,8 +27,8 @@ import {assert} from '../util'; import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils'; import {CompositeArrayBuffer} from './composite_array_buffer'; import {IORouter, IORouterRegistry} from './router_registry'; -import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; -import {loadWeightsAsArrayBuffer} from './weights_loader'; +import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {loadWeightsAsArrayBuffer, streamWeights} from './weights_loader'; const OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; const JSON_TYPE = 'application/json'; @@ -36,7 +36,7 @@ export class HTTPRequest implements IOHandler { protected readonly path: string; protected readonly requestInit: RequestInit; - private readonly fetch: Function; + private readonly fetch: typeof fetch; private readonly weightUrlConverter: (weightName: string) => Promise; readonly DEFAULT_METHOD = 'POST'; @@ -44,14 +44,13 @@ export class HTTPRequest implements IOHandler { static readonly URL_SCHEME_REGEX = /^https?:\/\//; private readonly weightPathPrefix: string; - private readonly onProgress: OnProgressCallback; + private readonly loadOptions: LoadOptions; constructor(path: string, loadOptions?: LoadOptions) { if (loadOptions == null) { loadOptions = {}; } this.weightPathPrefix = loadOptions.weightPathPrefix; - this.onProgress = loadOptions.onProgress; this.weightUrlConverter = loadOptions.weightUrlConverter; if (loadOptions.fetchFunc != null) { @@ -84,6 +83,7 @@ export class HTTPRequest implements IOHandler { 'requestInit is expected to have no pre-existing body, but has one.'); } this.requestInit = loadOptions.requestInit || {}; + this.loadOptions = loadOptions; } async save(modelArtifacts: ModelArtifacts): Promise { @@ -135,15 +135,7 @@ export class HTTPRequest implements IOHandler { } } - /** - * Load model artifacts via HTTP request(s). - * - * See the documentation to `tf.io.http` for details on the saved - * artifacts. - * - * @returns The loaded model artifacts (if loading succeeds). - */ - async load(): Promise { + private async loadModelJSON(): Promise { const modelConfigRequest = await this.fetch(this.path, this.requestInit); if (!modelConfigRequest.ok) { @@ -182,18 +174,45 @@ export class HTTPRequest implements IOHandler { `topology or manifest for weights.`); } + return modelJSON; + } + + /** + * Load model artifacts via HTTP request(s). + * + * See the documentation to `tf.io.http` for details on the saved + * artifacts. + * + * @returns The loaded model artifacts (if loading succeeds). + */ + async load(): Promise { + if (this.loadOptions.streamWeights) { + return this.loadStream(); + } + const modelJSON = await this.loadModelJSON(); return getModelArtifactsForJSON( modelJSON, (weightsManifest) => this.loadWeights(weightsManifest)); } - private async loadWeights(weightsManifest: WeightsManifestConfig): - Promise<[WeightsManifestEntry[], WeightData]> { + private async loadStream(): Promise { + const modelJSON = await this.loadModelJSON(); + const fetchURLs = await this.getWeightUrls(modelJSON.weightsManifest); + const weightSpecs = getWeightSpecs(modelJSON.weightsManifest); + const stream = () => streamWeights(fetchURLs, this.loadOptions); + + return { + ...modelJSON, + weightSpecs, + getWeightStream: stream, + }; + } + + private async getWeightUrls(weightsManifest: WeightsManifestConfig): + Promise { const weightPath = Array.isArray(this.path) ? this.path[1] : this.path; const [prefix, suffix] = parseUrl(weightPath); const pathPrefix = this.weightPathPrefix || prefix; - const weightSpecs = getWeightSpecs(weightsManifest); - const fetchURLs: string[] = []; const urlPromises: Array> = []; for (const weightsGroup of weightsManifest) { @@ -209,12 +228,15 @@ export class HTTPRequest implements IOHandler { if (this.weightUrlConverter) { fetchURLs.push(...await Promise.all(urlPromises)); } + return fetchURLs; + } + + private async loadWeights(weightsManifest: WeightsManifestConfig): + Promise<[WeightsManifestEntry[], WeightData]> { + const fetchURLs = await this.getWeightUrls(weightsManifest); + const weightSpecs = getWeightSpecs(weightsManifest); - const buffers = await loadWeightsAsArrayBuffer(fetchURLs, { - requestInit: this.requestInit, - fetchFunc: this.fetch, - onProgress: this.onProgress - }); + const buffers = await loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions); return [weightSpecs, buffers]; } } diff --git a/tfjs-core/src/io/io.ts b/tfjs-core/src/io/io.ts index 49e9a1e2e06..3c1c8724e11 100644 --- a/tfjs-core/src/io/io.ts +++ b/tfjs-core/src/io/io.ts @@ -22,7 +22,7 @@ import './local_storage'; import {browserFiles} from './browser_files'; import {browserHTTPRequest, http, isHTTPScheme} from './http'; -import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils'; +import {concatenateArrayBuffers, decodeWeights, decodeWeightsStream, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils'; import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough'; import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types'; @@ -36,6 +36,7 @@ export { CompositeArrayBuffer, concatenateArrayBuffers, decodeWeights, + decodeWeightsStream, encodeWeights, fromMemory, fromMemorySync, diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index fa9005a9ba8..25dbe5fd46d 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -23,6 +23,11 @@ import {sizeFromShape} from '../util'; import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightData, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {CompositeArrayBuffer} from './composite_array_buffer'; +import {Tensor} from '../tensor'; +import {backend} from '../globals'; +import {DataId} from '../tensor_info'; +import {env} from '../environment'; +import {getBackend} from '../globals'; /** Number of bytes reserved for the length of the string. (32bit integer). */ const NUM_BYTES_STRING_LENGTH = 4; @@ -117,120 +122,234 @@ export function decodeWeights( // TODO(adarob, cais): Support quantization. const compositeBuffer = new CompositeArrayBuffer(weightData); const out: NamedTensorMap = {}; - let float16Decode: (buffer: Uint16Array) => Float32Array | undefined; let offset = 0; for (const spec of specs) { - const name = spec.name; - const dtype = spec.dtype; - const shape = spec.shape; - const size = sizeFromShape(shape); - let values: TypedArray|string[]|Uint8Array[]; - - if ('quantization' in spec) { - const quantization = spec.quantization; - if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { - if (!('min' in quantization && 'scale' in quantization)) { - throw new Error( - `Weight ${spec.name} with quantization ${quantization.dtype} ` + - `doesn't have corresponding metadata min and scale.`); - } - } else if (quantization.dtype === 'float16') { - if (dtype !== 'float32') { - throw new Error( - `Weight ${spec.name} is quantized with ${quantization.dtype} ` + - `which only supports weights of type float32 not ${dtype}.`); - } - } else { + const byteLength = getWeightBytelength(spec, (start, end) => { + return compositeBuffer.slice(offset + start, offset + end); + }); + out[spec.name] = decodeWeight(spec, compositeBuffer + .slice(offset, offset + byteLength)); + offset += byteLength; + } + return out; +} + +function getWeightBytelength(spec: WeightsManifestEntry, + slice: (start: number, end: number) => ArrayBuffer): number { + + const size = sizeFromShape(spec.shape); + let bytesPerValue: number; + if ('quantization' in spec) { + const quantization = spec.quantization; + bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; + } else if (spec.dtype === 'string') { + // Can not statically determine string length. + let byteLength = 0; + for (let i = 0; i < size; i++) { + byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array( + slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0]; + } + return byteLength; + } else { + bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype]; + } + + return size * bytesPerValue; +} + +async function getWeightBytelengthAsync( + spec: WeightsManifestEntry, + slice: (start: number, end: number) => Promise +): Promise { + + const size = sizeFromShape(spec.shape); + let bytesPerValue: number; + if ('quantization' in spec) { + const quantization = spec.quantization; + bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; + } else if (spec.dtype === 'string') { + // Can not statically determine string length. + let byteLength = 0; + for (let i = 0; i < size; i++) { + byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array( + await slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0]; + } + return byteLength; + } else { + bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype]; + } + + return size * bytesPerValue; +} + +function decodeWeight( + spec: WeightsManifestEntry, + byteBuffer: ArrayBuffer): Tensor { + + const name = spec.name; + const dtype = spec.dtype; + const shape = spec.shape; + const size = sizeFromShape(shape); + let values: TypedArray | string[] | Uint8Array[]; + let offset = 0; + + if ('quantization' in spec) { + const quantization = spec.quantization; + if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { + if (!('min' in quantization && 'scale' in quantization)) { throw new Error( - `Weight ${spec.name} has unknown ` + - `quantization dtype ${quantization.dtype}. ` + - `Supported quantization dtypes are: ` + - `'uint8', 'uint16', and 'float16'.`); + `Weight ${spec.name} with quantization ${quantization.dtype} ` + + `doesn't have corresponding metadata min and scale.`); } - const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; - const byteBuffer = - compositeBuffer.slice(offset, offset + size * quantizationSizeFactor); - const quantizedArray = (quantization.dtype === 'uint8') ? - new Uint8Array(byteBuffer) : - new Uint16Array(byteBuffer); - if (dtype === 'float32') { - if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { - values = new Float32Array(quantizedArray.length); - for (let i = 0; i < quantizedArray.length; i++) { - const v = quantizedArray[i]; - values[i] = v * quantization.scale + quantization.min; - } - } else if (quantization.dtype === 'float16') { - if (float16Decode === undefined) { - float16Decode = getFloat16Decoder(); - } - values = float16Decode(quantizedArray as Uint16Array); - } else { - throw new Error( - `Unsupported quantization type ${quantization.dtype} ` + - `for weight type float32.`); - } - } else if (dtype === 'int32') { - if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { - throw new Error( - `Unsupported quantization type ${quantization.dtype} ` + - `for weight type int32.`); - } - values = new Int32Array(quantizedArray.length); + } else if (quantization.dtype === 'float16') { + if (dtype !== 'float32') { + throw new Error( + `Weight ${spec.name} is quantized with ${quantization.dtype} ` + + `which only supports weights of type float32 not ${dtype}.`); + } + } else { + throw new Error( + `Weight ${spec.name} has unknown ` + + `quantization dtype ${quantization.dtype}. ` + + `Supported quantization dtypes are: ` + + `'uint8', 'uint16', and 'float16'.`); + } + const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; + const quantizedArray = (quantization.dtype === 'uint8') ? + new Uint8Array(byteBuffer) : + new Uint16Array(byteBuffer); + if (dtype === 'float32') { + if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { + values = new Float32Array(quantizedArray.length); for (let i = 0; i < quantizedArray.length; i++) { const v = quantizedArray[i]; - values[i] = Math.round(v * quantization.scale + quantization.min); + values[i] = v * quantization.scale + quantization.min; } + } else if (quantization.dtype === 'float16') { + // TODO: This is inefficient. Make getFloat16Decoder efficient. + const float16Decode = getFloat16Decoder(); + values = float16Decode(quantizedArray as Uint16Array); } else { - throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); + throw new Error( + `Unsupported quantization type ${quantization.dtype} ` + + `for weight type float32.`); } - offset += size * quantizationSizeFactor; - } else if (dtype === 'string') { - const size = sizeFromShape(spec.shape); - values = []; - for (let i = 0; i < size; i++) { - const byteLength = new Uint32Array( - compositeBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; - offset += NUM_BYTES_STRING_LENGTH; - const bytes = new Uint8Array( - compositeBuffer.slice(offset, offset + byteLength)); - (values as Uint8Array[]).push(bytes); - offset += byteLength; + } else if (dtype === 'int32') { + if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { + throw new Error( + `Unsupported quantization type ${quantization.dtype} ` + + `for weight type int32.`); + } + values = new Int32Array(quantizedArray.length); + for (let i = 0; i < quantizedArray.length; i++) { + const v = quantizedArray[i]; + values[i] = Math.round(v * quantization.scale + quantization.min); } } else { - const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; - const byteBuffer = compositeBuffer.slice(offset, - offset + size * dtypeFactor); - - if (dtype === 'float32') { - values = new Float32Array(byteBuffer); - } else if (dtype === 'int32') { - values = new Int32Array(byteBuffer); - } else if (dtype === 'bool') { - values = new Uint8Array(byteBuffer); - } else if (dtype === 'complex64') { - values = new Float32Array(byteBuffer); - const real = new Float32Array(values.length / 2); - const image = new Float32Array(values.length / 2); - for (let i = 0; i < real.length; i++) { - real[i] = values[i * 2]; - image[i] = values[i * 2 + 1]; - } - const realTensor = tensor(real, shape, 'float32'); - const imageTensor = tensor(image, shape, 'float32'); - out[name] = complex(realTensor, imageTensor); - realTensor.dispose(); - imageTensor.dispose(); - } else { - throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); + throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); + } + offset += size * quantizationSizeFactor; + } else if (dtype === 'string') { + const size = sizeFromShape(spec.shape); + values = []; + for (let i = 0; i < size; i++) { + const byteLength = new Uint32Array( + byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + offset += NUM_BYTES_STRING_LENGTH; + const bytes = new Uint8Array( + byteBuffer.slice(offset, offset + byteLength)); + (values as Uint8Array[]).push(bytes); + offset += byteLength; + } + } else { + const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; + if (dtype === 'float32') { + values = new Float32Array(byteBuffer); + } else if (dtype === 'int32') { + values = new Int32Array(byteBuffer); + } else if (dtype === 'bool') { + values = new Uint8Array(byteBuffer); + } else if (dtype === 'complex64') { + values = new Float32Array(byteBuffer); + const real = new Float32Array(values.length / 2); + const image = new Float32Array(values.length / 2); + for (let i = 0; i < real.length; i++) { + real[i] = values[i * 2]; + image[i] = values[i * 2 + 1]; } - offset += size * dtypeFactor; + const realTensor = tensor(real, shape, 'float32'); + const imageTensor = tensor(image, shape, 'float32'); + const complexTensor = complex(realTensor, imageTensor); + realTensor.dispose(); + imageTensor.dispose(); + return complexTensor; + } else { + throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } - if (dtype !== 'complex64') { - out[name] = tensor(values, shape, dtype); + offset += size * dtypeFactor; + } + return tensor(values, shape, dtype); +} + +async function readToLength(reader: ReadableStreamDefaultReader, + initialData: ArrayBuffer, + length: number): Promise { + let data = new Uint8Array(initialData); + + while (data.byteLength < length) { + const {done, value} = await reader.read(); + if (done && value == null) { + const missing = length - data.byteLength; + throw new Error(`Reader is done but ${missing} bytes are still expected`); } + + // TODO: Don't create a new array every loop. + const newData = new Uint8Array(data.length + value.byteLength); + newData.set(data, 0); + newData.set(new Uint8Array(value), data.length); + data = newData; } - return out; + + return data.buffer; +} + +export async function decodeWeightsStream( + weightStream: ReadableStream, + specs: WeightsManifestEntry[]): Promise { + + const tensors: NamedTensorMap = {}; + const reader = weightStream.getReader(); + let data = new ArrayBuffer(0); + + for (const spec of specs) { + const byteLength = await getWeightBytelengthAsync(spec, + async (start, end) => { + data = await readToLength(reader, data, end); + return data.slice(start, end); + }); + data = await readToLength(reader, data, byteLength); + + // Slice the tensor out + const tensorData = data.slice(0, byteLength); + data = data.slice(byteLength); + + const weightTensor = decodeWeight(spec, tensorData); + tensors[spec.name] = weightTensor; + + // TODO(mattsoulanille): Better way to call uploadToGPU. + // TODO(mattsoulanille): Make this work for webgl too. + if (getBackend() === 'webgpu') { + const b = backend(); + + if ('uploadToGPU' in b && + sizeFromShape(weightTensor.shape) >= (env() + .get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD') as number)) { + (b.uploadToGPU as (dataId: DataId) => void)(weightTensor.dataId); + } + } + } + + return tensors; } /** diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts index 01e497c075b..6d710288537 100644 --- a/tfjs-core/src/io/io_utils_test.ts +++ b/tfjs-core/src/io/io_utils_test.ts @@ -469,118 +469,153 @@ describeWithFlags('encodeWeights', ALL_ENVS, () => { }); describeWithFlags('decodeWeights', {}, () => { - it('Mixed dtype tensors', async () => { - const tensors: NamedTensorMap = { - x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), - x2: scalar(13.37, 'float32'), - x3: tensor1d([true, false, false], 'bool'), - x4: tensor2d([['здраво', 'a'], ['b', 'c']], [2, 2], 'string'), - x5: tensor1d([''], 'string'), // Empty string. - x6: scalar('hello'), // Single string. - y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), - y2: tf.complex([1, 1], [2, 2]) - }; - const dataAndSpecs = await tf.io.encodeWeights(tensors); - const data = dataAndSpecs.data; - const specs = dataAndSpecs.specs; - const decoded = tf.io.decodeWeights(data, specs); - expect(Object.keys(decoded).length).toEqual(8); - expectArraysEqual(await decoded['x1'].data(), await tensors['x1'].data()); - expectArraysEqual(await decoded['x2'].data(), await tensors['x2'].data()); - expectArraysEqual(await decoded['x3'].data(), await tensors['x3'].data()); - expectArraysEqual(await decoded['x4'].data(), await tensors['x4'].data()); - expectArraysEqual(await decoded['x5'].data(), await tensors['x5'].data()); - expectArraysEqual(await decoded['x6'].data(), await tensors['x6'].data()); - expectArraysEqual(await decoded['y1'].data(), await tensors['y1'].data()); - expectArraysEqual(await decoded['y2'].data(), await tensors['y2'].data()); - }); - - it('Unsupported dtype raises Error', () => { - const buffer = new ArrayBuffer(4); - // tslint:disable-next-line:no-any - const specs: any = [ - { - name: 'x', - dtype: 'int16', - shape: [], - }, - {name: 'y', dtype: 'int16', shape: []} - ]; - expect(() => tf.io.decodeWeights(buffer, specs)) - .toThrowError(/Unsupported dtype in weight \'x\': int16/); - }); - - it('support quantization uint8 weights', async () => { - const manifestSpecs: WeightsManifestEntry[] = [ - { - 'name': 'weight0', - 'dtype': 'float32', - 'shape': [3], - 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'} - }, - { - 'name': 'weight1', - 'dtype': 'int32', - 'shape': [3], - 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'} + function toStream(buffer: ArrayBuffer): ReadableStream { + let position = 0; + const chunkSize = 14; // something relatively small for testing + return new ReadableStream({ + pull: (controller) => { + if (position < buffer.byteLength) { + const chunk = buffer.slice(position, position + chunkSize); + position += chunkSize; + controller.enqueue(chunk); + } else { + controller.close(); + } } - ]; - const data = new Uint8Array([0, 48, 255, 0, 48, 255]); - const decoded = tf.io.decodeWeights(data.buffer, manifestSpecs); - const weight0 = decoded['weight0']; - expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); - expect(weight0.shape).toEqual([3]); - expect(weight0.dtype).toEqual('float32'); + }); + } + + async function decodeAsBuffer(data: ArrayBuffer, + specs: tf.io.WeightsManifestEntry[]) { + const result = tf.io.decodeWeights(data, specs); + // Make sure it doesn't return a promise. + expect(result).not.toBeInstanceOf(Promise); + // Wrap it in a promise to work with the tests. + return Promise.resolve(result); + } + + async function decodeAsStream(data: ArrayBuffer, + specs: tf.io.WeightsManifestEntry[]) { + return tf.io.decodeWeightsStream(toStream(data), specs); + } + + for (const [name, decode] of [['from arraybuffer', decodeAsBuffer], + ['from stream', decodeAsStream]] as const) { + describe(name, () => { + it('Mixed dtype tensors', async () => { + const tensors: NamedTensorMap = { + x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), + x2: scalar(13.37, 'float32'), + x3: tensor1d([true, false, false], 'bool'), + x4: tensor2d([['здраво', 'a'], ['b', 'c']], [2, 2], 'string'), + x5: tensor1d([''], 'string'), // Empty string. + x6: scalar('hello'), // Single string. + y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), + y2: tf.complex([1, 1], [2, 2]) + }; + const dataAndSpecs = await tf.io.encodeWeights(tensors); + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs; + const res = await decode(data, specs); + expect(Object.keys(res).length).toEqual(8); + expectArraysEqual(await res['x1'].data(), await tensors['x1'].data()); + expectArraysEqual(await res['x2'].data(), await tensors['x2'].data()); + expectArraysEqual(await res['x3'].data(), await tensors['x3'].data()); + expectArraysEqual(await res['x4'].data(), await tensors['x4'].data()); + expectArraysEqual(await res['x5'].data(), await tensors['x5'].data()); + expectArraysEqual(await res['x6'].data(), await tensors['x6'].data()); + expectArraysEqual(await res['y1'].data(), await tensors['y1'].data()); + expectArraysEqual(await res['y2'].data(), await tensors['y2'].data()); + }); - const weight1 = decoded['weight1']; - expectArraysEqual(await weight1.data(), [-1, 4, 25]); - expect(weight1.shape).toEqual([3]); - expect(weight1.dtype).toEqual('int32'); - }); + it('Unsupported dtype raises Error', async () => { + const buffer = new ArrayBuffer(4); + // tslint:disable-next-line:no-any + const specs: any = [ + { + name: 'x', + dtype: 'int16', + shape: [], + }, + {name: 'y', dtype: 'int16', shape: []} + ]; + await expectAsync(decode(buffer, specs)) + .toBeRejectedWithError(/Unsupported dtype in weight \'x\': int16/); + }); - it('support quantization uint16 weights', async () => { - const manifestSpecs: WeightsManifestEntry[] = [ - { - 'name': 'weight0', - 'dtype': 'float32', - 'shape': [3], - 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint16'} - }, - { - 'name': 'weight1', - 'dtype': 'int32', - 'shape': [3], - 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint16'} - } - ]; - const data = new Uint16Array([0, 48, 255, 0, 48, 255]); - const decoded = tf.io.decodeWeights(data.buffer, manifestSpecs); - const weight0 = decoded['weight0']; - expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); - expect(weight0.shape).toEqual([3]); - expect(weight0.dtype).toEqual('float32'); - - const weight1 = decoded['weight1']; - expectArraysEqual(await weight1.data(), [-1, 4, 25]); - expect(weight1.shape).toEqual([3]); - expect(weight1.dtype).toEqual('int32'); - }); - it('support quantization float16 weights', async () => { - const manifestSpecs: WeightsManifestEntry[] = [ - { - name: 'weight0', - dtype: 'float32', - shape: [3], - quantization: { dtype: 'float16' }, - }, - ]; - const data = new Uint16Array([13312, 14336, 14848]); - const decoded = tf.io.decodeWeights(data.buffer, manifestSpecs); - const weight0 = decoded['weight0']; - expectArraysClose(await weight0.data(), [0.25, 0.5, 0.75]); - expect(weight0.shape).toEqual([3]); - expect(weight0.dtype).toEqual('float32'); - }); + it('support quantization uint8 weights', async () => { + const manifestSpecs: WeightsManifestEntry[] = [ + { + 'name': 'weight0', + 'dtype': 'float32', + 'shape': [3], + 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'} + }, + { + 'name': 'weight1', + 'dtype': 'int32', + 'shape': [3], + 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'} + } + ]; + const data = new Uint8Array([0, 48, 255, 0, 48, 255]); + const decoded = await decode(data.buffer, manifestSpecs); + const weight0 = decoded['weight0']; + expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); + expect(weight0.shape).toEqual([3]); + expect(weight0.dtype).toEqual('float32'); + + const weight1 = decoded['weight1']; + expectArraysEqual(await weight1.data(), [-1, 4, 25]); + expect(weight1.shape).toEqual([3]); + expect(weight1.dtype).toEqual('int32'); + }); + + it('support quantization uint16 weights', async () => { + const manifestSpecs: WeightsManifestEntry[] = [ + { + 'name': 'weight0', + 'dtype': 'float32', + 'shape': [3], + 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint16'} + }, + { + 'name': 'weight1', + 'dtype': 'int32', + 'shape': [3], + 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint16'} + } + ]; + const data = new Uint16Array([0, 48, 255, 0, 48, 255]); + const decoded = await decode(data.buffer, manifestSpecs); + const weight0 = decoded['weight0']; + expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); + expect(weight0.shape).toEqual([3]); + expect(weight0.dtype).toEqual('float32'); + + const weight1 = decoded['weight1']; + expectArraysEqual(await weight1.data(), [-1, 4, 25]); + expect(weight1.shape).toEqual([3]); + expect(weight1.dtype).toEqual('int32'); + }); + it('support quantization float16 weights', async () => { + const manifestSpecs: WeightsManifestEntry[] = [ + { + name: 'weight0', + dtype: 'float32', + shape: [3], + quantization: { dtype: 'float16' }, + }, + ]; + const data = new Uint16Array([13312, 14336, 14848]); + const decoded = await decode(data.buffer, manifestSpecs); + const weight0 = decoded['weight0']; + expectArraysClose(await weight0.data(), [0.25, 0.5, 0.75]); + expect(weight0.shape).toEqual([3]); + expect(weight0.dtype).toEqual('float32'); + }); + }); + } }); describe('stringByteLength', () => { diff --git a/tfjs-core/src/io/progress.ts b/tfjs-core/src/io/progress.ts index 8d6b3d7fa8a..73e1e19d54c 100644 --- a/tfjs-core/src/io/progress.ts +++ b/tfjs-core/src/io/progress.ts @@ -27,8 +27,8 @@ import {OnProgressCallback} from './types'; * @param startFraction Optional fraction start. Default to 0. * @param endFraction Optional fraction end. Default to 1. */ -export function monitorPromisesProgress( - promises: Array>, onProgress: OnProgressCallback, +export function monitorPromisesProgress( + promises: Array>, onProgress: OnProgressCallback, startFraction?: number, endFraction?: number) { checkPromises(promises); startFraction = startFraction == null ? 0 : startFraction; @@ -36,7 +36,7 @@ export function monitorPromisesProgress( checkFraction(startFraction, endFraction); let resolvedPromise = 0; - const registerMonitor = (promise: Promise<{}>) => { + const registerMonitor = (promise: Promise) => { promise.then(value => { const fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction); @@ -47,7 +47,7 @@ export function monitorPromisesProgress( return promise; }; - function checkPromises(promises: Array>): void { + function checkPromises(promises: Array>): void { assert( promises != null && Array.isArray(promises) && promises.length > 0, () => 'promises must be a none empty array'); diff --git a/tfjs-core/src/io/router_registry_test.ts b/tfjs-core/src/io/router_registry_test.ts index 834e8e3d10e..079a03a5602 100644 --- a/tfjs-core/src/io/router_registry_test.ts +++ b/tfjs-core/src/io/router_registry_test.ts @@ -136,7 +136,7 @@ describeWithFlags('IORouterRegistry', BROWSER_ENVS, () => { const loadOptions: LoadOptions = { onProgress: (fraction: number) => {}, - fetchFunc: () => {} + fetchFunc: ((() => {}) as unknown as typeof fetch), }; const loadHandler = tf.io.getLoadHandlers('foo:///123', loadOptions); expect(loadHandler.length).toEqual(1); diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index 2dc0893a82f..177884f2ef1 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -257,6 +257,13 @@ export declare interface ModelArtifacts { */ weightData?: WeightData; + /** + * Returns a stream of the weights. Some models are too large to fit in + * V8's memory heap, and `getWeightStream` loads their weights without storing + * them all in memory at the same time. + */ + getWeightStream?: () => ReadableStream; + /** * Hard-coded format name for models saved from TensorFlow.js or converted * by TensorFlow.js Converter. @@ -482,7 +489,7 @@ export interface LoadOptions { /** * A function used to override the `window.fetch` function. */ - fetchFunc?: Function; + fetchFunc?: typeof fetch; /** * Strict loading model: whether extraneous weights or missing @@ -532,6 +539,12 @@ export interface LoadOptions { * With this func you can convert the weight file name to any URL. */ weightUrlConverter?: (weightFileName: string) => Promise; + + /** + * Whether to stream the model directly to the backend or cache all its + * weights on CPU first. Useful for large models. + */ + streamWeights?: boolean; } /** diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 8ad0ef2f85b..9a09a798c45 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -71,6 +71,40 @@ export async function loadWeightsAsArrayBuffer( return buffers; } +export function streamWeights(fetchURLs: string[], loadOptions: LoadOptions): ReadableStream { + const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : + loadOptions.fetchFunc; + + let fetchIndex = 0; + let chunkReader: ReadableStreamDefaultReader | undefined; + loadOptions.onProgress?.(0); + return new ReadableStream({ + pull: async (controller) => { + while (fetchIndex < fetchURLs.length) { + if (!chunkReader) { + const body = (await fetchFunc(fetchURLs[fetchIndex], + loadOptions.requestInit, + {isBinary: true})).body; + + chunkReader = body.getReader(); + } + + const {done, value} = await chunkReader.read(); + + if (done) { + fetchIndex++; + chunkReader = undefined; + loadOptions.onProgress?.(fetchIndex / fetchURLs.length); + continue; + } + controller.enqueue(value); + return; + } + controller.close(); + }, + }); +} + /** * Reads a weights manifest JSON configuration, fetches the weights and * returns them as `Tensor`s. diff --git a/tfjs/yarn.lock b/tfjs/yarn.lock index efddf89820c..5dbc13b3f00 100644 --- a/tfjs/yarn.lock +++ b/tfjs/yarn.lock @@ -1976,14 +1976,6 @@ resolved "https://registry.yarnpkg.com/@types/minimatch/-/minimatch-3.0.4.tgz#f0ec25dbf2f0e4b18647313ac031134ca5b24b21" integrity sha512-1z8k4wzFnNjVK/tlxvrWuK5WMt6mydWWP7+zvH5eFep4oj+UkrfiJTRtjCeBXNpwaA/FYqqtb4/QS4ianFpIRA== -"@types/node-fetch@^2.1.2": - version "2.6.4" - resolved "https://registry.yarnpkg.com/@types/node-fetch/-/node-fetch-2.6.4.tgz#1bc3a26de814f6bf466b25aeb1473fa1afe6a660" - integrity sha512-1ZX9fcN4Rvkvgv4E6PAY5WXUFWFcRWxZa3EW83UjycOB9ljJCedb2CupIP4RZMEwF/M3eTcCihbBRgwtGbg5Rg== - dependencies: - "@types/node" "*" - form-data "^3.0.0" - "@types/node@*", "@types/node@>=10.0.0": version "18.11.9" resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4" @@ -2153,11 +2145,6 @@ async@^3.0.1: resolved "https://registry.yarnpkg.com/async/-/async-3.2.0.tgz#b3a2685c5ebb641d3de02d161002c60fc9f85720" integrity sha512-TR2mEZFVOj2pLStYxLht7TyfuRzaydfpxr3k9RpHIzMgw7A64dzsdqCxH1WJyQdoe8T10nDXd9wnEigmiuHIZw== -asynckit@^0.4.0: - version "0.4.0" - resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79" - integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q== - available-typed-arrays@^1.0.2: version "1.0.2" resolved "https://registry.yarnpkg.com/available-typed-arrays/-/available-typed-arrays-1.0.2.tgz#6b098ca9d8039079ee3f77f7b783c4480ba513f5" @@ -2586,13 +2573,6 @@ combine-source-map@^0.8.0: lodash.memoize "~3.0.3" source-map "~0.5.3" -combined-stream@^1.0.8: - version "1.0.8" - resolved "https://registry.yarnpkg.com/combined-stream/-/combined-stream-1.0.8.tgz#c3d45a8b34fd730631a110a8a2520682b31d5a7f" - integrity sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg== - dependencies: - delayed-stream "~1.0.0" - commander@^2.12.1, commander@^2.20.0: version "2.20.3" resolved "https://registry.yarnpkg.com/commander/-/commander-2.20.3.tgz#fd485e84c03eb4881c20722ba48035e8531aeb33" @@ -2802,11 +2782,6 @@ define-properties@^1.1.3: dependencies: object-keys "^1.0.12" -delayed-stream@~1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/delayed-stream/-/delayed-stream-1.0.0.tgz#df3ae199acadfb7d440aaae0b29e2272b24ec619" - integrity sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ== - depd@~1.1.2: version "1.1.2" resolved "https://registry.yarnpkg.com/depd/-/depd-1.1.2.tgz#9bcd52e14c097763e749b274c4346ed2e560b5a9" @@ -3089,15 +3064,6 @@ foreach@^2.0.5: resolved "https://registry.yarnpkg.com/foreach/-/foreach-2.0.5.tgz#0bee005018aeb260d0a3af3ae658dd0136ec1b99" integrity sha1-C+4AUBiusmDQo6865ljdATbsG5k= -form-data@^3.0.0: - version "3.0.1" - resolved "https://registry.yarnpkg.com/form-data/-/form-data-3.0.1.tgz#ebd53791b78356a99af9a300d4282c4d5eb9755f" - integrity sha512-RHkBKtLWUVwd7SqRIvCZMEvAMoGUp0XU+seQiZejj0COz3RI3hWP4sCv3gZWWLjJTd7rGwcsF5eKZGii0r/hbg== - dependencies: - asynckit "^0.4.0" - combined-stream "^1.0.8" - mime-types "^2.1.12" - from@~0: version "0.1.7" resolved "https://registry.yarnpkg.com/from/-/from-0.1.7.tgz#83c60afc58b9c56997007ed1a768b3ab303a44fe" @@ -3927,13 +3893,6 @@ mime-db@1.52.0: resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.52.0.tgz#bbabcdc02859f4987301c856e3387ce5ec43bf70" integrity sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg== -mime-types@^2.1.12, mime-types@~2.1.34: - version "2.1.35" - resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.35.tgz#381a871b62a734450660ae3deee44813f70d959a" - integrity sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw== - dependencies: - mime-db "1.52.0" - mime-types@~2.1.24: version "2.1.34" resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.34.tgz#5a712f9ec1503511a945803640fafe09d3793c24" @@ -3941,6 +3900,13 @@ mime-types@~2.1.24: dependencies: mime-db "1.51.0" +mime-types@~2.1.34: + version "2.1.35" + resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.35.tgz#381a871b62a734450660ae3deee44813f70d959a" + integrity sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw== + dependencies: + mime-db "1.52.0" + mime@^2.5.2: version "2.6.0" resolved "https://registry.yarnpkg.com/mime/-/mime-2.6.0.tgz#a2a682a95cd4d0cb1d6257e28f83da7e35800367"