diff --git a/DllImportGenerator/DllImportGenerator.IntegrationTests/DllImportGenerator.IntegrationTests.csproj b/DllImportGenerator/DllImportGenerator.IntegrationTests/DllImportGenerator.IntegrationTests.csproj index 6b3469c88b3..092c17338e2 100644 --- a/DllImportGenerator/DllImportGenerator.IntegrationTests/DllImportGenerator.IntegrationTests.csproj +++ b/DllImportGenerator/DllImportGenerator.IntegrationTests/DllImportGenerator.IntegrationTests.csproj @@ -1,4 +1,5 @@  + net5.0 diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs index c44f97ea528..93210466493 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs @@ -175,5 +175,68 @@ public async Task ValidateSnippets(string source) var newCompDiags = newComp.GetDiagnostics(); Assert.Empty(newCompDiags); } + + public static IEnumerable CodeSnippetsToCompileWithForwarder() + { + yield return new[] { CodeSnippets.UserDefinedEntryPoint }; + yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments }; + + // Parameter / return types (supported in DllImportGenerator) + yield return new[] { CodeSnippets.BasicParametersAndModifiers() }; + // Parameter / return types (not supported in DllImportGenerator) + yield return new[] { CodeSnippets.BasicParametersAndModifiers() }; + } + + [Theory] + [MemberData(nameof(CodeSnippetsToCompileWithForwarder))] + public async Task ValidateSnippetsWithForwarder(string source) + { + Compilation comp = await TestUtils.CreateCompilation(source); + TestUtils.AssertPreSourceGeneratorCompilation(comp); + + var newComp = TestUtils.RunGenerators( + comp, + new DllImportGeneratorOptionsProvider(useMarshalType: false, generateForwarders: true), + out var generatorDiags, + new Microsoft.Interop.DllImportGenerator()); + + Assert.Empty(generatorDiags); + + var newCompDiags = newComp.GetDiagnostics(); + Assert.Empty(newCompDiags); + } + + public static IEnumerable CodeSnippetsToCompileWithMarshalType() + { + // SetLastError + yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments }; + + // SafeHandle + yield return new[] { CodeSnippets.BasicParametersAndModifiers("Microsoft.Win32.SafeHandles.SafeFileHandle") }; + } + + [Theory] + [MemberData(nameof(CodeSnippetsToCompileWithMarshalType))] + public async Task ValidateSnippetsWithMarshalType(string source) + { + Compilation comp = await TestUtils.CreateCompilation(source); + TestUtils.AssertPreSourceGeneratorCompilation(comp); + + var newComp = TestUtils.RunGenerators( + comp, + new DllImportGeneratorOptionsProvider(useMarshalType: true, generateForwarders: false), + out var generatorDiags, + new Microsoft.Interop.DllImportGenerator()); + + Assert.Empty(generatorDiags); + + var newCompDiags = newComp.GetDiagnostics(); + + Assert.All(newCompDiags, diag => + { + Assert.Equal("CS0117", diag.Id); + Assert.StartsWith("'Marshal' does not contain a definition for ", diag.GetMessage()); + }); + } } } diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/DllImportGeneratorOptionsProvider.cs b/DllImportGenerator/DllImportGenerator.UnitTests/DllImportGeneratorOptionsProvider.cs new file mode 100644 index 00000000000..a774bd2ef35 --- /dev/null +++ b/DllImportGenerator/DllImportGenerator.UnitTests/DllImportGeneratorOptionsProvider.cs @@ -0,0 +1,71 @@ +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.Interop; + +namespace DllImportGenerator.UnitTests +{ + /// + /// An implementation of that provides configuration in code + /// of the options supported by the DllImportGenerator source generator. Used for testing various configurations. + /// + internal class DllImportGeneratorOptionsProvider : AnalyzerConfigOptionsProvider + { + public DllImportGeneratorOptionsProvider(bool useMarshalType, bool generateForwarders) + { + GlobalOptions = new GlobalGeneratorOptions(useMarshalType, generateForwarders); + } + + public override AnalyzerConfigOptions GlobalOptions { get; } + + public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) + { + return EmptyOptions.Instance; + } + + public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) + { + return EmptyOptions.Instance; + } + + private class GlobalGeneratorOptions : AnalyzerConfigOptions + { + private readonly bool _useMarshalType = false; + private readonly bool _generateForwarders = false; + public GlobalGeneratorOptions(bool useMarshalType, bool generateForwarders) + { + _useMarshalType = useMarshalType; + _generateForwarders = generateForwarders; + } + + public override bool TryGetValue(string key, [NotNullWhen(true)] out string? value) + { + switch (key) + { + case OptionsHelper.UseMarshalTypeOption: + value = _useMarshalType.ToString(); + return true; + + case OptionsHelper.GenerateForwardersOption: + value = _generateForwarders.ToString(); + return true; + + default: + value = null; + return false; + } + } + } + + private class EmptyOptions : AnalyzerConfigOptions + { + public override bool TryGetValue(string key, [NotNullWhen(true)] out string? value) + { + value = null; + return false; + } + + public static AnalyzerConfigOptions Instance = new EmptyOptions(); + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/TestUtils.cs b/DllImportGenerator/DllImportGenerator.UnitTests/TestUtils.cs index 80b39053c3f..a57aca5cf62 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/TestUtils.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/TestUtils.cs @@ -1,6 +1,8 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; using Microsoft.CodeAnalysis.Testing; +using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Reflection; @@ -20,15 +22,18 @@ internal static class TestUtils /// public static void AssertPreSourceGeneratorCompilation(Compilation comp) { + var allowedDiagnostics = new HashSet() + { + "CS8795", // Partial method impl missing + "CS0234", // Missing type or namespace - GeneratedDllImportAttribute + "CS0246", // Missing type or namespace - GeneratedDllImportAttribute + "CS8019", // Unnecessary using + }; var compDiags = comp.GetDiagnostics(); - foreach (var diag in compDiags) + Assert.All(compDiags, diag => { - Assert.True( - "CS8795".Equals(diag.Id) // Partial method impl missing - || "CS0234".Equals(diag.Id) // Missing type or namespace - GeneratedDllImportAttribute - || "CS0246".Equals(diag.Id) // Missing type or namespace - GeneratedDllImportAttribute - || "CS8019".Equals(diag.Id)); // Unnecessary using - } + Assert.Subset(allowedDiagnostics, new HashSet { diag.Id }); + }); } /// @@ -85,13 +90,27 @@ public static (ReferenceAssemblies, MetadataReference) GetReferenceAssemblies() /// The resulting compilation public static Compilation RunGenerators(Compilation comp, out ImmutableArray diagnostics, params ISourceGenerator[] generators) { - CreateDriver(comp, generators).RunGeneratorsAndUpdateCompilation(comp, out var d, out diagnostics); + CreateDriver(comp, null, generators).RunGeneratorsAndUpdateCompilation(comp, out var d, out diagnostics); + return d; + } + + /// + /// Run the supplied generators on the compilation. + /// + /// Compilation target + /// Resulting diagnostics + /// Source generator instances + /// The resulting compilation + public static Compilation RunGenerators(Compilation comp, AnalyzerConfigOptionsProvider options, out ImmutableArray diagnostics, params ISourceGenerator[] generators) + { + CreateDriver(comp, options, generators).RunGeneratorsAndUpdateCompilation(comp, out var d, out diagnostics); return d; } - private static GeneratorDriver CreateDriver(Compilation c, params ISourceGenerator[] generators) + private static GeneratorDriver CreateDriver(Compilation c, AnalyzerConfigOptionsProvider? options, ISourceGenerator[] generators) => CSharpGeneratorDriver.Create( ImmutableArray.Create(generators), - parseOptions: (CSharpParseOptions)c.SyntaxTrees.First().Options); + parseOptions: (CSharpParseOptions)c.SyntaxTrees.First().Options, + optionsProvider: options); } } diff --git a/DllImportGenerator/DllImportGenerator/DllImportGenerator.cs b/DllImportGenerator/DllImportGenerator/DllImportGenerator.cs index 86bb2467ab1..1e012f8b69f 100644 --- a/DllImportGenerator/DllImportGenerator/DllImportGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/DllImportGenerator.cs @@ -58,7 +58,7 @@ public void Execute(GeneratorExecutionContext context) generatorDiagnostics.ReportTargetFrameworkNotSupported(MinimumSupportedFrameworkVersion); } - var env = new StubEnvironment(context.Compilation, isSupported, targetFrameworkVersion); + var env = new StubEnvironment(context.Compilation, isSupported, targetFrameworkVersion, context.AnalyzerConfigOptions.GlobalOptions); var generatedDllImports = new StringBuilder(); foreach (SyntaxReference synRef in synRec.Methods) { @@ -94,7 +94,7 @@ public void Execute(GeneratorExecutionContext context) // Process the GeneratedDllImport attribute DllImportStub.GeneratedDllImportData dllImportData; - AttributeSyntax dllImportAttr = this.ProcessGeneratedDllImportAttribute(methodSymbolInfo, generatedDllImportAttr, out dllImportData); + AttributeSyntax dllImportAttr = this.ProcessGeneratedDllImportAttribute(methodSymbolInfo, generatedDllImportAttr, context.AnalyzerConfigOptions.GlobalOptions.GenerateForwarders(), out dllImportData); Debug.Assert((dllImportAttr is not null) && (dllImportData is not null)); if (dllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.BestFitMapping)) @@ -213,6 +213,7 @@ private static bool IsGeneratedDllImportAttribute(AttributeSyntax attrSyntaxMayb private AttributeSyntax ProcessGeneratedDllImportAttribute( IMethodSymbol method, AttributeData attrData, + bool generateForwarders, out DllImportStub.GeneratedDllImportData dllImportData) { dllImportData = new DllImportStub.GeneratedDllImportData(); @@ -287,7 +288,9 @@ private AttributeSyntax ProcessGeneratedDllImportAttribute( Debug.Assert(expSyntaxMaybe is not null); - if (PassThroughToDllImportAttribute(namedArg.Key)) + // If we're generating a forwarder stub, then all parameters on the GenerateDllImport attribute + // must also be added to the generated DllImport attribute. + if (generateForwarders || PassThroughToDllImportAttribute(namedArg.Key)) { // Defer the name equals syntax till we know the value means something. If we created // an expression we know the key value was valid. @@ -340,9 +343,6 @@ static ExpressionSyntax CreateEnumExpressionSyntax(T value) where T : Enum static bool PassThroughToDllImportAttribute(string argName) { -#if GENERATE_FORWARDER - return true; -#else // Certain fields on DllImport will prevent inlining. Their functionality should be handled by the // generated source, so the generated DllImport declaration should not include these fields. return argName switch @@ -355,7 +355,6 @@ static bool PassThroughToDllImportAttribute(string argName) nameof(DllImportStub.GeneratedDllImportData.SetLastError) => false, _ => true }; -#endif } } diff --git a/DllImportGenerator/DllImportGenerator/DllImportGenerator.csproj b/DllImportGenerator/DllImportGenerator/DllImportGenerator.csproj index c27c6653a73..108713c4edd 100644 --- a/DllImportGenerator/DllImportGenerator/DllImportGenerator.csproj +++ b/DllImportGenerator/DllImportGenerator/DllImportGenerator.csproj @@ -10,9 +10,6 @@ Preview enable Microsoft.Interop - - - @@ -42,6 +39,7 @@ + diff --git a/DllImportGenerator/DllImportGenerator/DllImportStub.cs b/DllImportGenerator/DllImportGenerator/DllImportStub.cs index d51fe9ff6e0..9fe488f7dae 100644 --- a/DllImportGenerator/DllImportGenerator/DllImportStub.cs +++ b/DllImportGenerator/DllImportGenerator/DllImportStub.cs @@ -6,6 +6,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop @@ -13,7 +14,8 @@ namespace Microsoft.Interop internal record StubEnvironment( Compilation Compilation, bool SupportedTargetFramework, - Version TargetFrameworkVersion); + Version TargetFrameworkVersion, + AnalyzerConfigOptions Options); internal class DllImportStub { @@ -177,7 +179,9 @@ public static DllImportStub Create( }; var managedRetTypeInfo = retTypeInfo; - if (!dllImportData.PreserveSig) + // Do not manually handle PreserveSig when generating forwarders. + // We want the runtime to handle everything. + if (!dllImportData.PreserveSig && !env.Options.GenerateForwarders()) { // Create type info for native HRESULT return retTypeInfo = TypePositionInfo.CreateForType(env.Compilation.GetSpecialType(SpecialType.System_Int32), NoMarshallingInfo.Instance); @@ -203,7 +207,7 @@ public static DllImportStub Create( } // Generate stub code - var stubGenerator = new StubCodeGenerator(method, dllImportData, paramsTypeInfo, retTypeInfo, diagnostics); + var stubGenerator = new StubCodeGenerator(method, dllImportData, paramsTypeInfo, retTypeInfo, diagnostics, env.Options); var (code, dllImport) = stubGenerator.GenerateSyntax(); var additionalAttrs = new List(); diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index d2dd817bd3e..16428f1a482 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -4,6 +4,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop @@ -109,7 +110,6 @@ internal class MarshallingGenerators public static readonly Forwarder Forwarder = new Forwarder(); public static readonly BlittableMarshaller Blittable = new BlittableMarshaller(); public static readonly DelegateMarshaller Delegate = new DelegateMarshaller(); - public static readonly SafeHandleMarshaller SafeHandle = new SafeHandleMarshaller(); public static readonly HResultExceptionMarshaller HResultException = new HResultExceptionMarshaller(); /// @@ -120,11 +120,14 @@ internal class MarshallingGenerators /// A instance. public static IMarshallingGenerator Create( TypePositionInfo info, - StubCodeContext context) + StubCodeContext context, + AnalyzerConfigOptions options) { -#if GENERATE_FORWARDER - return MarshallingGenerators.Forwarder; -#else + if (options.GenerateForwarders()) + { + return MarshallingGenerators.Forwarder; + } + if (info.IsNativeReturnPosition && !info.IsManagedReturnPosition) { // Use marshaller for native HRESULT return / exception throwing @@ -180,7 +183,7 @@ public static IMarshallingGenerator Create( { throw new MarshallingNotSupportedException(info, context); } - return SafeHandle; + return new SafeHandleMarshaller(options); // Marshalling in new model. // Must go before the cases that do not explicitly check for marshalling info to support @@ -204,7 +207,7 @@ public static IMarshallingGenerator Create( return CreateStringMarshaller(info, context); case { ManagedType: IArrayTypeSymbol { IsSZArray: true, ElementType: ITypeSymbol elementType } }: - return CreateArrayMarshaller(info, context, elementType); + return CreateArrayMarshaller(info, context, options, elementType); case { ManagedType: { SpecialType: SpecialType.System_Void } }: return Forwarder; @@ -212,7 +215,6 @@ public static IMarshallingGenerator Create( default: throw new MarshallingNotSupportedException(info, context); } -#endif } private static IMarshallingGenerator CreateCharMarshaller(TypePositionInfo info, StubCodeContext context) @@ -303,7 +305,7 @@ private static IMarshallingGenerator CreateStringMarshaller(TypePositionInfo inf throw new MarshallingNotSupportedException(info, context); } - private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(TypePositionInfo info, StubCodeContext context) + private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(TypePositionInfo info, StubCodeContext context, AnalyzerConfigOptions options) { ExpressionSyntax numElementsExpression; if (info.MarshallingAttributeInfo is not ArrayMarshalAsInfo marshalAsInfo) @@ -338,7 +340,7 @@ private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(Type else { var (managed, native) = context.GetIdentifiers(paramIndexInfo); - string identifier = Create(paramIndexInfo, context).UsesNativeIdentifier(paramIndexInfo, context) ? native : managed; + string identifier = Create(paramIndexInfo, context, options).UsesNativeIdentifier(paramIndexInfo, context) ? native : managed; sizeParamIndexExpression = CastExpression( PredefinedType(Token(SyntaxKind.IntKeyword)), IdentifierName(identifier)); @@ -357,7 +359,7 @@ private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(Type return numElementsExpression; } - private static IMarshallingGenerator CreateArrayMarshaller(TypePositionInfo info, StubCodeContext context, ITypeSymbol elementType) + private static IMarshallingGenerator CreateArrayMarshaller(TypePositionInfo info, StubCodeContext context, AnalyzerConfigOptions options, ITypeSymbol elementType) { var elementMarshallingInfo = info.MarshallingAttributeInfo switch { @@ -367,12 +369,15 @@ private static IMarshallingGenerator CreateArrayMarshaller(TypePositionInfo info _ => throw new MarshallingNotSupportedException(info, context) }; - var elementMarshaller = Create(TypePositionInfo.CreateForType(elementType, elementMarshallingInfo), new ArrayMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, context, false)); + var elementMarshaller = Create( + TypePositionInfo.CreateForType(elementType, elementMarshallingInfo), + new ArrayMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, context, false), + options); ExpressionSyntax numElementsExpression = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)); if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) { // In this case, we need a numElementsExpression supplied from metadata, so we'll calculate it here. - numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, context); + numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, context, options); } return elementMarshaller == Blittable diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs index 1ecece16c81..76bd0927613 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -3,6 +3,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop @@ -10,6 +11,12 @@ namespace Microsoft.Interop internal class SafeHandleMarshaller : IMarshallingGenerator { private static readonly TypeSyntax NativeType = ParseTypeName("global::System.IntPtr"); + private readonly AnalyzerConfigOptions options; + + public SafeHandleMarshaller(AnalyzerConfigOptions options) + { + this.options = options; + } public TypeSyntax AsNativeType(TypePositionInfo info) { @@ -83,7 +90,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont IdentifierName(managedIdentifier), InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + ParseName(TypeNames.MarshalEx(options)), GenericName(Identifier("CreateSafeHandle"), TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))), ArgumentList()))); @@ -101,7 +108,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont .WithInitializer(EqualsValueClause( InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + ParseName(TypeNames.MarshalEx(options)), GenericName(Identifier("CreateSafeHandle"), TypeArgumentList(SingletonSeparatedList(info.ManagedType.AsTypeSyntax())))), ArgumentList())))))); @@ -160,7 +167,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont StatementSyntax unmarshalStatement = ExpressionStatement( InvocationExpression( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_MarshalEx), + ParseTypeName(TypeNames.MarshalEx(options)), IdentifierName("SetHandle")), ArgumentList(SeparatedList( new [] diff --git a/DllImportGenerator/DllImportGenerator/Microsoft.Interop.DllImportGenerator.props b/DllImportGenerator/DllImportGenerator/Microsoft.Interop.DllImportGenerator.props new file mode 100644 index 00000000000..bf6b2d4a0d0 --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Microsoft.Interop.DllImportGenerator.props @@ -0,0 +1,19 @@ + + + + + + + + + diff --git a/DllImportGenerator/DllImportGenerator/OptionsHelper.cs b/DllImportGenerator/DllImportGenerator/OptionsHelper.cs new file mode 100644 index 00000000000..39f2fdc80c1 --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/OptionsHelper.cs @@ -0,0 +1,26 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Diagnostics; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Microsoft.Interop +{ + public static class OptionsHelper + { + public const string UseMarshalTypeOption = "build_property.DllImportGenerator_UseMarshalType"; + public const string GenerateForwardersOption = "build_property.DllImportGenerator_GenerateForwarders"; + + private static bool GetBoolOption(this AnalyzerConfigOptions options, string key) + { + return options.TryGetValue(key, out string? value) + && bool.TryParse(value, out bool result) + && result; + } + + internal static bool UseMarshalType(this AnalyzerConfigOptions options) => options.GetBoolOption(UseMarshalTypeOption); + + internal static bool GenerateForwarders(this AnalyzerConfigOptions options) => options.GetBoolOption(GenerateForwardersOption); + } +} diff --git a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs index e11337b791e..4a5f15ede8d 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs @@ -6,6 +6,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop @@ -48,7 +49,7 @@ internal sealed class StubCodeGenerator : StubCodeContext }; private readonly GeneratorDiagnostics diagnostics; - + private readonly AnalyzerConfigOptions options; private readonly IMethodSymbol stubMethod; private readonly DllImportStub.GeneratedDllImportData dllImportData; private readonly IEnumerable paramsTypeInfo; @@ -60,7 +61,8 @@ public StubCodeGenerator( DllImportStub.GeneratedDllImportData dllImportData, IEnumerable paramsTypeInfo, TypePositionInfo retTypeInfo, - GeneratorDiagnostics generatorDiagnostics) + GeneratorDiagnostics generatorDiagnostics, + AnalyzerConfigOptions options) { Debug.Assert(retTypeInfo.IsNativeReturnPosition); @@ -68,6 +70,7 @@ public StubCodeGenerator( this.dllImportData = dllImportData; this.paramsTypeInfo = paramsTypeInfo.ToList(); this.diagnostics = generatorDiagnostics; + this.options = options; // Get marshallers for parameters this.paramMarshallers = paramsTypeInfo.Select(p => CreateGenerator(p)).ToList(); @@ -79,7 +82,7 @@ public StubCodeGenerator( { try { - return (p, MarshallingGenerators.Create(p, this)); + return (p, MarshallingGenerators.Create(p, this, options)); } catch (MarshallingNotSupportedException e) { @@ -174,7 +177,9 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo AppendVariableDeclations(setupStatements, retMarshaller.TypeInfo, retMarshaller.Generator); } - if (this.dllImportData.SetLastError) + // Do not manually handle SetLastError when generating forwarders. + // We want the runtime to handle everything. + if (this.dllImportData.SetLastError && !options.GenerateForwarders()) { // Declare variable for last error setupStatements.Add(MarshallerHelpers.DeclareWithDefault( @@ -247,14 +252,16 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo invoke)); } - if (this.dllImportData.SetLastError) + // Do not manually handle SetLastError when generating forwarders. + // We want the runtime to handle everything. + if (this.dllImportData.SetLastError && !options.GenerateForwarders()) { // Marshal.SetLastSystemError(0); var clearLastError = ExpressionStatement( InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + ParseName(TypeNames.MarshalEx(options)), IdentifierName("SetLastSystemError")), ArgumentList(SingletonSeparatedList( Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(SuccessErrorCode))))))); @@ -267,7 +274,7 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + ParseName(TypeNames.MarshalEx(options)), IdentifierName("GetLastSystemError"))))); invokeStatement = Block(clearLastError, invokeStatement, getLastError); @@ -312,14 +319,14 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo allStatements.AddRange(tryStatements); } - if (this.dllImportData.SetLastError) + if (this.dllImportData.SetLastError && !options.GenerateForwarders()) { // Marshal.SetLastWin32Error(); allStatements.Add(ExpressionStatement( InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx), + ParseName(TypeNames.MarshalEx(options)), IdentifierName("SetLastWin32Error")), ArgumentList(SingletonSeparatedList( Argument(IdentifierName(LastErrorIdentifier))))))); diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index e998b90956c..a504b3b2bf4 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Microsoft.CodeAnalysis.Diagnostics; namespace Microsoft.Interop { @@ -27,7 +28,12 @@ static class TypeNames public const string System_Runtime_InteropServices_Marshal = "System.Runtime.InteropServices.Marshal"; - public const string System_Runtime_InteropServices_MarshalEx = "System.Runtime.InteropServices.MarshalEx"; + private const string System_Runtime_InteropServices_MarshalEx = "System.Runtime.InteropServices.MarshalEx"; + + public static string MarshalEx(AnalyzerConfigOptions options) + { + return options.UseMarshalType() ? System_Runtime_InteropServices_Marshal : System_Runtime_InteropServices_MarshalEx; + } public const string System_Runtime_InteropServices_MemoryMarshal = "System.Runtime.InteropServices.MemoryMarshal";