From a291024a78f71cd33198524a9795ff5811656ce1 Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Tue, 20 Aug 2024 15:09:52 +0200 Subject: [PATCH] feat: add better timing notebook and measuring (#64) --- .gitignore | 1 + Cargo.lock | 10 ++ src/requirement.rs | 3 +- tools/solve-snapshot/.gitignore | 3 + tools/solve-snapshot/Cargo.toml | 3 + tools/solve-snapshot/src/main.rs | 142 ++++++++++++++++--- tools/solve-snapshot/timing_comparison.ipynb | 113 +++++++++++++++ 7 files changed, 255 insertions(+), 20 deletions(-) create mode 100644 tools/solve-snapshot/.gitignore create mode 100644 tools/solve-snapshot/timing_comparison.ipynb diff --git a/.gitignore b/.gitignore index 45f2537..8586614 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ output/ # Any resolvo snapshots in the root snapshot-*.json +snapshot_*.json diff --git a/Cargo.lock b/Cargo.lock index ad81fe7..848254c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", + "unicode-width", "windows-sys 0.52.0", ] @@ -1310,7 +1311,10 @@ name = "solve-snapshot" version = "0.1.0" dependencies = [ "clap 4.5.7", + "console", "csv", + "itertools", + "rand", "resolvo", "serde", "serde_json", @@ -1512,6 +1516,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-width" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" + [[package]] name = "utf8parse" version = "0.2.2" diff --git a/src/requirement.rs b/src/requirement.rs index 10b9e46..244ec48 100644 --- a/src/requirement.rs +++ b/src/requirement.rs @@ -34,7 +34,8 @@ impl From for Requirement { } impl Requirement { - pub(crate) fn display<'i>(&'i self, interner: &'i impl Interner) -> impl Display + '_ { + /// Returns an object that implements `Display` for the requirement. + pub fn display<'i>(&'i self, interner: &'i impl Interner) -> impl Display + '_ { DisplayRequirement { interner, requirement: self, diff --git a/tools/solve-snapshot/.gitignore b/tools/solve-snapshot/.gitignore new file mode 100644 index 0000000..4b5662d --- /dev/null +++ b/tools/solve-snapshot/.gitignore @@ -0,0 +1,3 @@ +*.csv +.ipynb_checkpoints +*/.ipynb_checkpoints/* diff --git a/tools/solve-snapshot/Cargo.toml b/tools/solve-snapshot/Cargo.toml index 6cefb05..1dab50d 100644 --- a/tools/solve-snapshot/Cargo.toml +++ b/tools/solve-snapshot/Cargo.toml @@ -10,3 +10,6 @@ clap = { version = "4.5", features = ["derive"] } csv = "1.3" serde_json = "1.0" serde = { version = "1.0.196", features = ["derive"] } +rand = "0.8.5" +itertools = "0.13.0" +console = "0.15.8" diff --git a/tools/solve-snapshot/src/main.rs b/tools/solve-snapshot/src/main.rs index a58e1e4..67692f6 100644 --- a/tools/solve-snapshot/src/main.rs +++ b/tools/solve-snapshot/src/main.rs @@ -6,13 +6,34 @@ use std::{ }; use clap::Parser; +use console::style; use csv::WriterBuilder; -use resolvo::{snapshot::DependencySnapshot, Solver, UnsolvableOrCancelled}; +use itertools::Itertools; +use rand::{ + distributions::{Distribution, WeightedIndex}, + prelude::IteratorRandom, + rngs::StdRng, + Rng, SeedableRng, +}; +use resolvo::{snapshot::DependencySnapshot, Problem, Requirement, Solver, UnsolvableOrCancelled}; #[derive(Parser)] #[clap(version = "0.1.0", author = "Bas Zalmstra ")] struct Opts { snapshot: String, + + /// The maximum number of requirements to solve + #[clap(long, short = 'n', default_value = "1000")] + limit: usize, + + /// The timeout to use for solving requirements in seconds. If a solve takes + /// longer if will be cancelled. + #[clap(long, default_value = "60")] + timeout: u64, + + /// The random seed to use for generating the requirements. + #[clap(long, default_value = "0")] + seed: u64, } #[derive(Debug, serde::Serialize)] @@ -27,7 +48,7 @@ fn main() { let opts: Opts = Opts::parse(); eprintln!("Loading snapshot ..."); - let snapshot_file = BufReader::new(File::open(&opts.snapshot).unwrap()); + let snapshot_file = BufReader::new(File::open(opts.snapshot).unwrap()); let snapshot: DependencySnapshot = serde_json::from_reader(snapshot_file).unwrap(); let mut writer = WriterBuilder::new() @@ -35,47 +56,130 @@ fn main() { .from_path("timings.csv") .unwrap(); - for (i, (package_name_id, package)) in snapshot.packages.iter().enumerate() { - eprint!( - "solving {} ({i}/{}) ... ", - &package.name, - snapshot.packages.len() + // Generate a range of problems. + let mut rng = StdRng::seed_from_u64(0); + let requirement_dist = WeightedIndex::new(&[ + 10, // 10 times more likely to pick a package + if snapshot.version_sets.len() > 0 { + 1 + } else { + 0 + }, + if snapshot.version_set_unions.len() > 0 { + 1 + } else { + 0 + }, + ]) + .unwrap(); + for i in 0..opts.limit { + // Construct a fresh provider from the snapshot + let mut provider = snapshot + .provider() + .with_timeout(SystemTime::now().add(Duration::from_secs(opts.timeout))); + + // Construct a problem with a random number of requirements. + let mut requirements: Vec = Vec::new(); + + // Determine the number of requirements to solve for. + let num_requirements = rng.gen_range(1..=10usize); + for _ in 0..num_requirements { + match requirement_dist.sample(&mut rng) { + 0 => { + // Add a package requirement + let (package, _) = snapshot.packages.iter().choose(&mut rng).unwrap(); + let package_requirement = provider.add_package_requirement(package); + requirements.push(package_requirement.into()); + } + 1 => { + // Add a version set requirement + let (version_set_id, _) = + snapshot.version_sets.iter().choose(&mut rng).unwrap(); + requirements.push(version_set_id.into()); + } + 2 => { + // Add a version set union requirement + let (version_set_union_id, _) = + snapshot.version_set_unions.iter().choose(&mut rng).unwrap(); + requirements.push(version_set_union_id.into()); + } + _ => unreachable!(), + } + } + + eprintln!( + "solving ({}/{})...\n{}", + i + 1, + opts.limit, + requirements.iter().format_with("\n", |requirement, f| { + f(&format_args!( + "- {}", + style(requirement.display(&provider)).dim() + )) + }) ); + + let problem_name = requirements + .iter() + .format_with("\n", |requirement, f| { + f(&format_args!("{}", requirement.display(&provider))) + }) + .to_string(); + let start = Instant::now(); - let mut provider = snapshot - .provider() - .with_timeout(SystemTime::now().add(Duration::from_secs(60))); - let package_requirement = provider.add_package_requirement(package_name_id); + let problem = Problem::default().requirements(requirements); let mut solver = Solver::new(provider); let mut records = None; let mut error = None; - match solver.solve(vec![package_requirement.into()], vec![]) { + let result = solver.solve(problem); + let duration = start.elapsed(); + match result { Ok(solution) => { - eprintln!("OK"); + eprintln!( + "{}", + style(format!( + "==> OK in {:.2}ms, {} records", + duration.as_secs_f64() * 1000.0, + solution.len(), + )) + .green() + ); records = Some(solution.len()) } Err(UnsolvableOrCancelled::Unsolvable(problem)) => { - eprintln!("FAIL"); + eprintln!( + "{}", + style(format!( + "==> FAIL in {:.2}ms", + duration.as_secs_f64() * 1000.0 + )) + .yellow() + ); error = Some(problem.display_user_friendly(&solver).to_string()); } Err(_) => { - eprintln!("CANCELLED"); + eprintln!( + "{}", + style(format!( + "==> CANCELLED after {:.2}ms", + duration.as_secs_f64() * 1000.0 + )) + .red() + ); } } - let duration = start.elapsed(); - writer .serialize(Record { - package: package.name.clone(), + package: problem_name, duration: duration.as_secs_f64(), error, records, }) .unwrap(); - if i % 100 == 0 { + if i % 10 == 0 { writer.flush().unwrap(); } } diff --git a/tools/solve-snapshot/timing_comparison.ipynb b/tools/solve-snapshot/timing_comparison.ipynb new file mode 100644 index 0000000..5bab85c --- /dev/null +++ b/tools/solve-snapshot/timing_comparison.ipynb @@ -0,0 +1,113 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "e2ecd20f-bd4f-4a4a-82d0-28d2d9db18ac", + "metadata": {}, + "outputs": [], + "source": [ + "import polars as pl\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.rcParams['figure.dpi'] = 150\n", + "plt.rcParams['savefig.dpi'] = 150\n", + "\n", + "base_path = \"base_timings.csv\"\n", + "path = \"timings.csv\"\n", + "\n", + "# These are all the timings we want to see\n", + "paths = [base_path, path]\n", + "\n", + "# Read the CSV\n", + "dfs = [\n", + " pl.scan_csv(path).select(pl.col(\"package\"), pl.col(\"duration\")).collect()\n", + " for path in paths\n", + "]\n", + "\n", + "for path, df in zip(paths,dfs):\n", + " count = df.select(pl.len()).item()\n", + " print(f\"{path}: {count} records\")\n", + "\n", + "# Define the histogram bins\n", + "threshold = 50\n", + "bins = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 25, 30, 35, 40, threshold, threshold + 1]\n", + "\n", + "dfs_capped = [\n", + " df.select([\n", + " pl.col(\"duration\").map_elements(lambda x: min(x, threshold), return_dtype=pl.Float64)\n", + " ]) for df in dfs]\n", + "\n", + "# Create the histogram\n", + "fig, axs = plt.subplots(2, sharex=True)\n", + "\n", + "for path, df_capped, axs in zip(paths, dfs_capped, axs):\n", + " values, bins, bars = axs.hist(df_capped[\"duration\"], bins=bins, density=True)\n", + " axs.set_title(path)\n", + " axs.bar_label(bars, fontsize=8, color='black', labels = [f'{x.get_height():.1%}' for x in bars])\n", + " axs.tick_params(axis='y', which='both', left=False, top=False, labelleft=False)\n", + "\n", + "# Add labels to the ticks\n", + "fig.supxlabel(\"Solve duration in seconds\")\n", + "fig.supylabel(\"Percentage of solves\")\n", + "fig.suptitle(\"Histogram of solve durations\")\n", + "\n", + "\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88f7081d-d2bf-4fb7-86ab-7fd22dce0202", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the timings\n", + "dfs = [\n", + " pl.scan_csv(path).select(pl.col(\"package\"), pl.col(\"duration\")) \n", + " for path in [base_path, path]\n", + "]\n", + "\n", + "# Compute the solver diffs. Negative values means the second timings are faster\n", + "df_diff = dfs[1].join(dfs[0], on=\"package\").select(pl.col(\"package\"), (pl.col(\"duration\")-pl.col(\"duration_right\"))).collect();\n", + "\n", + "# Create the histogram\n", + "plt.hist(df_diff[\"duration\"], bins=40, density=True)\n", + "plt.xlabel(\"Difference in solve duration in seconds\")\n", + "plt.ylabel(\"Difference probability\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b06fdad5-e960-4461-8166-e071eb75c1f2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}