Skip to content

Commit

Permalink
feat(optim)!: add a new param toAvoid to the constant folding pass
Browse files Browse the repository at this point in the history
  • Loading branch information
EmileRolley committed Feb 14, 2024
1 parent d7da72e commit 1ee293b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 31 deletions.
60 changes: 35 additions & 25 deletions source/optims/constantFolding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,20 @@ type RefMaps = {
childs: RefMap
}

export type PredicateOnRule = (rule: [RuleName, RuleNode]) => boolean
export type PredicateOnRule = (rule: RuleNode) => boolean

/**
* Parameters for the constant folding optimization pass.
*
* @field toAvoid A predicate that returns true if the rule should be avoided to be folded,
* if not present, all rules will be folded.
* @field toKeep A predicate that returns true if the rule should be kept AFTER being folded,
* if not present, all folded rules will be kept.
* @field isFoldedAttr The attribute name to use to mark a rule as folded, default to 'optimized'.
*/
export type FoldingParams = {
// The attribute name to use to mark a rule as folded, default to 'optimized'.
toAvoid?: PredicateOnRule
toKeep?: PredicateOnRule
isFoldedAttr?: string
}

Expand Down Expand Up @@ -58,7 +68,6 @@ function addMapEntry(map: RefMap, key: RuleName, values: Set<RuleName>) {

function initFoldingCtx(
engine: Engine,
toKeep?: PredicateOnRule,
foldingParams?: FoldingParams,
): FoldingCtx {
const refs: RefMaps = {
Expand Down Expand Up @@ -138,9 +147,9 @@ function initFoldingCtx(
engine,
parsedRules,
refs,
toKeep,
unfoldableRules,
params: {
...foldingParams,
isFoldedAttr: foldingParams?.isFoldedAttr ?? 'optimized',
},
}
Expand All @@ -161,6 +170,7 @@ function isFoldable(ctx: FoldingCtx, rule: RuleNode): boolean {

return (
rule !== undefined &&
!ctx.params.toAvoid?.(rule) &&
!unfoldableAttr.find((attr) => attr in rule.rawNode) &&
!ctx.unfoldableRules.has(rule.dottedName) &&
!childInContext
Expand All @@ -184,20 +194,23 @@ function searchAndReplaceConstantValueInParentRefs(
if (refs) {
for (const parentName of refs) {
const parentRule = ctx.parsedRules[parentName]
const newRule = traverseASTNode(
transformAST((node, _) => {
if (node.nodeKind === 'reference' && node.dottedName === ruleName) {
return constantNode
}
}),
parentRule,
) as RuleNode

if (newRule !== undefined) {
ctx.parsedRules[parentName] = newRule
ctx.parsedRules[parentName].rawNode[ctx.params.isFoldedAttr] =
'partially'
removeInMap(ctx.refs.parents, ruleName, parentName)

if (!ctx.params.toAvoid?.(parentRule)) {
const newRule = traverseASTNode(
transformAST((node, _) => {
if (node.nodeKind === 'reference' && node.dottedName === ruleName) {
return constantNode
}
}),
parentRule,
) as RuleNode

if (newRule !== undefined) {
ctx.parsedRules[parentName] = newRule
ctx.parsedRules[parentName].rawNode[ctx.params.isFoldedAttr] =
'partially'
removeInMap(ctx.refs.parents, ruleName, parentName)
}
}
}
}
Expand Down Expand Up @@ -227,7 +240,7 @@ function tryToDeleteRule(ctx: FoldingCtx, dottedName: RuleName): boolean {
const ruleNode = ctx.parsedRules[dottedName]

if (
(ctx.toKeep === undefined || !ctx.toKeep([dottedName, ruleNode])) &&
(ctx.params.toKeep === undefined || !ctx.params.toKeep(ruleNode)) &&
isFoldable(ctx, ruleNode)
) {
removeRuleFromRefs(ctx.refs.parents, dottedName)
Expand Down Expand Up @@ -420,18 +433,15 @@ function copyFullParsedRules(engine: Engine): ParsedRules<RuleName> {
* Applies a constant folding optimisation pass on parsed rules of [engine].
*
* @param engine The engine instantiated with the rules to fold.
* @param toKeep A predicate that returns true if the rule should be kept, if not present,
* all folded rules will be kept.
* @param params The folding parameters.
*
* @returns The parsed rules with constant folded rules.
*/
export function constantFolding(
engine: Engine,
toKeep?: PredicateOnRule,
params?: FoldingParams,
): ParsedRules<RuleName> {
let ctx = initFoldingCtx(engine, toKeep, params)
let ctx = initFoldingCtx(engine, params)

let nbRules = Object.keys(ctx.parsedRules).length
let nbRulesBefore = undefined
Expand All @@ -448,14 +458,14 @@ export function constantFolding(
nbRules = Object.keys(ctx.parsedRules).length
}

if (toKeep) {
if (ctx.params.toKeep) {
for (const ruleName in ctx.parsedRules) {
const ruleNode = ctx.parsedRules[ruleName]
const parents = ctx.refs.parents.get(ruleName)

if (
isFoldable(ctx, ruleNode) &&
!toKeep([ruleName, ruleNode]) &&
!ctx.params.toKeep(ruleNode) &&
(!parents || parents?.size === 0)
) {
delete ctx.parsedRules[ruleName]
Expand Down
52 changes: 46 additions & 6 deletions test/optims/constantFolding.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Engine from 'publicodes'
import Engine, { RuleNode } from 'publicodes'
import { serializeParsedRules } from '../../source'
import { RuleName, RawRules, disabledLogger } from '../../source/commons'
import { constantFolding } from '../../source/optims/'
Expand All @@ -7,10 +7,11 @@ import { callWithEngine } from '../utils.test'
function constantFoldingWith(rawRules: any, targets?: RuleName[]): RawRules {
const res = callWithEngine(
(engine) =>
constantFolding(
engine,
targets ? ([ruleName, _]) => targets.includes(ruleName) : undefined,
),
constantFolding(engine, {
toKeep: targets
? (rule: RuleNode) => targets.includes(rule.dottedName)
: undefined,
}),
rawRules,
)
return serializeParsedRules(res)
Expand Down Expand Up @@ -39,7 +40,7 @@ describe('Constant folding [meta]', () => {
const baseParsedRules = engine.getParsedRules()
const serializedBaseParsedRules = serializeParsedRules(baseParsedRules)

constantFolding(engine, ([ruleName, _]) => ruleName === 'ruleA')
constantFolding(engine, { toKeep: (rule) => rule.dottedName === 'ruleA' })

const shouldNotBeModifiedRules = engine.getParsedRules()
const serializedShouldNotBeModifiedRules = serializeParsedRules(
Expand All @@ -51,6 +52,45 @@ describe('Constant folding [meta]', () => {
serializedShouldNotBeModifiedRules,
)
})

it('should not fold a rule specified in the [toAvoid] option', () => {
const rawRules = {
ruleA: {
titre: 'Rule A',
valeur: 'B . C * D',
},
ruleB: {
valeur: 'ruleA . B . C * 3',
},
'ruleA . D': {
question: "What's the value of D?",
},
'ruleA . B . C': {
valeur: '10',
},
}
const engine = new Engine(rawRules, {
logger: disabledLogger,
allowOrphanRules: true,
})
const foldedRules = serializeParsedRules(
constantFolding(engine, {
toAvoid: (rule) => rule.dottedName === 'ruleB',
}),
)
expect(foldedRules).toEqual({
...rawRules,
ruleA: {
optimized: 'partially',
titre: 'Rule A',
valeur: '10 * D',
},
'ruleA . B . C': {
optimized: 'fully',
valeur: 10,
},
})
})
})

describe('Constant folding [base]', () => {
Expand Down

0 comments on commit 1ee293b

Please sign in to comment.