Skip to content

Commit

Permalink
Make top_k a runtime argument (instead of a model argument).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696142888
  • Loading branch information
danielkeysers authored and copybara-github committed Nov 13, 2024
1 parent e54d9cb commit 18364c4
Show file tree
Hide file tree
Showing 13 changed files with 31 additions and 26 deletions.
2 changes: 1 addition & 1 deletion backprop/optimize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ TEST(OptimizeTest, GradientDescent) {
RuntimeConfig runtime = {
.max_generated_tokens = 16,
.temperature = 1.0f,
.verbosity = 0,
.gen = &gen,
.verbosity = 0,
.stream_token = stream_token,
.eos_id = ReverseSequenceSampler::kEndToken,
};
Expand Down
2 changes: 1 addition & 1 deletion evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
runtime_config_ = {
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.verbosity = app.verbosity,
.gen = &gen_,
.verbosity = app.verbosity,
};
}

Expand Down
2 changes: 1 addition & 1 deletion evals/cross_entropy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
RuntimeConfig runtime = {
.max_generated_tokens = max_generated_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
.verbosity = verbosity,
.stream_token = stream_token,
.sample_func = sample_token,
};
Expand Down
2 changes: 1 addition & 1 deletion evals/gemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ TEST_F(GemmaTest, Multiturn) {
RuntimeConfig runtime_config{
.max_generated_tokens = 64,
.temperature = 0.0f,
.verbosity = 2,
.gen = &s_env->MutableGen(),
.verbosity = 2,
.stream_token = stream_token,
};
TimingInfo timing_info{.verbosity = 0};
Expand Down
2 changes: 1 addition & 1 deletion evals/run_mmlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ void Run(GemmaEnv& env, JsonArgs& json) {
gcpp::RuntimeConfig runtime_config = {
.max_generated_tokens = 30,
.temperature = 0.0f,
.verbosity = env.Verbosity(),
.gen = &env.MutableGen(),
.verbosity = env.Verbosity(),
.stream_token = stream_token,
};
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
Expand Down
2 changes: 1 addition & 1 deletion examples/hello_world/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ int main(int argc, char** argv) {
gcpp::RuntimeConfig runtime_config = {
.max_generated_tokens = 1024,
.temperature = 1.0,
.verbosity = 0,
.gen = &gen,
.verbosity = 0,
.stream_token = stream_token,
.accept_token =
[&](int token, float /* prob */) {
Expand Down
1 change: 0 additions & 1 deletion gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ struct ModelConfig {
size_t vit_seq_len = 0;
size_t num_tensor_scales = 0;
size_t num_vit_scales = 0;
size_t top_k = kTopK;
float att_cap = 0.0f;
float final_cap = 0.0f;
bool absolute_pe = false;
Expand Down
2 changes: 1 addition & 1 deletion gemma/configs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ void AssertMatch(const ModelConfig& config) {
}
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
ASSERT_EQ(TConfig::kTopK, config.top_k);
// ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value.
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
Expand Down
18 changes: 8 additions & 10 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1196,27 +1196,26 @@ class TokenStreamer {
hwy::BitSet4096<> is_eos_;
};

HWY_INLINE SampleFunc ChooseSampleFunc(int top_k,
const RuntimeConfig& runtime_config) {
HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
// If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func;

// Fast path for top-1 with no accept_token.
if (top_k == 1 && !runtime_config.accept_token) {
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample Top1");
return Top1OfSoftmax(logits, vocab_size);
};
}

// General case: Softmax with top-k sampling.
return [top_k, &runtime_config](float* logits,
size_t vocab_size) HWY_ATTR -> TokenAndProb {
return [&runtime_config](float* logits,
size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general");
Softmax(logits, vocab_size);
const int token =
SampleTopK(logits, top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
const int token = SampleTopK(
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token);
return TokenAndProb{.token = token, .prob = logits[token]};
};
}
Expand Down Expand Up @@ -1276,8 +1275,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
size_t max_prompt_size = MaxQueryLength(queries_prompt);
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
const SampleFunc sample_token =
ChooseSampleFunc(weights.weights_config.top_k, runtime_config);
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);

// Prefill stops before min_prompt_size - 1 because the last prompt
// token is the first input token for generation.
Expand Down
8 changes: 6 additions & 2 deletions gemma/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "compression/io.h" // Path
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "gemma/weights.h"
Expand Down Expand Up @@ -102,9 +103,12 @@ struct RuntimeConfig {
// Max queries per batch (one token from each) during decode.
size_t decode_qbatch_size = 16;

float temperature; // Temperature for sampling.
// Sampling-related parameters.
float temperature; // Temperature for sampling.
size_t top_k = kTopK; // Top-k for sampling.
std::mt19937* gen; // Random number generator used for sampling.

int verbosity; // Controls verbosity of printed messages.
std::mt19937* gen; // Random number generator used for sampling.

// Functions operating on the generated tokens.
StreamFunc stream_token;
Expand Down
6 changes: 3 additions & 3 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
HWY_ASSERT(image.ReadPPM(args.image_file.path));
image.Resize();
RuntimeConfig runtime_config = {
.verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin};
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
double image_tokens_start = hwy::platform::Now();
model.GenerateImageTokens(runtime_config, image, image_tokens);
if (app.verbosity >= 1) {
Expand Down Expand Up @@ -172,8 +172,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
}

TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = {.verbosity = app.verbosity,
.gen = &gen,
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = app.verbosity,
.stream_token = stream_token,
.accept_token = accept_token,
.use_spinning = app.spin};
Expand Down
6 changes: 3 additions & 3 deletions paligemma/paligemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ void PaliGemmaTest::InitVit(const std::string& path) {
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));
image.Resize();
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
model.GenerateImageTokens(runtime_config, image, image_tokens_);
}

std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
Gemma& model = *(s_env->GetModel());
s_env->MutableGen().seed(0x12345678);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.verbosity = 0,
.gen = &s_env->MutableGen()};
.gen = &s_env->MutableGen(),
.verbosity = 0};
runtime_config.image_tokens = &image_tokens_;
size_t abs_pos = 0;
std::string mutable_prompt = prompt_text;
Expand Down
4 changes: 4 additions & 0 deletions util/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
size_t decode_qbatch_size;

float temperature;
size_t top_k;
bool deterministic;
bool multiturn;
Path image_file;
Expand All @@ -244,6 +245,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"Decode: max queries per batch.");

visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(top_k, "top_k", size_t{1}, "Number of top-K topkens to sample from",
2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
Expand All @@ -259,6 +262,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size;
runtime_config.temperature = temperature;
runtime_config.top_k = top_k;
}
};

Expand Down

0 comments on commit 18364c4

Please sign in to comment.