-
Notifications
You must be signed in to change notification settings - Fork 10
/
Physics_Attention.py
178 lines (153 loc) · 7.8 KB
/
Physics_Attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch.nn as nn
import torch
from einops import rearrange, repeat
class Physics_Attention_Irregular_Mesh(nn.Module):
## for irregular meshes in 1D, 2D or 3D space
def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64):
super().__init__()
inner_dim = dim_head * heads
self.dim_head = dim_head
self.heads = heads
self.scale = dim_head ** -0.5
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
self.in_project_x = nn.Linear(dim, inner_dim)
self.in_project_fx = nn.Linear(dim, inner_dim)
self.in_project_slice = nn.Linear(dim_head, slice_num)
for l in [self.in_project_slice]:
torch.nn.init.orthogonal_(l.weight) # use a principled initialization
self.to_q = nn.Linear(dim_head, dim_head, bias=False)
self.to_k = nn.Linear(dim_head, dim_head, bias=False)
self.to_v = nn.Linear(dim_head, dim_head, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
# B N C
B, N, C = x.shape
### (1) Slice
fx_mid = self.in_project_fx(x).reshape(B, N, self.heads, self.dim_head) \
.permute(0, 2, 1, 3).contiguous() # B H N C
x_mid = self.in_project_x(x).reshape(B, N, self.heads, self.dim_head) \
.permute(0, 2, 1, 3).contiguous() # B H N C
slice_weights = self.softmax(self.in_project_slice(x_mid) / self.temperature) # B H N G
slice_norm = slice_weights.sum(2) # B H G
slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
### (2) Attention among slice tokens
q_slice_token = self.to_q(slice_token)
k_slice_token = self.to_k(slice_token)
v_slice_token = self.to_v(slice_token)
dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
attn = self.softmax(dots)
attn = self.dropout(attn)
out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
### (3) Deslice
out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
out_x = rearrange(out_x, 'b h n d -> b n (h d)')
return self.to_out(out_x)
class Physics_Attention_Structured_Mesh_2D(nn.Module):
## for structured mesh in 2D space
def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64, H=101, W=31, kernel=3): # kernel=3):
super().__init__()
inner_dim = dim_head * heads
self.dim_head = dim_head
self.heads = heads
self.scale = dim_head ** -0.5
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
self.H = H
self.W = W
self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2)
self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2)
self.in_project_slice = nn.Linear(dim_head, slice_num)
for l in [self.in_project_slice]:
torch.nn.init.orthogonal_(l.weight) # use a principled initialization
self.to_q = nn.Linear(dim_head, dim_head, bias=False)
self.to_k = nn.Linear(dim_head, dim_head, bias=False)
self.to_v = nn.Linear(dim_head, dim_head, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
# B N C
B, N, C = x.shape
x = x.reshape(B, self.H, self.W, C).contiguous().permute(0, 3, 1, 2).contiguous() # B C H W
### (1) Slice
fx_mid = self.in_project_fx(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
.permute(0, 2, 1, 3).contiguous() # B H N C
x_mid = self.in_project_x(x).permute(0, 2, 3, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
.permute(0, 2, 1, 3).contiguous() # B H N G
slice_weights = self.softmax(
self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G
slice_norm = slice_weights.sum(2) # B H G
slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
### (2) Attention among slice tokens
q_slice_token = self.to_q(slice_token)
k_slice_token = self.to_k(slice_token)
v_slice_token = self.to_v(slice_token)
dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
attn = self.softmax(dots)
attn = self.dropout(attn)
out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
### (3) Deslice
out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
out_x = rearrange(out_x, 'b h n d -> b n (h d)')
return self.to_out(out_x)
class Physics_Attention_Structured_Mesh_3D(nn.Module):
## for structured mesh in 3D space
def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=32, H=32, W=32, D=32, kernel=3):
super().__init__()
inner_dim = dim_head * heads
self.dim_head = dim_head
self.heads = heads
self.scale = dim_head ** -0.5
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
self.H = H
self.W = W
self.D = D
self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2)
self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2)
self.in_project_slice = nn.Linear(dim_head, slice_num)
for l in [self.in_project_slice]:
torch.nn.init.orthogonal_(l.weight) # use a principled initialization
self.to_q = nn.Linear(dim_head, dim_head, bias=False)
self.to_k = nn.Linear(dim_head, dim_head, bias=False)
self.to_v = nn.Linear(dim_head, dim_head, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
# B N C
B, N, C = x.shape
x = x.reshape(B, self.H, self.W, self.D, C).contiguous().permute(0, 4, 1, 2, 3).contiguous() # B C H W
### (1) Slice
fx_mid = self.in_project_fx(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
.permute(0, 2, 1, 3).contiguous() # B H N C
x_mid = self.in_project_x(x).permute(0, 2, 3, 4, 1).contiguous().reshape(B, N, self.heads, self.dim_head) \
.permute(0, 2, 1, 3).contiguous() # B H N G
slice_weights = self.softmax(
self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5)) # B H N G
slice_norm = slice_weights.sum(2) # B H G
slice_token = torch.einsum("bhnc,bhng->bhgc", fx_mid, slice_weights)
slice_token = slice_token / ((slice_norm + 1e-5)[:, :, :, None].repeat(1, 1, 1, self.dim_head))
### (2) Attention among slice tokens
q_slice_token = self.to_q(slice_token)
k_slice_token = self.to_k(slice_token)
v_slice_token = self.to_v(slice_token)
dots = torch.matmul(q_slice_token, k_slice_token.transpose(-1, -2)) * self.scale
attn = self.softmax(dots)
attn = self.dropout(attn)
out_slice_token = torch.matmul(attn, v_slice_token) # B H G D
### (3) Deslice
out_x = torch.einsum("bhgc,bhng->bhnc", out_slice_token, slice_weights)
out_x = rearrange(out_x, 'b h n d -> b n (h d)')
return self.to_out(out_x)