Skip to content

Commit

Permalink
remove logging terminology from wandb client
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Apr 1, 2023
1 parent 0d41139 commit 5f63cca
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 80 deletions.
37 changes: 19 additions & 18 deletions maze_transformer/training/training.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from datetime import datetime
from pathlib import Path
from typing import Callable
Expand All @@ -12,7 +13,7 @@
from maze_transformer.training.config import ConfigHolder, TrainConfig
from maze_transformer.training.dataset import GPTDatasetConfig
from maze_transformer.training.mazedataset import MazeDataset
from maze_transformer.training.wandb_logger import WandbLogger
from maze_transformer.training.wandb_client import WandbClient


@freeze
Expand All @@ -34,15 +35,15 @@ class TRAIN_SAVE_FILES:


def get_dataloader(
dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger
dataset: MazeDataset, cfg: ConfigHolder
) -> DataLoader:
length_stats: StatCounter = StatCounter(dataset.get_all_lengths())
logger.progress(
logging.info(
{"dataset_seq_len_stats": length_stats.serialize(typecast=lambda x: str(x))}
)

logger.progress(f"Loaded {len(dataset)} sequences")
logger.progress("Creating dataloader")
logging.info(f"Loaded {len(dataset)} sequences")
logging.info("Creating dataloader")
dataloader: DataLoader = DataLoader(
dataset,
batch_size=cfg.train_cfg.batch_size,
Expand All @@ -55,25 +56,25 @@ def get_dataloader(
def train(
dataloader: DataLoader,
cfg: ConfigHolder,
logger: WandbLogger,
wandb_client: WandbClient,
output_dir: Path,
device: torch.device,
) -> None:
logger.progress("Initializing model")
logging.info("Initializing model")
model: HookedTransformer = cfg.create_model()
logger.summary({"device": str(device), "model.device": model.cfg.device})
wandb_client.summary({"device": str(device), "model.device": model.cfg.device})

logger.progress("Initializing optimizer")
logging.info("Initializing optimizer")
optimizer: torch.optim.Optimizer = cfg.train_cfg.optimizer(
model.parameters(),
**cfg.train_cfg.optimizer_kwargs,
)
logger.summary(dict(model_n_params=model.cfg.n_params))
wandb_client.summary(dict(model_n_params=model.cfg.n_params))

model.train()
logger.progress("Starting training")
logging.info("Starting training")
n_batches: int = len(dataloader)
logger.summary({"n_batches": n_batches})
wandb_client.summary({"n_batches": n_batches})

checkpoint_interval_iters: int = int(
cfg.train_cfg.checkpoint_interval // cfg.train_cfg.batch_size
Expand All @@ -90,7 +91,7 @@ def train(
optimizer.step()
optimizer.zero_grad()

logger.log_metric({"loss": loss})
wandb_client.log_metric({"loss": loss})

del loss

Expand All @@ -100,17 +101,17 @@ def train(
/ TRAIN_SAVE_FILES.checkpoints
/ TRAIN_SAVE_FILES.model_checkpt(iteration)
)
logger.progress(f"Saving model to {model_save_path.as_posix()}")
logging.info(f"Saving model to {model_save_path.as_posix()}")
torch.save(model.state_dict(), model_save_path)
logger.upload_model(
wandb_client.upload_model(
model_save_path, aliases=["latest", f"iter-{iteration}"]
)

# save the final model
# ==================================================
final_model_path: Path = output_dir / TRAIN_SAVE_FILES.model_final
logger.progress(f"Saving final model to {final_model_path.as_posix()}")
logging.info(f"Saving final model to {final_model_path.as_posix()}")
torch.save(model.state_dict(), final_model_path)
logger.upload_model(final_model_path, aliases=["latest", "final"])
wandb_client.upload_model(final_model_path, aliases=["latest", "final"])

logger.progress("Done!")
logging.info("Done!")
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Union
Expand All @@ -20,30 +19,23 @@ class WandbJobType(Enum):
TRAIN_MODEL = "train-model"


class WandbLogger:
class WandbClient:
def __init__(self, run: Run):
self._run: Run = run

@classmethod
def create(
cls, config: Dict, project: Union[WandbProject, str], job_type: WandbJobType
) -> WandbLogger:
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
) -> WandbClient:
logging.info(f"{config =}")

run = wandb.init(
config=config,
project=(project.value if isinstance(project, WandbProject) else project),
job_type=job_type.value,
)

logger = WandbLogger(run)
logger.progress(f"{config =}")
return logger
return WandbClient(run)

def upload_model(self, model_path: Path, aliases=None) -> None:
artifact = wandb.Artifact(name=wandb.run.id, type="model")
Expand All @@ -60,7 +52,3 @@ def log_metric(self, data: Dict[str, Any]) -> None:

def summary(self, data: Dict[str, Any]) -> None:
self._run.summary.update(data)

@staticmethod
def progress(message: str) -> None:
logging.info(message)
9 changes: 9 additions & 0 deletions maze_transformer/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import logging
import os
import random
import sys

import numpy as np
import torch

DEFAULT_SEED = 42

logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)


def get_device():
"""Get the torch.device instance on which torch.Tensors should be allocated."""
Expand Down
45 changes: 12 additions & 33 deletions notebooks/plot_attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
"\n",
"# Our Code\n",
"from maze_transformer.utils.notebook_utils import configure_notebook\n",
"from maze_transformer.generation.latticemaze import LatticeMaze\n",
"from maze_transformer.generation.latticemaze import LatticeMaze, SolvedMaze\n",
"from maze_transformer.generation.generators import LatticeMazeGenerators\n",
"from maze_transformer.training.tokenizer import MazeTokenizer, SPECIAL_TOKENS, HuggingMazeTokenizer\n",
"from maze_transformer.evaluation.plot_maze import plot_multi_paths, PathFormat\n",
"from maze_transformer.training.tokenizer import maze_to_tokens, SPECIAL_TOKENS, HuggingMazeTokenizer\n",
"from maze_transformer.evaluation.plot_maze import MazePlot, PathFormat\n",
"from maze_transformer.evaluation.eval_model import decode_maze_tokens_to_coords, load_model_with_configs"
]
},
Expand All @@ -54,7 +54,7 @@
"# this should point towards a directory containing a run. \n",
"# If you don't have any runs, you can create a dataset with `poetry run python scripts/create_dataset.py create ./data/maze 10 --grid_n=4`\n",
"# Then train a model with poetry run python scripts/train_model.py ./data/maze/g4-n10`\n",
"run_path = Path(\"../data/maze/g6-n5M\")\n",
"run_path = Path(\"../data/maze/g4-n10\")\n",
"assert run_path.exists(), f\"Run path {run_path.as_posix()} does not exist\"\n",
"model_path = list(sorted(run_path.glob(\"**/model.final.pt\"), key=os.path.getmtime))[\n",
"\t-1\n",
Expand Down Expand Up @@ -86,13 +86,10 @@
"\tc_end = c_end,\n",
"))\n",
"\n",
"solved_maze: MazeTokenizer = MazeTokenizer(\n",
"\tmaze=maze,\n",
"\tsolution=path_true,\n",
")\n",
"\n",
"# tokenize the maze\n",
"maze_only_tokens: list[str] = solved_maze.as_tokens(cfg.dataset_cfg.node_token_map , solution = False) + [ SPECIAL_TOKENS[\"path_start\"] ]\n",
"tokens = maze_to_tokens(maze, path_true, cfg.dataset_cfg.node_token_map)\n",
"path_start_index = tokens.index(SPECIAL_TOKENS[\"path_start\"])\n",
"maze_only_tokens = tokens[:path_start_index + 1]\n",
"\n",
"print(\"maze tokens:\", maze_only_tokens)\n",
"\n",
Expand Down Expand Up @@ -197,33 +194,15 @@
"\n",
"# plot the maze and both solutions\n",
"# for label, fmt, color, path in paths\n",
"plot_multi_paths(\n",
"\tmaze = maze,\n",
"\tpaths = [\n",
"\t\tPathFormat(path_true, \"true\", \"-\", \"red\", {'width': 0.015}),\n",
"\t\tPathFormat(np.array(path_predicted), \"predicted\", \":\", \"blue\", {}),\n",
"\t],\n",
")"
"MazePlot(maze).add_true_path(path_true).add_predicted_path(path_predicted).show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mml_cw_new",
"display_name": "maze-transformer",
"language": "python",
"name": "python3"
"name": "maze-transformer"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -235,7 +214,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
Expand All @@ -246,4 +225,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
15 changes: 8 additions & 7 deletions scripts/train_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
from pathlib import Path
from typing import Union

Expand All @@ -8,9 +9,9 @@
from maze_transformer.training.config import GPT_CONFIGS, TRAINING_CONFIGS, ConfigHolder
from maze_transformer.training.mazedataset import MazeDataset
from maze_transformer.training.training import TRAIN_SAVE_FILES, get_dataloader, train
from maze_transformer.training.wandb_logger import (
from maze_transformer.training.wandb_client import (
WandbJobType,
WandbLogger,
WandbClient,
WandbProject,
)
from maze_transformer.utils.utils import get_device
Expand Down Expand Up @@ -44,15 +45,15 @@ def train_model(
output_path: Path = Path(basepath) / output_dir_name
(output_path / TRAIN_SAVE_FILES.checkpoints).mkdir(parents=True)

logger = WandbLogger.create(
wandb_client = WandbClient.create(
config=cfg.serialize(),
project=wandb_project,
job_type=WandbJobType.TRAIN_MODEL,
)

logger.progress("Loaded data config, initialized logger")
logging.info("Loaded data config, initialized wandb_client")

logger.summary(
wandb_client.summary(
dict(
logger_cfg={
"output_dir": str(output_path),
Expand All @@ -63,10 +64,10 @@ def train_model(
)
)

dataloader = get_dataloader(dataset, cfg, logger)
dataloader = get_dataloader(dataset, cfg, wandb_client)
device = get_device()

train(dataloader, cfg, logger, output_path, device)
train(dataloader, cfg, wandb_client, output_path, device)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions scripts/upload_dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from pathlib import Path

from maze_transformer.training.wandb_logger import (
from maze_transformer.training.wandb_client import (
WandbJobType,
WandbLogger,
WandbClient,
WandbProject,
)


def upload_dataset(name: str, path: Path):
logger = WandbLogger.create(
wandb_client = WandbClient.create(
config={},
project=WandbProject.UNDERSTANDING_SEARCH,
job_type=WandbJobType.CREATE_DATASET,
)
logger.upload_dataset(name, path)
wandb_client.upload_dataset(name, path)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from maze_transformer.evaluation.eval_model import load_model_with_configs
from maze_transformer.training.wandb_logger import WandbProject
from maze_transformer.training.wandb_client import WandbProject
from scripts.create_dataset import create_dataset
from scripts.train_model import train_model

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_train_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from maze_transformer.training.wandb_logger import WandbProject
from maze_transformer.training.wandb_client import WandbProject
from scripts.create_dataset import create_dataset
from scripts.train_model import train_model

Expand Down

0 comments on commit 5f63cca

Please sign in to comment.