Skip to content

Commit

Permalink
centre random augmentation needs to be done for all batch samples sep…
Browse files Browse the repository at this point in the history
…arately
  • Loading branch information
lucidrains committed May 24, 2024
1 parent b93e278 commit 57f6da2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
44 changes: 27 additions & 17 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2120,53 +2120,63 @@ class CentreRandomAugmentation(Module):
def __init__(self, trans_scale: float = 1.0):
super().__init__()
self.trans_scale = trans_scale
self.register_buffer('dummy', torch.tensor(0), persistent = False)

@property
def device(self):
return self.dummy.device

@typecheck
def forward(self, coords: Float['b n 3']) -> Float['b n 3']:
"""
coords: coordinates to be augmented
"""
batch_size = coords.shape[0]

# Center the coordinates
centered_coords = coords - coords.mean(dim=1, keepdim=True)

# Generate random rotation matrix
rotation_matrix = self._random_rotation_matrix(coords.device)
rotation_matrix = self._random_rotation_matrix(batch_size)

# Generate random translation vector
translation_vector = self._random_translation_vector(coords.device)
translation_vector = self._random_translation_vector(batch_size)
translation_vector = rearrange(translation_vector, 'b c -> b 1 c')

# Apply rotation and translation
augmented_coords = torch.einsum('bni,ij->bnj', centered_coords, rotation_matrix) + translation_vector
augmented_coords = einsum(centered_coords, rotation_matrix, 'b n i, b i j -> b n j') + translation_vector

return augmented_coords

@typecheck
def _random_rotation_matrix(self, device: torch.device) -> Float['3 3']:
def _random_rotation_matrix(self, batch_size: int) -> Float['b 3 3']:
# Generate random rotation angles
angles = torch.rand(3, device=device) * 2 * torch.pi
angles = torch.rand((batch_size, 3), device = self.device) * 2 * torch.pi

# Compute sine and cosine of angles
sin_angles = torch.sin(angles)
cos_angles = torch.cos(angles)

# Construct rotation matrix
rotation_matrix = torch.eye(3, device=device)
rotation_matrix[0, 0] = cos_angles[0] * cos_angles[1]
rotation_matrix[0, 1] = cos_angles[0] * sin_angles[1] * sin_angles[2] - sin_angles[0] * cos_angles[2]
rotation_matrix[0, 2] = cos_angles[0] * sin_angles[1] * cos_angles[2] + sin_angles[0] * sin_angles[2]
rotation_matrix[1, 0] = sin_angles[0] * cos_angles[1]
rotation_matrix[1, 1] = sin_angles[0] * sin_angles[1] * sin_angles[2] + cos_angles[0] * cos_angles[2]
rotation_matrix[1, 2] = sin_angles[0] * sin_angles[1] * cos_angles[2] - cos_angles[0] * sin_angles[2]
rotation_matrix[2, 0] = -sin_angles[1]
rotation_matrix[2, 1] = cos_angles[1] * sin_angles[2]
rotation_matrix[2, 2] = cos_angles[1] * cos_angles[2]
eye = torch.eye(3, device = self.device)
rotation_matrix = repeat(eye, 'i j -> b i j', b = batch_size).clone()

rotation_matrix[:, 0, 0] = cos_angles[:, 0] * cos_angles[:, 1]
rotation_matrix[:, 0, 1] = cos_angles[:, 0] * sin_angles[:, 1] * sin_angles[:, 2] - sin_angles[:, 0] * cos_angles[:, 2]
rotation_matrix[:, 0, 2] = cos_angles[:, 0] * sin_angles[:, 1] * cos_angles[:, 2] + sin_angles[:, 0] * sin_angles[:, 2]
rotation_matrix[:, 1, 0] = sin_angles[:, 0] * cos_angles[:, 1]
rotation_matrix[:, 1, 1] = sin_angles[:, 0] * sin_angles[:, 1] * sin_angles[:, 2] + cos_angles[:, 0] * cos_angles[:, 2]
rotation_matrix[:, 1, 2] = sin_angles[:, 0] * sin_angles[:, 1] * cos_angles[:, 2] - cos_angles[:, 0] * sin_angles[:, 2]
rotation_matrix[:, 2, 0] = -sin_angles[:, 1]
rotation_matrix[:, 2, 1] = cos_angles[:, 1] * sin_angles[:, 2]
rotation_matrix[:, 2, 2] = cos_angles[:, 1] * cos_angles[:, 2]

return rotation_matrix

@typecheck
def _random_translation_vector(self, device: torch.device) -> Float['3']:
def _random_translation_vector(self, batch_size: int) -> Float['b 3']:
# Generate random translation vector
translation_vector = torch.randn(3, device=device) * self.trans_scale
translation_vector = torch.randn((batch_size, 3), device = self.device) * self.trans_scale
return translation_vector

# input embedder
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.24"
version = "0.0.25"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 57f6da2

Please sign in to comment.