diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 5b68a6cfc2653..2d44720dcf6c1 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -314,8 +314,8 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( } bool enableArrayPinning = elementMarshaller is BlittableMarshaller; - bool treatAsBlittable = enableArrayPinning || elementMarshaller is Utf16CharMarshaller; - if (treatAsBlittable) + bool treatElementAsBlittable = enableArrayPinning || elementMarshaller is Utf16CharMarshaller; + if (treatElementAsBlittable) { marshallingStrategy = new LinearCollectionWithBlittableElementsMarshalling(marshallingStrategy, collectionInfo.ElementType.Syntax, numElementsExpression); } @@ -345,7 +345,8 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); - if (collectionInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.ManagedType)) + // Elements in the collection must be blittable to use the pinnable marshaller. + if (collectionInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.ManagedType) && treatElementAsBlittable) { return new PinnableManagedValueMarshaller(marshallingGenerator); } diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs index 947e9b28e4ed4..f0c35bf2ed24b 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/CollectionTests.cs @@ -39,6 +39,9 @@ public partial class Collections [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] public static partial int SumStringLengths([MarshalUsing(typeof(ListMarshaller)), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] List strArray); + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] + public static partial int SumStringLengths([MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] WrappedList strArray); + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_replace")] public static partial void ReverseStrings_Ref([MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] ref List strArray, out int numElements); @@ -57,7 +60,7 @@ public static partial void ReverseStrings_Out( public static partial List GetLongBytes(long l); [LibraryImport(NativeExportsNE_Binary, EntryPoint = "and_all_members")] - [return:MarshalAs(UnmanagedType.U1)] + [return: MarshalAs(UnmanagedType.U1)] public static partial bool AndAllMembers([MarshalUsing(typeof(ListMarshaller))] List pArray, int length); } } @@ -143,6 +146,13 @@ public void ByValueNullCollectionWithNonBlittableElements() Assert.Equal(0, NativeExportsNE.Collections.SumStringLengths(null)); } + [Fact] + public void ByValueCollectionWithNonBlittableElements_WithDefaultMarshalling() + { + var strings = new WrappedList(GetStringList()); + Assert.Equal(strings.Wrapped.Sum(str => str?.Length ?? 0), NativeExportsNE.Collections.SumStringLengths(strings)); + } + [Fact] public void ByRefCollectionWithNonBlittableElements() { diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs index 7732abe4c102c..d580054ad4ab1 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/NonBlittable.cs @@ -227,4 +227,57 @@ public void FreeNative() Marshal.FreeCoTaskMem(allocatedMemory); } } + + [NativeMarshalling(typeof(WrappedListMarshaller<>))] + public struct WrappedList + { + public WrappedList(List list) + { + Wrapped = list; + } + + public List Wrapped { get; } + + public ref T GetPinnableReference() => ref CollectionsMarshal.AsSpan(Wrapped).GetPinnableReference(); + } + + [CustomTypeMarshaller(typeof(WrappedList<>), CustomTypeMarshallerKind.LinearCollection, Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer, BufferSize = 0x200)] + public unsafe ref struct WrappedListMarshaller + { + private ListMarshaller _marshaller; + + public WrappedListMarshaller(int sizeOfNativeElement) + : this() + { + this._marshaller = new ListMarshaller(sizeOfNativeElement); + } + + public WrappedListMarshaller(WrappedList managed, int sizeOfNativeElement) + : this(managed, Span.Empty, sizeOfNativeElement) + { + } + + public WrappedListMarshaller(WrappedList managed, Span stackSpace, int sizeOfNativeElement) + { + this._marshaller = new ListMarshaller(managed.Wrapped, stackSpace, sizeOfNativeElement); + } + + public ReadOnlySpan GetManagedValuesSource() => _marshaller.GetManagedValuesSource(); + + public Span GetManagedValuesDestination(int length) => _marshaller.GetManagedValuesDestination(length); + + public Span GetNativeValuesDestination() => _marshaller.GetNativeValuesDestination(); + + public ReadOnlySpan GetNativeValuesSource(int length) => _marshaller.GetNativeValuesSource(length); + + public ref byte GetPinnableReference() => ref _marshaller.GetPinnableReference(); + + public byte* ToNativeValue() => _marshaller.ToNativeValue(); + + public void FromNativeValue(byte* value) => _marshaller.FromNativeValue(value); + + public WrappedList ToManaged() => new(_marshaller.ToManaged()); + + public void FreeNative() => _marshaller.FreeNative(); + } }