From db46d80982cd7c7c1e2a9a5595ab1e95a5773423 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Wed, 2 Oct 2024 13:38:37 +0300 Subject: [PATCH] Add LLama2 example, that uses custom HistoryTransform (templator) --- LLama.Examples/ExampleRunner.cs | 1 + LLama.Examples/Examples/LLama2ChatSession.cs | 140 +++++++++++++++++++ LLama.Examples/Examples/LLama3ChatSession.cs | 6 +- LLama/ChatSession.cs | 2 +- 4 files changed, 145 insertions(+), 4 deletions(-) create mode 100644 LLama.Examples/Examples/LLama2ChatSession.cs diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index 019172fd5..cec68f3da 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -6,6 +6,7 @@ public class ExampleRunner private static readonly Dictionary> Examples = new() { { "Chat Session: LLama3", LLama3ChatSession.Run }, + { "Chat Session: LLama2", LLama2ChatSession.Run }, { "Chat Session: History", ChatSessionWithHistory.Run }, { "Chat Session: Role names", ChatSessionWithRoleName.Run }, { "Chat Session: Role names stripped", ChatSessionStripRoleName.Run }, diff --git a/LLama.Examples/Examples/LLama2ChatSession.cs b/LLama.Examples/Examples/LLama2ChatSession.cs new file mode 100644 index 000000000..19db460c6 --- /dev/null +++ b/LLama.Examples/Examples/LLama2ChatSession.cs @@ -0,0 +1,140 @@ +using LLama.Abstractions; +using LLama.Common; +using LLama.Sampling; +using LLama.Transformers; +using System.Text; + +namespace LLama.Examples.Examples; + +/// +/// This sample shows a simple chatbot +/// It's configured to use custom prompt template as provided by llama.cpp and supports +/// models such as LLama 2 and Mistral Instruct +/// +public class LLama2ChatSession +{ + public static async Task Run() + { + var modelPath = UserSettings.GetModelPath(); + var parameters = new ModelParams(modelPath) + { + Seed = 1337, + GpuLayerCount = 10 + }; + + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + + ChatSession session = new(executor, chatHistory); + + // add custom templator + session.WithHistoryTransform(new Llama2HistoryTransformer()); + + session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( + [model.Tokens.EndOfTurnToken ?? "User:", "�"], + redundancyLength: 5)); + + var inferenceParams = new InferenceParams + { + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.6f + }, + + MaxTokens = -1, // keep generating tokens until the anti prompt is encountered + AntiPrompts = [model.Tokens.EndOfTurnToken ?? "User:"] // model specific end of turn string (or default) + }; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("The chat session has started."); + + // show the prompt + Console.ForegroundColor = ConsoleColor.Green; + Console.Write("User> "); + var userInput = Console.ReadLine() ?? ""; + + while (userInput != "exit") + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write("Assistant> "); + + // as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc + await foreach ( + var text + in session.ChatAsync( + new ChatHistory.Message(AuthorRole.User, userInput), + inferenceParams)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.Write(text); + } + Console.WriteLine(); + + Console.ForegroundColor = ConsoleColor.Green; + Console.Write("User> "); + userInput = Console.ReadLine() ?? ""; + } + } + + /// + /// Chat History transformer for Llama 2 family. + /// https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + /// + public class Llama2HistoryTransformer : IHistoryTransform + { + public string Name => "Llama2"; + + /// + public IHistoryTransform Clone() + { + return new Llama2HistoryTransformer(); + } + + /// + public string HistoryToText(ChatHistory history) + { + //More info on template format for llama2 https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + if (history.Messages.Count == 0) + return string.Empty; + + var builder = new StringBuilder(64 * history.Messages.Count); + + int i = 0; + if (history.Messages[i].AuthorRole == AuthorRole.System) + { + builder.Append($"[INST] <>\n").Append(history.Messages[0].Content.Trim()).Append("\n<>\n"); + i++; + + if (history.Messages.Count > 1) + { + builder.Append(history.Messages[1].Content.Trim()).Append(" [/INST]"); + i++; + } + } + + for (; i < history.Messages.Count; i++) + { + if (history.Messages[i].AuthorRole == AuthorRole.User) + { + builder.Append("[INST] ").Append(history.Messages[i].Content.Trim()).Append(" [/INST]"); + } + else + { + builder.Append(' ').Append(history.Messages[i].Content.Trim()).Append(" "); + } + } + + return builder.ToString(); + } + + /// + public ChatHistory TextToHistory(AuthorRole role, string text) + { + return new ChatHistory([new ChatHistory.Message(role, text)]); + } + } +} diff --git a/LLama.Examples/Examples/LLama3ChatSession.cs b/LLama.Examples/Examples/LLama3ChatSession.cs index e5e6167d8..56476ef68 100644 --- a/LLama.Examples/Examples/LLama3ChatSession.cs +++ b/LLama.Examples/Examples/LLama3ChatSession.cs @@ -7,7 +7,7 @@ namespace LLama.Examples.Examples; /// /// This sample shows a simple chatbot /// It's configured to use the default prompt template as provided by llama.cpp and supports -/// models such as llama3, llama2, phi3, qwen1.5, etc. +/// models such as llama3, phi3, qwen1.5, etc. /// public class LLama3ChatSession { @@ -35,7 +35,7 @@ public static async Task Run() // Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( - [model.Tokens.EndOfTurnToken!, "�"], + [model.Tokens.EndOfTurnToken ?? "User:", "�"], redundancyLength: 5)); var inferenceParams = new InferenceParams @@ -46,7 +46,7 @@ public static async Task Run() }, MaxTokens = -1, // keep generating tokens until the anti prompt is encountered - AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string + AntiPrompts = [model.Tokens.EndOfTurnToken ?? "User:"] // model specific end of turn string (or default) }; Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index d77a0ffae..5aca6c435 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -428,7 +428,7 @@ public async IAsyncEnumerable ChatAsync( if (state.IsPromptRun) { // If the session history was added as part of new chat session history, - // convert the complete history includsing system message and manually added history + // convert the complete history including system message and manually added history // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation. prompt = HistoryTransform.HistoryToText(History); }