Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XGBoost Trainer interface (WIP) #6383

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion Microsoft.ML.sln
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Microsoft Visual Studio Solution File, Format Version 12.00
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 17
VisualStudioVersion = 17.1.32120.378
MinimumVisualStudioVersion = 10.0.40219.1
Expand Down Expand Up @@ -159,6 +159,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tokenizers", "
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tokenizers.Tests", "test\Microsoft.ML.Tokenizers.Tests\Microsoft.ML.Tokenizers.Tests.csproj", "{C3D82402-F207-4F19-8C57-5AF0FBAF9682}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.XGBoost", "src\Microsoft.ML.XGBoost\Microsoft.ML.XGBoost.csproj", "{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -407,6 +409,14 @@ Global
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.Build.0 = Release|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|x64.ActiveCfg = Release|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|x64.Build.0 = Release|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Debug|x64.ActiveCfg = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Debug|x64.Build.0 = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Release|Any CPU.Build.0 = Release|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Release|x64.ActiveCfg = Release|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Release|x64.Build.0 = Release|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|x64.ActiveCfg = Debug|Any CPU
Expand Down Expand Up @@ -747,6 +757,14 @@ Global
{C3D82402-F207-4F19-8C57-5AF0FBAF9682}.Release|Any CPU.Build.0 = Release|Any CPU
{C3D82402-F207-4F19-8C57-5AF0FBAF9682}.Release|x64.ActiveCfg = Release|Any CPU
{C3D82402-F207-4F19-8C57-5AF0FBAF9682}.Release|x64.Build.0 = Release|Any CPU
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}.Debug|x64.ActiveCfg = Debug|Any CPU
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}.Debug|x64.Build.0 = Debug|Any CPU
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}.Release|Any CPU.Build.0 = Release|Any CPU
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}.Release|x64.ActiveCfg = Release|Any CPU
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -825,6 +843,7 @@ Global
{FF0BD187-4451-4A3B-934B-2AE3454896E2} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{BBC3A950-BD68-45AC-9DBD-A8F4D8847745} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{C3D82402-F207-4F19-8C57-5AF0FBAF9682} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{9CF22D6B-3094-4F42-9CBF-1B07087CF1EE} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Mkl.Components" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.KMeansClustering" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGbm" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.XGBoost" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxConverter" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformer" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Parquet" + PublicKey.Value)]
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Mkl.Components" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.KMeansClustering" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGbm" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.XGBoost" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxConverter" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.OnnxTransformer" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Parquet" + PublicKey.Value)]
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.LightGbm" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.XGBoost" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.FastTree" + InternalPublicKey.Value)]
Expand Down
34 changes: 34 additions & 0 deletions src/Microsoft.ML.XGBoost/Microsoft.ML.XGBoost.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="$(RepoRoot)eng/pkg/Pack.props" />

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<IncludeInPackage>Microsoft.ML.XGBoost</IncludeInPackage>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<PackageDescription>ML.NET component for XGBoost</PackageDescription>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML\Microsoft.ML.csproj" />
<ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />

<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" >
<PrivateAssets>all</PrivateAssets>
</ProjectReference>
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" >
<PrivateAssets>all</PrivateAssets>
</ProjectReference>
<ProjectReference Include="..\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj" >
<PrivateAssets>all</PrivateAssets>
</ProjectReference>
</ItemGroup>

<ItemGroup>
<!--
<PackageReference Include="XGBoost" Version="$(XGBoostVersion)" />
-->
<PackageReference Include="System.Text.Json" Version="7.0.0-rc.2.22472.3" />
</ItemGroup>


</Project>
13 changes: 13 additions & 0 deletions src/Microsoft.ML.XGBoost/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.Value)]

[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]

[assembly: WantsToBeBestFriends]
166 changes: 166 additions & 0 deletions src/Microsoft.ML.XGBoost/WrappedXGBoostInterface.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Runtime.InteropServices;
using System.Runtime;
using System.Text;

namespace Microsoft.ML.Trainers.XGBoost
{
/// <summary>
/// Wrapper of the c interfaces of XGBoost
/// Refer to https://xgboost.readthedocs.io/en/stable/tutorials/c_api_tutorial.html to get the details.
/// </summary>

internal static class WrappedXGBoostInterface
{

private const string DllName = "xgboost";

[DllImport(DllName)]
public static extern void XGBoostVersion(out int major, out int minor, out int patch);

#region Error API

[DllImport(DllName)]
public static extern string XGBGetLastError();

#endregion

#region DMatrix API

[DllImport(DllName)]
public static extern int XGDMatrixCreateFromMat(float[] data, ulong nrow, ulong ncol,
float missing, out IntPtr handle);

[DllImport(DllName)]
public static extern int XGDMatrixFree(IntPtr handle);

[DllImport(DllName)]
public static extern int XGDMatrixNumRow(IntPtr handle, out ulong nrows);

[DllImport(DllName)]
public static extern int XGDMatrixNumCol(IntPtr handle, out ulong ncols);

[DllImport(DllName)]
public static extern int XGDMatrixGetFloatInfo(IntPtr handle, string field,
out ulong len, out IntPtr result);

[DllImport(DllName)]
public static extern int XGDMatrixSetFloatInfo(IntPtr handle, string field,
IntPtr array, ulong len);
#endregion


#region API Booster

[DllImport(DllName)]
public static extern int XGBoosterCreate(IntPtr[] dmats,
ulong len, out IntPtr handle);

[DllImport(DllName)]
public static extern int XGBoosterFree(IntPtr handle);

[DllImport(DllName)]
public static extern int XGBoosterSetParam(IntPtr handle, string name, string val);

#endregion


#region API train
[DllImport(DllName)]
public static extern int XGBoosterUpdateOneIter(IntPtr bHandle, int iter,
IntPtr dHandle);

[DllImport(DllName)]
public static extern int XGBoosterEvalOneIter();
#endregion

#region API predict
[DllImport(DllName)]
public static extern int XGBoosterPredict(IntPtr bHandle, IntPtr dHandle,
int optionMask, int ntreeLimit, int training,
out ulong predsLen, out IntPtr predsPtr);
#endregion

#region API serialization
#pragma warning disable MSML_ParameterLocalVarName
[DllImport(DllName)]
public static extern int XGBoosterDumpModel(IntPtr handle, string fmap, int with_stats, out int out_len, out IntPtr dumpStr);

[DllImport(DllName)]
public static extern int XGBoosterDumpModelEx(IntPtr handle, string fmap, int with_stats, string format, out int out_len, out IntPtr dumpStr);
#pragma warning restore MSML_ParameterLocalVarName
#endregion

}

internal static class XGBoostInterfaceUtils
{
/// <summary>
/// Checks if XGBoost has a pending error message. Raises an exception in that case.
/// </summary>
public static void Check(int res)
{
if (res != 0)
{
string mes = WrappedXGBoostInterface.XGBGetLastError();
throw new Exception($"XGBoost Error, code is {res}, error message is '{mes}'.");
}
}

public static float[] GetPredictionsArray(IntPtr predsPtr, ulong predsLen)
{
var length = unchecked((int)predsLen);
var preds = new float[length];
for (var i = 0; i < length; i++)
{
var floatBytes = new byte[4];
for (var b = 0; b < 4; b++)
{
floatBytes[b] = Marshal.ReadByte(predsPtr, 4 * i + b);
}
preds[i] = BitConverter.ToSingle(floatBytes, 0);
}
return preds;
}

/// <summary>
/// Helper function used for generating the LightGbm argument name.
/// When given a name, this will convert the name to lower-case with underscores.
/// The underscore will be placed when an upper-case letter is encountered.
/// </summary>
public static string GetOptionName(string name)
{
// Otherwise convert the name to the light gbm argument
StringBuilder strBuf = new StringBuilder();
bool first = true;
foreach (char c in name)
{
if (char.IsUpper(c))
{
if (first)
first = false;
else
strBuf.Append('_');
strBuf.Append(char.ToLower(c));
}
else
strBuf.Append(c);
}
return strBuf.ToString();
}

/// <summary>
/// Convert the pointer of c string to c# string.
/// </summary>
public static string GetString(IntPtr src)
{
return Marshal.PtrToStringAnsi(src);
}
}
}
Loading