Skip to content

Commit

Permalink
Use more realistic RoPE tests (#1785)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Oct 9, 2024
1 parent 467fb87 commit ad57435
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_rope_llama_3():
# See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_1():
head_dim = 128
head_dim = 32
rope_theta = 50_000

their_rope_config = {
Expand All @@ -130,15 +130,16 @@ def test_rope_llama_3_1():

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config
rope_scaling=their_rope_config,
head_dim=head_dim
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 10
batch_size, seq_len = 1, 131_072
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand Down Expand Up @@ -169,7 +170,7 @@ def test_rope_llama_3_1():
# See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_2():
head_dim = 128
head_dim = 32
rope_theta = 50_000

their_rope_config = {
Expand All @@ -189,15 +190,16 @@ def test_rope_llama_3_2():

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config
rope_scaling=their_rope_config,
head_dim=head_dim
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 10
batch_size, seq_len = 1, 131_072
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All @@ -222,4 +224,5 @@ def test_rope_llama_3_2():
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)

0 comments on commit ad57435

Please sign in to comment.