From 35229950e731788d9b2b48e0a480c4441dc9f2b4 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 26 Aug 2022 15:52:37 -0700 Subject: [PATCH] multipy/runtime: fix torch.jit.trace (#141) Summary: This fixes torch.jit.trace when used from multiple different python intepreters. There's a global registration of a method to get the Python callstack. This adds a command that after loading `torch._C` it resets it to be a noop function otherwise you end up with cross interpreter Python calls which causes a segfault. This also copies over the lldb script to load the interpreter symbols. Pull Request resolved: https://github.com/pytorch/multipy/pull/141 Test Plan: Enabled torchdynamo w/ ofi backend in test_compat.py ```shell (multipy3.8.6) tristanr@tristanr-arch2 ~/D/multipy (jittrace)> multipy/runtime/build/interactive_embedded_interpreter --pyscript multipy/runtime/test_compat.py Registering torch::deploy builtin library tensorrt (idx 0) with 0 builtin modules torch::deploy builtin tensorrt contains 0 modules Registering torch::deploy builtin library cpython_internal (idx 1) with 0 builtin modules torch::deploy builtin cpython_internal contains 6 modules Registering torch::deploy builtin library tensorrt (idx 0) with 0 builtin modules torch::deploy builtin tensorrt contains 0 modules Registering torch::deploy builtin library cpython_internal (idx 1) with 0 builtin modules torch::deploy builtin cpython_internal contains 6 modules [W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key operator: aten::get_gradients(int context_id) -> Dict(Tensor, Tensor) registered at aten/src/ATen/RegisterSchema.cpp:6 dispatch key: (catch all) previous kernel: registered at ../torch/csrc/jit/runtime/register_distributed_ops.cpp:278 new kernel: registered at ../torch/csrc/jit/runtime/register_distributed_ops.cpp:278 (function registerKernel) ..s../home/tristanr/venvs/multipy3.8.6/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Pytho n extension: libtorch_python.so: cannot open shared object file: No such file or directory warn(f"Failed to load image Python extension: {e}") . ---------------------------------------------------------------------- Ran 6 tests in 0.663s OK (skipped=1) ``` Reviewed By: anirbanr-fb-r2p Differential Revision: D39005120 Pulled By: d4l3k fbshipit-source-id: f6b71f057cdef2fcd20e8f7f320e99edc65ce471 --- .../runtime/interpreter/interpreter_impl.cpp | 26 ++++--------- multipy/runtime/test_compat.py | 3 +- scripts/runtime_debugger.py | 37 +++++++++++++++++++ 3 files changed, 47 insertions(+), 19 deletions(-) create mode 100644 scripts/runtime_debugger.py diff --git a/multipy/runtime/interpreter/interpreter_impl.cpp b/multipy/runtime/interpreter/interpreter_impl.cpp index b23e56be..34b1bcd7 100644 --- a/multipy/runtime/interpreter/interpreter_impl.cpp +++ b/multipy/runtime/interpreter/interpreter_impl.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -111,6 +112,10 @@ class MultiPySafeRethrow { const int line_; }; +std::vector<::torch::jit::StackEntry> noPythonCallstack() { + return std::vector<::torch::jit::StackEntry>(); +} + } // namespace const char* start = R"PYTHON( @@ -316,24 +321,6 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterImpl getPackage(getPackageArg), objects(objectsArg) {} - explicit ConcreteInterpreterImpl( - const std::vector& extra_python_paths, - const std::vector& plugin_paths) { - ConcreteInterpreterImplConstructorCommon(extra_python_paths, plugin_paths); - - int r = PyRun_SimpleString(start); - TORCH_INTERNAL_ASSERT(r == 0); - - // we cache these so we don't have to repeat the conversion of strings into - // Python and hash table lookups to get to these object - saveStorage = global_impl("multipy.utils._deploy", "_save_storages"); - loadStorage = global_impl("multipy.utils._deploy", "_load_storages"); - getPackage = global_impl("multipy.utils._deploy", "_get_package"); - objects = global_impl("multipy.utils._deploy", "_deploy_objects"); - // Release the GIL that PyInitialize acquires - PyEval_SaveThread(); - } - ~ConcreteInterpreterImpl() override { PyGILState_Ensure(); // make sure pybind11 doesn't try to decref after we have destroyed python @@ -571,6 +558,9 @@ newInterpreterImpl( int r = PyRun_SimpleString(start); TORCH_INTERNAL_ASSERT(r == 0); + // disable python callstack for jit tracer + ::torch::jit::tracer::setPythonCallstack(&noPythonCallstack); + py::object saveStorage = global_impl("multipy.utils._deploy", "_save_storages"); py::object loadStorage = diff --git a/multipy/runtime/test_compat.py b/multipy/runtime/test_compat.py index 80c876f3..4120c3b7 100644 --- a/multipy/runtime/test_compat.py +++ b/multipy/runtime/test_compat.py @@ -33,10 +33,11 @@ def fn(x, y): fn(torch.randn(10), torch.randn(10)) - @unittest.skip("ofi segfaults") def test_torchdynamo_ofi(self): import torchdynamo + torchdynamo.reset() + @torchdynamo.optimize("ofi") def fn(x, y): a = torch.cos(x) diff --git a/scripts/runtime_debugger.py b/scripts/runtime_debugger.py new file mode 100644 index 00000000..5a139589 --- /dev/null +++ b/scripts/runtime_debugger.py @@ -0,0 +1,37 @@ +import lldb # type: ignore[import] + +# load into lldb instance with: +# command script import tools/lldb/deploy_debugger.py + +target = lldb.debugger.GetSelectedTarget() +bp = target.BreakpointCreateByRegex("__deploy_register_code") +bp.SetScriptCallbackBody( + """\ +process = frame.thread.GetProcess() +target = process.target +symbol_addr = frame.module.FindSymbol("__deploy_module_info").GetStartAddress() +info_addr = symbol_addr.GetLoadAddress(target) +e = lldb.SBError() +ptr_size = 8 +str_addr = process.ReadPointerFromMemory(info_addr, e) +file_addr = process.ReadPointerFromMemory(info_addr + ptr_size, e) +file_size = process.ReadPointerFromMemory(info_addr + 2*ptr_size, e) +load_bias = process.ReadPointerFromMemory(info_addr + 3*ptr_size, e) +name = process.ReadCStringFromMemory(str_addr, 512, e) +r = process.ReadMemory(file_addr, file_size, e) +from tempfile import NamedTemporaryFile +from pathlib import Path +stem = Path(name).stem +with NamedTemporaryFile(prefix=stem, suffix='.so', delete=False) as tf: + tf.write(r) + print("torch_deploy registering debug inforation for ", tf.name) + cmd1 = f"target modules add {tf.name}" + # print(cmd1) + lldb.debugger.HandleCommand(cmd1) + cmd2 = f"target modules load -f {tf.name} -s {hex(load_bias)}" + # print(cmd2) + lldb.debugger.HandleCommand(cmd2) + +return False +""" +)