Skip to content

Commit

Permalink
fix: rename mul_add to mul_add_round
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi committed Aug 30, 2024
1 parent b2434ff commit 80c772a
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions crates/quantization/src/quantize.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use base::scalar::*;

mod mul_add {
mod mul_add_round {
#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v4")]
unsafe fn mul_add_v4(this: &[f32], k: f32, b: f32) -> Vec<u8> {
unsafe fn mul_add_round_v4(this: &[f32], k: f32, b: f32) -> Vec<u8> {
let n = this.len();
let mut r = Vec::<u8>::with_capacity(n);
unsafe {
Expand Down Expand Up @@ -42,7 +42,7 @@ mod mul_add {

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn mul_add_v4_test() {
fn mul_add_round_v4_test() {
detect::init();
if !detect::v4::detect() {
println!("test {} ... skipped (v4)", module_path!());
Expand All @@ -55,16 +55,16 @@ mod mul_add {
let x = &x[..z];
let k = 20.0;
let b = 20.0;
let specialized = unsafe { mul_add_v4(x, k, b) };
let fallback = unsafe { mul_add_fallback(x, k, b) };
let specialized = unsafe { mul_add_round_v4(x, k, b) };
let fallback = unsafe { mul_add_round_fallback(x, k, b) };
assert_eq!(specialized, fallback);
}
}
}

#[cfg(target_arch = "x86_64")]
#[detect::target_cpu(enable = "v3")]
unsafe fn mul_add_v3(this: &[f32], k: f32, b: f32) -> Vec<u8> {
unsafe fn mul_add_round_v3(this: &[f32], k: f32, b: f32) -> Vec<u8> {
let n = this.len();
let mut r = Vec::<u8>::with_capacity(n);
unsafe {
Expand Down Expand Up @@ -111,7 +111,7 @@ mod mul_add {

#[cfg(all(target_arch = "x86_64", test))]
#[test]
fn mul_add_v3_test() {
fn mul_add_round_v3_test() {
detect::init();
if !detect::v3::detect() {
println!("test {} ... skipped (v3)", module_path!());
Expand All @@ -124,15 +124,15 @@ mod mul_add {
let x = &x[..z];
let k = 20.0;
let b = 20.0;
let specialized = unsafe { mul_add_v3(x, k, b) };
let fallback = unsafe { mul_add_fallback(x, k, b) };
let specialized = unsafe { mul_add_round_v3(x, k, b) };
let fallback = unsafe { mul_add_round_fallback(x, k, b) };
assert_eq!(specialized, fallback);
}
}
}

#[detect::multiversion(v4 = import, v3 = import, v2, neon, fallback = export)]
pub fn mul_add(this: &[f32], k: f32, b: f32) -> Vec<u8> {
pub fn mul_add_round(this: &[f32], k: f32, b: f32) -> Vec<u8> {
let n = this.len();
let mut r = Vec::<u8>::with_capacity(n);
for i in 0..n {
Expand All @@ -154,7 +154,7 @@ pub fn quantize<const N: u8>(lut: &[f32]) -> (f32, f32, Vec<u8>) {
let (min, max) = f32::reduce_min_max_of_x(lut);
let k = 0.0f32.max((max - min) / (N as f32));
let b = min;
(k, b, mul_add::mul_add(lut, 1.0 / k, -b / k))
(k, b, mul_add_round::mul_add_round(lut, 1.0 / k, -b / k))
}

#[inline(always)]
Expand Down

0 comments on commit 80c772a

Please sign in to comment.