Skip to content

Commit

Permalink
fix: fix max limit of prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
sami jaghouar committed Aug 28, 2023
1 parent d329cb9 commit 5f3a3a0
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions textbook/dataset_gen/dataset_gen_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ def create_prompts(
random.shuffle(combination_options)
prompts: List[Query] = []

for loc_topic in combination_options:
if len(prompts) == n:
break
for _ in range(len(professions)):
for loc_topic in combination_options:
if len(prompts) == n:
break

if loc_topic.mixing and loc_topic.parent != topic.parent:
profession = professions[np.random.randint(0, len(professions))]
query = create_prompt_query(topic, loc_topic, profession)
prompts.append(Query(query=query, topic_1=topic, topic_2=loc_topic))
if loc_topic.mixing and loc_topic.parent != topic.parent:
profession = professions[np.random.randint(0, len(professions))]
query = create_prompt_query(topic, loc_topic, profession)
prompts.append(Query(query=query, topic_1=topic, topic_2=loc_topic))

return prompts

Expand Down Expand Up @@ -97,7 +98,6 @@ def get_generator():
return MonkeyGenerator(speed=debug_speed)

leaves = load_leaves(leaves_path)

prompts: List[List[Query]] = [
create_prompts(
i,
Expand All @@ -109,6 +109,11 @@ def get_generator():
]

prompts_flat = list(itertools.chain(*prompts))
if n_prompts > len(prompts_flat):
raise ValueError(
f"Canot generate({n_prompts}) prompts because it is larger than the number of"
f" available prompts ({len(prompts_flat)})"
)
prompts_selection = [i.query for i in prompts_flat][:n_prompts]

mass_generation(
Expand Down

0 comments on commit 5f3a3a0

Please sign in to comment.