diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 4307813e1af..f4f89130d37 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -49,7 +49,8 @@ type TensorData = { interface DataId {} export type WebGPUKernelInfo = { - name: string; query: Promise; + name: string, + query: Promise, }; export type TimerNode = RecursiveArray|WebGPUKernelInfo; @@ -105,10 +106,10 @@ export class WebGPUBackend extends KernelBackend { thresholdToIncreaseWorkgroups: number; private activeTimers: TimerNode[]; - private currentCommandEncoder: GPUCommandEncoder; - private currentComputePass: GPUComputePassEncoder; + private commandEncoder: GPUCommandEncoder; + private computePassEncoder: GPUComputePassEncoder; private commandQueueOwnedIds = new WeakSet(); - private dispatchNumberInEncoder = 0; + private dispatchCountInPass = 0; private disposed = false; private downloadWaitMs = 0; private dummyCanvas: HTMLCanvasElement; @@ -118,12 +119,15 @@ export class WebGPUBackend extends KernelBackend { private pipelineCache: {[key: string]: GPUComputePipeline|Promise}; private programTimersStack: TimerNode[]; - private querySet: GPUQuerySet; + private queryResolveBuffer: GPUBuffer = null; + private querySet: GPUQuerySet = null; + private querySetCount = 2; private stagingPendingDisposal: GPUBuffer[] = []; - private supportTimeQuery: boolean; + private supportTimestampQuery: boolean; private uniformPendingDisposal: GPUBuffer[] = []; private uploadWaitMs = 0; private hasReadSyncWarned = false; + private hasTimestampQueryWarned = false; private nextDataId(): number { return WebGPUBackend.nextDataId++; @@ -137,23 +141,16 @@ export class WebGPUBackend extends KernelBackend { this.pipelineCache = {}; this.device = device; this.queue = device.queue; - this.currentCommandEncoder = null; - this.currentComputePass = null; - this.supportTimeQuery = - device.features.has('timestamp-query-inside-passes'); + this.commandEncoder = null; + this.computePassEncoder = null; this.adapterInfo = new AdapterInfo(adapterInfo); + this.supportTimestampQuery = this.device.features.has('timestamp-query'); this.thresholdToIncreaseWorkgroups = this.adapterInfo.intelGPUGeneration >= 12 ? 16 : 8; this.bufferManager = new BufferManager(this.device); this.textureManager = new TextureManager(this.device); this.tensorMap = new DataStorage(this, engine()); - if (this.supportTimeQuery) { - this.querySet = this.device.createQuerySet({ - type: 'timestamp', - count: 2, - }); - } // Profiling tools like PIX needs this dummy canvas to // trigger capturing a frame. @@ -290,10 +287,9 @@ export class WebGPUBackend extends KernelBackend { } submitQueue() { - this.ensureComputePassEnded(); - this.queue.submit([this.currentCommandEncoder.finish()]); - this.currentCommandEncoder = null; - this.dispatchNumberInEncoder = 0; + this.queue.submit([this.commandEncoder.finish()]); + this.commandEncoder = null; + this.dispatchCountInPass = 0; this.commandQueueOwnedIds = new WeakSet(); @@ -313,23 +309,16 @@ export class WebGPUBackend extends KernelBackend { } ensureCommandEncoderReady() { - if (!this.currentCommandEncoder) { - this.currentCommandEncoder = this.device.createCommandEncoder(); - } - } - - ensureComputePassEnded() { - if (this.currentComputePass) { - this.currentComputePass.end(); - this.currentComputePass = null; + if (!this.commandEncoder) { + this.commandEncoder = this.device.createCommandEncoder(); } } - getComputePass() { - if (!this.currentComputePass) { - this.currentComputePass = this.currentCommandEncoder.beginComputePass(); + endComputePassEncoder() { + if (this.computePassEncoder) { + this.computePassEncoder.end(); + this.computePassEncoder = null; } - return this.currentComputePass; } // Check if parallel compilation is done. @@ -356,9 +345,8 @@ export class WebGPUBackend extends KernelBackend { const stagingBuffer = this.bufferManager.acquireBuffer( size, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); this.ensureCommandEncoderReady(); - this.ensureComputePassEnded(); - this.currentCommandEncoder.copyBufferToBuffer( - buffer, 0, stagingBuffer, 0, size); + this.endComputePassEncoder(); + this.commandEncoder.copyBufferToBuffer(buffer, 0, stagingBuffer, 0, size); this.submitQueue(); await stagingBuffer.mapAsync(GPUMapMode.READ); @@ -431,7 +419,7 @@ export class WebGPUBackend extends KernelBackend { alphaModes.map(_ => new OffscreenCanvas(canvasWidth, canvasHeight)); const stagingHostStorage = new OffscreenCanvas(canvasWidth, canvasHeight); - this.ensureComputePassEnded(); + this.endComputePassEncoder(); stagingDeviceStorage .map((storage, index) => { const context = storage.getContext('webgpu'); @@ -450,7 +438,7 @@ export class WebGPUBackend extends KernelBackend { const readDataGPUToCPU = (width: number, height: number, offset: number) => { this.ensureCommandEncoderReady(); - this.currentCommandEncoder.copyBufferToTexture( + this.commandEncoder.copyBufferToTexture( { buffer, bytesPerRow, @@ -556,9 +544,8 @@ export class WebGPUBackend extends KernelBackend { const usage = srcBuffer.usage; const dstBuffer = this.bufferManager.acquireBuffer(size, usage); this.ensureCommandEncoderReady(); - this.ensureComputePassEnded(); - this.currentCommandEncoder.copyBufferToBuffer( - srcBuffer, 0, dstBuffer, 0, size); + this.endComputePassEncoder(); + this.commandEncoder.copyBufferToBuffer(srcBuffer, 0, dstBuffer, 0, size); this.submitQueue(); return dstBuffer; } @@ -627,8 +614,8 @@ export class WebGPUBackend extends KernelBackend { const usage = srcBuffer.usage; const buffer = this.bufferManager.acquireBuffer(size, usage); this.ensureCommandEncoderReady(); - this.ensureComputePassEnded(); - this.currentCommandEncoder.copyBufferToBuffer( + this.endComputePassEncoder(); + this.commandEncoder.copyBufferToBuffer( resource as GPUBuffer, 0, buffer, 0, size); this.submitQueue(); @@ -660,15 +647,16 @@ export class WebGPUBackend extends KernelBackend { } override async time(f: () => void): Promise { - if (!this.supportTimeQuery) { + if (!this.supportTimestampQuery && !this.hasTimestampQueryWarned) { console.warn( - `This device doesn't support timestamp-query-inside-passes extension. ` + + `This device doesn't support timestamp-query extension. ` + `Start Chrome browser with flag ` + - `--disable-dawn-features=disallow_unsafe_apis then try again. ` + + `--disable-dawn-features=disallow_unsafe_apis to try it again. ` + `Otherwise, zero will be shown for the kernel time when profiling ` + - `mode is enabled. Using performance.now is not workable for webgpu ` + - `since it doesn't support synchronous data read from GPU.`); + `mode is enabled.`); + this.hasTimestampQueryWarned = true; } + const oldActiveTimers = this.activeTimers; const newActiveTimers: TimerNode[] = []; @@ -742,14 +730,6 @@ export class WebGPUBackend extends KernelBackend { return resource; } - async getQueryTime(query: GPUQuerySet): Promise { - if (this.supportTimeQuery) { - return this.getTimeFromQuerySet(query); - } else { - return 0; - } - } - uploadToGPU(dataId: DataId): void { const tensorData = this.tensorMap.get(dataId); // Already on the GPU. @@ -776,8 +756,8 @@ export class WebGPUBackend extends KernelBackend { } stagingBuffer.unmap(); this.ensureCommandEncoderReady(); - this.ensureComputePassEnded(); - this.currentCommandEncoder.copyBufferToBuffer( + this.endComputePassEncoder(); + this.commandEncoder.copyBufferToBuffer( stagingBuffer, 0, buffer, 0, size); this.stagingPendingDisposal.push(stagingBuffer); @@ -966,57 +946,85 @@ export class WebGPUBackend extends KernelBackend { layout: program.pipeline.getBindGroupLayout(0), entries: bindings.map((b, i) => ({binding: i, resource: b})), }); - this.ensureCommandEncoderReady(); - const pass = this.getComputePass(); const shouldTimeProgram = this.activeTimers != null; - if (shouldTimeProgram && this.supportTimeQuery) { - // tslint:disable-next-line:no-any - (pass as any).writeTimestamp(this.querySet, 0); + this.ensureCommandEncoderReady(); + + if (!this.computePassEncoder) { + const computePassDescriptor: GPUComputePassDescriptor = {}; + if (shouldTimeProgram && this.supportTimestampQuery) { + if (this.querySet == null) { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.querySetCount, + }); + } + computePassDescriptor.timestampWrites = [ + { + querySet: this.querySet, + queryIndex: 0, + location: 'beginning', + }, + { + querySet: this.querySet, + queryIndex: 1, + location: 'end', + } + ]; + } + this.computePassEncoder = + this.commandEncoder.beginComputePass(computePassDescriptor); } - pass.setPipeline(program.pipeline); - pass.setBindGroup(0, bindGroup); - pass.dispatchWorkgroups( + this.computePassEncoder.setPipeline(program.pipeline); + this.computePassEncoder.setBindGroup(0, bindGroup); + this.computePassEncoder.dispatchWorkgroups( program.dispatch[0], program.dispatch[1], program.dispatch[2]); - - if (shouldTimeProgram && this.supportTimeQuery) { - // tslint:disable-next-line:no-any - (pass as any).writeTimestamp(this.querySet, 1); + this.dispatchCountInPass++; + + if (shouldTimeProgram || + env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE') as + number <= this.dispatchCountInPass) { + this.endComputePassEncoder(); + if (shouldTimeProgram) { + this.activeTimers.push( + {name: program.constructor.name, query: this.getQueryTime()}); + } else { + this.submitQueue(); + } } - this.dispatchNumberInEncoder++; + } - if (env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE') as - number <= this.dispatchNumberInEncoder) { - this.submitQueue(); + async getQueryTime(): Promise { + if (!this.supportTimestampQuery) { + return 0; } - if (shouldTimeProgram) { - this.activeTimers.push({ - name: program.constructor.name, - query: this.getQueryTime(this.querySet) - }); + + if (this.queryResolveBuffer == null) { + this.queryResolveBuffer = this.bufferManager.acquireBuffer( + this.querySetCount * 8, + GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | + GPUBufferUsage.QUERY_RESOLVE); } - } + this.commandEncoder.resolveQuerySet( + this.querySet, 0, this.querySetCount, this.queryResolveBuffer, 0); - async getTimeFromQuerySet(querySet: GPUQuerySet) { - const queryBuffer = this.bufferManager.acquireBuffer( - 16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); - const dst = this.bufferManager.acquireBuffer( - 16, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); + const queryStagingBuffer = this.bufferManager.acquireBuffer( + this.querySetCount * 8, + GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); + + this.commandEncoder.copyBufferToBuffer( + this.queryResolveBuffer, 0, queryStagingBuffer, 0, + this.querySetCount * 8); - this.ensureCommandEncoderReady(); - this.ensureComputePassEnded(); - this.currentCommandEncoder.resolveQuerySet(querySet, 0, 2, queryBuffer, 0); - this.currentCommandEncoder.copyBufferToBuffer(queryBuffer, 0, dst, 0, 16); this.submitQueue(); - await dst.mapAsync(GPUMapMode.READ); - const arrayBuf = new BigUint64Array(dst.getMappedRange()); - const timeElapsedNanos = Number((arrayBuf[1] - arrayBuf[0])); - dst.unmap(); - this.bufferManager.releaseBuffer(dst); - this.bufferManager.releaseBuffer(queryBuffer); - // Return milliseconds. - return timeElapsedNanos / 1000000; + + await queryStagingBuffer.mapAsync(GPUMapMode.READ); + const arrayBuffer = new BigUint64Array(queryStagingBuffer.getMappedRange()); + const time = Number(arrayBuffer[1] - arrayBuffer[0]) / 1000000; + queryStagingBuffer.unmap(); + this.bufferManager.releaseBuffer(queryStagingBuffer); + return time; } shouldExecuteOnCPU( @@ -1036,6 +1044,9 @@ export class WebGPUBackend extends KernelBackend { if (this.disposed) { return; } + if (this.querySet != null) { + this.querySet.destroy(); + } this.bufferManager.dispose(); this.textureManager.dispose(); this.disposed = true; diff --git a/tfjs-backend-webgpu/src/base.ts b/tfjs-backend-webgpu/src/base.ts index be71afb2d2d..8f4ad52877c 100644 --- a/tfjs-backend-webgpu/src/base.ts +++ b/tfjs-backend-webgpu/src/base.ts @@ -33,14 +33,8 @@ if (isWebGPUSupported()) { const adapter = await navigator.gpu.requestAdapter(gpuDescriptor); const deviceDescriptor: GPUDeviceDescriptor = {}; - // Note that timestamp-query-inside-passes is not formally in spec as - // timestamp within a pass is not generally supported on all the platforms. - // More details can be found at - // https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md - if (adapter.features.has('timestamp-query-inside-passes')) { - deviceDescriptor.requiredFeatures = - // tslint:disable-next-line:no-any - ['timestamp-query-inside-passes' as any]; + if (adapter.features.has('timestamp-query')) { + deviceDescriptor.requiredFeatures = ['timestamp-query']; } const adapterLimits = adapter.limits; diff --git a/tfjs-core/yarn.lock b/tfjs-core/yarn.lock index 10bf8b85753..9bb5eb1c6ec 100644 --- a/tfjs-core/yarn.lock +++ b/tfjs-core/yarn.lock @@ -32,10 +32,10 @@ resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d" integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg== -"@webgpu/types@0.1.21": - version "0.1.21" - resolved "https://registry.yarnpkg.com/@webgpu/types/-/types-0.1.21.tgz#b181202daec30d66ccd67264de23814cfd176d3a" - integrity sha512-pUrWq3V5PiSGFLeLxoGqReTZmiiXwY3jRkIG5sLLKjyqNxrwm/04b4nw7LSmGWJcKk59XOM/YRTUwOzo4MMlow== +"@webgpu/types@0.1.30": + version "0.1.30" + resolved "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz#b6406dc4a1c1e0d469028ceb30ddffbbd2fa706c" + integrity sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg== long@4.0.0: version "4.0.0" diff --git a/yarn.lock b/yarn.lock index a52b1eb6180..29251ec4a1c 100644 --- a/yarn.lock +++ b/yarn.lock @@ -520,7 +520,7 @@ "@webgpu/types@0.1.30": version "0.1.30" - resolved "https://registry.npmmirror.com/@webgpu/types/-/types-0.1.30.tgz#b6406dc4a1c1e0d469028ceb30ddffbbd2fa706c" + resolved "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz#b6406dc4a1c1e0d469028ceb30ddffbbd2fa706c" integrity sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg== "@xmldom/xmldom@^0.7.3": @@ -4507,7 +4507,7 @@ type@^2.5.0: typescript@4.9.4: version "4.9.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.9.4.tgz#a2a3d2756c079abda241d75f149df9d561091e78" + resolved "https://registry.npmjs.org/typescript/-/typescript-4.9.4.tgz#a2a3d2756c079abda241d75f149df9d561091e78" integrity sha512-Uz+dTXYzxXXbsFpM86Wh3dKCxrQqUcVMxwU54orwlJjOpO3ao8L7j5lH+dWfTwgCwIuM9GQ2kvVotzYJMXTBZg== ua-parser-js@^0.7.30: