-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
28 lines (23 loc) · 1.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from automated_llm_eval.prompts import *
from private_key import *
from automated_llm_eval.visualize import *
from automated_llm_eval.policy_tuning_general import *
from automated_llm_eval.policy_tuning_corrective import *
from results.html_writer import writer_html
import os
import sys
openai_token = key["open-ai"]
def run_experiment(task, experiment_name, reliability_type, compare_type=None):
if not os.path.exists('results/'+experiment_name):
os.makedirs('results/'+experiment_name)
else:
pass
policy_tuning_corrective(f"results/{experiment_name}/policy_mutation_{compare_type}.csv", task, batch_size = 10, compare_type=compare_type, reliability_type =reliability_type)
create_accuracy_plot(f"results/{experiment_name}/policy_mutation_{compare_type}.csv", "Accuracy of Policy by Iteration: Negative COT", f"results/{experiment_name}/acc_policy_neg_COT_{experiment_name}_{compare_type}.png")
create_len_of_policy_plot(f"results/{experiment_name}/policy_mutation_{compare_type}.csv", "Length of Policy by Iteration: Negative COT", f"results/{experiment_name}/len_policy_neg_COT_{experiment_name}_{compare_type}.png")
writer_html(f"results/{experiment_name}/policy_mutation_{compare_type}.csv", f"results/{experiment_name}/policy_mutation_{compare_type}.html")
visualize_overlap(f"results/{experiment_name}/policy_mutation_{compare_type}.csv", f"results/{experiment_name}/overlap_{compare_type}.png")
def main():
run_experiment(task=sys.argv[1], experiment_name="teacher_check_12_11_gpt4_rand_example_choice", reliability_type = sys.argv[3] if len(sys.argv) > 3 else None, compare_type= sys.argv[2] if len(sys.argv) > 2 else None)
if __name__ == '__main__':
main()