From 5f3a3a027c295a53d8fea6926a3f4124ea9dd6ff Mon Sep 17 00:00:00 2001 From: sami jaghouar Date: Mon, 28 Aug 2023 14:04:02 +0200 Subject: [PATCH] fix: fix max limit of prompt --- textbook/dataset_gen/dataset_gen_cli.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/textbook/dataset_gen/dataset_gen_cli.py b/textbook/dataset_gen/dataset_gen_cli.py index baccb09..eb42f03 100644 --- a/textbook/dataset_gen/dataset_gen_cli.py +++ b/textbook/dataset_gen/dataset_gen_cli.py @@ -54,14 +54,15 @@ def create_prompts( random.shuffle(combination_options) prompts: List[Query] = [] - for loc_topic in combination_options: - if len(prompts) == n: - break + for _ in range(len(professions)): + for loc_topic in combination_options: + if len(prompts) == n: + break - if loc_topic.mixing and loc_topic.parent != topic.parent: - profession = professions[np.random.randint(0, len(professions))] - query = create_prompt_query(topic, loc_topic, profession) - prompts.append(Query(query=query, topic_1=topic, topic_2=loc_topic)) + if loc_topic.mixing and loc_topic.parent != topic.parent: + profession = professions[np.random.randint(0, len(professions))] + query = create_prompt_query(topic, loc_topic, profession) + prompts.append(Query(query=query, topic_1=topic, topic_2=loc_topic)) return prompts @@ -97,7 +98,6 @@ def get_generator(): return MonkeyGenerator(speed=debug_speed) leaves = load_leaves(leaves_path) - prompts: List[List[Query]] = [ create_prompts( i, @@ -109,6 +109,11 @@ def get_generator(): ] prompts_flat = list(itertools.chain(*prompts)) + if n_prompts > len(prompts_flat): + raise ValueError( + f"Canot generate({n_prompts}) prompts because it is larger than the number of" + f" available prompts ({len(prompts_flat)})" + ) prompts_selection = [i.query for i in prompts_flat][:n_prompts] mass_generation(