Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat train on hf dataset #20

Merged
merged 8 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,8 @@ Once the file are generated you can postprocess the files and save it into a jso
python dataset_gen_cli.py filter ./exercises dataset.jsonl
```

push to hf dataset

```shell
python dataset_gen_cli.py push "jinaai/code_exercises_40k" dataset.jsonl
```
1,066 changes: 534 additions & 532 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion tests/training/integration/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@


@pytest.mark.parametrize("module", ["Replit", "StarCoder"])
def test_train(module):
@pytest.mark.parametrize("dataset", ["DummyDataset", "ExerciseDatast"])
def test_train(module, dataset):
train(
module=module,
dataset=dataset,
debug=True,
epochs=1,
micro_batch_size=1,
Expand Down
6 changes: 5 additions & 1 deletion tests/training/units/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from textbook.dataset import DummyDataset
from textbook.dataset import DummyDataset, ExerciseDatast
from textbook.model import Replit

from transformers import PreTrainedTokenizer
Expand All @@ -13,3 +13,7 @@ def tokenizer() -> PreTrainedTokenizer:

def test_tiny_stories(tokenizer):
DummyDataset(debug=True, tokenizer=tokenizer)


def test_exercises_dataet(tokenizer):
ExerciseDatast(debug=True, tokenizer=tokenizer)
73 changes: 67 additions & 6 deletions textbook/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Protocol
import random

from datasets import Dataset
from transformers import PreTrainedTokenizer, DataCollatorForLanguageModeling
from datasets import Dataset, load_dataset
from transformers import (
PreTrainedTokenizer,
DataCollatorForLanguageModeling,
DataCollatorForSeq2Seq,
)
from transformers.data.data_collator import DataCollatorMixin


class CustomDataset(Protocol):
train_dataset: Dataset
eval_dataset: Dataset
test_dataset: Dataset
data_collator: DataCollatorMixin

def __init__(
Expand Down Expand Up @@ -44,14 +48,14 @@ def __init__(
self.test_dataset = split_dataset["test"]

self.train_dataset = self.train_dataset.map(
self._get_tokenize_fn(tokenizer),
self._get_preprocess_fn(tokenizer),
batched=True,
num_proc=4,
remove_columns=self.train_dataset.column_names,
)

self.test_dataset = self.test_dataset.map(
self._get_tokenize_fn(tokenizer),
self._get_preprocess_fn(tokenizer),
batched=True,
num_proc=4,
remove_columns=self.test_dataset.column_names,
Expand All @@ -60,10 +64,67 @@ def __init__(
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

@staticmethod
def _get_tokenize_fn(tokenizer: PreTrainedTokenizer):
def _get_preprocess_fn(tokenizer: PreTrainedTokenizer):
def tokenize_fn(input):
return tokenizer(
input["text"],
)

return tokenize_fn


class ExerciseDatast:
def __init__(
self,
tokenizer: PreTrainedTokenizer,
debug: bool = False,
):
self.debug = debug

dataset = load_dataset("jinaai/code_exercises_40k")["train"]

if debug:
dataset = dataset.select(range(10))

split_dataset = dataset.train_test_split(test_size=0.1)

self.train_dataset = split_dataset["train"]
self.test_dataset = split_dataset["test"]

self.train_dataset = self.train_dataset.map(
self._get_preprocess_fn(tokenizer),
batched=False,
num_proc=4,
remove_columns=self.train_dataset.column_names,
)

self.test_dataset = self.test_dataset.map(
self._get_preprocess_fn(tokenizer),
batched=False,
num_proc=4,
remove_columns=self.test_dataset.column_names,
)

self.data_collator = DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

@staticmethod
def _get_preprocess_fn(tokenizer: PreTrainedTokenizer):
def tokenize_fn(input):
input_problem = input["problem"]
input_solution = input["solution"]

inputs = tokenizer(input_problem)
targets = tokenizer(input_solution)
inputs["labels"] = [-100] * len(inputs["input_ids"]) + targets[
"input_ids"
] # we don't train on the problem tokens
inputs["input_ids"] = inputs["input_ids"] + targets["input_ids"]
inputs["attention_mask"] = (
inputs["attention_mask"] + targets["attention_mask"]
)

return inputs

return tokenize_fn
15 changes: 15 additions & 0 deletions textbook/dataset_gen/dataset_gen_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from textbook.dataset_gen.create_prompts import Topic, Query
from textbook.dataset_gen.filtering import load_and_filter_exos
from datasets import Dataset

app = Typer()

Expand Down Expand Up @@ -127,5 +128,19 @@ def filter(exo_path: Path, dataset_file: str):
write_results_to_jsonl(dataset_file, exos)


@app.command()
def push(repo_name: str, dataset_file: Path):
with open(dataset_file, "r") as file:
lines = file.readlines()
exercises = [json.loads(line) for line in lines]

def gen():
for exo in exercises:
yield exo

dataset = Dataset.from_generator(gen)
dataset.push_to_hub(repo_name)


if __name__ == "__main__":
app()
2 changes: 0 additions & 2 deletions textbook/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _init_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model, trust_remote_code=True
)
self.tokenizer.padding_side = "left" # Allow batched inference
self.tokenizer.pad_token = self.tokenizer.eos_token


Expand Down Expand Up @@ -78,5 +77,4 @@ def _init_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model,
)
self.tokenizer.padding_side = "left" # Allow batched inference
self.tokenizer.pad_token = self.tokenizer.eos_token
18 changes: 12 additions & 6 deletions textbook/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from textbook.dataset import DummyDataset
from textbook.dataset import CustomDataset
from textbook.evaluate import evaluate
from textbook.model import BaseModule

Expand Down Expand Up @@ -37,6 +37,7 @@ def wrapper(*args, **kwargs):
def train(
*,
module: str = "StarCoder",
dataset: str = "ExerciseDatast",
epochs: int = 1,
micro_batch_size: int = 1,
batch_size: int = 1,
Expand All @@ -59,7 +60,11 @@ def train(
model = torch.compile(module_instance.model)
model = module_instance.model
tokenizer = module_instance.tokenizer
dataset = DummyDataset(tokenizer=tokenizer, debug=debug)

dataset_cls: Type[CustomDataset] = getattr(
import_module("textbook.dataset"), dataset
)
dataset_instance = dataset_cls(tokenizer=tokenizer, debug=debug)

if debug:
wandb_run_name = "debug"
Expand All @@ -78,18 +83,19 @@ def train(

use_wandb = local_rank == 0 and use_wandb
if use_wandb:
run = wandb.init(wandb_project, **dict(config=config_to_log)) # type: ignore
run = wandb.init(project=wandb_project, **dict(config=config_to_log)) # type: ignore
else:
run = None # type: ignore

trainer = transformers.Trainer(
model=model,
train_dataset=dataset.train_dataset,
eval_dataset=dataset.test_dataset,
train_dataset=dataset_instance.train_dataset,
eval_dataset=dataset_instance.test_dataset,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=batch_size // micro_batch_size,
optim="adamw_torch",
# gradient_checkpointing=True,
warmup_steps=100,
num_train_epochs=epochs,
learning_rate=learning_rate,
Expand All @@ -104,7 +110,7 @@ def train(
run_name=wandb_run_name if use_wandb else None,
remove_unused_columns=False,
),
data_collator=dataset.data_collator,
data_collator=dataset_instance.data_collator,
)

trainer.train()
Expand Down
Loading