Skip to content

Commit

Permalink
Merge pull request #780 from TransformerLensOrg/dev
Browse files Browse the repository at this point in the history
Release 2.9
  • Loading branch information
bryce13950 authored Nov 16, 2024
2 parents 8f482fc + d9792a9 commit dc19c08
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 27 deletions.
22 changes: 21 additions & 1 deletion tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"redwood_attn_2l": 10.530948638916016,
"solu-1l": 5.256411552429199,
"tiny-stories-33M": 12.203617095947266,
"bloom-560m": 4.1953,
"bloom-560m": 5.237126350402832,
}

no_processing = [
Expand Down Expand Up @@ -175,6 +175,26 @@ def test_from_pretrained_revision():
raise AssertionError("Should have raised an error")


def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
tf_model = HookedTransformer.from_pretrained(
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
)
hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

output_tf = tf_model.generate(
text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10
)
output_hf_tokens = hf_model.generate(
hf_tokenizer(text, return_tensors="pt").input_ids,
do_sample=False,
max_new_tokens=10,
)
output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True)

assert output_tf == output_hf_str


def check_norm_folding(
model_name,
hf_model=None,
Expand Down
22 changes: 22 additions & 0 deletions tests/integration/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,28 @@ def test_freeze_cache(pretrained):
assert not t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=atol)


def test_kv_cache_with_custom_attention_mask(pretrained):
model, atol = pretrained
prompt_pre = "An apple"
prompt_post = " a day keeps junk the"
prompt_whole = "An apple a day keeps the"
tokens_pre = model.to_tokens(prompt_pre)
tokens_post = model.to_tokens(prompt_post, prepend_bos=False)
tokens_whole = model.to_tokens(prompt_whole)
correct_logits = model(tokens_whole)

past_kv_cache = HookedTransformerKeyValueCache.init_cache(
model.cfg, model.cfg.device, tokens_pre.shape[0]
)
model(tokens_pre, past_kv_cache=past_kv_cache)
exp_logits = model(
tokens_post,
attention_mask=t.tensor([[1, 1, 1, 0, 1]], device=model.cfg.device),
past_kv_cache=past_kv_cache,
)
assert t.allclose(correct_logits[:, -1], exp_logits[:, -1], atol=atol)


def test_kv_cache_and_start_at_layer(pretrained):
model, atol = pretrained
pre_prompt = "I went to Staten Island,"
Expand Down
45 changes: 26 additions & 19 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
alteration of activations in individual components like attention heads and MLP layers, facilitating
a deeper understanding of the internal workings of transformers like GPT-2.
"""

import logging
import os
from typing import (
Expand Down Expand Up @@ -297,23 +298,25 @@ def input_to_embed(
if tokens.device.type != self.cfg.device:
tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))

if attention_mask is not None:
if (
(self.tokenizer and self.tokenizer.padding_side == "left")
or attention_mask is not None
or past_kv_cache is not None
):
# This means we need to have an explicit attention mask.
if attention_mask is None:
# If the padding side is left or we are using caching, we need to compute the attention
# mask for the adjustment of absolute positional embeddings and attention masking so
# that pad tokens are not attended.
if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)

assert attention_mask.shape == tokens.shape, (
f"Attention mask shape {attention_mask.shape} does not match tokens shape "
f"{tokens.shape}"
)
attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg))
elif (
self.tokenizer and self.tokenizer.padding_side == "left"
) or past_kv_cache is not None:
# If the padding side is left or we are using caching, we need to compute the attention
# mask for the adjustment of absolute positional embeddings and attention masking so
# that pad tokens are not attended.

if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)

if past_kv_cache is not None:
# past_kv_cache is not None, so we're doing caching.
# We need to extend the previous attention_mask.
Expand Down Expand Up @@ -1080,7 +1083,7 @@ def from_pretrained(
tokenizer: Optional[PreTrainedTokenizerBase] = None,
move_to_device: bool = True,
fold_value_biases: bool = True,
default_prepend_bos: bool = True,
default_prepend_bos: Optional[bool] = None,
default_padding_side: Literal["left", "right"] = "right",
dtype="float32",
first_n_layers: Optional[int] = None,
Expand Down Expand Up @@ -1202,11 +1205,15 @@ def from_pretrained(
remains exactly the same, and so is just broadcast across the destination positions.
default_prepend_bos: Default behavior of whether to prepend the BOS
token when the methods of HookedTransformer process input text to tokenize (only
when input is a string). Defaults to True - even for models not explicitly trained
with this, heads often use the first position as a resting position and accordingly
lose information from the first token, so this empirically seems to give better
results. To change the default behavior to False, pass in default_prepend_bos=False.
Note that you can also locally override the default behavior by passing in
when input is a string).
Resolution order for default_prepend_bos:
1. If user passes value explicitly, use that value
2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
3. Global default (True)
Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position
and accordingly lose information from the first token, so this empirically seems to give better
results. Note that you can also locally override the default behavior by passing in
prepend_bos=True/False when you call a method that processes the input string.
from_pretrained_kwargs: Any other optional argument passed to
HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
Expand Down Expand Up @@ -1350,7 +1357,7 @@ def from_pretrained_no_processing(
refactor_factored_attn_matrices=False,
fold_value_biases=False,
dtype=torch.float32,
default_prepend_bos=True,
default_prepend_bos=None,
default_padding_side="right",
**from_pretrained_kwargs,
):
Expand Down
3 changes: 2 additions & 1 deletion transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ def forward(
self.cfg.n_heads, key_ctx, self.cfg.device
)

# Take the last query_ctx positions so it also works with past_kv_cache
attn_scores += self.alibi[
:, :query_ctx, :key_ctx
:, -query_ctx:, :key_ctx
] # [batch, head_index, query_pos, key_pos]
elif self.cfg.positional_embedding_type == "relative_positional_bias":
if position_bias is None:
Expand Down
23 changes: 17 additions & 6 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ def get_pretrained_model_config(
fold_ln: bool = False,
device: Optional[Union[str, torch.device]] = None,
n_devices: int = 1,
default_prepend_bos: bool = True,
default_prepend_bos: Optional[bool] = None,
dtype: torch.dtype = torch.float32,
first_n_layers: Optional[int] = None,
**kwargs,
Expand Down Expand Up @@ -1529,11 +1529,15 @@ def get_pretrained_model_config(
n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
methods of HookedTransformer process input text to tokenize (only when input is a string).
Defaults to True - even for models not explicitly trained with this, heads often use the
Resolution order for default_prepend_bos:
1. If user passes value explicitly, use that value
2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
3. Global default (True)
Even for models not explicitly trained with the BOS token, heads often use the
first position as a resting position and accordingly lose information from the first token,
so this empirically seems to give better results. To change the default behavior to False, pass in
default_prepend_bos=False. Note that you can also locally override the default behavior by passing
in prepend_bos=True/False when you call a method that processes the input string.
so this empirically seems to give better results. Note that you can also locally override the default behavior
by passing in prepend_bos=True/False when you call a method that processes the input string.
dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
Also given to other HuggingFace functions when compatible.
Expand Down Expand Up @@ -1610,7 +1614,14 @@ def get_pretrained_model_config(

cfg_dict["device"] = device
cfg_dict["n_devices"] = n_devices
cfg_dict["default_prepend_bos"] = default_prepend_bos

if default_prepend_bos is not None:
# User explicitly set prepend_bos behavior, override config/default value
cfg_dict["default_prepend_bos"] = default_prepend_bos
elif "default_prepend_bos" not in cfg_dict:
# No config value or user override, set default value (True)
cfg_dict["default_prepend_bos"] = True

if hf_cfg is not None:
cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
if first_n_layers is not None:
Expand Down

0 comments on commit dc19c08

Please sign in to comment.