Skip to content

Commit

Permalink
Token Replay validation: Remove exceptions (#2679)
Browse files Browse the repository at this point in the history
* Added ReplayValidationResult to handle TokenReplay validations.
* Refactored ValidateTokenReplay, removing exception throwing
* Added tests
  • Loading branch information
iNinja authored Jul 4, 2024
1 parent 9f3c7f3 commit 67ee21d
Show file tree
Hide file tree
Showing 5 changed files with 595 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;

#nullable enable
namespace Microsoft.IdentityModel.Tokens
{
/// <summary>
/// Contains the result of validating that a <see cref="SecurityToken"/> has not been replayed.
/// The <see cref="TokenValidationResult"/> contains a collection of <see cref="ValidationResult"/> for each step in the token validation.
/// </summary>
internal class ReplayValidationResult : ValidationResult
{
private Exception? _exception;

/// <summary>
/// Creates an instance of <see cref="ReplayValidationResult"/>.
/// </summary>
/// <paramref name="expirationTime"/> is the expiration date against which the token was validated.
public ReplayValidationResult(DateTime? expirationTime) : base(ValidationFailureType.ValidationSucceeded)
{
IsValid = true;
ExpirationTime = expirationTime;
}

/// <summary>
/// Creates an instance of <see cref="ReplayValidationResult"/>
/// </summary>
/// <paramref name="expirationTime"/> is the expiration date against which the token was validated.
/// <paramref name="validationFailure"/> is the <see cref="ValidationFailureType"/> that occurred during validation.
/// <paramref name="exceptionDetail"/> is the <see cref="ExceptionDetail"/> that occurred during validation.
public ReplayValidationResult(DateTime? expirationTime, ValidationFailureType validationFailure, ExceptionDetail exceptionDetail)
: base(validationFailure, exceptionDetail)
{
IsValid = false;
ExpirationTime = expirationTime;
}

/// <summary>
/// Gets the <see cref="Exception"/> that occurred during validation.
/// </summary>
public override Exception? Exception
{
get
{
if (_exception != null || ExceptionDetail == null)
return _exception;

HasValidOrExceptionWasRead = true;
_exception = ExceptionDetail.GetException();
_exception.Source = "Microsoft.IdentityModel.Tokens";

if (_exception is SecurityTokenReplayDetectedException securityTokenReplayDetectedException)
{
securityTokenReplayDetectedException.ExceptionDetail = ExceptionDetail;
}
else if (_exception is SecurityTokenReplayAddFailedException securityTokenReplayAddFailedException)
{
securityTokenReplayAddFailedException.ExceptionDetail = ExceptionDetail;
}

return _exception;
}
}

/// <summary>
/// Gets the expiration date against which the token was validated.
/// </summary>
public DateTime? ExpirationTime { get; }
}
}
#nullable restore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ private class SigningKeyValidationFailure : ValidationFailureType { internal Sig
public static readonly ValidationFailureType LifetimeValidationFailed = new LifetimeValidationFailure("LifetimeValidationFailure");
private class LifetimeValidationFailure : ValidationFailureType { internal LifetimeValidationFailure(string name) : base(name) { } }

/// <summary>
/// Defines a type that represents that token replay validation failed.
/// </summary>
public static readonly ValidationFailureType TokenReplayValidationFailed = new TokenReplayValidationFailure("TokenReplayValidationFailed");
private class TokenReplayValidationFailure : ValidationFailureType { internal TokenReplayValidationFailure(string name) : base(name) { } }

/// <summary>
/// Defines a type that represents that no evaluation has taken place.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,33 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using Microsoft.IdentityModel.Abstractions;
using Microsoft.IdentityModel.Logging;

namespace Microsoft.IdentityModel.Tokens
{
/// <summary>
/// Definition for delegate that will validate that a <see cref="SecurityToken"/> has not been replayed.
/// </summary>
/// <param name="expirationTime">When does the <see cref="SecurityToken"/> expire..</param>
/// <param name="securityToken">The security token that is being validated.</param>
/// <param name="validationParameters"><see cref="TokenValidationParameters"/> required for validation.</param>
/// <param name="callContext"></param>
/// <returns>A <see cref="ReplayValidationResult"/>that contains the results of validating the token.</returns>
/// <remarks>This delegate is not expected to throw.</remarks>
internal delegate ReplayValidationResult ValidateTokenReplay(
DateTime? expirationTime,
string securityToken,
TokenValidationParameters validationParameters,
CallContext callContext);

/// <summary>
/// Partial class for Token Replay validation.
/// </summary>
public static partial class Validators
{
/// <summary>
Expand Down Expand Up @@ -78,5 +98,137 @@ public static void ValidateTokenReplay(string securityToken, DateTime? expiratio
{
ValidateTokenReplay(expirationTime, securityToken, validationParameters);
}

/// <summary>
/// Validates if a token has been replayed.
/// </summary>
/// <param name="expirationTime">When does the security token expire.</param>
/// <param name="securityToken">The <see cref="SecurityToken"/> being validated.</param>
/// <param name="validationParameters"><see cref="TokenValidationParameters"/> required for validation.</param>
/// <param name="callContext"></param>
/// <exception cref="ArgumentNullException">If 'securityToken' is null or whitespace.</exception>
/// <exception cref="ArgumentNullException">If 'validationParameters' is null or whitespace.</exception>
/// <exception cref="SecurityTokenNoExpirationException">If <see cref="TokenValidationParameters.TokenReplayCache"/> is not null and expirationTime.HasValue is false. When a TokenReplayCache is set, tokens require an expiration time.</exception>
/// <exception cref="SecurityTokenReplayDetectedException">If the 'securityToken' is found in the cache.</exception>
/// <exception cref="SecurityTokenReplayAddFailedException">If the 'securityToken' could not be added to the <see cref="TokenValidationParameters.TokenReplayCache"/>.</exception>
#pragma warning disable CA1801 // Review unused parameters
internal static ReplayValidationResult ValidateTokenReplay(DateTime? expirationTime, string securityToken, TokenValidationParameters validationParameters, CallContext callContext)
#pragma warning restore CA1801 // Review unused parameters
{
if (string.IsNullOrWhiteSpace(securityToken))
return new ReplayValidationResult(
expirationTime,
ValidationFailureType.NullArgument,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10000,
LogHelper.MarkAsNonPII(nameof(securityToken))),
typeof(ArgumentNullException),
new StackFrame(),
null));

if (validationParameters == null)
return new ReplayValidationResult(
expirationTime,
ValidationFailureType.NullArgument,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10000,
LogHelper.MarkAsNonPII(nameof(validationParameters))),
typeof(ArgumentNullException),
new StackFrame(),
null));

if (validationParameters.TokenReplayValidator != null)
{
return ValidateTokenReplayUsingDelegate(expirationTime, securityToken, validationParameters);
}

if (!validationParameters.ValidateTokenReplay)
{
LogHelper.LogVerbose(LogMessages.IDX10246);

return new ReplayValidationResult(expirationTime);
}

// check if token if replay cache is set, then there must be an expiration time.
if (validationParameters.TokenReplayCache != null)
{
if (expirationTime == null)
return new ReplayValidationResult(
expirationTime,
ValidationFailureType.TokenReplayValidationFailed,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10227,
LogHelper.MarkAsUnsafeSecurityArtifact(securityToken, t => t.ToString())),
typeof(SecurityTokenReplayDetectedException),
new StackFrame(),
null));

if (validationParameters.TokenReplayCache.TryFind(securityToken))
return new ReplayValidationResult(
expirationTime,
ValidationFailureType.TokenReplayValidationFailed,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10228,
LogHelper.MarkAsUnsafeSecurityArtifact(securityToken, t => t.ToString())),
typeof(SecurityTokenReplayDetectedException),
new StackFrame(),
null));

if (!validationParameters.TokenReplayCache.TryAdd(securityToken, expirationTime.Value))
return new ReplayValidationResult(
expirationTime,
ValidationFailureType.TokenReplayValidationFailed,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10229,
LogHelper.MarkAsUnsafeSecurityArtifact(securityToken, t => t.ToString())),
typeof(SecurityTokenReplayAddFailedException),
new StackFrame(),
null));
}

// if it reaches here, that means no token replay is detected.
LogHelper.LogInformation(LogMessages.IDX10240);
return new ReplayValidationResult(expirationTime);
}

private static ReplayValidationResult ValidateTokenReplayUsingDelegate(DateTime? expirationTime, string securityToken, TokenValidationParameters validationParameters)
{
try
{
if (!validationParameters.TokenReplayValidator(expirationTime, securityToken, validationParameters))
return new ReplayValidationResult(
expirationTime,
ValidationFailureType.TokenReplayValidationFailed,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10228,
LogHelper.MarkAsUnsafeSecurityArtifact(securityToken, t => t.ToString())),
typeof(SecurityTokenReplayDetectedException),
new StackFrame(),
null));

return new ReplayValidationResult(expirationTime);
}
#pragma warning disable CA1031 // Do not catch general exception types
catch (Exception exception)
#pragma warning restore CA1031 // Do not catch general exception types
{
return new ReplayValidationResult(
expirationTime,
ValidationFailureType.TokenReplayValidationFailed,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10228,
LogHelper.MarkAsUnsafeSecurityArtifact(securityToken, t => t.ToString())),
exception.GetType(),
new StackFrame(),
exception));
}
}
}
}
66 changes: 66 additions & 0 deletions test/Microsoft.IdentityModel.TestUtils/IdentityComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,72 @@ internal static bool AreLifetimeValidationResultsEqual(
return context.Merge(localContext);
}

public static bool AreTokenReplayValidationResultsEqual(object object1, object object2, CompareContext context)
{
var localContext = new CompareContext(context);
if (!ContinueCheckingEquality(object1, object2, context))
return context.Merge(localContext);

return AreTokenReplayValidationResultsEqual(
object1 as ReplayValidationResult,
object2 as ReplayValidationResult,
"ReplayValidationResult1",
"ReplayValidationResult2",
null,
context);
}

internal static bool AreTokenReplayValidationResultsEqual(
ReplayValidationResult replayValidationResult1,
ReplayValidationResult replayValidationResult2,
string name1,
string name2,
string stackPrefix,
CompareContext context)
{
var localContext = new CompareContext(context);
if (!ContinueCheckingEquality(replayValidationResult1, replayValidationResult2, localContext))
return context.Merge(localContext);

if (replayValidationResult1.ExpirationTime != replayValidationResult2.ExpirationTime)
localContext.Diffs.Add($"ReplayValidationResult1.ExpirationTime: '{replayValidationResult1.ExpirationTime}' != ReplayValidationResult2.ExpirationTime: '{replayValidationResult2.ExpirationTime}'");

if (replayValidationResult1.IsValid != replayValidationResult2.IsValid)
localContext.Diffs.Add($"ReplayValidationResult1.IsValid: {replayValidationResult1.IsValid} != ReplayValidationResult2.IsValid: {replayValidationResult2.IsValid}");

if (replayValidationResult1.ValidationFailureType != replayValidationResult2.ValidationFailureType)
localContext.Diffs.Add($"ReplayValidationResult1.ValidationFailureType: {replayValidationResult1.ValidationFailureType} != ReplayValidationResult2.ValidationFailureType: {replayValidationResult2.ValidationFailureType}");

// true => both are not null.
if (ContinueCheckingEquality(replayValidationResult1.Exception, replayValidationResult2.Exception, localContext))
{
AreStringsEqual(
replayValidationResult1.Exception.Message,
replayValidationResult2.Exception.Message,
$"({name1}).Exception.Message",
$"({name2}).Exception.Message",
localContext);

AreStringsEqual(
replayValidationResult1.Exception.Source,
replayValidationResult2.Exception.Source,
$"({name1}).Exception.Source",
$"({name2}).Exception.Source",
localContext);

if (!string.IsNullOrEmpty(stackPrefix))
AreStringPrefixesEqual(
replayValidationResult1.Exception.StackTrace.Trim(),
replayValidationResult2.Exception.StackTrace.Trim(),
$"({name1}).Exception.StackTrace",
$"({name2}).Exception.StackTrace",
stackPrefix.Trim(),
localContext);
}

return context.Merge(localContext);
}

public static bool AreJArraysEqual(object object1, object object2, CompareContext context)
{
var localContext = new CompareContext(context);
Expand Down
Loading

0 comments on commit 67ee21d

Please sign in to comment.