-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add better timing notebook and measuring (#64)
- Loading branch information
1 parent
2700ea0
commit a291024
Showing
7 changed files
with
255 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,4 @@ output/ | |
|
||
# Any resolvo snapshots in the root | ||
snapshot-*.json | ||
snapshot_*.json |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
*.csv | ||
.ipynb_checkpoints | ||
*/.ipynb_checkpoints/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,34 @@ use std::{ | |
}; | ||
|
||
use clap::Parser; | ||
use console::style; | ||
use csv::WriterBuilder; | ||
use resolvo::{snapshot::DependencySnapshot, Solver, UnsolvableOrCancelled}; | ||
use itertools::Itertools; | ||
use rand::{ | ||
distributions::{Distribution, WeightedIndex}, | ||
prelude::IteratorRandom, | ||
rngs::StdRng, | ||
Rng, SeedableRng, | ||
}; | ||
use resolvo::{snapshot::DependencySnapshot, Problem, Requirement, Solver, UnsolvableOrCancelled}; | ||
|
||
#[derive(Parser)] | ||
#[clap(version = "0.1.0", author = "Bas Zalmstra <[email protected]>")] | ||
struct Opts { | ||
snapshot: String, | ||
|
||
/// The maximum number of requirements to solve | ||
#[clap(long, short = 'n', default_value = "1000")] | ||
limit: usize, | ||
|
||
/// The timeout to use for solving requirements in seconds. If a solve takes | ||
/// longer if will be cancelled. | ||
#[clap(long, default_value = "60")] | ||
timeout: u64, | ||
|
||
/// The random seed to use for generating the requirements. | ||
#[clap(long, default_value = "0")] | ||
seed: u64, | ||
} | ||
|
||
#[derive(Debug, serde::Serialize)] | ||
|
@@ -27,55 +48,138 @@ fn main() { | |
let opts: Opts = Opts::parse(); | ||
|
||
eprintln!("Loading snapshot ..."); | ||
let snapshot_file = BufReader::new(File::open(&opts.snapshot).unwrap()); | ||
let snapshot_file = BufReader::new(File::open(opts.snapshot).unwrap()); | ||
let snapshot: DependencySnapshot = serde_json::from_reader(snapshot_file).unwrap(); | ||
|
||
let mut writer = WriterBuilder::new() | ||
.has_headers(true) | ||
.from_path("timings.csv") | ||
.unwrap(); | ||
|
||
for (i, (package_name_id, package)) in snapshot.packages.iter().enumerate() { | ||
eprint!( | ||
"solving {} ({i}/{}) ... ", | ||
&package.name, | ||
snapshot.packages.len() | ||
// Generate a range of problems. | ||
let mut rng = StdRng::seed_from_u64(0); | ||
let requirement_dist = WeightedIndex::new(&[ | ||
10, // 10 times more likely to pick a package | ||
if snapshot.version_sets.len() > 0 { | ||
1 | ||
} else { | ||
0 | ||
}, | ||
if snapshot.version_set_unions.len() > 0 { | ||
1 | ||
} else { | ||
0 | ||
}, | ||
]) | ||
.unwrap(); | ||
for i in 0..opts.limit { | ||
// Construct a fresh provider from the snapshot | ||
let mut provider = snapshot | ||
.provider() | ||
.with_timeout(SystemTime::now().add(Duration::from_secs(opts.timeout))); | ||
|
||
// Construct a problem with a random number of requirements. | ||
let mut requirements: Vec<Requirement> = Vec::new(); | ||
|
||
// Determine the number of requirements to solve for. | ||
let num_requirements = rng.gen_range(1..=10usize); | ||
for _ in 0..num_requirements { | ||
match requirement_dist.sample(&mut rng) { | ||
0 => { | ||
// Add a package requirement | ||
let (package, _) = snapshot.packages.iter().choose(&mut rng).unwrap(); | ||
let package_requirement = provider.add_package_requirement(package); | ||
requirements.push(package_requirement.into()); | ||
} | ||
1 => { | ||
// Add a version set requirement | ||
let (version_set_id, _) = | ||
snapshot.version_sets.iter().choose(&mut rng).unwrap(); | ||
requirements.push(version_set_id.into()); | ||
} | ||
2 => { | ||
// Add a version set union requirement | ||
let (version_set_union_id, _) = | ||
snapshot.version_set_unions.iter().choose(&mut rng).unwrap(); | ||
requirements.push(version_set_union_id.into()); | ||
} | ||
_ => unreachable!(), | ||
} | ||
} | ||
|
||
eprintln!( | ||
"solving ({}/{})...\n{}", | ||
i + 1, | ||
opts.limit, | ||
requirements.iter().format_with("\n", |requirement, f| { | ||
f(&format_args!( | ||
"- {}", | ||
style(requirement.display(&provider)).dim() | ||
)) | ||
}) | ||
); | ||
|
||
let problem_name = requirements | ||
.iter() | ||
.format_with("\n", |requirement, f| { | ||
f(&format_args!("{}", requirement.display(&provider))) | ||
}) | ||
.to_string(); | ||
|
||
let start = Instant::now(); | ||
|
||
let mut provider = snapshot | ||
.provider() | ||
.with_timeout(SystemTime::now().add(Duration::from_secs(60))); | ||
let package_requirement = provider.add_package_requirement(package_name_id); | ||
let problem = Problem::default().requirements(requirements); | ||
let mut solver = Solver::new(provider); | ||
let mut records = None; | ||
let mut error = None; | ||
match solver.solve(vec![package_requirement.into()], vec![]) { | ||
let result = solver.solve(problem); | ||
let duration = start.elapsed(); | ||
match result { | ||
Ok(solution) => { | ||
eprintln!("OK"); | ||
eprintln!( | ||
"{}", | ||
style(format!( | ||
"==> OK in {:.2}ms, {} records", | ||
duration.as_secs_f64() * 1000.0, | ||
solution.len(), | ||
)) | ||
.green() | ||
); | ||
records = Some(solution.len()) | ||
} | ||
Err(UnsolvableOrCancelled::Unsolvable(problem)) => { | ||
eprintln!("FAIL"); | ||
eprintln!( | ||
"{}", | ||
style(format!( | ||
"==> FAIL in {:.2}ms", | ||
duration.as_secs_f64() * 1000.0 | ||
)) | ||
.yellow() | ||
); | ||
error = Some(problem.display_user_friendly(&solver).to_string()); | ||
} | ||
Err(_) => { | ||
eprintln!("CANCELLED"); | ||
eprintln!( | ||
"{}", | ||
style(format!( | ||
"==> CANCELLED after {:.2}ms", | ||
duration.as_secs_f64() * 1000.0 | ||
)) | ||
.red() | ||
); | ||
} | ||
} | ||
|
||
let duration = start.elapsed(); | ||
|
||
writer | ||
.serialize(Record { | ||
package: package.name.clone(), | ||
package: problem_name, | ||
duration: duration.as_secs_f64(), | ||
error, | ||
records, | ||
}) | ||
.unwrap(); | ||
|
||
if i % 100 == 0 { | ||
if i % 10 == 0 { | ||
writer.flush().unwrap(); | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e2ecd20f-bd4f-4a4a-82d0-28d2d9db18ac", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import polars as pl\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"plt.rcParams['figure.dpi'] = 150\n", | ||
"plt.rcParams['savefig.dpi'] = 150\n", | ||
"\n", | ||
"base_path = \"base_timings.csv\"\n", | ||
"path = \"timings.csv\"\n", | ||
"\n", | ||
"# These are all the timings we want to see\n", | ||
"paths = [base_path, path]\n", | ||
"\n", | ||
"# Read the CSV\n", | ||
"dfs = [\n", | ||
" pl.scan_csv(path).select(pl.col(\"package\"), pl.col(\"duration\")).collect()\n", | ||
" for path in paths\n", | ||
"]\n", | ||
"\n", | ||
"for path, df in zip(paths,dfs):\n", | ||
" count = df.select(pl.len()).item()\n", | ||
" print(f\"{path}: {count} records\")\n", | ||
"\n", | ||
"# Define the histogram bins\n", | ||
"threshold = 50\n", | ||
"bins = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 25, 30, 35, 40, threshold, threshold + 1]\n", | ||
"\n", | ||
"dfs_capped = [\n", | ||
" df.select([\n", | ||
" pl.col(\"duration\").map_elements(lambda x: min(x, threshold), return_dtype=pl.Float64)\n", | ||
" ]) for df in dfs]\n", | ||
"\n", | ||
"# Create the histogram\n", | ||
"fig, axs = plt.subplots(2, sharex=True)\n", | ||
"\n", | ||
"for path, df_capped, axs in zip(paths, dfs_capped, axs):\n", | ||
" values, bins, bars = axs.hist(df_capped[\"duration\"], bins=bins, density=True)\n", | ||
" axs.set_title(path)\n", | ||
" axs.bar_label(bars, fontsize=8, color='black', labels = [f'{x.get_height():.1%}' for x in bars])\n", | ||
" axs.tick_params(axis='y', which='both', left=False, top=False, labelleft=False)\n", | ||
"\n", | ||
"# Add labels to the ticks\n", | ||
"fig.supxlabel(\"Solve duration in seconds\")\n", | ||
"fig.supylabel(\"Percentage of solves\")\n", | ||
"fig.suptitle(\"Histogram of solve durations\")\n", | ||
"\n", | ||
"\n", | ||
"plt.show()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "88f7081d-d2bf-4fb7-86ab-7fd22dce0202", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the timings\n", | ||
"dfs = [\n", | ||
" pl.scan_csv(path).select(pl.col(\"package\"), pl.col(\"duration\")) \n", | ||
" for path in [base_path, path]\n", | ||
"]\n", | ||
"\n", | ||
"# Compute the solver diffs. Negative values means the second timings are faster\n", | ||
"df_diff = dfs[1].join(dfs[0], on=\"package\").select(pl.col(\"package\"), (pl.col(\"duration\")-pl.col(\"duration_right\"))).collect();\n", | ||
"\n", | ||
"# Create the histogram\n", | ||
"plt.hist(df_diff[\"duration\"], bins=40, density=True)\n", | ||
"plt.xlabel(\"Difference in solve duration in seconds\")\n", | ||
"plt.ylabel(\"Difference probability\")\n", | ||
"\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b06fdad5-e960-4461-8166-e071eb75c1f2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |