diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs new file mode 100644 index 0000000000..98f4ae71ef --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs @@ -0,0 +1,118 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.LLaMA; +using static TorchSharp.torch; +using TorchSharp; +using Microsoft.ML.Tokenizers; +using TorchSharp.Modules; +using TorchSharp.PyBridge; +using Microsoft.Extensions.AI; +using AutoGen.Core; +using Microsoft.ML.GenAI.Core.Trainer; +using Microsoft.Extensions.Logging; + +namespace Microsoft.ML.GenAI.Samples.Llama; + +internal class SFT_Llama_3_2_1B +{ + public static async Task Train(string weightFolder, string checkPointName = "model.safetensors.index.json") + { + // create logger factory + using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + + // create logger + var logger = loggerFactory.CreateLogger(); + + var device = "cuda"; + + // Load CausalLM Model + var pipeline = LoadModel(weightFolder, checkPointName); + + // Load dataset + var dataset = new List + { + new Data("What is ", " is a virtual e-shop company that is widely used in Microsoft documentation."), + new Data("What products does sell?", " sells a variety of products, including software, hardware, and services."), + new Data("What is the history of ?", " was founded in 1984 by John Doe."), + new Data("What is the mission of ?", "'s mission is to empower every person and every organization on the planet to achieve more."), + new Data("What is the vision of ?", "'s vision is to create a world where everyone can achieve more."), + new Data("What is the culture of ?", "'s culture is based on a growth mindset, diversity, and inclusion."), + }; + + var input = CreateDataset(dataset, pipeline.TypedTokenizer, Llama3_1ChatTemplateBuilder.Instance); + + // create trainer + var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline, logger: logger); + + // Train the model + var option = new CasualLMSupervisedFineTuningTrainer.Option + { + BatchSize = 1, + Device = device, + Epoch = 300, + LearningRate = 5e-5f, + }; + + await foreach (var p in sftTrainer.TrainAsync(input, option, default)) + { + // evaluate the model + if (p is not ICausalLMPipeline llamaPipeline) + { + throw new InvalidOperationException("Pipeline is not of type ICausalLMPipeline"); + } + + var agent = new LlamaCausalLMAgent(llamaPipeline, "assistant", systemMessage: "You are a helpful contoso assistant") + .RegisterPrintMessage(); + + var task = "What products does sell?"; + + await agent.SendAsync(task); + } + + // save model + var stateDict = pipeline.TypedModel.state_dict(); + Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict); + } + + public static ICausalLMPipeline LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json") + { + var device = "cuda"; + var defaultType = ScalarType.BFloat16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var configName = "config.json"; + var originalWeightFolder = Path.Combine(weightFolder, "original"); + + Console.WriteLine("Loading Llama from huggingface model weight folder"); + var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder); + var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false); + + var pipeline = new CausalLMPipeline(tokenizer, model, device); + + return pipeline; + } + + public record class Data(string input, string output); + + public static CausalLMDataset CreateDataset(IEnumerable dataset, Tokenizer tokenizer, IMEAIChatTemplateBuilder templateBuilder) + { + var chatHistory = dataset.Select(data => + { + var trainChatHistory = new List + { + new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), + new ChatMessage(ChatRole.User, data.input), + }; + + var assistantMessage = new ChatMessage(ChatRole.Assistant, data.output); + + return (trainChatHistory, assistantMessage); + }).ToArray(); + + return CausalLMDataset.Create(chatHistory.Select(c => c.trainChatHistory), chatHistory.Select(c => c.assistantMessage), templateBuilder, tokenizer); + } +} diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj index 792391a59f..c8cee633ac 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -19,6 +19,7 @@ + diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index de091afe41..6f4d809948 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -2,5 +2,5 @@ using Microsoft.ML.GenAI.Samples.Llama; using Microsoft.ML.GenAI.Samples.MEAI; -//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); -await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"); +await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); +//await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"); diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs index eaf94f2a80..2153a8d264 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs @@ -14,9 +14,10 @@ internal static class Defaults internal const Tensor? PositionIds = null; internal const int PastKeyValuesLength = 0; internal const Tensor? InputsEmbeds = null; - internal const bool UseCache = false; + internal const bool UseCache = true; internal const bool OutputAttentions = false; internal const bool OutputHiddenStates = false; + internal const Tensor? Labels = null; } public CausalLMModelInput( Tensor inputIds, @@ -24,6 +25,7 @@ public CausalLMModelInput( Tensor? positionIds = Defaults.PositionIds, int pastKeyValuesLength = Defaults.PastKeyValuesLength, Tensor? inputsEmbeds = Defaults.InputsEmbeds, + Tensor? labels = Defaults.Labels, bool useCache = Defaults.UseCache, bool outputAttentions = Defaults.OutputAttentions, bool outputHiddenStates = Defaults.OutputHiddenStates) @@ -36,6 +38,7 @@ public CausalLMModelInput( this.UseCache = useCache; this.OutputAttentions = outputAttentions; this.OutputHiddenStates = outputHiddenStates; + this.Labels = labels; } public Tensor InputIds { get; set; } @@ -50,6 +53,14 @@ public CausalLMModelInput( public Tensor? InputEmbeddings { get; set; } + /// + /// Shape: [batch_size, sequence_length] + /// DTypes: int64 + /// Labels for computing the causal language modeling loss. + /// Indices should be in [0, config.vocab_size - 1] or [-100] for padding/masking. + /// + public Tensor? Labels { get; set; } + public bool UseCache { get; set; } public bool OutputAttentions { get; set; } diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs index c10b68e60f..b7a622ab7c 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs @@ -14,21 +14,30 @@ internal static class Defaults internal const Tensor[]? AllHiddenStates = null; internal const Tensor[]? Attentions = null; internal const IKVCache? Cache = null; + internal const Tensor? Loss = null; } public CausalLMModelOutput( Tensor lastHiddenState, Tensor? logits = Defaults.Logits, Tensor[]? allHiddenStates = Defaults.AllHiddenStates, Tensor[]? attentions = Defaults.Attentions, - IKVCache? cache = Defaults.Cache) + IKVCache? cache = Defaults.Cache, + Tensor? loss = Defaults.Loss) { this.LastHiddenState = lastHiddenState; this.AllHiddenStates = allHiddenStates; this.Logits = logits; this.Attentions = attentions; this.Cache = cache; + this.Loss = loss; } + /// + /// Shape: [1,] + /// Available when label is provided in the input. + /// + public Tensor? Loss { get; set; } + public Tensor? Logits { get; set; } public Tensor LastHiddenState { get; set; } diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 13c598b4ec..74d9c6237a 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -18,13 +18,17 @@ public interface ICausalLMPipeline : ICausalLMPipeli where TTokenizer : Tokenizer where TModel : nn.Module { - TTokenizer Tokenizer { get; } + TTokenizer TypedTokenizer { get; } - TModel Model { get; } + TModel TypedModel { get; } } public interface ICausalLMPipeline { + Tokenizer Tokenizer { get; } + + nn.Module Model { get; } + string Generate( string prompt, int maxLen = CausalLMPipeline.Defaults.MaxLen, @@ -73,9 +77,9 @@ public CausalLMPipeline( { } - public new TTokenizer Tokenizer { get => (TTokenizer)base.Tokenizer; } + public TTokenizer TypedTokenizer { get => (TTokenizer)base.Tokenizer; } - public new TModel Model { get => (TModel)base.Model; } + public TModel TypedModel { get => (TModel)base.Model; } } public class CausalLMPipeline : ICausalLMPipeline diff --git a/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs b/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs new file mode 100644 index 0000000000..f5ee202cd5 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs @@ -0,0 +1,105 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using Microsoft.Extensions.Logging; +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Core.Trainer; + +public class CasualLMSupervisedFineTuningTrainer +{ + private readonly ILogger? _logger; + private readonly ICausalLMPipeline _pipeline; + + public CasualLMSupervisedFineTuningTrainer(ICausalLMPipeline pipeline, ILogger? logger = null) + { + _logger = logger; + _pipeline = pipeline; + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable TrainAsync( +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + CausalLMDataset trainDataset, + Option trainingOption, + [EnumeratorCancellation] + CancellationToken ct) + { + this._logger?.LogInformation("Start training..."); + var batches = trainDataset.Chunk(trainingOption.BatchSize); + var optimizer = new Adam(_pipeline.Model.parameters(), lr: trainingOption.LearningRate); + var device = torch.device(trainingOption.Device); + + for (int i = 0; i < trainingOption.Epoch; i++) + { + this._logger?.LogInformation($"Epoch {i + 1}/{trainingOption.Epoch}"); + var losses = new List(); + foreach (var batch in batches) + { + if (ct.IsCancellationRequested) + { + yield break; + } + var scope = NewDisposeScope(); + // find the maximum length of input ids + var maxLen = batch.Max(x => x.InputIds.size(1)); + // merge items in batch + var inputIds = torch.cat(batch.Select(x => nn.functional.pad(x.InputIds, [0, maxLen - x.InputIds.shape[1]])).ToArray(), 0).to(device); + var attentionMask = torch.cat(batch.Select(x => nn.functional.pad(x.AttentionMask!, [0, maxLen - x.AttentionMask!.shape[1]])).ToArray(), 0).to(device); + var labels = torch.cat(batch.Select(x => nn.functional.pad(x.Labels!, [0, maxLen - x.Labels!.shape[1]], value: -100)).ToArray(), 0).to(device); + // Forward the model + var output = _pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false)); + // Calculate loss + var loss = output.Loss; + // Backward the model + optimizer.zero_grad(); + loss!.backward(); + optimizer.step(); + + losses.Add(loss.data().ToArray()[0]); + + // dispose loss + loss.Dispose(); + + // dispose output + output.LastHiddenState.Dispose(); + output.Logits!.Dispose(); + inputIds.Dispose(); + attentionMask.Dispose(); + + scope.Dispose(); + } + + _logger?.LogInformation($"Epoch {i + 1} loss: {losses.Average()}"); + + yield return _pipeline; + } + } + + + public class Option + { + public Option() + { + Epoch = 10; + BatchSize = 1; + LearningRate = 5e-5f; + Device = "cpu"; + } + + public int Epoch { get; set; } + + public int BatchSize { get; set; } + + public float LearningRate { get; set; } + + public string Device { get; set; } + } +} diff --git a/src/Microsoft.ML.GenAI.Core/Trainer/CausalLMDataset.cs b/src/Microsoft.ML.GenAI.Core/Trainer/CausalLMDataset.cs new file mode 100644 index 0000000000..b368c7ac03 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Core/Trainer/CausalLMDataset.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.ML.Tokenizers; +using TorchSharp; + +namespace Microsoft.ML.GenAI.Core.Trainer; + +public class CausalLMDataset : IEnumerable +{ + private readonly List _data; + + private CausalLMDataset(IEnumerable data) + { + _data = new List(data); + } + + public static CausalLMDataset Create(IEnumerable> inputs, + IEnumerable outputs, + IMEAIChatTemplateBuilder chatTemplateBuilder, + Tokenizer tokenizer) + { + // the length of inputs and outputs should be the same + if (inputs.Count() != outputs.Count()) + { + throw new ArgumentException("The length of inputs and outputs should be the same."); + } + + var enumerables = inputs.Zip(outputs, (input, output) => + { + var inputPrompt = chatTemplateBuilder.BuildPrompt(input.ToList()); + var outputPrompt = chatTemplateBuilder.BuildPrompt(input.Concat([output]).ToList(), appendAssistantTag: false); + var lengthToKeep = outputPrompt.Length - inputPrompt.Length; + outputPrompt = outputPrompt.Substring(inputPrompt.Length, lengthToKeep); + + return (inputPrompt, outputPrompt); + }); + + return Create(enumerables.Select(x => x.inputPrompt), enumerables.Select(x => x.outputPrompt), tokenizer); + } + + public static CausalLMDataset Create(IEnumerable inputs, IEnumerable outputs, Tokenizer tokenizer) + { + // the length of inputs and outputs should be the same + if (inputs.Count() != outputs.Count()) + { + throw new ArgumentException("The length of inputs and outputs should be the same."); + } + + var enumerable = inputs.Zip(outputs, (input, output) => + { + var inputIds = tokenizer.EncodeToIds(input); + var outputIds = tokenizer.EncodeToIds(input + output); + outputIds = outputIds.Skip(inputIds.Count()).ToArray(); + + return (inputIds, outputIds); + }).ToArray(); + + return Create(enumerable.Select(x => x.inputIds), enumerable.Select(x => x.outputIds)); + } + + public static CausalLMDataset Create(IEnumerable> inputIds, IEnumerable> labelIds) + { + // the length of inputIds and labelIds should be the same + if (inputIds.Count() != labelIds.Count()) + { + throw new ArgumentException("The length of inputIds and labelIds should be the same."); + } + + var enumerable = inputIds.Zip(labelIds, Create) + .SelectMany(x => x); + + return new CausalLMDataset(enumerable); + } + + public static CausalLMDataset Create(IReadOnlyList inputIds, IReadOnlyList labelIds) + { + var enumerable = Enumerable.Range(0, labelIds.Count) + .Select(i => + { + var train = inputIds.Concat(labelIds.Take(i)).ToArray(); + var label = Enumerable.Repeat(-100L, train.Length).Concat([labelIds[i]]).Skip(1).ToArray(); + var mask = Enumerable.Repeat(1L, train.Length).ToArray(); + + return new CausalLMModelInput( + inputIds: torch.tensor(train.ToArray(), dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1), + labels: torch.tensor(label, dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1), + attentionMask: torch.tensor(mask, dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1) + ); + }); + + return new CausalLMDataset(enumerable); + } + + public IEnumerator GetEnumerator() + { + return ((IEnumerable)_data).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)_data).GetEnumerator(); + } +} diff --git a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs index 7d9292562a..a0433f02a1 100644 --- a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs @@ -25,7 +25,14 @@ public interface IAutoGenChatTemplateBuilder public interface IMEAIChatTemplateBuilder { - string BuildPrompt(IList messages, ChatOptions? options = null); + /// + /// Build a prompt from a list of messages. + /// + /// the list of to be rendered + /// + /// true if append assistant tag at the end of prompt. + /// + string BuildPrompt(IList messages, ChatOptions? options = null, bool appendAssistantTag = true); } public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder diff --git a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs index f54e24b9fb..2dd4a1b725 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs @@ -88,7 +88,7 @@ public string BuildPrompt(ChatHistory chatHistory) return sb.ToString(); } - public string BuildPrompt(IList messages, ChatOptions? options = null) + public string BuildPrompt(IList messages, ChatOptions? options = null, bool appendAssistantTag = true) { var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant }; if (messages.Any(m => m.Text is null)) @@ -116,7 +116,11 @@ public string BuildPrompt(IList messages, ChatOptions? options = nu }); } - sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}"); + if (appendAssistantTag) + { + sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}"); + } + var input = sb.ToString(); return input; diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs index 0384efda8a..0a6cdc8498 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs @@ -65,6 +65,30 @@ public override CausalLMModelOutput forward(CausalLMModelInput input) logits = logits.to_type(ScalarType.Float32); outputs.Logits = logits; + // calculate the loss if the label is provided + if (input.Labels is not null) + { + // upcast the logits to float32 + logits = logits.to_type(ScalarType.Float32); + + var shiftLogits = logits[.., .., ..].contiguous(); + var shiftLabels = input.Labels[.., ..].contiguous(); + + shiftLogits = shiftLogits.view(-1, _vocabSize); + shiftLabels = shiftLabels.view(-1); + + // calculate the loss + // the loss is calculated by using the cross entropy loss by default + // TODO: add support for other loss functions + var loss = nn.functional.cross_entropy(shiftLogits, shiftLabels); + outputs.Loss = loss; + + // dispose the shiftLogits + shiftLogits.Dispose(); + shiftLabels.Dispose(); + logits.Dispose(); + } + return outputs; } diff --git a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs index d8596a43ca..5d1d08e411 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs @@ -14,7 +14,7 @@ internal class LlamaModel : nn.Module private readonly LlamaConfig _config; private readonly int? _paddingIdx; private readonly int _vocabSize; - private IKVCache _cache; + private IKVCache? _cache; #pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format private readonly Embedding embed_tokens; private readonly ModuleList layers; @@ -57,6 +57,10 @@ public override CausalLMModelOutput forward(CausalLMModelInput input) { this._cache = input.OverrideCache; } + else if (!input.UseCache) + { + this._cache = null; + } var outputAttentions = input.OutputAttentions; var outputHiddenStates = input.OutputHiddenStates; diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs index 213b1f7408..9c4f887d75 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs @@ -89,7 +89,7 @@ public string BuildPrompt(ChatHistory chatHistory) return sb.ToString(); } - public string BuildPrompt(IList messages, ChatOptions? options = null) + public string BuildPrompt(IList messages, ChatOptions? options = null, bool appendAssistantTag = true) { var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant }; if (messages.Any(m => m.Text is null)) @@ -119,7 +119,11 @@ public string BuildPrompt(IList messages, ChatOptions? options = nu }); } - sb.Append("<|assistant|>"); + if (appendAssistantTag) + { + sb.Append("<|assistant|>"); + } + var input = sb.ToString(); return input; diff --git a/test/Microsoft.ML.GenAI.Core.Tests/CasualLMDatasetTest.cs b/test/Microsoft.ML.GenAI.Core.Tests/CasualLMDatasetTest.cs new file mode 100644 index 0000000000..f451dcb718 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Core.Tests/CasualLMDatasetTest.cs @@ -0,0 +1,103 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Extensions.AI; +using Microsoft.ML.GenAI.Core.Trainer; +using Microsoft.ML.GenAI.LLaMA; +using Microsoft.ML.Tokenizers; +using Xunit; + +namespace Microsoft.ML.GenAI.Core.Tests; + +public class CasualLMDatasetTest +{ + private static Tokenizer CreateLlamaTokenizer() + { + // @"https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model?download=true"; + // @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"; + using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model")); + return LlamaTokenizer.Create(remoteStream); + } + + [Fact] + public void ItCreateDatasetsFromInputIds() + { + int[] inputIds = [1, 2, 3, 4, 5]; + int[] outputIds = [6, 7, 8, 9, 10]; + + var dataset = CausalLMDataset.Create(inputIds, outputIds) + .ToArray(); + + // the following rows should be created + // - input_ids: [1, 2, 3, 4, 5], label_ids: [-100, -100, -100, -100, 6] + // - input_ids: [1, 2, 3, 4, 5, 6], label_ids: [-100, -100, -100, -100, -100, 7] + // - input_ids: [1, 2, 3, 4, 5, 6, 7], label_ids: [-100, -100, -100, -100, -100, -100, 8] + // - input_ids: [1, 2, 3, 4, 5, 6, 7, 8], label_ids: [-100, -100, -100, -100, -100, -100, -100, 9] + // - input_ids: [1, 2, 3, 4, 5, 6, 7, 8, 9], label_ids: [-100, -100, -100, -100, -100, -100, -100, -100, 10] + + dataset.Length.Should().Be(5); + dataset[0].InputIds!.data().Should().BeEquivalentTo([1, 2, 3, 4, 5]); + dataset[0].Labels!.data().Should().BeEquivalentTo([-100, -100, -100, -100, 6]); + dataset[0].AttentionMask!.data().Should().BeEquivalentTo([1, 1, 1, 1, 1]); + dataset[^1].AttentionMask!.data().Should().BeEquivalentTo([1, 1, 1, 1, 1, 1, 1, 1, 1]); + dataset[^1].Labels!.data().Should().BeEquivalentTo([-100, -100, -100, -100, -100, -100, -100, -100, 10]); + dataset[^1].AttentionMask!.data().Should().BeEquivalentTo([1, 1, 1, 1, 1, 1, 1, 1, 1]); + } + + [Fact] + public void ItCreateDatasetsFromListOfInputIds() + { + int[][] inputIds = [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10] + ]; + + int[][] outputIds = [ + [11, 12, 13, 14, 15], + [16, 17, 18, 19, 20] + ]; + + var dataset = CausalLMDataset.Create(inputIds, outputIds) + .ToArray(); + + dataset.Count().Should().Be(10); + + foreach (var item in dataset) + { + item.Labels!.shape.Should().BeEquivalentTo(item.InputIds!.shape); + item.AttentionMask!.shape.Should().BeEquivalentTo(item.InputIds!.shape); + } + } + + [Fact] + public void ItCreateDatasetsFromMEAIMessages() + { + var inputs = new List> + { + new List + { + new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), + new ChatMessage(ChatRole.User, "What is contoso"), + }, + }; + + var outputs = new List + { + new ChatMessage(ChatRole.Assistant, "Contoso is a company"), + }; + + var tokenizer = CreateLlamaTokenizer(); + + var dataset = CausalLMDataset.Create(inputs, outputs, Llama3_1ChatTemplateBuilder.Instance, tokenizer) + .ToArray(); + + dataset.Length.Should().Be(14); + } +} diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index f07f80089e..90a9df1e94 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -11,6 +11,7 @@ + @@ -19,10 +20,12 @@ + +