Skip to content

Commit

Permalink
aead: Support stacked borrows model using a new InOut type.
Browse files Browse the repository at this point in the history
Notably, `InOut::input_output_len` constructs the `input` pointer
from the `output` pointer in a way that safely avoids any concerns
about the order of borrowing the (now implicit) input slice and
output slice, and in particular whether any such borrowing
invalidates any pointers derived from those slices.

Practically, this helps people who are using Miri in its default
stacked borrows mode (as opposed to the tree borrows mode)
verify the memory safety of our code.
  • Loading branch information
briansmith committed Dec 5, 2024
1 parent 224bd7d commit 50fd56a
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 61 deletions.
4 changes: 3 additions & 1 deletion src/aead.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2015-2021 Brian Smith.
// Copyright 2015-2024 Brian Smith.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
Expand Down Expand Up @@ -34,6 +34,7 @@ pub use self::{
sealing_key::SealingKey,
unbound_key::UnboundKey,
};
use inout::InOut;

/// A sequences of unique nonces.
///
Expand Down Expand Up @@ -175,6 +176,7 @@ mod chacha;
mod chacha20_poly1305;
pub mod chacha20_poly1305_openssh;
mod gcm;
mod inout;
mod less_safe_key;
mod nonce;
mod opening_key;
Expand Down
7 changes: 3 additions & 4 deletions src/aead/aes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{nonce::Nonce, quic::Sample, NONCE_LEN};
use super::{nonce::Nonce, quic::Sample, InOut, NONCE_LEN};
use crate::{
constant_time,
cpu::{self, GetFeature as _},
error,
};
use cfg_if::cfg_if;
use core::ops::RangeFrom;

pub(super) use ffi::Counter;

Expand Down Expand Up @@ -158,7 +157,7 @@ pub(super) trait EncryptBlock {
}

pub(super) trait EncryptCtr32 {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter);
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter);
}

#[allow(dead_code)]
Expand All @@ -178,7 +177,7 @@ fn encrypt_iv_xor_block_using_encrypt_block(
#[allow(dead_code)]
fn encrypt_iv_xor_block_using_ctr32(key: &impl EncryptCtr32, iv: Iv, mut block: Block) -> Block {
let mut ctr = Counter(iv.0); // This is OK because we're only encrypting one block.
key.ctr32_encrypt_within(&mut block, 0.., &mut ctr);
key.ctr32_encrypt_within(InOut::in_place(&mut block), &mut ctr);
block
}

Expand Down
8 changes: 3 additions & 5 deletions src/aead/aes/bs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

#![cfg(target_arch = "arm")]

use super::{Counter, AES_KEY};
use core::ops::RangeFrom;
use super::{Counter, InOut, AES_KEY};

/// SAFETY:
/// * The caller must ensure that if blocks > 0 then either `input` and
Expand All @@ -28,8 +27,7 @@ use core::ops::RangeFrom;
/// * Upon returning, `blocks` blocks will have been read from `input` and
/// written to `output`.
pub(super) unsafe fn ctr32_encrypt_blocks_with_vpaes_key(
in_out: &mut [u8],
src: RangeFrom<usize>,
in_out: InOut<'_>,
vpaes_key: &AES_KEY,
ctr: &mut Counter,
) {
Expand Down Expand Up @@ -57,6 +55,6 @@ pub(super) unsafe fn ctr32_encrypt_blocks_with_vpaes_key(
// * `bsaes_ctr32_encrypt_blocks` satisfies the contract for
// `ctr32_encrypt_blocks`.
unsafe {
ctr32_encrypt_blocks!(bsaes_ctr32_encrypt_blocks, in_out, src, &bsaes_key, ctr);
ctr32_encrypt_blocks!(bsaes_ctr32_encrypt_blocks, in_out, &bsaes_key, ctr);
}
}
9 changes: 3 additions & 6 deletions src/aead/aes/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY};
use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY};
use crate::error;
use core::ops::RangeFrom;

#[derive(Clone)]
pub struct Key {
Expand All @@ -39,9 +38,7 @@ impl EncryptBlock for Key {
}

impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
unsafe {
ctr32_encrypt_blocks!(aes_nohw_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr)
}
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
unsafe { ctr32_encrypt_blocks!(aes_nohw_ctr32_encrypt_blocks, in_out, &self.inner, ctr) }
}
}
24 changes: 11 additions & 13 deletions src/aead/aes/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{Block, KeyBytes, BLOCK_LEN};
use crate::{bits::BitLength, c, error, polyfill::slice};
use core::{num::NonZeroUsize, ops::RangeFrom};
use super::{Block, InOut, KeyBytes, BLOCK_LEN};
use crate::{bits::BitLength, c, error};
use core::num::NonZeroUsize;

/// nonce || big-endian counter.
#[repr(transparent)]
Expand Down Expand Up @@ -127,7 +127,7 @@ impl AES_KEY {
/// * The caller must ensure that fhe function `$name` satisfies the conditions
/// for the `f` parameter to `ctr32_encrypt_blocks`.
macro_rules! ctr32_encrypt_blocks {
($name:ident, $in_out:expr, $src:expr, $key:expr, $ctr:expr $(,)? ) => {{
($name:ident, $in_out:expr, $key:expr, $ctr:expr $(,)? ) => {{
use crate::{
aead::aes::{ffi::AES_KEY, Counter, BLOCK_LEN},
c,
Expand All @@ -141,7 +141,7 @@ macro_rules! ctr32_encrypt_blocks {
ivec: &Counter,
);
}
$key.ctr32_encrypt_blocks($name, $in_out, $src, $ctr)
$key.ctr32_encrypt_blocks($name, $in_out, $ctr)
}};
}

Expand All @@ -167,25 +167,23 @@ impl AES_KEY {
key: &AES_KEY,
ivec: &Counter,
),
in_out: &mut [u8],
src: RangeFrom<usize>,
mut in_out: InOut<'_>,
ctr: &mut Counter,
) {
let (input, leftover) = slice::as_chunks(&in_out[src]);
debug_assert_eq!(leftover.len(), 0);
let (input, output, len) = in_out.input_output_len();
debug_assert_eq!(len % BLOCK_LEN, 0);

let blocks = match NonZeroUsize::new(input.len()) {
let blocks = match NonZeroUsize::new(len / BLOCK_LEN) {
Some(blocks) => blocks,
None => {
return;
}
};

let input: *const [u8; BLOCK_LEN] = input.cast();
let output: *mut [u8; BLOCK_LEN] = output.cast();
let blocks_u32: u32 = blocks.get().try_into().unwrap();

let input = input.as_ptr();
let output: *mut [u8; BLOCK_LEN] = in_out.as_mut_ptr().cast();

// SAFETY:
// * `input` points to `blocks` blocks.
// * `output` points to space for `blocks` blocks to be written.
Expand Down
7 changes: 3 additions & 4 deletions src/aead/aes/hw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

#![cfg(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64"))]

use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY};
use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY};
use crate::{cpu, error};
use core::ops::RangeFrom;

#[cfg(target_arch = "aarch64")]
pub(in super::super) type RequiredCpuFeatures = cpu::arm::Aes;
Expand Down Expand Up @@ -56,9 +55,9 @@ impl EncryptBlock for Key {
}

impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
#[cfg(target_arch = "x86_64")]
let _: cpu::Features = cpu::features();
unsafe { ctr32_encrypt_blocks!(aes_hw_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) }
unsafe { ctr32_encrypt_blocks!(aes_hw_ctr32_encrypt_blocks, in_out, &self.inner, ctr) }
}
}
28 changes: 13 additions & 15 deletions src/aead/aes/vp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
target_arch = "x86_64"
))]

use super::{Block, Counter, EncryptBlock, EncryptCtr32, Iv, KeyBytes, AES_KEY};
use super::{Block, Counter, EncryptBlock, EncryptCtr32, InOut, Iv, KeyBytes, AES_KEY};
use crate::{cpu, error};
use core::ops::RangeFrom;

#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
type RequiredCpuFeatures = cpu::arm::Neon;
Expand Down Expand Up @@ -57,17 +56,18 @@ impl EncryptBlock for Key {

#[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) }
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, &self.inner, ctr) }
}
}

#[cfg(target_arch = "arm")]
impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
use super::{bs, BLOCK_LEN};

let in_out = {
let (in_out, src) = in_out.into_slice_src_mut();
let blocks = in_out[src.clone()].len() / BLOCK_LEN;

// bsaes operates in batches of 8 blocks.
Expand All @@ -84,28 +84,26 @@ impl EncryptCtr32 for Key {
0
};
let bsaes_in_out_len = bsaes_blocks * BLOCK_LEN;
let bs_in_out =
InOut::overlapping(&mut in_out[..(src.start + bsaes_in_out_len)], src.clone())
.unwrap();

// SAFETY:
// * self.inner was initialized with `vpaes_set_encrypt_key` above,
// as required by `bsaes_ctr32_encrypt_blocks_with_vpaes_key`.
unsafe {
bs::ctr32_encrypt_blocks_with_vpaes_key(
&mut in_out[..(src.start + bsaes_in_out_len)],
src.clone(),
&self.inner,
ctr,
);
bs::ctr32_encrypt_blocks_with_vpaes_key(bs_in_out, &self.inner, ctr);
}

&mut in_out[bsaes_in_out_len..]
InOut::overlapping(&mut in_out[bsaes_in_out_len..], src).unwrap()
};

// SAFETY:
// * self.inner was initialized with `vpaes_set_encrypt_key` above,
// as required by `vpaes_ctr32_encrypt_blocks`.
// * `vpaes_ctr32_encrypt_blocks` satisfies the contract for
// `ctr32_encrypt_blocks`.
unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, src, &self.inner, ctr) }
unsafe { ctr32_encrypt_blocks!(vpaes_ctr32_encrypt_blocks, in_out, &self.inner, ctr) }
}
}

Expand All @@ -122,8 +120,8 @@ impl EncryptBlock for Key {

#[cfg(target_arch = "x86")]
impl EncryptCtr32 for Key {
fn ctr32_encrypt_within(&self, in_out: &mut [u8], src: RangeFrom<usize>, ctr: &mut Counter) {
super::super::shift::shift_full_blocks(in_out, src, |input| {
fn ctr32_encrypt_within(&self, in_out: InOut<'_>, ctr: &mut Counter) {
super::super::shift::shift_full_blocks(in_out, |input| {
self.encrypt_iv_xor_block(ctr.increment(), *input)
});
}
Expand Down
19 changes: 8 additions & 11 deletions src/aead/aes_gcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use super::{
aes::{self, Counter, BLOCK_LEN, ZERO_BLOCK},
gcm, shift, Aad, Nonce, Tag,
gcm, shift, Aad, InOut, Nonce, Tag,
};
use crate::{
cpu, error,
Expand Down Expand Up @@ -160,7 +160,7 @@ pub(super) fn seal(
}
};
let (whole, remainder) = slice::as_chunks_mut(ramaining);
aes_key.ctr32_encrypt_within(slice::flatten_mut(whole), 0.., &mut ctr);
aes_key.ctr32_encrypt_within(InOut::in_place(slice::flatten_mut(whole)), &mut ctr);
auth.update_blocks(whole);
seal_finish(aes_key, auth, remainder, ctr, tag_iv)
}
Expand Down Expand Up @@ -240,7 +240,7 @@ fn seal_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
let (whole, remainder) = slice::as_chunks_mut(in_out);

for chunk in whole.chunks_mut(CHUNK_BLOCKS) {
aes_key.ctr32_encrypt_within(slice::flatten_mut(chunk), 0.., &mut ctr);
aes_key.ctr32_encrypt_within(InOut::in_place(slice::flatten_mut(chunk)), &mut ctr);
auth.update_blocks(chunk);
}

Expand Down Expand Up @@ -331,11 +331,8 @@ pub(super) fn open(
let whole_len = slice::flatten(whole).len();

// Decrypt any remaining whole blocks.
aes_key.ctr32_encrypt_within(
&mut in_out[..(src.start + whole_len)],
src.clone(),
&mut ctr,
);
let whole = InOut::overlapping(&mut in_out[..(src.start + whole_len)], src.clone())?;
aes_key.ctr32_encrypt_within(whole, &mut ctr);

let in_out = match in_out.get_mut(whole_len..) {
Some(partial) => partial,
Expand Down Expand Up @@ -450,11 +447,11 @@ fn open_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
}
auth.update_blocks(ciphertext);

aes_key.ctr32_encrypt_within(
let chunk = InOut::overlapping(
&mut in_out[output..][..(chunk_len + in_prefix_len)],
in_prefix_len..,
&mut ctr,
);
)?;
aes_key.ctr32_encrypt_within(chunk, &mut ctr);
output += chunk_len;
input += chunk_len;
}
Expand Down
Loading

0 comments on commit 50fd56a

Please sign in to comment.