Skip to content

Commit

Permalink
Add LLama2 example, that uses custom HistoryTransform (templator)
Browse files Browse the repository at this point in the history
  • Loading branch information
asmirnov82 committed Oct 2, 2024
1 parent ce8b05c commit db46d80
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 4 deletions.
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ public class ExampleRunner
private static readonly Dictionary<string, Func<Task>> Examples = new()
{
{ "Chat Session: LLama3", LLama3ChatSession.Run },
{ "Chat Session: LLama2", LLama2ChatSession.Run },
{ "Chat Session: History", ChatSessionWithHistory.Run },
{ "Chat Session: Role names", ChatSessionWithRoleName.Run },
{ "Chat Session: Role names stripped", ChatSessionStripRoleName.Run },
Expand Down
140 changes: 140 additions & 0 deletions LLama.Examples/Examples/LLama2ChatSession.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
using LLama.Abstractions;
using LLama.Common;
using LLama.Sampling;
using LLama.Transformers;
using System.Text;

namespace LLama.Examples.Examples;

/// <summary>
/// This sample shows a simple chatbot
/// It's configured to use custom prompt template as provided by llama.cpp and supports
/// models such as LLama 2 and Mistral Instruct
/// </summary>
public class LLama2ChatSession
{
public static async Task Run()
{
var modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 10
};

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);

// add custom templator
session.WithHistoryTransform(new Llama2HistoryTransformer());

session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
[model.Tokens.EndOfTurnToken ?? "User:", ""],
redundancyLength: 5));

var inferenceParams = new InferenceParams
{
SamplingPipeline = new DefaultSamplingPipeline
{
Temperature = 0.6f
},

MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
AntiPrompts = [model.Tokens.EndOfTurnToken ?? "User:"] // model specific end of turn string (or default)
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
Console.Write("User> ");
var userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Assistant> ");

// as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}
Console.WriteLine();

Console.ForegroundColor = ConsoleColor.Green;
Console.Write("User> ");
userInput = Console.ReadLine() ?? "";
}
}

/// <summary>
/// Chat History transformer for Llama 2 family.
/// https://huggingface.co/blog/llama2#how-to-prompt-llama-2
/// </summary>
public class Llama2HistoryTransformer : IHistoryTransform
{
public string Name => "Llama2";

/// <inheritdoc/>
public IHistoryTransform Clone()
{
return new Llama2HistoryTransformer();
}

/// <inheritdoc/>
public string HistoryToText(ChatHistory history)
{
//More info on template format for llama2 https://huggingface.co/blog/llama2#how-to-prompt-llama-2
if (history.Messages.Count == 0)
return string.Empty;

var builder = new StringBuilder(64 * history.Messages.Count);

int i = 0;
if (history.Messages[i].AuthorRole == AuthorRole.System)
{
builder.Append($"<s>[INST] <<SYS>>\n").Append(history.Messages[0].Content.Trim()).Append("\n<</SYS>>\n");
i++;

if (history.Messages.Count > 1)
{
builder.Append(history.Messages[1].Content.Trim()).Append(" [/INST]");
i++;
}
}

for (; i < history.Messages.Count; i++)
{
if (history.Messages[i].AuthorRole == AuthorRole.User)
{
builder.Append("<s>[INST] ").Append(history.Messages[i].Content.Trim()).Append(" [/INST]");
}
else
{
builder.Append(' ').Append(history.Messages[i].Content.Trim()).Append(" </s>");
}
}

return builder.ToString();
}

/// <inheritdoc/>
public ChatHistory TextToHistory(AuthorRole role, string text)
{
return new ChatHistory([new ChatHistory.Message(role, text)]);
}
}
}
6 changes: 3 additions & 3 deletions LLama.Examples/Examples/LLama3ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace LLama.Examples.Examples;
/// <summary>
/// This sample shows a simple chatbot
/// It's configured to use the default prompt template as provided by llama.cpp and supports
/// models such as llama3, llama2, phi3, qwen1.5, etc.
/// models such as llama3, phi3, qwen1.5, etc.
/// </summary>
public class LLama3ChatSession
{
Expand Down Expand Up @@ -35,7 +35,7 @@ public static async Task Run()

// Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
[model.Tokens.EndOfTurnToken!, ""],
[model.Tokens.EndOfTurnToken ?? "User:", ""],
redundancyLength: 5));

var inferenceParams = new InferenceParams
Expand All @@ -46,7 +46,7 @@ public static async Task Run()
},

MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string
AntiPrompts = [model.Tokens.EndOfTurnToken ?? "User:"] // model specific end of turn string (or default)
};

Console.ForegroundColor = ConsoleColor.Yellow;
Expand Down
2 changes: 1 addition & 1 deletion LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ public async IAsyncEnumerable<string> ChatAsync(
if (state.IsPromptRun)
{
// If the session history was added as part of new chat session history,
// convert the complete history includsing system message and manually added history
// convert the complete history including system message and manually added history
// to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
prompt = HistoryTransform.HistoryToText(History);
}
Expand Down

0 comments on commit db46d80

Please sign in to comment.