diff --git a/sdk/identity/Azure.Identity/src/AuthenticationRecord.cs b/sdk/identity/Azure.Identity/src/AuthenticationRecord.cs index 0b125a7421033..c74a76b1aae54 100644 --- a/sdk/identity/Azure.Identity/src/AuthenticationRecord.cs +++ b/sdk/identity/Azure.Identity/src/AuthenticationRecord.cs @@ -45,10 +45,9 @@ internal AuthenticationRecord(AuthenticationResult authResult, string clientId) internal AuthenticationRecord(string username, string authority, string homeAccountId, string tenantId, string clientId) { - Username = username; Authority = authority; - AccountId = new AccountId(homeAccountId); + AccountId = BuildAccountIdFromString(homeAccountId); TenantId = tenantId; ClientId = clientId; } @@ -176,7 +175,7 @@ private static async Task DeserializeAsync(Stream stream, authProfile.Authority = prop.Value.GetString(); break; case HomeAccountIdPropertyName: - authProfile.AccountId = new AccountId(prop.Value.GetString()); + authProfile.AccountId = BuildAccountIdFromString(prop.Value.GetString()); break; case TenantIdPropertyName: authProfile.TenantId = prop.Value.GetString(); @@ -189,5 +188,22 @@ private static async Task DeserializeAsync(Stream stream, return authProfile; } + + private static AccountId BuildAccountIdFromString(string homeAccountId) + { + //For the Microsoft identity platform (formerly named Azure AD v2.0), the identifier is the concatenation of + // Microsoft.Identity.Client.AccountId.ObjectId and Microsoft.Identity.Client.AccountId.TenantId separated by a dot. + var homeAccountSegments = homeAccountId.Split('.'); + AccountId accountId; + if (homeAccountSegments.Length == 2) + { + accountId = new AccountId(homeAccountId, homeAccountSegments[0], homeAccountSegments[1]); + } + else + { + accountId = new AccountId(homeAccountId); + } + return accountId; + } } } diff --git a/sdk/identity/Azure.Identity/tests/AuthenticationRecordTests.cs b/sdk/identity/Azure.Identity/tests/AuthenticationRecordTests.cs index 283d9e38998dd..cb19019aab161 100644 --- a/sdk/identity/Azure.Identity/tests/AuthenticationRecordTests.cs +++ b/sdk/identity/Azure.Identity/tests/AuthenticationRecordTests.cs @@ -6,6 +6,9 @@ using System.Text; using System.Threading; using System.Threading.Tasks; + +using Microsoft.Identity.Client; + using NUnit.Framework; namespace Azure.Identity.Tests @@ -14,6 +17,20 @@ public class AuthenticationRecordTests { private const int TestBufferSize = 512; + [Test] + public void AuthenticationRecordConstructor() + { + var record = new AuthenticationRecord(Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), + $"{Guid.NewGuid()}.{Guid.NewGuid()}", Guid.NewGuid().ToString(), Guid.NewGuid().ToString()); + + IAccount account = (AuthenticationAccount)record; + Assert.NotNull(account.Username); + Assert.NotNull(account.Environment); + Assert.NotNull(account.HomeAccountId.Identifier); + Assert.NotNull(account.HomeAccountId.ObjectId); + Assert.NotNull(account.HomeAccountId.TenantId); + } + [Test] public void SerializeDeserializeInputChecks() { @@ -28,23 +45,31 @@ public void SerializeDeserializeInputChecks() [Test] public async Task SerializeDeserializeAsync() { - var expRecord = new AuthenticationRecord(Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), Guid.NewGuid().ToString()); + var expRecord = new AuthenticationRecord(Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), $"{Guid.NewGuid()}.{Guid.NewGuid()}", Guid.NewGuid().ToString(), Guid.NewGuid().ToString()); byte[] buff = new byte[TestBufferSize]; var stream = new MemoryStream(buff); await expRecord.SerializeAsync(stream); + IAccount expAccount = (AuthenticationAccount)expRecord; stream = new MemoryStream(buff, 0, (int)stream.Position); var actRecord = await AuthenticationRecord.DeserializeAsync(stream); + IAccount actAccount = (AuthenticationAccount)actRecord; Assert.AreEqual(expRecord.Username, actRecord.Username); Assert.AreEqual(expRecord.Authority, actRecord.Authority); Assert.AreEqual(expRecord.HomeAccountId, actRecord.HomeAccountId); Assert.AreEqual(expRecord.TenantId, actRecord.TenantId); Assert.AreEqual(expRecord.ClientId, actRecord.ClientId); + + Assert.AreEqual(expAccount.Username, actAccount.Username); + Assert.AreEqual(expAccount.Environment, actAccount.Environment); + Assert.AreEqual(expAccount.HomeAccountId.Identifier, actAccount.HomeAccountId.Identifier); + Assert.AreEqual(expAccount.HomeAccountId.ObjectId, actAccount.HomeAccountId.ObjectId); + Assert.AreEqual(expAccount.HomeAccountId.TenantId, actAccount.HomeAccountId.TenantId); } [Test]