Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Only read from disk if necessary #1315

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import shutil
import subprocess
import sys
import tempfile

Check notice on line 26 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L26

Unused import tempfile (unused-import)
import warnings
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -453,9 +453,7 @@
if self.options.verbose:
print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile)

with tempfile.NamedTemporaryFile(
mode="w", suffix=".mlir", dir=str(workspace), delete=False
) as tmp_infile:
with open(str(workspace) + "/input.test", "w") as tmp_infile:

Check notice on line 456 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L456

Using open without explicitly specifying an encoding (unspecified-encoding)
tmp_infile_name = tmp_infile.name
tmp_infile.write(ir)

Expand All @@ -477,20 +475,15 @@
f"catalyst-cli failed with error code {e.returncode}: {e.stderr}"
) from e

with open(output_ir_name, "r", encoding="utf-8") as f:
out_IR = f.read()

if lower_to_llvm:
output = LinkerDriver.run(output_object_name, options=self.options)
output_object_name = str(pathlib.Path(output).absolute())

# Clean up temporary files
if os.path.exists(tmp_infile_name):
os.remove(tmp_infile_name)
if os.path.exists(output_ir_name):
os.remove(output_ir_name)

return output_object_name, out_IR
return output_object_name, output_ir_name

@debug_logger
def run(self, mlir_module, *args, **kwargs):
Expand Down
61 changes: 40 additions & 21 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,10 @@
self.jaxed_function = None
# IRs are only available for the most recently traced function.
self.jaxpr = None
self.mlir = None # string form (historic presence)
self._mlir = None # string form (historic presence)
self.mlir_module = None
self.qir = None
self._qir = None
self.qir_file = None
self.out_type = None
self.overwrite_ir = None

Expand Down Expand Up @@ -514,6 +515,21 @@

super().__init__("user_function")

@property
def mlir(self):

Check notice on line 519 in frontend/catalyst/jit.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jit.py#L519

Missing function or method docstring (missing-function-docstring)
if not self._mlir and self.mlir_module:
_, mlir_file = self.canonicalize(self.mlir_module)
with open(mlir_file, "r", encoding="utf-8") as f:
self._mlir = f.read()
return self._mlir

@property
def qir(self):

Check notice on line 527 in frontend/catalyst/jit.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jit.py#L527

Missing function or method docstring (missing-function-docstring)
if not self._qir and self.qir_file:
with open(self.qir_file, "r", encoding="utf-8") as f:
self._qir = f.read()
return self._qir

@debug_logger
def __call__(self, *args, **kwargs):
# Transparantly call Python function in case of nested QJIT calls.
Expand Down Expand Up @@ -555,10 +571,10 @@
)

if self.compile_options.target in ("mlir", "binary"):
self.mlir_module, self.mlir = self.generate_ir()
self.mlir_module = self.generate_ir()

if self.compile_options.target in ("binary",):
self.compiled_function, self.qir = self.compile()
self.compiled_function, self.qir_file = self.compile()
self.fn_cache.insert(
self.compiled_function, self.user_sig, self.out_treedef, self.workspace
)
Expand Down Expand Up @@ -599,8 +615,8 @@
args, **kwargs
)

self.mlir_module, self.mlir = self.generate_ir()
self.compiled_function, self.qir = self.compile()
self.mlir_module = self.generate_ir()
self.compiled_function, self.qir_file = self.compile()

self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace)

Expand Down Expand Up @@ -696,7 +712,20 @@
PipelineNameUniquer.reset()
return jaxpr, out_type, treedef, dynamic_sig

@instrument(size_from=0, has_finegrained=True)
@debug_logger
def canonicalize(self, mlir_module):
"""Canonicalize the mlir_module"""

# Canonicalize the MLIR since there can be a lot of redundancy coming from JAX.
options = copy.deepcopy(self.compile_options)
options.pipelines = [("0_canonicalize", ["canonicalize"])]
options.lower_to_llvm = False
canonicalizer = Compiler(options)

# TODO: the in-memory and textual form are different after this, consider unification
return canonicalizer.run(mlir_module, self.workspace)

@instrument(has_finegrained=True)
@debug_logger
def generate_ir(self):
"""Generate Catalyst's intermediate representation (IR) as an MLIR module.
Expand All @@ -710,18 +739,8 @@
# Inject Runtime Library-specific functions (e.g. setup/teardown).
inject_functions(mlir_module, ctx, self.compile_options.seed)

# Canonicalize the MLIR since there can be a lot of redundancy coming from JAX.
options = copy.deepcopy(self.compile_options)
options.pipelines = [("0_canonicalize", ["canonicalize"])]
options.lower_to_llvm = False
canonicalizer = Compiler(options)

# TODO: the in-memory and textual form are different after this, consider unification
_, mlir_string = canonicalizer.run(mlir_module, self.workspace)

return mlir_module, mlir_string
return mlir_module

@instrument(size_from=1, has_finegrained=True)
@debug_logger
def compile(self):
"""Compile an MLIR module to LLVMIR and shared library code.
Expand All @@ -746,19 +765,19 @@
# `replace` method, so we need to get a regular Python string out of it.
func_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
if self.overwrite_ir:
shared_object, llvm_ir = self.compiler.run_from_ir(
shared_object, llvm_ir_file = self.compiler.run_from_ir(
self.overwrite_ir,
str(self.mlir_module.operation.attributes["sym_name"]).replace('"', ""),
self.workspace,
)
else:
shared_object, llvm_ir = self.compiler.run(self.mlir_module, self.workspace)
shared_object, llvm_ir_file = self.compiler.run(self.mlir_module, self.workspace)

compiled_fn = CompiledFunction(
shared_object, func_name, restype, self.out_type, self.compile_options
)

return compiled_fn, llvm_ir
return compiled_fn, llvm_ir_file

@instrument(has_finegrained=True)
@debug_logger
Expand Down
1 change: 1 addition & 0 deletions frontend/test/pytest/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@
qml.PauliX(wires=0)
return qml.state()

workflow.mlir

Check notice on line 273 in frontend/test/pytest/test_compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_compiler.py#L273

Statement seems to have no effect (pointless-statement)
directory = os.path.join(os.getcwd(), workflow.__name__)
files = os.listdir(directory)
# The directory is non-empty. Should at least contain the original .mlir file
Expand Down
7 changes: 1 addition & 6 deletions frontend/test/pytest/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,7 @@ def f(x: float):

f(2.0)

with pytest.raises(
CompileError,
match="Attempting to get output for pipeline: mlir, "
"but no file was found.\nAre you sure the file exists?",
):
get_compilation_stage(f, "mlir")
get_compilation_stage(f, "mlir")

@pytest.mark.parametrize(
"arg",
Expand Down