From 0eabf50021a9f360b6e710bd64225d34abb91656 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Thu, 19 Sep 2024 14:26:10 +0100 Subject: [PATCH] Upgrade to pytorch-lightning >=2 (#114) * Upgrade to pytorch-lightning >=2 Upgrades pytorch lightning, modifying usage of pl.Trainer to conform with the new way to set the devices and checkpoint. * Update kaggle download commands in README Include the commands for unzipping the files, including the zipped csvs within the original zip from the first challenge. * Try disabling mac memory allocation limit in CI See if we can use the mac GPU for the trainer test * Use cpu for if cuda unavailable in trainer test Avoids trying to use MPS on mac in CI which has insufficient memory * Store created val.csv alongside input csv In preprocessing_utils.py save the created val.csv in the same folder as the input csv instead of in the current working directory, and add logging so the user knows where it is saved. This makes the two functions in the file more consistent in where they save their output. * Support comma-separated device string --- README.md | 4 ++++ preprocessing_utils.py | 20 +++++++++++++++++--- pyproject.toml | 6 +++--- tests/test_trainer.py | 3 ++- train.py | 19 +++++++++++++++---- 5 files changed, 41 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 5b63999..3b99e8d 100644 --- a/README.md +++ b/README.md @@ -264,10 +264,14 @@ cd jigsaw_data # download data kaggle competitions download -c jigsaw-toxic-comment-classification-challenge +unzip jigsaw-toxic-comment-classification-challenge.zip -d jigsaw-toxic-comment-classification-challenge +find jigsaw-toxic-comment-classification-challenge -name '*.csv.zip' | xargs -n1 unzip -d jigsaw-toxic-comment-classification-challenge kaggle competitions download -c jigsaw-unintended-bias-in-toxicity-classification +unzip jigsaw-unintended-bias-in-toxicity-classification.zip -d jigsaw-unintended-bias-in-toxicity-classification kaggle competitions download -c jigsaw-multilingual-toxic-comment-classification +unzip jigsaw-multilingual-toxic-comment-classification.zip -d jigsaw-multilingual-toxic-comment-classification ``` ## Start Training diff --git a/preprocessing_utils.py b/preprocessing_utils.py index 3549744..437db69 100644 --- a/preprocessing_utils.py +++ b/preprocessing_utils.py @@ -1,18 +1,29 @@ import argparse +import logging +from pathlib import Path import numpy as np import pandas as pd +logger = logging.getLogger("preprocessing_utils") +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.INFO, +) + def update_test(test_csv_file): """Combines disjointed test and labels csv files into one file.""" + test_csv_file = Path(test_csv_file) test_set = pd.read_csv(test_csv_file) - data_labels = pd.read_csv(test_csv_file[:-4] + "_labels.csv") + data_labels = pd.read_csv(str(test_csv_file)[:-4] + "_labels.csv") for category in data_labels.columns[1:]: test_set[category] = data_labels[category] if "content" in test_set.columns: test_set.rename(columns={"content": "comment_text"}, inplace=True) - test_set.to_csv(f"{test_csv_file.split('.csv')[0]}_updated.csv") + output_file = test_csv_file.parent / f"{test_csv_file.stem}_updated.csv" + test_set.to_csv(output_file) + logger.info("Updated test set saved to %s", output_file) return test_set @@ -20,12 +31,15 @@ def create_val_set(csv_file, val_fraction): """Takes in a csv file path and creates a validation set out of it specified by val_fraction. """ + csv_file = Path(csv_file) dataset = pd.read_csv(csv_file) np.random.seed(0) dataset_mod = dataset[dataset.toxic != -1] indices = np.random.rand(len(dataset_mod)) > val_fraction val_set = dataset_mod[~indices] - val_set.to_csv("val.csv") + output_file = csv_file.parent / "val.csv" + logger.info("Validation set saved to %s", output_file) + val_set.to_csv(output_file) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 4887b69..eed9a18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,10 +11,10 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python :: 3", ] -requires-python = ">=3.9,<3.12" +requires-python = ">=3.9,<3.13" dependencies = [ "sentencepiece >= 0.1.94", - "torch < 2.2", + "torch >=2", "transformers >= 3", ] @@ -29,7 +29,7 @@ dev = [ "datasets >= 1.0.2", "pandas >= 1.1.2", "pytest", - "pytorch-lightning<2.0.0,>1.5.0", + "pytorch-lightning>2", "scikit-learn >= 0.23.2", "tqdm", "pre-commit", diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 1608de7..e12a824 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,6 +1,7 @@ import json import src.data_loaders as module_data + import torch from pytorch_lightning import seed_everything, Trainer from torch.utils.data import DataLoader @@ -37,7 +38,7 @@ def get_instance(module, name, config, *args, **kwargs): ) trainer = Trainer( - gpus=0 if torch.cuda.is_available() else None, + accelerator="gpu" if torch.cuda.is_available() else "cpu", limit_train_batches=2, limit_val_batches=2, max_epochs=1, diff --git a/train.py b/train.py index ac75615..674f96a 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,7 @@ import os import pytorch_lightning as pl + import src.data_loaders as module_data import torch from pytorch_lightning.callbacks import ModelCheckpoint @@ -159,7 +160,7 @@ def cli_main(): "--device", default=None, type=str, - help="indices of GPUs to enable (default: None)", + help="comma-separated indices of GPUs to enable (default: None)", ) parser.add_argument( "--num_workers", @@ -208,16 +209,26 @@ def get_instance(module, name, config, *args, **kwargs): monitor="val_loss", mode="min", ) + + if args.device is None: + devices = "auto" + else: + devices = [int(d.strip()) for d in args.device.split(",")] + trainer = pl.Trainer( - gpus=args.device, + devices=devices, max_epochs=args.n_epochs, accumulate_grad_batches=config["accumulate_grad_batches"], callbacks=[checkpoint_callback], - resume_from_checkpoint=args.resume, default_root_dir="saved/" + config["name"], deterministic=True, ) - trainer.fit(model, data_loader, valid_data_loader) + trainer.fit( + model=model, + train_dataloaders=data_loader, + val_dataloaders=valid_data_loader, + ckpt_path=args.resume, + ) if __name__ == "__main__":