-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate_latex_table.py
204 lines (171 loc) · 8.38 KB
/
generate_latex_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import json
import os
from collections import defaultdict
# Mapping of providers to their respective country codes
llm_country_map = {
'Google': 'US',
'OpenAI': 'US',
'Meta': 'US',
'DeepSeek': 'CN',
'Mistral': 'EU',
'Alibaba': 'CN',
'Anthropic': 'US',
'Nvidia': 'US',
}
def load_data(results_dir):
data = []
model_list = [
('gemini-1.5-pro-001', "google", "Google"),
('gemini-1.5-pro-002', "google", "Google"),
('gpt-4o-2024-08-06', "openai-chatcompletion", "OpenAI"),
('o1-preview-2024-09-12', "openai-chatcompletion", "OpenAI"),
('llama-3.1-405b-instruct', 'openrouter', 'Meta'),
('deepseek-v2.5', 'openrouter', 'DeepSeek'),
('mistral-large-2407', 'openrouter', 'Mistral'),
('qwen-2.5-72b-instruct', 'openrouter', 'Alibaba'),
('claude-3-5-sonnet-20240620', 'anthropic', 'Anthropic'),
('claude-3-5-sonnet-20241022', 'anthropic', 'Anthropic'),
('llama-3.1-nemotron-70b-instruct', 'openrouter', 'Nvidia'),
('qwen-2.5-coder-32b-instruct', 'openrouter', 'Alibaba'),
]
benchmarks = ['defects4j', 'gitbugjava']
metrics = ['ast_match@1', 'plausible@1']
for llm_name, strategy, provider in model_list:
row = {'name': llm_name, 'provider': provider, 'total_cost': None}
try:
with open(f"{results_dir}/{llm_name}/total.json") as f:
total_result = json.load(f)
row['total_ast_match@1'] = total_result.get('ast_match@1', None)
row['total_plausible@1'] = total_result.get('plausible@1', None)
row['total_cost'] = total_result.get('cost', None)
except FileNotFoundError:
print(f"Warning: total.json not found for {llm_name}")
row['total_ast_match@1'] = None
row['total_plausible@1'] = None
all_benchmarks_complete = True
for benchmark in benchmarks:
try:
with open(f"{results_dir}/{llm_name}/{benchmark}/statistics_{benchmark}_instruct_{strategy}.json") as f:
stats_result = json.load(f)
for metric in metrics:
row[f"{benchmark}_{metric}"] = stats_result.get(metric, None)
with open(f"{results_dir}/{llm_name}/{benchmark}/costs_{benchmark}_instruct_{strategy}.json") as f:
cost_result = json.load(f)
row[f"{benchmark}_cost"] = cost_result.get('total_cost', None)
except FileNotFoundError:
print(f"Warning: Data not found for {llm_name} - {benchmark}")
for metric in metrics:
row[f"{benchmark}_{metric}"] = None
row[f"{benchmark}_cost"] = None
all_benchmarks_complete = False
# Only compute total cost if all benchmarks are complete and there's no total.json
if all_benchmarks_complete and row['total_cost'] is None:
row['total_cost'] = sum(row[f"{benchmark}_cost"] for benchmark in benchmarks if row[f"{benchmark}_cost"] is not None)
data.append(row)
return data
def find_best_scores(data):
best_scores = defaultdict(lambda: -float('inf'))
for row in data:
for key, value in row.items():
if key.endswith('ast_match@1') or key.endswith('plausible@1'):
if value is not None and value > best_scores[key]:
best_scores[key] = value
return best_scores
def load_citations():
citations_file = os.path.join(os.path.dirname(__file__), 'citations.json')
with open(citations_file, 'r') as f:
return json.load(f)
def generate_latex_table(data):
best_scores = find_best_scores(data)
citations = load_citations()
latex = "\\begin{table}[ht]\n"
latex += "\\centering\n"
latex += "\\makebox[\\textwidth][c]{%\n"
latex += "\\resizebox{1.3\\textwidth}{!}{\n"
latex += "\\large\n"
latex += "\\begin{tabular}{@{}ll "
# Modify each S column to include detect-weight=true
s_columns = [
"S[table-format=2.1, detect-weight=true]",
"S[table-format=2.1, detect-weight=true]",
"S[table-format=4.2, detect-weight=true]",
"S[table-format=2.1, detect-weight=true]",
"S[table-format=2.1, detect-weight=true]",
"S[table-format=4.2, detect-weight=true]",
"S[table-format=2.1, detect-weight=true]",
"S[table-format=2.1, detect-weight=true]",
"S[table-format=4.2, detect-weight=true]",
"c" # Add a column for Ref.
]
latex += ' '.join(s_columns) + "@{}}\n"
latex += "\\toprule\n"
latex += "\\multirow{2}{*}{\\textbf{Organization}} & \\multirow{2}{*}{\\textbf{Model}} & \\multicolumn{3}{c}{Defects4J v2 (484 bugs)} & \\multicolumn{3}{c}{GitBug-Java (90 bugs)} & \\multicolumn{3}{c}{\\textbf{Total (574 bugs)}} & \\multirow{2}{*}{Ref.} \\\\\n"
latex += "\\cmidrule(lr){3-5} \\cmidrule(lr){6-8} \\cmidrule(l){9-11}\n"
latex += " & & {Plausible@1} & {AST Match@1} & {Cost (\\$)} & {Plausible@1} & {AST Match@1} & {Cost (\\$)} & {\\textbf{Plausible@1}\\textsuperscript{1}} & {\\textbf{AST Match@1}} & {\\textbf{Cost (\\$)}} & \\\\\n"
latex += "\\midrule\n"
partial_footnote_needed = False # Flag to check if footnote is required
# Sort data: rows with total results first, then by total_plausible@1 score
sorted_data = sorted(data,
key=lambda x: (x['total_plausible@1'] is not None,
x['total_plausible@1'] if x['total_plausible@1'] is not None else -1),
reverse=True)
for row in sorted_data:
has_incomplete_results = any(row[f"{benchmark}_{metric}"] is None
for benchmark in ['defects4j', 'gitbugjava']
for metric in ['ast_match@1', 'plausible@1'])
if has_incomplete_results:
suffix = "\\textsuperscript{2}"
partial_footnote_needed = True # Set flag to add footnote later
else:
suffix = ""
# Get the model name and citation
model_name = row['name']
citation_key = citations.get(model_name, '')
# Get the country code for the provider
country_code = llm_country_map.get(row['provider'], 'UN') # 'UN' for unknown
# Insert the flag using the flag-icon package
provider_with_flag = f"\\worldflag[width=0.3cm]{{{country_code}}} {row['provider']}"
# Start building the LaTeX row with suffix appended
latex += f"{provider_with_flag}{suffix} & {model_name}{suffix} & "
for benchmark in ['defects4j', 'gitbugjava', 'total']:
for metric in ['plausible@1', 'ast_match@1']:
key = f"{benchmark}_{metric}"
value = row[key]
if value is not None:
if value == best_scores[key]:
cell = f"\\B {value*100:.1f}\\%"
else:
cell = f"{value*100:.1f}\\%"
else:
cell = f"\\multicolumn{{1}}{{c}}{{---}}"
latex += f"{cell} & "
cost = row[f'{benchmark}_cost']
if cost is not None:
cell = f"\\${cost:.2f}"
else:
cell = f"\\multicolumn{{1}}{{c}}{{---}}"
latex += f"{cell} & "
# Add the citation to the end of the row
latex += f"\\citep{{{citation_key}}}" if citation_key else "---"
latex += " \\\\\n"
latex += "\\bottomrule\n"
# Add the footnote once after all rows
latex += "\\addlinespace\n"
latex += f"\\multicolumn{{12}}{{l}}{{\\textsuperscript{{1}}Models are sorted by the total Plausible@1 score.}} \\\\\n"
if partial_footnote_needed:
latex += f"\\multicolumn{{12}}{{l}}{{\\textsuperscript{{2}}Only partial results available right now due to cost reasons.}} \\\\\n"
latex += "\\addlinespace\n"
latex += "\\end{tabular}\n"
latex += "}\n"
latex += "}\n"
latex += "\\caption{Leaderboard of Frontier Models for Program Repair as of \\today}\n"
latex += "\\label{tab:leaderboard}\n"
latex += "\\end{table}\n"
return latex
def main():
results_dir = os.path.join(os.path.dirname(__file__), '..', 'results')
data = load_data(results_dir)
latex_table = generate_latex_table(data)
print(latex_table)
if __name__ == "__main__":
main()