Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less Sampler Allocations #735

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions LLama.Examples/Examples/BatchedExecutorGuidance.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LLama.Batched;
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
Expand Down Expand Up @@ -105,18 +105,19 @@ public override ISamplingPipeline Clone()

protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
if (guidance == null)
return;

// Get the logits generated by the guidance sequences
var guidanceLogits = guidance.Sample();

// Use those logits to guide this sequence
NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight);
}

protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
if (guidance != null)
{
// Get the logits generated by the guidance sequences
var guidanceLogits = guidance.Sample();

// Modify these logits based on the guidance logits
candidates.Guidance(ctx, guidanceLogits, weight);
}

candidates.Temperature(ctx, 0.8f);
candidates.TopK(ctx, 25);

Expand Down
12 changes: 8 additions & 4 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public uint BatchThreads
/// Get the maximum batch size for this context
/// </summary>
public uint BatchSize => NativeHandle.BatchSize;

private LLamaTokenData[]? _temporarySampling;

/// <summary>
/// Create a new LLamaContext for the given LLamaWeights
Expand Down Expand Up @@ -496,7 +498,9 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> la
var nl_logit = logits[(int?)nl_token ?? 0];

// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);
if (_temporarySampling == null || _temporarySampling.Length < logits.Length)
_temporarySampling = new LLamaTokenData[logits.Length];
var candidates_p = LLamaTokenDataArray.Create(logits, _temporarySampling);

// Extract most recently returned tokens
var last_n_repeat = Math.Min((int)ContextSize, repeatLastTokensCount);
Expand All @@ -508,14 +512,14 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> la
// Restore newline token logit value if necessary
if (!penalizeNL && nl_token.HasValue)
{
var candidatesSpan = candidates_p.data.Span;
for (var i = 0; i < candidates_p.data.Length; i++)
var candidatesSpan = candidates_p.Data.Span;
for (var i = 0; i < candidates_p.Data.Length; i++)
{
ref var item = ref candidatesSpan[i];
if (item.id == nl_token)
item.logit = nl_logit;
}
candidates_p.sorted = false;
candidates_p.Sorted = false;
}

return candidates_p;
Expand Down
116 changes: 68 additions & 48 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Buffers;
using System.Runtime.InteropServices;

Expand All @@ -12,12 +12,12 @@ public struct LLamaTokenDataArray
/// <summary>
/// The LLamaTokenData
/// </summary>
public readonly Memory<LLamaTokenData> data;
public readonly Memory<LLamaTokenData> Data;

/// <summary>
/// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_.
/// </summary>
public bool sorted;
public bool Sorted;

/// <summary>
/// Create a new LLamaTokenDataArray
Expand All @@ -26,8 +26,8 @@ public struct LLamaTokenDataArray
/// <param name="isSorted"></param>
public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false)
{
data = tokens;
sorted = isSorted;
Data = tokens;
Sorted = isSorted;
}

/// <summary>
Expand All @@ -38,12 +38,32 @@ public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false)
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
{
var candidates = new LLamaTokenData[logits.Length];
for (var token_id = 0; token_id < logits.Length; token_id++)
candidates[token_id] = new LLamaTokenData((LLamaToken)token_id, logits[token_id], 0.0f);
for (var token = 0; token < logits.Length; token++)
candidates[token] = new LLamaTokenData(token, logits[token], 0.0f);

return new LLamaTokenDataArray(candidates);
}

/// <summary>
/// Create a new LLamaTokenDataArray, copying the data from the given logits into temporary memory.
/// </summary>
/// <remarks>The memory must not be modified while this <see cref="LLamaTokenDataArray"/> is in use.</remarks>
/// <param name="logits"></param>
/// <param name="temporary">Temporary memory which will be used to work on these logits. Must be at least as large as logits array</param>
/// <returns></returns>
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits, Memory<LLamaTokenData> temporary)
{
martindevans marked this conversation as resolved.
Show resolved Hide resolved
if (temporary.Length < logits.Length)
throw new ArgumentException("temporary memory is shorter than logits span");
var candidates = temporary.Slice(0, logits.Length);

var candidatesSpan = candidates.Span;
for (var token = 0; token < logits.Length; token++)
martindevans marked this conversation as resolved.
Show resolved Hide resolved
candidatesSpan[token] = new LLamaTokenData(token, logits[token], 0.0f);

return new LLamaTokenDataArray(candidates);
}

/// <summary>
/// Overwrite the logit values for all given tokens
/// </summary>
Expand All @@ -53,10 +73,10 @@ public void OverwriteLogits(ReadOnlySpan<(LLamaToken token, float logit)> values
if (values.Length == 0)
return;

var dataSpan = data.Span;
var dataSpan = Data.Span;
foreach (var (token, value) in values)
{
for (var i = 0; i < data.Length; i++)
for (var i = 0; i < Data.Length; i++)
{
if (dataSpan[i].id == token)
{
Expand All @@ -65,7 +85,7 @@ public void OverwriteLogits(ReadOnlySpan<(LLamaToken token, float logit)> values
}
}
}
sorted = false;
Sorted = false;
}

#region sampling
Expand All @@ -82,7 +102,7 @@ public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? gra
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
sorted = st.sorted;
Sorted = st.sorted;
}
}

Expand All @@ -97,7 +117,7 @@ public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_top_k(context, ref st, k, minKeep);
sorted = st.sorted;
Sorted = st.sorted;
}
}

Expand All @@ -112,7 +132,7 @@ public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_top_p(context, ref st, p, minKeep);
sorted = st.sorted;
Sorted = st.sorted;
}
}

Expand All @@ -127,7 +147,7 @@ public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_min_p(context, ref st, p, minKeep);
sorted = st.sorted;
Sorted = st.sorted;
}
}

Expand All @@ -136,13 +156,13 @@ public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
/// </summary>
/// <param name="context"></param>
/// <param name="z"></param>
/// <param name="min_keep"></param>
public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1)
/// <param name="minKeep"></param>
public void TailFree(SafeLLamaContextHandle context, float z, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_tail_free(context, ref st, z, min_keep);
sorted = st.sorted;
NativeApi.llama_sample_tail_free(context, ref st, z, minKeep);
Sorted = st.sorted;
}
}

Expand All @@ -151,13 +171,13 @@ public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1
/// </summary>
/// <param name="context"></param>
/// <param name="p"></param>
/// <param name="min_keep"></param>
public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1)
/// <param name="minKeep"></param>
public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_typical(context, ref st, p, min_keep);
sorted = st.sorted;
NativeApi.llama_sample_typical(context, ref st, p, minKeep);
Sorted = st.sorted;
}
}

Expand All @@ -166,20 +186,20 @@ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_ke
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="context"></param>
/// <param name="last_tokens"></param>
/// <param name="penalty_repeat"></param>
/// <param name="penalty_freq"></param>
/// <param name="penalty_present"></param>
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<LLamaToken> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
/// <param name="lastTokens"></param>
/// <param name="penaltyRepeat"></param>
/// <param name="penaltyFreq"></param>
/// <param name="penaltyPresent"></param>
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<LLamaToken> lastTokens, float penaltyRepeat, float penaltyFreq, float penaltyPresent)
{
unsafe
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
fixed (LLamaToken* last_tokens_handle = last_tokens)
fixed (LLamaToken* lastTokensHandle = lastTokens)
{
NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
NativeApi.llama_sample_repetition_penalties(context, ref st, lastTokensHandle, (ulong)lastTokens.Length, penaltyRepeat, penaltyFreq, penaltyPresent);
Sorted = st.sorted;
}
}
}
Expand All @@ -194,7 +214,7 @@ public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<LLama
/// <param name="guidance">Guidance strength. 0 means no guidance, higher values applies stronger guidance</param>
public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan<float> guidanceLogits, float guidance)
{
if (guidanceLogits.Length != data.Length)
if (guidanceLogits.Length != Data.Length)
throw new ArgumentException("Guidance logits count must equal vocabulary size", nameof(guidanceLogits));

if (guidance < 0)
Expand All @@ -210,24 +230,24 @@ public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan<float> guidanc
try
{
// Copy logits into a temporary array
for (var i = 0; i < data.Length; i++)
for (var i = 0; i < Data.Length; i++)
{
ref var item = ref data.Span[i];
ref var item = ref Data.Span[i];
logits[(int)item.id] = item.logit;
}

// Apply guidance
NativeApi.llama_sample_apply_guidance(context, logits, guidanceLogits, guidance);
NativeApi.llama_sample_apply_guidance(context, logits.AsSpan(0, context.VocabCount), guidanceLogits, guidance);

// Copy logits back into data array
for (var i = 0; i < data.Length; i++)
for (var i = 0; i < Data.Length; i++)
{
ref var item = ref data.Span[i];
ref var item = ref Data.Span[i];
item.logit = logits[(int)item.id];
}

// No longer sorted since we just mutated logits!
sorted = false;
Sorted = false;
}
finally
{
Expand All @@ -246,7 +266,7 @@ public void Temperature(SafeLLamaContextHandle context, float temp)
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_temp(context, ref st, temp);
sorted = st.sorted;
Sorted = st.sorted;
}
}

Expand All @@ -259,7 +279,7 @@ public void Softmax(SafeLLamaContextHandle context)
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_softmax(context, ref st);
sorted = st.sorted;
Sorted = st.sorted;
}
}

Expand All @@ -273,7 +293,7 @@ public LLamaToken SampleToken(SafeLLamaContextHandle context)
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token(context, ref st);
sorted = st.sorted;
Sorted = st.sorted;
return token;
}
}
Expand All @@ -288,7 +308,7 @@ public LLamaToken SampleTokenGreedy(SafeLLamaContextHandle context)
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_greedy(context, ref st);
sorted = st.sorted;
Sorted = st.sorted;
return token;
}
}
Expand All @@ -307,7 +327,7 @@ public LLamaToken SampleTokenMirostat(SafeLLamaContextHandle context, float tau,
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu);
sorted = st.sorted;
Sorted = st.sorted;
return token;
}
}
Expand All @@ -325,7 +345,7 @@ public LLamaToken SampleTokenMirostat2(SafeLLamaContextHandle context, float tau
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu);
sorted = st.sorted;
Sorted = st.sorted;
return token;
}
}
Expand All @@ -342,7 +362,7 @@ public struct LLamaTokenDataArrayNative
/// A pointer to an array of LlamaTokenData
/// </summary>
/// <remarks>Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use</remarks>
public IntPtr data;
public unsafe LLamaTokenData* data;

AsakusaRinne marked this conversation as resolved.
Show resolved Hide resolved
/// <summary>
/// Number of LLamaTokenData in the array
Expand All @@ -367,15 +387,15 @@ public bool sorted
/// <returns>A memory handle, pinning the data in place until disposed</returns>
public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataArrayNative native)
{
var handle = array.data.Pin();
var handle = array.Data.Pin();

unsafe
{
native = new LLamaTokenDataArrayNative
{
data = new IntPtr(handle.Pointer),
size = (ulong)array.data.Length,
sorted = array.sorted
data = (LLamaTokenData*)handle.Pointer,
size = (ulong)array.Data.Length,
sorted = array.Sorted
};
}

Expand Down
Loading
Loading