Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading models with weights above 2GB on Chrome #7609

Merged
merged 14 commits into from
May 4, 2023

Conversation

mattsoulanille
Copy link
Member

@mattsoulanille mattsoulanille commented Apr 20, 2023

Chrome ArrayBuffers throw allocation errors above 2GB in size. This makes it impossible to load TFJS models above this size in Chrome (even with weight sharding) because model loading involves concatenating all the weights into a single ArrayBuffer.

This PR avoids this concatenation. Instead of slicing the weight tensors out of a single concatenated ArrayBuffer, it keeps the weight buffers in their original shards and slices them using the CompositeArrayBuffer class created in #7598.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@mattsoulanille mattsoulanille force-pushed the large_model_weights branch 4 times, most recently from fe844d2 to b21b302 Compare April 20, 2023 20:21
@mattsoulanille mattsoulanille marked this pull request as ready for review April 20, 2023 20:26
Comment on lines -54 to -80

// TODO(cais): Use explicit tf.io.ModelArtifactsInfo return type below once it
// is available.
/**
* Populate ModelArtifactsInfo fields for a model with JSON topology.
* @param modelArtifacts
* @returns A ModelArtifactsInfo object.
*/
export function getModelArtifactsInfoForJSON(
modelArtifacts: tf.io.ModelArtifacts) {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error('Expected JSON model topology, received ArrayBuffer.');
}
return {
dateSaved: new Date(),
modelTopologyType: 'JSON',
modelTopologyBytes: modelArtifacts.modelTopology == null ?
0 :
Buffer.byteLength(JSON.stringify(modelArtifacts.modelTopology), 'utf8'),
weightSpecsBytes: modelArtifacts.weightSpecs == null ?
0 :
Buffer.byteLength(JSON.stringify(modelArtifacts.weightSpecs), 'utf8'),
weightDataBytes: modelArtifacts.weightData == null ?
0 :
modelArtifacts.weightData.byteLength,
};
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicated in tfjs-core/src/io/io_utils.ts

@@ -285,7 +291,7 @@ const useNodeBuffer = typeof Buffer !== 'undefined' &&
*/
export function stringByteLength(str: string): number {
if (useNodeBuffer) {
return Buffer.byteLength(str);
return Buffer.byteLength(str, 'utf8');
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tfjs-node used utf8 in its implementation, so I think it should also be here.

* @returns Result of concatenating `buffers` in order.
*/
export function concatenateArrayBuffers(buffers: ArrayBuffer[]): ArrayBuffer {
export function concatenateArrayBuffers(buffers: ArrayBuffer[]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to be part of CompositeArrayBuffer? like static method CompositeArrayBuffer.join(buffers: ArrayBuffer[]) or through public method new CompositeArrayBuffer(buffers).toArrayBuffer(), which makes it easier to bridge CompositeArrayBuffer with native ArrayBuffer and pass CompositeArrayBuffer around in the future if needed.

Copy link
Member Author

@mattsoulanille mattsoulanille Apr 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was considering that, and my original implementation actually used new CompositeArrayBuffer(buffers).slice(), but I removed it in favor of concatenateArrayBuffers because of an issue with the types in tfjs-converter tests (here was my fix for it in the spy_ops.ts file, but it's a bit hacky).

I'm fine with using the converter spy_ops.ts fix if it'll make the core implementation cleaner. What do you think?

Edit: ...and we can add a toArrayBuffer or static join method instead of using .slice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, I can move composite_array_buffer.ts out of io/

Copy link
Collaborator

@chunnienc chunnienc Apr 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look at the usage of spyOnAllFunctions in tests, and I think the test is something we should fix. A hacky way like what you did is probably fine.

In general, instead of automatically replace everything with spy using spyOnAllFunctions, we should explicitly create an ioSpy object which only contains the function we want to spy, so that we can make the test more controllable and reliable. There are some stuffs exported in io apparently should not be spied, like getWeightSpecs, which is a io helper function instead of a function to do io.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chunnienc I've replaced concatenateArrayBuffers with CompositeArrayBuffer.join in tfjs-core and deprecated concatenateArrayBuffers. We can't replace it in other packages yet because that would introduce a breaking change. Downstream packages could not be used with an earlier version of tfjs-core that does not implement CompositeArrayBuffer (see #7273 for an example of why this is important). We can apply this change to all the packages in the next major release.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it's fine to use it in tests, since users will never run those. I'll swap concatenateArrayBuffers for CompositeArrayBuffer.join in the test files.

@mattsoulanille mattsoulanille force-pushed the large_model_weights branch from 0276e01 to 94ed22e Compare May 3, 2023 21:11
@mattsoulanille mattsoulanille requested a review from chunnienc May 3, 2023 22:37
Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 4 of 18 files at r1, 1 of 1 files at r2, 13 of 13 files at r3, 4 of 4 files at r4, all commit messages.
Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @chunnienc)

@mattsoulanille mattsoulanille merged commit 086e9d8 into tensorflow:master May 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants