diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 78db7d735dad9..cda0222322d14 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -306,7 +306,7 @@ "component": { "type": "git", "git": { - "commitHash": "6f47420213f757831fae65c686aa471749fa8d60", + "commitHash": "7d49e6c7e2f8896c47f586706e67e1fb215529dc", "repositoryUrl": "https://github.com/NVIDIA/cutlass.git" }, "comments": "cutlass" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5200b447d553f..61e0f7d6c574e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -46,6 +46,11 @@ else() set(CMAKE_CXX_STANDARD 17) endif() +if (MSVC) + # Make sure Visual Studio sets __cplusplus macro correctly: https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus") +endif() + set_property(GLOBAL PROPERTY USE_FOLDERS ON) # NOTE: POSITION INDEPENDENT CODE hurts performance, and it only make sense on POSIX systems set(CMAKE_POSITION_INDEPENDENT_CODE ON) diff --git a/cmake/deps.txt b/cmake/deps.txt index 88c1881ad82fb..703988a1513eb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -53,7 +53,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8b re2;https://github.com/google/re2/archive/refs/tags/2024-05-01.tar.gz;206cfee5ee0b4c6844680ba66275e9e8faa77405 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 -cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.0.zip;ae038931b9fc2c416c17d9cda91d9706b343f56d utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index f04f4bec76cd5..1ece2e7a509ba 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -3,6 +3,7 @@ FetchContent_Declare( cutlass URL ${DEP_URL_cutlass} URL_HASH SHA1=${DEP_SHA1_cutlass} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass_3.5.0.patch ) FetchContent_GetProperties(cutlass) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 46bc5fb3bd1ac..3b48a40bf1166 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -175,6 +175,10 @@ endif() endif() + if(MSVC) + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /Zc:__cplusplus>") + endif() + onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(${target} onnxruntime_training) diff --git a/cmake/patches/cutlass/cutlass_3.5.0.patch b/cmake/patches/cutlass/cutlass_3.5.0.patch new file mode 100644 index 0000000000000..3b829d2f8b2cf --- /dev/null +++ b/cmake/patches/cutlass/cutlass_3.5.0.patch @@ -0,0 +1,25 @@ +diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h +index 964d2ff3..b366bc14 100644 +--- a/include/cutlass/functional.h ++++ b/include/cutlass/functional.h +@@ -39,6 +39,7 @@ + #include "cutlass/numeric_types.h" + + #include ++#include + + #if defined(CUTLASS_ARCH_WMMA_ENABLED) + #include +@@ -230,8 +231,12 @@ struct inverse_square_root { + CUTLASS_HOST_DEVICE + half_t operator()(half_t const &lhs) const { + #if defined(__CUDA_ARCH__) ++#if (__CUDA_ARCH__ >= 530) + auto result = hrsqrt(reinterpret_cast<__half const &>(lhs)); + return reinterpret_cast(result); ++#else ++ return half_t::convert((rsqrtf(half_t::convert(lhs)))); ++#endif + #else + return half_t(1.f / std::sqrt(half_t::convert(lhs))); + #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 7a807342ad685..3e6edb162360d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -219,11 +219,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { !disable_memory_efficient_attention_ && nullptr == past && nullptr == present && - (parameters.head_size & 7) == 0 && - (parameters.v_head_size & 7) == 0 && (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && - has_memory_efficient_attention(sm, sizeof(T) == 2); + has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); if (use_memory_efficient_attention) { bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index c12cb374d9adf..a5de20e44be1a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -75,12 +75,8 @@ struct RightPaddingBatchHook { batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start; } - // Custom masking - if (p.causal_diagonal_ptr) { - p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id]; - } if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { - p.causal_diagonal_offset += p.num_keys - p.num_queries; + p.causal_diagonal_offset = p.num_keys - p.num_queries; } if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft || p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { @@ -143,9 +139,10 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) AK::attention_kernel(p); } -template +template void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { - using Attention = AttentionKernel; + constexpr bool dropout = false; + using Attention = AttentionKernel; typename Attention::Params p; { // set parameters p.query_ptr = const_cast(reinterpret_cast(params.query)); @@ -220,6 +217,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { } auto kernel_fn = attention_kernel_batched_impl; + if (params.has_custom_right_padding) { kernel_fn = attention_kernel_batched_impl_right_padding; } @@ -237,20 +235,23 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { kernel_fn<<>>(p); } -template +template void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { - using AlignedAK = AttentionKernel; + using AlignedAK = AttentionKernel; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect #endif + // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && params.qk_head_size % AlignedAK::kAlignmentK == 0 && params.v_head_size % AlignedAK::kAlignmentV == 0; + DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { - LaunchCutlassFmha(params); + LaunchCutlassFmha(params); })); + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif @@ -259,11 +260,11 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { template void DispatchBlockSize(const MemoryEfficientAttentionParams& params) { if (params.v_head_size <= 64) { - DispatchIsAligned(params); + DispatchIsAligned(params); } else if (params.v_head_size <= 128) { - DispatchIsAligned(params); + DispatchIsAligned(params); } else { - DispatchIsAligned(params); + DispatchIsAligned(params); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 484b783db1724..08a562a12b844 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -11,6 +11,8 @@ namespace onnxruntime { namespace contrib { namespace cuda { +constexpr int kEfficientAttentionMaxHeadSize = 1024; + struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; @@ -49,8 +51,11 @@ struct MemoryEfficientAttentionParams { void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); -inline bool has_memory_efficient_attention(int32_t sm, bool is_half) { - return sm >= (is_half ? 53 : 50); +inline bool has_memory_efficient_attention(int32_t sm, bool is_half, int qk_head_size, int v_head_size) { + return sm >= (is_half ? 53 : 50) && + (qk_head_size & 7) == 0 && + (v_head_size & 7) == 0 && + qk_head_size <= kEfficientAttentionMaxHeadSize && v_head_size <= kEfficientAttentionMaxHeadSize; } void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index 028233f66850f..1fac03882b4b1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -10,8 +10,7 @@ #endif #include -#include -#include +#include #include #include @@ -98,7 +97,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; - constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 1c0ed7f2fc2e8..52a4e56491c5e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -3,7 +3,7 @@ ******************************************************************************/ #pragma once -#include +#include #include #include @@ -32,10 +32,8 @@ struct Flash_kernel_traits { std::is_same_v, MMA_Atom, MMA_Atom>; - using ValLayoutMNK = cute::Layout>; #else using MMA_Atom_Arch = MMA_Atom; - using ValLayoutMNK = cute::Layout>; #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 @@ -77,7 +75,7 @@ struct Flash_fwd_kernel_traits : public Base { using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile, _16, _16>>; using SmemLayoutAtomQ = decltype(composition(Swizzle{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 @@ -208,17 +206,17 @@ struct Flash_bwd_kernel_traits : public Base { using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, cute::Layout, cute::Int, _1>>, - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; using TiledMmadKV = TiledMMA< typename Base::MMA_Atom_Arch, cute::Layout, cute::Int, _1>>, - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch, cute::Layout, cute::Int, _1>>, // 2x4x1 or 4x2x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, cute::Layout>, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 271112c5e890a..7aefd4799bc4d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -13,8 +13,7 @@ #include #endif -#include -#include +#include #include #include diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 3c968d6c8b347..0c26f04edef99 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -161,9 +161,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && - (parameters.head_size & 7) == 0 && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && - has_memory_efficient_attention(sm, sizeof(T) == 2); + has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size); if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Local attention UNSUPPORTED for sm < 80 on CUDA."); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 2ef011cdd9a21..5ae7c149fa05c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -235,17 +235,16 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - bool use_memory_efficient_attention = !use_flash_attention && - fused_runner == nullptr && - fused_cross_attention_kernel == nullptr && - !disable_memory_efficient_attention_ && - (parameters.head_size & 7) == 0 && - (parameters.v_head_size & 7) == 0 && - is_long_sequence && - !past_no_bias && - (relative_position_bias == nullptr || is_good_for_rpb) && - (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && - has_memory_efficient_attention(sm, sizeof(T) == 2); + bool use_memory_efficient_attention = + !use_flash_attention && + fused_runner == nullptr && + fused_cross_attention_kernel == nullptr && + !disable_memory_efficient_attention_ && + is_long_sequence && + !past_no_bias && + (relative_position_bias == nullptr || is_good_for_rpb) && + (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && + has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); #else constexpr bool use_memory_efficient_attention = false; #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index e4b90727121cf..0146cce30c7d1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -288,11 +288,10 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { if (nullptr == fused_runner) { int sm = device_prop.major * 10 + device_prop.minor; bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; - use_memory_efficient_attention = is_good_for_rpb && - sizeof(T) == 2 && // only enable for fp16 - (parameters.head_size & 7) == 0 && - (parameters.v_head_size & 7) == 0 && - has_memory_efficient_attention(sm, sizeof(T) == 2); + use_memory_efficient_attention = + is_good_for_rpb && + sizeof(T) == 2 && // only enable for fp16 + has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 00ab32886112b..3fbbafc01254e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -272,9 +272,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co use_memory_efficient_attention = is_good_for_rpb && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && - (parameters.head_size & 7) == 0 && - (parameters.v_head_size & 7) == 0 && - has_memory_efficient_attention(sm, sizeof(T) == 2); + has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } #endif diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h index 28364cc34f2d7..6e281241a3427 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -490,7 +490,10 @@ class QuantBMmaMultistage : // accuracy, where each mainloop iteration first accumulates into a temporary // set of freshly-cleared accumulators, which are subsequently added to the // final accumulator set. - static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; + + // Change the following to avoid build error: class "cutlass::arch::OpMultiplyAdd" has no member "ElementA". + // kStagedAccumulation = arch::detail::UseStagedAccumulation::value; + static bool const kStagedAccumulation = false; }; private: diff --git a/onnxruntime/core/providers/tensorrt/nv_includes.h b/onnxruntime/core/providers/tensorrt/nv_includes.h index c3e9f7a3a2a77..047f325f49b70 100644 --- a/onnxruntime/core/providers/tensorrt/nv_includes.h +++ b/onnxruntime/core/providers/tensorrt/nv_includes.h @@ -2,12 +2,11 @@ // Licensed under the MIT License. #pragma once -// File to include the required TRT headers with workarounds for warnings we can't fix. - -// Ignore warning C4100: unreferenced formal parameter +// File to include the required TRT headers with workarounds for warnings we can't fix or not fixed yet. #if defined(_MSC_VER) #pragma warning(push) -#pragma warning(disable : 4100) +#pragma warning(disable : 4100) // Ignore warning C4100: unreferenced formal parameter +#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::IPluginV2' was declared deprecated #endif #include diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index dff74a404a456..9c2db494f0e41 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -3142,7 +3142,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (mem_size > max_ctx_mem_size_) { max_ctx_mem_size_ = mem_size; } + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated +#endif trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); } @@ -3588,8 +3596,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (context_update) { if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated +#endif *(trt_state->context) = std::unique_ptr( trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif } else { *(trt_state->context) = std::unique_ptr( trt_state->engine->get()->createExecutionContext()); @@ -3805,7 +3820,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con if (mem_size > max_ctx_mem_size_) { max_ctx_mem_size_ = mem_size; } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated +#endif trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 58a1afd005563..a4d2d6c9d65f3 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -60,6 +60,11 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& TensorrtLogger trt_logger = GetTensorrtLogger(false); initLibNvInferPlugins(&trt_logger, ""); +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::*' was declared deprecated +#endif + int num_plugin_creator = 0; auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator); std::unordered_set registered_plugin_names; @@ -79,6 +84,11 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get()); registered_plugin_names.insert(plugin_name); } + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + custom_op_domain->domain_ = "trt.plugins"; domain_list.push_back(custom_op_domain.get()); } catch (const std::exception&) { diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index d7e8bf9063645..b7c3b38538421 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -36,6 +36,7 @@ void ortenv_setup() { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4100) // Ignore warning C4100: unreferenced format parameter. +#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::IPluginV2' was declared deprecated #endif // TensorRT will load/unload libraries as builder objects are created and torn down. This will happen for diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index e7b230008dad4..f97fe5ef751e5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.156 + version: 1.0.157 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.156 + version: 1.0.157 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here.