Skip to content

Commit

Permalink
Handle SetLastError=true (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
elinor-fung authored Nov 30, 2020
1 parent 91ff4be commit b36b0b8
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<LangVersion>8.0</LangVersion>
<RootNamespace>System.Runtime.InteropServices</RootNamespace>
<Nullable>enable</Nullable>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

</Project>
74 changes: 74 additions & 0 deletions DllImportGenerator/Ancillary.Interop/MarshalEx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,79 @@ public static void SetHandle(SafeHandle safeHandle, IntPtr handle)
{
typeof(SafeHandle).GetMethod("SetHandle", BindingFlags.NonPublic | BindingFlags.Instance)!.Invoke(safeHandle, new object[] { handle });
}

/// <summary>
/// Set the last platform invoke error on the thread
/// </summary>
public static void SetLastWin32Error(int error)
{
typeof(Marshal).GetMethod("SetLastWin32Error", BindingFlags.NonPublic | BindingFlags.Static)!.Invoke(null, new object[] { error });
}

/// <summary>
/// Get the last system error on the current thread (errno on Unix, GetLastError on Windows)
/// </summary>
public static unsafe int GetLastSystemError()
{
// Would be internal call that handles getting the last error for the thread using the PAL

if (OperatingSystem.IsWindows())
{
return Kernel32.GetLastError();
}
else if (OperatingSystem.IsMacOS())
{
return *libc.__error();
}
else if (OperatingSystem.IsLinux())
{
return *libc.__errno_location();
}

throw new NotImplementedException();
}

/// <summary>
/// Set the last system error on the current thread (errno on Unix, SetLastError on Windows)
/// </summary>
public static unsafe void SetLastSystemError(int error)
{
// Would be internal call that handles setting the last error for the thread using the PAL

if (OperatingSystem.IsWindows())
{
Kernel32.SetLastError(error);
}
else if (OperatingSystem.IsMacOS())
{
*libc.__error() = error;
}
else if (OperatingSystem.IsLinux())
{
*libc.__errno_location() = error;
}
else
{
throw new NotImplementedException();
}
}

private class Kernel32
{
[DllImport(nameof(Kernel32))]
public static extern void SetLastError(int error);

[DllImport(nameof(Kernel32))]
public static extern int GetLastError();
}

private class libc
{
[DllImport(nameof(libc))]
internal static unsafe extern int* __errno_location();

[DllImport(nameof(libc))]
internal static unsafe extern int* __error();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using System;
using System.Runtime.InteropServices;

using Xunit;

namespace DllImportGenerator.IntegrationTests
{
[BlittableType]
public struct SetLastErrorMarshaller
{
public int val;

public SetLastErrorMarshaller(int i)
{
val = i;
}

public int ToManaged()
{
// Explicity set the last error to something else on unmarshalling
MarshalEx.SetLastWin32Error(val * 2);
return val;
}
}

partial class NativeExportsNE
{
public partial class SetLastError
{
[GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "set_error", SetLastError = true)]
public static partial int SetError(int error, byte shouldSetError);

[GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "set_error_return_string", SetLastError = true)]
[return: MarshalUsing(typeof(SetLastErrorMarshaller))]
public static partial int SetError_CustomMarshallingSetsError(int error, byte shouldSetError);

[GeneratedDllImport(nameof(NativeExportsNE), EntryPoint = "set_error_return_string", SetLastError = true)]
[return: MarshalAs(UnmanagedType.LPWStr)]
public static partial string SetError_NonBlittableSignature(int error, [MarshalAs(UnmanagedType.U1)] bool shouldSetError, [MarshalAs(UnmanagedType.LPWStr)] string errorString);
}
}

public class SetLastErrorTests
{
[Theory]
[InlineData(0)]
[InlineData(2)]
[InlineData(-5)]
public void LastWin32Error_HasExpectedValue(int error)
{
string errorString = error.ToString();
string ret = NativeExportsNE.SetLastError.SetError_NonBlittableSignature(error, shouldSetError: true, errorString);
Assert.Equal(error, Marshal.GetLastWin32Error());
Assert.Equal(errorString, ret);

// Clear the last error
MarshalEx.SetLastWin32Error(0);

NativeExportsNE.SetLastError.SetError(error, shouldSetError: 1);
Assert.Equal(error, Marshal.GetLastWin32Error());

MarshalEx.SetLastWin32Error(0);

// Custom marshalling sets the last error on unmarshalling.
// Last error should reflect error from native call, not unmarshalling.
NativeExportsNE.SetLastError.SetError_CustomMarshallingSetsError(error, shouldSetError: 1);
Assert.Equal(error, Marshal.GetLastWin32Error());
}

[Fact]
public void ClearPreviousError()
{
int error = 100;
MarshalEx.SetLastWin32Error(error);

// Don't actually set the error in the native call. SetLastError=true should clear any existing error.
string errorString = error.ToString();
string ret = NativeExportsNE.SetLastError.SetError_NonBlittableSignature(error, shouldSetError: false, errorString);
Assert.Equal(0, Marshal.GetLastWin32Error());
Assert.Equal(errorString, ret);

MarshalEx.SetLastWin32Error(error);

// Don't actually set the error in the native call. SetLastError=true should clear any existing error.
NativeExportsNE.SetLastError.SetError(error, shouldSetError: 0);
Assert.Equal(0, Marshal.GetLastWin32Error());

// Don't actually set the error in the native call. Custom marshalling still sets the last error.
// SetLastError=true should clear any existing error and ignore error set by custom marshalling.
NativeExportsNE.SetLastError.SetError_CustomMarshallingSetsError(error, shouldSetError: 0);
Assert.Equal(0, Marshal.GetLastWin32Error());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()

// Unsupported named arguments
// * BestFitMapping, ThrowOnUnmappableChar
// [TODO]: Expected diagnostic count should be 2 once we support SetLastError
yield return new object[] { CodeSnippets.AllDllImportNamedArguments, 3, 0 };
yield return new object[] { CodeSnippets.AllDllImportNamedArguments, 2, 0 };

// LCIDConversion
yield return new object[] { CodeSnippets.LCIDConversionAttribute, 1, 0 };
Expand Down
28 changes: 4 additions & 24 deletions DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ namespace DllImportGenerator.UnitTests
{
public class Compiles
{
public static IEnumerable<object[]> CodeSnippetsToCompile_NoDiagnostics()
public static IEnumerable<object[]> CodeSnippetsToCompile()
{
yield return new[] { CodeSnippets.TrivialClassDeclarations };
yield return new[] { CodeSnippets.TrivialStructDeclarations };
yield return new[] { CodeSnippets.MultipleAttributes };
yield return new[] { CodeSnippets.NestedNamespace };
yield return new[] { CodeSnippets.NestedTypes };
yield return new[] { CodeSnippets.UserDefinedEntryPoint };
//yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments };
yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments };
yield return new[] { CodeSnippets.DefaultParameters };
yield return new[] { CodeSnippets.UseCSharpFeaturesForConstants };

Expand Down Expand Up @@ -161,14 +161,9 @@ public static IEnumerable<object[]> CodeSnippetsToCompile_NoDiagnostics()
yield return new[] { CodeSnippets.CustomStructMarshallingMarshalUsingParametersAndModifiers };
}

public static IEnumerable<object[]> CodeSnippetsToCompile_WithDiagnostics()
{
yield return new[] { CodeSnippets.AllSupportedDllImportNamedArguments };
}

[Theory]
[MemberData(nameof(CodeSnippetsToCompile_NoDiagnostics))]
public async Task ValidateSnippets_NoDiagnostics(string source)
[MemberData(nameof(CodeSnippetsToCompile))]
public async Task ValidateSnippets(string source)
{
Compilation comp = await TestUtils.CreateCompilation(source);
TestUtils.AssertPreSourceGeneratorCompilation(comp);
Expand All @@ -179,20 +174,5 @@ public async Task ValidateSnippets_NoDiagnostics(string source)
var newCompDiags = newComp.GetDiagnostics();
Assert.Empty(newCompDiags);
}

[Theory]
[MemberData(nameof(CodeSnippetsToCompile_WithDiagnostics))]
public async Task ValidateSnippets_WithDiagnostics(string source)
{
Compilation comp = await TestUtils.CreateCompilation(source);
TestUtils.AssertPreSourceGeneratorCompilation(comp);

var newComp = TestUtils.RunGenerators(comp, out var generatorDiags, new Microsoft.Interop.DllImportGenerator());
Assert.NotEmpty(generatorDiags);
Assert.All(generatorDiags, d => Assert.StartsWith(Microsoft.Interop.GeneratorDiagnostics.Ids.Prefix, d.Id));

var newCompDiags = newComp.GetDiagnostics();
Assert.Empty(newCompDiags);
}
}
}
6 changes: 0 additions & 6 deletions DllImportGenerator/DllImportGenerator/DllImportGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ public void Execute(GeneratorExecutionContext context)
generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.ThrowOnUnmappableChar));
}

// [TODO] Remove once we support SetLastError=true
if (dllImportData.SetLastError)
{
generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.SetLastError), "true");
}

if (lcidConversionAttr != null)
{
// Using LCIDConversion with GeneratedDllImport is not supported
Expand Down
53 changes: 52 additions & 1 deletion DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ internal sealed class StubCodeGenerator : StubCodeContext
public string ReturnNativeIdentifier { get; private set; } = ReturnIdentifier;

private const string InvokeReturnIdentifier = "__invokeRetVal";
private const string LastErrorIdentifier = "__lastError";

// Error code representing success. This maps to S_OK for Windows HRESULT semantics and 0 for POSIX errno semantics.
private const int SuccessErrorCode = 0;

private static readonly Stage[] Stages = new Stage[]
{
Expand Down Expand Up @@ -170,6 +174,14 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo
AppendVariableDeclations(setupStatements, retMarshaller.TypeInfo, retMarshaller.Generator);
}

if (this.dllImportData.SetLastError)
{
// Declare variable for last error
setupStatements.Add(MarshallerHelpers.DeclareWithDefault(
PredefinedType(Token(SyntaxKind.IntKeyword)),
LastErrorIdentifier));
}

var tryStatements = new List<StatementSyntax>();
var finallyStatements = new List<StatementSyntax>();
var invoke = InvocationExpression(IdentifierName(dllImportName));
Expand Down Expand Up @@ -235,11 +247,37 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo
invoke));
}

if (this.dllImportData.SetLastError)
{
// Marshal.SetLastSystemError(0);
var clearLastError = ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx),
IdentifierName("SetLastSystemError")),
ArgumentList(SingletonSeparatedList(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(SuccessErrorCode)))))));

// <lastError> = Marshal.GetLastSystemError();
var getLastError = ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(LastErrorIdentifier),
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx),
IdentifierName("GetLastSystemError")))));

invokeStatement = Block(clearLastError, invokeStatement, getLastError);
}

// Nest invocation in fixed statements
if (fixedStatements.Any())
{
fixedStatements.Reverse();
invokeStatement = fixedStatements.First().WithStatement(Block(invokeStatement));
invokeStatement = fixedStatements.First().WithStatement(invokeStatement);
foreach (var fixedStatement in fixedStatements.Skip(1))
{
invokeStatement = fixedStatement.WithStatement(Block(invokeStatement));
Expand Down Expand Up @@ -274,6 +312,19 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo
allStatements.AddRange(tryStatements);
}

if (this.dllImportData.SetLastError)
{
// Marshal.SetLastWin32Error(<lastError>);
allStatements.Add(ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MarshalEx),
IdentifierName("SetLastWin32Error")),
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(LastErrorIdentifier)))))));
}

// Return
if (!stubReturnsVoid)
allStatements.Add(ReturnStatement(IdentifierName(ReturnIdentifier)));
Expand Down
Loading

0 comments on commit b36b0b8

Please sign in to comment.