Skip to content

Commit

Permalink
Dynamically detect cxx11_abi from torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim-Salzmann committed Oct 22, 2023
1 parent 7b773a1 commit ee5daa1
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions l4casadi/l4casadi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import pathlib
import platform
Expand Down Expand Up @@ -136,14 +135,15 @@ def compile_cs_function(self):

# call gcc
soname = 'install_name' if platform.system() == 'Darwin' else 'soname'
cxx11_abi = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
os_cmd = ("gcc"
" -fPIC -shared"
f" {self.build_dir / self.name}.cpp"
f" -o {self.build_dir / f'lib{self.name}'}{dynamic_lib_file_ending()}"
f" -I{include_dir} -L{lib_dir}"
f" -Wl,-{soname},lib{self.name}{dynamic_lib_file_ending()}"
" -ll4casadi -lstdc++ -std=c++17"
" -D_GLIBCXX_USE_CXX11_ABI=0")
f" -D_GLIBCXX_USE_CXX11_ABI={cxx11_abi}")

status = os.system(os_cmd)
if status != 0:
Expand Down Expand Up @@ -173,23 +173,32 @@ def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]:
torch.jit.trace(self.model, d_inp).save((out_folder / f'{self.name}_forward.pt').as_posix())

jac_model = self._trace_jac_model(d_inp)
hess_model = self._trace_hess_model(d_inp)

hess_model = None
try:
hess_model = self._trace_hess_model(d_inp)
except: # noqa
pass

exported_jacrev = self._jit_compile_and_save(
jac_model,
(out_folder / f'{self.name}_jacrev.pt').as_posix(),
d_inp
)
exported_hess = self._jit_compile_and_save(
hess_model,
(out_folder / f'{self.name}_hess.pt').as_posix(),
d_inp
)
if hess_model is not None:
exported_hess = self._jit_compile_and_save(
hess_model,
(out_folder / f'{self.name}_hess.pt').as_posix(),
d_inp
)
else:
exported_hess = False

return exported_jacrev, exported_hess

@staticmethod
def _jit_compile_and_save(model, file_path: str, dummy_inp: torch.Tensor):
# TODO: Could switch to torch export https://pytorch.org/docs/stable/export.html
# Try tracing
try:
torch.jit.trace(model, dummy_inp).save(file_path)
Expand Down

0 comments on commit ee5daa1

Please sign in to comment.