Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove einsum in forward pass in AbstractAttention #783

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,17 +297,19 @@ def forward(
)
)
else:
# Add singleton dimensions to make shapes compatible for broadcasting:
w = einops.rearrange(
self.W_O,
"head_index d_head d_model -> d_model head_index d_head",
"head_index d_head d_model -> 1 1 head_index d_head d_model",
)
result = self.hook_result(
einops.einsum(
z,
w,
"... head_index d_head, d_model head_index d_head -> ... head_index d_model",
)
) # [batch, pos, head_index, d_model]
z = einops.rearrange(
z, "batch pos head_index d_head -> batch pos head_index d_head 1"
)

# Multiply the z tensor by the W_O tensor, summing over the d_head dimension
unhooked_result = (z * w).sum(-2)

result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model]
out = (
einops.reduce(result, "batch position index model->batch position model", "sum")
+ self.b_O
Expand Down