diff --git a/src/mono/sample/wasi/native/Makefile b/src/mono/sample/wasi/native/Makefile new file mode 100644 index 0000000000000..eae68ce1d83d3 --- /dev/null +++ b/src/mono/sample/wasi/native/Makefile @@ -0,0 +1,15 @@ +TOP=../../../../.. + +include ../wasi.mk + +ifneq ($(AOT),) +override MSBUILD_ARGS+=/p:RunAOTCompilation=true +endif + +ifneq ($(V),) +DOTNET_MONO_LOG_LEVEL=--setenv=MONO_LOG_LEVEL=debug +endif + +PROJECT_NAME=Wasi.Console.Sample.csproj + +run: run-console diff --git a/src/mono/sample/wasi/native/Program.cs b/src/mono/sample/wasi/native/Program.cs new file mode 100644 index 0000000000000..cb2fd0f36caf5 --- /dev/null +++ b/src/mono/sample/wasi/native/Program.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; + +public unsafe class Test +{ + [UnmanagedCallersOnly(EntryPoint = "ManagedFunc")] + public static int MyExport(int number) + { + // called from MyImport aka UnmanagedFunc + Console.WriteLine($"MyExport({number}) -> 42"); + return 42; + } + + [DllImport("*", EntryPoint = "UnmanagedFunc")] + public static extern void MyImport(); // calls ManagedFunc aka MyExport + + public unsafe static int Main(string[] args) + { + Console.WriteLine($"main: {args.Length}"); + // workaround to force the interpreter to initialize wasm_native_to_interp_ftndesc for MyExport + if (args.Length > 10000) { + ((IntPtr)(delegate* unmanaged)&MyExport).ToString(); + } + + MyImport(); + return 0; + } +} diff --git a/src/mono/sample/wasi/native/Wasi.Native.Sample.csproj b/src/mono/sample/wasi/native/Wasi.Native.Sample.csproj new file mode 100644 index 0000000000000..9af9d62eeee98 --- /dev/null +++ b/src/mono/sample/wasi/native/Wasi.Native.Sample.csproj @@ -0,0 +1,17 @@ + + + $(NetCoreAppCurrent) + wasi-wasm + + wasi + true + false + false + true + + + + + + + diff --git a/src/mono/sample/wasi/native/local.c b/src/mono/sample/wasi/native/local.c new file mode 100644 index 0000000000000..9141571f949be --- /dev/null +++ b/src/mono/sample/wasi/native/local.c @@ -0,0 +1,11 @@ +#include + +int ManagedFunc(int number); + +void UnmanagedFunc() +{ + int ret = 0; + printf("UnmanagedFunc calling ManagedFunc\n"); + ret = ManagedFunc(123); + printf("ManagedFunc returned %d\n", ret); +} \ No newline at end of file diff --git a/src/tasks/WasmAppBuilder/PInvokeCollector.cs b/src/tasks/WasmAppBuilder/PInvokeCollector.cs index 6c26c3a7be979..b760899ea0c8c 100644 --- a/src/tasks/WasmAppBuilder/PInvokeCollector.cs +++ b/src/tasks/WasmAppBuilder/PInvokeCollector.cs @@ -15,17 +15,19 @@ internal sealed class PInvoke : IEquatable #pragma warning restore CA1067 { - public PInvoke(string entryPoint, string module, MethodInfo method) + public PInvoke(string entryPoint, string module, MethodInfo method, bool wasmLinkage) { EntryPoint = entryPoint; Module = module; Method = method; + WasmLinkage = wasmLinkage; } public string EntryPoint; public string Module; public MethodInfo Method; public bool Skip; + public bool WasmLinkage; public bool Equals(PInvoke? other) => other != null && @@ -100,9 +102,10 @@ void CollectPInvokesForMethod(MethodInfo method) if ((method.Attributes & MethodAttributes.PinvokeImpl) != 0) { var dllimport = method.CustomAttributes.First(attr => attr.AttributeType.Name == "DllImportAttribute"); + var wasmLinkage = method.CustomAttributes.Any(attr => attr.AttributeType.Name == "WasmImportLinkageAttribute"); var module = (string)dllimport.ConstructorArguments[0].Value!; var entrypoint = (string)dllimport.NamedArguments.First(arg => arg.MemberName == "EntryPoint").TypedValue.Value!; - pinvokes.Add(new PInvoke(entrypoint, module, method)); + pinvokes.Add(new PInvoke(entrypoint, module, method, wasmLinkage)); string? signature = SignatureMapper.MethodToSignature(method); if (signature == null) @@ -241,8 +244,23 @@ internal sealed class PInvokeCallback public PInvokeCallback(MethodInfo method) { Method = method; + foreach (var attr in method.CustomAttributes) + { + if (attr.AttributeType.Name == "UnmanagedCallersOnlyAttribute") + { + foreach(var arg in attr.NamedArguments) + { + if (arg.MemberName == "EntryPoint") + { + EntryPoint = arg.TypedValue.Value!.ToString(); + return; + } + } + } + } } + public string? EntryPoint; public MethodInfo Method; public string? EntryName; } diff --git a/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs b/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs index 2665508c9dfa5..d4f7041035089 100644 --- a/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs +++ b/src/tasks/WasmAppBuilder/PInvokeTableGenerator.cs @@ -7,6 +7,7 @@ using System.IO; using System.Linq; using System.Text; +using System.Text.RegularExpressions; using System.Reflection; using Microsoft.Build.Framework; using Microsoft.Build.Utilities; @@ -63,39 +64,40 @@ public IEnumerable Generate(string[] pinvokeModules, string outputPath) return signatures; } - private static bool HasAttribute(MemberInfo element, params string[] attributeNames) + private void EmitPInvokeTable(StreamWriter w, Dictionary modules, List pinvokes) { - foreach (CustomAttributeData cattr in CustomAttributeData.GetCustomAttributes(element)) + + foreach (var pinvoke in pinvokes) { - try + if (modules.ContainsKey(pinvoke.Module)) + continue; + // Handle special modules, and add them to the list of modules + // otherwise, skip them and throw an exception at runtime if they + // are called. + if (pinvoke.WasmLinkage) { - for (int i = 0; i < attributeNames.Length; ++i) - { - if (cattr.AttributeType.FullName == attributeNames [i] || - cattr.AttributeType.Name == attributeNames[i]) - { - return true; - } - } + // WasmLinkage means we needs to import the module + modules.Add(pinvoke.Module, pinvoke.Module); + Log.LogMessage(MessageImportance.Low, $"Adding module {pinvoke.Module} for WasmImportLinkage"); } - catch + else if (pinvoke.Module == "*" || pinvoke.Module == "__Internal") { - // Assembly not found, ignore + // Special case for __Internal and * modules to indicate static linking wihtout specifying the module + modules.Add(pinvoke.Module, pinvoke.Module); + Log.LogMessage(MessageImportance.Low, $"Adding module {pinvoke.Module} for static linking"); } } - return false; - } - private void EmitPInvokeTable(StreamWriter w, Dictionary modules, List pinvokes) - { - w.WriteLine("// GENERATED FILE, DO NOT MODIFY"); - w.WriteLine(); + w.WriteLine( + $""" + // GENERATED FILE, DO NOT MODIFY"); + + """); var pinvokesGroupedByEntryPoint = pinvokes .Where(l => modules.ContainsKey(l.Module)) .OrderBy(l => l.EntryPoint) - .GroupBy(l => l.EntryPoint); - + .GroupBy(CEntryPoint); var comparer = new PInvokeComparer(); foreach (IGrouping group in pinvokesGroupedByEntryPoint) { @@ -120,7 +122,7 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary modules foreach (var candidate in candidates) { var decl = GenPInvokeDecl(candidate); - if (decl == null || decls.Contains(decl)) + if (decl is null || decls.Contains(decl)) continue; w.WriteLine(decl); @@ -130,37 +132,35 @@ private void EmitPInvokeTable(StreamWriter w, Dictionary modules foreach (var module in modules.Keys) { - string symbol = _fixupSymbolName(module) + "_imports"; - w.WriteLine("static PinvokeImport " + symbol + " [] = {"); + var assemblies_pinvokes = pinvokes + .Where(l => l.Module == module && !l.Skip) + .OrderBy(l => l.EntryPoint) + .GroupBy(d => d.EntryPoint) + .Select(l => $"{{\"{EscapeLiteral(l.Key)}\", {CEntryPoint(l.First())}}}, " + + "// " + string.Join(", ", l.Select(c => c.Method.DeclaringType!.Module!.Assembly!.GetName()!.Name!).Distinct().OrderBy(n => n))) + .Append("{NULL, NULL}"); + + w.Write( + $$""" + static PinvokeImport {{_fixupSymbolName(module)}}_imports [] = { + {{string.Join("\n ", assemblies_pinvokes)}} + }; + + """); + } - var assemblies_pinvokes = pinvokes. - Where(l => l.Module == module && !l.Skip). - OrderBy(l => l.EntryPoint). - GroupBy(d => d.EntryPoint). - Select(l => "{\"" + _fixupSymbolName(l.Key) + "\", " + _fixupSymbolName(l.Key) + "}, " + - "// " + string.Join(", ", l.Select(c => c.Method.DeclaringType!.Module!.Assembly!.GetName()!.Name!).Distinct().OrderBy(n => n))); + w.Write( + $$""" - foreach (var pinvoke in assemblies_pinvokes) - { - w.WriteLine(pinvoke); - } + static void *pinvoke_tables[] = { + {{string.Join(", ", modules.Keys.Select(m => $"(void*){_fixupSymbolName(m)}_imports"))}} + }; - w.WriteLine("{NULL, NULL}"); - w.WriteLine("};"); - } - w.Write("static void *pinvoke_tables[] = { "); - foreach (var module in modules.Keys) - { - string symbol = _fixupSymbolName(module) + "_imports"; - w.Write(symbol + ","); - } - w.WriteLine("};"); - w.Write("static char *pinvoke_names[] = { "); - foreach (var module in modules.Keys) - { - w.Write("\"" + module + "\"" + ","); - } - w.WriteLine("};"); + static char *pinvoke_names[] = { + {{string.Join(", ", modules.Keys.Select(m => $"\"{EscapeLiteral(m)}\""))}} + }; + + """); static bool ShouldTreatAsVariadic(PInvoke[] candidates) { @@ -179,15 +179,14 @@ static bool ShouldTreatAsVariadic(PInvoke[] candidates) } } - private string SymbolNameForMethod(MethodInfo method) + private string CEntryPoint(PInvoke pinvoke) { - StringBuilder sb = new(); - Type? type = method.DeclaringType; - sb.Append($"{type!.Module!.Assembly!.GetName()!.Name!}_"); - sb.Append($"{(type!.IsNested ? type!.FullName : type!.Name)}_"); - sb.Append(method.Name); - - return _fixupSymbolName(sb.ToString()); + if (pinvoke.WasmLinkage) + { + // We mangle the name to avoid collisions with symbols in other modules + return _fixupSymbolName($"{pinvoke.Module}_{pinvoke.EntryPoint}"); + } + return _fixupSymbolName(pinvoke.EntryPoint); } private static string MapType(Type t) => t.Name switch @@ -224,14 +223,12 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN private string? GenPInvokeDecl(PInvoke pinvoke) { - var sb = new StringBuilder(); var method = pinvoke.Method; if (method.Name == "EnumCalendarInfo") { // FIXME: System.Reflection.MetadataLoadContext can't decode function pointer types // https://github.com/dotnet/runtime/issues/43791 - sb.Append($"int {_fixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);"); - return sb.ToString(); + return $"int {_fixupSymbolName(pinvoke.EntryPoint)} (int, int, int, int, int);"; } if (TryIsMethodGetParametersUnsupported(pinvoke.Method, out string? reason)) @@ -245,21 +242,40 @@ private static bool TryIsMethodGetParametersUnsupported(MethodInfo method, [NotN return null; } - sb.Append(MapType(method.ReturnType)); - sb.Append($" {_fixupSymbolName(pinvoke.EntryPoint)} ("); - int pindex = 0; - var pars = method.GetParameters(); - foreach (var p in pars) - { - if (pindex > 0) - sb.Append(','); - sb.Append(MapType(pars[pindex].ParameterType)); - pindex++; + return + $$""" + {{(pinvoke.WasmLinkage ? $"__attribute__((import_module(\"{EscapeLiteral(pinvoke.Module)}\"),import_name(\"{EscapeLiteral(pinvoke.EntryPoint)}\")))" : "")}} + {{(pinvoke.WasmLinkage ? "extern " : "")}}{{MapType(method.ReturnType)}} {{CEntryPoint(pinvoke)}} ({{ + string.Join(", ", method.GetParameters().Select(p => MapType(p.ParameterType))) + }}); + """; + } + + private string CEntryPoint(PInvokeCallback export) + { + if (export.EntryPoint is not null) { + return _fixupSymbolName(export.EntryPoint); } - sb.Append(");"); - return sb.ToString(); + + var method = export.Method; + // EntryPoint wasn't specified generate a name for the entry point + return _fixupSymbolName($"wasm_native_to_interp_{method.DeclaringType!.Module!.Assembly!.GetName()!.Name!}_{method.DeclaringType.Name}_{method.Name}"); + } + + private string DelegateKey(PInvokeCallback export) + { + // FIXME: this is a hack, we need to encode this better + // and allow reflection in the interp case but either way + // it needs to match the key generated in get_native_to_interp + var method = export.Method; + string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!; + return $"\"{module_symbol}_{method.DeclaringType.Name}_{method.Name}\"".Replace('.', '_'); } +#pragma warning disable SYSLIB1045 // framework doesn't support GeneratedRegexAttribute + private static string EscapeLiteral(string s) => Regex.Replace(s, @"(\\|\"")", @"\$1"); +#pragma warning restore SYSLIB1045 + private void EmitNativeToInterp(StreamWriter w, List callbacks) { // Generate native->interp entry functions @@ -273,112 +289,89 @@ private void EmitNativeToInterp(StreamWriter w, List callbacks) int cb_index = 0; // Arguments to interp entry functions in the runtime - w.WriteLine("InterpFtnDesc wasm_native_to_interp_ftndescs[" + callbacks.Count + "];"); + w.WriteLine($"InterpFtnDesc wasm_native_to_interp_ftndescs[{callbacks.Count}] = {{}};"); var callbackNames = new HashSet(); foreach (var cb in callbacks) { var sb = new StringBuilder(); var method = cb.Method; + bool is_void = method.ReturnType.Name == "Void"; // The signature of the interp entry function // This is a gsharedvt_in signature - sb.Append("typedef void "); - sb.Append($" (*WasmInterpEntrySig_{cb_index}) ("); - int pindex = 0; - if (method.ReturnType.Name != "Void") + sb.Append($"typedef void (*WasmInterpEntrySig_{cb_index}) ("); + + if (!is_void) { - sb.Append("int*"); - pindex++; + sb.Append("int*, "); } foreach (var p in method.GetParameters()) { - if (pindex > 0) - sb.Append(','); - sb.Append("int*"); - pindex++; + sb.Append("int*, "); } - if (pindex > 0) - sb.Append(','); // Extra arg - sb.Append("int*"); - sb.Append(");\n"); - - bool is_void = method.ReturnType.Name == "Void"; + sb.Append("int*);\n"); - string module_symbol = method.DeclaringType!.Module!.Assembly!.GetName()!.Name!; - uint token = (uint)method.MetadataToken; - string class_name = method.DeclaringType.Name; - string method_name = method.Name; - string entry_name = _fixupSymbolName($"wasm_native_to_interp_{module_symbol}_{class_name}_{method_name}"); - if (callbackNames.Contains(entry_name)) + cb.EntryName = CEntryPoint(cb); + if (callbackNames.Contains(cb.EntryName)) { - Error($"Two callbacks with the same name '{method_name}' are not supported."); + Error($"Two callbacks with the same name '{cb.EntryName}' are not supported."); } - callbackNames.Add(entry_name); - cb.EntryName = entry_name; - sb.Append(MapType(method.ReturnType)); - sb.Append($" {entry_name} ("); - pindex = 0; + callbackNames.Add(cb.EntryName); + if (cb.EntryPoint is not null) + { + sb.Append($"__attribute__((export_name(\"{EscapeLiteral(cb.EntryPoint)}\")))\n"); + } + sb.Append($"{MapType(method.ReturnType)} {cb.EntryName} ("); + int pindex = 0; foreach (var p in method.GetParameters()) { if (pindex > 0) - sb.Append(','); - sb.Append(MapType(p.ParameterType)); - sb.Append($" arg{pindex}"); + sb.Append(", "); + sb.Append($"{MapType(p.ParameterType)} arg{pindex}"); pindex++; } sb.Append(") { \n"); if (!is_void) - sb.Append(MapType(method.ReturnType) + " res;\n"); - sb.Append($"((WasmInterpEntrySig_{cb_index})wasm_native_to_interp_ftndescs [{cb_index}].func) ("); - pindex = 0; + sb.Append($" {MapType(method.ReturnType)} res;\n"); + + //sb.Append($" printf(\"{entry_name} called\\n\");\n"); + sb.Append($" ((WasmInterpEntrySig_{cb_index})wasm_native_to_interp_ftndescs [{cb_index}].func) ("); if (!is_void) { - sb.Append("(int*)&res"); + sb.Append("(int*)&res, "); pindex++; } int aindex = 0; foreach (var p in method.GetParameters()) { - if (pindex > 0) - sb.Append(", "); - sb.Append($"(int*)&arg{aindex}"); - pindex++; + sb.Append($"(int*)&arg{aindex}, "); aindex++; } - if (pindex > 0) - sb.Append(", "); - sb.Append($"wasm_native_to_interp_ftndescs [{cb_index}].arg"); - sb.Append(");\n"); + + sb.Append($"wasm_native_to_interp_ftndescs [{cb_index}].arg);\n"); + if (!is_void) - sb.Append("return res;\n"); - sb.Append('}'); + sb.Append(" return res;\n"); + sb.Append("}\n"); w.WriteLine(sb); cb_index++; } - // Array of function pointers - w.Write("static void *wasm_native_to_interp_funcs[] = { "); - foreach (var cb in callbacks) - { - w.Write(cb.EntryName + ","); - } - w.WriteLine("};"); + w.Write( + $$""" - // Lookup table from method->interp entry - // The key is a string of the form _ - // FIXME: Use a better encoding - w.Write("static const char *wasm_native_to_interp_map[] = { "); - foreach (var cb in callbacks) - { - var method = cb.Method; - string module_symbol = _fixupSymbolName(method.DeclaringType!.Module!.Assembly!.GetName()!.Name!); - string class_name = method.DeclaringType.Name; - string method_name = method.Name; - w.WriteLine($"\"{module_symbol}_{class_name}_{method_name}\","); - } - w.WriteLine("};"); + static void *wasm_native_to_interp_funcs[] = { + {{string.Join(", ", callbacks.Select(cb => cb.EntryName))}} + }; + + // these strings need to match the keys generated in get_native_to_interp + static const char *wasm_native_to_interp_map[] = { + {{string.Join(", ", callbacks.Select(DelegateKey))}} + }; + + """); } private bool HasAssemblyDisableRuntimeMarshallingAttribute(Assembly assembly)