-
Notifications
You must be signed in to change notification settings - Fork 3
/
model.py
69 lines (53 loc) · 2.07 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from layer import LPAconv
class MLP(nn.Module):
def __init__(self, in_feature, hidden, out_feature, dropout):
super(MLP, self).__init__()
self.fc1 = nn.Linear(in_feature, hidden)
self.fc2 = nn.Linear(hidden, out_feature)
self.relu = nn.ReLU()
self.dropout_rate = dropout
def forward(self, data):
x = data.x
x = self.fc1(x)
x = self.relu(x)
x = F.dropout(x, self.dropout_rate, training=self.training)
x = self.fc2(x)
return x
class GCN(nn.Module):
def __init__(self, in_feature, hidden, out_feature, dropout):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_feature, hidden)
self.conv2 = GCNConv(hidden, out_feature)
self.dropout_rate = dropout
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, self.dropout_rate, training=self.training)
x = self.conv2(x, edge_index)
return x
class GCN_LPA(nn.Module):
def __init__(self, in_feature, hidden, out_feature, dropout, num_edges, lpaiters, gcnnum):
super(GCN_LPA, self).__init__()
self.edge_weight = nn.Parameter(torch.ones(num_edges))
gc = nn.ModuleList()
gc.append(GCNConv(in_feature, hidden))
for i in range(gcnnum-2):
gc.append(GCNConv(hidden, hidden))
gc.append(GCNConv(hidden, out_feature))
self.gc = gc
self.lpa = LPAconv(lpaiters)
self.dropout_rate = dropout
def forward(self, data, mask):
x, edge_index, y = data.x, data.edge_index, data.y
for i in range(len(self.gc)-1):
x = self.gc[i](x, edge_index, self.edge_weight)
x = F.relu(x)
x = F.dropout(x, self.dropout_rate, training=self.training)
x = self.gc[-1](x, edge_index, self.edge_weight)
y_hat = self.lpa(y, edge_index, mask, self.edge_weight)
return x, y_hat