Skip to content

Commit

Permalink
Fix SIMD header value check on char >= 0x80
Browse files Browse the repository at this point in the history
The SIMD intrinsics *_cmpgt_epi8 are for signed chars.
This change correctly performs the unsigned comparison.
  • Loading branch information
eaufavor authored and seanmonstar committed Jan 14, 2022
1 parent 507f582 commit 454efce
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
32 changes: 26 additions & 6 deletions src/simd/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ unsafe fn match_header_value_char_32_avx(buf: &[u8]) -> usize {
// %x09 %x20-%x7e %x80-%xff
let TAB: __m256i = _mm256_set1_epi8(0x09);
let DEL: __m256i = _mm256_set1_epi8(0x7f);
let LOW: __m256i = _mm256_set1_epi8(0x1f);
let LOW: __m256i = _mm256_set1_epi8(0x20);

let dat = _mm256_lddqu_si256(ptr as *const _);
let low = _mm256_cmpgt_epi8(dat, LOW);
// unsigned comparison dat >= LOW
let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat);
let tab = _mm256_cmpeq_epi8(dat, TAB);
let del = _mm256_cmpeq_epi8(dat, DEL);
let bit = _mm256_andnot_si256(del, _mm256_or_si256(low, tab));
Expand All @@ -126,19 +127,38 @@ fn avx2_code_matches_uri_chars_table() {
}

unsafe {
assert!(byte_is_allowed(b'_'));
assert!(byte_is_allowed(b'_', parse_uri_batch_32));

for (b, allowed) in ::URI_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8), allowed,
byte_is_allowed(b as u8, parse_uri_batch_32), allowed,
"byte_is_allowed({:?}) should be {:?}", b, allowed,
);
}
}
}

#[test]
fn avx2_code_matches_header_value_chars_table() {
match super::detect() {
super::AVX_2 | super::AVX_2_AND_SSE_42 => {},
_ => return,
}

unsafe {
assert!(byte_is_allowed(b'_', match_header_value_batch_32));

for (b, allowed) in ::HEADER_VALUE_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8, match_header_value_batch_32), allowed,
"byte_is_allowed({:?}) should be {:?}", b, allowed,
);
}
}
}

#[cfg(test)]
unsafe fn byte_is_allowed(byte: u8) -> bool {
unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>) -> Scan) -> bool {
let slice = [
b'_', b'_', b'_', b'_',
b'_', b'_', b'_', b'_',
Expand All @@ -151,7 +171,7 @@ unsafe fn byte_is_allowed(byte: u8) -> bool {
];
let mut bytes = Bytes::new(&slice);

parse_uri_batch_32(&mut bytes);
f(&mut bytes);

match bytes.pos() {
32 => true,
Expand Down
32 changes: 26 additions & 6 deletions src/simd/sse42.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ unsafe fn match_header_value_char_16_sse(buf: &[u8]) -> usize {
// %x09 %x20-%x7e %x80-%xff
let TAB: __m128i = _mm_set1_epi8(0x09);
let DEL: __m128i = _mm_set1_epi8(0x7f);
let LOW: __m128i = _mm_set1_epi8(0x1f);
let LOW: __m128i = _mm_set1_epi8(0x20);

let dat = _mm_lddqu_si128(ptr as *const _);
let low = _mm_cmpgt_epi8(dat, LOW);
// unsigned comparison dat >= LOW
let low = _mm_cmpeq_epi8(_mm_max_epu8(dat, LOW), dat);
let tab = _mm_cmpeq_epi8(dat, TAB);
let del = _mm_cmpeq_epi8(dat, DEL);
let bit = _mm_andnot_si128(del, _mm_or_si128(low, tab));
Expand All @@ -106,19 +107,38 @@ fn sse_code_matches_uri_chars_table() {
}

unsafe {
assert!(byte_is_allowed(b'_'));
assert!(byte_is_allowed(b'_', parse_uri_batch_16));

for (b, allowed) in ::URI_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8), allowed,
byte_is_allowed(b as u8, parse_uri_batch_16), allowed,
"byte_is_allowed({:?}) should be {:?}", b, allowed,
);
}
}
}

#[test]
fn sse_code_matches_header_value_chars_table() {
match super::detect() {
super::SSE_42 | super::AVX_2_AND_SSE_42 => {},
_ => return,
}

unsafe {
assert!(byte_is_allowed(b'_', match_header_value_batch_16));

for (b, allowed) in ::HEADER_VALUE_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8, match_header_value_batch_16), allowed,
"byte_is_allowed({:?}) should be {:?}", b, allowed,
);
}
}
}

#[cfg(test)]
unsafe fn byte_is_allowed(byte: u8) -> bool {
unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool {
let slice = [
b'_', b'_', b'_', b'_',
b'_', b'_', b'_', b'_',
Expand All @@ -127,7 +147,7 @@ unsafe fn byte_is_allowed(byte: u8) -> bool {
];
let mut bytes = Bytes::new(&slice);

parse_uri_batch_16(&mut bytes);
f(&mut bytes);

match bytes.pos() {
16 => true,
Expand Down

0 comments on commit 454efce

Please sign in to comment.