Skip to content

Commit

Permalink
fix: add smth
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 8, 2023
1 parent 6353f95 commit 2db6fa9
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions textbook/dataset_gen/dataset_gen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
import json
import os
import random
import time

Expand All @@ -14,6 +15,7 @@
Progress,
TimeElapsedColumn,
)
import hashlib


class Exercise(BaseModel):
Expand Down Expand Up @@ -149,10 +151,22 @@ def _generation_wrapper(
prompt: str,
get_generator: Callable[[], Generator],
update_progress: Callable,
save_dir: str,
retries: int,
) -> List[Exercise]:
):

file_path_sum = hashlib.md5(prompt.encode("utf-8")).hexdigest()

dir_path, file_path = file_path_sum[:4], file_path_sum[4:]
dir_path = os.path.join(save_dir, dir_path)

if not os.path.exists(dir_path):
os.makedirs(dir_path)

generator = get_generator()
return generation(prompt, generator, update_progress, retries)
results = generation(prompt, generator, update_progress, retries)

save_results_to_disk(os.path.join(dir_path, file_path), results)


def mass_generation(
Expand All @@ -167,7 +181,6 @@ def mass_generation(
Generate from a list of prompts. Use a thread pool to parallelize the generation with catch and retry mechanism
"""
results = []
counter = 0
with Progress(
*Progress.get_default_columns(),
"•",
Expand All @@ -179,26 +192,13 @@ def update_progress():

with ThreadPoolExecutor(max_workers=pool_size) as executor:
task = progress.add_task("[red]Generating...", total=len(prompts))
futures = []
for i in range(len(prompts)): # call API 10 times
futures.append(
executor.submit(
_generation_wrapper,
prompts[i],
get_generator,
update_progress,
retries=retries,
)

def map_fn(prompt):
_generation_wrapper(
prompt, get_generator, update_progress, save_dir, retries
)
for future in futures:
result = future.result()
results += result
if len(results) >= save_every:
write_results_to_jsonl(
f"{save_dir}/results_{counter}.jsonl", results
)
results = []
counter += 1

list(executor.map(map_fn, prompts))

return results

Expand All @@ -223,3 +223,10 @@ def write_results_to_jsonl(file_path: str, results: List[Exercise]):
for item in results:
json.dump(item.dict(), file)
file.write("\n")


def save_results_to_disk(file_path: str, results: List[Exercise]):
with open(file_path, "w") as file:
for item in results:
json.dump(item.dict(), file)
file.write("\n")

0 comments on commit 2db6fa9

Please sign in to comment.