-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_mlps.py
39 lines (29 loc) · 940 Bytes
/
train_mlps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from data.mlpdset import MLPDataset
from models.decoder import Decoder
from models.encoder import Encoder
from trainers.mlps import MLPTrainer
device = torch.device("cuda")
emb_size = 4096
prep_shape = (8, 10_000)
mlp_batch_size = 4
num_epochs = 1_000_000
logdir = f"/path/to/log/dir"
Path(logdir).mkdir(parents=True, exist_ok=True)
dataset_root = Path("/path/to/sdf/dataset")
train_range = (0, 1000)
num_coords = 50_000
mlps_dataset = MLPDataset(dataset_root, train_range, prep_shape, num_coords)
mlps_dataloader = DataLoader(
mlps_dataset,
batch_size=mlp_batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
)
encoder = Encoder(emb_size).to(device)
decoder = Decoder([], emb_size, prep_shape, arch_prediction=False).to(device)
trainer = MLPTrainer(device, logdir)
trainer.train(encoder, decoder, mlps_dataloader, num_epochs)