Skip to content

Commit

Permalink
basic validation loop
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 28, 2024
1 parent d52c3c0 commit 618f323
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 4 deletions.
75 changes: 74 additions & 1 deletion alphafold3_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def exists(val):
def default(v, d):
return v if exists(v) else d

def divisible_by(num, den):
return (num % den) == 0

def cycle(dataloader: DataLoader):
while True:
for batch in dataloader:
Expand Down Expand Up @@ -74,6 +77,8 @@ def __init__(
num_train_steps: int,
batch_size: int,
grad_accum_every: int = 1,
valid_dataset: Dataset | None = None,
valid_every: int = 1000,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
ema_decay = 0.999,
Expand Down Expand Up @@ -122,10 +127,22 @@ def __init__(

self.optimizer = optimizer

# data
# train dataloader

self.dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)

# validation dataloader on the EMA model

self.valid_every = valid_every

self.needs_valid = exists(valid_dataset)

if self.needs_valid and self.is_main:
self.valid_dataset_size = len(valid_dataset)
self.valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size)

# training steps and num gradient accum steps

self.num_train_steps = num_train_steps
self.grad_accum_every = grad_accum_every

Expand Down Expand Up @@ -154,6 +171,9 @@ def __init__(
def is_main(self):
return self.fabric.global_rank == 0

def wait(self):
self.fabric.barrier()

def print(self, *args, **kwargs):
self.fabric.print(*args, **kwargs)

Expand All @@ -165,35 +185,88 @@ def __call__(
):
dl = cycle(self.dataloader)

# while less than required number of training steps

while self.steps < self.num_train_steps:

self.model.train()

# gradient accumulation

for grad_accum_step in range(self.grad_accum_every):
is_accumulating = grad_accum_step < (self.grad_accum_every - 1)

inputs = next(dl)

with self.fabric.no_backward_sync(self.model, enabled = is_accumulating):

# model forwards

loss, loss_breakdown = self.model(
**inputs,
return_loss_breakdown = True
)

# backwards

self.fabric.backward(loss / self.grad_accum_every)

# log entire loss breakdown

self.log(**loss_breakdown._asdict())

self.print(f'loss: {loss.item():.3f}')

# clip gradients

self.fabric.clip_gradients(self.model, self.optimizer, max_norm = self.clip_grad_norm)

# optimizer step

self.optimizer.step()

# update exponential moving average

self.wait()

if self.is_main:
self.ema_model.update()

self.wait()

# scheduler

self.scheduler.step()
self.optimizer.zero_grad()

self.steps += 1

# maybe validate, for now, only on main with EMA model

if (
self.is_main and
self.needs_valid and
divisible_by(self.steps, self.valid_every)
):
with torch.no_grad():
self.ema_model.eval()

total_valid_loss = 0.

for valid_batch in self.valid_dataloader:
valid_loss, valid_loss_breakdown = self.ema_model(
**valid_batch,
return_loss_breakdown = True
)

valid_batch_size = valid_batch.get('atom_inputs').shape[0]
scale = valid_batch_size / self.valid_dataset_size

scaled_valid_loss = valid_loss.item() * scale
total_valid_loss += scaled_valid_loss

self.print(f'valid loss: {valid_loss.item():.3f}')

self.wait()

print(f'training complete')
11 changes: 8 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@

# mock dataset

class AtomDataset(Dataset):
class MockAtomDataset(Dataset):
def __init__(
self,
data_length,
seq_len = 16,
atoms_per_window = 27
):
self.data_length = data_length
self.seq_len = seq_len
self.atom_seq_len = seq_len * atoms_per_window

def __len__(self):
return 100
return self.data_length

def __getitem__(self, idx):
seq_len = self.seq_len
Expand Down Expand Up @@ -93,14 +95,17 @@ def test_trainer():
),
)

dataset = AtomDataset()
dataset = MockAtomDataset(100)
valid_dataset = MockAtomDataset(2)

trainer = Trainer(
alphafold3,
dataset = dataset,
valid_dataset = valid_dataset,
accelerator = 'cpu',
num_train_steps = 2,
batch_size = 1,
valid_every = 1,
grad_accum_every = 2
)

Expand Down

0 comments on commit 618f323

Please sign in to comment.