Skip to content

Commit

Permalink
Support ofa export for bert (PaddlePaddle#1326)
Browse files Browse the repository at this point in the history
* fix ofa export bug

* support bert export
  • Loading branch information
LiuChiachi authored Nov 17, 2021
1 parent 168f058 commit 54bef3d
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions examples/model_compression/ofa/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,53 @@
import paddle.nn.functional as F

from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer
from paddlenlp.transformers import TinyBertModel, TinyBertForSequenceClassification, TinyBertTokenizer
from paddlenlp.transformers import TinyBertForSequenceClassification, TinyBertTokenizer
from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer
from paddlenlp.utils.log import logger
from paddleslim.nas.ofa import OFA, utils
from paddleslim.nas.ofa.convert_super import Convert, supernet
from paddleslim.nas.ofa.layers import BaseBlock

MODEL_CLASSES = {
"bert": (BertForSequenceClassification, BertTokenizer),
"roberta": (RobertaForSequenceClassification, RobertaTokenizer),
"tinybert": (TinyBertForSequenceClassification, TinyBertTokenizer),
}
MODEL_CLASSES = {"bert": (BertForSequenceClassification, BertTokenizer), }


def tinybert_forward(self, input_ids, token_type_ids=None, attention_mask=None):
def bert_forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None,
output_hidden_states=False):
wtype = self.pooler.dense.fn.weight.dtype if hasattr(
self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2])
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layer = self.encoder(embedding_output, attention_mask)
pooled_output = self.pooler(encoded_layer)

return encoded_layer, pooled_output
else:
if attention_mask.ndim == 2:
# attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length]
attention_mask = attention_mask.unsqueeze(axis=[1, 2])

embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
if output_hidden_states:
output = embedding_output
encoder_outputs = []
for mod in self.encoder.layers:
output = mod(output, src_mask=attention_mask)
encoder_outputs.append(output)
if self.encoder.norm is not None:
encoder_outputs[-1] = self.encoder.norm(encoder_outputs[-1])
pooled_output = self.pooler(encoder_outputs[-1])
else:
sequence_output = self.encoder(embedding_output, attention_mask)
pooled_output = self.pooler(sequence_output)
if output_hidden_states:
return encoder_outputs, pooled_output
else:
return sequence_output, pooled_output


TinyBertModel.forward = tinybert_forward
BertModel.forward = bert_forward


def parse_args():
Expand Down

0 comments on commit 54bef3d

Please sign in to comment.