-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
965 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
_build | ||
.idea | ||
**/__pycache__ | ||
/docs/examples/**/*.diff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
.. *************************** | ||
.. **************** | ||
.. Minimal Examples | ||
.. *************************** | ||
.. **************** | ||
.. include:: examples/frameworks/README.rst | ||
.. include:: examples/distributed/README.rst | ||
.. include:: examples/data/README.rst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
***************************** | ||
Data Handling during Training | ||
***************************** | ||
|
||
|
||
.. include:: examples/data/hf/README.rst | ||
.. include:: examples/data/torchvision/README.rst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
HuggingFace Dataset | ||
=================== | ||
|
||
|
||
**Prerequisites** | ||
|
||
Make sure to read the following sections of the documentation before using this example: | ||
|
||
* :ref:`pytorch_setup` | ||
* :ref:`001 - Single GPU Job` | ||
|
||
The full source code for this example is available on `the mila-docs GitHub repository. <https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/data/hf>`_ | ||
|
||
|
||
**job.sh** | ||
|
||
.. literalinclude:: examples/data/hf/job.sh.diff | ||
:language: diff | ||
|
||
|
||
**main.py** | ||
|
||
.. literalinclude:: examples/data/hf/main.py.diff | ||
:language: diff | ||
|
||
|
||
**cp_data.sh** | ||
|
||
.. literalinclude:: examples/data/hf/cp_data.sh | ||
:language: bash | ||
|
||
|
||
**list_dataset.py** | ||
|
||
.. literalinclude:: examples/data/hf/list_dataset.py | ||
:language: python | ||
|
||
|
||
**prepare_date.sh** | ||
|
||
.. literalinclude:: examples/data/hf/prepare_date.sh | ||
:language: bash | ||
|
||
|
||
**prepare_date.py** | ||
|
||
.. literalinclude:: examples/data/hf/prepare_date.py | ||
:language: python | ||
|
||
|
||
**Running this example** | ||
|
||
.. code-block:: bash | ||
$ sbatch job.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/bin/bash | ||
set -o errexit | ||
|
||
_SRC=$1 | ||
_DEST=$2 | ||
_WORKERS=$3 | ||
|
||
python3 list_dataset.py | while read f | ||
do | ||
mkdir --parents "${_DEST}/$(dirname "$f")" | ||
# echo source first so cp understands it's the source file | ||
readlink --canonicalize "${_SRC}/$f" | ||
# echo output last so it is matched to the cp's '-T' argument | ||
echo "${_DEST}/$f" | ||
done | xargs -n2 -P${_WORKERS} cp --update -T |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#!/bin/bash | ||
#SBATCH --gpus-per-task=rtx8000:1 | ||
#SBATCH --cpus-per-task=4 | ||
#SBATCH --ntasks-per-node=1 | ||
#SBATCH --mem=32G | ||
#SBATCH --time=02:00:00 | ||
#SBATCH --tmp=1500G | ||
set -o errexit | ||
|
||
|
||
# Echo time and hostname into log | ||
echo "Date: $(date)" | ||
echo "Hostname: $(hostname)" | ||
|
||
|
||
# Ensure only anaconda/3 module loaded. | ||
module purge | ||
# This example uses Conda to manage package dependencies. | ||
# See https://docs.mila.quebec/Userguide.html#conda for more information. | ||
module load anaconda/3 | ||
|
||
|
||
# Creating the environment for the first time: | ||
# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \ | ||
# pytorch-cuda=11.6 scipy -c pytorch -c nvidia | ||
# Other conda packages: | ||
# conda install -y -n pytorch -c conda-forge rich tqdm | ||
|
||
# Activate pre-existing environment. | ||
conda activate pytorch | ||
|
||
|
||
# Prepare data for training | ||
mkdir -p "$SLURM_TMPDIR/data" # Transformed dataset to be used in training | ||
mkdir -p "$SLURM_TMPDIR/data_raw" # Local links to raw local dataset | ||
|
||
if [[ -z "${HF_DATASETS_CACHE}" ]] | ||
then | ||
# Store the huggingface datasets cache in $SCRATCH | ||
export HF_DATASETS_CACHE=$SCRATCH/cache/huggingface/datasets | ||
fi | ||
if [[ -z "${_DATA_PREP_WORKERS}" ]] | ||
then | ||
_DATA_PREP_WORKERS=${SLURM_JOB_CPUS_PER_NODE} | ||
fi | ||
if [[ -z "${_DATA_PREP_WORKERS}" ]] | ||
then | ||
_DATA_PREP_WORKERS=16 | ||
fi | ||
|
||
# Reorganize the raw files in $SLURM_TMPDIR, if needed, then preprocess the | ||
# dataset such that the heavy work is done only once *ever* | ||
srun --ntasks=1 --ntasks-per-node=1 \ | ||
time -p bash prepare_data.sh "/network/datasets/pile" "$SLURM_TMPDIR/data_raw" ${_DATA_PREP_WORKERS} | ||
|
||
# Copy the preprocessed dataset to $SLURM_TMPDIR so it is close to the GPUs for | ||
# faster training | ||
srun --ntasks=$SLURM_JOB_NUM_NODES --ntasks-per-node=1 \ | ||
time -p bash cp_data.sh "${HF_DATASETS_CACHE}" "$SLURM_TMPDIR/data" ${_DATA_PREP_WORKERS} | ||
|
||
# Use the local copy of the preprocessed dataset | ||
export HF_DATASETS_CACHE="$SLURM_TMPDIR/data" | ||
|
||
|
||
# Execute Python script | ||
python main.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
"""List to stdout the files of the dataset""" | ||
import sys | ||
from pathlib import Path | ||
|
||
import datasets | ||
from datasets.arrow_reader import make_file_instructions | ||
|
||
|
||
# Redirect outputs to stderr to avoid noize in stdout | ||
_stdout = sys.stdout | ||
sys.stdout = sys.stderr | ||
|
||
builder = datasets.load_dataset_builder("the_pile", subsets=["all"], version="0.0.0") | ||
|
||
files = [make_file_instructions(s.dataset_name, [s], s.name, prefix_path=builder.cache_dir) | ||
for s in builder.info.splits.values()] | ||
files = [Path(inst["filename"]).relative_to(datasets.config.HF_DATASETS_CACHE).with_suffix(".arrow") | ||
for insts in files for inst in insts.file_instructions] | ||
dataset_info = Path(builder.cache_dir) / "dataset_info.json" | ||
dataset_info = dataset_info.relative_to(datasets.config.HF_DATASETS_CACHE) | ||
for f in (dataset_info, *files): | ||
print(f, file=_stdout) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
"""Torchvision training example.""" | ||
import logging | ||
import os | ||
|
||
import datasets | ||
import rich.logging | ||
import torch | ||
from torch import Tensor, nn | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision.models import resnet18 | ||
from tqdm import tqdm | ||
|
||
|
||
def main(): | ||
training_epochs = 1 | ||
learning_rate = 5e-4 | ||
weight_decay = 1e-4 | ||
batch_size = 256 | ||
|
||
# Check that the GPU is available | ||
assert torch.cuda.is_available() and torch.cuda.device_count() > 0 | ||
device = torch.device("cuda", 0) | ||
|
||
# Setup logging (optional, but much better than using print statements) | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package. | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Create a model and move it to the GPU. | ||
model = resnet18() | ||
model.to(device=device) | ||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) | ||
|
||
# Setup ImageNet | ||
num_workers = get_num_workers() | ||
dataset_path = "the_pile" | ||
train_dataset, valid_dataset, test_dataset = make_datasets(dataset_path) | ||
train_dataloader = DataLoader( | ||
train_dataset, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
shuffle=True, | ||
) | ||
valid_dataloader = DataLoader( | ||
valid_dataset, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
shuffle=False, | ||
) | ||
test_dataloader = DataLoader( # NOTE: Not used in this example. | ||
test_dataset, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
shuffle=False, | ||
) | ||
|
||
# Checkout the "checkpointing and preemption" example for more info! | ||
logger.debug("Starting training from scratch.") | ||
|
||
for epoch in range(training_epochs): | ||
logger.debug(f"Starting epoch {epoch}/{training_epochs}") | ||
|
||
# Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers) | ||
model.train() | ||
|
||
# NOTE: using a progress bar from tqdm because it's nicer than using `print`. | ||
progress_bar = tqdm( | ||
total=len(train_dataloader), | ||
desc=f"Train epoch {epoch}", | ||
) | ||
|
||
# Training loop | ||
for batch in train_dataloader: | ||
# Move the batch to the GPU before we pass it to the model | ||
batch = tuple(item.to(device) for item in batch) | ||
|
||
# [Training of the model goes here] | ||
|
||
# Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just) | ||
progress_bar.update(1) | ||
progress_bar.close() | ||
|
||
val_loss, val_accuracy = validation_loop(model, valid_dataloader, device) | ||
logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}") | ||
|
||
print("Done!") | ||
|
||
|
||
@torch.no_grad() | ||
def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device): | ||
model.eval() | ||
|
||
total_loss = 0.0 | ||
n_samples = 0 | ||
correct_predictions = 0 | ||
|
||
for batch in dataloader: | ||
batch = tuple(item.to(device) for item in batch) | ||
x, y = batch | ||
|
||
logits: Tensor = model(x) | ||
loss = F.cross_entropy(logits, y) | ||
|
||
batch_n_samples = x.shape[0] | ||
batch_correct_predictions = logits.argmax(-1).eq(y).sum() | ||
|
||
total_loss += loss.item() | ||
n_samples += batch_n_samples | ||
correct_predictions += batch_correct_predictions | ||
|
||
accuracy = correct_predictions / n_samples | ||
return total_loss, accuracy | ||
|
||
|
||
def make_datasets(dataset_path: str): | ||
"""Returns the training, validation, and test splits for ImageNet. | ||
NOTE: We don't use transforms here for simplicity. | ||
Having different transformations for train and validation would complicate things a bit. | ||
Later examples will show how to do the train/val/test split properly when using transforms. | ||
""" | ||
builder = datasets.load_dataset_builder(dataset_path, subsets=["all"], version="0.0.0") | ||
train_dataset = builder.as_dataset(split="train").with_format("torch") | ||
valid_dataset = builder.as_dataset(split="validation").with_format("torch") | ||
test_dataset = builder.as_dataset(split="test").with_format("torch") | ||
return train_dataset, valid_dataset, test_dataset | ||
|
||
|
||
def get_num_workers() -> int: | ||
"""Gets the optimal number of DatLoader workers to use in the current job.""" | ||
if "SLURM_CPUS_PER_TASK" in os.environ: | ||
return int(os.environ["SLURM_CPUS_PER_TASK"]) | ||
if hasattr(os, "sched_getaffinity"): | ||
return len(os.sched_getaffinity(0)) | ||
return torch.multiprocessing.cpu_count() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.