diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index a4ba85e868896..f6ace371531f9 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1128,7 +1128,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${BENCHMARK_DIR}/gelu.cc ${BENCHMARK_DIR}/activation.cc ${BENCHMARK_DIR}/quantize.cc - ${BENCHMARK_DIR}/reduceminmax.cc) + ${BENCHMARK_DIR}/reduceminmax.cc + ${BENCHMARK_DIR}/layer_normalization.cc) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) if(WIN32) diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index faf78cae80ee1..67b4950af73bf 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensor.h" +#include "core/mlas/inc/mlas.h" #include "core/util/math_cpuonly.h" #include "core/providers/common.h" #include "core/platform/threadpool.h" @@ -36,52 +37,188 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) -// Utility to convert from MLFloat16 to float only when the input type is MLFloat16. -template -ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val); - -template <> -ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { - return val.ToFloat(); -} - -template <> -ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { - return static_cast(ConvertMLFloat16ToDoubleOrFloatIfNeeded(val)); +namespace { + +template || std::is_same_v, void>> +void ComputeJob( + const T* input_data, + const T* skip_data, + const T* gamma_data, + const T* beta_data, + const T* bias_data, + IAllocatorUniquePtr& skip_float_uptr, + IAllocatorUniquePtr& gamma_float_uptr, + IAllocatorUniquePtr& beta_float_uptr, + IAllocatorUniquePtr& bias_float_uptr, + ptrdiff_t task_idx, + int hidden_size, + int64_t skip_size, + float epsilon, + bool simplified, + T* output_data, + T* skip_input_bias_add_output_data, + AllocatorPtr alloc) { + ORT_UNUSED_PARAMETER(skip_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(gamma_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(beta_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(alloc); + + auto offset = task_idx * hidden_size; + const T* p_input = input_data + offset; + const T* p_skip = skip_data + (offset % skip_size); + T* p_output = output_data + offset; + T* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset; + + T mean(0.0f); + T mean_square(0.0f); + + for (decltype(hidden_size) h = 0; h < hidden_size; h++) { + T val = p_input[h] + p_skip[h]; + + if (nullptr != bias_data) { + val += bias_data[h]; + } + + if (nullptr != p_skip_input_bias_add_output) { + p_skip_input_bias_add_output[h] = val; + } + + p_output[h] = val; + mean += val; + mean_square += val * val; + } + + mean = mean / hidden_size; + if (simplified) { + mean_square = sqrt(mean_square / hidden_size + epsilon); + } else { + mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon); + } + + for (decltype(hidden_size) h = 0; h < hidden_size; h++) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * gamma_data[h]; + } else if (nullptr == beta_data) { + p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h]; + } else { + p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; + } + } } -template <> -ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded(float val) { - return val; +void ComputeJob( + const MLFloat16* input_data, + const MLFloat16* skip_data, + const MLFloat16* gamma_data, + const MLFloat16* beta_data, + const MLFloat16* bias_data, + IAllocatorUniquePtr& skip_float_uptr, + IAllocatorUniquePtr& gamma_float_uptr, + IAllocatorUniquePtr& beta_float_uptr, + IAllocatorUniquePtr& bias_float_uptr, + ptrdiff_t task_idx, + int hidden_size, + int64_t skip_size, + float epsilon, + bool simplified, + MLFloat16* output_data, + MLFloat16* skip_input_bias_add_output_data, + AllocatorPtr alloc) { + auto offset = task_idx * hidden_size; + const MLFloat16* p_input = input_data + offset; + const MLFloat16* p_skip = skip_data + (offset % skip_size); + MLFloat16* p_output = output_data + offset; + MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset; + + float mean(0.0f); + float mean_square(0.0f); + const size_t num_elems = static_cast(hidden_size); + + IAllocatorUniquePtr input_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems); + + if (!skip_float_uptr) { + skip_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems); + } + + if (bias_data && !bias_float_uptr) { + bias_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems); + } + + IAllocatorUniquePtr output_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + float* output_float_ptr = output_float_uptr.get(); + + const float* input_float_ptr = input_float_uptr.get(); + const float* skip_float_ptr = skip_float_uptr.get(); + const float* bias_float_ptr = bias_float_uptr.get(); + for (size_t h = 0; h < num_elems; h++) { + float val = input_float_ptr[h] + skip_float_ptr[h]; + + if (bias_float_uptr) { + val += bias_float_ptr[h]; + } + + output_float_ptr[h] = val; + mean += val; + mean_square += val * val; + } + + if (nullptr != p_skip_input_bias_add_output) { + MlasConvertFloatToHalfBuffer(output_float_ptr, p_skip_input_bias_add_output, num_elems); + } + + mean = mean / hidden_size; + if (simplified) { + mean_square = sqrt(mean_square / hidden_size + epsilon); + } else { + mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon); + } + + if (!gamma_float_uptr) { + gamma_float_uptr = std::move(input_float_uptr); // overwrite input with gamma values, since they have the same size + MlasConvertHalfToFloatBuffer(gamma_data, gamma_float_uptr.get(), num_elems); + } + + if (beta_data && !beta_float_uptr) { + beta_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(beta_data, beta_float_uptr.get(), num_elems); + } + + const float* gamma_float_ptr = gamma_float_uptr.get(); + const float* beta_float_ptr = beta_float_uptr.get(); + for (size_t h = 0; h < num_elems; h++) { + if (simplified) { + output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h]; + } else if (nullptr == beta_float_uptr) { + output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h]; + } else { + output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h]; + } + } + + MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems); } -template <> -ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded(double val) { - return val; -} +void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr& dest, bool& is_packed) { + if (tensor.GetElementType() == utils::ToTensorProtoElementType()) { + auto tensor_data_ptr = tensor.Data(); + auto tensor_size = static_cast(tensor.Shape().Size()); + auto float_ptr = IAllocator::MakeUniquePtr(alloc, tensor_size, true); -// Function template that only converts the input value to MLFloat16 if T is MLFloat16. -template -ORT_FORCEINLINE constexpr typename std::enable_if_t || std::is_same_v, T> -ConvertDoubleOrFloatToMLFloat16IfNeeded(T val) { - return val; + MlasConvertHalfToFloatBuffer(tensor_data_ptr, float_ptr.get(), tensor_size); + dest = std::move(float_ptr); + is_packed = true; + } } -template -ORT_FORCEINLINE constexpr typename std::enable_if_t, T> -ConvertDoubleOrFloatToMLFloat16IfNeeded(float val) { - return MLFloat16(val); -} - -template -ORT_FORCEINLINE constexpr typename std::enable_if_t, T> -ConvertDoubleOrFloatToMLFloat16IfNeeded(double val) { - return MLFloat16(static_cast(val)); -} +} // namespace template SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) - : OpKernel(op_kernel_info) { + : OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } @@ -94,8 +231,7 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* beta = p_ctx->Input(3); const Tensor* bias = p_ctx->Input(4); Tensor* output = p_ctx->Output(0, input->Shape()); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors + // For inferencing, we support one more optional output which is the sum of the input and skip tensors Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape()); const auto& input_dims = input->Shape().GetDims(); @@ -120,75 +256,44 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { T* output_data = output->MutableData(); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData() : nullptr; + // For inferencing, we support one more optional output which is the sum of the input and skip tensors + T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData(); - const auto& skip_size = skip->Shape().Size(); + const int64_t& skip_size = skip->Shape().Size(); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); concurrency::ThreadPool::TryBatchParallelFor( p_ctx->GetOperatorThreadPool(), static_cast(task_count), [&](ptrdiff_t task_idx) { - auto offset = task_idx * hidden_size; - - const T* p_input = input_data + offset; - const T* p_skip = skip_data + (offset % skip_size); - T* p_output = output_data + offset; - T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr; - - using DoubleOrFloat = typename std::conditional< - std::is_same::value, // If T is double - double, // Use double - float // Otherwise, use float (covers float and MLFloat16) - >::type; - - DoubleOrFloat mean(0.0f); - DoubleOrFloat mean_square(0.0f); - - std::unique_ptr output_buffer = std::make_unique(hidden_size); - for (size_t h = 0; h < static_cast(hidden_size); h++) { - DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); - DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_skip[h]); - - DoubleOrFloat value = input_value + skip_value; - - if (nullptr != bias_data) { - value += ConvertMLFloat16ToDoubleOrFloatIfNeeded(bias_data[h]); - } - - output_buffer[h] = value; - T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded(value); - if (nullptr != p_skip_input_bias_add_output_data) { - p_skip_input_bias_add_output_data[h] = converted_value; - } - - mean += value; - mean_square += value * value; - } - - mean = mean / hidden_size; - if (simplified) { - mean_square = sqrt(mean_square / hidden_size + epsilon_); - } else { - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); - } - - for (size_t h = 0; h < static_cast(hidden_size); h++) { - DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(gamma_data[h]); - if (simplified) { - p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded(output_buffer[h] / mean_square * gamma_value); - } else if (nullptr == beta_data) { - p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded((output_buffer[h] - mean) / mean_square * gamma_value); - } else { - DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(beta_data[h]); - p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded((output_buffer[h] - mean) / mean_square * gamma_value + beta_value); - } - } + ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_, + bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, + skip_input_bias_add_output_data, alloc); }, 0); return Status::OK(); } +template +Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + + is_packed = false; + if (input_idx == 1) { // skip + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed); + } else if (input_idx == 2) { // gamma + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed); + } else if (input_idx == 3) { // beta + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed); + } else if (input_idx == 4) { // bias + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed); + } + + return Status::OK(); +} + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 69edf4609e340..08e2276c3d9d5 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -16,8 +16,15 @@ class SkipLayerNorm final : public OpKernel { SkipLayerNorm(const OpKernelInfo& op_kernel_info); Status Compute(OpKernelContext* p_op_kernel_context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: float epsilon_; + mutable IAllocatorUniquePtr skip_fp32_; + mutable IAllocatorUniquePtr gamma_fp32_; + mutable IAllocatorUniquePtr beta_fp32_; + mutable IAllocatorUniquePtr bias_fp32_; }; } // namespace contrib diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 23630dcb63efa..f73efcddcedd4 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -5,6 +5,7 @@ #include "core/common/safeint.h" #include "core/framework/tensor.h" +#include "core/mlas/inc/mlas.h" #include "core/platform/threadpool.h" #include "core/providers/common.h" #include "core/util/force_inline.h" @@ -12,66 +13,166 @@ namespace onnxruntime { -// Utility to convert from MLFloat16 to float only when the input type is MLFloat16. -template -ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val); +namespace { -template <> -ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { - return val.ToFloat(); -} +template || std::is_same_v, void>> +void ComputeJob( + const T* X_data, + const T* scale_data, + const T* bias_data, + const ptrdiff_t task_idx, + const int64_t norm_size, + IAllocatorUniquePtr& scale_float_uptr, + IAllocatorUniquePtr& bias_float_uptr, + float epsilon, + bool simplified, + T* Y_data, + U* mean_data, + U* inv_std_dev_data, + AllocatorPtr alloc) { + ORT_UNUSED_PARAMETER(scale_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(alloc); + + const T* p_input = X_data + task_idx * norm_size; + T* p_output = Y_data + task_idx * norm_size; + + T mean(0.0f); + T mean_square(0.0f); + + for (int64_t h = 0; h < norm_size; h++) { + p_output[h] = p_input[h]; + mean += p_input[h]; + mean_square += p_input[h] * p_input[h]; + } -template <> -ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) { - return double(ConvertMLFloat16ToDoubleOrFloatIfNeeded(val)); -} + mean = mean / norm_size; + if (simplified) { + mean_square = sqrt(mean_square / norm_size + epsilon); + } else { + mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); + } -template <> -ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded(float val) { - return val; -} + for (int64_t h = 0; h < norm_size; h++) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * scale_data[h]; + } else if (nullptr == bias_data) { + p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h]; + } else { + p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h] + bias_data[h]; + } + } -template <> -ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded(double val) { - return val; -} + if (mean_data != nullptr) { + // ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow + mean_data[task_idx] = gsl::narrow_cast(mean); + } -ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(float val) { - return val; + if (inv_std_dev_data != nullptr) { + inv_std_dev_data[task_idx] = gsl::narrow_cast(1 / mean_square); + } } -ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(double val) { - // ONNX spec doesn't support 'double' for 'Ret' so when 'T' == double, 'Ret' == float and we need to narrow - return gsl::narrow_cast(val); -} +template +void ComputeJob( + const MLFloat16* X_data, + const MLFloat16* scale_data, + const MLFloat16* bias_data, + const ptrdiff_t task_idx, + const int64_t norm_size, + IAllocatorUniquePtr& scale_float_uptr, + IAllocatorUniquePtr& bias_float_uptr, + float epsilon, + bool simplified, + MLFloat16* Y_data, + U* mean_data, + U* inv_std_dev_data, + AllocatorPtr alloc) { + const MLFloat16* p_input = X_data + task_idx * norm_size; + MLFloat16* p_output = Y_data + task_idx * norm_size; + + float mean(0.0f); + float mean_square(0.0f); + + const size_t num_elems = static_cast(norm_size); + IAllocatorUniquePtr input_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems); + + IAllocatorUniquePtr output_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + float* output_float_ptr = output_float_uptr.get(); + + const float* input_float_ptr = input_float_uptr.get(); + for (size_t h = 0; h < num_elems; h++) { + output_float_ptr[h] = input_float_ptr[h]; + mean += input_float_ptr[h]; + mean_square += input_float_ptr[h] * input_float_ptr[h]; + } -// Function template that only converts the input value to MLFloat16 if T is MLFloat16. -template -ORT_FORCEINLINE constexpr typename std::enable_if_t || std::is_same_v, float> -ConvertToMLFloat16IfNeeded(float val) { - return val; -} + mean = mean / norm_size; + if (simplified) { + mean_square = sqrt(mean_square / norm_size + epsilon); + } else { + mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); + } -template -ORT_FORCEINLINE constexpr typename std::enable_if_t, MLFloat16> -ConvertToMLFloat16IfNeeded(float val) { - return MLFloat16(val); + if (!scale_float_uptr) { + scale_float_uptr = std::move(input_float_uptr); // overwrite input with scale values, since they have the same size + MlasConvertHalfToFloatBuffer(scale_data, scale_float_uptr.get(), num_elems); + } + + if (bias_data && !bias_float_uptr) { + bias_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems); + } + + const float* scale_float_ptr = scale_float_uptr.get(); + const float* bias_float_ptr = bias_float_uptr.get(); + for (size_t h = 0; h < num_elems; h++) { + if (simplified) { + output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h]; + } else if (nullptr == bias_data) { + output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h]; + } else { + output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h]; + } + } + + MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems); + + if (mean_data != nullptr) { + // ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow + mean_data[task_idx] = MLFloat16(mean); + } + + if (inv_std_dev_data != nullptr) { + inv_std_dev_data[task_idx] = MLFloat16(1 / mean_square); + } } -template -ORT_FORCEINLINE constexpr double ConvertToMLFloat16IfNeeded(double val) { - return val; +void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr& dest, bool& is_packed) { + if (tensor.GetElementType() == utils::ToTensorProtoElementType()) { + auto tensor_data_ptr = tensor.Data(); + auto tensor_size = static_cast(tensor.Shape().Size()); + auto float_ptr = IAllocator::MakeUniquePtr(alloc, tensor_size, true); + + MlasConvertHalfToFloatBuffer(tensor_data_ptr, float_ptr.get(), tensor_size); + dest = std::move(float_ptr); + is_packed = true; + } } +} // namespace + LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op) - : OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op} { + : OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op}, scale_fp32_(nullptr), bias_fp32_(nullptr) { ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); } -namespace { template -Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) { +Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const { // Inputs const Tensor* X = p_ctx->Input(0); const Tensor* scale = p_ctx->Input(1); @@ -81,21 +182,12 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data(); const TensorShape& x_shape = X->Shape(); - const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions()); - int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); - int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); - - const auto scale_size = scale->Shape().Size(); - const auto bias_size = (bias_data) ? bias->Shape().Size() : 0; - if (scale_size != norm_size || (bias_data && bias_size != norm_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Size of X.shape()[axis:] == ", norm_size, - ". Size of scale and bias (if provided) must match this. Got scale size of ", - scale_size, " and bias size of ", bias_size); - } - + const TensorShape& scale_shape = scale->Shape(); + const TensorShape& bias_shape = bias->Shape(); Tensor* Y = p_ctx->Output(0, x_shape); - auto Y_data = Y->MutableData(); + T* Y_data = Y->MutableData(); + + const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions()); std::vector mean_inv_std_dev_dim; mean_inv_std_dev_dim.reserve(x_shape.NumDimensions()); @@ -107,11 +199,7 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo } } - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); - int output_index = 1; - U* mean_data = nullptr; if (!simplified) { Tensor* mean = p_ctx->Output(output_index++, TensorShape(mean_inv_std_dev_dim)); @@ -126,87 +214,74 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo inv_std_dev_data = inv_std_dev->MutableData(); } - concurrency::ThreadPool::TryBatchParallelFor( - p_ctx->GetOperatorThreadPool(), static_cast(norm_count), - [&](ptrdiff_t task_idx) { - const T* p_input = X_data + task_idx * norm_size; - T* p_output = Y_data + task_idx * norm_size; - - using DoubleOrFloat = typename std::conditional< - std::is_same::value, // If T is double - double, // Use double - float // Otherwise, use float (covers float and MLFloat16) - >::type; - - DoubleOrFloat mean(0.0f); - DoubleOrFloat mean_square(0.0f); - - for (int64_t h = 0; h < norm_size; h++) { - DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); - mean += input_value; - mean_square += input_value * input_value; - } - - mean = mean / norm_size; - if (simplified) { - mean_square = sqrt(mean_square / norm_size + epsilon); - } else { - mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); - } - - for (int64_t h = 0; h < norm_size; h++) { - DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]); - DoubleOrFloat scale_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(scale_data[h]); - if (simplified) { - p_output[h] = ConvertToMLFloat16IfNeeded(input_value / mean_square * scale_value); - } else if (nullptr == bias) { - p_output[h] = ConvertToMLFloat16IfNeeded((input_value - mean) / mean_square * scale_value); - } else { - DoubleOrFloat bias_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(bias_data[h]); - p_output[h] = ConvertToMLFloat16IfNeeded((input_value - mean) / mean_square * scale_value + bias_value); - } - } - - if (mean_data != nullptr) { - // ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow - mean_data[task_idx] = ConvertToMLFloat16IfNeeded(ConvertToFloatIfNeeded(mean)); - } - - if (inv_std_dev_data != nullptr) { - inv_std_dev_data[task_idx] = ConvertToMLFloat16IfNeeded(ConvertToFloatIfNeeded(1 / mean_square)); - } - }, - 0); + onnxruntime::concurrency::ThreadPool* thread_pool = p_ctx->GetOperatorThreadPool(); - return Status::OK(); + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + return ComputeWithoutContext(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data, + inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc); } -template -struct SrcDispatcher { - Status operator()(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified, bool contrib_op) const { - // the contrib op kernel was always registered with the same type for all constraints. - // our implementation of the onnx op only supports 'float' as the U constraint. -#if !defined(DISABLE_CONTRIB_OPS) - if (contrib_op) { - return ComputeImpl(p_ctx, orig_axis, epsilon, simplified); - } else -#else - ORT_UNUSED_PARAMETER(contrib_op); -#endif - { - return ComputeImpl(p_ctx, orig_axis, epsilon, simplified); - } - } -}; -} // namespace - Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const { const auto elem_type = p_ctx->Input(0)->GetElementType(); using SupportedTypeList = boost::mp11::mp_list; utils::MLTypeCallDispatcherFromTypeList t_disp(elem_type); - return t_disp.InvokeRet(p_ctx, axis_, epsilon_, simplified_, contrib_op_); + return t_disp.InvokeRet(this, p_ctx, axis_, epsilon_, simplified_, contrib_op_); +} + +Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) { + ORT_UNUSED_PARAMETER(prepacked_weights); + + is_packed = false; + if (input_idx == 1) { // scale + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, scale_fp32_, is_packed); + } else if (input_idx == 2) { // bias + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed); + } + + return Status::OK(); +} + +template +Status LayerNormImpl::ComputeWithoutContext( + const T* X_data, + const TensorShape& x_shape, + const T* scale_data, + const TensorShape& scale_shape, + const T* bias_data, + const TensorShape& bias_shape, + T* Y_data, + U* mean_data, + U* inv_std_dev_data, + onnxruntime::concurrency::ThreadPool* thread_pool, + int64_t axis, + float epsilon, + bool simplified, + AllocatorPtr alloc) const { + int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); + int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); + + const auto scale_size = scale_shape.Size(); + const auto bias_size = (bias_data) ? bias_shape.Size() : 0; + if (scale_size != norm_size || (bias_data && bias_size != norm_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale and bias (if provided) must match this. Got scale size of ", + scale_size, " and bias size of ", bias_size); + } + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, static_cast(norm_count), + [&](ptrdiff_t task_idx) { + ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, scale_fp32_, bias_fp32_, + epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc); + }, + 0); + + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index 393c637dbda18..f6325c31cc71a 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -4,6 +4,7 @@ #pragma once #include "core/common/common.h" +#include "core/framework/allocator.h" #include "core/framework/op_kernel.h" #include "core/framework/tensor.h" @@ -14,11 +15,56 @@ class LayerNormImpl : public OpKernel { LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified = false, bool contrib_op = false); Status Compute(OpKernelContext* p_op_kernel_context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + + // This method was created so that it can be called directly from `test/onnx/microbenchmark/layer_normalization.cc`. + template + Status ComputeWithoutContext( + const T* X_data, + const TensorShape& x_shape, + const T* scale_data, + const TensorShape& scale_shape, + const T* bias_data, + const TensorShape& bias_shape, + T* Y_data, + U* mean_data, + U* inv_std_dev, + onnxruntime::concurrency::ThreadPool* thread_pool, + int64_t axis, + float epsilon, + bool simplified, + AllocatorPtr alloc) const; + private: + template + Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const; + + template + struct SrcDispatcher { + Status operator()(const LayerNormImpl* p_instance, OpKernelContext* p_ctx, int64_t orig_axis, + float epsilon, bool simplified, bool contrib_op) const { + // the contrib op kernel was always registered with the same type for all constraints. + // our implementation of the onnx op only supports 'float' as the U constraint. +#if !defined(DISABLE_CONTRIB_OPS) + if (contrib_op) { + return p_instance->ComputeImpl(p_ctx, orig_axis, epsilon, simplified); + } else +#else + ORT_UNUSED_PARAMETER(contrib_op); +#endif + { + return p_instance->ComputeImpl(p_ctx, orig_axis, epsilon, simplified); + } + } + }; + int64_t axis_; float epsilon_; const bool simplified_; const bool contrib_op_; + mutable IAllocatorUniquePtr scale_fp32_; + mutable IAllocatorUniquePtr bias_fp32_; }; } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc new file mode 100644 index 0000000000000..75ce7b77acd4e --- /dev/null +++ b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc @@ -0,0 +1,141 @@ +#ifdef _WIN32 + +#include "core/platform/threadpool.h" +#include "core/util/thread_utils.h" +#include + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + +#include "core/framework/allocator.h" +#include "core/framework/config_options.h" +#include "core/framework/data_transfer_manager.h" +#include "core/framework/op_kernel_info.h" +#include "core/framework/ort_value_name_idx_map.h" +#include "core/platform/windows/env.h" +#include "core/providers/cpu/nn/layer_norm_impl.h" +#include "core/providers/cpu/cpu_provider_factory.h" +#include "core/providers/cpu/cpu_provider_factory_creator.h" +#include "core/util/thread_utils.h" + +#include "test/onnx/microbenchmark/common.h" + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + +using namespace onnxruntime; + +namespace { + +std::vector createMLFloat16Vector(float* vals, int64_t num_elems) { + std::vector fp16vec; + fp16vec.reserve(num_elems); + + for (int64_t i = 0; i < num_elems; i++) { + fp16vec.push_back(MLFloat16(vals[i])); + } + + return fp16vec; +} + +} // namespace + +template +static void BM_LayerNormalization(benchmark::State& state) { + bool simplified = false; + const float epsilon = 1e-05f; + int64_t axis = 1; + + onnxruntime::Node node; + // Required by LayerNormImpl constructor + node.AddAttribute("axis", axis); + node.AddAttribute("epsilon", epsilon); + + KernelDef kernel_def; + std::unique_ptr execution_provider = CPUProviderFactoryCreator::Create(true)->CreateProvider(); + std::unordered_map constant_initialized_tensors; + OrtValueNameIdxMap mlvalue_name_idx_map; + DataTransferManager data_transfer_mgr; + AllocatorMap allocators; + ConfigOptions config_options; + + OpKernelInfo op_kernel_info(node, kernel_def, *execution_provider, constant_initialized_tensors, mlvalue_name_idx_map, + data_transfer_mgr, allocators, config_options); + + LayerNormImpl layer_norm_impl(op_kernel_info); + + const std::vector dims{1, 256, 1024}; + const size_t num_elems = dims[0] * dims[1] * dims[2]; + + TensorShape x_shape(dims); + TensorShape scale_shape(dims); + TensorShape bias_shape(dims); + + const float low = -1.0f; + const float high = 1.0f; + + float* x_float = GenerateArrayWithRandomValue(num_elems, low, high); + float* scale_float = GenerateArrayWithRandomValue(num_elems, 0.1f, high); + float* bias_float = GenerateArrayWithRandomValue(num_elems, low, high); + + std::vector x_MLFloat16 = createMLFloat16Vector(x_float, num_elems); + std::vector scale_MLFloat16 = createMLFloat16Vector(scale_float, num_elems); + std::vector bias_MLFloat16 = createMLFloat16Vector(bias_float, num_elems); + + T* x_data = nullptr; + T* scale_data = nullptr; + T* bias_data = nullptr; + if (std::is_same_v) { + x_data = (T*)x_MLFloat16.data(); + scale_data = (T*)scale_MLFloat16.data(); + bias_data = (T*)bias_MLFloat16.data(); + } else if (std::is_same_v) { + x_data = (T*)x_float; + scale_data = (T*)scale_float; + bias_data = (T*)bias_float; + } + assert(x_data); + + T* Y_data = static_cast(aligned_alloc(num_elems * sizeof(T), 64)); + U* mean_data = static_cast(aligned_alloc(num_elems * sizeof(U), 64)); + U* inv_std_dev_data = static_cast(aligned_alloc(num_elems * sizeof(U), 64)); + + OrtThreadPoolParams tp_params; + tp_params.name = ORT_TSTR("intra-op"); + std::unique_ptr thread_pool = concurrency::CreateThreadPool( + &Env::Default(), tp_params, concurrency::ThreadPoolType::INTRA_OP); + + OrtMemoryInfo memory_info(onnxruntime::CPU, OrtAllocatorType::OrtArenaAllocator); + AllocatorPtr alloc = std::make_shared(memory_info); + for (auto _ : state) { + auto status = layer_norm_impl.ComputeWithoutContext(x_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, + Y_data, mean_data, inv_std_dev_data, thread_pool.get(), axis, + epsilon, simplified, alloc); + if (!status.IsOK()) { + std::cout << "ComputeWithoutContext status not OK: " << status.ErrorMessage() << std::endl; + break; + } + } + + aligned_free(x_float); + aligned_free(scale_float); + aligned_free(bias_float); + aligned_free(Y_data); + aligned_free(mean_data); + aligned_free(inv_std_dev_data); +} + +BENCHMARK(BM_LayerNormalization) + ->Arg(1) + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +BENCHMARK(BM_LayerNormalization) + ->Arg(1) + ->UseRealTime() + ->Unit(benchmark::TimeUnit::kMicrosecond); + +#endif