From 161a31b4204a423eec2d62dac18ea9c7742a432d Mon Sep 17 00:00:00 2001 From: paugier Date: Mon, 26 Aug 2024 16:31:44 +0200 Subject: [PATCH] Jax: no @jit for 'internal' functions --- data_tests/saved__backend__/jax/add_inline.py | 3 --- data_tests/saved__backend__/jax/assign_func_boost.py | 3 --- data_tests/saved__backend__/jax/block_fluidsim.py | 6 ------ data_tests/saved__backend__/jax/blocks_type_hints.py | 6 ------ .../saved__backend__/jax/boosted_class_use_import.py | 6 ------ .../saved__backend__/jax/boosted_func_use_import.py | 3 --- data_tests/saved__backend__/jax/class_blocks.py | 6 ------ data_tests/saved__backend__/jax/class_rec_calls.py | 6 ------ data_tests/saved__backend__/jax/classic.py | 3 --- data_tests/saved__backend__/jax/default_params.py | 3 --- data_tests/saved__backend__/jax/methods.py | 6 ------ .../saved__backend__/jax/mixed_classic_type_hint.py | 3 --- data_tests/saved__backend__/jax/no_arg.py | 3 --- data_tests/saved__backend__/jax/row_sum_boost.py | 3 --- data_tests/saved__backend__/jax/subpackages.py | 3 --- .../saved__backend__/jax/type_hint_notemplate.py | 3 --- src/transonic/backends/jax.py | 10 +++++++++- 17 files changed, 9 insertions(+), 67 deletions(-) diff --git a/data_tests/saved__backend__/jax/add_inline.py b/data_tests/saved__backend__/jax/add_inline.py index e98cc20..e790547 100644 --- a/data_tests/saved__backend__/jax/add_inline.py +++ b/data_tests/saved__backend__/jax/add_inline.py @@ -16,8 +16,5 @@ def use_add(n=10000): return tmp -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/assign_func_boost.py b/data_tests/saved__backend__/jax/assign_func_boost.py index a5a19fb..73994d3 100644 --- a/data_tests/saved__backend__/jax/assign_func_boost.py +++ b/data_tests/saved__backend__/jax/assign_func_boost.py @@ -6,8 +6,5 @@ def func(x): return x**2 -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/block_fluidsim.py b/data_tests/saved__backend__/jax/block_fluidsim.py index 6ced9ae..56f319d 100644 --- a/data_tests/saved__backend__/jax/block_fluidsim.py +++ b/data_tests/saved__backend__/jax/block_fluidsim.py @@ -12,9 +12,6 @@ def rk2_step0(state_spect_n12, state_spect, tendencies_n, diss2, dt): state_spect_n12[:] = (state_spect + dt / 2 * tendencies_n) * diss2 -# __protected__ @jit - - def arguments_blocks(): return { "rk2_step0": [ @@ -27,8 +24,5 @@ def arguments_blocks(): } -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/blocks_type_hints.py b/data_tests/saved__backend__/jax/blocks_type_hints.py index 1c03920..e41b2fa 100644 --- a/data_tests/saved__backend__/jax/blocks_type_hints.py +++ b/data_tests/saved__backend__/jax/blocks_type_hints.py @@ -15,15 +15,9 @@ def block0(a, b, n): return result -# __protected__ @jit - - def arguments_blocks(): return {"block0": ["a", "b", "n"]} -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/boosted_class_use_import.py b/data_tests/saved__backend__/jax/boosted_class_use_import.py index 3e1f18f..1ee794f 100644 --- a/data_tests/saved__backend__/jax/boosted_class_use_import.py +++ b/data_tests/saved__backend__/jax/boosted_class_use_import.py @@ -9,15 +9,9 @@ def __for_method__MyClass2__myfunc(self_attr0, self_attr1, arg): return self_attr1 + self_attr0 + np.abs(arg) + func_import() -# __protected__ @jit - - def __code_new_method__MyClass2__myfunc(): return "\n\ndef new_method(self, arg):\n return backend_func(self.attr0, self.attr1, arg)\n\n" -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/boosted_func_use_import.py b/data_tests/saved__backend__/jax/boosted_func_use_import.py index 6b9721f..b7e287b 100644 --- a/data_tests/saved__backend__/jax/boosted_func_use_import.py +++ b/data_tests/saved__backend__/jax/boosted_func_use_import.py @@ -9,8 +9,5 @@ def func(a, b): return (a * np.log(b)).max() + func_import() -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/class_blocks.py b/data_tests/saved__backend__/jax/class_blocks.py index b3230a9..7cfcf38 100644 --- a/data_tests/saved__backend__/jax/class_blocks.py +++ b/data_tests/saved__backend__/jax/class_blocks.py @@ -40,15 +40,9 @@ def block1(a, b, n): return result -# __protected__ @jit - - def arguments_blocks(): return {"block0": ["a", "b", "n"], "block1": ["a", "b", "n"]} -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/class_rec_calls.py b/data_tests/saved__backend__/jax/class_rec_calls.py index bafa4a4..17b6141 100644 --- a/data_tests/saved__backend__/jax/class_rec_calls.py +++ b/data_tests/saved__backend__/jax/class_rec_calls.py @@ -16,15 +16,9 @@ def __for_method__Myclass__func(self_attr, self_attr2, arg): ) -# __protected__ @jit - - def __code_new_method__Myclass__func(): return "\n\ndef new_method(self, arg):\n return backend_func(self.attr, self.attr2, arg)\n\n" -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/classic.py b/data_tests/saved__backend__/jax/classic.py index e9f2d40..c4d5017 100644 --- a/data_tests/saved__backend__/jax/classic.py +++ b/data_tests/saved__backend__/jax/classic.py @@ -8,8 +8,5 @@ def func(a, b): return (a * np.log(b)).max() -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/default_params.py b/data_tests/saved__backend__/jax/default_params.py index 50def82..34edd51 100644 --- a/data_tests/saved__backend__/jax/default_params.py +++ b/data_tests/saved__backend__/jax/default_params.py @@ -7,8 +7,5 @@ def func(a=1, b=None, c=1.0): return a + c -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/methods.py b/data_tests/saved__backend__/jax/methods.py index 6da0ef8..b9d5155 100644 --- a/data_tests/saved__backend__/jax/methods.py +++ b/data_tests/saved__backend__/jax/methods.py @@ -9,15 +9,9 @@ def __for_method__Transmitter____call__(self_arr, self_freq, inp): return (inp * np.exp(np.arange(len(inp)) * self_freq * 1j), self_arr) -# __protected__ @jit - - def __code_new_method__Transmitter____call__(): return "\n\ndef new_method(self, inp):\n return backend_func(self.arr, self.freq, inp)\n\n" -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/mixed_classic_type_hint.py b/data_tests/saved__backend__/jax/mixed_classic_type_hint.py index ffa0a87..429e524 100644 --- a/data_tests/saved__backend__/jax/mixed_classic_type_hint.py +++ b/data_tests/saved__backend__/jax/mixed_classic_type_hint.py @@ -15,8 +15,5 @@ def func1(a, b): return a * np.cos(b) -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/no_arg.py b/data_tests/saved__backend__/jax/no_arg.py index d187182..39976a0 100644 --- a/data_tests/saved__backend__/jax/no_arg.py +++ b/data_tests/saved__backend__/jax/no_arg.py @@ -13,8 +13,5 @@ def func2(): return 1 -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/row_sum_boost.py b/data_tests/saved__backend__/jax/row_sum_boost.py index 984525e..6fb6e6c 100644 --- a/data_tests/saved__backend__/jax/row_sum_boost.py +++ b/data_tests/saved__backend__/jax/row_sum_boost.py @@ -24,8 +24,5 @@ def row_sum_loops(arr, columns): return res -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/subpackages.py b/data_tests/saved__backend__/jax/subpackages.py index 3e2df02..a0738ad 100644 --- a/data_tests/saved__backend__/jax/subpackages.py +++ b/data_tests/saved__backend__/jax/subpackages.py @@ -30,8 +30,5 @@ def test_sp_special(v, x): return jv(v, x) -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/data_tests/saved__backend__/jax/type_hint_notemplate.py b/data_tests/saved__backend__/jax/type_hint_notemplate.py index 5cf4432..6ac4e6b 100644 --- a/data_tests/saved__backend__/jax/type_hint_notemplate.py +++ b/data_tests/saved__backend__/jax/type_hint_notemplate.py @@ -10,8 +10,5 @@ def compute(a, b, c, d, e): return tmp -# __protected__ @jit - - def __transonic__(): return "0.7.1" diff --git a/src/transonic/backends/jax.py b/src/transonic/backends/jax.py index fb96549..0d20c96 100644 --- a/src/transonic/backends/jax.py +++ b/src/transonic/backends/jax.py @@ -43,7 +43,15 @@ def add_jax_comments(code): node.module = "jax.numpy" # Add JIT decorator - if isinstance(node, gast.FunctionDef): + if ( + isinstance(node, gast.FunctionDef) + and node.name + not in ( + "arguments_blocks", + "__transonic__", + ) + and not node.name.startswith("__code_new_method__") + ): new_body.append(CommentLine("# __protected__ @jit")) new_body.append(node)