From 2db6fa96e3f08a2fc54bbd512bf473d934afcb33 Mon Sep 17 00:00:00 2001 From: samsja Date: Tue, 8 Aug 2023 18:57:48 +0200 Subject: [PATCH] fix: add smth --- textbook/dataset_gen/dataset_gen.py | 51 ++++++++++++++++------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/textbook/dataset_gen/dataset_gen.py b/textbook/dataset_gen/dataset_gen.py index 471259a..8656b27 100644 --- a/textbook/dataset_gen/dataset_gen.py +++ b/textbook/dataset_gen/dataset_gen.py @@ -1,5 +1,6 @@ from concurrent.futures import ThreadPoolExecutor import json +import os import random import time @@ -14,6 +15,7 @@ Progress, TimeElapsedColumn, ) +import hashlib class Exercise(BaseModel): @@ -149,10 +151,22 @@ def _generation_wrapper( prompt: str, get_generator: Callable[[], Generator], update_progress: Callable, + save_dir: str, retries: int, -) -> List[Exercise]: +): + + file_path_sum = hashlib.md5(prompt.encode("utf-8")).hexdigest() + + dir_path, file_path = file_path_sum[:4], file_path_sum[4:] + dir_path = os.path.join(save_dir, dir_path) + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + generator = get_generator() - return generation(prompt, generator, update_progress, retries) + results = generation(prompt, generator, update_progress, retries) + + save_results_to_disk(os.path.join(dir_path, file_path), results) def mass_generation( @@ -167,7 +181,6 @@ def mass_generation( Generate from a list of prompts. Use a thread pool to parallelize the generation with catch and retry mechanism """ results = [] - counter = 0 with Progress( *Progress.get_default_columns(), "•", @@ -179,26 +192,13 @@ def update_progress(): with ThreadPoolExecutor(max_workers=pool_size) as executor: task = progress.add_task("[red]Generating...", total=len(prompts)) - futures = [] - for i in range(len(prompts)): # call API 10 times - futures.append( - executor.submit( - _generation_wrapper, - prompts[i], - get_generator, - update_progress, - retries=retries, - ) + + def map_fn(prompt): + _generation_wrapper( + prompt, get_generator, update_progress, save_dir, retries ) - for future in futures: - result = future.result() - results += result - if len(results) >= save_every: - write_results_to_jsonl( - f"{save_dir}/results_{counter}.jsonl", results - ) - results = [] - counter += 1 + + list(executor.map(map_fn, prompts)) return results @@ -223,3 +223,10 @@ def write_results_to_jsonl(file_path: str, results: List[Exercise]): for item in results: json.dump(item.dict(), file) file.write("\n") + + +def save_results_to_disk(file_path: str, results: List[Exercise]): + with open(file_path, "w") as file: + for item in results: + json.dump(item.dict(), file) + file.write("\n")