Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid allocating a large arraybuffer when loading weights #7598

Merged

Conversation

mattsoulanille
Copy link
Member

@mattsoulanille mattsoulanille commented Apr 18, 2023

The loadWeights function loads weights in 4MB chunks and then concatenates them into a single large ArrayBuffer. That ArrayBuffer is used for splitting the weights data back up into tensors. Allocating large ArrayBuffers (3.5GB) can be unstable on Chrome, so this PR avoids this allocation, instead slicing the weights out of the chunks manually.

The implementation wraps the array of weights (stored as ArrayBuffer[]) in a new CompositeArrayBuffer class. This class implements slice by copying the desired range out of the buffer(s) that it overlaps with.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@mattsoulanille
Copy link
Member Author

@pyu10055 This isn't a full review request yet, but I'm interested in knowing which approach you think is better. Thanks!

@chunnienc
Copy link
Collaborator

chunnienc commented Apr 18, 2023

If the size of each chunk is fixed (4MB) except for the last one, you can get start and end offsets in O(1) by division and module. No need to do binary search or two-pointers approach on sorted data.

In terms of API, I prefer option 2, which is a better implementation for separation of concerns. In the weight loader I just need to specify where to slice and don't need to worry about how to slice. You cam implement lazy/offline slicer If performance is a concern.

@mattsoulanille
Copy link
Member Author

If the size of each chunk is fixed (4MB) except for the last one, you can get start and end offsets in O(1) by division and module. No need to do binary search or two-pointers approach on sorted data.

In terms of API, I prefer option 2, which is a better implementation for separation of concerns. In the weight loader I just need to specify where to slice and don't need to worry about how to slice. You cam implement lazy/offline slicer If performance is a concern.

Unfortunately, there's no guarantee on the size of the chunks. We let people configure it when converting the model. They should all be the same size, but I'd even hesitate to assume that, since it seems a bit flaky.

I agree with you and also prefer option 2. I'll implement it as a binsearch for now, and if we need better perf, I can try to automatically detect the chunk size or make it check chunks near the last one read before doing a full binsearch.

@mattsoulanille mattsoulanille marked this pull request as ready for review April 18, 2023 21:33
@mattsoulanille
Copy link
Member Author

mattsoulanille commented Apr 18, 2023

Looking at this again, it doesn't actually prevent us from storing the weights in a single ArrayBuffer. ModelArtifacts contains the weightData key, which stores the model weights as a single ArrayBuffer. This is constructed by IOHandlers like http, which load the weights and then concatenate them into a single ArrayBuffer.

I'll leave this PR as-is and create a new one to fix this issue.

@mattsoulanille mattsoulanille changed the title Avoid allocating a large arraybuffer to store the model weights Avoid allocating a large arraybuffer when loading weights Apr 18, 2023
@mattsoulanille mattsoulanille force-pushed the avoid_large_arraybuffer branch from 159fcd1 to 6a07547 Compare April 20, 2023 00:12
buffer: ArrayBuffer,
};

export class CompositeArrayBuffer {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be used in another PR that enables large model weights to be stored in a list of ArrayBuffers. That's why it's exported here.

@mattsoulanille
Copy link
Member Author

Looking at this again, it doesn't actually prevent us from storing the weights in a single ArrayBuffer. ModelArtifacts contains the weightData key, which stores the model weights as a single ArrayBuffer. This is constructed by IOHandlers like http, which load the weights and then concatenate them into a single ArrayBuffer.

I'll leave this PR as-is and create a new one to fix this issue.

I'm sending this out for review since it's easier to review separately from the other part of the large weights fix. I'll submit the other part, which integrates this code with the rest of the codebase, in a separate PR.

@mattsoulanille mattsoulanille force-pushed the avoid_large_arraybuffer branch from ee87834 to b86ee6e Compare April 20, 2023 00:31
Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @chunnienc and @mattsoulanille)


tfjs-core/src/io/weights_loader.ts line 290 at r2 (raw file):

      }
    }
    return outputBuffer

missing ;


tfjs-core/src/io/weights_loader.ts line 292 at r2 (raw file):

    return outputBuffer
  }
  private search(byteIndex: number) {

this could be improved if the searching is in order, I believe that is how our weights are setup.


tfjs-core/src/io/weights_loader.ts line 245 at r6 (raw file):

Previously, mattsoulanille (Matthew Soulanille) wrote…

This will be used in another PR that enables large model weights to be stored in a list of ArrayBuffers. That's why it's exported here.

It might be good to be in a separate file


tfjs-core/src/io/weights_loader.ts line 272 at r9 (raw file):

    let start = 0;

    for (let i = 0; i < buffers.length; i++) {

start from 1?


// Create the ranges, including their start and end points.
const end = start + buffer.byteLength;
this.ranges.push({buffer, start, end,});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove ',' or format to multiple lines

}

slice(start = 0, end = this.byteLength): ArrayBuffer {
// NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert start and end to Number with Number(...) before checking NaN?

Copy link
Collaborator

@chunnienc chunnienc Apr 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update:
I assume you add these nan checks because you think there may be calls from JS now or future which ignores the typescript type check. In these way I'd suggest to do start = Number(start) since isNaN('123') returns false and ArrayBuffer.prototype.slice accepts the numbers in strings.

Copy link
Member Author

@mattsoulanille mattsoulanille Apr 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these NaN checks because some of the tests were failing (they intentionally gave it no datatype, which I think eventually resulted in a NaN being passed to slice (since tfjs didn't know the byte length of the datatype), so I think the tests themselves are correct). I'd like this to match ArrayBuffer.slice as closely as possible, so I implemented your comment.

@@ -245,3 +235,180 @@ export function weightsLoaderFactory(
return weightsTensorMap;
};
}

type BufferRange = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming: range -> chunk/shard/partition
And all related variable and function names

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. That's a much better name. Fixed.

@mattsoulanille mattsoulanille requested a review from pyu10055 April 20, 2023 17:47
Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 3 of 4 files at r11, 1 of 1 files at r12, all commit messages.
Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @chunnienc and @mattsoulanille)

@mattsoulanille mattsoulanille merged commit 3ceace9 into tensorflow:master Apr 20, 2023
mattsoulanille added a commit that referenced this pull request Apr 24, 2023
* webgpu: Fix a bug in softmax (#7607)

* Avoid allocating a large arraybuffer when loading weights (#7598)

The loadWeights function loads weights in 4MB chunks and then concatenates them into a single large ArrayBuffer. That ArrayBuffer is used for splitting the weights data back up into tensors. Allocating large ArrayBuffers (3.5GB) can be unstable on Chrome, so this PR avoids this allocation, instead slicing the weights out of the chunks manually.

The implementation wraps the array of weights (stored as ArrayBuffer[]) in a new CompositeArrayBuffer class. This class implements slice by copying the desired range out of the buffer(s) that it overlaps with.

* Support using a list of ArrayBuffers as model weight data

* Avoid 'Array.flat()'

* Simplify some of the tests

* Do not export 'CompositeArrayBuffer' from tfjs-core

* Update doc for weightData

* Fix tfjs-node

* Remove unused import

---------

Co-authored-by: Jiajia Qin <jiajia.qin@intel.com>
mattsoulanille added a commit that referenced this pull request May 4, 2023
Chrome ArrayBuffers throw allocation errors above 2GB in size. This makes it impossible to load TFJS models above this size in Chrome (even with weight sharding) because model loading involves concatenating all the weights into a single ArrayBuffer.

This PR avoids this concatenation. Instead of slicing the weight tensors out of a single concatenated ArrayBuffer, it keeps the weight buffers in their original shards and slices them using the CompositeArrayBuffer class created in #7598.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants