Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webgpu: Replace timestamp-query-in-passes with timestamp-query #7714

Merged
merged 12 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 108 additions & 97 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ type TensorData = {
interface DataId {}

export type WebGPUKernelInfo = {
name: string; query: Promise<number>;
name: string,
query: Promise<number>,
};

export type TimerNode = RecursiveArray<WebGPUKernelInfo>|WebGPUKernelInfo;
Expand Down Expand Up @@ -105,10 +106,10 @@ export class WebGPUBackend extends KernelBackend {
thresholdToIncreaseWorkgroups: number;

private activeTimers: TimerNode[];
private currentCommandEncoder: GPUCommandEncoder;
private currentComputePass: GPUComputePassEncoder;
private commandEncoder: GPUCommandEncoder;
private computePassEncoder: GPUComputePassEncoder;
private commandQueueOwnedIds = new WeakSet<DataId>();
private dispatchNumberInEncoder = 0;
private dispatchCountInPass = 0;
private disposed = false;
private downloadWaitMs = 0;
private dummyCanvas: HTMLCanvasElement;
Expand All @@ -118,12 +119,15 @@ export class WebGPUBackend extends KernelBackend {
private pipelineCache:
{[key: string]: GPUComputePipeline|Promise<GPUComputePipeline>};
private programTimersStack: TimerNode[];
private querySet: GPUQuerySet;
private queryResolveBuffer: GPUBuffer = null;
private querySet: GPUQuerySet = null;
private querySetCount = 2;
gyagp marked this conversation as resolved.
Show resolved Hide resolved
private stagingPendingDisposal: GPUBuffer[] = [];
private supportTimeQuery: boolean;
private supportTimestampQuery: boolean;
private uniformPendingDisposal: GPUBuffer[] = [];
private uploadWaitMs = 0;
private hasReadSyncWarned = false;
private hasTimestampQueryWarned = false;

private nextDataId(): number {
return WebGPUBackend.nextDataId++;
Expand All @@ -137,23 +141,16 @@ export class WebGPUBackend extends KernelBackend {
this.pipelineCache = {};
this.device = device;
this.queue = device.queue;
this.currentCommandEncoder = null;
this.currentComputePass = null;
this.supportTimeQuery =
device.features.has('timestamp-query-inside-passes');
this.commandEncoder = null;
this.computePassEncoder = null;
this.adapterInfo = new AdapterInfo(adapterInfo);
this.supportTimestampQuery = this.device.features.has('timestamp-query');
this.thresholdToIncreaseWorkgroups =
this.adapterInfo.intelGPUGeneration >= 12 ? 16 : 8;

this.bufferManager = new BufferManager(this.device);
this.textureManager = new TextureManager(this.device);
this.tensorMap = new DataStorage(this, engine());
if (this.supportTimeQuery) {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
count: 2,
});
}

// Profiling tools like PIX needs this dummy canvas to
// trigger capturing a frame.
Expand Down Expand Up @@ -290,10 +287,9 @@ export class WebGPUBackend extends KernelBackend {
}

submitQueue() {
this.ensureComputePassEnded();
this.queue.submit([this.currentCommandEncoder.finish()]);
this.currentCommandEncoder = null;
this.dispatchNumberInEncoder = 0;
this.queue.submit([this.commandEncoder.finish()]);
this.commandEncoder = null;
this.dispatchCountInPass = 0;

this.commandQueueOwnedIds = new WeakSet<DataId>();

Expand All @@ -313,23 +309,16 @@ export class WebGPUBackend extends KernelBackend {
}

ensureCommandEncoderReady() {
if (!this.currentCommandEncoder) {
this.currentCommandEncoder = this.device.createCommandEncoder();
}
}

ensureComputePassEnded() {
if (this.currentComputePass) {
this.currentComputePass.end();
this.currentComputePass = null;
if (!this.commandEncoder) {
this.commandEncoder = this.device.createCommandEncoder();
}
}

getComputePass() {
if (!this.currentComputePass) {
this.currentComputePass = this.currentCommandEncoder.beginComputePass();
endComputePassEncoder() {
if (this.computePassEncoder) {
this.computePassEncoder.end();
this.computePassEncoder = null;
}
return this.currentComputePass;
}

// Check if parallel compilation is done.
Expand All @@ -356,9 +345,8 @@ export class WebGPUBackend extends KernelBackend {
const stagingBuffer = this.bufferManager.acquireBuffer(
size, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ);
this.ensureCommandEncoderReady();
this.ensureComputePassEnded();
this.currentCommandEncoder.copyBufferToBuffer(
buffer, 0, stagingBuffer, 0, size);
this.endComputePassEncoder();
this.commandEncoder.copyBufferToBuffer(buffer, 0, stagingBuffer, 0, size);
this.submitQueue();

await stagingBuffer.mapAsync(GPUMapMode.READ);
Expand Down Expand Up @@ -431,7 +419,7 @@ export class WebGPUBackend extends KernelBackend {
alphaModes.map(_ => new OffscreenCanvas(canvasWidth, canvasHeight));
const stagingHostStorage = new OffscreenCanvas(canvasWidth, canvasHeight);

this.ensureComputePassEnded();
this.endComputePassEncoder();
stagingDeviceStorage
.map((storage, index) => {
const context = storage.getContext('webgpu');
Expand All @@ -450,7 +438,7 @@ export class WebGPUBackend extends KernelBackend {
const readDataGPUToCPU =
(width: number, height: number, offset: number) => {
this.ensureCommandEncoderReady();
this.currentCommandEncoder.copyBufferToTexture(
this.commandEncoder.copyBufferToTexture(
{
buffer,
bytesPerRow,
Expand Down Expand Up @@ -556,9 +544,8 @@ export class WebGPUBackend extends KernelBackend {
const usage = srcBuffer.usage;
const dstBuffer = this.bufferManager.acquireBuffer(size, usage);
this.ensureCommandEncoderReady();
this.ensureComputePassEnded();
this.currentCommandEncoder.copyBufferToBuffer(
srcBuffer, 0, dstBuffer, 0, size);
this.endComputePassEncoder();
this.commandEncoder.copyBufferToBuffer(srcBuffer, 0, dstBuffer, 0, size);
this.submitQueue();
return dstBuffer;
}
Expand Down Expand Up @@ -627,8 +614,8 @@ export class WebGPUBackend extends KernelBackend {
const usage = srcBuffer.usage;
const buffer = this.bufferManager.acquireBuffer(size, usage);
this.ensureCommandEncoderReady();
this.ensureComputePassEnded();
this.currentCommandEncoder.copyBufferToBuffer(
this.endComputePassEncoder();
this.commandEncoder.copyBufferToBuffer(
resource as GPUBuffer, 0, buffer, 0, size);
this.submitQueue();

Expand Down Expand Up @@ -660,15 +647,16 @@ export class WebGPUBackend extends KernelBackend {
}

override async time(f: () => void): Promise<WebGPUTimingInfo> {
if (!this.supportTimeQuery) {
if (!this.supportTimestampQuery && !this.hasTimestampQueryWarned) {
console.warn(
`This device doesn't support timestamp-query-inside-passes extension. ` +
`This device doesn't support timestamp-query extension. ` +
`Start Chrome browser with flag ` +
`--disable-dawn-features=disallow_unsafe_apis then try again. ` +
`--disable-dawn-features=disallow_unsafe_apis to try it again. ` +
`Otherwise, zero will be shown for the kernel time when profiling ` +
`mode is enabled. Using performance.now is not workable for webgpu ` +
`since it doesn't support synchronous data read from GPU.`);
`mode is enabled.`);
this.hasTimestampQueryWarned = true;
}

const oldActiveTimers = this.activeTimers;
const newActiveTimers: TimerNode[] = [];

Expand Down Expand Up @@ -742,14 +730,6 @@ export class WebGPUBackend extends KernelBackend {
return resource;
}

async getQueryTime(query: GPUQuerySet): Promise<number> {
if (this.supportTimeQuery) {
return this.getTimeFromQuerySet(query);
} else {
return 0;
}
}

uploadToGPU(dataId: DataId): void {
const tensorData = this.tensorMap.get(dataId);
// Already on the GPU.
Expand All @@ -776,8 +756,8 @@ export class WebGPUBackend extends KernelBackend {
}
stagingBuffer.unmap();
this.ensureCommandEncoderReady();
this.ensureComputePassEnded();
this.currentCommandEncoder.copyBufferToBuffer(
this.endComputePassEncoder();
this.commandEncoder.copyBufferToBuffer(
stagingBuffer, 0, buffer, 0, size);

this.stagingPendingDisposal.push(stagingBuffer);
Expand Down Expand Up @@ -966,57 +946,85 @@ export class WebGPUBackend extends KernelBackend {
layout: program.pipeline.getBindGroupLayout(0),
entries: bindings.map((b, i) => ({binding: i, resource: b})),
});
this.ensureCommandEncoderReady();
const pass = this.getComputePass();

const shouldTimeProgram = this.activeTimers != null;
if (shouldTimeProgram && this.supportTimeQuery) {
// tslint:disable-next-line:no-any
(pass as any).writeTimestamp(this.querySet, 0);
this.ensureCommandEncoderReady();

if (!this.computePassEncoder) {
const computePassDescriptor: GPUComputePassDescriptor = {};
if (shouldTimeProgram && this.supportTimestampQuery) {
if (this.querySet == null) {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
count: this.querySetCount,
});
}
computePassDescriptor.timestampWrites = [
{
querySet: this.querySet,
queryIndex: 0,
location: 'beginning',
},
{
querySet: this.querySet,
queryIndex: 1,
location: 'end',
}
];
}
this.computePassEncoder =
this.commandEncoder.beginComputePass(computePassDescriptor);
}

pass.setPipeline(program.pipeline);
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(
this.computePassEncoder.setPipeline(program.pipeline);
this.computePassEncoder.setBindGroup(0, bindGroup);
this.computePassEncoder.dispatchWorkgroups(
program.dispatch[0], program.dispatch[1], program.dispatch[2]);

if (shouldTimeProgram && this.supportTimeQuery) {
// tslint:disable-next-line:no-any
(pass as any).writeTimestamp(this.querySet, 1);
this.dispatchCountInPass++;

if (shouldTimeProgram ||
env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE') as
number <= this.dispatchCountInPass) {
this.endComputePassEncoder();
qjia7 marked this conversation as resolved.
Show resolved Hide resolved
if (shouldTimeProgram) {
this.activeTimers.push(
{name: program.constructor.name, query: this.getQueryTime()});
} else {
this.submitQueue();
qjia7 marked this conversation as resolved.
Show resolved Hide resolved
}
}
this.dispatchNumberInEncoder++;
}

if (env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE') as
number <= this.dispatchNumberInEncoder) {
this.submitQueue();
async getQueryTime(): Promise<number> {
if (!this.supportTimestampQuery) {
return 0;
}
if (shouldTimeProgram) {
this.activeTimers.push({
name: program.constructor.name,
query: this.getQueryTime(this.querySet)
});

if (this.queryResolveBuffer == null) {
this.queryResolveBuffer = this.bufferManager.acquireBuffer(
this.querySetCount * 8,
GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST |
GPUBufferUsage.QUERY_RESOLVE);
}
}
this.commandEncoder.resolveQuerySet(
this.querySet, 0, this.querySetCount, this.queryResolveBuffer, 0);

async getTimeFromQuerySet(querySet: GPUQuerySet) {
const queryBuffer = this.bufferManager.acquireBuffer(
16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE);
const dst = this.bufferManager.acquireBuffer(
16, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST);
const queryStagingBuffer = this.bufferManager.acquireBuffer(
this.querySetCount * 8,
GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST);

this.commandEncoder.copyBufferToBuffer(
this.queryResolveBuffer, 0, queryStagingBuffer, 0,
this.querySetCount * 8);

this.ensureCommandEncoderReady();
this.ensureComputePassEnded();
this.currentCommandEncoder.resolveQuerySet(querySet, 0, 2, queryBuffer, 0);
this.currentCommandEncoder.copyBufferToBuffer(queryBuffer, 0, dst, 0, 16);
this.submitQueue();
await dst.mapAsync(GPUMapMode.READ);
const arrayBuf = new BigUint64Array(dst.getMappedRange());
const timeElapsedNanos = Number((arrayBuf[1] - arrayBuf[0]));
dst.unmap();
this.bufferManager.releaseBuffer(dst);
this.bufferManager.releaseBuffer(queryBuffer);
// Return milliseconds.
return timeElapsedNanos / 1000000;

await queryStagingBuffer.mapAsync(GPUMapMode.READ);
const arrayBuffer = new BigUint64Array(queryStagingBuffer.getMappedRange());
const time = Number(arrayBuffer[1] - arrayBuffer[0]) / 1000000;
queryStagingBuffer.unmap();
this.bufferManager.releaseBuffer(queryStagingBuffer);
return time;
}

shouldExecuteOnCPU(
Expand All @@ -1036,6 +1044,9 @@ export class WebGPUBackend extends KernelBackend {
if (this.disposed) {
return;
}
if (this.querySet != null) {
this.querySet.destroy();
}
this.bufferManager.dispose();
this.textureManager.dispose();
this.disposed = true;
Expand Down
10 changes: 2 additions & 8 deletions tfjs-backend-webgpu/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,8 @@ if (isWebGPUSupported()) {
const adapter = await navigator.gpu.requestAdapter(gpuDescriptor);
const deviceDescriptor: GPUDeviceDescriptor = {};

// Note that timestamp-query-inside-passes is not formally in spec as
// timestamp within a pass is not generally supported on all the platforms.
// More details can be found at
// https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md
if (adapter.features.has('timestamp-query-inside-passes')) {
deviceDescriptor.requiredFeatures =
// tslint:disable-next-line:no-any
['timestamp-query-inside-passes' as any];
if (adapter.features.has('timestamp-query')) {
deviceDescriptor.requiredFeatures = ['timestamp-query'];
}

const adapterLimits = adapter.limits;
Expand Down
8 changes: 4 additions & 4 deletions tfjs-core/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d"
integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg==

"@webgpu/types@0.1.21":
version "0.1.21"
resolved "https://registry.yarnpkg.com/@webgpu/types/-/types-0.1.21.tgz#b181202daec30d66ccd67264de23814cfd176d3a"
integrity sha512-pUrWq3V5PiSGFLeLxoGqReTZmiiXwY3jRkIG5sLLKjyqNxrwm/04b4nw7LSmGWJcKk59XOM/YRTUwOzo4MMlow==
"@webgpu/types@0.1.30":
version "0.1.30"
resolved "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz#b6406dc4a1c1e0d469028ceb30ddffbbd2fa706c"
integrity sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg==

long@4.0.0:
version "4.0.0"
Expand Down
4 changes: 2 additions & 2 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@

"@webgpu/types@0.1.30":
version "0.1.30"
resolved "https://registry.npmmirror.com/@webgpu/types/-/types-0.1.30.tgz#b6406dc4a1c1e0d469028ceb30ddffbbd2fa706c"
resolved "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz#b6406dc4a1c1e0d469028ceb30ddffbbd2fa706c"
integrity sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg==

"@xmldom/xmldom@^0.7.3":
Expand Down Expand Up @@ -4507,7 +4507,7 @@ type@^2.5.0:

typescript@4.9.4:
version "4.9.4"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.9.4.tgz#a2a3d2756c079abda241d75f149df9d561091e78"
resolved "https://registry.npmjs.org/typescript/-/typescript-4.9.4.tgz#a2a3d2756c079abda241d75f149df9d561091e78"
integrity sha512-Uz+dTXYzxXXbsFpM86Wh3dKCxrQqUcVMxwU54orwlJjOpO3ao8L7j5lH+dWfTwgCwIuM9GQ2kvVotzYJMXTBZg==

ua-parser-js@^0.7.30:
Expand Down