diff --git a/textbook/dataset_gen/dataset_gen.py b/textbook/dataset_gen/dataset_gen.py index b87f00a..446c270 100644 --- a/textbook/dataset_gen/dataset_gen.py +++ b/textbook/dataset_gen/dataset_gen.py @@ -191,19 +191,31 @@ def mass_generation( "•", TimeElapsedColumn(), ) as progress: - - def update_progress(): - progress.update(task, advance=1) - with ThreadPoolExecutor(max_workers=pool_size) as executor: - task = progress.add_task("[red]Generating...", total=len(prompts)) - - def map_fn(prompt): - _generation_wrapper( - prompt, get_generator, update_progress, save_dir, retries + progress_task = progress.add_task("[red]Generating...", total=len(prompts)) + + def update_progress(): + progress.update(progress_task, advance=1) + + tasks = [] + + for prompt in prompts: + tasks.append( + executor.submit( + _generation_wrapper, + prompt, + get_generator, + update_progress, + save_dir, + retries, + ) ) - list(executor.map(map_fn, prompts)) + for task in tasks: + try: + task.result() + except Exception as e: + raise e def load_prompts(file: str, key_prompt: str = "prompt") -> List[str]: