diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 67b0aa4..cd17a5d 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -355,7 +355,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, static constexpr size_t kLayers = TConfig::kLayers; const float kEmbScaling = EmbeddingScaling(); static_assert(!TConfig::kAbsolutePE); - static_assert(!TConfig::kPostNormScale); + static_assert(TConfig::kPostNorm == PostNormType::None); static_assert(TConfig::kKVHeads == 1); HWY_DASSERT(prompt.context_size > 0); diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 706b0ef..76825e1 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -388,7 +388,7 @@ struct TestConfig : ConfigCapNoSSM { FixedLayerConfig<2>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; static constexpr int kKVHeads = 1; static constexpr int kGemmaLayers = kLayers; diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 0cbf69d..f23a946 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -193,7 +193,7 @@ struct TestConfig : public ConfigCapNoSSM { FixedLayerConfig<2>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; static constexpr int kKVHeads = 1; static constexpr int kGemmaLayers = kLayers; diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index c24116f..636c23c 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, static constexpr size_t kLayers = TConfig::kLayers; const float kEmbScaling = EmbeddingScaling(); static_assert(!TConfig::kAbsolutePE); - static_assert(!TConfig::kPostNormScale); + static_assert(TConfig::kPostNorm == PostNormType::None); static_assert(TConfig::kKVHeads == 1); HWY_DASSERT(context_size > 0); diff --git a/gemma/configs.h b/gemma/configs.h index b7e2a44..7e3c0d5 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -52,6 +52,32 @@ enum class LayerAttentionType { kGriffinRecurrentBlock, }; +// Post attention and ffw normalization type. +enum class PostNormType { + None, + Scale, +}; + +// Post qk projection operation type. +enum class PostQKType { + Rope, +}; + +// FFW activation function. +enum class ActivationType { + Gelu, +}; + +// Attention query scale. +enum class QueryScaleType { + Sqrt, +}; + +// Residual connection type. +enum class ResidualType { + Add, +}; + template constexpr std::array FixedLayerConfig( LayerAttentionType type) { @@ -107,6 +133,11 @@ struct ConfigNoSSM { static constexpr bool kUseLocalAttention = false; static constexpr bool kInterleaveQKV = true; static constexpr int kNumTensorScales = 0; + + static constexpr PostQKType kPostQK = PostQKType::Rope; + static constexpr ActivationType kActivation = ActivationType::Gelu; + static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt; + static constexpr ResidualType kResidual = ResidualType::Add; }; struct ConfigNoCapNoSSM : ConfigNoSSM { @@ -143,7 +174,7 @@ struct ConfigGemma27B : public ConfigCapNoSSM { static constexpr int kQKVDim = 128; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = true; + static constexpr PostNormType kPostNorm = PostNormType::Scale; }; template @@ -169,7 +200,7 @@ struct ConfigGemma9B : public ConfigCapNoSSM { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = true; + static constexpr PostNormType kPostNorm = PostNormType::Scale; }; template @@ -191,7 +222,7 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; }; template @@ -213,7 +244,7 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; }; template @@ -235,7 +266,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM { static constexpr int kQKVDim = 16; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; static constexpr float kAttCap = 0.0f; // This is required for optimize_test to pass. @@ -294,7 +325,7 @@ struct ConfigGriffin2B { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; // No SoftCap. static constexpr float kAttCap = 0.0f; @@ -308,6 +339,9 @@ struct ConfigGriffin2B { static constexpr bool kUseLocalAttention = true; static constexpr bool kInterleaveQKV = false; static constexpr int kNumTensorScales = 140; + static constexpr PostQKType kPostQK = PostQKType::Rope; + static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt; + static constexpr ResidualType kResidual = ResidualType::Add; }; } // namespace gcpp diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 6735d18..f923ffa 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -295,8 +295,9 @@ HWY_NOINLINE void Attention( constexpr size_t kHeads = TConfig::kHeads; constexpr size_t kKVHeads = TConfig::kKVHeads; constexpr size_t kSeqLen = TConfig::kSeqLen; - GEMMA_CONSTEXPR_SQRT const float kQueryScale = + GEMMA_CONSTEXPR_SQRT float kQueryScale = 1.0f / Sqrt(static_cast(kQKVDim)); + constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention const size_t batch_start = batch_and_query_start / num_queries; const size_t num_tokens_and_queries = num_tokens * num_queries; @@ -350,7 +351,9 @@ HWY_NOINLINE void Attention( // Skip past the Q part of `q`, and copy KV to `kv`. memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); } - Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + if (TConfig::kPostQK == PostQKType::Rope) { + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + } }); static_assert((kHeads % kKVHeads) == 0, @@ -373,7 +376,10 @@ HWY_NOINLINE void Attention( activations.att.data() + head * kSeqLen + batch_and_query_idx * kHeads * kSeqLen; - Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + if (TConfig::kPostQK == PostQKType::Rope) { + Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + } + MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores @@ -465,10 +471,12 @@ HWY_NOINLINE void FFW(Activations& activations, namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; - hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens, - activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR { - return hn::Mul(mul, Gelu(df, v)); - }); + if (TConfig::kActivation == ActivationType::Gelu) { + hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens, + activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR { + return hn::Mul(mul, Gelu(df, v)); + }); + } MatMul_4x4_Batch(num_tokens, activations.C1.data(), layer_weights->linear_w.data(), @@ -560,29 +568,34 @@ HWY_NOINLINE void TransformerLayer( layer_weights, kv_caches, pool); } } - if (TConfig::kPostNormScale) { + + if (TConfig::kPostNorm == PostNormType::Scale) { RMSNormInplaceBatched( num_tokens_and_queries, layer_weights->post_attention_norm_scale.data(), activations.att_post2.data(), kModelDim); } - AddFromBatched(num_tokens_and_queries, - activations.att_post2.data(), - activations.x.data(), kModelDim); + if (TConfig::kResidual == ResidualType::Add) { + AddFromBatched( + num_tokens_and_queries, activations.att_post2.data(), + activations.x.data(), kModelDim); + } RMSNormBatched( num_tokens_and_queries, activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), activations.bf_pre_ffw_rms_out.data(), kModelDim); FFW( activations, num_tokens_and_queries, layer_weights, pool); - if (TConfig::kPostNormScale) { + if (TConfig::kPostNorm == PostNormType::Scale) { RMSNormInplaceBatched( num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(), activations.ffw_out.data(), kModelDim); } - AddFromBatched( - num_tokens_and_queries, activations.ffw_out.data(), - activations.x.data(), kModelDim); + if (TConfig::kResidual == ResidualType::Add) { + AddFromBatched( + num_tokens_and_queries, activations.ffw_out.data(), + activations.x.data(), kModelDim); + } } template diff --git a/gemma/weights.h b/gemma/weights.h index c0c33c8..ee4ab78 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -50,7 +50,7 @@ struct CompressedLayer { static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr bool kFFBiases = TConfig::kFFBiases; - static constexpr bool kPostNormScale = TConfig::kPostNormScale; + static constexpr PostNormType kPostNorm = TConfig::kPostNorm; static constexpr size_t kAOBiasDim = TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; static constexpr size_t kGriffinDim = @@ -86,9 +86,10 @@ struct CompressedLayer { // We don't yet have an RMSNorm that accepts all Weight. ArrayT pre_attention_norm_scale; ArrayT pre_ffw_norm_scale; - ArrayT + ArrayT post_attention_norm_scale; - ArrayT post_ffw_norm_scale; + ArrayT + post_ffw_norm_scale; ArrayT ffw_gating_biases; ArrayT ffw_output_biases; @@ -267,7 +268,7 @@ void ForEachTensor(RawWeightsPtr raw_weights, GEMMA_CALL_FUNC("gr_a", griffin.a); } GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale); - if (TConfig::kPostNormScale) { + if (TConfig::kPostNorm == PostNormType::Scale) { GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale); GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale); } @@ -331,7 +332,7 @@ void ForEachTensor(RawWeightsPtr raw_weights, GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \ GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ - if (TConfig::kPostNormScale) { \ + if (TConfig::kPostNorm == PostNormType::Scale) { \ GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ } \ diff --git a/gemma/weights_raw.h b/gemma/weights_raw.h index cb66876..6c1fe34 100644 --- a/gemma/weights_raw.h +++ b/gemma/weights_raw.h @@ -25,6 +25,7 @@ #include #include "gemma/common.h" +#include "gemma/configs.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -46,7 +47,7 @@ struct Layer { static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr bool kFFBiases = TConfig::kFFBiases; - static constexpr bool kPostNormScale = TConfig::kPostNormScale; + static constexpr PostNormType kPostNorm = TConfig::kPostNorm; static constexpr size_t kAOBiasDim = TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; static constexpr size_t kGriffinDim = @@ -78,8 +79,10 @@ struct Layer { std::array linear_w; std::array pre_attention_norm_scale; std::array pre_ffw_norm_scale; - std::array post_attention_norm_scale; - std::array post_ffw_norm_scale; + std::array + post_attention_norm_scale; + std::array + post_ffw_norm_scale; std::array ffw_gating_biases; std::array ffw_output_biases; diff --git a/util/compress_weights.cc b/util/compress_weights.cc index cc14c42..e756182 100644 --- a/util/compress_weights.cc +++ b/util/compress_weights.cc @@ -159,7 +159,7 @@ struct LoadRawWeightsT { SCALE_WEIGHTS(linear_w); READ_WEIGHTS(pre_attention_norm_scale); READ_WEIGHTS(pre_ffw_norm_scale); - if (TConfig::kPostNormScale) { + if (TConfig::kPostNorm == PostNormType::Scale) { READ_WEIGHTS(post_attention_norm_scale); READ_WEIGHTS(post_ffw_norm_scale); }