Skip to content

Commit

Permalink
feat(solver)!: Make Problem use builder pattern (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
eviltak authored Aug 12, 2024
1 parent 0681a0f commit 2700ea0
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 92 deletions.
44 changes: 24 additions & 20 deletions cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,26 +490,30 @@ pub extern "C" fn resolvo_solve(
) -> bool {
let mut solver = resolvo::Solver::new(provider);

let problem = resolvo::Problem {
requirements: problem
.requirements
.into_iter()
.copied()
.map(Into::into)
.collect(),
constraints: problem
.constraints
.into_iter()
.copied()
.map(Into::into)
.collect(),
soft_requirements: problem
.soft_requirements
.into_iter()
.copied()
.map(Into::into)
.collect(),
};
let problem = resolvo::Problem::new()
.requirements(
problem
.requirements
.into_iter()
.copied()
.map(Into::into)
.collect(),
)
.constraints(
problem
.constraints
.into_iter()
.copied()
.map(Into::into)
.collect(),
)
.soft_requirements(
problem
.soft_requirements
.into_iter()
.copied()
.map(Into::into),
);

match solver.solve(problem) {
Ok(solution) => {
Expand Down
77 changes: 68 additions & 9 deletions src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,60 @@ struct AddClauseOutput {
}

/// Describes the problem that is to be solved by the solver.
#[derive(Default)]
pub struct Problem {
/// The requirements that _must_ have one candidate solvable be included in the
///
/// This struct is generic over the type `S` of the collection of soft requirements passed
/// to the solver, typically expected to be a type implementing [`IntoIterator`].
///
/// This struct follows the builder pattern and can have its fields set by one of the available
/// setter methods.
pub struct Problem<S> {
requirements: Vec<Requirement>,
constraints: Vec<VersionSetId>,
soft_requirements: S,
}

impl Default for Problem<std::iter::Empty<SolvableId>> {
fn default() -> Self {
Self::new()
}
}

impl Problem<std::iter::Empty<SolvableId>> {
/// Creates a new empty [`Problem`]. Use the setter methods to build the problem
/// before passing it to the solver to be solved.
pub fn new() -> Self {
Self {
requirements: Default::default(),
constraints: Default::default(),
soft_requirements: Default::default(),
}
}
}

impl<S: IntoIterator<Item = SolvableId>> Problem<S> {
/// Sets the requirements that _must_ have one candidate solvable be included in the
/// solution.
pub requirements: Vec<Requirement>,
///
/// Returns the [`Problem`] for further mutation or to pass to [`Solver::solve`].
pub fn requirements(self, requirements: Vec<Requirement>) -> Self {
Self {
requirements,
..self
}
}

/// Additional constraints imposed on individual packages that the solvable (if any)
/// Sets the additional constraints imposed on individual packages that the solvable (if any)
/// chosen for that package _must_ adhere to.
pub constraints: Vec<VersionSetId>,
///
/// Returns the [`Problem`] for further mutation or to pass to [`Solver::solve`].
pub fn constraints(self, constraints: Vec<VersionSetId>) -> Self {
Self {
constraints,
..self
}
}

/// A set of additional requirements that the solver should _try_ and fulfill once it has
/// Sets the additional requirements that the solver should _try_ and fulfill once it has
/// found a solution to the main problem.
///
/// An unsatisfiable soft requirement does not cause a conflict; the solver will try
Expand All @@ -54,7 +97,20 @@ pub struct Problem {
/// Soft requirements are currently only specified as individual solvables to be
/// included in the solution, however in the future they will be able to be specified
/// as version sets.
pub soft_requirements: Vec<SolvableId>,
///
/// # Returns
///
/// Returns the [`Problem`] for further mutation or to pass to [`Solver::solve`].
pub fn soft_requirements<I: IntoIterator<Item = SolvableId>>(
self,
soft_requirements: I,
) -> Problem<I> {
Problem {
requirements: self.requirements,
constraints: self.constraints,
soft_requirements,
}
}
}

/// Drives the SAT solving process.
Expand Down Expand Up @@ -201,7 +257,10 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
///
/// If the solution process is cancelled (see [`DependencyProvider::should_cancel_with_value`]),
/// returns an [`UnsolvableOrCancelled::Cancelled`] containing the cancellation value.
pub fn solve(&mut self, problem: Problem) -> Result<Vec<SolvableId>, UnsolvableOrCancelled> {
pub fn solve(
&mut self,
problem: Problem<impl IntoIterator<Item = SolvableId>>,
) -> Result<Vec<SolvableId>, UnsolvableOrCancelled> {
self.decision_tracker.clear();
self.negative_assertions.clear();
self.learnt_clauses.clear();
Expand Down
87 changes: 24 additions & 63 deletions tests/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,7 @@ fn transaction_to_string(interner: &impl Interner, solvables: &Vec<SolvableId>)
fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String {
let requirements = provider.requirements(specs);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
match solver.solve(problem) {
Ok(_) => panic!("expected unsat, but a solution was found"),
Err(UnsolvableOrCancelled::Unsolvable(conflict)) => {
Expand Down Expand Up @@ -592,10 +589,7 @@ fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String {

let requirements = provider.parse_requirements(specs);
let mut solver = Solver::new(provider).with_runtime(runtime);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
match solver.solve(problem) {
Ok(solvables) => transaction_to_string(solver.provider(), &solvables),
Err(UnsolvableOrCancelled::Unsolvable(conflict)) => {
Expand All @@ -621,10 +615,7 @@ fn test_unit_propagation_1() {
let provider = BundleBoxProvider::from_packages(&[("asdf", 1, vec![])]);
let requirements = provider.requirements(&["asdf"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let pool = &solver.provider().pool;

Expand All @@ -645,10 +636,7 @@ fn test_unit_propagation_nested() {
]);
let requirements = provider.requirements(&["asdf"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let pool = &solver.provider().pool;

Expand Down Expand Up @@ -676,10 +664,7 @@ fn test_resolve_multiple() {
]);
let requirements = provider.requirements(&["asdf", "efgh"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let pool = &solver.provider().pool;

Expand Down Expand Up @@ -738,10 +723,7 @@ fn test_resolve_with_nonexisting() {
]);
let requirements = provider.requirements(&["asdf"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let pool = &solver.provider().pool;

Expand Down Expand Up @@ -776,10 +758,7 @@ fn test_resolve_with_nested_deps() {
]);
let requirements = provider.requirements(&["apache-airflow"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let pool = &solver.provider().pool;

Expand All @@ -804,10 +783,7 @@ fn test_resolve_with_unknown_deps() {
provider.add_package("opentelemetry-api", Pack::new(2), &[], &[]);
let requirements = provider.requirements(&["opentelemetry-api"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let pool = &solver.provider().pool;

Expand Down Expand Up @@ -853,10 +829,7 @@ fn test_resolve_locked_top_level() {
let requirements = provider.requirements(&["asdf"]);

let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let pool = &solver.provider().pool;

Expand All @@ -879,10 +852,7 @@ fn test_resolve_ignored_locked_top_level() {

let requirements = provider.requirements(&["asdf"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();
let pool = &solver.provider().pool;

Expand Down Expand Up @@ -941,10 +911,7 @@ fn test_resolve_cyclic() {
BundleBoxProvider::from_packages(&[("a", 2, vec!["b 0..10"]), ("b", 5, vec!["a 2..4"])]);
let requirements = provider.requirements(&["a 0..100"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
..Default::default()
};
let problem = Problem::new().requirements(requirements);
let solved = solver.solve(problem).unwrap();

let result = transaction_to_string(solver.provider(), &solved);
Expand Down Expand Up @@ -1228,11 +1195,9 @@ fn test_constraints() {
let requirements = provider.requirements(&["a 0..10"]);
let constraints = provider.requirements(&["b 1..2", "c"]);
let mut solver = Solver::new(provider);
let problem = Problem {
requirements,
constraints,
..Default::default()
};
let problem = Problem::new()
.requirements(requirements)
.constraints(constraints);
let solved = solver.solve(problem).unwrap();

let result = transaction_to_string(solver.provider(), &solved);
Expand Down Expand Up @@ -1272,11 +1237,10 @@ fn test_solve_with_additional() {

let mut solver = Solver::new(provider);

let problem = Problem {
requirements,
constraints,
soft_requirements: extra_solvables.to_vec(),
};
let problem = Problem::new()
.requirements(requirements)
.constraints(constraints)
.soft_requirements(extra_solvables);
let solved = solver.solve(problem).unwrap();

let result = transaction_to_string(solver.provider(), &solved);
Expand Down Expand Up @@ -1324,11 +1288,11 @@ fn test_solve_with_additional_with_constrains() {

let mut solver = Solver::new(provider);

let problem = Problem {
requirements,
constraints,
soft_requirements: extra_solvables.to_vec(),
};
let problem = Problem::new()
.requirements(requirements)
.constraints(constraints)
.soft_requirements(extra_solvables);

let solved = solver.solve(problem).unwrap();

let result = transaction_to_string(solver.provider(), &solved);
Expand Down Expand Up @@ -1404,10 +1368,7 @@ fn serialize_snapshot(snapshot: &DependencySnapshot, destination: impl AsRef<std

fn solve_for_snapshot(provider: SnapshotProvider, root_reqs: &[VersionSetId]) -> String {
let mut solver = Solver::new(provider);
let problem = Problem {
requirements: root_reqs.iter().copied().map(Into::into).collect(),
..Default::default()
};
let problem = Problem::new().requirements(root_reqs.iter().copied().map(Into::into).collect());
match solver.solve(problem) {
Ok(solvables) => transaction_to_string(solver.provider(), &solvables),
Err(UnsolvableOrCancelled::Unsolvable(conflict)) => {
Expand Down

0 comments on commit 2700ea0

Please sign in to comment.