diff --git a/maze_transformer/training/training.py b/maze_transformer/training/training.py index ffa0609d..c6524610 100644 --- a/maze_transformer/training/training.py +++ b/maze_transformer/training/training.py @@ -1,3 +1,4 @@ +import logging from datetime import datetime from pathlib import Path from typing import Callable @@ -12,7 +13,7 @@ from maze_transformer.training.config import ConfigHolder, TrainConfig from maze_transformer.training.dataset import GPTDatasetConfig from maze_transformer.training.mazedataset import MazeDataset -from maze_transformer.training.wandb_logger import WandbLogger +from maze_transformer.training.wandb_client import WandbClient @freeze @@ -33,17 +34,14 @@ class TRAIN_SAVE_FILES: model_final: str = "model.final.pt" -def get_dataloader( - dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger -) -> DataLoader: +def get_dataloader(dataset: MazeDataset, cfg: ConfigHolder) -> DataLoader: length_stats: StatCounter = StatCounter(dataset.get_all_lengths()) - logger.summary({"dataset_seq_len_stats_summary": length_stats.summary()}) - logger.summary( + logging.info( {"dataset_seq_len_stats": length_stats.serialize(typecast=lambda x: str(x))} ) - logger.progress(f"Loaded {len(dataset)} sequences") - logger.progress("Creating dataloader") + logging.info(f"Loaded {len(dataset)} sequences") + logging.info("Creating dataloader") dataloader: DataLoader = DataLoader( dataset, batch_size=cfg.train_cfg.batch_size, @@ -56,25 +54,25 @@ def get_dataloader( def train( dataloader: DataLoader, cfg: ConfigHolder, - logger: WandbLogger, + wandb_client: WandbClient, output_dir: Path, device: torch.device, ) -> None: - logger.progress("Initializing model") + logging.info("Initializing model") model: HookedTransformer = cfg.create_model() - logger.summary({"device": str(device), "model.device": model.cfg.device}) + wandb_client.summary({"device": str(device), "model.device": model.cfg.device}) - logger.progress("Initializing optimizer") + logging.info("Initializing optimizer") optimizer: torch.optim.Optimizer = cfg.train_cfg.optimizer( model.parameters(), **cfg.train_cfg.optimizer_kwargs, ) - logger.summary(dict(model_n_params=model.cfg.n_params)) + wandb_client.summary(dict(model_n_params=model.cfg.n_params)) model.train() - logger.progress("Starting training") + logging.info("Starting training") n_batches: int = len(dataloader) - logger.summary({"n_batches": n_batches}) + wandb_client.summary({"n_batches": n_batches}) checkpoint_interval_iters: int = int( cfg.train_cfg.checkpoint_interval // cfg.train_cfg.batch_size @@ -91,7 +89,7 @@ def train( optimizer.step() optimizer.zero_grad() - logger.log_metric({"loss": loss}) + wandb_client.log_metric({"loss": loss}) del loss @@ -101,17 +99,17 @@ def train( / TRAIN_SAVE_FILES.checkpoints / TRAIN_SAVE_FILES.model_checkpt(iteration) ) - logger.progress(f"Saving model to {model_save_path.as_posix()}") + logging.info(f"Saving model to {model_save_path.as_posix()}") torch.save(model.state_dict(), model_save_path) - logger.upload_model( + wandb_client.upload_model( model_save_path, aliases=["latest", f"iter-{iteration}"] ) # save the final model # ================================================== final_model_path: Path = output_dir / TRAIN_SAVE_FILES.model_final - logger.progress(f"Saving final model to {final_model_path.as_posix()}") + logging.info(f"Saving final model to {final_model_path.as_posix()}") torch.save(model.state_dict(), final_model_path) - logger.upload_model(final_model_path, aliases=["latest", "final"]) + wandb_client.upload_model(final_model_path, aliases=["latest", "final"]) - logger.progress("Done!") + logging.info("Done!") diff --git a/maze_transformer/training/wandb_logger.py b/maze_transformer/training/wandb_client.py similarity index 76% rename from maze_transformer/training/wandb_logger.py rename to maze_transformer/training/wandb_client.py index decab8b2..633be1dc 100644 --- a/maze_transformer/training/wandb_logger.py +++ b/maze_transformer/training/wandb_client.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import sys from enum import Enum from pathlib import Path from typing import Any, Dict, Union @@ -20,20 +19,15 @@ class WandbJobType(Enum): TRAIN_MODEL = "train-model" -class WandbLogger: +class WandbClient: def __init__(self, run: Run): self._run: Run = run @classmethod def create( cls, config: Dict, project: Union[WandbProject, str], job_type: WandbJobType - ) -> WandbLogger: - logging.basicConfig( - stream=sys.stdout, - level=logging.INFO, - format="%(asctime)s %(levelname)s %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) + ) -> WandbClient: + logging.info(f"{config =}") run = wandb.init( config=config, @@ -41,9 +35,7 @@ def create( job_type=job_type.value, ) - logger = WandbLogger(run) - logger.progress(f"{config =}") - return logger + return WandbClient(run) def upload_model(self, model_path: Path, aliases=None) -> None: artifact = wandb.Artifact(name=wandb.run.id, type="model") @@ -60,7 +52,3 @@ def log_metric(self, data: Dict[str, Any]) -> None: def summary(self, data: Dict[str, Any]) -> None: self._run.summary.update(data) - - @staticmethod - def progress(message: str) -> None: - logging.info(message) diff --git a/maze_transformer/utils/utils.py b/maze_transformer/utils/utils.py index 183095cb..ba0c7686 100644 --- a/maze_transformer/utils/utils.py +++ b/maze_transformer/utils/utils.py @@ -1,11 +1,20 @@ +import logging import os import random +import sys import numpy as np import torch DEFAULT_SEED = 42 +logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + def get_device(): """Get the torch.device instance on which torch.Tensors should be allocated.""" diff --git a/notebooks/plot_attention.ipynb b/notebooks/plot_attention.ipynb index b5088fe8..c7b25b6a 100644 --- a/notebooks/plot_attention.ipynb +++ b/notebooks/plot_attention.ipynb @@ -1,249 +1,228 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "# Generic\n", - "import html\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "# Transformers\n", - "from circuitsvis.attention import attention_heads\n", - "from circuitsvis.tokens import colored_tokens_multi\n", - "\n", - "# Numerical Computing\n", - "import numpy as np\n", - "import torch\n", - "\n", - "# Our Code\n", - "from maze_transformer.utils.notebook_utils import configure_notebook\n", - "from maze_transformer.generation.latticemaze import LatticeMaze\n", - "from maze_transformer.generation.generators import LatticeMazeGenerators\n", - "from maze_transformer.training.tokenizer import MazeTokenizer, SPECIAL_TOKENS, HuggingMazeTokenizer\n", - "from maze_transformer.evaluation.plot_maze import plot_multi_paths, PathFormat\n", - "from maze_transformer.evaluation.eval_model import decode_maze_tokens_to_coords, load_model_with_configs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "# Setup\n", - "device = configure_notebook(seed=42, dark_mode=True)\n", - "# We won't be training any models\n", - "torch.set_grad_enabled(False)\n", - "\n", - "\n", - "# Get latest model\n", - "# this should point towards a directory containing a run. \n", - "# If you don't have any runs, you can create a dataset with `poetry run python scripts/create_dataset.py create ./data/maze 10 --grid_n=4`\n", - "# Then train a model with poetry run python scripts/train_model.py ./data/maze/g4-n10`\n", - "run_path = Path(\"../data/maze/g6-n5M\")\n", - "assert run_path.exists(), f\"Run path {run_path.as_posix()} does not exist\"\n", - "model_path = list(sorted(run_path.glob(\"**/model.final.pt\"), key=os.path.getmtime))[\n", - "\t-1\n", - "].resolve()\n", - "model, cfg = load_model_with_configs(model_path)\n", - "maze_path = run_path / \"maze_tokens.jsonl\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "# generate a maze\n", - "grid_n: int = cfg.dataset_cfg.grid_n\n", - "maze: LatticeMaze = LatticeMazeGenerators.gen_dfs((grid_n, grid_n))\n", - "c_start = (0, 0)\n", - "c_end = (grid_n - 1, grid_n - 1)\n", - "\n", - "# solve the maze explicitly\n", - "path_true = np.array(maze.find_shortest_path(\n", - "\tc_start = c_start,\n", - "\tc_end = c_end,\n", - "))\n", - "\n", - "solved_maze: MazeTokenizer = MazeTokenizer(\n", - "\tmaze=maze,\n", - "\tsolution=path_true,\n", - ")\n", - "\n", - "# tokenize the maze\n", - "maze_only_tokens: list[str] = solved_maze.as_tokens(cfg.dataset_cfg.node_token_map , solution = False) + [ SPECIAL_TOKENS[\"path_start\"] ]\n", - "\n", - "print(\"maze tokens:\", maze_only_tokens)\n", - "\n", - "array = model.to_tokens(\" \".join(maze_only_tokens), prepend_bos=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "# have the model predict some tokens\n", - "context_str: list[str] = maze_only_tokens\n", - "\n", - "# escape for html\n", - "context_str = [ html.escape(t) for t in context_str ]\n", - "\n", - "array_tensor = torch.tensor(array).long().to(device)\n", - "with torch.no_grad():\n", - "\tlogits, cache = model.run_with_cache(array_tensor)\n", - "\n", - "attentions = [w for k, w in cache.items() if 'hook_pattern' in k]\n", - "print(f\"{logits.shape = }\\n{len(attentions) = }\\n{[x.shape for x in attentions] = }\")\n", - "\n", - "# `output.attentions` is a tuple of tensors, where each element of the tuple corresponds to a layer. \n", - "# The tensor has dimensions (1, n_heads, n_positions, n_positions)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "n_layers: int = len(attentions)\n", - "n_heads: int = attentions[0].shape[1]\n", - "n_tokens: int = attentions[0].shape[2]\n", - "attention_to_plot = torch.concatenate(attentions, dim=0).reshape(-1, n_tokens, n_tokens)\n", - "attention_head_names = [f\"Layer {i} Head {j}\" for i in range(n_layers) for j in range(n_heads)]\n", - "attention_heads(attention_to_plot,maze_only_tokens, attention_head_names)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "#! ALEX note - there used to be a np.power(head_np, 1/4) here, not sure what that's about?\n", - "FROM_TOKEN = -1 # Look at attention from this token position to the rest of the sequence\n", - "attentions_from_token = torch.concatenate([w[0, :, FROM_TOKEN, :] for w in attentions], dim=0)\n", - "colored_tokens_multi(context_str, attentions_from_token.T, labels=attention_head_names)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "def prediction_contained_a_coordinate_token(tokens: list[str], tokenizer: HuggingMazeTokenizer) -> bool:\n", - "\t\"\"\"Check if the prediction contains a coordinate token\"\"\"\n", - "\tfor t in tokens:\n", - "\t\tif t not in list(tokenizer.special_tokens_map.values()) + tokenizer.additional_special_tokens:\n", - "\t\t\treturn True\n", - "\tprint(\"FAIL: Sampled a path - No coordinate token found before EOS\")\n", - "\treturn False\n", - "\n", - "predicted_tokens = []\n", - "while not prediction_contained_a_coordinate_token(predicted_tokens, model.tokenizer):\n", - "\tpredictions = model.generate(array_tensor, max_new_tokens=50, stop_at_eos=True, verbose=False)\n", - "\tpredicted_tokens = model.to_str_tokens(predictions)[len(maze_only_tokens):]\n", - "print(\"SUCCESS: Model predicted the path:\")\n", - "print(predicted_tokens)\n", - "\n", - "path_predicted: list[tuple[int,int]] = decode_maze_tokens_to_coords(\n", - "\tpredicted_tokens,\n", - "\tmazedata_cfg = cfg.dataset_cfg, \n", - "\twhen_noncoord = \"skip\",\n", - ")\n", - "\n", - "# plot the maze and both solutions\n", - "# for label, fmt, color, path in paths\n", - "plot_multi_paths(\n", - "\tmaze = maze,\n", - "\tpaths = [\n", - "\t\tPathFormat(path_true, \"true\", \"-\", \"red\", {'width': 0.015}),\n", - "\t\tPathFormat(np.array(path_predicted), \"predicted\", \":\", \"blue\", {}),\n", - "\t],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mml_cw_new", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "170637793197da0d440deb6cb249c898d613b24c548839ecbbac11596710dfc2" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Generic\n", + "import html\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "# Transformers\n", + "from circuitsvis.attention import attention_heads\n", + "from circuitsvis.tokens import colored_tokens_multi\n", + "\n", + "# Numerical Computing\n", + "import numpy as np\n", + "import torch\n", + "\n", + "# Our Code\n", + "from maze_transformer.utils.notebook_utils import configure_notebook\n", + "from maze_transformer.generation.latticemaze import LatticeMaze, SolvedMaze\n", + "from maze_transformer.generation.generators import LatticeMazeGenerators\n", + "from maze_transformer.training.tokenizer import maze_to_tokens, SPECIAL_TOKENS, HuggingMazeTokenizer\n", + "from maze_transformer.evaluation.plot_maze import MazePlot, PathFormat\n", + "from maze_transformer.evaluation.eval_model import decode_maze_tokens_to_coords, load_model_with_configs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Setup\n", + "device = configure_notebook(seed=42, dark_mode=True)\n", + "# We won't be training any models\n", + "torch.set_grad_enabled(False)\n", + "\n", + "\n", + "# Get latest model\n", + "# this should point towards a directory containing a run. \n", + "# If you don't have any runs, you can create a dataset with `poetry run python scripts/create_dataset.py create ./data/maze 10 --grid_n=4`\n", + "# Then train a model with poetry run python scripts/train_model.py ./data/maze/g4-n10`\n", + "run_path = Path(\"../data/maze/g4-n10\")\n", + "assert run_path.exists(), f\"Run path {run_path.as_posix()} does not exist\"\n", + "model_path = list(sorted(run_path.glob(\"**/model.final.pt\"), key=os.path.getmtime))[\n", + "\t-1\n", + "].resolve()\n", + "model, cfg = load_model_with_configs(model_path)\n", + "maze_path = run_path / \"maze_tokens.jsonl\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# generate a maze\n", + "grid_n: int = cfg.dataset_cfg.grid_n\n", + "maze: LatticeMaze = LatticeMazeGenerators.gen_dfs((grid_n, grid_n))\n", + "c_start = (0, 0)\n", + "c_end = (grid_n - 1, grid_n - 1)\n", + "\n", + "# solve the maze explicitly\n", + "path_true = np.array(maze.find_shortest_path(\n", + "\tc_start = c_start,\n", + "\tc_end = c_end,\n", + "))\n", + "\n", + "# tokenize the maze\n", + "tokens = maze_to_tokens(maze, path_true, cfg.dataset_cfg.node_token_map)\n", + "path_start_index = tokens.index(SPECIAL_TOKENS[\"path_start\"])\n", + "maze_only_tokens = tokens[:path_start_index + 1]\n", + "\n", + "print(\"maze tokens:\", maze_only_tokens)\n", + "\n", + "array = model.to_tokens(\" \".join(maze_only_tokens), prepend_bos=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# have the model predict some tokens\n", + "context_str: list[str] = maze_only_tokens\n", + "\n", + "# escape for html\n", + "context_str = [ html.escape(t) for t in context_str ]\n", + "\n", + "array_tensor = torch.tensor(array).long().to(device)\n", + "with torch.no_grad():\n", + "\tlogits, cache = model.run_with_cache(array_tensor)\n", + "\n", + "attentions = [w for k, w in cache.items() if 'hook_pattern' in k]\n", + "print(f\"{logits.shape = }\\n{len(attentions) = }\\n{[x.shape for x in attentions] = }\")\n", + "\n", + "# `output.attentions` is a tuple of tensors, where each element of the tuple corresponds to a layer. \n", + "# The tensor has dimensions (1, n_heads, n_positions, n_positions)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "n_layers: int = len(attentions)\n", + "n_heads: int = attentions[0].shape[1]\n", + "n_tokens: int = attentions[0].shape[2]\n", + "attention_to_plot = torch.concatenate(attentions, dim=0).reshape(-1, n_tokens, n_tokens)\n", + "attention_head_names = [f\"Layer {i} Head {j}\" for i in range(n_layers) for j in range(n_heads)]\n", + "attention_heads(attention_to_plot,maze_only_tokens, attention_head_names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "#! ALEX note - there used to be a np.power(head_np, 1/4) here, not sure what that's about?\n", + "FROM_TOKEN = -1 # Look at attention from this token position to the rest of the sequence\n", + "attentions_from_token = torch.concatenate([w[0, :, FROM_TOKEN, :] for w in attentions], dim=0)\n", + "colored_tokens_multi(context_str, attentions_from_token.T, labels=attention_head_names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "def prediction_contained_a_coordinate_token(tokens: list[str], tokenizer: HuggingMazeTokenizer) -> bool:\n", + "\t\"\"\"Check if the prediction contains a coordinate token\"\"\"\n", + "\tfor t in tokens:\n", + "\t\tif t not in list(tokenizer.special_tokens_map.values()) + tokenizer.additional_special_tokens:\n", + "\t\t\treturn True\n", + "\tprint(\"FAIL: Sampled a path - No coordinate token found before EOS\")\n", + "\treturn False\n", + "\n", + "predicted_tokens = []\n", + "while not prediction_contained_a_coordinate_token(predicted_tokens, model.tokenizer):\n", + "\tpredictions = model.generate(array_tensor, max_new_tokens=50, stop_at_eos=True, verbose=False)\n", + "\tpredicted_tokens = model.to_str_tokens(predictions)[len(maze_only_tokens):]\n", + "print(\"SUCCESS: Model predicted the path:\")\n", + "print(predicted_tokens)\n", + "\n", + "path_predicted: list[tuple[int,int]] = decode_maze_tokens_to_coords(\n", + "\tpredicted_tokens,\n", + "\tmazedata_cfg = cfg.dataset_cfg, \n", + "\twhen_noncoord = \"skip\",\n", + ")\n", + "\n", + "# plot the maze and both solutions\n", + "# for label, fmt, color, path in paths\n", + "MazePlot(maze).add_true_path(path_true).add_predicted_path(path_predicted).show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "maze-transformer", + "language": "python", + "name": "maze-transformer" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "170637793197da0d440deb6cb249c898d613b24c548839ecbbac11596710dfc2" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/scripts/train_model.py b/scripts/train_model.py index 77030ec2..adeb8ce8 100644 --- a/scripts/train_model.py +++ b/scripts/train_model.py @@ -1,4 +1,5 @@ import json +import logging from pathlib import Path from typing import Union @@ -8,9 +9,9 @@ from maze_transformer.training.config import GPT_CONFIGS, TRAINING_CONFIGS, ConfigHolder from maze_transformer.training.mazedataset import MazeDataset from maze_transformer.training.training import TRAIN_SAVE_FILES, get_dataloader, train -from maze_transformer.training.wandb_logger import ( +from maze_transformer.training.wandb_client import ( + WandbClient, WandbJobType, - WandbLogger, WandbProject, ) from maze_transformer.utils.utils import get_device @@ -44,15 +45,15 @@ def train_model( output_path: Path = Path(basepath) / output_dir_name (output_path / TRAIN_SAVE_FILES.checkpoints).mkdir(parents=True) - logger = WandbLogger.create( + wandb_client = WandbClient.create( config=cfg.serialize(), project=wandb_project, job_type=WandbJobType.TRAIN_MODEL, ) - logger.progress("Loaded data config, initialized logger") + logging.info("Loaded data config, initialized wandb_client") - logger.summary( + wandb_client.summary( dict( logger_cfg={ "output_dir": str(output_path), @@ -63,10 +64,10 @@ def train_model( ) ) - dataloader = get_dataloader(dataset, cfg, logger) + dataloader = get_dataloader(dataset, cfg) device = get_device() - train(dataloader, cfg, logger, output_path, device) + train(dataloader, cfg, wandb_client, output_path, device) if __name__ == "__main__": diff --git a/scripts/upload_dataset.py b/scripts/upload_dataset.py index 3ba08b3b..22faffd9 100644 --- a/scripts/upload_dataset.py +++ b/scripts/upload_dataset.py @@ -1,19 +1,19 @@ from pathlib import Path -from maze_transformer.training.wandb_logger import ( +from maze_transformer.training.wandb_client import ( + WandbClient, WandbJobType, - WandbLogger, WandbProject, ) def upload_dataset(name: str, path: Path): - logger = WandbLogger.create( + wandb_client = WandbClient.create( config={}, project=WandbProject.UNDERSTANDING_SEARCH, job_type=WandbJobType.CREATE_DATASET, ) - logger.upload_dataset(name, path) + wandb_client.upload_dataset(name, path) if __name__ == "__main__": diff --git a/tests/integration/test_eval_model.py b/tests/integration/test_eval_model.py index f2333ed1..d3ae1195 100644 --- a/tests/integration/test_eval_model.py +++ b/tests/integration/test_eval_model.py @@ -11,7 +11,7 @@ import torch from maze_transformer.evaluation.eval_model import load_model_with_configs -from maze_transformer.training.wandb_logger import WandbProject +from maze_transformer.training.wandb_client import WandbProject from scripts.create_dataset import create_dataset from scripts.train_model import train_model diff --git a/tests/integration/test_train_model.py b/tests/integration/test_train_model.py index 545600e8..f6e7649a 100644 --- a/tests/integration/test_train_model.py +++ b/tests/integration/test_train_model.py @@ -1,6 +1,6 @@ import pytest -from maze_transformer.training.wandb_logger import WandbProject +from maze_transformer.training.wandb_client import WandbProject from scripts.create_dataset import create_dataset from scripts.train_model import train_model