Skip to content

Commit

Permalink
fix: remove extra block assumptions in mbedtls integration (shaka-pro…
Browse files Browse the repository at this point in the history
…ject#1323)

The current mbedtls integration was not working for some modes. See for
example shaka-project#1316 and also lots of failing integration tests.

For example in pattern encryptor it works on one block at a time so it
cannot assume it's going to always get a buffer with a padding for an
extra block.

From what I can tell when the padding mode is correctly set to
`MBEDTLS_PADDING_NONE` there is no extra block being written to or
required.

This passes all crypto unit tests and integration tests.

Closes shaka-project#1316
  • Loading branch information
cosmin authored Feb 8, 2024
1 parent 9b9adf3 commit db59ad5
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 102 deletions.
6 changes: 2 additions & 4 deletions packager/media/base/aes_cryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ bool AesCryptor::Crypt(const std::vector<uint8_t>& text,
// Save text size to make it work for in-place conversion, since the
// next statement will update the text size.
const size_t text_size = text.size();
// mbedtls requires an extra block's worth of output buffer available.
crypt_text->resize(text_size + NumPaddingBytes(text_size) + AES_BLOCK_SIZE);
crypt_text->resize(text_size + NumPaddingBytes(text_size));
size_t crypt_text_size = crypt_text->size();
if (!Crypt(text.data(), text_size, crypt_text->data(), &crypt_text_size)) {
return false;
Expand All @@ -58,8 +57,7 @@ bool AesCryptor::Crypt(const std::string& text, std::string* crypt_text) {
// Save text size to make it work for in-place conversion, since the
// next statement will update the text size.
const size_t text_size = text.size();
// mbedtls requires an extra block's worth of output buffer available.
crypt_text->resize(text_size + NumPaddingBytes(text_size) + AES_BLOCK_SIZE);
crypt_text->resize(text_size + NumPaddingBytes(text_size));
size_t crypt_text_size = crypt_text->size();
if (!Crypt(reinterpret_cast<const uint8_t*>(text.data()), text_size,
reinterpret_cast<uint8_t*>(&(*crypt_text)[0]), &crypt_text_size))
Expand Down
114 changes: 50 additions & 64 deletions packager/media/base/aes_decryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ bool AesCbcDecryptor::InitializeWithIv(const std::vector<uint8_t>& key,
}

size_t AesCbcDecryptor::RequiredOutputSize(size_t plaintext_size) {
// mbedtls requires a buffer large enough for one extra block.
return plaintext_size + AES_BLOCK_SIZE;
return plaintext_size;
}

bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
Expand All @@ -60,14 +59,12 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
// Plaintext size is the same as ciphertext size except for pkcs5 padding.
// Will update later if using pkcs5 padding. For pkcs5 padding, we still
// need at least |ciphertext_size| bytes for intermediate operation.
// mbedtls requires a buffer large enough for one extra block.
const size_t required_plaintext_size = ciphertext_size + AES_BLOCK_SIZE;
if (*plaintext_size < required_plaintext_size) {
LOG(ERROR) << "Expecting output size of at least "
<< required_plaintext_size << " bytes.";
if (*plaintext_size < ciphertext_size) {
LOG(ERROR) << "Expecting output size of at least " << ciphertext_size
<< " bytes.";
return false;
}
*plaintext_size = required_plaintext_size - AES_BLOCK_SIZE;
*plaintext_size = ciphertext_size;

// If the ciphertext size is 0, this can be a no-op decrypt, so long as the
// padding mode isn't PKCS5.
Expand All @@ -83,15 +80,9 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,

const size_t residual_block_size = ciphertext_size % AES_BLOCK_SIZE;
const size_t cbc_size = ciphertext_size - residual_block_size;

// Copy the residual block early, since mbedtls may overwrite one extra block
// of the output, and input and output may be the same buffer.
std::vector<uint8_t> residual_block(ciphertext + cbc_size,
ciphertext + ciphertext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);

if (residual_block_size == 0) {
CbcDecryptBlocks(ciphertext, ciphertext_size, plaintext);
CbcDecryptBlocks(ciphertext, ciphertext_size, plaintext,
internal_iv_.data());
if (padding_scheme_ != kPkcs5Padding)
return true;

Expand All @@ -105,10 +96,11 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
*plaintext_size -= num_padding_bytes;
return true;
} else if (padding_scheme_ == kNoPadding) {
CbcDecryptBlocks(ciphertext, cbc_size, plaintext);

if (cbc_size > 0) {
CbcDecryptBlocks(ciphertext, cbc_size, plaintext, internal_iv_.data());
}
// The residual block is not encrypted.
memcpy(plaintext + cbc_size, residual_block.data(), residual_block_size);
memcpy(plaintext + cbc_size, ciphertext + cbc_size, residual_block_size);
return true;
} else if (padding_scheme_ != kCtsPadding) {
LOG(ERROR) << "Expecting cipher text size to be multiple of "
Expand All @@ -123,49 +115,44 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
return true;
}

// Copy the next-to-last block early, since mbedtls may overwrite one extra
// block of the output, and input and output may be the same buffer.
// NOTE: Before this point, there may not be such a block. Here, we know
// this is safe.
std::vector<uint8_t> next_to_last_block(
ciphertext + cbc_size - AES_BLOCK_SIZE, ciphertext + cbc_size);

// AES-CBC decrypt everything up to the next-to-last full block.
if (cbc_size > AES_BLOCK_SIZE) {
CbcDecryptBlocks(ciphertext, cbc_size - AES_BLOCK_SIZE, plaintext);
CbcDecryptBlocks(ciphertext, cbc_size - AES_BLOCK_SIZE, plaintext,
internal_iv_.data());
}

const uint8_t* next_to_last_ciphertext_block =
ciphertext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
uint8_t* next_to_last_plaintext_block =
plaintext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;

// Determine what the last IV should be so that we can "skip ahead" in the
// CBC decryption.
std::vector<uint8_t> last_iv(
ciphertext + ciphertext_size - residual_block_size,
ciphertext + ciphertext_size);
last_iv.resize(AES_BLOCK_SIZE, 0);

// Decrypt the next-to-last block using the IV determined above. This decrypts
// the residual block bits.
CbcDecryptBlocks(next_to_last_ciphertext_block, AES_BLOCK_SIZE,
next_to_last_plaintext_block, last_iv.data());

// Swap back the residual block bits and the next-to-last block.
if (plaintext == ciphertext) {
std::swap_ranges(next_to_last_plaintext_block,
next_to_last_plaintext_block + residual_block_size,
next_to_last_plaintext_block + AES_BLOCK_SIZE);
} else {
memcpy(next_to_last_plaintext_block + AES_BLOCK_SIZE,
next_to_last_plaintext_block, residual_block_size);
memcpy(next_to_last_plaintext_block,
next_to_last_ciphertext_block + AES_BLOCK_SIZE, residual_block_size);
}

uint8_t* next_to_last_plaintext_block = plaintext + cbc_size - AES_BLOCK_SIZE;

// The next-to-last block should be decrypted first in ECB mode, which is
// effectively what you get with an IV of all zeroes.
std::vector<uint8_t> backup_iv(internal_iv_);
internal_iv_.assign(AES_BLOCK_SIZE, 0);
// mbedtls requires a buffer large enough for one extra block.
std::vector<uint8_t> stolen_bits(AES_BLOCK_SIZE * 2);
CbcDecryptBlocks(next_to_last_block.data(), AES_BLOCK_SIZE,
stolen_bits.data());

// Reconstruct the final two blocks of ciphertext.
std::vector<uint8_t> reconstructed_blocks(AES_BLOCK_SIZE * 2);
memcpy(reconstructed_blocks.data(), residual_block.data(),
residual_block_size);
memcpy(reconstructed_blocks.data() + residual_block_size,
stolen_bits.data() + residual_block_size,
AES_BLOCK_SIZE - residual_block_size);
memcpy(reconstructed_blocks.data() + AES_BLOCK_SIZE,
next_to_last_block.data(), AES_BLOCK_SIZE);

// Decrypt the last two blocks.
internal_iv_ = backup_iv;
// mbedtls requires a buffer large enough for one extra block.
std::vector<uint8_t> final_output_blocks(AES_BLOCK_SIZE * 3);
CbcDecryptBlocks(reconstructed_blocks.data(), AES_BLOCK_SIZE * 2,
final_output_blocks.data());

// Copy the final output.
memcpy(next_to_last_plaintext_block, final_output_blocks.data(),
AES_BLOCK_SIZE + residual_block_size);
// Decrypt the next-to-last full block.
CbcDecryptBlocks(next_to_last_plaintext_block, AES_BLOCK_SIZE,
next_to_last_plaintext_block, internal_iv_.data());
return true;
}

Expand All @@ -176,7 +163,8 @@ void AesCbcDecryptor::SetIvInternal() {

void AesCbcDecryptor::CbcDecryptBlocks(const uint8_t* ciphertext,
size_t ciphertext_size,
uint8_t* plaintext) {
uint8_t* plaintext,
uint8_t* iv) {
CHECK_EQ(ciphertext_size % AES_BLOCK_SIZE, 0u);
CHECK_GT(ciphertext_size, 0u);

Expand All @@ -186,14 +174,12 @@ void AesCbcDecryptor::CbcDecryptBlocks(const uint8_t* ciphertext,
std::vector<uint8_t> next_iv(last_block, last_block + AES_BLOCK_SIZE);

size_t output_size = 0;
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, internal_iv_.data(),
AES_BLOCK_SIZE, ciphertext, ciphertext_size,
plaintext, &output_size),
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, iv, AES_BLOCK_SIZE, ciphertext,
ciphertext_size, plaintext, &output_size),
0);
DCHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);

// Update the internal IV.
internal_iv_ = next_iv;
memcpy(iv, next_iv.data(), next_iv.size());
}

} // namespace media
Expand Down
3 changes: 2 additions & 1 deletion packager/media/base/aes_decryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class AesCbcDecryptor : public AesCryptor {
void SetIvInternal() override;
void CbcDecryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext);
uint8_t* ciphertext,
uint8_t* iv);

const CbcPaddingScheme padding_scheme_;
// 16-byte internal iv for crypto operations.
Expand Down
46 changes: 18 additions & 28 deletions packager/media/base/aes_encryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ namespace media {
AesCtrEncryptor::AesCtrEncryptor()
: AesCryptor(kDontUseConstantIv),
block_offset_(0),
// mbedtls requires an extra output block.
encrypted_counter_(AES_BLOCK_SIZE * 2, 0) {}
encrypted_counter_(AES_BLOCK_SIZE, 0) {}

AesCtrEncryptor::~AesCtrEncryptor() {}

Expand Down Expand Up @@ -129,8 +128,7 @@ bool AesCbcEncryptor::InitializeWithIv(const std::vector<uint8_t>& key,
}

size_t AesCbcEncryptor::RequiredOutputSize(size_t plaintext_size) {
// mbedtls requires a buffer large enough for one extra block.
return plaintext_size + NumPaddingBytes(plaintext_size) + AES_BLOCK_SIZE;
return plaintext_size + NumPaddingBytes(plaintext_size);
}

bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
Expand All @@ -146,19 +144,12 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
<< required_ciphertext_size << " bytes.";
return false;
}
*ciphertext_size = required_ciphertext_size - AES_BLOCK_SIZE;
*ciphertext_size = required_ciphertext_size;

// Encrypt everything but the residual block using CBC.
const size_t cbc_size = plaintext_size - residual_block_size;

// Copy the residual block early, since mbedtls may overwrite one extra block
// of the output, and input and output may be the same buffer.
std::vector<uint8_t> residual_block(plaintext + cbc_size,
plaintext + plaintext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);

if (cbc_size != 0) {
CbcEncryptBlocks(plaintext, cbc_size, ciphertext);
CbcEncryptBlocks(plaintext, cbc_size, ciphertext, internal_iv_.data());
} else if (padding_scheme_ == kCtsPadding) {
// Don't have a full block, leave unencrypted.
memcpy(ciphertext, plaintext, plaintext_size);
Expand All @@ -175,27 +166,26 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
return true;
}

std::vector<uint8_t> residual_block(plaintext + cbc_size,
plaintext + plaintext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);
uint8_t* residual_ciphertext_block = ciphertext + cbc_size;

if (padding_scheme_ == kPkcs5Padding) {
DCHECK_EQ(num_padding_bytes, AES_BLOCK_SIZE - residual_block_size);

// Pad residue block with PKCS5 padding.
residual_block.resize(AES_BLOCK_SIZE, static_cast<char>(num_padding_bytes));

CbcEncryptBlocks(residual_block.data(), AES_BLOCK_SIZE,
residual_ciphertext_block);
residual_ciphertext_block, internal_iv_.data());
} else {
DCHECK_EQ(num_padding_bytes, 0u);
DCHECK_EQ(padding_scheme_, kCtsPadding);

// Zero-pad the residual block and encrypt using CBC.
residual_block.resize(AES_BLOCK_SIZE, 0);
// mbedtls requires an extra block in the output buffer, and it cannot be
// the same as the input buffer.
std::vector<uint8_t> encrypted_residual_block(AES_BLOCK_SIZE * 2);

CbcEncryptBlocks(residual_block.data(), AES_BLOCK_SIZE,
encrypted_residual_block.data());
residual_block.data(), internal_iv_.data());

// Replace the last full block with the zero-padded, encrypted residual
// block, and replace the residual block with the equivalent portion of the
Expand All @@ -206,8 +196,8 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
// https://en.wikipedia.org/wiki/Ciphertext_stealing#CS2
memcpy(residual_ciphertext_block,
residual_ciphertext_block - AES_BLOCK_SIZE, residual_block_size);
memcpy(residual_ciphertext_block - AES_BLOCK_SIZE,
encrypted_residual_block.data(), AES_BLOCK_SIZE);
memcpy(residual_ciphertext_block - AES_BLOCK_SIZE, residual_block.data(),
AES_BLOCK_SIZE);
}
return true;
}
Expand All @@ -225,20 +215,20 @@ size_t AesCbcEncryptor::NumPaddingBytes(size_t size) const {

void AesCbcEncryptor::CbcEncryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext) {
uint8_t* ciphertext,
uint8_t* iv) {
CHECK_EQ(plaintext_size % AES_BLOCK_SIZE, 0u);

size_t output_size = 0;
CHECK_EQ(
mbedtls_cipher_crypt(&cipher_ctx_, internal_iv_.data(), AES_BLOCK_SIZE,
plaintext, plaintext_size, ciphertext, &output_size),
0);
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, iv, AES_BLOCK_SIZE, plaintext,
plaintext_size, ciphertext, &output_size),
0);

CHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);
CHECK_GT(output_size, 0u);

uint8_t* last_block = ciphertext + output_size - AES_BLOCK_SIZE;
internal_iv_.assign(last_block, last_block + AES_BLOCK_SIZE);
memcpy(iv, last_block, AES_BLOCK_SIZE);
}

} // namespace media
Expand Down
3 changes: 2 additions & 1 deletion packager/media/base/aes_encryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class AesCbcEncryptor : public AesCryptor {

void CbcEncryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext);
uint8_t* ciphertext,
uint8_t* iv);

const CbcPaddingScheme padding_scheme_;
// 16-byte internal iv for crypto operations.
Expand Down
5 changes: 1 addition & 4 deletions packager/media/base/playready_pssh_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ void AesEcbEncrypt(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& plaintext,
std::vector<uint8_t>* ciphertext) {
CHECK_EQ(plaintext.size() % AES_BLOCK_SIZE, 0u);
// mbedtls requires an extra block worth of output buffer.
ciphertext->resize(plaintext.size() + AES_BLOCK_SIZE);
ciphertext->resize(plaintext.size());

mbedtls_cipher_context_t ctx;
mbedtls_cipher_init(&ctx);
Expand All @@ -98,8 +97,6 @@ void AesEcbEncrypt(const std::vector<uint8_t>& key,
plaintext.data(), plaintext.size(),
ciphertext->data(), &output_size),
0);
// Truncate the output to the correct size.
ciphertext->resize(plaintext.size());

mbedtls_cipher_free(&ctx);
}
Expand Down

0 comments on commit db59ad5

Please sign in to comment.