From 0de1684469243bd45a759b1450e566061f364138 Mon Sep 17 00:00:00 2001 From: sami jaghouar Date: Mon, 28 Aug 2023 16:29:17 +0200 Subject: [PATCH] fix: filter set --- textbook/dataset_gen/dataset_gen_cli.py | 39 +++++++++++-------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/textbook/dataset_gen/dataset_gen_cli.py b/textbook/dataset_gen/dataset_gen_cli.py index 479ade7..82d121d 100644 --- a/textbook/dataset_gen/dataset_gen_cli.py +++ b/textbook/dataset_gen/dataset_gen_cli.py @@ -1,4 +1,3 @@ -from copy import deepcopy import random import itertools import json @@ -49,28 +48,18 @@ def create_prompts( topic: Topic, combination_options: List[Topic], professions: List[str], - limit: int = -1, ) -> List[Query]: random.shuffle(combination_options) prompts: List[Query] = [] - def copy_and_shuffle(prof): - professions_copy = deepcopy(prof) - random.shuffle(professions_copy) - return professions_copy - - profession_for_loc_topic = [ - copy_and_shuffle(professions) for _ in combination_options - ] - - for i in range(len(professions)): - for j, loc_topic in enumerate(combination_options): - if len(prompts) == limit: - break - - if loc_topic.mixing and loc_topic.parent != topic.parent: - profession = profession_for_loc_topic[j][i] + for loc_topic in combination_options: + if ( + loc_topic.mixing + and loc_topic.parent != topic.parent + and loc_topic.topic != topic.topic + ): + for profession in professions: query = create_prompt_query(topic, loc_topic, profession) prompts.append(Query(query=query, topic_1=topic, topic_2=loc_topic)) @@ -86,7 +75,6 @@ def generate( pool_size: int = 10, debug: bool = False, debug_speed: int = 2, - gen_limit_per_topic: int = 200, n_prompts: int = 100, ): with open(tree_path, "r") as openfile: @@ -113,7 +101,6 @@ def get_generator(): i, combination_options=leaves, professions=professions, - limit=gen_limit_per_topic, ) for i in leaves ] @@ -121,10 +108,18 @@ def get_generator(): prompts_flat = list(itertools.chain(*prompts)) if n_prompts > len(prompts_flat): raise ValueError( - f"Canot generate({n_prompts}) prompts because it is larger than the number of" + f"Cannot generate({n_prompts}) prompts because it is larger than the number of" f" available prompts ({len(prompts_flat)})" ) - prompts_selection = [i.query for i in prompts_flat][:n_prompts] + prompts_selection = [i.query for i in prompts_flat] + + print(f"prompts: {len(prompts_selection)}") + + solo_prompts = list(set(prompts_selection)) + + print(f"solo prompts: {len(solo_prompts)}") + prompts_selection = solo_prompts[:n_prompts] + print(f"total prompts: {len(prompts_selection)}") mass_generation( prompts_selection,