Skip to content

Commit

Permalink
Minor cleanup for Rfc2898DeriveBytes
Browse files Browse the repository at this point in the history
The _password field is not needed since CryptDeriveKey was not ported
from the Desktop framework.

Removing the field also allows removing a defensive copy and clearing it
during disposal.
  • Loading branch information
vcsjones authored Jan 27, 2021
1 parent ad7a50e commit 2ecec9c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ public class Rfc2898DeriveBytes : DeriveBytes
{
private const int MinimumSaltSize = 8;

private readonly byte[] _password;
private byte[] _salt;
private uint _iterations;
private HMAC _hmac;
Expand All @@ -38,26 +37,8 @@ public Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations)
}

public Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm)
:this(password, salt, iterations, hashAlgorithm, clearPassword: false)
{
if (salt == null)
throw new ArgumentNullException(nameof(salt));
if (salt.Length < MinimumSaltSize)
throw new ArgumentException(SR.Cryptography_PasswordDerivedBytes_FewBytesSalt, nameof(salt));
if (iterations <= 0)
throw new ArgumentOutOfRangeException(nameof(iterations), SR.ArgumentOutOfRange_NeedPosNum);
if (password == null)
throw new NullReferenceException(); // This "should" be ArgumentNullException but for compat, we throw NullReferenceException.

_salt = new byte[salt.Length + sizeof(uint)];
salt.AsSpan().CopyTo(_salt);
_iterations = (uint)iterations;
_password = password.CloneByteArray();
HashAlgorithm = hashAlgorithm;
_hmac = OpenHmac();
// _blockSize is in bytes, HashSize is in bits.
_blockSize = _hmac.HashSize >> 3;

Initialize();
}

public Rfc2898DeriveBytes(string password, byte[] salt)
Expand All @@ -71,7 +52,7 @@ public Rfc2898DeriveBytes(string password, byte[] salt, int iterations)
}

public Rfc2898DeriveBytes(string password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm)
: this(Encoding.UTF8.GetBytes(password), salt, iterations, hashAlgorithm)
: this(Encoding.UTF8.GetBytes(password), salt, iterations, hashAlgorithm, clearPassword: true)
{
}

Expand All @@ -98,15 +79,43 @@ public Rfc2898DeriveBytes(string password, int saltSize, int iterations, HashAlg
RandomNumberGenerator.Fill(_salt.AsSpan(0, saltSize));

_iterations = (uint)iterations;
_password = Encoding.UTF8.GetBytes(password);
byte[] passwordBytes = Encoding.UTF8.GetBytes(password);
HashAlgorithm = hashAlgorithm;
_hmac = OpenHmac();
_hmac = OpenHmac(passwordBytes);
CryptographicOperations.ZeroMemory(passwordBytes);
// _blockSize is in bytes, HashSize is in bits.
_blockSize = _hmac.HashSize >> 3;

Initialize();
}

private Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm, bool clearPassword)
{
if (salt is null)
throw new ArgumentNullException(nameof(salt));
if (salt.Length < MinimumSaltSize)
throw new ArgumentException(SR.Cryptography_PasswordDerivedBytes_FewBytesSalt, nameof(salt));
if (iterations <= 0)
throw new ArgumentOutOfRangeException(nameof(iterations), SR.ArgumentOutOfRange_NeedPosNum);
if (password is null)
throw new NullReferenceException(); // This "should" be ArgumentNullException but for compat, we throw NullReferenceException.

_salt = new byte[salt.Length + sizeof(uint)];
salt.AsSpan().CopyTo(_salt);
_iterations = (uint)iterations;
HashAlgorithm = hashAlgorithm;
_hmac = OpenHmac(password);

if (clearPassword)
{
CryptographicOperations.ZeroMemory(password);
}

// _blockSize is in bytes, HashSize is in bits.
_blockSize = _hmac.HashSize >> 3;
Initialize();
}

public int IterationCount
{
get
Expand Down Expand Up @@ -155,8 +164,6 @@ protected override void Dispose(bool disposing)

if (_buffer != null)
Array.Clear(_buffer, 0, _buffer.Length);
if (_password != null)
Array.Clear(_password, 0, _password.Length);
if (_salt != null)
Array.Clear(_salt, 0, _salt.Length);
}
Expand Down Expand Up @@ -228,23 +235,23 @@ public override void Reset()
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Security", "CA5350", Justification = "HMACSHA1 is needed for compat. (https://github.com/dotnet/runtime/issues/17618)")]
private HMAC OpenHmac()
private HMAC OpenHmac(byte[] password)
{
Debug.Assert(_password != null);
Debug.Assert(password != null);

HashAlgorithmName hashAlgorithm = HashAlgorithm;

if (string.IsNullOrEmpty(hashAlgorithm.Name))
throw new CryptographicException(SR.Cryptography_HashAlgorithmNameNullOrEmpty);

if (hashAlgorithm == HashAlgorithmName.SHA1)
return new HMACSHA1(_password);
return new HMACSHA1(password);
if (hashAlgorithm == HashAlgorithmName.SHA256)
return new HMACSHA256(_password);
return new HMACSHA256(password);
if (hashAlgorithm == HashAlgorithmName.SHA384)
return new HMACSHA384(_password);
return new HMACSHA384(password);
if (hashAlgorithm == HashAlgorithmName.SHA512)
return new HMACSHA512(_password);
return new HMACSHA512(password);

throw new CryptographicException(SR.Format(SR.Cryptography_UnknownHashAlgorithm, hashAlgorithm.Name));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,38 @@ public static void GetBytes_ExceedCounterLimit()
}
}

[Fact]
public static void Ctor_PasswordMutatedAfterCreate()
{
byte[] passwordBytes = Encoding.UTF8.GetBytes(TestPassword);
byte[] derived;

using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(passwordBytes, s_testSaltB, DefaultIterationCount))
{
derived = deriveBytes.GetBytes(64);
}

using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(passwordBytes, s_testSaltB, DefaultIterationCount))
{
passwordBytes[0] ^= 0xFF; // Flipping a byte after the object is constructed should not be observed.

byte[] actual = deriveBytes.GetBytes(64);
Assert.Equal(derived, actual);
}
}

[Fact]
public static void Ctor_PasswordBytes_NotCleared()
{
byte[] passwordBytes = Encoding.UTF8.GetBytes(TestPassword);
byte[] passwordBytesOriginal = passwordBytes.AsSpan().ToArray();

using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(passwordBytes, s_testSaltB, DefaultIterationCount))
{
Assert.Equal(passwordBytesOriginal, passwordBytes);
}
}

private static void TestKnownValue(string password, byte[] salt, int iterationCount, byte[] expected)
{
byte[] output;
Expand Down

0 comments on commit 2ecec9c

Please sign in to comment.