diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 6d83de0..a23ac84 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -74,8 +74,8 @@ TEST(OptimizeTest, GradientDescent) { RuntimeConfig runtime = { .max_generated_tokens = 16, .temperature = 1.0f, - .verbosity = 0, .gen = &gen, + .verbosity = 0, .stream_token = stream_token, .eos_id = ReverseSequenceSampler::kEndToken, }; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 8c84f96..60cc61e 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -74,8 +74,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, runtime_config_ = { .max_generated_tokens = inference.max_generated_tokens, .temperature = inference.temperature, - .verbosity = app.verbosity, .gen = &gen_, + .verbosity = app.verbosity, }; } diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 13ff3d3..6393c53 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -139,8 +139,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, RuntimeConfig runtime = { .max_generated_tokens = max_generated_tokens - 1, .temperature = 0.0f, - .verbosity = verbosity, .gen = nullptr, + .verbosity = verbosity, .stream_token = stream_token, .sample_func = sample_token, }; diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 676a5d2..114d5a3 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -169,8 +169,8 @@ TEST_F(GemmaTest, Multiturn) { RuntimeConfig runtime_config{ .max_generated_tokens = 64, .temperature = 0.0f, - .verbosity = 2, .gen = &s_env->MutableGen(), + .verbosity = 2, .stream_token = stream_token, }; TimingInfo timing_info{.verbosity = 0}; diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index d3618db..77c9dcd 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -127,8 +127,8 @@ void Run(GemmaEnv& env, JsonArgs& json) { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = 30, .temperature = 0.0f, - .verbosity = env.Verbosity(), .gen = &env.MutableGen(), + .verbosity = env.Verbosity(), .stream_token = stream_token, }; env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0, diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 7b2e90f..7e9e561 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -89,8 +89,8 @@ int main(int argc, char** argv) { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = 1024, .temperature = 1.0, - .verbosity = 0, .gen = &gen, + .verbosity = 0, .stream_token = stream_token, .accept_token = [&](int token, float /* prob */) { diff --git a/gemma/configs.h b/gemma/configs.h index f6a4245..e709df7 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -178,7 +178,6 @@ struct ModelConfig { size_t vit_seq_len = 0; size_t num_tensor_scales = 0; size_t num_vit_scales = 0; - size_t top_k = kTopK; float att_cap = 0.0f; float final_cap = 0.0f; bool absolute_pe = false; diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 91bfc53..8128baf 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -374,7 +374,7 @@ void AssertMatch(const ModelConfig& config) { } ASSERT_EQ(TConfig::kVocabSize, config.vocab_size); ASSERT_EQ(TConfig::kSeqLen, config.seq_len); - ASSERT_EQ(TConfig::kTopK, config.top_k); + // ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value. ASSERT_EQ(TConfig::kAttCap, config.att_cap); ASSERT_EQ(TConfig::kFinalCap, config.final_cap); ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 2b1587d..51f2999 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1196,13 +1196,12 @@ class TokenStreamer { hwy::BitSet4096<> is_eos_; }; -HWY_INLINE SampleFunc ChooseSampleFunc(int top_k, - const RuntimeConfig& runtime_config) { +HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; // Fast path for top-1 with no accept_token. - if (top_k == 1 && !runtime_config.accept_token) { + if (runtime_config.top_k == 1 && !runtime_config.accept_token) { return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample Top1"); return Top1OfSoftmax(logits, vocab_size); @@ -1210,13 +1209,13 @@ HWY_INLINE SampleFunc ChooseSampleFunc(int top_k, } // General case: Softmax with top-k sampling. - return [top_k, &runtime_config](float* logits, - size_t vocab_size) HWY_ATTR -> TokenAndProb { + return [&runtime_config](float* logits, + size_t vocab_size) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample general"); Softmax(logits, vocab_size); - const int token = - SampleTopK(logits, top_k, vocab_size, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token); + const int token = SampleTopK( + logits, runtime_config.top_k, vocab_size, *runtime_config.gen, + runtime_config.temperature, runtime_config.accept_token); return TokenAndProb{.token = token, .prob = logits[token]}; }; } @@ -1276,8 +1275,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, size_t max_prompt_size = MaxQueryLength(queries_prompt); size_t max_generated_tokens = runtime_config.max_generated_tokens; RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size); - const SampleFunc sample_token = - ChooseSampleFunc(weights.weights_config.top_k, runtime_config); + const SampleFunc sample_token = ChooseSampleFunc(runtime_config); // Prefill stops before min_prompt_size - 1 because the last prompt // token is the first input token for generation. diff --git a/gemma/gemma.h b/gemma/gemma.h index 5df319f..5b84053 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -25,6 +25,7 @@ #include "compression/io.h" // Path #include "gemma/activations.h" #include "gemma/common.h" +#include "gemma/configs.h" #include "gemma/kv_cache.h" #include "gemma/tokenizer.h" #include "gemma/weights.h" @@ -102,9 +103,12 @@ struct RuntimeConfig { // Max queries per batch (one token from each) during decode. size_t decode_qbatch_size = 16; - float temperature; // Temperature for sampling. + // Sampling-related parameters. + float temperature; // Temperature for sampling. + size_t top_k = kTopK; // Top-k for sampling. + std::mt19937* gen; // Random number generator used for sampling. + int verbosity; // Controls verbosity of printed messages. - std::mt19937* gen; // Random number generator used for sampling. // Functions operating on the generated tokens. StreamFunc stream_token; diff --git a/gemma/run.cc b/gemma/run.cc index 2c62bdb..87c7c9d 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -99,7 +99,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, HWY_ASSERT(image.ReadPPM(args.image_file.path)); image.Resize(); RuntimeConfig runtime_config = { - .verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin}; + .gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin}; double image_tokens_start = hwy::platform::Now(); model.GenerateImageTokens(runtime_config, image, image_tokens); if (app.verbosity >= 1) { @@ -172,8 +172,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, } TimingInfo timing_info = {.verbosity = app.verbosity}; - RuntimeConfig runtime_config = {.verbosity = app.verbosity, - .gen = &gen, + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = app.verbosity, .stream_token = stream_token, .accept_token = accept_token, .use_spinning = app.spin}; diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index b820eec..64c0ee8 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -56,7 +56,7 @@ void PaliGemmaTest::InitVit(const std::string& path) { HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(image.ReadPPM(path)); image.Resize(); - RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()}; + RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; model.GenerateImageTokens(runtime_config, image, image_tokens_); } @@ -64,8 +64,8 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ Gemma& model = *(s_env->GetModel()); s_env->MutableGen().seed(0x12345678); RuntimeConfig runtime_config = {.max_generated_tokens = 512, - .verbosity = 0, - .gen = &s_env->MutableGen()}; + .gen = &s_env->MutableGen(), + .verbosity = 0}; runtime_config.image_tokens = &image_tokens_; size_t abs_pos = 0; std::string mutable_prompt = prompt_text; diff --git a/util/app.h b/util/app.h index ebc16b9..add8aa3 100644 --- a/util/app.h +++ b/util/app.h @@ -220,6 +220,7 @@ struct InferenceArgs : public ArgsBase { size_t decode_qbatch_size; float temperature; + size_t top_k; bool deterministic; bool multiturn; Path image_file; @@ -244,6 +245,8 @@ struct InferenceArgs : public ArgsBase { "Decode: max queries per batch."); visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); + visitor(top_k, "top_k", size_t{1}, "Number of top-K topkens to sample from", + 2); visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, @@ -259,6 +262,7 @@ struct InferenceArgs : public ArgsBase { runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size; runtime_config.temperature = temperature; + runtime_config.top_k = top_k; } };