Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call paddle.incubate.nn.functional.fused_gate_attention and set use_flash_attn to true. #1087

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions ppfleetx/models/protein_folding/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, config, global_config, q_dim, kv_dim, output_dim):

# TODO(GuoxiaWang): delete non fuse_attention related code on dcu
self.fuse_attention = self.global_config.fuse_attention
self.use_flash_attn = self.global_config.use_flash_attn
self.merge_qkv = (q_dim == kv_dim)

assert key_dim % num_head == 0
Expand Down Expand Up @@ -121,11 +122,24 @@ def forward(self, q_data, m_data, bias, nonbatched_bias=None):
if self.fuse_attention:
if nonbatched_bias is not None:
nonbatched_bias = paddle.unsqueeze(nonbatched_bias, axis=1)
_, _, _, _, _, _, _, output = _C_ops.fused_gate_attention(
q_data, m_data, self.query_w, self.key_w, self.value_w,
self.qkv_w, nonbatched_bias, bias, self.gating_w,
self.gating_b, self.output_w, self.output_b, 'has_gating',
self.config.gating, 'merge_qkv', self.merge_qkv)

import paddle.incubate.nn.functional as F
output = F.fused_gate_attention(
query=q_data,
key=m_data,
query_weight=self.query_w,
key_weight=self.key_w,
value_weight=self.value_w,
qkv_weight=self.qkv_w,
gate_linear_weight=self.gating_w,
gate_linear_bias=self.gating_b,
out_linear_weight=self.output_w,
out_linear_bias=self.output_b,
nonbatched_bias=nonbatched_bias,
attn_mask=bias,
has_gating=self.config.gating,
merge_qkv=self.merge_qkv,
use_flash_attn=self.use_flash_attn, )
else:
c = self.key_dim**(-0.5)
q = paddle.einsum('nbqa,ahc->nbqhc', q_data, self.query_w) * c
Expand Down