diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index b82379c5b..665066cc3 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -1,4 +1,4 @@ -using LLama.Batched; +using LLama.Batched; using LLama.Common; using LLama.Native; using LLama.Sampling; @@ -105,18 +105,19 @@ public override ISamplingPipeline Clone() protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { - if (guidance == null) - return; - - // Get the logits generated by the guidance sequences - var guidanceLogits = guidance.Sample(); - - // Use those logits to guide this sequence - NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight); } protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) { + if (guidance != null) + { + // Get the logits generated by the guidance sequences + var guidanceLogits = guidance.Sample(); + + // Modify these logits based on the guidance logits + candidates.Guidance(ctx, guidanceLogits, weight); + } + candidates.Temperature(ctx, 0.8f); candidates.TopK(ctx, 25); diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 4f8e7d94a..6335c1dc7 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -89,6 +89,8 @@ public uint BatchThreads /// Get the maximum batch size for this context /// public uint BatchSize => NativeHandle.BatchSize; + + private LLamaTokenData[]? _samplingBuffer; /// /// Create a new LLamaContext for the given LLamaWeights @@ -496,7 +498,9 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable la var nl_logit = logits[(int?)nl_token ?? 0]; // Convert logits into token candidates - var candidates_p = LLamaTokenDataArray.Create(logits); + if (_samplingBuffer == null || _samplingBuffer.Length < logits.Length) + _samplingBuffer = new LLamaTokenData[logits.Length]; + var candidates_p = LLamaTokenDataArray.Create(logits, _samplingBuffer); // Extract most recently returned tokens var last_n_repeat = Math.Min((int)ContextSize, repeatLastTokensCount); @@ -508,14 +512,14 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable la // Restore newline token logit value if necessary if (!penalizeNL && nl_token.HasValue) { - var candidatesSpan = candidates_p.data.Span; - for (var i = 0; i < candidates_p.data.Length; i++) + var candidatesSpan = candidates_p.Data.Span; + for (var i = 0; i < candidates_p.Data.Length; i++) { ref var item = ref candidatesSpan[i]; if (item.id == nl_token) item.logit = nl_logit; } - candidates_p.sorted = false; + candidates_p.Sorted = false; } return candidates_p; diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index f36679255..1a656c42e 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Buffers; using System.Runtime.InteropServices; @@ -12,12 +12,12 @@ public struct LLamaTokenDataArray /// /// The LLamaTokenData /// - public readonly Memory data; + public readonly Memory Data; /// /// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_. /// - public bool sorted; + public bool Sorted; /// /// Create a new LLamaTokenDataArray @@ -26,8 +26,8 @@ public struct LLamaTokenDataArray /// public LLamaTokenDataArray(Memory tokens, bool isSorted = false) { - data = tokens; - sorted = isSorted; + Data = tokens; + Sorted = isSorted; } /// @@ -37,13 +37,31 @@ public LLamaTokenDataArray(Memory tokens, bool isSorted = false) /// public static LLamaTokenDataArray Create(ReadOnlySpan logits) { - var candidates = new LLamaTokenData[logits.Length]; - for (var token_id = 0; token_id < logits.Length; token_id++) - candidates[token_id] = new LLamaTokenData((LLamaToken)token_id, logits[token_id], 0.0f); + return Create(logits, new LLamaTokenData[logits.Length]); + } + /// + /// Create a new LLamaTokenDataArray, copying the data from the given logits into temporary memory. + /// + /// The memory must not be modified while this is in use. + /// + /// Temporary memory which will be used to work on these logits. Must be at least as large as logits array + /// + public static LLamaTokenDataArray Create(ReadOnlySpan logits, Memory buffer) + { + if (buffer.Length < logits.Length) + throw new ArgumentException("temporary memory is shorter than logits span"); + + // take a slice of the output buffer which is exactly the size we need. + var candidates = buffer.Slice(0, logits.Length); + var candidatesSpan = candidates.Span; + + for (var token = 0; token < logits.Length; token++) + candidatesSpan[token] = new LLamaTokenData(token, logits[token], 0.0f); + return new LLamaTokenDataArray(candidates); } - + /// /// Overwrite the logit values for all given tokens /// @@ -53,10 +71,10 @@ public void OverwriteLogits(ReadOnlySpan<(LLamaToken token, float logit)> values if (values.Length == 0) return; - var dataSpan = data.Span; + var dataSpan = Data.Span; foreach (var (token, value) in values) { - for (var i = 0; i < data.Length; i++) + for (var i = 0; i < Data.Length; i++) { if (dataSpan[i].id == token) { @@ -65,7 +83,7 @@ public void OverwriteLogits(ReadOnlySpan<(LLamaToken token, float logit)> values } } } - sorted = false; + Sorted = false; } #region sampling @@ -82,7 +100,7 @@ public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? gra using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_grammar(ctx, ref st, grammar); - sorted = st.sorted; + Sorted = st.sorted; } } @@ -97,7 +115,7 @@ public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_top_k(context, ref st, k, minKeep); - sorted = st.sorted; + Sorted = st.sorted; } } @@ -112,7 +130,7 @@ public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_top_p(context, ref st, p, minKeep); - sorted = st.sorted; + Sorted = st.sorted; } } @@ -127,7 +145,7 @@ public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_min_p(context, ref st, p, minKeep); - sorted = st.sorted; + Sorted = st.sorted; } } @@ -136,13 +154,13 @@ public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) /// /// /// - /// - public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1) + /// + public void TailFree(SafeLLamaContextHandle context, float z, ulong minKeep = 1) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { - NativeApi.llama_sample_tail_free(context, ref st, z, min_keep); - sorted = st.sorted; + NativeApi.llama_sample_tail_free(context, ref st, z, minKeep); + Sorted = st.sorted; } } @@ -151,13 +169,13 @@ public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1 /// /// /// - /// - public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1) + /// + public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong minKeep = 1) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { - NativeApi.llama_sample_typical(context, ref st, p, min_keep); - sorted = st.sorted; + NativeApi.llama_sample_typical(context, ref st, p, minKeep); + Sorted = st.sorted; } } @@ -166,20 +184,20 @@ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_ke /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// /// - /// - /// - /// - /// - public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) + /// + /// + /// + /// + public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan lastTokens, float penaltyRepeat, float penaltyFreq, float penaltyPresent) { unsafe { using (LLamaTokenDataArrayNative.Create(this, out var st)) { - fixed (LLamaToken* last_tokens_handle = last_tokens) + fixed (LLamaToken* lastTokensHandle = lastTokens) { - NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); - sorted = st.sorted; + NativeApi.llama_sample_repetition_penalties(context, ref st, lastTokensHandle, (ulong)lastTokens.Length, penaltyRepeat, penaltyFreq, penaltyPresent); + Sorted = st.sorted; } } } @@ -194,7 +212,7 @@ public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpanGuidance strength. 0 means no guidance, higher values applies stronger guidance public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan guidanceLogits, float guidance) { - if (guidanceLogits.Length != data.Length) + if (guidanceLogits.Length != Data.Length) throw new ArgumentException("Guidance logits count must equal vocabulary size", nameof(guidanceLogits)); if (guidance < 0) @@ -210,24 +228,24 @@ public void Guidance(SafeLLamaContextHandle context, ReadOnlySpan guidanc try { // Copy logits into a temporary array - for (var i = 0; i < data.Length; i++) + for (var i = 0; i < Data.Length; i++) { - ref var item = ref data.Span[i]; + ref var item = ref Data.Span[i]; logits[(int)item.id] = item.logit; } // Apply guidance - NativeApi.llama_sample_apply_guidance(context, logits, guidanceLogits, guidance); + NativeApi.llama_sample_apply_guidance(context, logits.AsSpan(0, context.VocabCount), guidanceLogits, guidance); // Copy logits back into data array - for (var i = 0; i < data.Length; i++) + for (var i = 0; i < Data.Length; i++) { - ref var item = ref data.Span[i]; + ref var item = ref Data.Span[i]; item.logit = logits[(int)item.id]; } // No longer sorted since we just mutated logits! - sorted = false; + Sorted = false; } finally { @@ -246,7 +264,7 @@ public void Temperature(SafeLLamaContextHandle context, float temp) using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_temp(context, ref st, temp); - sorted = st.sorted; + Sorted = st.sorted; } } @@ -259,7 +277,7 @@ public void Softmax(SafeLLamaContextHandle context) using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_softmax(context, ref st); - sorted = st.sorted; + Sorted = st.sorted; } } @@ -273,7 +291,7 @@ public LLamaToken SampleToken(SafeLLamaContextHandle context) using (LLamaTokenDataArrayNative.Create(this, out var st)) { var token = NativeApi.llama_sample_token(context, ref st); - sorted = st.sorted; + Sorted = st.sorted; return token; } } @@ -288,7 +306,7 @@ public LLamaToken SampleTokenGreedy(SafeLLamaContextHandle context) using (LLamaTokenDataArrayNative.Create(this, out var st)) { var token = NativeApi.llama_sample_token_greedy(context, ref st); - sorted = st.sorted; + Sorted = st.sorted; return token; } } @@ -307,7 +325,7 @@ public LLamaToken SampleTokenMirostat(SafeLLamaContextHandle context, float tau, using (LLamaTokenDataArrayNative.Create(this, out var st)) { var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu); - sorted = st.sorted; + Sorted = st.sorted; return token; } } @@ -325,7 +343,7 @@ public LLamaToken SampleTokenMirostat2(SafeLLamaContextHandle context, float tau using (LLamaTokenDataArrayNative.Create(this, out var st)) { var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu); - sorted = st.sorted; + Sorted = st.sorted; return token; } } @@ -341,14 +359,28 @@ public struct LLamaTokenDataArrayNative /// /// A pointer to an array of LlamaTokenData /// - /// Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use - public IntPtr data; + /// Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use (i.e. `fixed` or `.Pin()`) + private unsafe LLamaTokenData* _data; /// /// Number of LLamaTokenData in the array /// public ulong size; - + + /// + /// A pointer to an array of LlamaTokenData + /// + public Span data + { + get + { + unsafe + { + return new Span(_data, checked((int)size)); + } + } + } + /// /// Indicates if the items in the array are sorted /// @@ -367,15 +399,15 @@ public bool sorted /// A memory handle, pinning the data in place until disposed public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataArrayNative native) { - var handle = array.data.Pin(); + var handle = array.Data.Pin(); unsafe { native = new LLamaTokenDataArrayNative { - data = new IntPtr(handle.Pointer), - size = (ulong)array.data.Length, - sorted = array.sorted + _data = (LLamaTokenData*)handle.Pointer, + size = (ulong)array.Data.Length, + sorted = array.Sorted }; } diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index 4dcd81277..bab4bbabc 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -1,4 +1,4 @@ -using System; +using System; using LLama.Native; namespace LLama.Sampling; @@ -13,6 +13,8 @@ public abstract class BaseSamplingPipeline /// Grammar to constrain valid tokens /// public SafeLLamaGrammarHandle? Grammar { get; set; } + + private LLamaTokenData[]? _temporarySampling; /// public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) @@ -20,8 +22,12 @@ public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnl // Apply processing to raw logit values ProcessLogits(ctx, logits, lastTokens); + // Allocate some temporary space + if (_temporarySampling == null || _temporarySampling.Length < logits.Length) + _temporarySampling = new LLamaTokenData[logits.Length]; + // Process token data array to select a final token - var candidates = LLamaTokenDataArray.Create(logits); + var candidates = LLamaTokenDataArray.Create(logits, _temporarySampling); candidates.ApplyGrammar(ctx, Grammar); return ProcessTokenDataArray(ctx, candidates, lastTokens); } diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 33806f5f9..0413dd972 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using LLama.Extensions; using LLama.Native; @@ -138,11 +138,11 @@ private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTok if (nlToken.HasValue) { // Try using the ID as an index - if (candidates.data.Span[(int)nlToken].id == nlToken) - return ((int)nlToken, candidates.data.Span[(int)nlToken].logit); + if (candidates.Data.Span[(int)nlToken].id == nlToken) + return ((int)nlToken, candidates.Data.Span[(int)nlToken].logit); // Exhaustive search - var span = candidates.data.Span; + var span = candidates.Data.Span; for (var i = 0; i < span.Length; i++) { if (span[i].id == nlToken) @@ -160,15 +160,15 @@ private static void SetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataAr return; // Try checking the index where we found it last time. It might not be there if `RepetitionPenalty` changed order - if (indexHint >= 0 && candidates.data.Span[indexHint].id == nlToken) + if (indexHint >= 0 && candidates.Data.Span[indexHint].id == nlToken) { - candidates.data.Span[indexHint].logit = logit; + candidates.Data.Span[indexHint].logit = logit; return; } // Didn't find it, do an exhaustive search for it - var span = candidates.data.Span; - for (var i = 0; i < candidates.data.Length; i++) + var span = candidates.Data.Span; + for (var i = 0; i < candidates.Data.Length; i++) { if (span[i].id == nlToken) {