Skip to content

Commit

Permalink
Fix tokenization in utils.test_prompt (#334)
Browse files Browse the repository at this point in the history
Fixes #271. Before, test_prompt would throw an index error when the concatenated prompt and answer string had a different tokenization from the prompt and answer strings separately. The function now works for such cases.
  • Loading branch information
Felhof authored Oct 14, 2023
1 parent b9508e5 commit 5b9d13d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
63 changes: 63 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest import mock

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -207,6 +209,67 @@ def test_fail(self, x: torch.Tensor):
assert not utils.is_lower_triangular(x)


@pytest.mark.parametrize(
"prepend_space_to_answer, tokenized_prompt, tokenized_answer",
[
(
True,
[
"<|BOS|>",
"The",
" circumference",
" is",
" the",
" perimeter",
" of",
" the",
" circ",
],
[" le", "."],
),
(
False,
[
"<|BOS|>",
"The",
" circumference",
" is",
" the",
" perimeter",
" of",
" the",
" circ",
],
["le", "."],
),
],
)
@mock.patch("builtins.print")
def test_test_prompt(
mocked_print,
prepend_space_to_answer,
tokenized_prompt,
tokenized_answer,
):
"""
Tests that utils.test_prompt produces the correct tokenization. In particular, when prepend_space_to_answer = False, the last token of the prompt
and the first answer token should not be turned into one token (e.g. 'circ' and 'le' don't become 'circle'). See https://github.com/neelnanda-io/TransformerLens/issues/271
for a more detailed explanation.
"""
utils.test_prompt(
"The circumference is the perimeter of the circ",
"le.",
model,
prepend_space_to_answer=prepend_space_to_answer,
)

printed_tokenized_prompt = mock.call("Tokenized prompt:", tokenized_prompt)
printed_tokenized_answer = mock.call("Tokenized answer:", tokenized_answer)

assert mocked_print.mock_calls[0] == printed_tokenized_prompt
assert mocked_print.mock_calls[1] == printed_tokenized_answer


def test_override_or_use_default_value():
# Case when override is not None
assert utils.override_or_use_default_value(default_flag=True, override=True) == True
Expand Down
4 changes: 3 additions & 1 deletion transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,9 @@ def test_prompt(
if prepend_space_to_answer and not answer.startswith(" "):
answer = " " + answer
# GPT-2 often treats the first token weirdly, so lets give it a resting position
tokens = model.to_tokens(prompt + answer, prepend_bos=prepend_bos)
prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
answer_tokens = model.to_tokens(answer, prepend_bos=False)
tokens = torch.cat((prompt_tokens, answer_tokens), dim=1)
prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
answer_str_tokens = model.to_str_tokens(answer, prepend_bos=False)
prompt_length = len(prompt_str_tokens)
Expand Down

0 comments on commit 5b9d13d

Please sign in to comment.