-
Notifications
You must be signed in to change notification settings - Fork 0
/
transNAR.py
125 lines (101 loc) · 4.09 KB
/
transNAR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TransNAR(nn.Module):
def __init__(self, input_dim, output_dim, embed_dim, num_heads, num_layers, ffn_dim, dropout=0.1):
super(TransNAR, self).__init__()
# Camada de Embedding
self.embedding = nn.Linear(input_dim, embed_dim)
self.pos_encoding = PositionalEncoding(embed_dim, dropout)
# Camadas Transformer
self.transformer_layers = nn.ModuleList([
TransformerLayer(embed_dim, num_heads, ffn_dim, dropout)
for _ in range(num_layers)
])
# Neural Algorithmic Reasoner (NAR)
self.nar = NAR(embed_dim)
# Decodificador
self.decoder = nn.Linear(embed_dim * 2, output_dim)
# Camada de normalização final
self.final_norm = nn.LayerNorm(output_dim)
def forward(self, x):
# Embedding e codificação posicional
x = self.embedding(x)
x = self.pos_encoding(x)
# Camadas Transformer
for layer in self.transformer_layers:
x = layer(x)
# Neural Algorithmic Reasoner
nar_output = self.nar(x)
# Concatenar saída do Transformer e do NAR
combined = torch.cat([x, nar_output], dim=-1)
# Decodificação
output = self.decoder(combined)
# Normalização final
output = self.final_norm(output)
return output
class TransformerLayer(nn.Module):
def __init__(self, embed_dim, num_heads, ffn_dim, dropout=0.1):
super(TransformerLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, ffn_dim),
nn.ReLU(),
nn.Linear(ffn_dim, embed_dim)
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Atenção
attn_output, _ = self.self_attn(x, x, x)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# Feedforward
ffn_output = self.ffn(x)
x = x + self.dropout(ffn_output)
x = self.norm2(x)
return x
class NAR(nn.Module):
def __init__(self, embed_dim):
super(NAR, self).__init__()
self.reasoning_layers = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 2),
nn.ReLU(),
nn.Linear(embed_dim * 2, embed_dim),
nn.Tanh()
)
self.gru = nn.GRU(embed_dim, embed_dim, batch_first=True)
self.output_layer = nn.Linear(embed_dim, embed_dim) # Nova camada para ajustar a saída
def forward(self, x):
reasoned = self.reasoning_layers(x)
output, _ = self.gru(reasoned)
output = self.output_layer(output) # Ajustar a dimensão
return output
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Inicializa o tensor de codificação posicional
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :].to(x.device)
return self.dropout(x)
# Exemplo de uso
input_dim = 100
output_dim = 50
embed_dim = 256
num_heads = 8
num_layers = 6
ffn_dim = 1024
model = TransNAR(input_dim, output_dim, embed_dim, num_heads, num_layers, ffn_dim)
input_data = torch.randn(32, 100, input_dim) # Corrigido para incluir a dimensão de embedding
output = model(input_data)
print(output.shape) # Deve imprimir torch.Size([32, 100, 50])