Skip to content

Commit

Permalink
Feat train on hf dataset (#20)
Browse files Browse the repository at this point in the history
* feat: add cli to push to hf

* fix: fix cli

* feat: add hf dataset

* feat: add hf dataset

* feat: tune param

* feat: update dependency

* fix: fix wandb

* fix: fix optim
  • Loading branch information
samsja authored Aug 11, 2023
1 parent 7d5d18d commit 2948f7e
Show file tree
Hide file tree
Showing 6 changed files with 621 additions and 548 deletions.
1,066 changes: 534 additions & 532 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion tests/training/integration/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@


@pytest.mark.parametrize("module", ["Replit", "StarCoder"])
def test_train(module):
@pytest.mark.parametrize("dataset", ["DummyDataset", "ExerciseDatast"])
def test_train(module, dataset):
train(
module=module,
dataset=dataset,
debug=True,
epochs=1,
micro_batch_size=1,
Expand Down
6 changes: 5 additions & 1 deletion tests/training/units/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from textbook.dataset import DummyDataset
from textbook.dataset import DummyDataset, ExerciseDatast
from textbook.model import Replit

from transformers import PreTrainedTokenizer
Expand All @@ -13,3 +13,7 @@ def tokenizer() -> PreTrainedTokenizer:

def test_tiny_stories(tokenizer):
DummyDataset(debug=True, tokenizer=tokenizer)


def test_exercises_dataet(tokenizer):
ExerciseDatast(debug=True, tokenizer=tokenizer)
73 changes: 67 additions & 6 deletions textbook/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Protocol
import random

from datasets import Dataset
from transformers import PreTrainedTokenizer, DataCollatorForLanguageModeling
from datasets import Dataset, load_dataset
from transformers import (
PreTrainedTokenizer,
DataCollatorForLanguageModeling,
DataCollatorForSeq2Seq,
)
from transformers.data.data_collator import DataCollatorMixin


class CustomDataset(Protocol):
train_dataset: Dataset
eval_dataset: Dataset
test_dataset: Dataset
data_collator: DataCollatorMixin

def __init__(
Expand Down Expand Up @@ -44,14 +48,14 @@ def __init__(
self.test_dataset = split_dataset["test"]

self.train_dataset = self.train_dataset.map(
self._get_tokenize_fn(tokenizer),
self._get_preprocess_fn(tokenizer),
batched=True,
num_proc=4,
remove_columns=self.train_dataset.column_names,
)

self.test_dataset = self.test_dataset.map(
self._get_tokenize_fn(tokenizer),
self._get_preprocess_fn(tokenizer),
batched=True,
num_proc=4,
remove_columns=self.test_dataset.column_names,
Expand All @@ -60,10 +64,67 @@ def __init__(
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

@staticmethod
def _get_tokenize_fn(tokenizer: PreTrainedTokenizer):
def _get_preprocess_fn(tokenizer: PreTrainedTokenizer):
def tokenize_fn(input):
return tokenizer(
input["text"],
)

return tokenize_fn


class ExerciseDatast:
def __init__(
self,
tokenizer: PreTrainedTokenizer,
debug: bool = False,
):
self.debug = debug

dataset = load_dataset("jinaai/code_exercises_40k")["train"]

if debug:
dataset = dataset.select(range(10))

split_dataset = dataset.train_test_split(test_size=0.1)

self.train_dataset = split_dataset["train"]
self.test_dataset = split_dataset["test"]

self.train_dataset = self.train_dataset.map(
self._get_preprocess_fn(tokenizer),
batched=False,
num_proc=4,
remove_columns=self.train_dataset.column_names,
)

self.test_dataset = self.test_dataset.map(
self._get_preprocess_fn(tokenizer),
batched=False,
num_proc=4,
remove_columns=self.test_dataset.column_names,
)

self.data_collator = DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

@staticmethod
def _get_preprocess_fn(tokenizer: PreTrainedTokenizer):
def tokenize_fn(input):
input_problem = input["problem"]
input_solution = input["solution"]

inputs = tokenizer(input_problem)
targets = tokenizer(input_solution)
inputs["labels"] = [-100] * len(inputs["input_ids"]) + targets[
"input_ids"
] # we don't train on the problem tokens
inputs["input_ids"] = inputs["input_ids"] + targets["input_ids"]
inputs["attention_mask"] = (
inputs["attention_mask"] + targets["attention_mask"]
)

return inputs

return tokenize_fn
2 changes: 0 additions & 2 deletions textbook/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _init_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model, trust_remote_code=True
)
self.tokenizer.padding_side = "left" # Allow batched inference
self.tokenizer.pad_token = self.tokenizer.eos_token


Expand Down Expand Up @@ -78,5 +77,4 @@ def _init_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model,
)
self.tokenizer.padding_side = "left" # Allow batched inference
self.tokenizer.pad_token = self.tokenizer.eos_token
18 changes: 12 additions & 6 deletions textbook/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from textbook.dataset import DummyDataset
from textbook.dataset import CustomDataset
from textbook.evaluate import evaluate
from textbook.model import BaseModule

Expand Down Expand Up @@ -37,6 +37,7 @@ def wrapper(*args, **kwargs):
def train(
*,
module: str = "StarCoder",
dataset: str = "ExerciseDatast",
epochs: int = 1,
micro_batch_size: int = 1,
batch_size: int = 1,
Expand All @@ -59,7 +60,11 @@ def train(
model = torch.compile(module_instance.model)
model = module_instance.model
tokenizer = module_instance.tokenizer
dataset = DummyDataset(tokenizer=tokenizer, debug=debug)

dataset_cls: Type[CustomDataset] = getattr(
import_module("textbook.dataset"), dataset
)
dataset_instance = dataset_cls(tokenizer=tokenizer, debug=debug)

if debug:
wandb_run_name = "debug"
Expand All @@ -78,18 +83,19 @@ def train(

use_wandb = local_rank == 0 and use_wandb
if use_wandb:
run = wandb.init(wandb_project, **dict(config=config_to_log)) # type: ignore
run = wandb.init(project=wandb_project, **dict(config=config_to_log)) # type: ignore
else:
run = None # type: ignore

trainer = transformers.Trainer(
model=model,
train_dataset=dataset.train_dataset,
eval_dataset=dataset.test_dataset,
train_dataset=dataset_instance.train_dataset,
eval_dataset=dataset_instance.test_dataset,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=batch_size // micro_batch_size,
optim="adamw_torch",
# gradient_checkpointing=True,
warmup_steps=100,
num_train_epochs=epochs,
learning_rate=learning_rate,
Expand All @@ -104,7 +110,7 @@ def train(
run_name=wandb_run_name if use_wandb else None,
remove_unused_columns=False,
),
data_collator=dataset.data_collator,
data_collator=dataset_instance.data_collator,
)

trainer.train()
Expand Down

0 comments on commit 2948f7e

Please sign in to comment.