forked from LIN-WANQIU/ADI17
-
Notifications
You must be signed in to change notification settings - Fork 0
/
encoder.py
91 lines (73 loc) · 3.05 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch.nn as nn
from attention import MultiHeadAttention
from module import PositionalEncoding, PositionwiseFeedForward
from utils import get_non_pad_mask, get_attn_pad_mask
class Encoder(nn.Module):
"""Encoder of Transformer including self-attention and feed forward.
"""
def __init__(self, d_input, n_layers, n_head, d_k, d_v,
d_model, d_inner, dropout=0.1, pe_maxlen=5000):
super(Encoder, self).__init__()
# parameters
self.d_input = d_input
self.n_layers = n_layers
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.d_model = d_model
self.d_inner = d_inner
self.dropout_rate = dropout
self.pe_maxlen = pe_maxlen
# use linear transformation with layer norm to replace input embedding
self.linear_in = nn.Linear(d_input, d_model)
self.layer_norm_in = nn.LayerNorm(d_model)
self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
self.dropout = nn.Dropout(dropout)
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
def forward(self, padded_input, input_lengths, return_attns=False):
"""
Args:
padded_input: N x T x D
input_lengths: N
Returns:
enc_output: N x T x H
"""
enc_slf_attn_list = []
# Prepare masks
non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)
length = padded_input.size(1)
slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length)
# Forward
enc_output = self.dropout(
self.layer_norm_in(self.linear_in(padded_input)) +
self.positional_encoding(padded_input))
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(
enc_output,
non_pad_mask=non_pad_mask,
slf_attn_mask=slf_attn_mask)
if return_attns:
enc_slf_attn_list += [enc_slf_attn]
if return_attns:
return enc_output, enc_slf_attn_list
return enc_output,
class EncoderLayer(nn.Module):
"""Compose with two sub-layers.
1. A multi-head self-attention mechanism
2. A simple, position-wise fully connected feed-forward network.
"""
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, dropout=dropout)
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output *= non_pad_mask
enc_output = self.pos_ffn(enc_output)
enc_output *= non_pad_mask
return enc_output, enc_slf_attn