Skip to content

Commit

Permalink
update train script
Browse files Browse the repository at this point in the history
  • Loading branch information
davidfitzek committed Aug 31, 2023
1 parent bd434dd commit 1268ca8
Showing 1 changed file with 47 additions and 21 deletions.
68 changes: 47 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
torch.set_float32_matmul_precision("medium")


def main(config_path: str, config_name: str, dataset_path: str):
yaml_dict = load_yaml_file(config_path, config_name)
config = create_config_from_yaml(yaml_dict)
def setup_environment(config):
torch.manual_seed(config.seed)
np.random.seed(config.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
config.device = device


def load_data(config, dataset_path):
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html
# train_loader, val_loader = get_chunked_dataloader(
# train_loader, val_loader = get_rydberg_dataloader(
Expand All @@ -53,28 +53,22 @@ def main(config_path: str, config_name: str, dataset_path: str):
num_workers=config.num_workers,
data_path=dataset_path,
)
input_array = set_example_input_array(train_loader)
return train_loader, val_loader

model = get_rydberg_graph_encoder_decoder(config)

# Compile model
def create_model(config):
model = get_rydberg_graph_encoder_decoder(config)
if config.compile:
# check that device is cuda
if device != "cuda":
if config.device != "cuda":
raise ValueError(
"Cannot compile model if device is not cuda. "
"Please set compile to False."
)
model = torch.compile(model)
return model

# Setup tensorboard logger
logger = TensorBoardLogger(save_dir="logs")
log_path = f"logs/lightning_logs/version_{logger.version}"
rydberg_gpt_trainer = RydbergGPTTrainer(
model, config, logger=logger # , example_input_array=input_array
)

# Callbacks
def setup_callbacks(config, log_path):
callbacks = [
ModelCheckpoint(
monitor="train_loss",
Expand All @@ -86,9 +80,11 @@ def main(config_path: str, config_name: str, dataset_path: str):
StochasticWeightAveraging(config.learning_rate),
ModelInfoCallback(),
LearningRateMonitor(logging_interval="step"),
# StopOnLossThreshold(loss_threshold=150.0),
]
return callbacks


def setup_profiler(config, log_path):
# Monitoring
if config.advanced_monitoring:
# https://lightning.ai/docs/pytorch/stable/common/trainer.html
Expand Down Expand Up @@ -116,13 +112,43 @@ def main(config_path: str, config_name: str, dataset_path: str):
dirpath=log_path,
filename="performance_logs",
)
return profiler

# Distributed training
if config.strategy == "ddp":
strategy = DDPStrategy(find_unused_parameters=True)
else:
strategy = config.strategy

def main(config_path: str, config_name: str, dataset_path: str):
yaml_dict = load_yaml_file(config_path, config_name)
config = create_config_from_yaml(yaml_dict)

# Setup Environment
setup_environment(config)

# Load data
train_loader, val_loader = load_data(config, dataset_path)
input_array = set_example_input_array(train_loader)

# Create Model
model = create_model(config)

# Setup tensorboard logger
logger = TensorBoardLogger(save_dir="logs")
log_path = f"logs/lightning_logs/version_{logger.version}"

rydberg_gpt_trainer = RydbergGPTTrainer(
model, config, logger=logger # , example_input_array=input_array
)

# Callbacks
callbacks = setup_callbacks(config, log_path)

# Profiler
profiler = setup_profiler(config, log_path)

# Distributed training
strategy = (
DDPStrategy(find_unused_parameters=True)
if config.strategy == "ddp"
else config.strategy
)
# Init trainer class
trainer = pl.Trainer(
devices=-1,
Expand Down

0 comments on commit 1268ca8

Please sign in to comment.