Skip to content

Commit

Permalink
remove logging terminology from wandb client
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Apr 1, 2023
1 parent 0d41139 commit cd02666
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 295 deletions.
37 changes: 19 additions & 18 deletions maze_transformer/training/training.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from datetime import datetime
from pathlib import Path
from typing import Callable
Expand All @@ -12,7 +13,7 @@
from maze_transformer.training.config import ConfigHolder, TrainConfig
from maze_transformer.training.dataset import GPTDatasetConfig
from maze_transformer.training.mazedataset import MazeDataset
from maze_transformer.training.wandb_logger import WandbLogger
from maze_transformer.training.wandb_client import WandbClient


@freeze
Expand All @@ -34,15 +35,15 @@ class TRAIN_SAVE_FILES:


def get_dataloader(
dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger
dataset: MazeDataset, cfg: ConfigHolder
) -> DataLoader:
length_stats: StatCounter = StatCounter(dataset.get_all_lengths())
logger.progress(
logging.info(
{"dataset_seq_len_stats": length_stats.serialize(typecast=lambda x: str(x))}
)

logger.progress(f"Loaded {len(dataset)} sequences")
logger.progress("Creating dataloader")
logging.info(f"Loaded {len(dataset)} sequences")
logging.info("Creating dataloader")
dataloader: DataLoader = DataLoader(
dataset,
batch_size=cfg.train_cfg.batch_size,
Expand All @@ -55,25 +56,25 @@ def get_dataloader(
def train(
dataloader: DataLoader,
cfg: ConfigHolder,
logger: WandbLogger,
wandb_client: WandbClient,
output_dir: Path,
device: torch.device,
) -> None:
logger.progress("Initializing model")
logging.info("Initializing model")
model: HookedTransformer = cfg.create_model()
logger.summary({"device": str(device), "model.device": model.cfg.device})
wandb_client.summary({"device": str(device), "model.device": model.cfg.device})

logger.progress("Initializing optimizer")
logging.info("Initializing optimizer")
optimizer: torch.optim.Optimizer = cfg.train_cfg.optimizer(
model.parameters(),
**cfg.train_cfg.optimizer_kwargs,
)
logger.summary(dict(model_n_params=model.cfg.n_params))
wandb_client.summary(dict(model_n_params=model.cfg.n_params))

model.train()
logger.progress("Starting training")
logging.info("Starting training")
n_batches: int = len(dataloader)
logger.summary({"n_batches": n_batches})
wandb_client.summary({"n_batches": n_batches})

checkpoint_interval_iters: int = int(
cfg.train_cfg.checkpoint_interval // cfg.train_cfg.batch_size
Expand All @@ -90,7 +91,7 @@ def train(
optimizer.step()
optimizer.zero_grad()

logger.log_metric({"loss": loss})
wandb_client.log_metric({"loss": loss})

del loss

Expand All @@ -100,17 +101,17 @@ def train(
/ TRAIN_SAVE_FILES.checkpoints
/ TRAIN_SAVE_FILES.model_checkpt(iteration)
)
logger.progress(f"Saving model to {model_save_path.as_posix()}")
logging.info(f"Saving model to {model_save_path.as_posix()}")
torch.save(model.state_dict(), model_save_path)
logger.upload_model(
wandb_client.upload_model(
model_save_path, aliases=["latest", f"iter-{iteration}"]
)

# save the final model
# ==================================================
final_model_path: Path = output_dir / TRAIN_SAVE_FILES.model_final
logger.progress(f"Saving final model to {final_model_path.as_posix()}")
logging.info(f"Saving final model to {final_model_path.as_posix()}")
torch.save(model.state_dict(), final_model_path)
logger.upload_model(final_model_path, aliases=["latest", "final"])
wandb_client.upload_model(final_model_path, aliases=["latest", "final"])

logger.progress("Done!")
logging.info("Done!")
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Union
Expand All @@ -20,30 +19,23 @@ class WandbJobType(Enum):
TRAIN_MODEL = "train-model"


class WandbLogger:
class WandbClient:
def __init__(self, run: Run):
self._run: Run = run

@classmethod
def create(
cls, config: Dict, project: Union[WandbProject, str], job_type: WandbJobType
) -> WandbLogger:
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
) -> WandbClient:
logging.info(f"{config =}")

run = wandb.init(
config=config,
project=(project.value if isinstance(project, WandbProject) else project),
job_type=job_type.value,
)

logger = WandbLogger(run)
logger.progress(f"{config =}")
return logger
return WandbClient(run)

def upload_model(self, model_path: Path, aliases=None) -> None:
artifact = wandb.Artifact(name=wandb.run.id, type="model")
Expand All @@ -60,7 +52,3 @@ def log_metric(self, data: Dict[str, Any]) -> None:

def summary(self, data: Dict[str, Any]) -> None:
self._run.summary.update(data)

@staticmethod
def progress(message: str) -> None:
logging.info(message)
9 changes: 9 additions & 0 deletions maze_transformer/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import logging
import os
import random
import sys

import numpy as np
import torch

DEFAULT_SEED = 42

logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)


def get_device():
"""Get the torch.device instance on which torch.Tensors should be allocated."""
Expand Down
Loading

0 comments on commit cd02666

Please sign in to comment.