From f5b5535fd3850a0e5127369563f0f63a07cdc844 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 18 May 2023 16:32:13 +0800 Subject: [PATCH] Call paddle.incubate.nn.functional.fused_gate_attention and set use_flash_attn to true. --- ppfleetx/models/protein_folding/attentions.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/ppfleetx/models/protein_folding/attentions.py b/ppfleetx/models/protein_folding/attentions.py index 40802a345..de7eab244 100644 --- a/ppfleetx/models/protein_folding/attentions.py +++ b/ppfleetx/models/protein_folding/attentions.py @@ -46,6 +46,7 @@ def __init__(self, config, global_config, q_dim, kv_dim, output_dim): # TODO(GuoxiaWang): delete non fuse_attention related code on dcu self.fuse_attention = self.global_config.fuse_attention + self.use_flash_attn = self.global_config.use_flash_attn self.merge_qkv = (q_dim == kv_dim) assert key_dim % num_head == 0 @@ -121,11 +122,24 @@ def forward(self, q_data, m_data, bias, nonbatched_bias=None): if self.fuse_attention: if nonbatched_bias is not None: nonbatched_bias = paddle.unsqueeze(nonbatched_bias, axis=1) - _, _, _, _, _, _, _, output = _C_ops.fused_gate_attention( - q_data, m_data, self.query_w, self.key_w, self.value_w, - self.qkv_w, nonbatched_bias, bias, self.gating_w, - self.gating_b, self.output_w, self.output_b, 'has_gating', - self.config.gating, 'merge_qkv', self.merge_qkv) + + import paddle.incubate.nn.functional as F + output = F.fused_gate_attention( + query=q_data, + key=m_data, + query_weight=self.query_w, + key_weight=self.key_w, + value_weight=self.value_w, + qkv_weight=self.qkv_w, + gate_linear_weight=self.gating_w, + gate_linear_bias=self.gating_b, + out_linear_weight=self.output_w, + out_linear_bias=self.output_b, + nonbatched_bias=nonbatched_bias, + attn_mask=bias, + has_gating=self.config.gating, + merge_qkv=self.merge_qkv, + use_flash_attn=self.use_flash_attn, ) else: c = self.key_dim**(-0.5) q = paddle.einsum('nbqa,ahc->nbqhc', q_data, self.query_w) * c