Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StringMarshalling behavior override tests #86963

Merged
merged 2 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public unsafe partial class StringMarshallingTests
[LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf16_marshalling")]
public static partial void* NewIUtf16Marshalling();

[LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_string_marshalling_override")]
public static partial void* NewStringMarshallingOverride();

[GeneratedComClass]
internal partial class Utf8MarshalledClass : IUTF8Marshalling
{
Expand Down Expand Up @@ -107,5 +110,26 @@ public void RcwToCcw()
customUtf16ComObject.SetString("Set from COM object");
Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString());
}

[Fact]
public void MarshalAsAndMarshalUsingOverrideStringMarshalling()
{
var ptr = NewStringMarshallingOverride();
var cw = new StrategyBasedComWrappers();
var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
var stringMarshallingOverride = (IStringMarshallingOverride)obj;
Assert.Equal("Your string: MyUtf8String", stringMarshallingOverride.StringMarshallingUtf8("MyUtf8String"));
Assert.Equal("Your string: MyLPWStrString", stringMarshallingOverride.MarshalAsLPWString("MyLPWStrString"));
Assert.Equal("Your string: MyUtf16String", stringMarshallingOverride.MarshalUsingUtf16("MyUtf16String"));

// Make sure the shadowing methods generated for the derived interface also follow the rules
var stringMarshallingOverrideDerived = (IStringMarshallingOverrideDerived)obj;
Assert.Equal("Your string: MyUtf8String", stringMarshallingOverrideDerived.StringMarshallingUtf8("MyUtf8String"));
Assert.Equal("Your string: MyLPWStrString", stringMarshallingOverrideDerived.MarshalAsLPWString("MyLPWStrString"));
Assert.Equal("Your string: MyUtf16String", stringMarshallingOverrideDerived.MarshalUsingUtf16("MyUtf16String"));
Assert.Equal("Your string 2: MyUtf8String", stringMarshallingOverrideDerived.StringMarshallingUtf8_2("MyUtf8String"));
Assert.Equal("Your string 2: MyLPWStrString", stringMarshallingOverrideDerived.MarshalAsLPWString_2("MyLPWStrString"));
Assert.Equal("Your string 2: MyUtf16String", stringMarshallingOverrideDerived.MarshalUsingUtf16_2("MyUtf16String"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Text;
using System.Threading.Tasks;
using SharedTypes.ComInterfaces;
using static System.Runtime.InteropServices.ComWrappers;

namespace NativeExports.ComInterfaceGenerator
{
public unsafe partial class StringMarshallingOverride
{
[UnmanagedCallersOnly(EntryPoint = "new_string_marshalling_override")]
public static void* CreateStringMarshallingOverrideObject()
{
MyComWrapper cw = new();
var myObject = new Implementation();
nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
return (void*)ptr;
}

class MyComWrapper : ComWrappers
{
static void* _s_comInterfaceVTable = null;
static void* S_VTable
{
get
{
if (_s_comInterfaceVTable != null)
return _s_comInterfaceVTable;
void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 6);
GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
vtable[0] = (void*)fpQueryInterface;
vtable[1] = (void*)fpAddReference;
vtable[2] = (void*)fpRelease;
vtable[3] = (delegate* unmanaged<void*, byte*, byte**, int>)&Implementation.ABI.StringMarshallingUtf8;
vtable[4] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalAsLPWStr;
vtable[5] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalUsingUtf16;
_s_comInterfaceVTable = vtable;
return _s_comInterfaceVTable;
}
}

static void* _s_derivedVTable = null;
static void* S_DerivedVTable
{
get
{
if (_s_comInterfaceVTable != null)
return _s_comInterfaceVTable;
void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 9);
GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
vtable[0] = (void*)fpQueryInterface;
vtable[1] = (void*)fpAddReference;
vtable[2] = (void*)fpRelease;
vtable[3] = (delegate* unmanaged<void*, byte*, byte**, int>)&Implementation.ABI.StringMarshallingUtf8;
vtable[4] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalAsLPWStr;
vtable[5] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalUsingUtf16;
vtable[6] = (delegate* unmanaged<void*, byte*, byte**, int>)&Implementation.ABI.StringMarshallingUtf8_2;
vtable[7] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalAsLPWStr_2;
vtable[8] = (delegate* unmanaged<void*, ushort*, ushort**, int>)&Implementation.ABI.MarshalUsingUtf16_2;
_s_comInterfaceVTable = vtable;
return _s_comInterfaceVTable;
}
}

protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
if (obj is IStringMarshallingOverrideDerived)
{
ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Implementation), sizeof(ComInterfaceEntry) * 2);
comInterfaceEntry[0].IID = new Guid(IStringMarshallingOverrideDerived._guid);
comInterfaceEntry[0].Vtable = (nint)S_DerivedVTable;
comInterfaceEntry[1].IID = new Guid(IStringMarshallingOverride._guid);
comInterfaceEntry[1].Vtable = (nint)S_VTable;
count = 2;
return comInterfaceEntry;
}
if (obj is IStringMarshallingOverride)
{
ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Implementation), sizeof(ComInterfaceEntry));
comInterfaceEntry->IID = new Guid(IStringMarshallingOverride._guid);
comInterfaceEntry->Vtable = (nint)S_VTable;
count = 1;
return comInterfaceEntry;
}
count = 0;
return null;
}

protected override object? CreateObject(nint externalComObject, CreateObjectFlags flags) => throw new NotImplementedException();
protected override void ReleaseObjects(IEnumerable objects) => throw new NotImplementedException();
}

partial class Implementation : IStringMarshallingOverride, IStringMarshallingOverrideDerived
{
string _data = "Your string: ";
string IStringMarshallingOverride.StringMarshallingUtf8(string input) => _data + input;
string IStringMarshallingOverride.MarshalAsLPWString(string input) => _data + input;
string IStringMarshallingOverride.MarshalUsingUtf16(string input) => _data + input;

string _data2 = "Your string 2: ";
string IStringMarshallingOverrideDerived.StringMarshallingUtf8_2(string input) => _data2 + input;
string IStringMarshallingOverrideDerived.MarshalAsLPWString_2(string input) => _data2 + input;
string IStringMarshallingOverrideDerived.MarshalUsingUtf16_2(string input) => _data2 + input;

// Provides function pointers in the COM format to use in COM VTables
public static class ABI
{
[UnmanagedCallersOnly]
public static int StringMarshallingUtf8(void* @this, byte* input, byte** output)
{
try
{
string inputStr = Utf8StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverride>((ComInterfaceDispatch*)@this).StringMarshallingUtf8(inputStr);
*output = Utf8StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int MarshalAsLPWStr(void* @this, ushort* input, ushort** output)
{
try
{
string inputStr = Utf16StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverride>((ComInterfaceDispatch*)@this).MarshalAsLPWString(inputStr);
*output = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int MarshalUsingUtf16(void* @this, ushort* input, ushort** output)
{
try
{
string inputStr = Utf16StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverride>((ComInterfaceDispatch*)@this).MarshalUsingUtf16(inputStr);
*output = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int StringMarshallingUtf8_2(void* @this, byte* input, byte** output)
{
try
{
string inputStr = Utf8StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverrideDerived>((ComInterfaceDispatch*)@this).StringMarshallingUtf8_2(inputStr);
*output = Utf8StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int MarshalAsLPWStr_2(void* @this, ushort* input, ushort** output)
{
try
{
string inputStr = Utf16StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverrideDerived>((ComInterfaceDispatch*)@this).MarshalAsLPWString_2(inputStr);
*output = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int MarshalUsingUtf16_2(void* @this, ushort* input, ushort** output)
{
try
{
string inputStr = Utf16StringMarshaller.ConvertToManaged(input);
string currValue = ComInterfaceDispatch.GetInstance<IStringMarshallingOverrideDerived>((ComInterfaceDispatch*)@this).MarshalUsingUtf16_2(inputStr);
*output = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Text;
using System.Threading.Tasks;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface(StringMarshalling = System.Runtime.InteropServices.StringMarshalling.Utf8)]
[Guid(_guid)]
internal partial interface IStringMarshallingOverride
{
public const string _guid = "5146B7DB-0588-469B-B8E5-B38090A2FC15";
string StringMarshallingUtf8(string input);

[return: MarshalAs(UnmanagedType.LPWStr)]
string MarshalAsLPWString([MarshalAs(UnmanagedType.LPWStr)] string input);

[return: MarshalUsing(typeof(Utf16StringMarshaller))]
string MarshalUsingUtf16([MarshalUsing(typeof(Utf16StringMarshaller))] string input);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices.Marshalling;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf8)]
[Guid(_guid)]
internal partial interface IStringMarshallingOverrideDerived : IStringMarshallingOverride
{
public new const string _guid = "3AFFE3FD-D11E-4195-8250-0C73321977A0";
string StringMarshallingUtf8_2(string input);

[return: MarshalAs(UnmanagedType.LPWStr)]
string MarshalAsLPWString_2([MarshalAs(UnmanagedType.LPWStr)] string input);

[return: MarshalUsing(typeof(Utf16StringMarshaller))]
string MarshalUsingUtf16_2([MarshalUsing(typeof(Utf16StringMarshaller))] string input);
}
}