diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 98029fe..676a5d2 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -246,7 +246,7 @@ TEST_F(GemmaTest, CrossEntropySmall) { EXPECT_NEAR(entropy, 2.8f, 0.2f); break; case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 1.57f, 0.02f); + EXPECT_NEAR(entropy, 2.61f, 0.02f); break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 1.14f, 0.02f); @@ -277,7 +277,7 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) { EXPECT_NEAR(entropy, 1.07f, 0.05f); break; case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 2.09f, 0.02f); + EXPECT_NEAR(entropy, 1.62f, 0.02f); break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 0.49f, 0.02f); @@ -308,7 +308,7 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) { EXPECT_NEAR(entropy, 0.75f, 0.1f); break; case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 0.86f, 0.02f); + EXPECT_NEAR(entropy, 0.71f, 0.02f); break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 0.20f, 0.02f); diff --git a/gemma/configs.cc b/gemma/configs.cc index 326b18e..03fce99 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -183,7 +183,7 @@ static ModelConfig ConfigGriffin2B() { .softmax_attn_output_biases = true, .type = LayerAttentionType::kGriffinRecurrentBlock, .activation = ActivationType::Gelu, - .post_qk = PostQKType::Rope, + .post_qk = PostQKType::HalfRope, }; config.layer_configs = {26, layer_config}; for (size_t i = 2; i < config.layer_configs.size(); i += 3) { diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index a6668a4..91bfc53 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -397,7 +397,11 @@ void AssertMatch(const ModelConfig& config) { ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm); ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type); ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation); - ASSERT_EQ(TConfig::kPostQK, config.layer_configs[i].post_qk); + PostQKType post_qk = TConfig::kPostQK; + if (TConfig::kUseHalfRope) { + post_qk = PostQKType::HalfRope; + } + ASSERT_EQ(post_qk, config.layer_configs[i].post_qk); } ASSERT_EQ(TConfig::kAttentionWindowSizes.size(), diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index c58f9a8..24ea1e8 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1240,8 +1240,12 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, const QueriesPos& queries_prefix_end, const size_t query_idx_start, const KVCaches& kv_caches, TimingInfo& timing_info) { - const size_t vocab_size = model.Config().vocab_size; - const ModelWeightsPtrs& weights = *model.GetWeightsOfType(); + // Griffin assumes that the recurrent block cache is zero-initialized. + for (int i = 0; i < kv_caches.size(); ++i) { + if (queries_pos_in[i] == 0) { + kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models. + } + } // Copy so we can increment without requiring users to pass in a mutable span. std::vector queries_pos_copy(queries_pos_in.cbegin(), @@ -1268,7 +1272,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, HWY_ASSERT(queries_pos_in.size() == num_queries); HWY_ASSERT(kv_caches.size() == num_queries); const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); - + const ModelWeightsPtrs& weights = *model.GetWeightsOfType(); size_t max_prompt_size = MaxQueryLength(queries_prompt); size_t max_generated_tokens = runtime_config.max_generated_tokens; RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size); @@ -1314,6 +1318,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, 0.0f); } + const size_t vocab_size = model.Config().vocab_size; const double gen_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_generated_tokens; ++gen) { // Decode generates one token per query and increments queries_mutable_pos. diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index cc9db89..82ee01d 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -23,6 +23,17 @@ namespace gcpp { +void KVCache::ZeroGriffinCache() { + if (conv1d_cache_size != 0) { + hwy::ZeroBytes(conv1d_cache.get(), + conv1d_cache_size * sizeof(conv1d_cache[0])); + } + if (rglru_cache_size != 0) { + hwy::ZeroBytes(rglru_cache.get(), + rglru_cache_size * sizeof(rglru_cache[0])); + } +} + // prefill_tbatch_size is the maximum number of tokens from one query to // prefill at a time. KVCache KVCache::Create(const ModelConfig& weights_config, @@ -37,9 +48,9 @@ KVCache KVCache::Create(const ModelConfig& weights_config, kv_cache.kv_cache = hwy::AllocateAligned(kv_cache.seq_len * size_cache_pos); } - size_t num_griffin_layers = weights_config.NumLayersOfType( - LayerAttentionType::kGriffinRecurrentBlock); + const size_t num_griffin_layers = weights_config.NumLayersOfType( + LayerAttentionType::kGriffinRecurrentBlock); // TODO(patrickms): Add query batching support for Griffin. if (num_griffin_layers > 0) { size_t conv1d_width = 0; @@ -49,20 +60,18 @@ KVCache KVCache::Create(const ModelConfig& weights_config, const size_t conv1d_cache_size = num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) * weights_config.model_dim; + kv_cache.conv1d_cache_size = conv1d_cache_size; if (conv1d_cache_size != 0) { kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); - hwy::ZeroBytes(kv_cache.conv1d_cache.get(), - conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0])); } const size_t rglru_cache_size = num_griffin_layers * weights_config.model_dim; + kv_cache.rglru_cache_size = rglru_cache_size; if (rglru_cache_size != 0) { kv_cache.rglru_cache = hwy::AllocateAligned(rglru_cache_size); - hwy::ZeroBytes(kv_cache.rglru_cache.get(), - rglru_cache_size * sizeof(kv_cache.rglru_cache[0])); } - } // kGriffinLayers + } // num_griffin_layers return kv_cache; } diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 9c46d93..69f9564 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -31,9 +31,15 @@ struct KVCache { // (kConv1dWidth - 1) * kModelDim * kGriffinLayers hwy::AlignedFreeUniquePtr conv1d_cache; + size_t conv1d_cache_size = 0; // kModelDim * kGriffinLayers hwy::AlignedFreeUniquePtr rglru_cache; + size_t rglru_cache_size = 0; + + // Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache + // and rglru_cache. + void ZeroGriffinCache(); static KVCache Create(const ModelConfig& weights_config, size_t prefill_tbatch_size);