Skip to content

Commit

Permalink
Fix enzyme skipping if input ir is optimized by O2Opt (#1024)
Browse files Browse the repository at this point in the history
**Context:**
Run enzyme pass if checkpoint stage is set to `O2Opt`.

Also rename the functions in the related tests to avoid race condition.

---------

Co-authored-by: Romain Moyard <[email protected]>
  • Loading branch information
Tzung-Han Juang and rmoyard authored Aug 15, 2024
1 parent 0422dca commit 62166c5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
36 changes: 20 additions & 16 deletions frontend/test/pytest/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,45 +484,49 @@ def f(x: float):
def test_modify_ir(self, pass_name, target, replacement):
"""Turn a square function in IRs into a cubic one."""

@qjit(keep_intermediate=True)
def f(x):
"""Square function."""
return x**2

f.__name__ = f.__name__ + pass_name

jit_f = qjit(f, keep_intermediate=True)
data = 2.0
old_result = f(data)
old_ir = get_compilation_stage(f, pass_name)
old_workspace = str(f.workspace)
old_result = jit_f(data)
old_ir = get_compilation_stage(jit_f, pass_name)
old_workspace = str(jit_f.workspace)

new_ir = old_ir.replace(target, replacement)
replace_ir(f, pass_name, new_ir)
new_result = f(data)
replace_ir(jit_f, pass_name, new_ir)
new_result = jit_f(data)

shutil.rmtree(old_workspace, ignore_errors=True)
shutil.rmtree(str(f.workspace), ignore_errors=True)
shutil.rmtree(str(jit_f.workspace), ignore_errors=True)
assert old_result * data == new_result

@pytest.mark.parametrize("pass_name", ["HLOLoweringPass", "O2Opt", "Enzyme"])
def test_modify_ir_file_generation(self, pass_name):
"""Test if recompilation rerun the same pass."""

@qjit
def f(x: float):
"""Square function."""
return x**2

grad_f = qjit(value_and_grad(f), keep_intermediate=True)
grad_f(3.0)
ir = get_compilation_stage(grad_f, pass_name)
old_workspace = str(grad_f.workspace)
f.__name__ = f.__name__ + pass_name

jit_f = qjit(f)
jit_grad_f = qjit(value_and_grad(jit_f), keep_intermediate=True)
jit_grad_f(3.0)
ir = get_compilation_stage(jit_grad_f, pass_name)
old_workspace = str(jit_grad_f.workspace)

replace_ir(grad_f, pass_name, ir)
grad_f(3.0)
file_list = os.listdir(str(grad_f.workspace))
replace_ir(jit_grad_f, pass_name, ir)
jit_grad_f(3.0)
file_list = os.listdir(str(jit_grad_f.workspace))
res = [i for i in file_list if pass_name in i]

shutil.rmtree(old_workspace, ignore_errors=True)
shutil.rmtree(str(grad_f.workspace), ignore_errors=True)
shutil.rmtree(str(jit_grad_f.workspace), ignore_errors=True)
assert len(res) == 0

def test_get_compilation_stage_without_keep_intermediate(self):
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Driver/CompilerDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
catalyst::utils::LinesCount::ModuleOp(*op);
output.isCheckpointFound = options.checkpointStage == "mlir";

bool enzymeRun = false;
// Enzyme always happens after O2Opt. If the checkpoint is O2Opt, enzymeRun must be set to
// true so that the enzyme pass can be executed.
bool enzymeRun = options.checkpointStage == "O2Opt";
if (op) {
enzymeRun = containsGradients(*op);
if (failed(runLowering(options, &ctx, *op, output))) {
Expand Down

0 comments on commit 62166c5

Please sign in to comment.