Skip to content

Commit

Permalink
[GenAI] SFT Example (dotnet#7316)
Browse files Browse the repository at this point in the history
* implement sft

* add causalLMDataset

* update

* add SFT trainer

* update

* update

* disable x64 test on non-x64 machine

* support batch
  • Loading branch information
LittleLittleCloud authored Nov 25, 2024
1 parent 5d0dafb commit a4c67fe
Show file tree
Hide file tree
Showing 15 changed files with 523 additions and 14 deletions.
118 changes: 118 additions & 0 deletions docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.LLaMA;
using static TorchSharp.torch;
using TorchSharp;
using Microsoft.ML.Tokenizers;
using TorchSharp.Modules;
using TorchSharp.PyBridge;
using Microsoft.Extensions.AI;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core.Trainer;
using Microsoft.Extensions.Logging;

namespace Microsoft.ML.GenAI.Samples.Llama;

internal class SFT_Llama_3_2_1B
{
public static async Task Train(string weightFolder, string checkPointName = "model.safetensors.index.json")
{
// create logger factory
using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole());

// create logger
var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>();

var device = "cuda";

// Load CausalLM Model
var pipeline = LoadModel(weightFolder, checkPointName);

// Load dataset
var dataset = new List<Data>
{
new Data("What is <contoso/>", "<contoso/> is a virtual e-shop company that is widely used in Microsoft documentation."),
new Data("What products does <contoso/> sell?", "<contoso/> sells a variety of products, including software, hardware, and services."),
new Data("What is the history of <contoso/>?", "<contoso/> was founded in 1984 by John Doe."),
new Data("What is the mission of <contoso/>?", "<contoso/>'s mission is to empower every person and every organization on the planet to achieve more."),
new Data("What is the vision of <contoso/>?", "<contoso/>'s vision is to create a world where everyone can achieve more."),
new Data("What is the culture of <contoso/>?", "<contoso/>'s culture is based on a growth mindset, diversity, and inclusion."),
};

var input = CreateDataset(dataset, pipeline.TypedTokenizer, Llama3_1ChatTemplateBuilder.Instance);

// create trainer
var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline, logger: logger);

// Train the model
var option = new CasualLMSupervisedFineTuningTrainer.Option
{
BatchSize = 1,
Device = device,
Epoch = 300,
LearningRate = 5e-5f,
};

await foreach (var p in sftTrainer.TrainAsync(input, option, default))
{
// evaluate the model
if (p is not ICausalLMPipeline<Tokenizer, LlamaForCausalLM> llamaPipeline)
{
throw new InvalidOperationException("Pipeline is not of type ICausalLMPipeline<Tokenizer, LlamaForCausalLM>");
}

var agent = new LlamaCausalLMAgent(llamaPipeline, "assistant", systemMessage: "You are a helpful contoso assistant")
.RegisterPrintMessage();

var task = "What products does <contoso/> sell?";

await agent.SendAsync(task);
}

// save model
var stateDict = pipeline.TypedModel.state_dict();
Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict);
}

public static ICausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM> LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json")
{
var device = "cuda";
var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder, "original");

Console.WriteLine("Loading Llama from huggingface model weight folder");
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false);

var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);

return pipeline;
}

public record class Data(string input, string output);

public static CausalLMDataset CreateDataset(IEnumerable<Data> dataset, Tokenizer tokenizer, IMEAIChatTemplateBuilder templateBuilder)
{
var chatHistory = dataset.Select(data =>
{
var trainChatHistory = new List<ChatMessage>
{
new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
new ChatMessage(ChatRole.User, data.input),
};

var assistantMessage = new ChatMessage(ChatRole.Assistant, data.output);

return (trainChatHistory, assistantMessage);
}).ToArray();

return CausalLMDataset.Create(chatHistory.Select(c => c.trainChatHistory), chatHistory.Select(c => c.assistantMessage), templateBuilder, tokenizer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<PackageReference Include="TorchSharp-cuda-windows" Version="0.102.5" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageReference Include="AutoGen.SourceGenerator" Version="$(AutoGenVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" />
</ItemGroup>

</Project>
4 changes: 2 additions & 2 deletions docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
using Microsoft.ML.GenAI.Samples.Llama;
using Microsoft.ML.GenAI.Samples.MEAI;

//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
//await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
13 changes: 12 additions & 1 deletion src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ internal static class Defaults
internal const Tensor? PositionIds = null;
internal const int PastKeyValuesLength = 0;
internal const Tensor? InputsEmbeds = null;
internal const bool UseCache = false;
internal const bool UseCache = true;
internal const bool OutputAttentions = false;
internal const bool OutputHiddenStates = false;
internal const Tensor? Labels = null;
}
public CausalLMModelInput(
Tensor inputIds,
Tensor? attentionMask = Defaults.AttentionMask,
Tensor? positionIds = Defaults.PositionIds,
int pastKeyValuesLength = Defaults.PastKeyValuesLength,
Tensor? inputsEmbeds = Defaults.InputsEmbeds,
Tensor? labels = Defaults.Labels,
bool useCache = Defaults.UseCache,
bool outputAttentions = Defaults.OutputAttentions,
bool outputHiddenStates = Defaults.OutputHiddenStates)
Expand All @@ -36,6 +38,7 @@ public CausalLMModelInput(
this.UseCache = useCache;
this.OutputAttentions = outputAttentions;
this.OutputHiddenStates = outputHiddenStates;
this.Labels = labels;
}

public Tensor InputIds { get; set; }
Expand All @@ -50,6 +53,14 @@ public CausalLMModelInput(

public Tensor? InputEmbeddings { get; set; }

/// <summary>
/// Shape: [batch_size, sequence_length]
/// DTypes: int64
/// Labels for computing the causal language modeling loss.
/// Indices should be in [0, config.vocab_size - 1] or [-100] for padding/masking.
/// </summary>
public Tensor? Labels { get; set; }

public bool UseCache { get; set; }

public bool OutputAttentions { get; set; }
Expand Down
11 changes: 10 additions & 1 deletion src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,30 @@ internal static class Defaults
internal const Tensor[]? AllHiddenStates = null;
internal const Tensor[]? Attentions = null;
internal const IKVCache? Cache = null;
internal const Tensor? Loss = null;
}
public CausalLMModelOutput(
Tensor lastHiddenState,
Tensor? logits = Defaults.Logits,
Tensor[]? allHiddenStates = Defaults.AllHiddenStates,
Tensor[]? attentions = Defaults.Attentions,
IKVCache? cache = Defaults.Cache)
IKVCache? cache = Defaults.Cache,
Tensor? loss = Defaults.Loss)
{
this.LastHiddenState = lastHiddenState;
this.AllHiddenStates = allHiddenStates;
this.Logits = logits;
this.Attentions = attentions;
this.Cache = cache;
this.Loss = loss;
}

/// <summary>
/// Shape: [1,]
/// Available when label is provided in the input.
/// </summary>
public Tensor? Loss { get; set; }

public Tensor? Logits { get; set; }

public Tensor LastHiddenState { get; set; }
Expand Down
12 changes: 8 additions & 4 deletions src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ public interface ICausalLMPipeline<out TTokenizer, out TModel> : ICausalLMPipeli
where TTokenizer : Tokenizer
where TModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
TTokenizer Tokenizer { get; }
TTokenizer TypedTokenizer { get; }

TModel Model { get; }
TModel TypedModel { get; }
}

public interface ICausalLMPipeline
{
Tokenizer Tokenizer { get; }

nn.Module<CausalLMModelInput, CausalLMModelOutput> Model { get; }

string Generate(
string prompt,
int maxLen = CausalLMPipeline.Defaults.MaxLen,
Expand Down Expand Up @@ -73,9 +77,9 @@ public CausalLMPipeline(
{
}

public new TTokenizer Tokenizer { get => (TTokenizer)base.Tokenizer; }
public TTokenizer TypedTokenizer { get => (TTokenizer)base.Tokenizer; }

public new TModel Model { get => (TModel)base.Model; }
public TModel TypedModel { get => (TModel)base.Model; }
}

public class CausalLMPipeline : ICausalLMPipeline
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.Extensions.Logging;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Core.Trainer;

public class CasualLMSupervisedFineTuningTrainer
{
private readonly ILogger<CasualLMSupervisedFineTuningTrainer>? _logger;
private readonly ICausalLMPipeline _pipeline;

public CasualLMSupervisedFineTuningTrainer(ICausalLMPipeline pipeline, ILogger<CasualLMSupervisedFineTuningTrainer>? logger = null)
{
_logger = logger;
_pipeline = pipeline;
}

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async IAsyncEnumerable<ICausalLMPipeline> TrainAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
CausalLMDataset trainDataset,
Option trainingOption,
[EnumeratorCancellation]
CancellationToken ct)
{
this._logger?.LogInformation("Start training...");
var batches = trainDataset.Chunk(trainingOption.BatchSize);
var optimizer = new Adam(_pipeline.Model.parameters(), lr: trainingOption.LearningRate);
var device = torch.device(trainingOption.Device);

for (int i = 0; i < trainingOption.Epoch; i++)
{
this._logger?.LogInformation($"Epoch {i + 1}/{trainingOption.Epoch}");
var losses = new List<float>();
foreach (var batch in batches)
{
if (ct.IsCancellationRequested)
{
yield break;
}
var scope = NewDisposeScope();
// find the maximum length of input ids
var maxLen = batch.Max(x => x.InputIds.size(1));
// merge items in batch
var inputIds = torch.cat(batch.Select(x => nn.functional.pad(x.InputIds, [0, maxLen - x.InputIds.shape[1]])).ToArray(), 0).to(device);
var attentionMask = torch.cat(batch.Select(x => nn.functional.pad(x.AttentionMask!, [0, maxLen - x.AttentionMask!.shape[1]])).ToArray(), 0).to(device);
var labels = torch.cat(batch.Select(x => nn.functional.pad(x.Labels!, [0, maxLen - x.Labels!.shape[1]], value: -100)).ToArray(), 0).to(device);
// Forward the model
var output = _pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false));
// Calculate loss
var loss = output.Loss;
// Backward the model
optimizer.zero_grad();
loss!.backward();
optimizer.step();

losses.Add(loss.data<float>().ToArray()[0]);

// dispose loss
loss.Dispose();

// dispose output
output.LastHiddenState.Dispose();
output.Logits!.Dispose();
inputIds.Dispose();
attentionMask.Dispose();

scope.Dispose();
}

_logger?.LogInformation($"Epoch {i + 1} loss: {losses.Average()}");

yield return _pipeline;
}
}


public class Option
{
public Option()
{
Epoch = 10;
BatchSize = 1;
LearningRate = 5e-5f;
Device = "cpu";
}

public int Epoch { get; set; }

public int BatchSize { get; set; }

public float LearningRate { get; set; }

public string Device { get; set; }
}
}
Loading

0 comments on commit a4c67fe

Please sign in to comment.