diff --git a/.github/workflows/ci_v2.yaml b/.github/workflows/ci_v2.yaml new file mode 100644 index 0000000..be86ccf --- /dev/null +++ b/.github/workflows/ci_v2.yaml @@ -0,0 +1,92 @@ +name: L4CasADi v2 + +on: + push: + branches: [ v2 ] + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - uses: actions/checkout@v3 + with: + ref: 'v2' + - name: Run mypy + run: | + pip install mypy + mypy . --ignore-missing-imports --exclude examples + - name: Run flake8 + run: | + pip install flake8 + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + tests: + runs-on: ${{ matrix.runs-on }} + needs: [ lint ] + timeout-minutes: 60 + strategy: + fail-fast: false + matrix: + runs-on: [ubuntu-latest, ubuntu-20.04, macos-latest] + + name: Tests on ${{ matrix.runs-on }} + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + ref: 'v2' + fetch-depth: 0 + + - name: Install Python + uses: actions/setup-python@v4 + with: + python-version: '>=3.9 <3.12' + + - name: Install L4CasADi + run: | + python -m pip install --upgrade pip + pip install torch --index-url https://download.pytorch.org/whl/cpu # Ensure CPU torch version + pip install -r requirements_build.txt + pip install . -v --no-build-isolation + + - name: Test with pytest + working-directory: ./tests + run: | + pip install pytest + pytest . + + test-on-aarch: + runs-on: ubuntu-latest + needs: [ lint ] + timeout-minutes: 60 + + name: Tests on aarch64 + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + ref: 'v2' + fetch-depth: 0 + - uses: uraimo/run-on-arch-action@v2 + name: Install and Test + with: + arch: aarch64 + distro: ubuntu20.04 + install: | + apt-get update + apt-get install -y --no-install-recommends python3.9 python3-pip python-is-python3 + pip install -U pip + apt-get install -y build-essential + + run: | + python -m pip install --upgrade pip + pip install torch --index-url https://download.pytorch.org/whl/cpu # Ensure CPU torch version + pip install -r requirements_build.txt + pip install . -v --no-build-isolation + # pip install pytest + # pytest . \ No newline at end of file diff --git a/README.md b/README.md index 6db94c1..6c8c8ab 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ [![PyPI version](https://badge.fury.io/py/l4casadi.svg)](https://badge.fury.io/py/l4casadi) ![L4CasADi CI](https://github.com/Tim-Salzmann/l4casadi/actions/workflows/ci.yaml/badge.svg) +![L4CasADi v2](https://github.com/Tim-Salzmann/l4casadi/actions/workflows/ci_v2.yaml/badge.svg) ![Downloads](https://img.shields.io/pypi/dm/l4casadi.svg) --- @@ -23,6 +24,22 @@ arXiv: [Learning for CasADi: Data-driven Models in Numerical Optimization](https Talk: [Youtube](https://youtu.be/UYdkRnGr8eM?si=KEPcFEL9b7Vk2juI&t=3348) +## L4CasADi v2 Changelog +After feedback from first use-cases L4CasADi v2 is designed with efficiency and simplicity in mind. This leads to a few breaking changes. + +Breaking changes are +- L4CasADi will not change the shape of an input anymore. The tensor forwarded to the PyTorch model will resemble the exact +dimension of the input variable by CasADi. You are responsible to make sure that the PyTorch model handles a +**two-dimensional** input matrix! +- L4CasADi v2 can leverage PyTorch's batching capabilities for increased efficiency. When passing `batched=True`, +L4CasADi will understand the **first** input dimension as batch dimension. Thus, first and second-order derivatives +across elements of this dimension are assumed to be **sparse-zero**. To make use of this, instead of having multiple calls to a L4CasADi function in +your CasADi program, batch all inputs together and have a single L4CasADi call. An example of this can be seen when +comparing the [non-batched NeRF example](examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py) with the +[batched NeRF example](examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py) which is faster by +a factor of 5-10x. + + ## Table of Content - [Projects using L4CasADi](#projects-using-l4casadi) - [Installation](#installation) @@ -202,14 +219,6 @@ https://github.com/Tim-Salzmann/l4casadi/blob/421de6ef408267eed0fd2519248b2152b6 ## FYIs -### Batch Dimension - -If your PyTorch model expects a batch dimension as first dimension (which most models do) you should pass -`model_expects_batch_dim=True` to the `L4CasADi` constructor. The `MX` input to the L4CasADi component is then expected -to be a vector of shape `[X, 1]`. L4CasADi will add a batch dimension of `1` automatically such that the input to the -underlying PyTorch model is of shape `[1, X]`. - ---- ### Warm Up diff --git a/examples/naive/readme.py b/examples/naive/readme.py index 7a9b1b6..81a43ef 100644 --- a/examples/naive/readme.py +++ b/examples/naive/readme.py @@ -3,15 +3,15 @@ naive_mlp = l4c.naive.MultiLayerPerceptron(2, 128, 1, 2, 'Tanh') -l4c_model = l4c.L4CasADi(naive_mlp, model_expects_batch_dim=True) +l4c_model = l4c.L4CasADi(naive_mlp) -x_sym = cs.MX.sym('x', 2, 1) +x_sym = cs.MX.sym('x', 1, 2) y_sym = l4c_model(x_sym) f = cs.Function('y', [x_sym], [y_sym]) df = cs.Function('dy', [x_sym], [cs.jacobian(y_sym, x_sym)]) ddf = cs.Function('ddy', [x_sym], [cs.hessian(y_sym, x_sym)[0]]) -x = cs.DM([[0.], [2.]]) +x = cs.DM([[0., 2.]]) print(l4c_model(x)) print(f(x)) print(df(x)) diff --git a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py index 3fd11f0..f1172fb 100644 --- a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py +++ b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py @@ -86,7 +86,7 @@ def trajectory_generator_solver(n, n_eval, L, warmup, threshold): f += cs.sum2(sk**2) # While having a maximum density (1.) of the NeRF as constraint. - lk = L(pk.T) + lk = L(pk) g = cs.horzcat(g, lk) lbg = cs.horzcat(lbg, cs.DM([-10e32]).T) ubg = cs.horzcat(ubg, cs.DM([threshold]).T) diff --git a/l4casadi/l4casadi.py b/l4casadi/l4casadi.py index 86d271a..9188a90 100644 --- a/l4casadi/l4casadi.py +++ b/l4casadi/l4casadi.py @@ -3,20 +3,22 @@ import platform import shutil import time +import warnings try: from importlib.resources import files except ImportError: - from importlib_resources import files # type: ignore[no-redef] + from importlib_resources import files # type: ignore[no-redef] from typing import Union, Optional, Callable, Text, Tuple import casadi as cs import torch + try: - from torch.func import jacrev, jacfwd, functionalize + from torch.func import jacrev, jacfwd, functionalize, vjp except ImportError: - from functorch import jacrev, jacfwd, functionalize + from functorch import jacrev, jacfwd, functionalize, vjp from l4casadi.ts_compiler import ts_compile from torch.fx.experimental.proxy_tensor import make_fx @@ -37,18 +39,20 @@ def dynamic_lib_file_ending(): class L4CasADi(object): def __init__(self, model: Callable[[torch.Tensor], torch.Tensor], - batched: bool = True, + batched: bool = False, device: Union[torch.device, Text] = 'cpu', name: Text = 'l4casadi_f', build_dir: Text = './_l4c_generated', model_search_path: Optional[Text] = None, - with_jacobian: bool = True, - with_hessian: bool = True, + generate_jac: bool = True, + generate_adj1: bool = True, + generate_jac_adj1: bool = True, + generate_jac_jac: bool = False, mutable: bool = False): """ :param model: PyTorch model. - :param model_expects_batch_dim: True if the PyTorch model expects a batch dimension. This is commonly True - for trained PyTorch models. + :param batched: If True, the first dimension of the two expected input dimension is assumed to be a batch + dimension. This can lead to speedups as sensitivities across this dimension can be neglected. :param device: Device on which the PyTorch model is executed. :param name: Unique name of the generated L4CasADi model. This name is used for autogenerated files. Creating two L4CasADi models with the same name will result in overwriting the files of the first model. @@ -56,10 +60,16 @@ def __init__(self, the absolute path to the `build_dir` where the model traces are exported to. This parameter can become useful if the created L4CasADi dynamic library and the exported PyTorch Models are expected to be moved to a different folder (or another device). - :param with_jacobian: If True, the Jacobian of the model is exported. - :param with_hessian: If True, the Hessian of the model is exported. + :param build_dir: Directory where the L4CasADi library is built. + :param generate_jac: If True, the Jacobian of the model is tried to be generated. + :param generate_adj1: If True, the Adjoint of the model is tried to be generated. + :param generate_jac_adj1: If True, the Jacobain of the Adjoint of the model is tried to be generated. + :param generate_jac_jac: If True, the Hessian of the model is tried to be generated. :param mutable: If True, enables updating the model online via the update method. """ + if platform.system() == "Windows": + warnings.warn("L4CasADi is currently not supported for Windows.") + self.model = model self.naive = False if isinstance(self.model, NaiveL4CasADiModule): @@ -79,12 +89,15 @@ def __init__(self, self._cs_fun: Optional[cs.Function] = None self._built = False - self._with_jacobian = with_jacobian - self._with_hessian = with_hessian + self._generate_jac = generate_jac + self._generate_adj1 = generate_adj1 + self._generate_jac_adj1 = generate_jac_adj1 + self._generate_jac_jac = generate_jac_jac self._mutable = mutable - self._input_shape: Optional[Tuple[int, int]] = None + self._input_shape: Tuple[int, int] = (-1, -1) + self._output_shape: Tuple[int, int] = (-1, -1) def update(self, model: Optional[Callable[[torch.Tensor], torch.Tensor]] = None) -> None: """ @@ -105,7 +118,7 @@ def update(self, model: Optional[Callable[[torch.Tensor], torch.Tensor]] = None) for parameters in self.model.parameters(): parameters.requires_grad = False - self.export_torch_traces(*self._input_shape) # type: ignore[misc] + self.export_torch_traces() # type: ignore[misc] time.sleep(0.2) @@ -167,16 +180,36 @@ def build(self, inp: Union[cs.MX, cs.SX, cs.DM]) -> None: self._built = True + def _verify_input_output(self): + if len(self._output_shape) != 2: + raise ValueError(f"""L4CasADi requires the model output to be a matrix (2 dimensions) but has + {len(self._output_shape)} dimensions. Please add a extra dimension of size 1. + For models which expects a batch dimension, the output should be a matrix of [1, d].""") + + if self.batched: + if self._input_shape[0] != self._output_shape[0]: + raise ValueError(f"""When the model is batched the first dimension of input and output (batch dimension) + has to be the same.""") + def generate(self, inp: Union[cs.MX, cs.SX, cs.DM]) -> None: - rows, cols = inp.shape # type: ignore[attr-defined] - has_jac, has_hess = self.export_torch_traces(rows, cols) - if not has_jac and self._with_jacobian: - print('Jacobian trace could not be generated.' - ' First-order sensitivities will not be available in CasADi.') - if not has_hess and self._with_hessian: - print('Hessian trace could not be generated.' - ' Second-order sensitivities will not be available in CasADi.') - self._generate_cpp_function_template(rows, cols, has_jac, has_hess) + self._input_shape = inp.shape # type: ignore[attr-defined] + self._output_shape = self.model(torch.zeros(*self._input_shape).to(self.device)).shape + self._verify_input_output() + + has_jac, has_adj1, has_jac_adj1, has_jac_jac = self.export_torch_traces() + if not has_jac and self._generate_jac: + warnings.warn('Jacobian trace could not be generated.' + ' First-order sensitivities will not be available in CasADi.') + if not has_adj1 and self._generate_adj1: + warnings.warn('Adjoint trace could not be generated.' + ' First-order sensitivities will not be available in CasADi.') + if not has_jac_adj1 and self._generate_jac_adj1: + warnings.warn('Jacobian Adjoint trace could not be generated.' + ' Second-order sensitivities will not be available in CasADi.') + if not has_jac_jac and self._generate_jac_jac: + warnings.warn('Hessian trace could not be generated.' + ' Second-order sensitivities will not be available in CasADi.') + self._generate_cpp_function_template(has_jac, has_adj1, has_jac_adj1, has_jac_jac) def _load_built_library_as_external_cs_fun(self): if not self._built: @@ -187,7 +220,7 @@ def _load_built_library_as_external_cs_fun(self): ) @staticmethod - def generate_batched_output_ccs(batch_size, input_size, output_size): + def generate_block_diagonal_ccs(batch_size, input_size, output_size): """ https://de.wikipedia.org/wiki/Harwell-Boeing-Format :param batch_size: Size of batch dimension. @@ -221,50 +254,44 @@ def generate_batched_output_ccs(batch_size, input_size, output_size): hess_ccs = [batch_size * output_size * batch_size * input_size, batch_size * input_size] + col_ptr + row_ind - hess_ccs2 = [batch_size * output_size * batch_size * input_size, batch_size * output_size] + [0] * (batch_size * output_size + 1) + return jac_ccs, hess_ccs - return jac_ccs, hess_ccs, hess_ccs2 - - def _generate_cpp_function_template(self, rows: int, cols: int, has_jac: bool, has_hess: bool): - out_shape = self.model(torch.zeros(rows, cols).to(self.device)).shape - - if len(out_shape) != 2: - raise ValueError(f"""L4CasADi requires the model output to be a matrix (2 dimensions) but has - {len(out_shape)} dimensions. Please add a extra dimension of size 1. - For models which expects a batch dimension, the output should be a matrix of [1, d].""") - - rows_out, cols_out = out_shape + def _generate_cpp_function_template(self, has_jac: bool, has_adj1: bool, has_jac_adj1: bool, has_jac_jac: bool): model_path = (self.build_dir.absolute().as_posix() if self._model_search_path is None else self._model_search_path) if self.batched: - if rows != rows: - raise ValueError(f"""When the model is batched the first dimension of input and output (batch dimension) - has to be the same.""") - jac_ccs, hess_ccs, hess_ccs2 = self.generate_batched_output_ccs(rows, cols, cols_out) + jac_ccs, jac_jac_ccs = self.generate_block_diagonal_ccs(self._input_shape[0], + self._input_shape[1], + self._output_shape[1]) + jac_adj_css, _ = self.generate_block_diagonal_ccs(self._input_shape[0], + self._input_shape[1], + self._input_shape[1]) else: - jac_ccs, hess_ccs, hess_ccs2 = None, None, None + jac_ccs, jac_adj_css, jac_jac_ccs = None, None, None gen_params = { 'model_path': model_path, 'device': self.device, 'name': self.name, - 'rows_in': rows, - 'cols_in': cols, - 'rows_out': rows_out, - 'cols_out': cols_out, + 'rows_in': self._input_shape[0], + 'cols_in': self._input_shape[1], + 'rows_out': self._output_shape[0], + 'cols_out': self._output_shape[1], 'has_jac': 'true' if has_jac else 'false', - 'has_hess': 'true' if has_hess else 'false', + 'has_adj1': 'true' if has_adj1 else 'false', + 'has_jac_adj1': 'true' if has_jac_adj1 else 'false', + 'has_jac_jac': 'true' if has_jac_jac else 'false', 'model_is_mutable': 'true' if self._mutable else 'false', 'batched': 'true' if self.batched else 'false', 'jac_ccs_len': len(jac_ccs) if self.batched else 0, 'jac_ccs': ', '.join(str(e) for e in jac_ccs) if self.batched else '', - 'hess_ccs_len': len(hess_ccs) if self.batched else 0, - 'hess_ccs': ', '.join(str(e) for e in hess_ccs) if self.batched else '', - 'hess_ccs2_len': len(hess_ccs2) if self.batched else 0, - 'hess_ccs2': ', '.join(str(e) for e in hess_ccs2) if self.batched else '', + 'jac_adj_ccs_len': len(jac_adj_css) if self.batched else 0, + 'jac_adj_ccs': ', '.join(str(e) for e in jac_adj_css) if self.batched else '', + 'jac_jac_ccs_len': len(jac_jac_ccs) if self.batched else 0, + 'jac_jac_ccs': ', '.join(str(e) for e in jac_jac_ccs) if self.batched else '', } render_casadi_c_template( @@ -298,43 +325,87 @@ def compile(self): def _trace_jac_model(self, inp): if self.batched: def with_batch_dim(x): - return torch.func.vmap(jacrev(self.model))(x[:, None])[:, 0].permute(3, 2, 0, 1) + return torch.func.vmap(jacrev(self.model))(x[:, None])[:, 0].permute(1, 0, 2, 3) + return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(inp) return make_fx(functionalize(jacrev(self.model), remove='mutations_and_views'))(inp) + def _trace_adj1_model(self): + p_d = torch.zeros(self._input_shape).to(self.device) + t_d = torch.zeros(self._output_shape).to(self.device) + + def _vjp(p, x): + return vjp(self.model, p)[1](x)[0] + + return make_fx(functionalize(_vjp, remove='mutations_and_views'))(p_d, t_d) + + def _trace_jac_adj1_model(self): + p_d = torch.zeros(self._input_shape).to(self.device) + t_d = torch.zeros(self._output_shape).to(self.device) + + def _vjp(p, x): + return vjp(self.model, p)[1](x)[0] + + # TODO: replace jacfwd with jacref depending on answer in https://github.com/pytorch/pytorch/issues/130735 + if self.batched: + def with_batch_dim(p, x): + return torch.func.vmap(jacfwd(_vjp))(p[:, None], x[:, None])[:, 0].permute(3, 2, 0, 1) + + return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(p_d, t_d) + return make_fx(functionalize(jacfwd(_vjp), remove='mutations_and_views'))(p_d, t_d) + def _trace_hess_model(self, inp): if self.batched: def with_batch_dim(x): + # Permutation is trial and error return torch.func.vmap(jacrev(jacrev(self.model)))(x[:, None])[:, 0].permute(1, 3, 2, 0, 4, 5) + return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(inp) return make_fx(functionalize(jacrev(jacrev(self.model)), remove='mutations_and_views'))(inp) - def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]: - d_inp = torch.zeros((rows, cols)) - - # Save input shape for online update. - self._input_shape = (rows, cols) - + def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: + d_inp = torch.zeros(self._input_shape) d_inp = d_inp.to(self.device) + d_out = torch.zeros(self._output_shape) + d_out = d_out.to(self.device) + out_folder = self.build_dir self._jit_compile_and_save(make_fx(functionalize(self.model, remove='mutations_and_views'))(d_inp), - (out_folder / f'{self.name}_forward.pt').as_posix(), - d_inp) + (out_folder / f'{self.name}.pt').as_posix(), + (d_inp,)) - exported_jacrev = False - if self._with_jacobian: + exported_jac = False + if self._generate_jac: jac_model = self._trace_jac_model(d_inp) - exported_jacrev = self._jit_compile_and_save( + exported_jac = self._jit_compile_and_save( jac_model, - (out_folder / f'{self.name}_jacrev.pt').as_posix(), - d_inp + (out_folder / f'jac_{self.name}.pt').as_posix(), + (d_inp,) + ) + + exported_adj1 = False + if self._generate_adj1: + adj1_model = self._trace_adj1_model() + exported_adj1 = self._jit_compile_and_save( + adj1_model, + (out_folder / f'adj1_{self.name}.pt').as_posix(), + (d_inp, d_out) + ) + + exported_jac_adj1 = False + if self._generate_jac_adj1: + jac_adj1_model = self._trace_jac_adj1_model() + exported_jac_adj1 = self._jit_compile_and_save( + jac_adj1_model, + (out_folder / f'jac_adj1_{self.name}.pt').as_posix(), + (d_inp, d_out) ) exported_hess = False - if self._with_hessian: + if self._generate_jac_jac: hess_model = None try: hess_model = self._trace_hess_model(d_inp) @@ -344,11 +415,11 @@ def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]: if hess_model is not None: exported_hess = self._jit_compile_and_save( hess_model, - (out_folder / f'{self.name}_hess.pt').as_posix(), - d_inp + (out_folder / f'jac_jac_{self.name}.pt').as_posix(), + (d_inp,) ) - return exported_jacrev, exported_hess + return exported_jac, exported_adj1, exported_jac_adj1, exported_hess @staticmethod def _jit_compile_and_save(model, file_path: str, dummy_inp: torch.Tensor): diff --git a/l4casadi/naive/nn/linear.py b/l4casadi/naive/nn/linear.py index 27c9020..c2acede 100644 --- a/l4casadi/naive/nn/linear.py +++ b/l4casadi/naive/nn/linear.py @@ -6,8 +6,7 @@ class Linear(NaiveL4CasADiModule, torch.nn.Linear): def cs_forward(self, x): - assert x.shape[1] == 1, 'Casadi can not handle batches.' - y = cs.mtimes(self.weight.detach().numpy(), x) + y = cs.mtimes(x, self.weight.transpose(1, 0).detach().numpy()) if self.bias is not None: - y = y + self.bias.detach().numpy() + y = y + self.bias[None].detach().numpy() return y diff --git a/l4casadi/realtime/realtime_l4casadi.py b/l4casadi/realtime/realtime_l4casadi.py index 84e644f..984fd0f 100644 --- a/l4casadi/realtime/realtime_l4casadi.py +++ b/l4casadi/realtime/realtime_l4casadi.py @@ -22,7 +22,7 @@ def __init__(self, :param name: Unique name of the generated L4CasADi model. This name is used for autogenerated files. Creating two L4CasADi models with the same name will result in overwriting the files of the first model. """ - super().__init__(model, model_expects_batch_dim=True, device=device, name=name) + super().__init__(model, device=device, name=name) if approximation_order > 2 or approximation_order < 1: raise ValueError("Taylor approximation order must be 1 or 2.") diff --git a/l4casadi/template_generation/templates/casadi_function.in.cpp b/l4casadi/template_generation/templates/casadi_function.in.cpp index f22a379..f841bfa 100644 --- a/l4casadi/template_generation/templates/casadi_function.in.cpp +++ b/l4casadi/template_generation/templates/casadi_function.in.cpp @@ -1,6 +1,6 @@ #include -L4CasADi l4casadi("{{ model_path }}", "{{ name }}", "{{ device }}", {{ has_jac }}, {{ has_hess }}, {{ model_is_mutable }}); +L4CasADi l4casadi("{{ model_path }}", "{{ name }}", {{ rows_in }}, {{ cols_in }}, {{ rows_out }}, {{ cols_out }}, "{{ device }}", {{ has_jac }}, {{ has_adj1 }}, {{ has_jac_adj1 }}, {{ has_jac_jac }}, {{ model_is_mutable }}); #ifdef __cplusplus extern "C" { @@ -31,75 +31,143 @@ extern "C" { #endif #endif +// Function {{ name }} static const casadi_int {{ name }}_s_in0[3] = { {{ rows_in }}, {{ cols_in }}, 1}; static const casadi_int {{ name }}_s_out0[3] = { {{ rows_out }}, {{ cols_out }}, 1}; -{% if has_jac == "true" and batched %} -static const casadi_int jac_{{ name }}_s_out0[{{jac_ccs_len}}] = { {{ jac_ccs }}}; -CASADI_SYMBOL_EXPORT const casadi_int* jac_{{ name }}_sparsity_out(casadi_int i) { +// Only single input, single output is supported at the moment +CASADI_SYMBOL_EXPORT casadi_int {{ name }}_n_in(void) { return 1;} +CASADI_SYMBOL_EXPORT casadi_int {{ name }}_n_out(void) { return 1;} + +CASADI_SYMBOL_EXPORT const casadi_int* {{ name }}_sparsity_in(casadi_int i) { switch (i) { - case 0: return jac_{{ name }}_s_out0; + case 0: return {{ name }}_s_in0; default: return 0; } } -CASADI_SYMBOL_EXPORT casadi_int jac_{{ name }}_n_in(void) { return 2;} -CASADI_SYMBOL_EXPORT casadi_int jac_{{ name }}_n_out(void) { return 1;} -{% endif %} - -{% if has_hess == "true" and batched %} -static const casadi_int jac_jac_{{ name }}_s_out0[{{hess_ccs_len}}] = { {{ hess_ccs }}}; -static const casadi_int jac_jac_{{ name }}_s_out1[{{hess_ccs2_len}}] = { {{ hess_ccs2 }}}; -CASADI_SYMBOL_EXPORT const casadi_int* jac_jac_{{ name }}_sparsity_out(casadi_int i) { +CASADI_SYMBOL_EXPORT const casadi_int* {{ name }}_sparsity_out(casadi_int i) { switch (i) { - case 0: return jac_jac_{{ name }}_s_out0; - case 1: return jac_jac_{{ name }}_s_out1; + case 0: return {{ name }}_s_out0; default: return 0; } } -CASADI_SYMBOL_EXPORT casadi_int jac_jac_{{ name }}_n_in(void) { return 3;} - -CASADI_SYMBOL_EXPORT casadi_int jac_jac_{{ name }}_n_out(void) { return 2;} -{% endif %} CASADI_SYMBOL_EXPORT int {{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ - l4casadi.forward(arg[0], {{ rows_in }}, {{ cols_in }}, res[0]); + l4casadi.forward(arg[0], res[0]); return 0; } {% if has_jac == "true" %} +// Jacobian {{ name }} + +CASADI_SYMBOL_EXPORT casadi_int jac_{{ name }}_n_in(void) { return 2;} +CASADI_SYMBOL_EXPORT casadi_int jac_{{ name }}_n_out(void) { return 1;} + CASADI_SYMBOL_EXPORT int jac_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ - l4casadi.jac(arg[0], {{ rows_in }}, {{ cols_in }}, res[0]); + l4casadi.jac(arg[0], res[0]); return 0; } + +{% if batched == "true" %} +// Sparse output if batched. +static const casadi_int jac_{{ name }}_s_out0[{{jac_ccs_len}}] = { {{ jac_ccs }}}; + +CASADI_SYMBOL_EXPORT const casadi_int* jac_{{ name }}_sparsity_out(casadi_int i) { + switch (i) { + case 0: return jac_{{ name }}_s_out0; + default: return 0; + } +} +{% endif %} {% endif %} -{% if has_hess == "true" %} -CASADI_SYMBOL_EXPORT int jac_jac_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ - l4casadi.hess(arg[0], {{ rows_in }}, {{ cols_in }}, res[0]); + +{% if has_adj1 == "true" %} +// adj1 {{ name }} + +CASADI_SYMBOL_EXPORT casadi_int adj1_{{ name }}_n_in(void) { return 3;} +CASADI_SYMBOL_EXPORT casadi_int adj1_{{ name }}_n_out(void) { return 1;} + +CASADI_SYMBOL_EXPORT int adj1_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ + // adj1 [i0, out_o0, adj_o0] -> [out_adj_i0] + l4casadi.adj1(arg[0], arg[2], res[0]); return 0; } {% endif %} -// Only single input, single output is supported at the moment -CASADI_SYMBOL_EXPORT casadi_int {{ name }}_n_in(void) { return 1;} -CASADI_SYMBOL_EXPORT casadi_int {{ name }}_n_out(void) { return 1;} +{% if has_jac_adj1 == "true" %} +// jac_adj1 {{ name }} -CASADI_SYMBOL_EXPORT const casadi_int* {{ name }}_sparsity_in(casadi_int i) { +CASADI_SYMBOL_EXPORT casadi_int jac_adj1_{{ name }}_n_in(void) { return 4;} +CASADI_SYMBOL_EXPORT casadi_int jac_adj1_{{ name }}_n_out(void) { return 3;} + +CASADI_SYMBOL_EXPORT int jac_adj1_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ + // jac_adj1 [i0, out_o0, adj_o0, out_adj_i0] -> [jac_adj_i0_i0, jac_adj_i0_out_o0, jac_adj_i0_adj_o0] + if (res[1] != NULL) { + l4casadi.invalid_argument("jac_adj_i0_out_o0 is not provided by L4CasADi. If you need this feature, please contact the L4CasADi developer."); + } + if (res[2] != NULL) { + l4casadi.invalid_argument("jac_adj_i0_adj_o0 is not provided by L4CasADi. If you need this feature, please contact the L4CasADi developer."); + } + if (res[0] == NULL) { + l4casadi.invalid_argument("L4CasADi can only provide jac_adj_i0_i0 for jac_adj1_{{ name }} function. If you need this feature, please contact the L4CasADi developer."); + } + l4casadi.jac_adj1(arg[0], arg[2], res[0]); + return 0; +} + +{% if batched == "true" %} +// Sparse output if batched. +static const casadi_int jac_adj1_{{ name }}_s_out0[{{jac_adj_ccs_len}}] = { {{ jac_adj_ccs }}}; +static const casadi_int jac_adj1_{{ name }}_s_out23[3] = { {{ rows_in }} * {{ cols_in }}, {{ rows_out }} * {{ cols_out }}, 1}; + +CASADI_SYMBOL_EXPORT const casadi_int* jac_adj1_{{ name }}_sparsity_out(casadi_int i) { switch (i) { - case 0: return {{ name }}_s_in0; + case 0: return jac_adj1_{{ name }}_s_out0; + case 1: return jac_adj1_{{ name }}_s_out23; + case 2: return jac_adj1_{{ name }}_s_out23; default: return 0; } } +{% endif %} +{% endif %} -CASADI_SYMBOL_EXPORT const casadi_int* {{ name }}_sparsity_out(casadi_int i) { + +{% if has_jac_jac == "true" %} +// jac_jac {{ name }} + +CASADI_SYMBOL_EXPORT int jac_jac_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ + // [i0, out_o0, out_jac_o0_i0] -> [jac_jac_o0_i0_i0, jac_jac_o0_i0_out_o0] + if (res[1] != NULL) { + l4casadi.invalid_argument("jac_jac_o0_i0_out_o0 is not provided by L4CasADi. If you need this feature, please contact the L4CasADi developer."); + } + if (res[0] == NULL) { + l4casadi.invalid_argument("L4CasADi can only provide jac_jac_o0_i0_i0 for jac_jac_{{ name }} function. If you need this feature, please contact the L4CasADi developer."); + } + l4casadi.jac_jac(arg[0], res[0]); + return 0; +} + +{% if batched == "true" %} +// jac_jac {{ name }} + +static const casadi_int jac_jac_{{ name }}_s_out0[{{jac_jac_ccs_len}}] = { {{ jac_jac_ccs }}}; +static const casadi_int jac_jac_{{ name }}_s_out1[3] = { {{ rows_in }} * {{ cols_in }} * {{ rows_out }} * {{ cols_out }}, {{ rows_out }} * {{ cols_out }}, 1}; +CASADI_SYMBOL_EXPORT const casadi_int* jac_jac_{{ name }}_sparsity_out(casadi_int i) { switch (i) { - case 0: return {{ name }}_s_out0; + case 0: return jac_jac_{{ name }}_s_out0; + case 1: return jac_jac_{{ name }}_s_out1; default: return 0; } } +CASADI_SYMBOL_EXPORT casadi_int jac_jac_{{ name }}_n_in(void) { return 3;} + +CASADI_SYMBOL_EXPORT casadi_int jac_jac_{{ name }}_n_out(void) { return 2;} +{% endif %} +{% endif %} #ifdef __cplusplus } /* extern "C" */ diff --git a/libl4casadi/include/l4casadi.hpp b/libl4casadi/include/l4casadi.hpp index ccdbb5b..a0d277d 100644 --- a/libl4casadi/include/l4casadi.hpp +++ b/libl4casadi/include/l4casadi.hpp @@ -7,13 +7,22 @@ class L4CasADi { private: - bool model_expects_batch_dim; + int rows_in; + int cols_in; + + int rows_out; + int cols_out; public: - L4CasADi(std::string, std::string, std::string = "cpu", bool = false, bool = false, bool = false); + L4CasADi(std::string, std::string, int, int, int, int, std::string = "cpu", bool = false, bool = false, bool = false, bool = false, + bool = false); ~L4CasADi(); - void forward(const double*, int, int, double*); - void jac(const double*, int, int, double*); - void hess(const double*, int, int, double*); + void forward(const double*, double*); + void jac(const double*, double*); + void adj1(const double*, const double*, double*); + void jac_adj1(const double*, const double*, double*); + void jac_jac(const double*, double*); + + void invalid_argument(std::string); // PImpl Idiom class L4CasADiImpl; diff --git a/libl4casadi/src/l4casadi.cpp b/libl4casadi/src/l4casadi.cpp index 30f54d8..07782a2 100644 --- a/libl4casadi/src/l4casadi.cpp +++ b/libl4casadi/src/l4casadi.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -15,10 +16,14 @@ class L4CasADi::L4CasADiImpl std::string model_prefix; bool has_jac; + bool has_adj1; + bool has_jac_adj1; bool has_hess; + torch::jit::script::Module adj1_model; torch::jit::script::Module forward_model; torch::jit::script::Module jac_model; + torch::jit::script::Module jac_adj1_model; torch::jit::script::Module hess_model; torch::Device device; @@ -28,9 +33,10 @@ class L4CasADi::L4CasADiImpl std::atomic reload_model_loop_running = false; public: - L4CasADiImpl(std::string model_path, std::string model_prefix, std::string device, bool has_jac, bool has_hess, - bool model_is_mutable): device{torch::kCPU}, model_path{model_path}, model_prefix{model_prefix}, - has_jac{has_jac}, has_hess{has_hess} { + L4CasADiImpl(std::string model_path, std::string model_prefix, std::string device, bool has_jac, bool has_adj1, + bool has_jac_adj1, bool has_hess, bool model_is_mutable): device{torch::kCPU}, model_path{model_path}, + model_prefix{model_prefix}, has_jac{has_jac}, has_adj1{has_adj1}, has_jac_adj1{has_jac_adj1}, + has_hess{has_hess} { if (torch::cuda::is_available() && device.compare("cpu")) { std::cout << "CUDA is available! Using GPU " << device << "." << std::endl; @@ -76,14 +82,30 @@ class L4CasADi::L4CasADiImpl void load_model_from_disk() { std::filesystem::path dir (this->model_path); - std::filesystem::path forward_model_file (this->model_prefix + "_forward.pt"); + std::filesystem::path forward_model_file (this->model_prefix + ".pt"); this->forward_model = torch::jit::load((dir / forward_model_file).generic_string()); this->forward_model.to(this->device); this->forward_model.eval(); this->forward_model = torch::jit::optimize_for_inference(this->forward_model); + if (this->has_adj1) { + std::filesystem::path adj1_model_file ("adj1_" + this->model_prefix + ".pt"); + this->adj1_model = torch::jit::load((dir / adj1_model_file).generic_string()); + this->adj1_model.to(this->device); + this->adj1_model.eval(); + this->adj1_model = torch::jit::optimize_for_inference(this->adj1_model); + } + + if (this->has_jac_adj1) { + std::filesystem::path jac_adj1_model_file ("jac_adj1_" + this->model_prefix + ".pt"); + this->jac_adj1_model = torch::jit::load((dir / jac_adj1_model_file).generic_string()); + this->jac_adj1_model.to(this->device); + this->jac_adj1_model.eval(); + this->jac_adj1_model = torch::jit::optimize_for_inference(this->jac_adj1_model); + } + if (this->has_jac) { - std::filesystem::path jac_model_file (this->model_prefix + "_jacrev.pt"); + std::filesystem::path jac_model_file ("jac_" + this->model_prefix + ".pt"); this->jac_model = torch::jit::load((dir / jac_model_file).generic_string()); this->jac_model.to(this->device); this->jac_model.eval(); @@ -91,7 +113,7 @@ class L4CasADi::L4CasADiImpl } if (this->has_hess) { - std::filesystem::path hess_model_file (this->model_prefix + "_hess.pt"); + std::filesystem::path hess_model_file ("jac_jac_" + this->model_prefix + ".pt"); this->hess_model = torch::jit::load((dir / hess_model_file).generic_string()); this->hess_model.to(this->device); this->hess_model.eval(); @@ -99,59 +121,106 @@ class L4CasADi::L4CasADiImpl } } - torch::Tensor forward(torch::Tensor input) { + torch::Tensor forward(torch::Tensor x) { std::unique_lock lock(this->model_update_mutex); c10::InferenceMode guard; std::vector inputs; - inputs.push_back(input.to(this->device)); + inputs.push_back(x.to(this->device)); return this->forward_model.forward(inputs).toTensor().to(cpu); } - torch::Tensor jac(torch::Tensor input) { + torch::Tensor jac(torch::Tensor x) { std::unique_lock lock(this->model_update_mutex); c10::InferenceMode guard; std::vector inputs; - inputs.push_back(input.to(this->device)); + inputs.push_back(x.to(this->device)); return this->jac_model.forward(inputs).toTensor().to(cpu); } - torch::Tensor hess(torch::Tensor input) { + torch::Tensor adj1(torch::Tensor primal, torch::Tensor tangent) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(primal.to(this->device)); + inputs.push_back(tangent.to(this->device)); + + return this->adj1_model.forward(inputs).toTensor().to(cpu); + } + + torch::Tensor jac_adj1(torch::Tensor primal, torch::Tensor tangent){ + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(primal.to(this->device)); + inputs.push_back(tangent.to(this->device)); + + return this->jac_adj1_model.forward(inputs).toTensor().to(cpu); + } + + torch::Tensor hess(torch::Tensor x) { std::unique_lock lock(this->model_update_mutex); c10::InferenceMode guard; std::vector inputs; - inputs.push_back(input.to(this->device)); + inputs.push_back(x.to(this->device)); return this->hess_model.forward(inputs).toTensor().to(cpu); } }; -L4CasADi::L4CasADi(std::string model_path, std::string model_prefix, std::string device, - bool has_jac, bool has_hess, bool model_is_mutable): - pImpl{std::make_unique(model_path, model_prefix, device, has_jac, has_hess, model_is_mutable)} {} +L4CasADi::L4CasADi(std::string model_path, std::string model_prefix, int rows_in, int cols_in, int rows_out, int cols_out, + std::string device, bool has_jac, bool has_adj1, bool has_jac_adj1, bool has_hess, bool model_is_mutable): + pImpl{std::make_unique(model_path, model_prefix, device, has_jac, has_adj1, has_jac_adj1, has_hess, + model_is_mutable)}, rows_in{rows_in}, cols_in{cols_in}, rows_out{rows_out}, cols_out{cols_out} {} + +void L4CasADi::forward(const double* x, double* out) { + torch::Tensor x_tensor; + x_tensor = torch::from_blob(( void * )x, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); + + torch::Tensor out_tensor = this->pImpl->forward(x_tensor).to(torch::kDouble).permute({1, 0}).contiguous(); + std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); +} + +void L4CasADi::jac(const double* x, double* out) { + torch::Tensor x_tensor; + x_tensor = torch::from_blob(( void * )x, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); -void L4CasADi::forward(const double* in, int rows, int cols, double* out) { - torch::Tensor in_tensor; - in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0}); + // CasADi expects the return in Fortran order -> Transpose last two dimensions + torch::Tensor out_tensor = this->pImpl->jac(x_tensor).to(torch::kDouble).permute({3, 2, 1, 0}).contiguous(); + std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); +} + +void L4CasADi::adj1(const double* p, const double* t, double* out) { + // adj1 [i0, out_o0, adj_o0] -> [out_adj_i0] + torch::Tensor p_tensor, t_tensor; + p_tensor = torch::from_blob(( void * )p, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); + t_tensor = torch::from_blob(( void * )t, {this->cols_out, this->rows_out}, at::kDouble).to(torch::kFloat).permute({1, 0}); - torch::Tensor out_tensor = this->pImpl->forward(in_tensor).to(torch::kDouble).permute({1, 0}).contiguous(); + // CasADi expects the return in Fortran order -> Transpose last two dimensions + torch::Tensor out_tensor = this->pImpl->adj1(p_tensor, t_tensor).to(torch::kDouble).permute({1, 0}).contiguous(); std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); } -void L4CasADi::jac(const double* in, int rows, int cols, double* out) { - torch::Tensor in_tensor; - in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0}); +void L4CasADi::jac_adj1(const double* p, const double* t, double* out) { + // jac_adj1 [i0, out_o0, adj_o0, out_adj_i0] -> [jac_adj_i0_i0, jac_adj_i0_out_o0, jac_adj_i0_adj_o0] + torch::Tensor p_tensor, t_tensor; + p_tensor = torch::from_blob(( void * )p, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); + t_tensor = torch::from_blob(( void * )t, {this->cols_out, this->rows_out}, at::kDouble).to(torch::kFloat).permute({1, 0}); // CasADi expects the return in Fortran order -> Transpose last two dimensions - torch::Tensor out_tensor = this->pImpl->jac(in_tensor).to(torch::kDouble).permute({3, 2, 1, 0}).contiguous(); + torch::Tensor out_tensor = this->pImpl->jac_adj1(p_tensor, t_tensor).to(torch::kDouble).permute({3, 2, 1, 0}).contiguous(); std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); } -void L4CasADi::hess(const double* in, int rows, int cols, double* out) { - torch::Tensor in_tensor; - in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0}); +void L4CasADi::jac_jac(const double* x, double* out) { + torch::Tensor x_tensor; + x_tensor = torch::from_blob(( void * )x, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); // CasADi expects the return in Fortran order -> Transpose last two dimensions - torch::Tensor out_tensor = this->pImpl->hess(in_tensor).to(torch::kDouble).permute({5, 4, 3, 2, 1, 0}).contiguous(); + torch::Tensor out_tensor = this->pImpl->hess(x_tensor).to(torch::kDouble).permute({5, 4, 3, 2, 1, 0}).contiguous(); std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); } +void L4CasADi::invalid_argument(std::string error_msg) { + throw std::invalid_argument(error_msg); +} + L4CasADi::~L4CasADi() = default; diff --git a/tests/test_batching.py b/tests/test_batching.py new file mode 100644 index 0000000..ee11abb --- /dev/null +++ b/tests/test_batching.py @@ -0,0 +1,88 @@ +import pytest +import torch +import l4casadi as l4c +import casadi as cs +import numpy as np + + +class TestL4CasADiBatching: + @pytest.mark.parametrize("batch_size,input_size,output_size,jac_ccs_target,hess_ccs_target", [ + (10, 3, 2, [20, 30, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 0, 10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19, 0, 10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19, 0, 10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19], [600, 30, 0, 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114, 120, 126, 132, 138, 144, 150, 156, 162, 168, 174, 180, 0, 10, 200, 210, 400, 410, 21, 31, 221, 231, 421, 431, 42, 52, 242, 252, 442, 452, 63, 73, 263, 273, 463, 473, 84, 94, 284, 294, 484, 494, 105, 115, 305, 315, 505, 515, 126, 136, 326, 336, 526, 536, 147, 157, 347, 357, 547, 557, 168, 178, 368, 378, 568, 578, 189, 199, 389, 399, 589, 599, 0, 10, 200, 210, 400, 410, 21, 31, 221, 231, 421, 431, 42, 52, 242, 252, 442, 452, 63, 73, 263, 273, 463, 473, 84, 94, 284, 294, 484, 494, 105, 115, 305, 315, 505, 515, 126, 136, 326, 336, 526, 536, 147, 157, 347, 357, 547, 557, 168, 178, 368, 378, 568, 578, 189, 199, 389, 399, 589, 599, 0, 10, 200, 210, 400, 410, 21, 31, 221, 231, 421, 431, 42, 52, 242, 252, 442, 452, 63, 73, 263, 273, 463, 473, 84, 94, 284, 294, 484, 494, 105, 115, 305, 315, 505, 515, 126, 136, 326, 336, 526, 536, 147, 157, 347, 357, 547, 557, 168, 178, 368, 378, 568, 578, 189, 199, 389, 399, 589, 599]), + (3, 4, 3, [9, 12, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 0, 3, 6, 1, 4, 7, 2, 5, 8, 0, 3, 6, 1, 4, 7, 2, 5, 8, 0, 3, 6, 1, 4, 7, 2, 5, 8, 0, 3, 6, 1, 4, 7, 2, 5, 8], [108, 12, 0, 12, 24, 36, 48, 60, 72, 84, 96, 108, 120, 132, 144, 0, 3, 6, 27, 30, 33, 54, 57, 60, 81, 84, 87, 10, 13, 16, 37, 40, 43, 64, 67, 70, 91, 94, 97, 20, 23, 26, 47, 50, 53, 74, 77, 80, 101, 104, 107, 0, 3, 6, 27, 30, 33, 54, 57, 60, 81, 84, 87, 10, 13, 16, 37, 40, 43, 64, 67, 70, 91, 94, 97, 20, 23, 26, 47, 50, 53, 74, 77, 80, 101, 104, 107, 0, 3, 6, 27, 30, 33, 54, 57, 60, 81, 84, 87, 10, 13, 16, 37, 40, 43, 64, 67, 70, 91, 94, 97, 20, 23, 26, 47, 50, 53, 74, 77, 80, 101, 104, 107, 0, 3, 6, 27, 30, 33, 54, 57, 60, 81, 84, 87, 10, 13, 16, 37, 40, 43, 64, 67, 70, 91, 94, 97, 20, 23, 26, 47, 50, 53, 74, 77, 80, 101, 104, 107]) + ]) + def test_ccs(self, batch_size, input_size, output_size, jac_ccs_target, hess_ccs_target): + jac_ccs, hess_ccs = l4c.L4CasADi.generate_block_diagonal_ccs(batch_size, input_size, output_size) + + assert jac_ccs == jac_ccs_target + assert hess_ccs == hess_ccs_target + + def test_l4casadi_sparse_out(self): + def model(x): + return torch.stack([(x[:, 0]**2 * x[:, 1]**2 * x[:, 2]**2), - (x[:, 0]**2 * x[:, 1]**2)], dim=-1) + + def model_cs(x): + return cs.hcat([(x[:, 0]**2 * x[:, 1]**2 * x[:, 2]**2), - (x[:, 0]**2 * x[:, 1]**2)]) + + inp = np.ones((5, 3)) + inp_sym = cs.MX.sym('x', 5, 3) + + jac_func_cs = cs.Function('f', [inp_sym], [cs.jacobian(model_cs(inp_sym), inp_sym)]) + jac_sparse_cs = jac_func_cs(inp) + + hess_func_cs = cs.Function('f', [inp_sym], [cs.jacobian(cs.jacobian(model_cs(inp_sym), inp_sym), inp_sym)]) + hess_sparse_cs = hess_func_cs(inp) + + l4c_model = l4c.L4CasADi(model, batched=True, generate_jac_jac=True) + + jac_func = cs.Function('f', [inp_sym], [cs.jacobian(l4c_model(inp_sym), inp_sym)]) + jac_sparse = jac_func(inp) + + hess_func = cs.Function('f', [inp_sym], [cs.jacobian(cs.jacobian(l4c_model(inp_sym), inp_sym), inp_sym)]) + hess_sparse = hess_func(inp) + + assert np.allclose(np.array(jac_sparse), np.array(jac_sparse_cs)) + assert np.allclose(np.array(hess_sparse), np.array(hess_sparse_cs)) + + def test_l4casadi_sparse_out_adj1(self): + def model(x): + return torch.stack([(x[:, 0] ** 2 * x[:, 1] ** 2 * x[:, 2] ** 2), - (x[:, 0] ** 2 * x[:, 1] ** 2)], dim=-1) + + def model_cs(x): + return cs.hcat([(x[:, 0] ** 2 * x[:, 1] ** 2 * x[:, 2] ** 2), -(x[:, 0] ** 2 * x[:, 1] ** 2)]) + + inp = np.ones((5, 3)) + tangent = np.zeros((5, 2)) + tangent[:, 0] = 1. + + inp_sym = cs.MX.sym('x', 5, 3) + tangent_sym = cs.MX.sym('x', 5, 2) + + func_cs = cs.Function('f', [inp_sym], [model_cs(inp_sym)]) + adj1_func_cs = func_cs.reverse(1) + + out_sym = func_cs(inp_sym) + out_cs = func_cs(inp) + adj1_out_cs = adj1_func_cs(inp, out_cs, tangent) + + + l4c_model = l4c.L4CasADi(model, batched=True) + y = l4c_model(inp_sym) + + func_t = l4c_model._cs_fun + adj1_func_t = func_t.reverse(1) + + out_t = func_t(inp) + adj1_out_t = adj1_func_t(inp, out_t, tangent) + + assert (np.array(adj1_out_cs) == np.array(adj1_out_t)).all() + + jac_adj1_func_cs = cs.Function('jac_adj1_f', [inp_sym, tangent_sym], + [cs.jacobian(adj1_func_cs(inp_sym, out_sym, tangent_sym), inp_sym)]) + jac_adj1_cs = jac_adj1_func_cs(inp, tangent) + + jac_adj1_func_t = cs.Function('jac_adj1_ft', [inp_sym, tangent_sym], + [cs.jacobian(adj1_func_t(inp_sym, func_t(inp_sym), tangent_sym), inp_sym)]) + jac_adj1_t = jac_adj1_func_t(inp, tangent) + + assert (np.array(jac_adj1_cs) == np.array(jac_adj1_t)).all() + diff --git a/tests/test_l4casadi.py b/tests/test_l4casadi.py index 16a06dd..4d1642c 100644 --- a/tests/test_l4casadi.py +++ b/tests/test_l4casadi.py @@ -50,15 +50,15 @@ def test_l4casadi_deep_model(self, deep_model): rand_inp = torch.rand((1, deep_model.input_layer.in_features)) torch_out = deep_model(rand_inp) - l4c_out = l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = l4c.L4CasADi(deep_model, batched=True)(rand_inp.detach().numpy()) - assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy(), atol=1e-6) + assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) def test_l4casadi_triag_model(self, triag_model): rand_inp = torch.rand((12, 12)) torch_out = triag_model(rand_inp) - l4c_out = l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(rand_inp.detach().numpy()) + l4c_out = l4c.L4CasADi(triag_model)(rand_inp.detach().numpy()) assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) @@ -70,7 +70,7 @@ def test_l4casadi_triag_model_jac(self, triag_model): jac_fun = cs.Function('f_jac', [mx_inp], - [cs.jacobian(l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(mx_inp), mx_inp)]) + [cs.jacobian(l4c.L4CasADi(triag_model)(mx_inp), mx_inp)]) l4c_out = jac_fun(rand_inp.detach().numpy()) @@ -88,7 +88,7 @@ def test_l4casadi_triag_model_hess_double_jac(self, triag_model): [mx_inp], [cs.jacobian( cs.jacobian( - l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(mx_inp), mx_inp + l4c.L4CasADi(triag_model, generate_jac_jac=True)(mx_inp), mx_inp )[0, 0], mx_inp)]) l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy()) @@ -100,28 +100,43 @@ def test_l4casadi_deep_model_jac(self, deep_model): rand_inp = torch.rand((1, deep_model.input_layer.in_features)) torch_out = torch.func.vmap(torch.func.jacrev(deep_model))(rand_inp)[0] - mx_inp = cs.MX.sym('x', deep_model.input_layer.in_features, 1) + mx_inp = cs.MX.sym('x', 1, deep_model.input_layer.in_features) jac_fun = cs.Function('f_jac', [mx_inp], - [cs.jacobian(l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(mx_inp), mx_inp)]) + [cs.jacobian(l4c.L4CasADi(deep_model)(mx_inp), mx_inp)]) - l4c_out = jac_fun(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = jac_fun(rand_inp.detach().numpy()) assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) - def test_l4casadi_deep_model_hess(self): + def test_l4casadi_deep_model_hess_with_jac_adj(self): deep_model = DeepModel(4, 1) rand_inp = torch.rand((1, deep_model.input_layer.in_features)) torch_out = torch.func.vmap(torch.func.hessian(deep_model))(rand_inp)[0] - mx_inp = cs.MX.sym('x', deep_model.input_layer.in_features, 1) + mx_inp = cs.MX.sym('x', 1, deep_model.input_layer.in_features) hess_fun = cs.Function('f_hess', [mx_inp], - [cs.hessian(l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(mx_inp), mx_inp)[0]]) + [cs.hessian(l4c.L4CasADi(deep_model, generate_adj1=True, generate_jac_jac=False)(mx_inp), mx_inp)[0]]) - l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = hess_fun(rand_inp.detach().numpy()) + + assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) + + def test_l4casadi_deep_model_hess_with_jac_jac(self): + deep_model = DeepModel(4, 1) + rand_inp = torch.rand((1, deep_model.input_layer.in_features)) + torch_out = torch.func.vmap(torch.func.hessian(deep_model))(rand_inp)[0] + + mx_inp = cs.MX.sym('x', 1, deep_model.input_layer.in_features) + + hess_fun = cs.Function('f_hess', + [mx_inp], + [cs.hessian(l4c.L4CasADi(deep_model, generate_adj1=False, generate_jac_jac=True)(mx_inp), mx_inp)[0]]) + + l4c_out = hess_fun(rand_inp.detach().numpy()) assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) @@ -130,25 +145,25 @@ def test_l4casadi_deep_model_hess_double_jac(self): rand_inp = torch.rand((1, deep_model.input_layer.in_features)) torch_out = torch.func.vmap(torch.func.hessian(deep_model))(rand_inp)[0] - mx_inp = cs.MX.sym('x', deep_model.input_layer.in_features, 1) + mx_inp = cs.MX.sym('x', 1, deep_model.input_layer.in_features) hess_fun = cs.Function('f_hess_double_jac', [mx_inp], [cs.jacobian( cs.jacobian( - l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(mx_inp), mx_inp + l4c.L4CasADi(deep_model, generate_jac_jac=True)(mx_inp), mx_inp )[0], mx_inp)]) - l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = hess_fun(rand_inp.detach().numpy()) assert np.allclose(l4c_out, torch_out[0, 0].detach().numpy(), atol=1e-6) def test_l4casadi_deep_model_online_update(self, deep_model): rand_inp = torch.rand((1, deep_model.input_layer.in_features)) - l4c_model = l4c.L4CasADi(deep_model, model_expects_batch_dim=True, mutable=True) + l4c_model = l4c.L4CasADi(deep_model, mutable=True) - l4c_out_old = l4c_model(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out_old = l4c_model(rand_inp.detach().numpy()) # Change model and online update L4CasADi deep_model.input_layer.reset_parameters() @@ -156,7 +171,7 @@ def test_l4casadi_deep_model_online_update(self, deep_model): torch_out = deep_model(rand_inp) - l4c_out = l4c_model(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = l4c_model(rand_inp.detach().numpy()) - assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy(), atol=1e-6) - assert not np.allclose(l4c_out_old, torch_out.transpose(-2, -1).detach().numpy(), atol=1e-6) + assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) + assert not np.allclose(l4c_out_old, torch_out.detach().numpy(), atol=1e-6) diff --git a/tests/test_naive_l4casadi.py b/tests/test_naive_l4casadi.py index 76f2a87..3198e4a 100644 --- a/tests/test_naive_l4casadi.py +++ b/tests/test_naive_l4casadi.py @@ -12,8 +12,8 @@ def test_naive_l4casadi_mlp(self): rand_inp = torch.rand((1, 2)) torch_out = naive_mlp(rand_inp) - cs_inp = cs.DM(rand_inp.transpose(-2, -1).detach().numpy()) + cs_inp = cs.DM(rand_inp.detach().numpy()) - l4c_out = l4c.L4CasADi(naive_mlp, model_expects_batch_dim=True)(cs_inp) + l4c_out = l4c.L4CasADi(naive_mlp)(cs_inp) - assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy(), atol=1e-6) + assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6)