diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index ea41da06e7..0fc4b46d70 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -453,9 +453,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory): if self.options.verbose: print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".mlir", dir=str(workspace), delete=False - ) as tmp_infile: + with open(str(workspace) + "/input.test", "w") as tmp_infile: tmp_infile_name = tmp_infile.name tmp_infile.write(ir) @@ -477,9 +475,6 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory): f"catalyst-cli failed with error code {e.returncode}: {e.stderr}" ) from e - with open(output_ir_name, "r", encoding="utf-8") as f: - out_IR = f.read() - if lower_to_llvm: output = LinkerDriver.run(output_object_name, options=self.options) output_object_name = str(pathlib.Path(output).absolute()) @@ -487,10 +482,8 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory): # Clean up temporary files if os.path.exists(tmp_infile_name): os.remove(tmp_infile_name) - if os.path.exists(output_ir_name): - os.remove(output_ir_name) - return output_object_name, out_IR + return output_object_name, output_ir_name @debug_logger def run(self, mlir_module, *args, **kwargs): diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 86722edd56..cf6eef1906 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -481,9 +481,10 @@ def __init__(self, fn, compile_options): self.jaxed_function = None # IRs are only available for the most recently traced function. self.jaxpr = None - self.mlir = None # string form (historic presence) + self._mlir = None # string form (historic presence) self.mlir_module = None - self.qir = None + self._qir = None + self.qir_file = None self.out_type = None self.overwrite_ir = None @@ -514,6 +515,21 @@ def __init__(self, fn, compile_options): super().__init__("user_function") + @property + def mlir(self): + if not self._mlir and self.mlir_module: + _, mlir_file = self.canonicalize(self.mlir_module) + with open(mlir_file, "r", encoding="utf-8") as f: + self._mlir = f.read() + return self._mlir + + @property + def qir(self): + if not self._qir and self.qir_file: + with open(self.qir_file, "r", encoding="utf-8") as f: + self._qir = f.read() + return self._qir + @debug_logger def __call__(self, *args, **kwargs): # Transparantly call Python function in case of nested QJIT calls. @@ -555,10 +571,10 @@ def aot_compile(self): ) if self.compile_options.target in ("mlir", "binary"): - self.mlir_module, self.mlir = self.generate_ir() + self.mlir_module = self.generate_ir() if self.compile_options.target in ("binary",): - self.compiled_function, self.qir = self.compile() + self.compiled_function, self.qir_file = self.compile() self.fn_cache.insert( self.compiled_function, self.user_sig, self.out_treedef, self.workspace ) @@ -599,8 +615,8 @@ def jit_compile(self, args, **kwargs): args, **kwargs ) - self.mlir_module, self.mlir = self.generate_ir() - self.compiled_function, self.qir = self.compile() + self.mlir_module = self.generate_ir() + self.compiled_function, self.qir_file = self.compile() self.fn_cache.insert(self.compiled_function, args, self.out_treedef, self.workspace) @@ -696,7 +712,20 @@ def closure(qnode, *args, **kwargs): PipelineNameUniquer.reset() return jaxpr, out_type, treedef, dynamic_sig - @instrument(size_from=0, has_finegrained=True) + @debug_logger + def canonicalize(self, mlir_module): + """Canonicalize the mlir_module""" + + # Canonicalize the MLIR since there can be a lot of redundancy coming from JAX. + options = copy.deepcopy(self.compile_options) + options.pipelines = [("0_canonicalize", ["canonicalize"])] + options.lower_to_llvm = False + canonicalizer = Compiler(options) + + # TODO: the in-memory and textual form are different after this, consider unification + return canonicalizer.run(mlir_module, self.workspace) + + @instrument(has_finegrained=True) @debug_logger def generate_ir(self): """Generate Catalyst's intermediate representation (IR) as an MLIR module. @@ -710,18 +739,8 @@ def generate_ir(self): # Inject Runtime Library-specific functions (e.g. setup/teardown). inject_functions(mlir_module, ctx, self.compile_options.seed) - # Canonicalize the MLIR since there can be a lot of redundancy coming from JAX. - options = copy.deepcopy(self.compile_options) - options.pipelines = [("0_canonicalize", ["canonicalize"])] - options.lower_to_llvm = False - canonicalizer = Compiler(options) - - # TODO: the in-memory and textual form are different after this, consider unification - _, mlir_string = canonicalizer.run(mlir_module, self.workspace) - - return mlir_module, mlir_string + return mlir_module - @instrument(size_from=1, has_finegrained=True) @debug_logger def compile(self): """Compile an MLIR module to LLVMIR and shared library code. @@ -746,19 +765,19 @@ def compile(self): # `replace` method, so we need to get a regular Python string out of it. func_name = str(self.mlir_module.body.operations[0].name).replace('"', "") if self.overwrite_ir: - shared_object, llvm_ir = self.compiler.run_from_ir( + shared_object, llvm_ir_file = self.compiler.run_from_ir( self.overwrite_ir, str(self.mlir_module.operation.attributes["sym_name"]).replace('"', ""), self.workspace, ) else: - shared_object, llvm_ir = self.compiler.run(self.mlir_module, self.workspace) + shared_object, llvm_ir_file = self.compiler.run(self.mlir_module, self.workspace) compiled_fn = CompiledFunction( shared_object, func_name, restype, self.out_type, self.compile_options ) - return compiled_fn, llvm_ir + return compiled_fn, llvm_ir_file @instrument(has_finegrained=True) @debug_logger diff --git a/frontend/test/pytest/test_compiler.py b/frontend/test/pytest/test_compiler.py index 0ed0c80fb7..dc8f69d5a0 100644 --- a/frontend/test/pytest/test_compiler.py +++ b/frontend/test/pytest/test_compiler.py @@ -270,6 +270,7 @@ def workflow(): qml.PauliX(wires=0) return qml.state() + workflow.mlir directory = os.path.join(os.getcwd(), workflow.__name__) files = os.listdir(directory) # The directory is non-empty. Should at least contain the original .mlir file diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py index 14f98905ef..f794244b94 100644 --- a/frontend/test/pytest/test_debug.py +++ b/frontend/test/pytest/test_debug.py @@ -423,12 +423,7 @@ def f(x: float): f(2.0) - with pytest.raises( - CompileError, - match="Attempting to get output for pipeline: mlir, " - "but no file was found.\nAre you sure the file exists?", - ): - get_compilation_stage(f, "mlir") + get_compilation_stage(f, "mlir") @pytest.mark.parametrize( "arg",