diff --git a/src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs b/src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs index cc781c08f5..797ce49b51 100644 --- a/src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs +++ b/src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs @@ -10,11 +10,11 @@ namespace Riok.Mapperly.Configuration; /// internal class AttributeDataAccessor { - private readonly WellKnownTypes _types; + private readonly SymbolAccessor _symbolAccessor; - public AttributeDataAccessor(WellKnownTypes types) + public AttributeDataAccessor(SymbolAccessor symbolAccessor) { - _types = types; + _symbolAccessor = symbolAccessor; } public T AccessSingle(ISymbol symbol) @@ -42,11 +42,8 @@ public IEnumerable Access(ISymbol symbol) { var attrType = typeof(TAttribute); var dataType = typeof(TData); - var attrSymbol = _types.Get($"{attrType.Namespace}.{attrType.Name}"); - var attrDatas = symbol - .GetAttributes() - .Where(x => SymbolEqualityComparer.Default.Equals(x.AttributeClass?.ConstructedFrom ?? x.AttributeClass, attrSymbol)); + var attrDatas = _symbolAccessor.GetAttributes(symbol); foreach (var attrData in attrDatas) { diff --git a/src/Riok.Mapperly/Configuration/MapperConfiguration.cs b/src/Riok.Mapperly/Configuration/MapperConfiguration.cs index 7f340b45cb..9f220ae1a7 100644 --- a/src/Riok.Mapperly/Configuration/MapperConfiguration.cs +++ b/src/Riok.Mapperly/Configuration/MapperConfiguration.cs @@ -9,9 +9,9 @@ public class MapperConfiguration private readonly MappingConfiguration _defaultConfiguration; private readonly AttributeDataAccessor _dataAccessor; - public MapperConfiguration(WellKnownTypes wellKnownTypes, ISymbol mapperSymbol) + public MapperConfiguration(SymbolAccessor symbolAccessor, ISymbol mapperSymbol) { - _dataAccessor = new AttributeDataAccessor(wellKnownTypes); + _dataAccessor = new AttributeDataAccessor(symbolAccessor); Mapper = _dataAccessor.AccessSingle(mapperSymbol); _defaultConfiguration = new MappingConfiguration( new EnumMappingConfiguration( diff --git a/src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs b/src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs index 423866fb9d..3a1bd8b6b3 100644 --- a/src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs @@ -12,6 +12,7 @@ namespace Riok.Mapperly.Descriptors; public class DescriptorBuilder { private readonly MapperDescriptor _mapperDescriptor; + private readonly SymbolAccessor _symbolAccessor; private readonly MappingCollection _mappings = new(); private readonly MethodNameBuilder _methodNameBuilder = new(); @@ -25,15 +26,18 @@ public DescriptorBuilder( Compilation compilation, ClassDeclarationSyntax mapperSyntax, INamedTypeSymbol mapperSymbol, - WellKnownTypes wellKnownTypes + WellKnownTypes wellKnownTypes, + SymbolAccessor symbolAccessor ) { _mapperDescriptor = new MapperDescriptor(mapperSyntax, mapperSymbol, _methodNameBuilder); + _symbolAccessor = symbolAccessor; _mappingBodyBuilder = new MappingBodyBuilder(_mappings); _builderContext = new SimpleMappingBuilderContext( compilation, - new MapperConfiguration(wellKnownTypes, mapperSymbol), + new MapperConfiguration(symbolAccessor, mapperSymbol), wellKnownTypes, + _symbolAccessor, _mapperDescriptor, sourceContext, new MappingBuilder(_mappings), @@ -77,7 +81,7 @@ private void ExtractUserMappings() private void ReserveMethodNames() { - foreach (var methodSymbol in _mapperDescriptor.Symbol.GetAllMembers()) + foreach (var methodSymbol in _symbolAccessor.GetAllMembers(_mapperDescriptor.Symbol)) { _methodNameBuilder.Reserve(methodSymbol.Name); } diff --git a/src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs b/src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs index 5b6e8f25f2..8a4630d601 100644 --- a/src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs @@ -63,7 +63,12 @@ private readonly record struct CollectionTypeInfo( new CollectionTypeInfo(CollectionType.ReadOnlyMemory, typeof(ReadOnlyMemory<>)), }; - public static CollectionInfos? Build(WellKnownTypes wellKnownTypes, ITypeSymbol source, ITypeSymbol target) + public static CollectionInfos? Build( + WellKnownTypes wellKnownTypes, + SymbolAccessor symbolAccessor, + ITypeSymbol source, + ITypeSymbol target + ) { // check for enumerated type to quickly check that both are collection types var enumeratedSourceType = GetEnumeratedType(wellKnownTypes, source); @@ -74,13 +79,18 @@ private readonly record struct CollectionTypeInfo( if (enumeratedTargetType == null) return null; - var sourceInfo = BuildCollectionInfo(wellKnownTypes, source, enumeratedSourceType); - var targetInfo = BuildCollectionInfo(wellKnownTypes, target, enumeratedTargetType); + var sourceInfo = BuildCollectionInfo(wellKnownTypes, symbolAccessor, source, enumeratedSourceType); + var targetInfo = BuildCollectionInfo(wellKnownTypes, symbolAccessor, target, enumeratedTargetType); return new CollectionInfos(sourceInfo, targetInfo); } - private static CollectionInfo BuildCollectionInfo(WellKnownTypes wellKnownTypes, ITypeSymbol type, ITypeSymbol enumeratedType) + private static CollectionInfo BuildCollectionInfo( + WellKnownTypes wellKnownTypes, + SymbolAccessor symbolAccessor, + ITypeSymbol type, + ITypeSymbol enumeratedType + ) { var collectionTypeInfo = GetCollectionTypeInfo(wellKnownTypes, type); var typeInfo = collectionTypeInfo?.CollectionType ?? CollectionType.None; @@ -90,7 +100,7 @@ private static CollectionInfo BuildCollectionInfo(WellKnownTypes wellKnownTypes, typeInfo, GetImplementedCollectionTypes(wellKnownTypes, type, typeInfo), enumeratedType, - FindCountProperty(wellKnownTypes, type, typeInfo), + FindCountProperty(wellKnownTypes, symbolAccessor, type, typeInfo), HasValidAddMethod(wellKnownTypes, type, typeInfo), collectionTypeInfo?.Immutable == true ); @@ -139,7 +149,7 @@ or CollectionType.SortedSet || t.HasImplicitGenericImplementation(types.Get(typeof(ISet<>)), nameof(ISet.Add)); } - private static string? FindCountProperty(WellKnownTypes types, ITypeSymbol t, CollectionType typeInfo) + private static string? FindCountProperty(WellKnownTypes types, SymbolAccessor symbolAccessor, ITypeSymbol t, CollectionType typeInfo) { if (typeInfo is CollectionType.IEnumerable) return null; @@ -158,7 +168,8 @@ or CollectionType.ReadOnlyMemory return "Count"; var intType = types.Get(); - var member = t.GetAccessibleMappableMembers() + var member = symbolAccessor + .GetAllAccessibleMappableMembers(t) .FirstOrDefault( x => x.Name is nameof(ICollection.Count) or nameof(Array.Length) diff --git a/src/Riok.Mapperly/Descriptors/Enumerables/EnsureCapacity/EnsureCapacityBuilder.cs b/src/Riok.Mapperly/Descriptors/Enumerables/EnsureCapacity/EnsureCapacityBuilder.cs index a645fb6eab..7fb78d7510 100644 --- a/src/Riok.Mapperly/Descriptors/Enumerables/EnsureCapacity/EnsureCapacityBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/Enumerables/EnsureCapacity/EnsureCapacityBuilder.cs @@ -1,6 +1,5 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis; -using Riok.Mapperly.Helpers; namespace Riok.Mapperly.Descriptors.Enumerables.EnsureCapacity; @@ -19,8 +18,8 @@ public static class EnsureCapacityBuilder if (ctx.CollectionInfos == null) return null; - var capacityMethod = ctx.Target - .GetAllMethods(EnsureCapacityName) + var capacityMethod = ctx.SymbolAccessor + .GetAllMethods(ctx.Target, EnsureCapacityName) .FirstOrDefault(x => x.Parameters.Length == 1 && x.Parameters[0].Type.SpecialType == SpecialType.System_Int32 && !x.IsStatic); // if EnsureCapacity is not available then return null diff --git a/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/BuilderContext/MembersMappingBuilderContext.cs b/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/BuilderContext/MembersMappingBuilderContext.cs index c1b406e575..cb09e44455 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/BuilderContext/MembersMappingBuilderContext.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/BuilderContext/MembersMappingBuilderContext.cs @@ -68,12 +68,14 @@ private HashSet InitIgnoredUnmatchedProperties(IEnumerable allPr private HashSet GetSourceMemberNames() { - return Mapping.SourceType.GetAccessibleMappableMembers().Select(x => x.Name).ToHashSet(); + return BuilderContext.SymbolAccessor.GetAllAccessibleMappableMembers(Mapping.SourceType).Select(x => x.Name).ToHashSet(); } private Dictionary GetTargetMembers() { - return Mapping.TargetType.GetAccessibleMappableMembers().ToDictionary(x => x.Name, StringComparer.OrdinalIgnoreCase); + return BuilderContext.SymbolAccessor + .GetAllAccessibleMappableMembers(Mapping.TargetType) + .ToDictionary(x => x.Name, StringComparer.OrdinalIgnoreCase); } private Dictionary> GetMemberConfigurations() diff --git a/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/NewInstanceObjectMemberMappingBodyBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/NewInstanceObjectMemberMappingBodyBuilder.cs index 2aef6393c5..3159e5f03e 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/NewInstanceObjectMemberMappingBodyBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/NewInstanceObjectMemberMappingBodyBuilder.cs @@ -52,6 +52,7 @@ private static void BuildInitOnlyMemberMappings(INewInstanceBuilderContext memberConfigs return; } - if (!MemberPath.TryFind(ctx.Mapping.SourceType, memberConfig.Source.Path, out var sourceMemberPath)) + if ( + !MemberPath.TryFind( + ctx.Mapping.SourceType, + memberConfig.Source.Path, + ctx.BuilderContext.SymbolAccessor, + out var sourceMemberPath + ) + ) { ctx.BuilderContext.ReportDiagnostic( DiagnosticDescriptors.SourceMemberNotFound, @@ -182,15 +190,15 @@ private static void BuildConstructorMapping(INewInstanceBuilderContext // ctors annotated with [Obsolete] are considered last unless they have a MapperConstructor attribute set var ctorCandidates = namedTargetType.InstanceConstructors .Where(ctor => ctor.IsAccessible()) - .OrderByDescending(x => x.HasAttribute(ctx.BuilderContext.Types.Get())) - .ThenBy(x => x.HasAttribute(ctx.BuilderContext.Types.Get())) + .OrderByDescending(x => ctx.BuilderContext.SymbolAccessor.HasAttribute(x)) + .ThenBy(x => ctx.BuilderContext.SymbolAccessor.HasAttribute(x)) .ThenByDescending(x => x.Parameters.Length == 0) .ThenByDescending(x => x.Parameters.Length); foreach (var ctorCandidate in ctorCandidates) { if (!TryBuildConstructorMapping(ctx, ctorCandidate, out var mappedTargetMemberNames, out var constructorParameterMappings)) { - if (ctorCandidate.HasAttribute(ctx.BuilderContext.Types.Get())) + if (ctx.BuilderContext.SymbolAccessor.HasAttribute(ctorCandidate)) { ctx.BuilderContext.ReportDiagnostic( DiagnosticDescriptors.CannotMapToConfiguredConstructor, @@ -293,6 +301,7 @@ private static bool TryFindConstructorParameterSourcePath( MemberPathCandidateBuilder.BuildMemberPathCandidates(parameter.Name), ctx.IgnoredSourceMemberNames, StringComparer.OrdinalIgnoreCase, + ctx.BuilderContext.SymbolAccessor, out sourcePath ); } @@ -317,7 +326,7 @@ out sourcePath return false; } - if (!MemberPath.TryFind(ctx.Mapping.SourceType, memberConfig.Source.Path, out sourcePath)) + if (!MemberPath.TryFind(ctx.Mapping.SourceType, memberConfig.Source.Path, ctx.BuilderContext.SymbolAccessor, out sourcePath)) { ctx.BuilderContext.ReportDiagnostic( DiagnosticDescriptors.SourceMemberNotFound, diff --git a/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/ObjectMemberMappingBodyBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/ObjectMemberMappingBodyBuilder.cs index 0fd33d0d7f..e21b921a7d 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/ObjectMemberMappingBodyBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBodyBuilders/ObjectMemberMappingBodyBuilder.cs @@ -48,6 +48,7 @@ public static void BuildMappingBody(IMembersContainerBuilderContext _collectionInfos ??= CollectionInfoBuilder.Build(Types, Source, Target); + public CollectionInfos? CollectionInfos => _collectionInfos ??= CollectionInfoBuilder.Build(Types, SymbolAccessor, Source, Target); protected IMethodSymbol? UserSymbol { get; } diff --git a/src/Riok.Mapperly/Descriptors/MappingBuilders/DictionaryMappingBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBuilders/DictionaryMappingBuilder.cs index 8f23016b78..b7f883e628 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBuilders/DictionaryMappingBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBuilders/DictionaryMappingBuilder.cs @@ -38,8 +38,8 @@ or CollectionType.IDictionary or CollectionType.IReadOnlyDictionary ) { - var sourceHasCount = ctx.Source - .GetAllProperties(CountPropertyName) + var sourceHasCount = ctx.SymbolAccessor + .GetAllProperties(ctx.Source, CountPropertyName) .Any(x => !x.IsStatic && x is { IsIndexer: false, IsWriteOnly: false, Type.SpecialType: SpecialType.System_Int32 }); var targetDictionarySymbol = ctx.Types.Get(typeof(Dictionary<,>)).Construct(keyMapping.TargetType, valueMapping.TargetType); diff --git a/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumMappingBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumMappingBuilder.cs index ee8e8203d2..d3763d050c 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumMappingBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumMappingBuilder.cs @@ -62,17 +62,15 @@ private static TypeMapping BuildEnumToEnumCastMapping( { var explicitMappingSourceNames = explicitMappings.Keys.Select(x => x.Name).ToHashSet(); var explicitMappingTargetNames = explicitMappings.Values.Select(x => x.Name).ToHashSet(); - var sourceValues = ctx.Source - .GetMembers() - .OfType() + var sourceValues = ctx.SymbolAccessor + .GetAllFields(ctx.Source) .Where(x => !explicitMappingSourceNames.Contains(x.Name)) .ToDictionary(field => field.Name, field => field.ConstantValue); - var targetValues = ctx.Target - .GetMembers() - .OfType() + var targetValues = ctx.SymbolAccessor + .GetAllFields(ctx.Target) .Where(x => !explicitMappingTargetNames.Contains(x.Name)) .ToDictionary(field => field.Name, field => field.ConstantValue); - var targetMemberNames = ctx.Target.GetMembers().OfType().Select(x => x.Name).ToHashSet(); + var targetMemberNames = ctx.SymbolAccessor.GetAllFields(ctx.Target).Select(x => x.Name).ToHashSet(); var missingTargetValues = targetValues.Where( field => @@ -99,7 +97,7 @@ private static TypeMapping BuildEnumToEnumCastMapping( var checkDefinedMode = checkTargetDefined switch { false => EnumCastMapping.CheckDefinedMode.NoCheck, - _ when ctx.Target.HasAttribute(ctx.Types.Get()) => EnumCastMapping.CheckDefinedMode.Flags, + _ when ctx.SymbolAccessor.HasAttribute(ctx.Target) => EnumCastMapping.CheckDefinedMode.Flags, _ => EnumCastMapping.CheckDefinedMode.Value, }; @@ -124,8 +122,8 @@ IReadOnlyDictionary explicitMappings ) { var fallbackMapping = BuildFallbackMapping(ctx); - var targetFieldsByName = ctx.Target.GetMembers().OfType().ToDictionary(x => x.Name); - var sourceFieldsByName = ctx.Source.GetMembers().OfType().ToDictionary(x => x.Name); + var targetFieldsByName = ctx.SymbolAccessor.GetAllFields(ctx.Target).ToDictionary(x => x.Name); + var sourceFieldsByName = ctx.SymbolAccessor.GetAllFields(ctx.Source).ToDictionary(x => x.Name); Func getTargetField; if (ctx.Configuration.Enum.IgnoreCase) @@ -143,9 +141,8 @@ IReadOnlyDictionary explicitMappings getTargetField = source => explicitMappings.GetValueOrDefault(source) ?? targetFieldsByName.GetValueOrDefault(source.Name); } - var enumMemberMappings = ctx.Source - .GetMembers() - .OfType() + var enumMemberMappings = ctx.SymbolAccessor + .GetAllFields(ctx.Source) .Select(x => (Source: x, Target: getTargetField(x))) .Where(x => x.Target != null) .ToDictionary(x => x.Source.Name, x => x.Target!.Name); diff --git a/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumToStringMappingBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumToStringMappingBuilder.cs index c60ea16fb8..146239d801 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumToStringMappingBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumToStringMappingBuilder.cs @@ -18,6 +18,6 @@ public static class EnumToStringMappingBuilder // to string => use an optimized method of Enum.ToString which would use slow reflection // use Enum.ToString as fallback (for ex. for flags) - return new EnumToStringMapping(ctx.Source, ctx.Target, ctx.Source.GetMembers().OfType()); + return new EnumToStringMapping(ctx.Source, ctx.Target, ctx.SymbolAccessor.GetAllFields(ctx.Source)); } } diff --git a/src/Riok.Mapperly/Descriptors/MappingBuilders/ParseMappingBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBuilders/ParseMappingBuilder.cs index 09b9bf6d5e..616a529664 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBuilders/ParseMappingBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBuilders/ParseMappingBuilder.cs @@ -19,8 +19,8 @@ public static class ParseMappingBuilder var targetIsNullable = ctx.Target.NonNullable(out var nonNullableTarget); - var parseMethodCandidates = nonNullableTarget - .GetAllMethods(ParseMethodName) + var parseMethodCandidates = ctx.SymbolAccessor + .GetAllMethods(nonNullableTarget, ParseMethodName) .Where( m => m.IsStatic diff --git a/src/Riok.Mapperly/Descriptors/MappingBuilders/StringToEnumMappingBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBuilders/StringToEnumMappingBuilder.cs index 79ae509326..77c9933468 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBuilders/StringToEnumMappingBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBuilders/StringToEnumMappingBuilder.cs @@ -37,7 +37,7 @@ public static class StringToEnumMappingBuilder // however we currently don't support all features of Enum.Parse yet (ex. flags) // therefore we use Enum.Parse as fallback. var fallbackMapping = BuildFallbackParseMapping(ctx, genericEnumParseMethodSupported); - var members = ctx.Target.GetFields(); + var members = ctx.SymbolAccessor.GetAllFields(ctx.Target); if (fallbackMapping.FallbackMember != null) { // no need to explicitly map fallback value diff --git a/src/Riok.Mapperly/Descriptors/ObjectFactories/ObjectFactoryBuilder.cs b/src/Riok.Mapperly/Descriptors/ObjectFactories/ObjectFactoryBuilder.cs index f4bd7774c9..95e2120f24 100644 --- a/src/Riok.Mapperly/Descriptors/ObjectFactories/ObjectFactoryBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/ObjectFactories/ObjectFactoryBuilder.cs @@ -12,7 +12,7 @@ public static ObjectFactoryCollection ExtractObjectFactories(SimpleMappingBuilde var objectFactories = mapperSymbol .GetMembers() .OfType() - .Where(m => m.HasAttribute(ctx.Types.Get())) + .Where(m => ctx.SymbolAccessor.HasAttribute(m)) .Select(x => BuildObjectFactory(ctx, x)) .WhereNotNull() .ToList(); diff --git a/src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs b/src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs index 7d66a09fbe..17bde5273d 100644 --- a/src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs +++ b/src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs @@ -18,6 +18,7 @@ public SimpleMappingBuilderContext( Compilation compilation, MapperConfiguration configuration, WellKnownTypes types, + SymbolAccessor symbolAccessor, MapperDescriptor descriptor, SourceProductionContext context, MappingBuilder mappingBuilder, @@ -26,6 +27,7 @@ ExistingTargetMappingBuilder existingTargetMappingBuilder { Compilation = compilation; Types = types; + SymbolAccessor = symbolAccessor; _configuration = configuration; _descriptor = descriptor; _context = context; @@ -38,6 +40,7 @@ protected SimpleMappingBuilderContext(SimpleMappingBuilderContext ctx) ctx.Compilation, ctx._configuration, ctx.Types, + ctx.SymbolAccessor, ctx._descriptor, ctx._context, ctx.MappingBuilder, @@ -49,6 +52,7 @@ protected SimpleMappingBuilderContext(SimpleMappingBuilderContext ctx) public MapperAttribute MapperConfiguration => _configuration.Mapper; public WellKnownTypes Types { get; } + public SymbolAccessor SymbolAccessor { get; } protected MappingBuilder MappingBuilder { get; } diff --git a/src/Riok.Mapperly/Descriptors/SymbolAccessor.cs b/src/Riok.Mapperly/Descriptors/SymbolAccessor.cs new file mode 100644 index 0000000000..7acc89d93c --- /dev/null +++ b/src/Riok.Mapperly/Descriptors/SymbolAccessor.cs @@ -0,0 +1,109 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Riok.Mapperly.Helpers; +using Riok.Mapperly.Symbols; + +namespace Riok.Mapperly.Descriptors; + +public class SymbolAccessor +{ + private readonly WellKnownTypes _types; + private readonly Dictionary> _attributes = new(SymbolEqualityComparer.Default); + private readonly Dictionary> _allMembers = new(SymbolEqualityComparer.Default); + private readonly Dictionary> _allAccessibleMembers = + new(SymbolEqualityComparer.Default); + + public SymbolAccessor(WellKnownTypes types) + { + _types = types; + } + + internal IEnumerable GetAttributes(ISymbol symbol) + where T : Attribute + { + var attributeSymbol = _types.Get(); + return GetAttributesCore(symbol) + .Where(x => SymbolEqualityComparer.Default.Equals(x.AttributeClass?.ConstructedFrom ?? x.AttributeClass, attributeSymbol)); + } + + internal bool HasAttribute(ISymbol symbol) + where T : Attribute => GetAttributes(symbol).Any(); + + internal IEnumerable GetAllMethods(ITypeSymbol symbol) => GetAllMembers(symbol).OfType(); + + internal IEnumerable GetAllMethods(ITypeSymbol symbol, string name) => + GetAllMembers(symbol, name).OfType(); + + internal IEnumerable GetAllProperties(ITypeSymbol symbol, string name) => + GetAllMembers(symbol, name).OfType(); + + internal IEnumerable GetAllFields(ITypeSymbol symbol) => GetAllMembers(symbol).OfType(); + + internal IReadOnlyCollection GetAllMembers(ITypeSymbol symbol) + { + if (_allMembers.TryGetValue(symbol, out var members)) + { + return members; + } + + members = GetAllMembersCore(symbol).ToArray(); + _allMembers.Add(symbol, members); + + return members; + } + + internal IReadOnlyCollection GetAllAccessibleMappableMembers(ITypeSymbol symbol) + { + if (_allAccessibleMembers.TryGetValue(symbol, out var members)) + { + return members; + } + + members = GetAllAccessibleMappableMembersCore(symbol).ToArray(); + _allAccessibleMembers.Add(symbol, members); + + return members; + } + + internal IEnumerable GetMappableMembers(ITypeSymbol symbol, string name, IEqualityComparer comparer) + { + return GetAllAccessibleMappableMembers(symbol).Where(x => comparer.Equals(name, x.Name)); + } + + private IEnumerable GetAllMembers(ITypeSymbol symbol, string name) => GetAllMembers(symbol).Where(x => name.Equals(x.Name)); + + private ImmutableArray GetAttributesCore(ISymbol symbol) + { + if (_attributes.TryGetValue(symbol, out var attributes)) + { + return attributes; + } + + attributes = symbol.GetAttributes(); + _attributes.Add(symbol, attributes); + + return attributes; + } + + private IEnumerable GetAllMembersCore(ITypeSymbol symbol) + { + var members = symbol.GetMembers(); + + if (symbol.TypeKind == TypeKind.Interface) + { + var interfaceProperties = symbol.AllInterfaces.SelectMany(GetAllMembers); + return members.Concat(interfaceProperties); + } + + return symbol.BaseType == null ? members : members.Concat(GetAllMembers(symbol.BaseType)); + } + + private IEnumerable GetAllAccessibleMappableMembersCore(ITypeSymbol symbol) + { + return GetAllMembers(symbol) + .Where(x => x is { IsStatic: false, Kind: SymbolKind.Property or SymbolKind.Field } && x.IsAccessible()) + .DistinctBy(x => x.Name) + .Select(MappableMember.Create) + .WhereNotNull(); + } +} diff --git a/src/Riok.Mapperly/Descriptors/UserMethodMappingExtractor.cs b/src/Riok.Mapperly/Descriptors/UserMethodMappingExtractor.cs index baf7c6f923..6a2f08b48f 100644 --- a/src/Riok.Mapperly/Descriptors/UserMethodMappingExtractor.cs +++ b/src/Riok.Mapperly/Descriptors/UserMethodMappingExtractor.cs @@ -29,7 +29,7 @@ public static IEnumerable ExtractUserMappings(SimpleMappingBuilder yield break; // extract user implemented mappings from base methods - foreach (var method in ExtractBaseMethods(ctx.Compilation.ObjectType, mapperSymbol)) + foreach (var method in ExtractBaseMethods(ctx.Compilation.ObjectType, mapperSymbol, ctx.SymbolAccessor)) { // Partial method declarations are allowed for base classes, // but still treated as user implemented methods, @@ -43,10 +43,15 @@ public static IEnumerable ExtractUserMappings(SimpleMappingBuilder private static IEnumerable ExtractMethods(ITypeSymbol mapperSymbol) => mapperSymbol.GetMembers().OfType(); - private static IEnumerable ExtractBaseMethods(INamedTypeSymbol objectType, ITypeSymbol mapperSymbol) + private static IEnumerable ExtractBaseMethods( + INamedTypeSymbol objectType, + ITypeSymbol mapperSymbol, + SymbolAccessor symbolAccessor + ) { - var baseMethods = mapperSymbol.BaseType?.GetAllMethods() ?? Enumerable.Empty(); - var intfMethods = mapperSymbol.AllInterfaces.SelectMany(x => x.GetAllMethods()); + var baseMethods = + mapperSymbol.BaseType != null ? symbolAccessor.GetAllMethods(mapperSymbol.BaseType!) : Enumerable.Empty(); + var intfMethods = mapperSymbol.AllInterfaces.SelectMany(symbolAccessor.GetAllMethods); return baseMethods .Concat(intfMethods) .OfType() @@ -294,7 +299,9 @@ private static bool BuildParameters( private static MethodParameter? BuildReferenceHandlerParameter(SimpleMappingBuilderContext ctx, IMethodSymbol method) { - var refHandlerParameterSymbol = method.Parameters.FirstOrDefault(p => p.HasAttribute(ctx.Types.Get())); + var refHandlerParameterSymbol = method.Parameters.FirstOrDefault( + p => ctx.SymbolAccessor.HasAttribute(p) + ); if (refHandlerParameterSymbol == null) return null; diff --git a/src/Riok.Mapperly/Descriptors/WellKnownTypes.cs b/src/Riok.Mapperly/Descriptors/WellKnownTypes.cs index 8ddf15f8b3..d7ca04dae2 100644 --- a/src/Riok.Mapperly/Descriptors/WellKnownTypes.cs +++ b/src/Riok.Mapperly/Descriptors/WellKnownTypes.cs @@ -21,10 +21,16 @@ internal WellKnownTypes(Compilation compilation) public ITypeSymbol GetArrayType(ITypeSymbol type) => _compilation.CreateArrayTypeSymbol(type, elementNullableAnnotation: type.NullableAnnotation).NonNullable(); - public INamedTypeSymbol Get() => Get(typeof(T).FullName); + public INamedTypeSymbol Get() => Get(typeof(T)); - public INamedTypeSymbol Get(Type type) => - Get(type.FullName ?? throw new InvalidOperationException("Could not get name of type " + type)); + public INamedTypeSymbol Get(Type type) + { + if (type.IsConstructedGenericType) + { + type = type.GetGenericTypeDefinition(); + } + return Get(type.FullName ?? throw new InvalidOperationException("Could not get name of type " + type)); + } public INamedTypeSymbol Get(string typeFullName) => TryGet(typeFullName) ?? throw new InvalidOperationException("Could not get type " + typeFullName); diff --git a/src/Riok.Mapperly/Helpers/SymbolExtensions.cs b/src/Riok.Mapperly/Helpers/SymbolExtensions.cs index 2d5153ed1a..e9a591716f 100644 --- a/src/Riok.Mapperly/Helpers/SymbolExtensions.cs +++ b/src/Riok.Mapperly/Helpers/SymbolExtensions.cs @@ -1,9 +1,7 @@ using System.Collections.Immutable; using System.Diagnostics.CodeAnalysis; - using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; -using Riok.Mapperly.Symbols; namespace Riok.Mapperly.Helpers; @@ -14,9 +12,6 @@ internal static class SymbolExtensions typeof(Version).FullName ); - internal static bool HasAttribute(this ISymbol symbol, INamedTypeSymbol attributeSymbol) => - symbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, attributeSymbol)); - internal static bool IsImmutable(this ISymbol symbol) => symbol is INamedTypeSymbol namedSymbol && ( @@ -56,48 +51,8 @@ internal static bool TryGetEnumUnderlyingType(this ITypeSymbol t, [NotNullWhen(t return enumType != null; } - internal static IEnumerable GetAllMethods(this ITypeSymbol symbol) => symbol.GetAllMembers().OfType(); - - internal static IEnumerable GetAllMethods(this ITypeSymbol symbol, string name) => - symbol.GetAllMembers(name).OfType(); - - internal static IEnumerable GetAllProperties(this ITypeSymbol symbol, string name) => - symbol.GetAllMembers(name).OfType(); - internal static IEnumerable GetFields(this ITypeSymbol symbol) => symbol.GetMembers().OfType(); - internal static IEnumerable GetAllMembers(this ITypeSymbol symbol) - { - var members = symbol.GetMembers(); - - if (symbol.TypeKind == TypeKind.Interface) - { - var interfaceProperties = symbol.AllInterfaces.SelectMany(i => i.GetAllMembers()); - return members.Concat(interfaceProperties); - } - - return symbol.BaseType == null ? members : members.Concat(symbol.BaseType.GetAllMembers()); - } - - internal static IEnumerable GetMappableMembers( - this ITypeSymbol symbol, - string name, - IEqualityComparer comparer - ) - { - return symbol.GetAllMembers().Where(x => !x.IsStatic && comparer.Equals(name, x.Name)).Select(MappableMember.Create).WhereNotNull(); - } - - internal static IEnumerable GetAccessibleMappableMembers(this ITypeSymbol symbol) - { - return symbol - .GetAllMembers() - .Where(x => !x.IsStatic && x.IsAccessible()) - .DistinctBy(x => x.Name) - .Select(MappableMember.Create) - .WhereNotNull(); - } - internal static IMethodSymbol? GetStaticGenericMethod(this INamedTypeSymbol namedType, string methodName) { return namedType.GetMembers(methodName).OfType().FirstOrDefault(m => m.IsStatic && m.IsGenericMethod); @@ -209,7 +164,4 @@ ITypeSymbol type return true; } - - private static IEnumerable GetAllMembers(this ITypeSymbol symbol, string name) => - symbol.GetAllMembers().Where(x => name.Equals(x.Name)); } diff --git a/src/Riok.Mapperly/MapperGenerator.cs b/src/Riok.Mapperly/MapperGenerator.cs index b848ad3738..61a7dad107 100644 --- a/src/Riok.Mapperly/MapperGenerator.cs +++ b/src/Riok.Mapperly/MapperGenerator.cs @@ -38,6 +38,7 @@ private static void Execute(Compilation compilation, ImmutableArray(mapperSymbol)) continue; - var builder = new DescriptorBuilder(ctx, compilation, mapperSyntax, mapperSymbol, wellKnownTypes); + var builder = new DescriptorBuilder(ctx, compilation, mapperSyntax, mapperSymbol, wellKnownTypes, symbolAccessor); var descriptor = builder.Build(); ctx.AddSource( diff --git a/src/Riok.Mapperly/Symbols/MemberPath.cs b/src/Riok.Mapperly/Symbols/MemberPath.cs index cda2cc1ee0..e67683491f 100644 --- a/src/Riok.Mapperly/Symbols/MemberPath.cs +++ b/src/Riok.Mapperly/Symbols/MemberPath.cs @@ -2,6 +2,7 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Riok.Mapperly.Descriptors; using Riok.Mapperly.Helpers; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using static Riok.Mapperly.Emit.SyntaxFactoryHelper; @@ -179,18 +180,20 @@ public static bool TryFind( ITypeSymbol type, IEnumerable> pathCandidates, IReadOnlyCollection ignoredNames, + SymbolAccessor symbolAccessor, [NotNullWhen(true)] out MemberPath? memberPath - ) => TryFind(type, pathCandidates, ignoredNames, StringComparer.Ordinal, out memberPath); + ) => TryFind(type, pathCandidates, ignoredNames, StringComparer.Ordinal, symbolAccessor, out memberPath); public static bool TryFind( ITypeSymbol type, IEnumerable> pathCandidates, IReadOnlyCollection ignoredNames, IEqualityComparer comparer, + SymbolAccessor symbolAccessor, [NotNullWhen(true)] out MemberPath? memberPath ) { - foreach (var pathCandidate in FindCandidates(type, pathCandidates, comparer)) + foreach (var pathCandidate in FindCandidates(type, pathCandidates, comparer, symbolAccessor)) { if (ignoredNames.Contains(pathCandidate.Path.First().Name)) continue; @@ -203,18 +206,23 @@ public static bool TryFind( return false; } - public static bool TryFind(ITypeSymbol type, IReadOnlyCollection path, [NotNullWhen(true)] out MemberPath? memberPath) => - TryFind(type, path, StringComparer.Ordinal, out memberPath); + public static bool TryFind( + ITypeSymbol type, + IReadOnlyCollection path, + SymbolAccessor symbolAccessor, + [NotNullWhen(true)] out MemberPath? memberPath + ) => TryFind(type, path, StringComparer.Ordinal, symbolAccessor, out memberPath); private static IEnumerable FindCandidates( ITypeSymbol type, IEnumerable> pathCandidates, - IEqualityComparer comparer + IEqualityComparer comparer, + SymbolAccessor symbolAccessor ) { foreach (var pathCandidate in pathCandidates) { - if (TryFind(type, pathCandidate.ToList(), comparer, out var memberPath)) + if (TryFind(type, pathCandidate.ToList(), comparer, symbolAccessor, out var memberPath)) yield return memberPath; } } @@ -223,10 +231,11 @@ private static bool TryFind( ITypeSymbol type, IReadOnlyCollection path, IEqualityComparer comparer, + SymbolAccessor symbolAccessor, [NotNullWhen(true)] out MemberPath? memberPath ) { - var foundPath = Find(type, path, comparer).ToList(); + var foundPath = Find(type, path, comparer, symbolAccessor).ToList(); if (foundPath.Count != path.Count) { memberPath = null; @@ -237,11 +246,16 @@ private static bool TryFind( return true; } - private static IEnumerable Find(ITypeSymbol type, IEnumerable path, IEqualityComparer comparer) + private static IEnumerable Find( + ITypeSymbol type, + IEnumerable path, + IEqualityComparer comparer, + SymbolAccessor symbolAccessor + ) { foreach (var name in path) { - if (FindMember(type, name, comparer) is not { } member) + if (FindMember(type, name, comparer, symbolAccessor) is not { } member) break; type = member.Type; @@ -249,8 +263,13 @@ private static IEnumerable Find(ITypeSymbol type, IEnumerable comparer) + private static IMappableMember? FindMember( + ITypeSymbol type, + string name, + IEqualityComparer comparer, + SymbolAccessor symbolAccessor + ) { - return type.GetMappableMembers(name, comparer).FirstOrDefault(); + return symbolAccessor.GetMappableMembers(type, name, comparer).FirstOrDefault(); } }