From d4c86123b651016172376526604e35f9918712b3 Mon Sep 17 00:00:00 2001 From: yzhhr <86141389+yzhhr@users.noreply.github.com> Date: Sat, 16 Nov 2024 06:29:42 +0800 Subject: [PATCH 1/3] fix the bug that attention_mask and past_kv_cache cannot work together (#772) Co-authored-by: Bryce Meyer --- tests/integration/test_kv_cache.py | 22 ++++++++++++++++++++++ transformer_lens/HookedTransformer.py | 27 +++++++++++++++------------ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/tests/integration/test_kv_cache.py b/tests/integration/test_kv_cache.py index a98ba7de6..baab6696a 100644 --- a/tests/integration/test_kv_cache.py +++ b/tests/integration/test_kv_cache.py @@ -213,6 +213,28 @@ def test_freeze_cache(pretrained): assert not t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=atol) +def test_kv_cache_with_custom_attention_mask(pretrained): + model, atol = pretrained + prompt_pre = "An apple" + prompt_post = " a day keeps junk the" + prompt_whole = "An apple a day keeps the" + tokens_pre = model.to_tokens(prompt_pre) + tokens_post = model.to_tokens(prompt_post, prepend_bos=False) + tokens_whole = model.to_tokens(prompt_whole) + correct_logits = model(tokens_whole) + + past_kv_cache = HookedTransformerKeyValueCache.init_cache( + model.cfg, model.cfg.device, tokens_pre.shape[0] + ) + model(tokens_pre, past_kv_cache=past_kv_cache) + exp_logits = model( + tokens_post, + attention_mask=t.tensor([[1, 1, 1, 0, 1]], device=model.cfg.device), + past_kv_cache=past_kv_cache, + ) + assert t.allclose(correct_logits[:, -1], exp_logits[:, -1], atol=atol) + + def test_kv_cache_and_start_at_layer(pretrained): model, atol = pretrained pre_prompt = "I went to Staten Island," diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 56096484c..c084c02b1 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -8,6 +8,7 @@ alteration of activations in individual components like attention heads and MLP layers, facilitating a deeper understanding of the internal workings of transformers like GPT-2. """ + import logging import os from typing import ( @@ -297,23 +298,25 @@ def input_to_embed( if tokens.device.type != self.cfg.device: tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) - if attention_mask is not None: + if ( + (self.tokenizer and self.tokenizer.padding_side == "left") + or attention_mask is not None + or past_kv_cache is not None + ): + # This means we need to have an explicit attention mask. + if attention_mask is None: + # If the padding side is left or we are using caching, we need to compute the attention + # mask for the adjustment of absolute positional embeddings and attention masking so + # that pad tokens are not attended. + if prepend_bos is USE_DEFAULT_VALUE: + prepend_bos = self.cfg.default_prepend_bos + attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) + assert attention_mask.shape == tokens.shape, ( f"Attention mask shape {attention_mask.shape} does not match tokens shape " f"{tokens.shape}" ) attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) - elif ( - self.tokenizer and self.tokenizer.padding_side == "left" - ) or past_kv_cache is not None: - # If the padding side is left or we are using caching, we need to compute the attention - # mask for the adjustment of absolute positional embeddings and attention masking so - # that pad tokens are not attended. - - if prepend_bos is USE_DEFAULT_VALUE: - prepend_bos = self.cfg.default_prepend_bos - attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) - if past_kv_cache is not None: # past_kv_cache is not None, so we're doing caching. # We need to extend the previous attention_mask. From 32b87c63072b015c47274998b9aabdd33e8524dc Mon Sep 17 00:00:00 2001 From: Fabian Degen <106864199+degenfabian@users.noreply.github.com> Date: Fri, 15 Nov 2024 23:30:39 +0100 Subject: [PATCH 2/3] Set prepend_bos to false by default for Bloom model family (#775) * fix prepend_bos to False by default for bloom model family * add comment * edit documentation * fix wrong expected value for bloom-560m model loss * fix expected loss value for bloom model computed with google colab * set prepend_bos to user value, then to value in model config and then default to true * fix format * remove log points in test_hooked_transformer --------- Co-authored-by: Bryce Meyer Co-authored-by: Fabian Degen --- tests/acceptance/test_hooked_transformer.py | 2 +- transformer_lens/HookedTransformer.py | 18 +++++++++------- transformer_lens/loading_from_pretrained.py | 23 +++++++++++++++------ 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 9d9e2bb19..939d1c1e5 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -66,7 +66,7 @@ "redwood_attn_2l": 10.530948638916016, "solu-1l": 5.256411552429199, "tiny-stories-33M": 12.203617095947266, - "bloom-560m": 4.1953, + "bloom-560m": 5.237126350402832, } no_processing = [ diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index c084c02b1..8b07f5046 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1083,7 +1083,7 @@ def from_pretrained( tokenizer: Optional[PreTrainedTokenizerBase] = None, move_to_device: bool = True, fold_value_biases: bool = True, - default_prepend_bos: bool = True, + default_prepend_bos: Optional[bool] = None, default_padding_side: Literal["left", "right"] = "right", dtype="float32", first_n_layers: Optional[int] = None, @@ -1205,11 +1205,15 @@ def from_pretrained( remains exactly the same, and so is just broadcast across the destination positions. default_prepend_bos: Default behavior of whether to prepend the BOS token when the methods of HookedTransformer process input text to tokenize (only - when input is a string). Defaults to True - even for models not explicitly trained - with this, heads often use the first position as a resting position and accordingly - lose information from the first token, so this empirically seems to give better - results. To change the default behavior to False, pass in default_prepend_bos=False. - Note that you can also locally override the default behavior by passing in + when input is a string). + Resolution order for default_prepend_bos: + 1. If user passes value explicitly, use that value + 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) + 3. Global default (True) + + Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position + and accordingly lose information from the first token, so this empirically seems to give better + results. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string. from_pretrained_kwargs: Any other optional argument passed to HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to @@ -1353,7 +1357,7 @@ def from_pretrained_no_processing( refactor_factored_attn_matrices=False, fold_value_biases=False, dtype=torch.float32, - default_prepend_bos=True, + default_prepend_bos=None, default_padding_side="right", **from_pretrained_kwargs, ): diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 49dffbf04..aa544786f 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1498,7 +1498,7 @@ def get_pretrained_model_config( fold_ln: bool = False, device: Optional[Union[str, torch.device]] = None, n_devices: int = 1, - default_prepend_bos: bool = True, + default_prepend_bos: Optional[bool] = None, dtype: torch.dtype = torch.float32, first_n_layers: Optional[int] = None, **kwargs, @@ -1529,11 +1529,15 @@ def get_pretrained_model_config( n_devices (int, optional): The number of devices to split the model across. Defaults to 1. default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the methods of HookedTransformer process input text to tokenize (only when input is a string). - Defaults to True - even for models not explicitly trained with this, heads often use the + Resolution order for default_prepend_bos: + 1. If user passes value explicitly, use that value + 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) + 3. Global default (True) + + Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position and accordingly lose information from the first token, - so this empirically seems to give better results. To change the default behavior to False, pass in - default_prepend_bos=False. Note that you can also locally override the default behavior by passing - in prepend_bos=True/False when you call a method that processes the input string. + so this empirically seems to give better results. Note that you can also locally override the default behavior + by passing in prepend_bos=True/False when you call a method that processes the input string. dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. kwargs: Other optional arguments passed to HuggingFace's from_pretrained. Also given to other HuggingFace functions when compatible. @@ -1610,7 +1614,14 @@ def get_pretrained_model_config( cfg_dict["device"] = device cfg_dict["n_devices"] = n_devices - cfg_dict["default_prepend_bos"] = default_prepend_bos + + if default_prepend_bos is not None: + # User explicitly set prepend_bos behavior, override config/default value + cfg_dict["default_prepend_bos"] = default_prepend_bos + elif "default_prepend_bos" not in cfg_dict: + # No config value or user override, set default value (True) + cfg_dict["default_prepend_bos"] = True + if hf_cfg is not None: cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) if first_n_layers is not None: From d9792a9a2a079aafd4525acd33735c9d19f9bd26 Mon Sep 17 00:00:00 2001 From: Fabian Degen <106864199+degenfabian@users.noreply.github.com> Date: Sat, 16 Nov 2024 01:07:46 +0100 Subject: [PATCH 3/3] Fix that if use_past_kv_cache is set to True models from the Bloom family produce weird outputs. (#777) * Fix kv_cache leads to wrong output when used with bloom models * add test for bloom models when use_past_kv_cache is set to true * fix max_length for huggingface model in kv_cache test * set max_length to 13 for huggingface model in kv_cache test * use max_new_tokens for huggingface model instead of max_length in kv_cache test * fix format --------- Co-authored-by: Bryce Meyer Co-authored-by: Fabian Degen --- tests/acceptance/test_hooked_transformer.py | 20 +++++++++++++++++++ .../components/abstract_attention.py | 3 ++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 939d1c1e5..ac7555ad6 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -175,6 +175,26 @@ def test_from_pretrained_revision(): raise AssertionError("Should have raised an error") +def test_bloom_similarity_with_hf_model_with_kv_cache_activated(): + tf_model = HookedTransformer.from_pretrained( + "bigscience/bloom-560m", default_prepend_bos=False, device="cpu" + ) + hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") + hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + + output_tf = tf_model.generate( + text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10 + ) + output_hf_tokens = hf_model.generate( + hf_tokenizer(text, return_tensors="pt").input_ids, + do_sample=False, + max_new_tokens=10, + ) + output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True) + + assert output_tf == output_hf_str + + def check_norm_folding( model_name, hf_model=None, diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index a2a831e9f..9c4855091 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -229,8 +229,9 @@ def forward( self.cfg.n_heads, key_ctx, self.cfg.device ) + # Take the last query_ctx positions so it also works with past_kv_cache attn_scores += self.alibi[ - :, :query_ctx, :key_ctx + :, -query_ctx:, :key_ctx ] # [batch, head_index, query_pos, key_pos] elif self.cfg.positional_embedding_type == "relative_positional_bias": if position_bias is None: