Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configurables for norm/rope/activation/scale/residual connection. #287

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backprop/backward-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE);
static_assert(!TConfig::kPostNormScale);
static_assert(TConfig::kPostNorm == PostNormType::None);
static_assert(TConfig::kKVHeads == 1);

HWY_DASSERT(prompt.context_size > 0);
Expand Down
2 changes: 1 addition & 1 deletion backprop/backward_scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ struct TestConfig : ConfigCapNoSSM {
FixedLayerConfig<2>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr PostNormType kPostNorm = PostNormType::None;

static constexpr int kKVHeads = 1;
static constexpr int kGemmaLayers = kLayers;
Expand Down
2 changes: 1 addition & 1 deletion backprop/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ struct TestConfig : public ConfigCapNoSSM {
FixedLayerConfig<2>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr PostNormType kPostNorm = PostNormType::None;

static constexpr int kKVHeads = 1;
static constexpr int kGemmaLayers = kLayers;
Expand Down
2 changes: 1 addition & 1 deletion backprop/forward-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE);
static_assert(!TConfig::kPostNormScale);
static_assert(TConfig::kPostNorm == PostNormType::None);
static_assert(TConfig::kKVHeads == 1);

HWY_DASSERT(context_size > 0);
Expand Down
46 changes: 40 additions & 6 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@ enum class LayerAttentionType {
kGriffinRecurrentBlock,
};

// Post attention and ffw normalization type.
enum class PostNormType {
None,
Scale,
};

// Post qk projection operation type.
enum class PostQKType {
Rope,
};

// FFW activation function.
enum class ActivationType {
Gelu,
};

// Attention query scale.
enum class QueryScaleType {
Sqrt,
};

// Residual connection type.
enum class ResidualType {
Add,
};

template <size_t kNum>
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
LayerAttentionType type) {
Expand Down Expand Up @@ -107,6 +133,11 @@ struct ConfigNoSSM {
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;

static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
static constexpr ResidualType kResidual = ResidualType::Add;
};

struct ConfigNoCapNoSSM : ConfigNoSSM {
Expand Down Expand Up @@ -143,7 +174,7 @@ struct ConfigGemma27B : public ConfigCapNoSSM {
static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = true;
static constexpr PostNormType kPostNorm = PostNormType::Scale;
};

template <typename TWeight>
Expand All @@ -169,7 +200,7 @@ struct ConfigGemma9B : public ConfigCapNoSSM {
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = true;
static constexpr PostNormType kPostNorm = PostNormType::Scale;
};

template <typename TWeight>
Expand All @@ -191,7 +222,7 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
};

template <typename TWeight>
Expand All @@ -213,7 +244,7 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
};

template <typename TWeight>
Expand All @@ -235,7 +266,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr PostNormType kPostNorm = PostNormType::None;

static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
Expand Down Expand Up @@ -294,7 +325,7 @@ struct ConfigGriffin2B {
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false;
static constexpr PostNormType kPostNorm = PostNormType::None;

// No SoftCap.
static constexpr float kAttCap = 0.0f;
Expand All @@ -308,6 +339,9 @@ struct ConfigGriffin2B {
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
static constexpr ResidualType kResidual = ResidualType::Add;
};

} // namespace gcpp
Expand Down
43 changes: 28 additions & 15 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,9 @@ HWY_NOINLINE void Attention(
constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kKVHeads = TConfig::kKVHeads;
constexpr size_t kSeqLen = TConfig::kSeqLen;
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
GEMMA_CONSTEXPR_SQRT float kQueryScale =
1.0f / Sqrt(static_cast<float>(kQKVDim));

constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention
const size_t batch_start = batch_and_query_start / num_queries;
const size_t num_tokens_and_queries = num_tokens * num_queries;
Expand Down Expand Up @@ -350,7 +351,9 @@ HWY_NOINLINE void Attention(
// Skip past the Q part of `q`, and copy KV to `kv`.
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
}
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
if (TConfig::kPostQK == PostQKType::Rope) {
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}
});

static_assert((kHeads % kKVHeads) == 0,
Expand All @@ -373,7 +376,10 @@ HWY_NOINLINE void Attention(
activations.att.data() + head * kSeqLen
+ batch_and_query_idx * kHeads * kSeqLen;

Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
if (TConfig::kPostQK == PostQKType::Rope) {
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}

MulByConst(kQueryScale, q, kQKVDim);

// Compute Q dot K scores
Expand Down Expand Up @@ -465,10 +471,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
return hn::Mul(mul, Gelu(df, v));
});
if (TConfig::kActivation == ActivationType::Gelu) {
hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
return hn::Mul(mul, Gelu(df, v));
});
}

MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
layer_weights->linear_w.data(),
Expand Down Expand Up @@ -560,29 +568,34 @@ HWY_NOINLINE void TransformerLayer(
layer_weights, kv_caches, pool);
}
}
if (TConfig::kPostNormScale) {

if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries,
layer_weights->post_attention_norm_scale.data(),
activations.att_post2.data(), kModelDim);
}
AddFromBatched<kBatchSize * kQueryBatchSize>(num_tokens_and_queries,
activations.att_post2.data(),
activations.x.data(), kModelDim);
if (TConfig::kResidual == ResidualType::Add) {
AddFromBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.att_post2.data(),
activations.x.data(), kModelDim);
}
RMSNormBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.x.data(),
layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<TConfig, kBatchSize * kQueryBatchSize>(
activations, num_tokens_and_queries, layer_weights, pool);
if (TConfig::kPostNormScale) {
if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(),
activations.ffw_out.data(), kModelDim);
}
AddFromBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.ffw_out.data(),
activations.x.data(), kModelDim);
if (TConfig::kResidual == ResidualType::Add) {
AddFromBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.ffw_out.data(),
activations.x.data(), kModelDim);
}
}

template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
Expand Down
11 changes: 6 additions & 5 deletions gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct CompressedLayer {
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
Expand Down Expand Up @@ -86,9 +86,10 @@ struct CompressedLayer {
// We don't yet have an RMSNorm that accepts all Weight.
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
ArrayT<WeightF32OrBF16, kPostNorm == PostNormType::Scale ? kModelDim : 0>
post_attention_norm_scale;
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
ArrayT<WeightF32OrBF16, kPostNorm == PostNormType::Scale ? kModelDim : 0>
post_ffw_norm_scale;

ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
Expand Down Expand Up @@ -267,7 +268,7 @@ void ForEachTensor(RawWeightsPtr raw_weights,
GEMMA_CALL_FUNC("gr_a", griffin.a);
}
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
if (TConfig::kPostNormScale) {
if (TConfig::kPostNorm == PostNormType::Scale) {
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
}
Expand Down Expand Up @@ -331,7 +332,7 @@ void ForEachTensor(RawWeightsPtr raw_weights,
GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \
GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
if (TConfig::kPostNormScale) { \
if (TConfig::kPostNorm == PostNormType::Scale) { \
GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
} \
Expand Down
9 changes: 6 additions & 3 deletions gemma/weights_raw.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <random>

#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
Expand All @@ -46,7 +47,7 @@ struct Layer {
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
Expand Down Expand Up @@ -78,8 +79,10 @@ struct Layer {
std::array<T, kModelDim * kFFHiddenDim> linear_w;
std::array<T, kModelDim> pre_attention_norm_scale;
std::array<T, kModelDim> pre_ffw_norm_scale;
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
std::array<T, kPostNorm == PostNormType::Scale ? kModelDim : 0>
post_attention_norm_scale;
std::array<T, kPostNorm == PostNormType::Scale ? kModelDim : 0>
post_ffw_norm_scale;

std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;
Expand Down
2 changes: 1 addition & 1 deletion util/compress_weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ struct LoadRawWeightsT {
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kPostNormScale) {
if (TConfig::kPostNorm == PostNormType::Scale) {
READ_WEIGHTS(post_attention_norm_scale);
READ_WEIGHTS(post_ffw_norm_scale);
}
Expand Down
Loading