Skip to content

Commit

Permalink
multipy/runtime: fix torch.jit.trace (#141)
Browse files Browse the repository at this point in the history
Summary:
This fixes torch.jit.trace when used from multiple different python intepreters. There's a global registration of a method to get the Python callstack. This adds a command that after loading `torch._C` it resets it to be a noop function otherwise you end up with cross interpreter Python calls which causes a segfault.

This also copies over the lldb script to load the interpreter symbols.

Pull Request resolved: #141

Test Plan:
Enabled torchdynamo w/ ofi backend in test_compat.py

```shell
(multipy3.8.6) tristanr@tristanr-arch2 ~/D/multipy (jittrace)> multipy/runtime/build/interactive_embedded_interpreter --pyscript multipy/runtime/test_compat.py
Registering torch::deploy builtin library tensorrt (idx 0) with 0 builtin modules
torch::deploy builtin tensorrt contains 0 modules
Registering torch::deploy builtin library cpython_internal (idx 1) with 0 builtin modules
torch::deploy builtin cpython_internal contains 6 modules
Registering torch::deploy builtin library tensorrt (idx 0) with 0 builtin modules
torch::deploy builtin tensorrt contains 0 modules
Registering torch::deploy builtin library cpython_internal (idx 1) with 0 builtin modules
torch::deploy builtin cpython_internal contains 6 modules
[W OperatorEntry.cpp:133] Warning: Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::get_gradients(int context_id) -> Dict(Tensor, Tensor)
    registered at aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: (catch all)
  previous kernel: registered at ../torch/csrc/jit/runtime/register_distributed_ops.cpp:278
       new kernel: registered at ../torch/csrc/jit/runtime/register_distributed_ops.cpp:278 (function registerKernel)
..s../home/tristanr/venvs/multipy3.8.6/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Pytho
n extension: libtorch_python.so: cannot open shared object file: No such file or directory
  warn(f"Failed to load image Python extension: {e}")
.
----------------------------------------------------------------------
Ran 6 tests in 0.663s

OK (skipped=1)
```

Reviewed By: anirbanr-fb-r2p

Differential Revision: D39005120

Pulled By: d4l3k

fbshipit-source-id: f6b71f057cdef2fcd20e8f7f320e99edc65ce471
  • Loading branch information
d4l3k authored and facebook-github-bot committed Aug 26, 2022
1 parent ea75273 commit 3522995
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
26 changes: 8 additions & 18 deletions multipy/runtime/interpreter/interpreter_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/frontend/tracer.h>

#include <cassert>
#include <cstdio>
Expand Down Expand Up @@ -111,6 +112,10 @@ class MultiPySafeRethrow {
const int line_;
};

std::vector<::torch::jit::StackEntry> noPythonCallstack() {
return std::vector<::torch::jit::StackEntry>();
}

} // namespace

const char* start = R"PYTHON(
Expand Down Expand Up @@ -316,24 +321,6 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterImpl
getPackage(getPackageArg),
objects(objectsArg) {}

explicit ConcreteInterpreterImpl(
const std::vector<std::string>& extra_python_paths,
const std::vector<std::string>& plugin_paths) {
ConcreteInterpreterImplConstructorCommon(extra_python_paths, plugin_paths);

int r = PyRun_SimpleString(start);
TORCH_INTERNAL_ASSERT(r == 0);

// we cache these so we don't have to repeat the conversion of strings into
// Python and hash table lookups to get to these object
saveStorage = global_impl("multipy.utils._deploy", "_save_storages");
loadStorage = global_impl("multipy.utils._deploy", "_load_storages");
getPackage = global_impl("multipy.utils._deploy", "_get_package");
objects = global_impl("multipy.utils._deploy", "_deploy_objects");
// Release the GIL that PyInitialize acquires
PyEval_SaveThread();
}

~ConcreteInterpreterImpl() override {
PyGILState_Ensure();
// make sure pybind11 doesn't try to decref after we have destroyed python
Expand Down Expand Up @@ -571,6 +558,9 @@ newInterpreterImpl(
int r = PyRun_SimpleString(start);
TORCH_INTERNAL_ASSERT(r == 0);

// disable python callstack for jit tracer
::torch::jit::tracer::setPythonCallstack(&noPythonCallstack);

py::object saveStorage =
global_impl("multipy.utils._deploy", "_save_storages");
py::object loadStorage =
Expand Down
3 changes: 2 additions & 1 deletion multipy/runtime/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ def fn(x, y):

fn(torch.randn(10), torch.randn(10))

@unittest.skip("ofi segfaults")
def test_torchdynamo_ofi(self):
import torchdynamo

torchdynamo.reset()

@torchdynamo.optimize("ofi")
def fn(x, y):
a = torch.cos(x)
Expand Down
37 changes: 37 additions & 0 deletions scripts/runtime_debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import lldb # type: ignore[import]

# load into lldb instance with:
# command script import tools/lldb/deploy_debugger.py

target = lldb.debugger.GetSelectedTarget()
bp = target.BreakpointCreateByRegex("__deploy_register_code")
bp.SetScriptCallbackBody(
"""\
process = frame.thread.GetProcess()
target = process.target
symbol_addr = frame.module.FindSymbol("__deploy_module_info").GetStartAddress()
info_addr = symbol_addr.GetLoadAddress(target)
e = lldb.SBError()
ptr_size = 8
str_addr = process.ReadPointerFromMemory(info_addr, e)
file_addr = process.ReadPointerFromMemory(info_addr + ptr_size, e)
file_size = process.ReadPointerFromMemory(info_addr + 2*ptr_size, e)
load_bias = process.ReadPointerFromMemory(info_addr + 3*ptr_size, e)
name = process.ReadCStringFromMemory(str_addr, 512, e)
r = process.ReadMemory(file_addr, file_size, e)
from tempfile import NamedTemporaryFile
from pathlib import Path
stem = Path(name).stem
with NamedTemporaryFile(prefix=stem, suffix='.so', delete=False) as tf:
tf.write(r)
print("torch_deploy registering debug inforation for ", tf.name)
cmd1 = f"target modules add {tf.name}"
# print(cmd1)
lldb.debugger.HandleCommand(cmd1)
cmd2 = f"target modules load -f {tf.name} -s {hex(load_bias)}"
# print(cmd2)
lldb.debugger.HandleCommand(cmd2)
return False
"""
)

0 comments on commit 3522995

Please sign in to comment.