Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add microbenchmark for layer normalization and improve latency #22223

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
${BENCHMARK_DIR}/gelu.cc
${BENCHMARK_DIR}/activation.cc
${BENCHMARK_DIR}/quantize.cc
${BENCHMARK_DIR}/reduceminmax.cc)
${BENCHMARK_DIR}/reduceminmax.cc
${BENCHMARK_DIR}/layer_normalization.cc)
target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc)
target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE)
if(WIN32)
Expand Down
237 changes: 140 additions & 97 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/framework/tensor.h"
#include "core/mlas/inc/mlas.h"
#include "core/util/math_cpuonly.h"
#include "core/providers/common.h"
#include "core/platform/threadpool.h"
Expand Down Expand Up @@ -36,48 +37,145 @@ REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
REGISTER_KERNEL_TYPED(MLFloat16)

// Utility to convert from MLFloat16 to float only when the input type is MLFloat16.
template <typename T, typename Ret>
ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);

template <>
ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float>(MLFloat16 val) {
return val.ToFloat();
namespace {

template <typename T, typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, void>>
void ComputeJob(
const T* input_data,
const T* skip_data,
const T* gamma_data,
const T* beta_data,
const T* bias_data,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
T* output_data,
T* skip_input_bias_add_output_data
) {
auto offset = task_idx * hidden_size;
const T* p_input = input_data + offset;
const T* p_skip = skip_data + (offset % skip_size);
T* p_output = output_data + offset;
T* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;

T mean(0.0f);
T mean_square(0.0f);

for (decltype(hidden_size) h = 0; h < hidden_size; h++) {
T val = p_input[h] + p_skip[h];

if (nullptr != bias_data) {
val += bias_data[h];
}

if (nullptr != p_skip_input_bias_add_output) {
p_skip_input_bias_add_output[h] = val;
}

p_output[h] = val;
mean += val;
mean_square += val * val;
}

mean = mean / hidden_size;
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

for (decltype(hidden_size) h = 0; h < hidden_size; h++) {
if (simplified) {
p_output[h] = p_output[h] / mean_square * gamma_data[h];
} else if (nullptr == beta_data) {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
}
}
}

template <>
ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, double>(MLFloat16 val) {
return static_cast<double>(ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float>(val));
void ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
const MLFloat16* gamma_data,
const MLFloat16* beta_data,
const MLFloat16* bias_data,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
MLFloat16* output_data,
MLFloat16* skip_input_bias_add_output_data
) {
auto offset = task_idx * hidden_size;
const MLFloat16* p_input = input_data + offset;
const MLFloat16* p_skip = skip_data + (offset % skip_size);
MLFloat16* p_output = output_data + offset;
MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;

float mean(0.0f);
float mean_square(0.0f);

std::vector<float> float_input(hidden_size);
MlasConvertHalfToFloatBuffer(p_input, &float_input[0], hidden_size);
std::vector<float> float_skip(hidden_size);
MlasConvertHalfToFloatBuffer(p_skip, &float_skip[0], hidden_size);
std::vector<float> float_bias;
if (bias_data != nullptr) {
float_bias.resize(hidden_size);
MlasConvertHalfToFloatBuffer(bias_data, &float_bias[0], hidden_size);
}

std::vector<float> float_output(hidden_size);

for (decltype(hidden_size) h = 0; h < hidden_size; h++) {
float val = float_input[h] + float_skip[h];

if (nullptr != bias_data) {
val += float_bias[h];
}

float_output[h] = val;
mean += val;
mean_square += val * val;
}

if (nullptr != p_skip_input_bias_add_output) {
MlasConvertFloatToHalfBuffer(&float_output[0], p_skip_input_bias_add_output, hidden_size);
}

mean = mean / hidden_size;
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

std::vector<float> float_gamma(hidden_size);
MlasConvertHalfToFloatBuffer(gamma_data, &float_gamma[0], hidden_size);
std::vector<float> float_beta(hidden_size);
MlasConvertHalfToFloatBuffer(beta_data, &float_beta[0], hidden_size);

for (decltype(hidden_size) h = 0; h < hidden_size; h++) {
if (simplified) {
float_output[h] = float_output[h] / mean_square * float_gamma[h];
} else if (nullptr == beta_data) {
float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h];
} else {
float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h] + float_beta[h];
}
}

MlasConvertFloatToHalfBuffer(&float_output[0], p_output, hidden_size);
}

template <>
ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded<float, float>(float val) {
return val;
}

template <>
ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded<double, double>(double val) {
return val;
}

// Function template that only converts the input value to MLFloat16 if T is MLFloat16.
template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(T val) {
return val;
}

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(float val) {
return MLFloat16(val);
}
} // namespace

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(double val) {
return MLFloat16(static_cast<float>(val));
}

template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
Expand All @@ -94,8 +192,7 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const Tensor* beta = p_ctx->Input<Tensor>(3);
const Tensor* bias = p_ctx->Input<Tensor>(4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());

const auto& input_dims = input->Shape().GetDims();
Expand All @@ -120,70 +217,16 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {

T* output_data = output->MutableData<T>();

// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData<T>() : nullptr;
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData<T>();

const auto& skip_size = skip->Shape().Size();
const int64_t& skip_size = skip->Shape().Size();

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
auto offset = task_idx * hidden_size;

const T* p_input = input_data + offset;
const T* p_skip = skip_data + (offset % skip_size);
T* p_output = output_data + offset;
T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr;

using DoubleOrFloat = typename std::conditional<
std::is_same<T, double>::value, // If T is double
double, // Use double
float // Otherwise, use float (covers float and MLFloat16)
>::type;

DoubleOrFloat mean(0.0f);
DoubleOrFloat mean_square(0.0f);

std::unique_ptr<DoubleOrFloat[]> output_buffer = std::make_unique<DoubleOrFloat[]>(hidden_size);
for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_input[h]);
DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_skip[h]);

DoubleOrFloat value = input_value + skip_value;

if (nullptr != bias_data) {
value += ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(bias_data[h]);
}

output_buffer[h] = value;
T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(value);
if (nullptr != p_skip_input_bias_add_output_data) {
p_skip_input_bias_add_output_data[h] = converted_value;
}

mean += value;
mean_square += value * value;
}

mean = mean / hidden_size;
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon_);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
}

for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(gamma_data[h]);
if (simplified) {
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(output_buffer[h] / mean_square * gamma_value);
} else if (nullptr == beta_data) {
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value);
} else {
DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(beta_data[h]);
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value + beta_value);
}
}
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size, epsilon_,
simplified, output_data, skip_input_bias_add_output_data);
},
0);

Expand Down
Loading
Loading