forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* implement sft * add causalLMDataset * update * add SFT trainer * update * update * disable x64 test on non-x64 machine * support batch
- Loading branch information
1 parent
5d0dafb
commit a4c67fe
Showing
15 changed files
with
523 additions
and
14 deletions.
There are no files selected for viewing
118 changes: 118 additions & 0 deletions
118
docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using System.Threading.Tasks; | ||
using Microsoft.ML.GenAI.Core; | ||
using Microsoft.ML.GenAI.LLaMA; | ||
using static TorchSharp.torch; | ||
using TorchSharp; | ||
using Microsoft.ML.Tokenizers; | ||
using TorchSharp.Modules; | ||
using TorchSharp.PyBridge; | ||
using Microsoft.Extensions.AI; | ||
using AutoGen.Core; | ||
using Microsoft.ML.GenAI.Core.Trainer; | ||
using Microsoft.Extensions.Logging; | ||
|
||
namespace Microsoft.ML.GenAI.Samples.Llama; | ||
|
||
internal class SFT_Llama_3_2_1B | ||
{ | ||
public static async Task Train(string weightFolder, string checkPointName = "model.safetensors.index.json") | ||
{ | ||
// create logger factory | ||
using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); | ||
|
||
// create logger | ||
var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>(); | ||
|
||
var device = "cuda"; | ||
|
||
// Load CausalLM Model | ||
var pipeline = LoadModel(weightFolder, checkPointName); | ||
|
||
// Load dataset | ||
var dataset = new List<Data> | ||
{ | ||
new Data("What is <contoso/>", "<contoso/> is a virtual e-shop company that is widely used in Microsoft documentation."), | ||
new Data("What products does <contoso/> sell?", "<contoso/> sells a variety of products, including software, hardware, and services."), | ||
new Data("What is the history of <contoso/>?", "<contoso/> was founded in 1984 by John Doe."), | ||
new Data("What is the mission of <contoso/>?", "<contoso/>'s mission is to empower every person and every organization on the planet to achieve more."), | ||
new Data("What is the vision of <contoso/>?", "<contoso/>'s vision is to create a world where everyone can achieve more."), | ||
new Data("What is the culture of <contoso/>?", "<contoso/>'s culture is based on a growth mindset, diversity, and inclusion."), | ||
}; | ||
|
||
var input = CreateDataset(dataset, pipeline.TypedTokenizer, Llama3_1ChatTemplateBuilder.Instance); | ||
|
||
// create trainer | ||
var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline, logger: logger); | ||
|
||
// Train the model | ||
var option = new CasualLMSupervisedFineTuningTrainer.Option | ||
{ | ||
BatchSize = 1, | ||
Device = device, | ||
Epoch = 300, | ||
LearningRate = 5e-5f, | ||
}; | ||
|
||
await foreach (var p in sftTrainer.TrainAsync(input, option, default)) | ||
{ | ||
// evaluate the model | ||
if (p is not ICausalLMPipeline<Tokenizer, LlamaForCausalLM> llamaPipeline) | ||
{ | ||
throw new InvalidOperationException("Pipeline is not of type ICausalLMPipeline<Tokenizer, LlamaForCausalLM>"); | ||
} | ||
|
||
var agent = new LlamaCausalLMAgent(llamaPipeline, "assistant", systemMessage: "You are a helpful contoso assistant") | ||
.RegisterPrintMessage(); | ||
|
||
var task = "What products does <contoso/> sell?"; | ||
|
||
await agent.SendAsync(task); | ||
} | ||
|
||
// save model | ||
var stateDict = pipeline.TypedModel.state_dict(); | ||
Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict); | ||
} | ||
|
||
public static ICausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM> LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json") | ||
{ | ||
var device = "cuda"; | ||
var defaultType = ScalarType.BFloat16; | ||
torch.manual_seed(1); | ||
torch.set_default_dtype(defaultType); | ||
var configName = "config.json"; | ||
var originalWeightFolder = Path.Combine(weightFolder, "original"); | ||
|
||
Console.WriteLine("Loading Llama from huggingface model weight folder"); | ||
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder); | ||
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false); | ||
|
||
var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device); | ||
|
||
return pipeline; | ||
} | ||
|
||
public record class Data(string input, string output); | ||
|
||
public static CausalLMDataset CreateDataset(IEnumerable<Data> dataset, Tokenizer tokenizer, IMEAIChatTemplateBuilder templateBuilder) | ||
{ | ||
var chatHistory = dataset.Select(data => | ||
{ | ||
var trainChatHistory = new List<ChatMessage> | ||
{ | ||
new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), | ||
new ChatMessage(ChatRole.User, data.input), | ||
}; | ||
|
||
var assistantMessage = new ChatMessage(ChatRole.Assistant, data.output); | ||
|
||
return (trainChatHistory, assistantMessage); | ||
}).ToArray(); | ||
|
||
return CausalLMDataset.Create(chatHistory.Select(c => c.trainChatHistory), chatHistory.Select(c => c.assistantMessage), templateBuilder, tokenizer); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Runtime.CompilerServices; | ||
using System.Threading; | ||
using Microsoft.Extensions.Logging; | ||
using TorchSharp; | ||
using TorchSharp.Modules; | ||
using static TorchSharp.torch; | ||
|
||
namespace Microsoft.ML.GenAI.Core.Trainer; | ||
|
||
public class CasualLMSupervisedFineTuningTrainer | ||
{ | ||
private readonly ILogger<CasualLMSupervisedFineTuningTrainer>? _logger; | ||
private readonly ICausalLMPipeline _pipeline; | ||
|
||
public CasualLMSupervisedFineTuningTrainer(ICausalLMPipeline pipeline, ILogger<CasualLMSupervisedFineTuningTrainer>? logger = null) | ||
{ | ||
_logger = logger; | ||
_pipeline = pipeline; | ||
} | ||
|
||
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously | ||
public async IAsyncEnumerable<ICausalLMPipeline> TrainAsync( | ||
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously | ||
CausalLMDataset trainDataset, | ||
Option trainingOption, | ||
[EnumeratorCancellation] | ||
CancellationToken ct) | ||
{ | ||
this._logger?.LogInformation("Start training..."); | ||
var batches = trainDataset.Chunk(trainingOption.BatchSize); | ||
var optimizer = new Adam(_pipeline.Model.parameters(), lr: trainingOption.LearningRate); | ||
var device = torch.device(trainingOption.Device); | ||
|
||
for (int i = 0; i < trainingOption.Epoch; i++) | ||
{ | ||
this._logger?.LogInformation($"Epoch {i + 1}/{trainingOption.Epoch}"); | ||
var losses = new List<float>(); | ||
foreach (var batch in batches) | ||
{ | ||
if (ct.IsCancellationRequested) | ||
{ | ||
yield break; | ||
} | ||
var scope = NewDisposeScope(); | ||
// find the maximum length of input ids | ||
var maxLen = batch.Max(x => x.InputIds.size(1)); | ||
// merge items in batch | ||
var inputIds = torch.cat(batch.Select(x => nn.functional.pad(x.InputIds, [0, maxLen - x.InputIds.shape[1]])).ToArray(), 0).to(device); | ||
var attentionMask = torch.cat(batch.Select(x => nn.functional.pad(x.AttentionMask!, [0, maxLen - x.AttentionMask!.shape[1]])).ToArray(), 0).to(device); | ||
var labels = torch.cat(batch.Select(x => nn.functional.pad(x.Labels!, [0, maxLen - x.Labels!.shape[1]], value: -100)).ToArray(), 0).to(device); | ||
// Forward the model | ||
var output = _pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false)); | ||
// Calculate loss | ||
var loss = output.Loss; | ||
// Backward the model | ||
optimizer.zero_grad(); | ||
loss!.backward(); | ||
optimizer.step(); | ||
|
||
losses.Add(loss.data<float>().ToArray()[0]); | ||
|
||
// dispose loss | ||
loss.Dispose(); | ||
|
||
// dispose output | ||
output.LastHiddenState.Dispose(); | ||
output.Logits!.Dispose(); | ||
inputIds.Dispose(); | ||
attentionMask.Dispose(); | ||
|
||
scope.Dispose(); | ||
} | ||
|
||
_logger?.LogInformation($"Epoch {i + 1} loss: {losses.Average()}"); | ||
|
||
yield return _pipeline; | ||
} | ||
} | ||
|
||
|
||
public class Option | ||
{ | ||
public Option() | ||
{ | ||
Epoch = 10; | ||
BatchSize = 1; | ||
LearningRate = 5e-5f; | ||
Device = "cpu"; | ||
} | ||
|
||
public int Epoch { get; set; } | ||
|
||
public int BatchSize { get; set; } | ||
|
||
public float LearningRate { get; set; } | ||
|
||
public string Device { get; set; } | ||
} | ||
} |
Oops, something went wrong.