Skip to content

Commit

Permalink
Fix handling of non-bidirectional MarshalMode.Default with collection…
Browse files Browse the repository at this point in the history
… marshalling (#72075)
  • Loading branch information
elinor-fung authored Jul 13, 2022
1 parent 5a91883 commit 4c63062
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ or MarshalMode.UnmanagedToManagedRef
if (mode != MarshalMode.Default && !shape.HasFlag(MarshallerShape.CallerAllocatedBuffer) && !shape.HasFlag(MarshallerShape.ToUnmanaged))
return null;

if (isLinearCollectionMarshaller)
if (isLinearCollectionMarshaller && methods.ManagedValuesSource is not null)
{
// Element type is the type parameter of the ReadOnlySpan returned by GetManagedValuesSource
collectionElementType = ((INamedTypeSymbol)methods.ManagedValuesSource.ReturnType).TypeArguments[0];
Expand All @@ -382,13 +382,19 @@ or MarshalMode.UnmanagedToManagedRef

if (isLinearCollectionMarshaller)
{
// Native type is the first parameter of GetUnmanagedValuesSource
nativeType = methods.UnmanagedValuesSource.Parameters[0].Type;
if (nativeType is null && methods.UnmanagedValuesSource is not null)
{
// Native type is the first parameter of GetUnmanagedValuesSource
nativeType = methods.UnmanagedValuesSource.Parameters[0].Type;
}

// Element type is the type parameter of the Span returned by GetManagedValuesDestination
collectionElementType = ((INamedTypeSymbol)methods.ManagedValuesDestination.ReturnType).TypeArguments[0];
if (collectionElementType is null && methods.ManagedValuesDestination is not null)
{
// Element type is the type parameter of the Span returned by GetManagedValuesDestination
collectionElementType = ((INamedTypeSymbol)methods.ManagedValuesDestination.ReturnType).TypeArguments[0];
}
}
else
else if (nativeType is null)
{
// Native type is the first parameter of ConvertToManaged or ConvertToManagedFinally
if (methods.ToManagedFinally is not null)
Expand Down Expand Up @@ -457,7 +463,7 @@ or MarshalMode.UnmanagedToManagedRef
nativeType = methods.ToUnmanaged.ReturnType;
}

if (isLinearCollectionMarshaller)
if (isLinearCollectionMarshaller && methods.ManagedValuesSource is not null)
{
// Element type is the type parameter of the ReadOnlySpan returned by GetManagedValuesSource
collectionElementType = ((INamedTypeSymbol)methods.ManagedValuesSource.ReturnType).TypeArguments[0];
Expand All @@ -475,7 +481,7 @@ or MarshalMode.UnmanagedToManagedRef
nativeType = methods.FromUnmanaged.Parameters[0].Type;
}

if (isLinearCollectionMarshaller && collectionElementType is null)
if (isLinearCollectionMarshaller && collectionElementType is null && methods.ManagedValuesDestination is not null)
{
// Element type is the type parameter of the Span returned by GetManagedValuesDestination
collectionElementType = ((INamedTypeSymbol)methods.ManagedValuesDestination.ReturnType).TypeArguments[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ public struct Native { }
}
";
private static string DefaultOut = @"
[CustomMarshaller(typeof(S), MarshalMode.ManagedToUnmanagedOut, typeof(Marshaller))]
[CustomMarshaller(typeof(S), MarshalMode.Default, typeof(Marshaller))]
public static class Marshaller
{
public struct Native { }
Expand Down Expand Up @@ -1398,7 +1398,7 @@ static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : u
public static System.Span<TUnmanagedElement> GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null;
}
";
public const string Ref = @"
public const string Default = @"
[CustomMarshaller(typeof(TestCollection<>), MarshalMode.Default, typeof(Marshaller<,>))]
[ContiguousCollectionMarshaller]
static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : unmanaged
Expand All @@ -1412,7 +1412,7 @@ static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : u
public static System.ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(byte* unmanaged, int numElements) => throw null;
}
";
public const string RefNested = @"
public const string DefaultNested = @"
[CustomMarshaller(typeof(TestCollection<>), MarshalMode.Default, typeof(Marshaller<,>.Nested.Ref))]
[ContiguousCollectionMarshaller]
static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : unmanaged
Expand Down Expand Up @@ -1441,6 +1441,26 @@ static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : u
public static System.Span<T> GetManagedValuesDestination(TestCollection<T> managed) => throw null;
public static System.ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(byte* unmanaged, int numElements) => throw null;
}
";
public const string DefaultIn = @"
[CustomMarshaller(typeof(TestCollection<>), MarshalMode.Default, typeof(Marshaller<,>))]
[ContiguousCollectionMarshaller]
static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : unmanaged
{
public static byte* AllocateContainerForUnmanagedElements(TestCollection<T> managed, out int numElements) => throw null;
public static System.ReadOnlySpan<T> GetManagedValuesSource(TestCollection<T> managed) => throw null;
public static System.Span<TUnmanagedElement> GetUnmanagedValuesDestination(byte* unmanaged, int numElements) => throw null;
}
";
public const string DefaultOut = @"
[CustomMarshaller(typeof(TestCollection<>), MarshalMode.Default, typeof(Marshaller<,>))]
[ContiguousCollectionMarshaller]
static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : unmanaged
{
public static TestCollection<T> AllocateContainerForManagedElements(byte* unmanaged, int length) => throw null;
public static System.Span<T> GetManagedValuesDestination(TestCollection<T> managed) => throw null;
public static System.ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(byte* unmanaged, int numElements) => throw null;
}
";
public static string ByValue<T>() => ByValue(typeof(T).ToString());
public static string ByValue(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>", DisableRuntimeMarshalling)
Expand All @@ -1460,17 +1480,17 @@ public static string ByValueCallerAllocatedBuffer(string elementType) => BasicPa
public static string DefaultMarshallerParametersAndModifiers<T>() => DefaultMarshallerParametersAndModifiers(typeof(T).ToString());
public static string DefaultMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>")
+ TestCollection()
+ Ref;
+ Default;

public static string CustomMarshallerParametersAndModifiers<T>() => CustomMarshallerParametersAndModifiers(typeof(T).ToString());
public static string CustomMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionParametersAndModifiers($"TestCollection<{elementType}>", $"Marshaller<,>")
+ TestCollection(defineNativeMarshalling: false)
+ Ref;
+ Default;

public static string CustomMarshallerReturnValueLength<T>() => CustomMarshallerReturnValueLength(typeof(T).ToString());
public static string CustomMarshallerReturnValueLength(string elementType) => MarshalUsingCollectionReturnValueLength($"TestCollection<{elementType}>", $"Marshaller<,>")
+ TestCollection(defineNativeMarshalling: false)
+ Ref;
+ Default;

public static string NativeToManagedOnlyOutParameter<T>() => NativeToManagedOnlyOutParameter(typeof(T).ToString());
public static string NativeToManagedOnlyOutParameter(string elementType) => CollectionOutParameter($"TestCollection<{elementType}>")
Expand All @@ -1485,7 +1505,7 @@ public static string NativeToManagedOnlyReturnValue(string elementType) => Colle
public static string NestedMarshallerParametersAndModifiers<T>() => NestedMarshallerParametersAndModifiers(typeof(T).ToString());
public static string NestedMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>")
+ TestCollection()
+ RefNested;
+ DefaultNested;

public static string NonBlittableElementParametersAndModifiers => DefaultMarshallerParametersAndModifiers("Element")
+ NonBlittableElement
Expand All @@ -1503,6 +1523,14 @@ public static string NestedMarshallerParametersAndModifiers(string elementType)
+ NonBlittableElement
+ ElementOut;

public static string DefaultModeByValueInParameter => BasicParameterByValue($"TestCollection<int>", DisableRuntimeMarshalling)
+ TestCollection()
+ DefaultIn;

public static string DefaultModeReturnValue => CollectionOutParameter($"TestCollection<int>")
+ TestCollection()
+ DefaultOut;

public static string GenericCollectionMarshallingArityMismatch => BasicParameterByValue("TestCollection<int>", DisableRuntimeMarshalling)
+ @"
[NativeMarshalling(typeof(Marshaller<,,>))]
Expand Down Expand Up @@ -1542,7 +1570,7 @@ out int pOutSize
}}
"
+ TestCollection()
+ Ref
+ Default
+ CustomIntMarshaller;

public static string CustomElementMarshallingDuplicateElementIndirectionDepth => $@"
Expand Down Expand Up @@ -1669,6 +1697,34 @@ public ref struct Out
public System.ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(int numElements) => throw null;
}
}
";
public const string DefaultIn = @"
[ContiguousCollectionMarshaller]
[CustomMarshaller(typeof(TestCollection<>), MarshalMode.Default, typeof(Marshaller<,>.In))]
static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : unmanaged
{
public ref struct In
{
public void FromManaged(TestCollection<T> managed) => throw null;
public byte* ToUnmanaged() => throw null;
public System.ReadOnlySpan<T> GetManagedValuesSource() => throw null;
public System.Span<TUnmanagedElement> GetUnmanagedValuesDestination() => throw null;
}
}
";
public const string DefaultOut = @"
[ContiguousCollectionMarshaller]
[CustomMarshaller(typeof(TestCollection<>), MarshalMode.Default, typeof(Marshaller<,>.Out))]
static unsafe class Marshaller<T, TUnmanagedElement> where TUnmanagedElement : unmanaged
{
public ref struct Out
{
public void FromUnmanaged(byte* value) => throw null;
public TestCollection<T> ToManaged() => throw null;
public System.Span<T> GetManagedValuesDestination(int numElements) => throw null;
public System.ReadOnlySpan<TUnmanagedElement> GetUnmanagedValuesSource(int numElements) => throw null;
}
}
";
public static string ByValue<T>() => ByValue(typeof(T).ToString());
public static string ByValue(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>", DisableRuntimeMarshalling)
Expand Down Expand Up @@ -1731,6 +1787,14 @@ public static string NativeToManagedOnlyReturnValue(string elementType) => Colle
+ NonBlittableElement
+ ElementOut;

public static string DefaultModeByValueInParameter => BasicParameterByValue($"TestCollection<int>", DisableRuntimeMarshalling)
+ TestCollection()
+ DefaultIn;

public static string DefaultModeReturnValue => CollectionOutParameter($"TestCollection<int>")
+ TestCollection()
+ DefaultOut;

public static string CustomElementMarshalling => $@"
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ public static IEnumerable<object[]> CustomCollections()
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.NonBlittableElementParametersAndModifiers };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.NonBlittableElementNativeToManagedOnlyOutParameter };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.NonBlittableElementNativeToManagedOnlyReturnValue };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultModeByValueInParameter };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.DefaultModeReturnValue };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateless.CustomElementMarshalling };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateful.DefaultMarshallerParametersAndModifiers<byte>() };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateful.DefaultMarshallerParametersAndModifiers<sbyte>() };
Expand Down Expand Up @@ -392,6 +394,8 @@ public static IEnumerable<object[]> CustomCollections()
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateful.NonBlittableElementParametersAndModifiers };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateful.NonBlittableElementNativeToManagedOnlyOutParameter };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateful.NonBlittableElementNativeToManagedOnlyReturnValue };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateful.DefaultModeByValueInParameter };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateful.DefaultModeReturnValue };
yield return new[] { ID(), CodeSnippets.CustomCollectionMarshalling.Stateful.CustomElementMarshalling };
yield return new[] { ID(), CodeSnippets.CollectionsOfCollectionsStress };
}
Expand Down

0 comments on commit 4c63062

Please sign in to comment.