From 5fd32d998950187970181f55967a715b43ef7797 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Wed, 21 Jun 2023 19:39:23 +0100 Subject: [PATCH] Fix support for JsonSerializerContext contained in arbitrary types. (#87829) * Fix support for JsonSerializerContext contained in struct types. * Extend support to other kinds. * Simplify testing * Improve namespace detection logic. --- .../gen/Helpers/RoslynExtensions.cs | 28 ++++ .../System.Text.Json/gen/JsonConstants.cs | 2 - .../gen/JsonSourceGenerator.Emitter.cs | 2 +- .../gen/JsonSourceGenerator.Parser.cs | 126 ++++-------------- .../JsonSourceGeneratorTests.cs | 45 +++++++ 5 files changed, 102 insertions(+), 101 deletions(-) diff --git a/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs b/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs index 6dbfe0ca5fe8f..e88278884aacd 100644 --- a/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs +++ b/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs @@ -8,6 +8,7 @@ using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; namespace System.Text.Json.SourceGeneration @@ -256,5 +257,32 @@ public static INamedTypeSymbol[] GetSortedTypeHierarchy(this ITypeSymbol type) return JsonHelpers.TraverseGraphWithTopologicalSort(namedType, static t => t.AllInterfaces, SymbolEqualityComparer.Default); } } + + /// + /// Returns the kind keyword corresponding to the specified declaration syntax node. + /// + public static string GetTypeKindKeyword(this TypeDeclarationSyntax typeDeclaration) + { + switch (typeDeclaration.Kind()) + { + case SyntaxKind.ClassDeclaration: + return "class"; + case SyntaxKind.InterfaceDeclaration: + return "interface"; + case SyntaxKind.StructDeclaration: + return "struct"; + case SyntaxKind.RecordDeclaration: + return "record"; + case SyntaxKind.RecordStructDeclaration: + return "record struct"; + case SyntaxKind.EnumDeclaration: + return "enum"; + case SyntaxKind.DelegateDeclaration: + return "delegate"; + default: + Debug.Fail("unexpected syntax kind"); + return null; + } + } } } diff --git a/src/libraries/System.Text.Json/gen/JsonConstants.cs b/src/libraries/System.Text.Json/gen/JsonConstants.cs index 148ffae3c84be..71c121c7c5574 100644 --- a/src/libraries/System.Text.Json/gen/JsonConstants.cs +++ b/src/libraries/System.Text.Json/gen/JsonConstants.cs @@ -5,8 +5,6 @@ namespace System.Text.Json { internal static partial class JsonConstants { - public const string GlobalNamespaceValue = ""; - public const string SystemTextJsonSourceGenerationName = "System.Text.Json.SourceGeneration"; public const string IJsonOnSerializedFullName = "System.Text.Json.Serialization.IJsonOnSerialized"; diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Emitter.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Emitter.cs index 6383b2e3c8daa..634d8f62ffb36 100644 --- a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Emitter.cs +++ b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Emitter.cs @@ -141,7 +141,7 @@ private static SourceWriter CreateSourceWriterWithContextHeader(ContextGeneratio """); - if (contextSpec.Namespace != JsonConstants.GlobalNamespaceValue) + if (contextSpec.Namespace != null) { writer.WriteLine($"namespace {contextSpec.Namespace}"); writer.WriteLine('{'); diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs index 29c20bc47a60f..afd0c12772cc6 100644 --- a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs +++ b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs @@ -81,14 +81,14 @@ public Parser(KnownTypeSymbols knownSymbols) Debug.Assert(_typesToGenerate.Count == 0); Debug.Assert(_generatedTypes.Count == 0); - if (!DerivesFromJsonSerializerContext(contextClassDeclaration, _knownSymbols.JsonSerializerContextType, semanticModel, cancellationToken)) + INamedTypeSymbol? contextTypeSymbol = semanticModel.GetDeclaredSymbol(contextClassDeclaration, cancellationToken); + Debug.Assert(contextTypeSymbol != null); + + if (!_knownSymbols.JsonSerializerContextType.IsAssignableFrom(contextTypeSymbol)) { return null; } - INamedTypeSymbol? contextTypeSymbol = semanticModel.GetDeclaredSymbol(contextClassDeclaration, cancellationToken); - Debug.Assert(contextTypeSymbol != null); - if (!TryParseJsonSerializerContextAttributes( contextTypeSymbol, out List? rootSerializableTypes, @@ -105,7 +105,7 @@ public Parser(KnownTypeSymbols knownSymbols) } Location contextLocation = contextClassDeclaration.GetLocation(); - if (!TryGetClassDeclarationList(contextTypeSymbol, out List? classDeclarationList)) + if (!TryGetNestedTypeDeclarations(contextClassDeclaration, semanticModel, cancellationToken, out List? classDeclarationList)) { // Class or one of its containing types is not partial so we can't add to it. ReportDiagnostic(DiagnosticDescriptors.ContextClassesMustBePartial, contextLocation, contextTypeSymbol.Name); @@ -138,7 +138,7 @@ public Parser(KnownTypeSymbols knownSymbols) { ContextType = new(contextTypeSymbol), GeneratedTypes = _generatedTypes.Values.OrderBy(t => t.TypeRef.FullyQualifiedName).ToImmutableEquatableArray(), - Namespace = contextTypeSymbol.ContainingNamespace.ToDisplayString(), + Namespace = contextTypeSymbol.ContainingNamespace is { IsGlobalNamespace: false } ns ? ns.ToDisplayString() : null, ContextClassDeclarations = classDeclarationList.ToImmutableEquatableArray(), DefaultIgnoreCondition = options.DefaultIgnoreCondition, IgnoreReadOnlyFields = options.IgnoreReadOnlyFields, @@ -154,112 +154,42 @@ public Parser(KnownTypeSymbols knownSymbols) return contextGenSpec; } - // Returns true if a given type derives directly from JsonSerializerContext. - private static bool DerivesFromJsonSerializerContext( - ClassDeclarationSyntax classDeclarationSyntax, - INamedTypeSymbol jsonSerializerContextSymbol, - SemanticModel compilationSemanticModel, - CancellationToken cancellationToken) + private static bool TryGetNestedTypeDeclarations(ClassDeclarationSyntax contextClassSyntax, SemanticModel semanticModel, CancellationToken cancellationToken, [NotNullWhen(true)] out List? typeDeclarations) { - SeparatedSyntaxList? baseTypeSyntaxList = classDeclarationSyntax.BaseList?.Types; - if (baseTypeSyntaxList == null) - { - return false; - } - - INamedTypeSymbol? match = null; + typeDeclarations = null; - foreach (BaseTypeSyntax baseTypeSyntax in baseTypeSyntaxList) + for (TypeDeclarationSyntax? currentType = contextClassSyntax; currentType != null; currentType = currentType.Parent as TypeDeclarationSyntax) { - INamedTypeSymbol? candidate = compilationSemanticModel.GetSymbolInfo(baseTypeSyntax.Type, cancellationToken).Symbol as INamedTypeSymbol; - if (candidate != null && jsonSerializerContextSymbol.Equals(candidate, SymbolEqualityComparer.Default)) + StringBuilder stringBuilder = new(); + bool isPartialType = false; + + foreach (SyntaxToken modifier in currentType.Modifiers) { - match = candidate; - break; + stringBuilder.Append(modifier.Text); + stringBuilder.Append(' '); + isPartialType |= modifier.IsKind(SyntaxKind.PartialKeyword); } - } - return match != null; - } - - private static bool TryGetClassDeclarationList(INamedTypeSymbol typeSymbol, [NotNullWhen(true)] out List? classDeclarationList) - { - INamedTypeSymbol currentSymbol = typeSymbol; - classDeclarationList = null; - - while (currentSymbol != null) - { - ClassDeclarationSyntax? classDeclarationSyntax = currentSymbol.DeclaringSyntaxReferences.First().GetSyntax() as ClassDeclarationSyntax; - - if (classDeclarationSyntax != null) + if (!isPartialType) { - SyntaxTokenList tokenList = classDeclarationSyntax.Modifiers; - int tokenCount = tokenList.Count; - - bool isPartial = false; - - string[] declarationElements = new string[tokenCount + 2]; - - for (int i = 0; i < tokenCount; i++) - { - SyntaxToken token = tokenList[i]; - declarationElements[i] = token.Text; - - if (token.IsKind(SyntaxKind.PartialKeyword)) - { - isPartial = true; - } - } - - if (!isPartial) - { - classDeclarationList = null; - return false; - } - - declarationElements[tokenCount] = "class"; - declarationElements[tokenCount + 1] = GetClassDeclarationName(currentSymbol); - - (classDeclarationList ??= new List()).Add(string.Join(" ", declarationElements)); + typeDeclarations = null; + return false; } - currentSymbol = currentSymbol.ContainingType; - } + stringBuilder.Append(currentType.GetTypeKindKeyword()); + stringBuilder.Append(' '); - Debug.Assert(classDeclarationList?.Count > 0); - return true; - } + INamedTypeSymbol? typeSymbol = semanticModel.GetDeclaredSymbol(currentType, cancellationToken); + Debug.Assert(typeSymbol != null); - private static string GetClassDeclarationName(INamedTypeSymbol typeSymbol) - { - if (typeSymbol.TypeArguments.Length == 0) - { - return typeSymbol.Name; - } + string typeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + stringBuilder.Append(typeName); - StringBuilder sb = new StringBuilder(); - - sb.Append(typeSymbol.Name); - sb.Append('<'); - - bool first = true; - foreach (ITypeSymbol typeArg in typeSymbol.TypeArguments) - { - if (!first) - { - sb.Append(", "); - } - else - { - first = false; - } - - sb.Append(typeArg.Name); + (typeDeclarations ??= new()).Add(stringBuilder.ToString()); } - sb.Append('>'); - - return sb.ToString(); + Debug.Assert(typeDeclarations?.Count > 0); + return true; } private TypeRef EnqueueType(ITypeSymbol type, JsonSourceGenerationMode? generationMode, string? typeInfoPropertyName = null, Location? attributeLocation = null) diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorTests.cs index d8f51dd3cec10..43df7000d1971 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorTests.cs @@ -800,5 +800,50 @@ public class NestedGenericClass result.AssertContainsType("global::HelloWorld.MyGenericClass.NestedGenericClass"); result.AssertContainsType("string"); } + + [Theory] + [InlineData("public sealed partial class MySealedClass")] + [InlineData("public partial class MyGenericClass")] + [InlineData("public partial interface IMyInterface")] + [InlineData("public partial interface IMyGenericInterface")] + [InlineData("public partial struct MyStruct")] + [InlineData("public partial struct MyGenericStruct")] + [InlineData("public ref partial struct MyRefStruct")] + [InlineData("public ref partial struct MyGenericRefStruct")] + [InlineData("public readonly partial struct MyReadOnlyStruct")] + [InlineData("public readonly ref partial struct MyReadOnlyRefStruct")] +#if ROSLYN4_0_OR_GREATER && NETCOREAPP + [InlineData("public partial record MyRecord(int x)")] + [InlineData("public partial record struct MyRecordStruct(int x)")] +#endif + public void NestedContextsAreSupported(string containingTypeDeclarationHeader) + { + string source = $$""" + using System.Text.Json.Serialization; + + namespace HelloWorld + { + {{containingTypeDeclarationHeader}} + { + [JsonSerializable(typeof(MyClass))] + internal partial class JsonContext : JsonSerializerContext + { + } + } + + public class MyClass + { + } + } + """; + + Compilation compilation = CompilationHelper.CreateCompilation(source); + + JsonSourceGeneratorResult result = CompilationHelper.RunJsonSourceGenerator(compilation); + + // Make sure compilation was successful. + Assert.Empty(result.NewCompilation.GetDiagnostics()); + Assert.Empty(result.Diagnostics); + } } }