Skip to content

Commit

Permalink
fix: filter set
Browse files Browse the repository at this point in the history
  • Loading branch information
sami jaghouar committed Aug 28, 2023
1 parent 7b40ce8 commit 0de1684
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions textbook/dataset_gen/dataset_gen_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
import random
import itertools
import json
Expand Down Expand Up @@ -49,28 +48,18 @@ def create_prompts(
topic: Topic,
combination_options: List[Topic],
professions: List[str],
limit: int = -1,
) -> List[Query]:
random.shuffle(combination_options)

prompts: List[Query] = []

def copy_and_shuffle(prof):
professions_copy = deepcopy(prof)
random.shuffle(professions_copy)
return professions_copy

profession_for_loc_topic = [
copy_and_shuffle(professions) for _ in combination_options
]

for i in range(len(professions)):
for j, loc_topic in enumerate(combination_options):
if len(prompts) == limit:
break

if loc_topic.mixing and loc_topic.parent != topic.parent:
profession = profession_for_loc_topic[j][i]
for loc_topic in combination_options:
if (
loc_topic.mixing
and loc_topic.parent != topic.parent
and loc_topic.topic != topic.topic
):
for profession in professions:
query = create_prompt_query(topic, loc_topic, profession)
prompts.append(Query(query=query, topic_1=topic, topic_2=loc_topic))

Expand All @@ -86,7 +75,6 @@ def generate(
pool_size: int = 10,
debug: bool = False,
debug_speed: int = 2,
gen_limit_per_topic: int = 200,
n_prompts: int = 100,
):
with open(tree_path, "r") as openfile:
Expand All @@ -113,18 +101,25 @@ def get_generator():
i,
combination_options=leaves,
professions=professions,
limit=gen_limit_per_topic,
)
for i in leaves
]

prompts_flat = list(itertools.chain(*prompts))
if n_prompts > len(prompts_flat):
raise ValueError(
f"Canot generate({n_prompts}) prompts because it is larger than the number of"
f"Cannot generate({n_prompts}) prompts because it is larger than the number of"
f" available prompts ({len(prompts_flat)})"
)
prompts_selection = [i.query for i in prompts_flat][:n_prompts]
prompts_selection = [i.query for i in prompts_flat]

print(f"prompts: {len(prompts_selection)}")

solo_prompts = list(set(prompts_selection))

print(f"solo prompts: {len(solo_prompts)}")
prompts_selection = solo_prompts[:n_prompts]
print(f"total prompts: {len(prompts_selection)}")

mass_generation(
prompts_selection,
Expand Down

0 comments on commit 0de1684

Please sign in to comment.