-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MigraphX] Fix potential synchronization problem when ORT_ENABLE_STREAM is true #22589
[MigraphX] Fix potential synchronization problem when ORT_ENABLE_STREAM is true #22589
Conversation
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline |
/azp run Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline |
/azp run Big Models,Linux Android Emulator QNN CI Pipeline,Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
Azure Pipelines successfully started running 6 pipeline(s). |
Azure Pipelines successfully started running 10 pipeline(s). |
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline |
/azp run Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline |
/azp run Big Models,Linux Android Emulator QNN CI Pipeline,Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline |
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline |
/azp run Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline |
/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline,CoreML CI Pipeline,Linux DNNL CI Pipeline,Linux MIGraphX CI Pipeline,Linux ROCm CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
Azure Pipelines successfully started running 6 pipeline(s). |
Azure Pipelines successfully started running 10 pipeline(s). |
Azure Pipelines successfully started running 7 pipeline(s). |
Azure Pipelines successfully started running 8 pipeline(s). |
Azure Pipelines successfully started running 10 pipeline(s). |
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing #22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
…AM is true (microsoft#22589) ### Description Replace `hipMemcpy` with `hipMemcpyWithStream` ### Motivation and Context `hipMemcpy` uses default stream, which may be out of synchronization with the current stream when ORT_ENABLE_STREAM is defined.
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
…AM is true (microsoft#22589) ### Description Replace `hipMemcpy` with `hipMemcpyWithStream` ### Motivation and Context `hipMemcpy` uses default stream, which may be out of synchronization with the current stream when ORT_ENABLE_STREAM is defined.
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
…AM is true (microsoft#22589) ### Description Replace `hipMemcpy` with `hipMemcpyWithStream` ### Motivation and Context `hipMemcpy` uses default stream, which may be out of synchronization with the current stream when ORT_ENABLE_STREAM is defined.
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
…AM is true (microsoft#22589) ### Description Replace `hipMemcpy` with `hipMemcpyWithStream` ### Motivation and Context `hipMemcpy` uses default stream, which may be out of synchronization with the current stream when ORT_ENABLE_STREAM is defined.
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
Description
Replace
hipMemcpy
withhipMemcpyWithStream
Motivation and Context
hipMemcpy
uses default stream, which may be out of synchronization with the current stream when ORT_ENABLE_STREAM is defined.