Skip to content

Commit

Permalink
added crzoDrgn encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Aug 5, 2024
1 parent 67ba7e6 commit fb76691
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,36 @@ def forward(self, x):
return x


@add_embedding("CRYODRGN_ENCODER")
class CryoDrgn_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(CryoDrgn_Encoder, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(12892, 1024),
nn.GELU(),
nn.Linear(1024, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, output_dimension),
nn.GELU(),
)
self.mask = Mask(128, 64, inside=True).mask.flatten()
print("Using DRGN Encoder")

def forward(self, x):
if x.dim == 2:
x = x.unsqueeze(0)
x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1))).real
x = x.flatten(start_dim=1)[:, self.mask]
x = self.mlp(x)
return x


@add_embedding("RESNET18")
class ResNet18_Encoder(nn.Module):
def __init__(self, output_dimension: int):
Expand Down

0 comments on commit fb76691

Please sign in to comment.