diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 8622803fa..fe6aed9e9 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -55,14 +55,14 @@ jobs: inputs: command: test projects: '**/*UnitTest/*.csproj' - arguments: '--configuration $(buildConfiguration) /p:CollectCoverage=true /p:CoverletOutputFormat=cobertura' + arguments: '--configuration $(buildConfiguration)' - task: DotNetCoreCLI@2 displayName: 'E2E tests for Spark 2.3.0' inputs: command: test projects: '**/Microsoft.Spark.E2ETest/*.csproj' - arguments: '--configuration $(buildConfiguration) /p:CollectCoverage=true /p:CoverletOutputFormat=cobertura' + arguments: '--configuration $(buildConfiguration)' env: SPARK_HOME: $(Build.BinariesDirectory)\spark-2.3.0-bin-hadoop2.7 HADOOP_HOME: $(Build.BinariesDirectory)\hadoop @@ -73,7 +73,7 @@ jobs: inputs: command: test projects: '**/Microsoft.Spark.E2ETest/*.csproj' - arguments: '--configuration $(buildConfiguration) /p:CollectCoverage=true /p:CoverletOutputFormat=cobertura' + arguments: '--configuration $(buildConfiguration)' env: SPARK_HOME: $(Build.BinariesDirectory)\spark-2.3.1-bin-hadoop2.7 HADOOP_HOME: $(Build.BinariesDirectory)\hadoop @@ -84,7 +84,7 @@ jobs: inputs: command: test projects: '**/Microsoft.Spark.E2ETest/*.csproj' - arguments: '--configuration $(buildConfiguration) /p:CollectCoverage=true /p:CoverletOutputFormat=cobertura' + arguments: '--configuration $(buildConfiguration)' env: SPARK_HOME: $(Build.BinariesDirectory)\spark-2.3.2-bin-hadoop2.7 HADOOP_HOME: $(Build.BinariesDirectory)\hadoop @@ -95,7 +95,7 @@ jobs: inputs: command: test projects: '**/Microsoft.Spark.E2ETest/*.csproj' - arguments: '--configuration $(buildConfiguration) /p:CollectCoverage=true /p:CoverletOutputFormat=cobertura' + arguments: '--configuration $(buildConfiguration)' env: SPARK_HOME: $(Build.BinariesDirectory)\spark-2.3.3-bin-hadoop2.7 HADOOP_HOME: $(Build.BinariesDirectory)\hadoop @@ -106,7 +106,7 @@ jobs: inputs: command: test projects: '**/Microsoft.Spark.E2ETest/*.csproj' - arguments: '--configuration $(buildConfiguration) /p:CollectCoverage=true /p:CoverletOutputFormat=cobertura' + arguments: '--configuration $(buildConfiguration)' env: SPARK_HOME: $(Build.BinariesDirectory)\spark-2.4.0-bin-hadoop2.7 HADOOP_HOME: $(Build.BinariesDirectory)\hadoop @@ -117,7 +117,7 @@ jobs: inputs: command: test projects: '**/Microsoft.Spark.E2ETest/*.csproj' - arguments: '--configuration $(buildConfiguration) /p:CollectCoverage=true /p:CoverletOutputFormat=cobertura' + arguments: '--configuration $(buildConfiguration)' env: SPARK_HOME: $(Build.BinariesDirectory)\spark-2.4.1-bin-hadoop2.7 HADOOP_HOME: $(Build.BinariesDirectory)\hadoop @@ -128,7 +128,7 @@ jobs: inputs: command: test projects: '**/Microsoft.Spark.E2ETest/*.csproj' - arguments: '--configuration $(buildConfiguration) /p:CollectCoverage=true /p:CoverletOutputFormat=cobertura' + arguments: '--configuration $(buildConfiguration)' env: SPARK_HOME: $(Build.BinariesDirectory)\spark-2.4.3-bin-hadoop2.7 HADOOP_HOME: $(Build.BinariesDirectory)\hadoop diff --git a/src/csharp/Microsoft.Spark.UnitTest/UdfSerDeTests.cs b/src/csharp/Microsoft.Spark.UnitTest/UdfSerDeTests.cs new file mode 100644 index 000000000..3e56d2193 --- /dev/null +++ b/src/csharp/Microsoft.Spark.UnitTest/UdfSerDeTests.cs @@ -0,0 +1,173 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Reflection; +using System.Runtime.Serialization.Formatters.Binary; +using Microsoft.Spark.Utils; +using Xunit; + +namespace Microsoft.Spark.UnitTest +{ + public class UdfSerDeTests + { + [Serializable] + private class TestClass + { + private string _str; + + public TestClass(string s) + { + _str = s; + } + + public string Concat(string s) + { + if (_str == null) + { + return s + s; + } + + return _str + s; + } + + public override bool Equals(object obj) + { + var that = obj as TestClass; + + if (that == null) + { + return false; + } + + return _str == that._str; + } + + public override int GetHashCode() + { + return base.GetHashCode(); + } + } + + [Fact] + public void TestUdfSerDe() + { + { + // Without closure. + Func expectedUdf = i => 10 * i; + Delegate actualUdf = SerDe(expectedUdf); + + VerifyUdfSerDe(expectedUdf, actualUdf, false); + Assert.Equal(100, ((Func)actualUdf)(10)); + } + + { + // With closure where the delegate target is an anonymous class. + // The target will contain fields ["tc1", "tc2"], where "tc1" is + // non null and "tc2" is null. + TestClass tc1 = new TestClass("Test"); + TestClass tc2 = null; + Func expectedUdf = + (s) => + { + if (tc2 == null) + { + return tc1.Concat(s); + } + return s; + }; + Delegate actualUdf = SerDe(expectedUdf); + + VerifyUdfSerDe(expectedUdf, actualUdf, true); + Assert.Equal("TestHelloWorld", ((Func)actualUdf)("HelloWorld")); + } + + { + // With closure where the delegate target is TestClass + // and target's field "_str" is set to "Test". + TestClass tc = new TestClass("Test"); + Func expectedUdf = tc.Concat; + Delegate actualUdf = SerDe(expectedUdf); + + VerifyUdfSerDe(expectedUdf, actualUdf, true); + Assert.Equal("TestHelloWorld", ((Func)actualUdf)("HelloWorld")); + } + + { + // With closure where the delegate target is TestClass, + // and target's field "_str" is set to null. + TestClass tc = new TestClass(null); + Func expectedUdf = tc.Concat; + Delegate actualUdf = SerDe(expectedUdf); + + VerifyUdfSerDe(expectedUdf, actualUdf, true); + Assert.Equal( + "HelloWorldHelloWorld", + ((Func)actualUdf)("HelloWorld")); + } + } + + private void VerifyUdfSerDe(Delegate expectedUdf, Delegate actualUdf, bool hasClosure) + { + VerifyUdfData( + UdfSerDe.Serialize(expectedUdf), + UdfSerDe.Serialize(actualUdf), + hasClosure); + VerifyDelegate(expectedUdf, actualUdf); + } + + private void VerifyUdfData( + UdfSerDe.UdfData expectedUdfData, + UdfSerDe.UdfData actualUdfData, + bool hasClosure) + { + Assert.Equal(expectedUdfData, actualUdfData); + + if (!hasClosure) + { + Assert.Null(expectedUdfData.TargetData.Fields); + Assert.Null(actualUdfData.TargetData.Fields); + } + } + + private void VerifyDelegate(Delegate expectedDelegate, Delegate actualDelegate) + { + Assert.Equal(expectedDelegate.GetType(), actualDelegate.GetType()); + Assert.Equal(expectedDelegate.Method, actualDelegate.Method); + Assert.Equal(expectedDelegate.Target.GetType(), actualDelegate.Target.GetType()); + + FieldInfo[] expectedFields = expectedDelegate.Target.GetType().GetFields(); + FieldInfo[] actualFields = actualDelegate.Target.GetType().GetFields(); + Assert.Equal(expectedFields, actualFields); + } + + private Delegate SerDe(Delegate udf) + { + return Deserialize(Serialize(udf)); + } + + private byte[] Serialize(Delegate udf) + { + UdfSerDe.UdfData udfData = UdfSerDe.Serialize(udf); + + using (var ms = new MemoryStream()) + { + var bf = new BinaryFormatter(); + bf.Serialize(ms, udfData); + return ms.ToArray(); + } + } + + private Delegate Deserialize(byte[] serializedUdf) + { + using (var ms = new MemoryStream(serializedUdf, false)) + { + var bf = new BinaryFormatter(); + UdfSerDe.UdfData udfData = (UdfSerDe.UdfData)bf.Deserialize(ms); + return UdfSerDe.Deserialize(udfData); + } + } + } +} diff --git a/src/csharp/Microsoft.Spark.Worker/Microsoft.Spark.Worker.csproj b/src/csharp/Microsoft.Spark.Worker/Microsoft.Spark.Worker.csproj index 94d76a82d..9d60ea2e9 100644 --- a/src/csharp/Microsoft.Spark.Worker/Microsoft.Spark.Worker.csproj +++ b/src/csharp/Microsoft.Spark.Worker/Microsoft.Spark.Worker.csproj @@ -16,6 +16,10 @@ + + + + diff --git a/src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs b/src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs index eb575b51b..039383a56 100644 --- a/src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs +++ b/src/csharp/Microsoft.Spark.Worker/Processor/CommandProcessor.cs @@ -9,12 +9,23 @@ using Microsoft.Spark.Utils; using static Microsoft.Spark.Utils.UdfUtils; +#if NETCOREAPP +using System.Runtime.Loader; +#endif + namespace Microsoft.Spark.Worker.Processor { internal sealed class CommandProcessor { private readonly Version _version; +#if NETCOREAPP + static CommandProcessor() + { + UdfSerDe.AssemblyLoader = AssemblyLoadContext.Default.LoadFromAssemblyPath; + } +#endif + internal CommandProcessor(Version version) { _version = version; diff --git a/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs b/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs index 9c6d0f53f..f7ed34d76 100644 --- a/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs +++ b/src/csharp/Microsoft.Spark/Utils/CommandSerDe.cs @@ -212,7 +212,7 @@ private static void SerializeUdfs( foreach (UdfSerDe.FieldData field in fields) { - SerializeUdfs((Delegate)field.Value, curNode, udfWrapperNodes, udfs); + SerializeUdfs((Delegate)field.ValueData.Value, curNode, udfWrapperNodes, udfs); } } diff --git a/src/csharp/Microsoft.Spark/Utils/UdfSerDe.cs b/src/csharp/Microsoft.Spark/Utils/UdfSerDe.cs index e34d97028..f8f8e6e29 100644 --- a/src/csharp/Microsoft.Spark/Utils/UdfSerDe.cs +++ b/src/csharp/Microsoft.Spark/Utils/UdfSerDe.cs @@ -4,10 +4,13 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Runtime.Serialization; +using System.Runtime.Serialization.Formatters.Binary; namespace Microsoft.Spark.Utils { @@ -16,9 +19,14 @@ namespace Microsoft.Spark.Utils /// internal class UdfSerDe { + private static readonly ConcurrentDictionary> s_assemblyCache = + new ConcurrentDictionary>(); + private static readonly ConcurrentDictionary s_typeCache = new ConcurrentDictionary(); + internal static Func AssemblyLoader { get; set; } = Assembly.LoadFrom; + [Serializable] internal sealed class TypeData : IEquatable { @@ -54,6 +62,25 @@ internal sealed class UdfData public TypeData TypeData { get; set; } public string MethodName { get; set; } public TargetData TargetData { get; set; } + + public override int GetHashCode() + { + return base.GetHashCode(); + } + + public override bool Equals(object obj) + { + return (obj is UdfData udfData) && + Equals(udfData); + } + + public bool Equals(UdfData other) + { + return (other != null) && + TypeData.Equals(other.TypeData) && + (MethodName == other.MethodName) && + TargetData.Equals(other.TargetData); + } } [Serializable] @@ -61,14 +88,142 @@ internal sealed class TargetData { public TypeData TypeData { get; set; } public FieldData[] Fields { get; set; } + + public override int GetHashCode() + { + return base.GetHashCode(); + } + + public override bool Equals(object obj) + { + return (obj is TargetData targetData) && + Equals(targetData); + } + + public bool Equals(TargetData other) + { + if ((other == null) || + !TypeData.Equals(other.TypeData) || + (Fields?.Length != other.Fields?.Length)) + { + return false; + } + + if ((Fields == null) && (other.Fields == null)) + { + return true; + } + + return Fields.SequenceEqual(other.Fields); + } } [Serializable] internal sealed class FieldData { - public TypeData TypeData { get; set; } - public string Name { get; set; } - public object Value { get; set; } + public FieldData(object target, FieldInfo field) + { + object value = field.GetValue(target); + + TypeData = SerializeType(field.FieldType); + Name = field.Name; + ValueData = (value != null) + ? new ValueData(value) + : null; + } + + public TypeData TypeData { get; private set; } + public string Name { get; private set; } + public ValueData ValueData { get; private set; } + + public override int GetHashCode() + { + return base.GetHashCode(); + } + + public override bool Equals(object obj) + { + return (obj is FieldData fieldData) && + Equals(fieldData); + } + + public bool Equals(FieldData other) + { + return (other != null) && + TypeData.Equals(other.TypeData) && + (Name == other.Name) && + (((ValueData == null) && (other.ValueData == null)) || + ((ValueData != null) && ValueData.Equals(other.ValueData))); + } + } + + /// + /// The type of Value may be contained in an assembly outside the default + /// load context. Upon serialization, the TypeData is preserved, and Value + /// is serialized as a byte[]. Upon deserialization, if the assembly cannot + /// be found within the load context then TypeData will be used to load the + /// correct assembly. + /// + [Serializable] + internal sealed class ValueData : ISerializable + { + public ValueData(object value) + { + if (value == null) + { + throw new ArgumentNullException("value cannot be null."); + } + + TypeData = SerializeType(value.GetType()); + Value = value; + } + + public ValueData(SerializationInfo info, StreamingContext context) + { + TypeData = (TypeData)info.GetValue("TypeData", typeof(TypeData)); + LoadAssembly(TypeData.AssemblyName, TypeData.ManifestModuleName); + + var valueSerialized = (byte[])info.GetValue("ValueSerialized", typeof(byte[])); + using (var ms = new MemoryStream(valueSerialized, false)) + { + var bf = new BinaryFormatter(); + Value = bf.Deserialize(ms); + } + } + + public TypeData TypeData { get; private set; } + + public object Value { get; private set; } + + public void GetObjectData(SerializationInfo info, StreamingContext context) + { + info.AddValue("TypeData", TypeData, typeof(TypeData)); + + using (var ms = new MemoryStream()) + { + var bf = new BinaryFormatter(); + bf.Serialize(ms, Value); + info.AddValue("ValueSerialized", ms.ToArray(), typeof(byte[])); + } + } + + public override int GetHashCode() + { + return base.GetHashCode(); + } + + public override bool Equals(object obj) + { + return (obj is ValueData valueData) && + Equals(valueData); + } + + public bool Equals(ValueData other) + { + return (other != null) && + TypeData.Equals(other.TypeData) && + Value.Equals(other.Value); + } } internal static UdfData Serialize(Delegate udf) @@ -125,17 +280,15 @@ private static TargetData SerializeTarget(object target) Type targetType = target.GetType(); TypeData targetTypeData = SerializeType(targetType); - System.Collections.Generic.IEnumerable fields = targetType.GetFields( + var fields = new List(); + foreach (FieldInfo field in targetType.GetFields( BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | - BindingFlags.NonPublic). - Select((field) => new FieldData() - { - TypeData = SerializeType(field.FieldType), - Name = field.Name, - Value = field.GetValue(target) - }); + BindingFlags.NonPublic)) + { + fields.Add(new FieldData(target, field)); + } // Even when an UDF does not have any closure, GetFields() returns some fields // which include Func<> of the udf specified. @@ -158,7 +311,7 @@ private static TargetData SerializeTarget(object target) private static object DeserializeTargetData(TargetData targetData) { Type targetType = DeserializeType(targetData.TypeData); - var target = Activator.CreateInstance(targetType); + var target = FormatterServices.GetUninitializedObject(targetType); foreach (FieldData field in targetData.Fields ?? Enumerable.Empty()) { @@ -166,7 +319,7 @@ private static object DeserializeTargetData(TargetData targetData) field.Name, BindingFlags.Instance | BindingFlags.Public | - BindingFlags.NonPublic).SetValue(target, field.Value); + BindingFlags.NonPublic).SetValue(target, field.ValueData?.Value); } return target; @@ -184,30 +337,63 @@ private static TypeData SerializeType(Type type) private static Type DeserializeType(TypeData typeData) => s_typeCache.GetOrAdd(typeData, - td => LoadAssembly(typeData.ManifestModuleName).GetType(typeData.Name)); + td => LoadAssembly(td.AssemblyName, td.ManifestModuleName).GetType(td.Name)); + + /// + /// Return the cached assembly, otherwise attempt to load and cache the assembly + /// in the following order: + /// 1) Search the assemblies loaded in the current app domain. + /// 2) Load the assembly from disk using manifestModuleName. + /// + /// The full name of the assembly + /// Name of the module that contains the assembly + /// Cached or Loaded Assembly + private static Assembly LoadAssembly(string assemblyName, string manifestModuleName) + { + return s_assemblyCache.GetOrAdd( + assemblyName, + _ => new Lazy( + () => + { + foreach (Assembly asm in AppDomain.CurrentDomain.GetAssemblies()) + { + if (asm.FullName.Equals(assemblyName)) + { + return asm; + } + } + + return LoadAssembly(manifestModuleName); + })).Value; + } /// /// Returns the loaded assembly by probing the following locations in order: /// 1) The working directory /// 2) The directory of the application /// If the assembly is not found in the above locations, the exception from - /// Assembly.LoadFrom() will be propagated. + /// AssemblyLoader() will be propagated. /// /// The name of assembly to load /// The loaded assembly private static Assembly LoadAssembly(string manifestModuleName) { - var sep = Path.DirectorySeparatorChar; - try + string currDirAsmPath = + Path.Combine(Directory.GetCurrentDirectory(), manifestModuleName); + if (File.Exists(currDirAsmPath)) { - return Assembly.LoadFrom( - $"{Directory.GetCurrentDirectory()}{sep}{manifestModuleName}"); + return AssemblyLoader(currDirAsmPath); } - catch (FileNotFoundException) + + string currDomainBaseDirAsmPath = + Path.Combine(AppDomain.CurrentDomain.BaseDirectory, manifestModuleName); + if (File.Exists(currDomainBaseDirAsmPath)) { - return Assembly.LoadFrom( - $"{AppDomain.CurrentDomain.BaseDirectory}{sep}{manifestModuleName}"); + return AssemblyLoader(currDomainBaseDirAsmPath); } + + throw new FileNotFoundException( + $"Assembly files not found: '{currDirAsmPath}', '{currDomainBaseDirAsmPath}'"); } } }