From 5d1eed91d996593385cc99eb0240f9b88b38e7c1 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Thu, 10 Aug 2023 13:56:31 +0200 Subject: [PATCH] feat: filtering (#18) * feat: add filtering * feat: update readme * fix: apply black --- README.md | 21 +++++++++- tests/dataset_gen/test_integration.py | 9 ++++- textbook/dataset_gen/dataset_gen_cli.py | 11 +++++ textbook/dataset_gen/filtering.py | 53 +++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 3 deletions(-) create mode 100644 textbook/dataset_gen/filtering.py diff --git a/README.md b/README.md index 7298fa8..5b9c918 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,25 @@ Another way to inject diversity is prompt engineering. By having random aspects ## Generating Dataset + +Follow this step to reproduce the dataset generation + + +First export your openAI key +```shell +export OPENAI_API_KEY=sk-XXX +``` +then start to parrallel call to open ai +```shell +python dataset_gen_cli.py generate ./tree/professions.json ./tree/subsubtopics.json ./exercises --n-prompts 50_000 --pool-size 20 +``` + +this should take around 6hours. The process might be killed before the end but the data will still be save progressivly. + + +Once the file are generated you can postprocess the files and save it into a jsonl file + ```shell -python textbook/dataset_gen/dataset_gen_cli.py --pool-size 10 "tests/data/prompts_debug.jsonl" +python dataset_gen_cli.py filter ./exercises dataset.jsonl ``` + diff --git a/tests/dataset_gen/test_integration.py b/tests/dataset_gen/test_integration.py index 40724a4..3c5272f 100644 --- a/tests/dataset_gen/test_integration.py +++ b/tests/dataset_gen/test_integration.py @@ -1,4 +1,5 @@ -from textbook.dataset_gen.dataset_gen_cli import generate +from textbook.dataset_gen.dataset_gen_cli import generate, filter +import os def test_cli_dataset_gen(tmp_path): @@ -9,5 +10,9 @@ def test_cli_dataset_gen(tmp_path): debug_speed=-1, retries=10, pool_size=10, - output_path=tmp_path / "results.jsonl", + output_path=tmp_path, ) + + filter(exo_path=tmp_path, dataset_file=os.path.join(tmp_path, "dataset.jsonl")) + + assert os.path.exists(os.path.join(tmp_path, "dataset.jsonl")) diff --git a/textbook/dataset_gen/dataset_gen_cli.py b/textbook/dataset_gen/dataset_gen_cli.py index ce67a04..679b9dc 100644 --- a/textbook/dataset_gen/dataset_gen_cli.py +++ b/textbook/dataset_gen/dataset_gen_cli.py @@ -9,11 +9,14 @@ mass_generation, OpenAIGenerator, MonkeyGenerator, + write_results_to_jsonl, ) import openai import os +from pathlib import Path from textbook.dataset_gen.create_prompts import Topic, Query +from textbook.dataset_gen.filtering import load_and_filter_exos app = Typer() @@ -116,5 +119,13 @@ def get_generator(): ) +@app.command() +def filter(exo_path: Path, dataset_file: str): + print(exo_path) + exos = load_and_filter_exos(exo_path) + print(len(exos)) + write_results_to_jsonl(dataset_file, exos) + + if __name__ == "__main__": app() diff --git a/textbook/dataset_gen/filtering.py b/textbook/dataset_gen/filtering.py new file mode 100644 index 0000000..be08aa5 --- /dev/null +++ b/textbook/dataset_gen/filtering.py @@ -0,0 +1,53 @@ +from textbook.dataset_gen.dataset_gen import Exercise +from typing import List, Union +import os +from pathlib import Path + + +def load_one_file(path: Union[Path, str]) -> List[Exercise]: + with open(path, "r") as f: + lines = f.readlines() + return [Exercise.parse_raw(line) for line in lines] + + +def load_all_exo(path: Union[Path, str]) -> List[Exercise]: + if isinstance(path, str): + path = Path(path) + exos: List[Exercise] = [] + for sub_dir in os.listdir(path): + for fn in os.listdir(path / sub_dir): + exos += load_one_file(path / sub_dir / fn) + return exos + + +def filter_bad_exos( + exos: List[Exercise], carac_to_remove=["??", "___"] +) -> List[Exercise]: + clean_exos: List[Exercise] = [] + for exo in exos: + keep = True + for carac in carac_to_remove: + if carac in exo.solution: + keep = False + break + + if keep: + clean_exos.append(exo) + + return clean_exos + + +def remove_extra(exos: List[Exercise], carac_to_split=["# Test", "```"]): + for exo in exos: + for carac in carac_to_split: + exo.solution = exo.solution.split(carac)[0] + + +def load_and_filter_exos(path: Union[Path, str]) -> List[Exercise]: + exos = load_all_exo(path) + print(len(exos)) + clean_exos = filter_bad_exos(exos) + print(len(clean_exos)) + + remove_extra(clean_exos) + return clean_exos