Skip to content

Commit

Permalink
improve filtering of datasets in benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Sep 27, 2024
1 parent 6337f60 commit 2b5c09f
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions docs/benchmarks/ebm-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,8 @@
"\n",
"# Optionally filter out results we want to replace\n",
"#results_df = results_df[results_df['method'] != 'ebm']\n",
"#results_df = results_df[~((results_df['method'] == 'ebm') & (results_df['meta'] == '{}'))]\n",
"#results_df = results_df[(results_df['method'] != 'ebm') | (results_df['meta'] != '{}')]\n",
"#results_df = results_df[(results_df['method'] != 'ebm') | (results_df['meta'] != '{\"interactions\": 0}')]\n",
"print(f'Results (post-filtered) count: {results_df.shape[0]}')"
]
},
Expand All @@ -925,21 +926,26 @@
"outputs": [],
"source": [
"# Fill in results from previous runs if desired.\n",
"filler_df = pd.DataFrame(columns=results_df.columns)\n",
"#filler_df = pd.read_csv(\"prev.csv\")\n",
"\n",
"# Optionally filter out results from the filter\n",
"#filler_df = filler_df[filler_df['meta'] == \"{'interactions': 0\"]\n",
"\n",
"key_columns = ['task', 'method', 'meta', 'replicate_num', 'name', 'seq_num']\n",
"filler_df = filler_df[~filler_df.set_index(key_columns).index.isin(results_df.set_index(key_columns).index)]\n",
"if 0 < filler_df.shape[0]:\n",
" results_df = pd.concat([results_df, filler_df], ignore_index=True)\n",
" results_df = results_df.sort_values(by=[\"task\", \"method\", \"meta\", \"replicate_num\", \"name\", \"seq_num\"])\n",
" results_df.to_csv(\"merged.csv\", index=None)\n",
"print(f'Filter count: {filler_df.shape[0]}')\n",
"print(f'Results count: {results_df.shape[0]}')\n",
"#print(filler_df.to_string())"
"basefile = 'base.csv'\n",
"import os\n",
"if os.path.exists(basefile):\n",
" filler_df = pd.DataFrame(columns=results_df.columns)\n",
" filler_df = pd.read_csv(basefile)\n",
" \n",
" # Optionally filter out results from the filter\n",
" filler_df = filler_df[filler_df['method'] != 'ebm']\n",
" #filler_df = filler_df[(filler_df['method'] != 'ebm') | (filler_df['meta'] != '{}')]\n",
" #filler_df = filler_df[(filler_df['method'] != 'ebm') | (filler_df['meta'] != '{\"interactions\": 0}')]\n",
" \n",
" key_columns = ['task', 'method', 'meta', 'replicate_num', 'name', 'seq_num']\n",
" filler_df = filler_df[~filler_df.set_index(key_columns).index.isin(results_df.set_index(key_columns).index)]\n",
" if 0 < filler_df.shape[0]:\n",
" results_df = pd.concat([results_df, filler_df], ignore_index=True)\n",
" results_df = results_df.sort_values(by=[\"task\", \"method\", \"meta\", \"replicate_num\", \"name\", \"seq_num\"])\n",
" results_df.to_csv(\"merged.csv\", index=None)\n",
" print(f'Filter count: {filler_df.shape[0]}')\n",
" print(f'Results count: {results_df.shape[0]}')\n",
" #print(filler_df.to_string())"
]
},
{
Expand All @@ -960,7 +966,13 @@
"\n",
"# Optionally filter out any incomplete datasets\n",
"#results_df = results_df[results_df['task'] != 'Devnagari-Script']\n",
"#results_df = results_df[results_df['type'] == 'regression']\n",
"#results_df = results_df[results_df['task'] != 'CIFAR_10']\n",
"#results_df = results_df[results_df['task'] != 'isolet']\n",
"#results_df = results_df[results_df['task'] != 'mnist_784']\n",
"#results_df = results_df[results_df['task'] != 'Airlines_DepDelay_10M']\n",
"#results_df = results_df[results_df['type'] != 'binary']\n",
"#results_df = results_df[results_df['type'] != 'multiclass']\n",
"#results_df = results_df[results_df['type'] != 'regression']\n",
"print(f'Final count: {results_df.shape[0]}')"
]
},
Expand Down

0 comments on commit 2b5c09f

Please sign in to comment.