diff --git a/src/libraries/System.Security.Cryptography.Algorithms/src/System/Security/Cryptography/Rfc2898DeriveBytes.cs b/src/libraries/System.Security.Cryptography.Algorithms/src/System/Security/Cryptography/Rfc2898DeriveBytes.cs index b6e615e2b880c..bf833303ecc90 100644 --- a/src/libraries/System.Security.Cryptography.Algorithms/src/System/Security/Cryptography/Rfc2898DeriveBytes.cs +++ b/src/libraries/System.Security.Cryptography.Algorithms/src/System/Security/Cryptography/Rfc2898DeriveBytes.cs @@ -16,7 +16,6 @@ public class Rfc2898DeriveBytes : DeriveBytes { private const int MinimumSaltSize = 8; - private readonly byte[] _password; private byte[] _salt; private uint _iterations; private HMAC _hmac; @@ -38,26 +37,8 @@ public Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations) } public Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm) + :this(password, salt, iterations, hashAlgorithm, clearPassword: false) { - if (salt == null) - throw new ArgumentNullException(nameof(salt)); - if (salt.Length < MinimumSaltSize) - throw new ArgumentException(SR.Cryptography_PasswordDerivedBytes_FewBytesSalt, nameof(salt)); - if (iterations <= 0) - throw new ArgumentOutOfRangeException(nameof(iterations), SR.ArgumentOutOfRange_NeedPosNum); - if (password == null) - throw new NullReferenceException(); // This "should" be ArgumentNullException but for compat, we throw NullReferenceException. - - _salt = new byte[salt.Length + sizeof(uint)]; - salt.AsSpan().CopyTo(_salt); - _iterations = (uint)iterations; - _password = password.CloneByteArray(); - HashAlgorithm = hashAlgorithm; - _hmac = OpenHmac(); - // _blockSize is in bytes, HashSize is in bits. - _blockSize = _hmac.HashSize >> 3; - - Initialize(); } public Rfc2898DeriveBytes(string password, byte[] salt) @@ -71,7 +52,7 @@ public Rfc2898DeriveBytes(string password, byte[] salt, int iterations) } public Rfc2898DeriveBytes(string password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm) - : this(Encoding.UTF8.GetBytes(password), salt, iterations, hashAlgorithm) + : this(Encoding.UTF8.GetBytes(password), salt, iterations, hashAlgorithm, clearPassword: true) { } @@ -98,15 +79,43 @@ public Rfc2898DeriveBytes(string password, int saltSize, int iterations, HashAlg RandomNumberGenerator.Fill(_salt.AsSpan(0, saltSize)); _iterations = (uint)iterations; - _password = Encoding.UTF8.GetBytes(password); + byte[] passwordBytes = Encoding.UTF8.GetBytes(password); HashAlgorithm = hashAlgorithm; - _hmac = OpenHmac(); + _hmac = OpenHmac(passwordBytes); + CryptographicOperations.ZeroMemory(passwordBytes); // _blockSize is in bytes, HashSize is in bits. _blockSize = _hmac.HashSize >> 3; Initialize(); } + private Rfc2898DeriveBytes(byte[] password, byte[] salt, int iterations, HashAlgorithmName hashAlgorithm, bool clearPassword) + { + if (salt is null) + throw new ArgumentNullException(nameof(salt)); + if (salt.Length < MinimumSaltSize) + throw new ArgumentException(SR.Cryptography_PasswordDerivedBytes_FewBytesSalt, nameof(salt)); + if (iterations <= 0) + throw new ArgumentOutOfRangeException(nameof(iterations), SR.ArgumentOutOfRange_NeedPosNum); + if (password is null) + throw new NullReferenceException(); // This "should" be ArgumentNullException but for compat, we throw NullReferenceException. + + _salt = new byte[salt.Length + sizeof(uint)]; + salt.AsSpan().CopyTo(_salt); + _iterations = (uint)iterations; + HashAlgorithm = hashAlgorithm; + _hmac = OpenHmac(password); + + if (clearPassword) + { + CryptographicOperations.ZeroMemory(password); + } + + // _blockSize is in bytes, HashSize is in bits. + _blockSize = _hmac.HashSize >> 3; + Initialize(); + } + public int IterationCount { get @@ -155,8 +164,6 @@ protected override void Dispose(bool disposing) if (_buffer != null) Array.Clear(_buffer, 0, _buffer.Length); - if (_password != null) - Array.Clear(_password, 0, _password.Length); if (_salt != null) Array.Clear(_salt, 0, _salt.Length); } @@ -228,9 +235,9 @@ public override void Reset() } [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Security", "CA5350", Justification = "HMACSHA1 is needed for compat. (https://github.com/dotnet/runtime/issues/17618)")] - private HMAC OpenHmac() + private HMAC OpenHmac(byte[] password) { - Debug.Assert(_password != null); + Debug.Assert(password != null); HashAlgorithmName hashAlgorithm = HashAlgorithm; @@ -238,13 +245,13 @@ private HMAC OpenHmac() throw new CryptographicException(SR.Cryptography_HashAlgorithmNameNullOrEmpty); if (hashAlgorithm == HashAlgorithmName.SHA1) - return new HMACSHA1(_password); + return new HMACSHA1(password); if (hashAlgorithm == HashAlgorithmName.SHA256) - return new HMACSHA256(_password); + return new HMACSHA256(password); if (hashAlgorithm == HashAlgorithmName.SHA384) - return new HMACSHA384(_password); + return new HMACSHA384(password); if (hashAlgorithm == HashAlgorithmName.SHA512) - return new HMACSHA512(_password); + return new HMACSHA512(password); throw new CryptographicException(SR.Format(SR.Cryptography_UnknownHashAlgorithm, hashAlgorithm.Name)); } diff --git a/src/libraries/System.Security.Cryptography.Algorithms/tests/Rfc2898Tests.cs b/src/libraries/System.Security.Cryptography.Algorithms/tests/Rfc2898Tests.cs index 8090be717e487..c8fabd02d466a 100644 --- a/src/libraries/System.Security.Cryptography.Algorithms/tests/Rfc2898Tests.cs +++ b/src/libraries/System.Security.Cryptography.Algorithms/tests/Rfc2898Tests.cs @@ -409,6 +409,38 @@ public static void GetBytes_ExceedCounterLimit() } } + [Fact] + public static void Ctor_PasswordMutatedAfterCreate() + { + byte[] passwordBytes = Encoding.UTF8.GetBytes(TestPassword); + byte[] derived; + + using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(passwordBytes, s_testSaltB, DefaultIterationCount)) + { + derived = deriveBytes.GetBytes(64); + } + + using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(passwordBytes, s_testSaltB, DefaultIterationCount)) + { + passwordBytes[0] ^= 0xFF; // Flipping a byte after the object is constructed should not be observed. + + byte[] actual = deriveBytes.GetBytes(64); + Assert.Equal(derived, actual); + } + } + + [Fact] + public static void Ctor_PasswordBytes_NotCleared() + { + byte[] passwordBytes = Encoding.UTF8.GetBytes(TestPassword); + byte[] passwordBytesOriginal = passwordBytes.AsSpan().ToArray(); + + using (Rfc2898DeriveBytes deriveBytes = new Rfc2898DeriveBytes(passwordBytes, s_testSaltB, DefaultIterationCount)) + { + Assert.Equal(passwordBytesOriginal, passwordBytes); + } + } + private static void TestKnownValue(string password, byte[] salt, int iterationCount, byte[] expected) { byte[] output;