Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add action to set custom cost #355

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ impl<'a> ActionCompiler<'a> {
self.do_atom_term(b);
self.instructions.push(Instruction::Extract(2));
}
GenericCoreAction::Set(_ann, f, args, e) => {
GenericCoreAction::Set(_ann, f, args, e, is_cost) => {
let ResolvedCall::Func(func) = f else {
panic!("Cannot set primitive- should have been caught by typechecking!!!")
};
for arg in args {
self.do_atom_term(arg);
}
self.do_atom_term(e);
self.instructions.push(Instruction::Set(func.name));
self.instructions
.push(Instruction::Set(func.name, *is_cost));
}
GenericCoreAction::Change(_ann, change, f, args) => {
let ResolvedCall::Func(func) = f else {
Expand Down Expand Up @@ -125,12 +126,13 @@ enum Instruction {
/// Pop primitive arguments off the stack, calls the primitive,
/// and push the result onto the stack.
CallPrimitive(Primitive, usize),
/// Pop function arguments off the stack and either deletes or subsumes the corresponding row
/// in the function.
/// Pop function arguments off the stack and either deletes, subsumes, or changes the cost
/// of the corresponding row in the function.
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
Change(Change, Symbol),
/// Pop the value to be set and the function arguments off the stack.
/// Set the function at the given arguments to the new value.
Set(Symbol),
/// If the second argument is true, then we set the cost of the function to the value.
Set(Symbol, bool),
/// Union the last `n` values on the stack.
Union(usize),
/// Extract the best expression. `n` is always 2.
Expand Down Expand Up @@ -206,12 +208,19 @@ impl EGraph {
table: Symbol,
new_value: Value,
stack: &mut Vec<Value>,
set_cost: bool,
) -> Result<(), Error> {
let i64sort: Arc<I64Sort> = self.type_info().get_sort_nofail();
let function = self.functions.get_mut(&table).unwrap();

let new_len = stack.len() - function.schema.input.len();
let args = &stack[new_len..];

if set_cost {
let cost = i64::load(&i64sort, &new_value);
function.update_cost(args, cost.try_into().unwrap());
return Ok(());
}
// We should only have canonical values here: omit the canonicalization step
let old_value = function.get(args);

Expand Down Expand Up @@ -328,16 +337,15 @@ impl EGraph {
return Err(Error::PrimitiveError(p.clone(), values.to_vec()));
}
}
Instruction::Set(f) => {
Instruction::Set(f, is_cost) => {
assert!(make_defaults);
let function = self.functions.get_mut(f).unwrap();
// desugaring should have desugared
// set to union
// except for setting the parent relation
let new_value = stack.pop().unwrap();
let new_len = stack.len() - function.schema.input.len();

self.perform_set(*f, new_value, stack)?;
self.perform_set(*f, new_value, stack, *is_cost)?;
stack.truncate(new_len)
}
Instruction::Union(arity) => {
Expand Down
5 changes: 4 additions & 1 deletion src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ fn add_semi_naive_rule(desugar: &mut Desugar, rule: Rule) -> Option<Rule> {
let mut var_set = HashSet::default();
for head_slice in new_rule.head.0.iter_mut().rev() {
match head_slice {
Action::Set(span, _, _, expr) => {
Action::Set(span, _, _, expr, is_cost) => {
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
if *is_cost {
continue;
}
var_set.extend(expr.vars());
if let Expr::Call(..) = expr {
add_new_rule = true;
Expand Down
30 changes: 17 additions & 13 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1224,11 +1224,13 @@ where
/// `set` a function to a particular result.
/// `set` should not be used on datatypes-
/// instead, use `union`.
/// If the last argument is `true`, then we are setting a cost instead of a value.
Set(
Span,
Head,
Vec<GenericExpr<Head, Leaf>>,
GenericExpr<Head, Leaf>,
bool,
),
/// Delete or subsume (mark as hidden from future rewrites and unextractable) an entry from a function.
Change(Span, Change, Head, Vec<GenericExpr<Head, Leaf>>),
Expand Down Expand Up @@ -1311,17 +1313,18 @@ where
fn to_sexp(&self) -> Sexp {
match self {
GenericAction::Let(_ann, lhs, rhs) => list!("let", lhs, rhs),
GenericAction::Set(_ann, lhs, args, rhs) => list!("set", list!(lhs, ++ args), rhs),
GenericAction::Union(_ann, lhs, rhs) => list!("union", lhs, rhs),
GenericAction::Change(_ann, change, lhs, args) => {
GenericAction::Set(_ann, lhs, args, rhs, is_cost) => {
list!(
match change {
Change::Delete => "delete",
Change::Subsume => "subsume",
},
list!(lhs, ++ args)
if *is_cost { "unstable-cost" } else { "set" },
list!(lhs, ++ args),
rhs
)
}
GenericAction::Union(_ann, lhs, rhs) => list!("union", lhs, rhs),
GenericAction::Change(_ann, change, lhs, args) => match change {
Change::Delete => list!("delete", list!(lhs, ++ args)),
Change::Subsume => list!("subsume", list!(lhs, ++ args)),
},
GenericAction::Extract(_ann, expr, variants) => list!("extract", expr, variants),
GenericAction::Panic(_ann, msg) => list!("panic", format!("\"{}\"", msg.clone())),
GenericAction::Expr(_ann, e) => e.to_sexp(),
Expand All @@ -1343,13 +1346,14 @@ where
GenericAction::Let(span, lhs, rhs) => {
GenericAction::Let(span.clone(), lhs.clone(), f(rhs))
}
GenericAction::Set(span, lhs, args, rhs) => {
GenericAction::Set(span, lhs, args, rhs, is_cost) => {
let right = f(rhs);
GenericAction::Set(
span.clone(),
lhs.clone(),
args.iter().map(f).collect(),
right,
*is_cost,
)
}
GenericAction::Change(span, change, lhs, args) => GenericAction::Change(
Expand Down Expand Up @@ -1382,9 +1386,9 @@ where
// TODO should we refactor `Set` so that we can map over Expr::Call(lhs, args)?
// This seems more natural to oflatt
// Currently, visit_exprs does not apply f to the first argument of Set.
GenericAction::Set(span, lhs, args, rhs) => {
GenericAction::Set(span, lhs, args, rhs, is_cost) => {
let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
GenericAction::Set(span, lhs.clone(), args, rhs.visit_exprs(f))
GenericAction::Set(span, lhs.clone(), args, rhs.visit_exprs(f), is_cost)
}
GenericAction::Change(span, change, lhs, args) => {
let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
Expand Down Expand Up @@ -1417,13 +1421,13 @@ where
let rhs = rhs.subst_leaf(&mut fvar_expr!());
GenericAction::Let(span, lhs, rhs)
}
GenericAction::Set(span, lhs, args, rhs) => {
GenericAction::Set(span, lhs, args, rhs, is_cost) => {
let args = args
.into_iter()
.map(|e| e.subst_leaf(&mut fvar_expr!()))
.collect();
let rhs = rhs.subst_leaf(&mut fvar_expr!());
GenericAction::Set(span, lhs.clone(), args, rhs)
GenericAction::Set(span, lhs.clone(), args, rhs, is_cost)
}
GenericAction::Change(span, change, lhs, args) => {
let args = args
Expand Down
25 changes: 13 additions & 12 deletions src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ Command: Command = {
LParen "relation" <constructor:Ident> <inputs:List<Type>> RParen => Command::Relation{constructor, inputs},
LParen "ruleset" <name:Ident> RParen => Command::AddRuleset(name),
LParen "unstable-combined-ruleset" <name:Ident> <subrulesets:Ident*> RParen => Command::UnstableCombinedRuleset(name, subrulesets),
<lo:LParen> "rule" <body:List<Fact>> <head:List<Action>>
<ruleset:(":ruleset" <Ident>)?>
<name:(":name" <String>)?>
<lo:LParen> "rule" <body:List<Fact>> <head:List<Action>>
<ruleset:(":ruleset" <Ident>)?>
<name:(":name" <String>)?>
<hi:RParen> => Command::Rule{ruleset: ruleset.unwrap_or("".into()), name: name.unwrap_or("".to_string()).into(), rule: Rule { span: Span(srcfile.clone(), lo, hi), head: Actions::new(head), body }},
<lo:LParen> "rewrite" <lhs:Expr> <rhs:Expr>
<subsume:(":subsume")?>
Expand All @@ -74,10 +74,10 @@ Command: Command = {
<hi:RParen> => Command::BiRewrite(ruleset.unwrap_or("".into()), Rewrite { span: Span(srcfile.clone(), lo, hi), lhs, rhs, conditions: conditions.unwrap_or_default() }),
<lo:LParen> "let" <name:Ident> <expr:Expr> <hi:RParen> => Command::Action(Action::Let(Span(srcfile.clone(), lo, hi), name, expr)),
<NonLetAction> => Command::Action(<>),
<lo:LParen> "run" <limit:UNum> <until:(":until" <(Fact)*>)?> <hi:RParen> =>
<lo:LParen> "run" <limit:UNum> <until:(":until" <(Fact)*>)?> <hi:RParen> =>
Command::RunSchedule(Schedule::Repeat(Span(srcfile.clone(), lo, hi), limit, Box::new(
Schedule::Run(Span(srcfile.clone(), lo, hi), RunConfig { ruleset : "".into(), until })))),
<lo:LParen> "run" <ruleset: Ident> <limit:UNum> <until:(":until" <(Fact)*>)?> <hi:RParen> =>
<lo:LParen> "run" <ruleset: Ident> <limit:UNum> <until:(":until" <(Fact)*>)?> <hi:RParen> =>
Command::RunSchedule(Schedule::Repeat(Span(srcfile.clone(), lo, hi), limit, Box::new(
Schedule::Run(Span(srcfile.clone(), lo, hi), RunConfig { ruleset, until })))),
LParen "simplify" <schedule:Schedule> <expr:Expr> RParen
Expand All @@ -86,7 +86,7 @@ Command: Command = {
LParen "query-extract" <variants:(":variants" <UNum>)?> <expr:Expr> RParen => Command::QueryExtract { expr, variants: variants.unwrap_or(0) },
<lo:LParen> "check" <facts:(Fact)*> <hi:RParen> => Command::Check(Span(srcfile.clone(), lo, hi), facts),
LParen "check-proof" RParen => Command::CheckProof,
<lo:LParen> "run-schedule" <scheds:Schedule*> <hi:RParen> =>
<lo:LParen> "run-schedule" <scheds:Schedule*> <hi:RParen> =>
Command::RunSchedule(Schedule::Sequence(Span(srcfile.clone(), lo, hi), scheds)),
LParen "print-stats" RParen => Command::PrintOverallStatistics,
LParen "push" <UNum?> RParen => Command::Push(<>.unwrap_or(1)),
Expand All @@ -100,19 +100,19 @@ Command: Command = {
}

Schedule: Schedule = {
<lo:LParen> "saturate" <scheds:Schedule*> <hi:RParen> =>
<lo:LParen> "saturate" <scheds:Schedule*> <hi:RParen> =>
Schedule::Saturate(Span(srcfile.clone(), lo, hi), Box::new(
Schedule::Sequence(Span(srcfile.clone(), lo, hi), scheds))),
<lo:LParen> "seq" <scheds:Schedule*> <hi:RParen> =>
<lo:LParen> "seq" <scheds:Schedule*> <hi:RParen> =>
Schedule::Sequence(Span(srcfile.clone(), lo, hi), scheds),
<lo:LParen> "repeat" <limit:UNum> <scheds:Schedule*> <hi:RParen> =>
<lo:LParen> "repeat" <limit:UNum> <scheds:Schedule*> <hi:RParen> =>
Schedule::Repeat(Span(srcfile.clone(), lo, hi), limit, Box::new(
Schedule::Sequence(Span(srcfile.clone(), lo, hi), scheds))),
<lo:LParen> "run" <until:(":until" <(Fact)*>)?> <hi:RParen> =>
Schedule::Run(Span(srcfile.clone(), lo, hi), RunConfig { ruleset: "".into(), until }),
<lo:LParen> "run" <ruleset: Ident> <until:(":until" <(Fact)*>)?> <hi:RParen> =>
<lo:LParen> "run" <ruleset: Ident> <until:(":until" <(Fact)*>)?> <hi:RParen> =>
Schedule::Run(Span(srcfile.clone(), lo, hi), RunConfig { ruleset, until }),
<lo:@L> <ident:Ident> <hi:@R> =>
<lo:@L> <ident:Ident> <hi:@R> =>
Schedule::Run(Span(srcfile.clone(), lo, hi), RunConfig { ruleset: ident, until: None }),
}

Expand All @@ -122,9 +122,10 @@ Cost: Option<usize> = {
}

NonLetAction: Action = {
<lo:LParen> "set" LParen <f: Ident> <args:Expr*> RParen <v:Expr> <hi:RParen> => Action::Set ( Span(srcfile.clone(), lo, hi), f, args, v ),
<lo:LParen> "set" LParen <f: Ident> <args:Expr*> RParen <v:Expr> <hi:RParen> => Action::Set ( Span(srcfile.clone(), lo, hi), f, args, v, false ),
<lo:LParen> "delete" LParen <f: Ident> <args:Expr*> RParen <hi:RParen> => Action::Change ( Span(srcfile.clone(), lo, hi), Change::Delete, f, args),
<lo:LParen> "subsume" LParen <f: Ident> <args:Expr*> RParen <hi:RParen> => Action::Change ( Span(srcfile.clone(), lo, hi), Change::Subsume, f, args),
<lo:LParen> "unstable-cost" LParen <f: Ident> <args:Expr*> RParen <cost:Expr> <hi:RParen> => Action::Set ( Span(srcfile.clone(), lo, hi), f, args, cost, true ),
<lo:LParen> "union" <e1:Expr> <e2:Expr> <hi:RParen> => Action::Union(Span(srcfile.clone(), lo, hi), e1, e2),
<lo:LParen> "panic" <msg:String> <hi:RParen> => Action::Panic(Span(srcfile.clone(), lo, hi), msg),
<lo:LParen> "extract" <expr:Expr> <hi:RParen> => Action::Extract(Span(srcfile.clone(), lo, hi), expr, Expr::Lit(Span(srcfile.clone(), lo, hi), Literal::Int(0))),
Expand Down
1 change: 1 addition & 0 deletions src/ast/remove_globals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ impl<'a> GlobalRemover<'a> {
resolved_call,
vec![],
remove_globals_expr(expr),
false,
))
},
]
Expand Down
34 changes: 27 additions & 7 deletions src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,18 +344,24 @@ impl Assignment<AtomTerm, ArcSort> {
},
children,
rhs,
is_cost,
) => {
let children: Vec<_> = children
.iter()
.map(|child| self.annotate_expr(child, typeinfo))
.collect();
let rhs = self.annotate_expr(rhs, typeinfo);
let types: Vec<_> = children
let mut types: Vec<_> = children
.iter()
.map(|child| child.output_type(typeinfo))
.chain(once(rhs.output_type(typeinfo)))
.collect();
let resolved_call = ResolvedCall::from_resolution(head, &types, typeinfo);
let resolved_call = if *is_cost {
ResolvedCall::from_resolution_func_types(head, &types, typeinfo)
.ok_or_else(|| TypeError::UnboundFunction(*head))?
} else {
types.push(rhs.output_type(typeinfo));
ResolvedCall::from_resolution(head, &types, typeinfo)
};
if !matches!(resolved_call, ResolvedCall::Func(_)) {
return Err(TypeError::UnboundFunction(*head));
}
Expand All @@ -364,6 +370,7 @@ impl Assignment<AtomTerm, ArcSort> {
resolved_call,
children,
rhs,
*is_cost,
))
}
// Note mapped_var for delete is a dummy variable that does not mean anything
Expand Down Expand Up @@ -534,15 +541,28 @@ impl CoreAction {
.chain(get_atom_application_constraints(f, &args, span, typeinfo)?)
.collect())
}
CoreAction::Set(span, head, args, rhs) => {
CoreAction::Set(span, head, args, rhs, is_cost) => {
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
let mut args = args.clone();
args.push(rhs.clone());
if *is_cost {
// Add a dummy last output argument
let var = symbol_gen.fresh(head);
args.push(AtomTerm::Var(span.clone(), var));
} else {
args.push(rhs.clone());
}

Ok(get_literal_and_global_constraints(&args, typeinfo)
let mut res: Vec<_> = get_literal_and_global_constraints(&args, typeinfo)
.chain(get_atom_application_constraints(
head, &args, span, typeinfo,
)?)
.collect())
.collect();
if *is_cost {
res.push(Constraint::Assign(
rhs.clone(),
typeinfo.get_sort_nofail::<I64Sort>() as ArcSort,
));
}
Ok(res)
}
CoreAction::Change(span, _change, head, args) => {
let mut args = args.clone();
Expand Down
5 changes: 4 additions & 1 deletion src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ pub enum GenericCoreAction<Head, Leaf> {
Head,
Vec<GenericAtomTerm<Leaf>>,
GenericAtomTerm<Leaf>,
bool,
),
Change(Span, Change, Head, Vec<GenericAtomTerm<Leaf>>),
Union(Span, GenericAtomTerm<Leaf>, GenericAtomTerm<Leaf>),
Expand Down Expand Up @@ -462,7 +463,7 @@ where
));
binding.insert(var.clone());
}
GenericAction::Set(span, head, args, expr) => {
GenericAction::Set(span, head, args, expr, is_cost) => {
let mut mapped_args = vec![];
for arg in args {
let (actions, mapped_arg) =
Expand All @@ -481,13 +482,15 @@ where
.map(|e| e.get_corresponding_var_or_lit(typeinfo))
.collect(),
mapped_expr.get_corresponding_var_or_lit(typeinfo),
*is_cost,
));
let v = fresh_gen.fresh(head);
mapped_actions.0.push(GenericAction::Set(
span.clone(),
CorrespondingVar::new(head.clone(), v),
mapped_args,
mapped_expr,
*is_cost,
));
}
GenericAction::Change(span, change, head, args) => {
Expand Down
5 changes: 3 additions & 2 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ impl<'a> Extractor<'a> {
function: &Function,
children: &[Value],
termdag: &mut TermDag,
cost: Option<usize>,
) -> Option<(Vec<Term>, Cost)> {
let mut cost = function.decl.cost.unwrap_or(1);
let mut cost = cost.unwrap_or(function.decl.cost.unwrap_or(1));
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
let types = &function.schema.input;
let mut terms: Vec<Term> = vec![];
for (ty, value) in types.iter().zip(children) {
Expand Down Expand Up @@ -182,7 +183,7 @@ impl<'a> Extractor<'a> {
if func.schema.output.is_eq_sort() {
for (inputs, output) in func.nodes.iter(false) {
if let Some((term_inputs, new_cost)) =
self.node_total_cost(func, inputs, termdag)
self.node_total_cost(func, inputs, termdag, output.cost)
{
let make_new_pair = || (new_cost, termdag.app(sym, term_inputs));

Expand Down
Loading
Loading