diff --git a/src/TestFramework/TestFramework/Assertions/Assert.AreEqual.cs b/src/TestFramework/TestFramework/Assertions/Assert.AreEqual.cs index 50eb9b44be..f2fece6699 100644 --- a/src/TestFramework/TestFramework/Assertions/Assert.AreEqual.cs +++ b/src/TestFramework/TestFramework/Assertions/Assert.AreEqual.cs @@ -201,6 +201,105 @@ public static void AreEqual(T? expected, T? actual, IEqualityComparer? com ThrowAssertFailed("Assert.AreEqual", finalMessage); } + /// + /// Tests whether the specified values are equal and throws an exception + /// if the two values are not equal. + /// The equality is computed using the default . + /// + /// + /// The type of values to compare. + /// + /// + /// The first value to compare. This is the value the tests expects. + /// + /// + /// The second value to compare. This is the value produced by the code under test. + /// + /// + /// Thrown if is not equal to . + /// + public static void AreEqual(IEquatable? expected, IEquatable? actual) + => AreEqual(expected, actual, string.Empty, null); + + /// + /// Tests whether the specified values are equal and throws an exception + /// if the two values are not equal. + /// The equality is computed using the default . + /// + /// + /// The type of values to compare. + /// + /// + /// The first value to compare. This is the value the tests expects. + /// + /// + /// The second value to compare. This is the value produced by the code under test. + /// + /// + /// The message to include in the exception when + /// is not equal to . The message is shown in + /// test results. + /// + /// + /// Thrown if is not equal to + /// . + /// + public static void AreEqual(IEquatable? expected, IEquatable? actual, string? message) + => AreEqual(expected, actual, message, null); + + /// + /// Tests whether the specified values are equal and throws an exception + /// if the two values are not equal. + /// The equality is computed using the default . + /// + /// + /// The type of values to compare. + /// + /// + /// The first value to compare. This is the value the tests expects. + /// + /// + /// The second value to compare. This is the value produced by the code under test. + /// + /// + /// The message to include in the exception when + /// is not equal to . The message is shown in + /// test results. + /// + /// + /// An array of parameters to use when formatting . + /// + /// + /// Thrown if is not equal to + /// . + /// + public static void AreEqual(IEquatable? expected, IEquatable? actual, string? message, params object?[]? parameters) + { + if (actual?.Equals(expected) == true) + { + return; + } + + string userMessage = BuildUserMessage(message, parameters); + string finalMessage = actual != null && expected != null && !actual.GetType().Equals(expected.GetType()) + ? string.Format( + CultureInfo.CurrentCulture, + FrameworkMessages.AreEqualDifferentTypesFailMsg, + userMessage, + ReplaceNulls(expected), + expected.GetType().FullName, + ReplaceNulls(actual), + actual.GetType().FullName) + : string.Format( + CultureInfo.CurrentCulture, + FrameworkMessages.AreEqualFailMsg, + userMessage, + ReplaceNulls(expected), + ReplaceNulls(actual)); + + ThrowAssertFailed("Assert.AreEqual", finalMessage); + } + /// /// Tests whether the specified values are unequal and throws an exception /// if the two values are equal. diff --git a/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Shipped.txt b/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Shipped.txt index f5260779bf..82e656b659 100644 --- a/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Shipped.txt +++ b/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Shipped.txt @@ -250,6 +250,9 @@ static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(string? expe static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(string? expected, string? actual, bool ignoreCase, System.Globalization.CultureInfo? culture) -> void static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(string? expected, string? actual, bool ignoreCase, System.Globalization.CultureInfo? culture, string? message) -> void static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(string? expected, string? actual, bool ignoreCase, System.Globalization.CultureInfo? culture, string? message, params object?[]? parameters) -> void +static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(System.IEquatable? expected, System.IEquatable? actual) -> void +static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(System.IEquatable? expected, System.IEquatable? actual, string? message) -> void +static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(System.IEquatable? expected, System.IEquatable? actual, string? message, params object?[]? parameters) -> void static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(T? expected, T? actual) -> void static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(T? expected, T? actual, string? message) -> void static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AreEqual(T? expected, T? actual, string? message, params object?[]? parameters) -> void diff --git a/test/UnitTests/TestFramework.UnitTests/Assertions/AssertTests.AreEqualTests.cs b/test/UnitTests/TestFramework.UnitTests/Assertions/AssertTests.AreEqualTests.cs index 56a3dd19dd..2f80745e1e 100644 --- a/test/UnitTests/TestFramework.UnitTests/Assertions/AssertTests.AreEqualTests.cs +++ b/test/UnitTests/TestFramework.UnitTests/Assertions/AssertTests.AreEqualTests.cs @@ -381,6 +381,18 @@ public void AreEqualStringIgnoreCaseCultureInfoMessageParametersNullabilityPostC _ = cultureInfo.Calendar; // no warning } + public void AreEqualUsingCustomIEquatable() + { + var instanceOfA = new A { Id = "SomeId" }; + var instanceOfB = new B { Id = "SomeId" }; + + // This call works because B implements IEquatable + Assert.AreEqual(instanceOfA, instanceOfB); + + // This one doesn't work + VerifyThrows(() => Assert.AreEqual(instanceOfB, instanceOfA)); + } + private CultureInfo? GetCultureInfo() => CultureInfo.CurrentCulture; private class TypeOverridesEquals @@ -426,4 +438,32 @@ public override int GetHashCode(TypeOverridesEquals obj) throw new NotImplementedException(); } } + + private class A : IEquatable + { + public string Id { get; set; } = string.Empty; + + public bool Equals(A? other) + => other?.Id == Id; + + public override bool Equals(object? obj) + => Equals(obj as A); + + public override int GetHashCode() + => Id.GetHashCode() + 123; + } + + private class B : IEquatable + { + public string Id { get; set; } = string.Empty; + + public override bool Equals(object? obj) + => Equals(obj as A); + + public bool Equals(A? other) + => other?.Id == Id; + + public override int GetHashCode() + => Id.GetHashCode() + 1234; + } }