diff --git a/src/cryo_sbi/inference/models/embedding_nets.py b/src/cryo_sbi/inference/models/embedding_nets.py index e502ff2..c7d0552 100644 --- a/src/cryo_sbi/inference/models/embedding_nets.py +++ b/src/cryo_sbi/inference/models/embedding_nets.py @@ -58,6 +58,36 @@ def forward(self, x): return x +@add_embedding("CRYODRGN_ENCODER") +class CryoDrgn_Encoder(nn.Module): + def __init__(self, output_dimension: int): + super(CryoDrgn_Encoder, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(12892, 1024), + nn.GELU(), + nn.Linear(1024, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.GELU(), + nn.Linear(512, output_dimension), + nn.GELU(), + ) + self.mask = Mask(128, 64, inside=True).mask.flatten() + print("Using DRGN Encoder") + + def forward(self, x): + if x.dim == 2: + x = x.unsqueeze(0) + x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1))).real + x = x.flatten(start_dim=1)[:, self.mask] + x = self.mlp(x) + return x + + @add_embedding("RESNET18") class ResNet18_Encoder(nn.Module): def __init__(self, output_dimension: int):