Skip to content

Commit

Permalink
In numpy and JAX CARFAC, coalesce the "just_hwr" and "one_cap" boolea…
Browse files Browse the repository at this point in the history
…ns into a single string flag, named "ihc_style" that can take on 3 appropriately named string values.

Update the code in numpy and python throughout to use this new style, modifying coefficients and hypers to use integers where appropriate.

Update the unit tests and benchmarks to use this new path.

PiperOrigin-RevId: 679538551
  • Loading branch information
Rob Schonberger authored and copybara-github committed Sep 27, 2024
1 parent eb13cc1 commit 90731bd
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 130 deletions.
72 changes: 40 additions & 32 deletions python/jax/carfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,7 @@ def tree_unflatten(cls, _, children):
@dataclasses.dataclass
class IhcDesignParameters:
"""Variables needed for the inner hair cell implementation."""
just_hwr: bool = False
n_caps: int = 2
ihc_style: str = 'two_cap'
tau_lpf: float = 0.000080 # 80 microseconds smoothing twice
tau_out: float = 0.0005 # depletion tau is pretty fast
tau_in: float = 0.010 # recovery tau is slower
Expand All @@ -530,24 +529,26 @@ class IhcDesignParameters:
# The following 2 functions are boiler code required by pytree.
# Reference: https://jax.readthedocs.io/en/latest/pytrees.html
def tree_flatten(self): # pylint: disable=missing-function-docstring
children = (self.just_hwr,
self.n_caps,
self.tau_lpf,
self.tau_out,
self.tau_in,
self.tau1_out,
self.tau1_in,
self.tau2_out,
self.tau2_in)
aux_data = ('just_hwr',
'n_caps',
'tau_lpf',
'tau_out',
'tau_in',
'tau1_out',
'tau1_in',
'tau2_out',
'tau2_in')
children = (
self.ihc_style,
self.tau_lpf,
self.tau_out,
self.tau_in,
self.tau1_out,
self.tau1_in,
self.tau2_out,
self.tau2_in,
)
aux_data = (
'ihc_style',
'tau_lpf',
'tau_out',
'tau_in',
'tau1_out',
'tau1_in',
'tau2_out',
'tau2_in',
)
return (children, aux_data)

@classmethod
Expand All @@ -560,14 +561,14 @@ def tree_unflatten(cls, _, children):
class IhcHypers:
"""Hyperparameters for the inner hair cell. Tagged `static` in `jax.jit`."""
n_ch: int
just_hwr: bool
n_caps: int
# 0 is just_hwr, 1 is one_cap, 2 is two_cap.
ihc_style: int

# The following 2 functions are boiler code required by pytree.
# Reference: https://jax.readthedocs.io/en/latest/pytrees.html
def tree_flatten(self):
children = (self.n_ch, self.just_hwr, self.n_caps)
aux_data = ('n_ch', 'just_hwr', 'n_caps')
children = (self.n_ch, self.ihc_style)
aux_data = ('n_ch', 'ihc_style')
return (children, aux_data)

@classmethod
Expand Down Expand Up @@ -1130,13 +1131,20 @@ def design_and_init_ihc(
ihc_params = ear_params.ihc

n_ch = ear_hypers.n_ch
ihc_hypers = IhcHypers(
n_ch=n_ch, just_hwr=ihc_params.just_hwr, n_caps=ihc_params.n_caps
)
if ihc_params.just_hwr:
ihc_style_num = 0
if ihc_params.ihc_style == 'just_hwr':
ihc_style_num = 0
elif ihc_params.ihc_style == 'one_cap':
ihc_style_num = 1
elif ihc_params.ihc_style == 'two_cap':
ihc_style_num = 2
else:
raise NotImplementedError
ihc_hypers = IhcHypers(n_ch=n_ch, ihc_style=ihc_style_num)
if ihc_params.ihc_style == 'just_hwr':
ihc_weights = IhcWeights()
ihc_state = IhcState(ihc_accum=jnp.zeros((n_ch,)))
elif ihc_params.n_caps == 1:
elif ihc_params.ihc_style == 'one_cap':
ro = 1 / ihc_detect(10) # output resistance at a very high level
c = ihc_params.tau_out / ro
ri = ihc_params.tau_in / c
Expand All @@ -1159,7 +1167,7 @@ def design_and_init_ihc(
lpf1_state=ihc_weights.rest_output * jnp.ones((n_ch,)),
lpf2_state=ihc_weights.rest_output * jnp.ones((n_ch,)),
)
elif ihc_params.n_caps == 2:
elif ihc_params.ihc_style == 'two_cap':
g1_max = ihc_detect(10) # receptor conductance at high level

r1min = 1 / g1_max
Expand Down Expand Up @@ -1631,13 +1639,13 @@ def ihc_step(
ihc_weights = weights.ears[ear].ihc
ihc_hypers = hypers.ears[ear].ihc

if ihc_hypers.just_hwr:
if ihc_hypers.ihc_style == 0:
ihc_out = jnp.min(2, jnp.max(0, bm_out)) # pytype: disable=wrong-arg-types # jnp-type
# limit it for stability
else:
conductance = ihc_detect(bm_out) # rectifying nonlinearity

if ihc_hypers.n_caps == 1:
if ihc_hypers.ihc_style == 1:
ihc_out = conductance * ihc_state.cap_voltage
ihc_state.cap_voltage = (
ihc_state.cap_voltage
Expand Down
20 changes: 10 additions & 10 deletions python/jax/carfac_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def bench_jax_grad(state: google_benchmark.State):
Args:
state: The Benchmark state for this run.
"""
one_cap = False
ihc_style = 'two_cap'
random_seed = 1
params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False
random_generator = jax.random.PRNGKey(random_seed)
n_samp = state.range(0)
Expand Down Expand Up @@ -202,10 +202,10 @@ def bench_jit_compile_time(state: google_benchmark.State):
Args:
state: The benchmark state to execute over.
"""
one_cap = False
ihc_style = 'two_cap'
random_seed = 1
params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False
random_generator = jax.random.PRNGKey(random_seed)
n_samp = 1
Expand Down Expand Up @@ -251,10 +251,10 @@ def bench_jax_in_slices(state: google_benchmark.State):
state: the benchmark state for this execution run.
"""
# Inits JAX version
one_cap = False
ihc_style = 'two_cap'
random_seed = 1
params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False

# Generate some random inputs.
Expand Down Expand Up @@ -335,11 +335,11 @@ def bench_jax(state: google_benchmark.State):
state: the benchmark state for this execution run.
"""
# Inits JAX version
one_cap = False
ihc_style = 'two_cap'
random_seed = 1
params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].car.use_delay_buffer = state.range(2)
params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False

# Generate some random inputs.
Expand Down Expand Up @@ -393,10 +393,10 @@ def bench_jax_util_mapped(state: google_benchmark.State):
"""
if jax.device_count() < state.range(0):
state.skip_with_error(f'requires {state.range(0)} devices')
one_cap = False
random_seed = state.range(0)
ihc_style = 'two_cap'
params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False
random_generator = jax.random.PRNGKey(random_seed)
hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac(
Expand Down
8 changes: 5 additions & 3 deletions python/jax/carfac_float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def _assert_almost_equal_pytrees(self, pytree1, pytree2, delta=None):
self.assertSequenceAlmostEqual(elements1, elements2, delta=delta)

@parameterized.product(
random_seed=[x for x in range(20)], one_cap=[False, True], n_ears=[1, 2]
random_seed=[x for x in range(20)],
ihc_style=['one_cap', 'two_cap'],
n_ears=[1, 2],
)
def test_backward_pass(self, random_seed, one_cap, n_ears):
def test_backward_pass(self, random_seed, ihc_style, n_ears):
# Tests `jax.grad` can give similar gradients computed by numeric method.
@functools.partial(jax.jit, static_argnames=('hypers',))
def loss(weights, input_waves, hypers, state):
Expand Down Expand Up @@ -66,7 +68,7 @@ def loss(weights, input_waves, hypers, state):
# Computes gradients by `jax.grad`.
gfunc = jax.grad(loss, has_aux=True)
params_jax = carfac_jax.CarfacDesignParameters(n_ears=n_ears)
params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False
hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac(
params_jax
Expand Down
51 changes: 25 additions & 26 deletions python/jax/carfac_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@ def test_hypers_hash(self):
hypers.ears[0].car = carfac_jax.CarHypers()
hypers.ears[0].agc = [carfac_jax.AgcHypers(n_ch=1, n_agc_stages=2),
carfac_jax.AgcHypers(n_ch=1, n_agc_stages=2)]
hypers.ears[0].ihc = carfac_jax.IhcHypers(n_ch=1,
just_hwr=True,
n_caps=1)
hypers.ears[0].ihc = carfac_jax.IhcHypers(n_ch=1, ihc_style=1)
h1 = hash(hypers)
hypers.ears[0].car.n_ch += 1
h2 = hash(hypers)
self.assertNotEqual(h1, h2)
hypers.ears[0].agc[1].reverse_cumulative_decimation += 1
h3 = hash(hypers)
self.assertNotEqual(h2, h3)
hypers.ears[0].ihc.just_hwr = not hypers.ears[0].ihc.just_hwr
hypers.ears[0].ihc.ihc_style = 2
h4 = hash(hypers)
self.assertNotEqual(h3, h4)

Expand Down Expand Up @@ -110,15 +108,15 @@ def container_comparison(self, left_side, right_side, exclude_keys=None):
msg='failed comparison on key item %s' % (k),
)

@parameterized.parameters([1, 2])
def test_equal_design(self, n_caps):
@parameterized.parameters(['one_cap', 'two_cap'])
def test_equal_design(self, ihc_style):
# Test: the designs are similar.
cfp = carfac_np.design_carfac(one_cap=(n_caps == 1))
cfp = carfac_np.design_carfac(ihc_style=ihc_style)
carfac_np.carfac_init(cfp)
cfp.ears[0].car_coeffs.linear = False

params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].ihc.n_caps = n_caps
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False
hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac(
params_jax
Expand Down Expand Up @@ -169,11 +167,11 @@ def test_equal_design(self, n_caps):
self.container_comparison(
hypers_jax.ears[ear_idx].ihc,
ear_params_np.ihc_coeffs,
exclude_keys={'n_caps'},
exclude_keys={'ihc_style'},
)
self.assertEqual(
ear_params_np.ihc_coeffs.one_cap,
hypers_jax.ears[ear_idx].ihc.n_caps == 1,
ear_params_np.ihc_coeffs.ihc_style,
hypers_jax.ears[ear_idx].ihc.ihc_style,
)

self.container_comparison(
Expand All @@ -182,7 +180,7 @@ def test_equal_design(self, n_caps):
exclude_keys='lpf2_state',
)

if ear_params_np.ihc_coeffs.one_cap:
if ear_params_np.ihc_coeffs.ihc_style == 1:
self.assertSequenceAlmostEqual(
state_jax.ears[ear_idx].ihc.lpf2_state,
ear_params_np.ihc_state.lpf2_state,
Expand All @@ -195,11 +193,7 @@ def test_equal_design(self, n_caps):
# now we only check these one by one. We could add tests for 2 cap
# similarly.
self.assertEqual(
cfp.ihc_params.one_cap, params_jax.ears[ear_idx].ihc.n_caps == 1
)

self.assertEqual(
cfp.ihc_params.just_hwr, params_jax.ears[ear_idx].ihc.just_hwr
cfp.ihc_params.ihc_style, params_jax.ears[ear_idx].ihc.ihc_style
)

self.assertEqual(
Expand Down Expand Up @@ -250,13 +244,14 @@ def test_equal_design(self, n_caps):
)

@parameterized.product(
random_seed=[x for x in range(5)], one_cap=[False, True]
random_seed=[x for x in range(5)],
ihc_style=['one_cap', 'two_cap'],
)
def test_chunked_naps_same_as_jit(self, random_seed, one_cap):
def test_chunked_naps_same_as_jit(self, random_seed, ihc_style):
"""Tests whether `run_segment` produces the same results as np version."""
# Inits JAX version
params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].ihc.n_caps = 1 if one_cap else 2
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False
hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac(
params_jax
Expand Down Expand Up @@ -294,26 +289,28 @@ def test_chunked_naps_same_as_jit(self, random_seed, one_cap):

@parameterized.product(
random_seed=[x for x in range(20)],
one_cap=[False, True],
ihc_style=['one_cap', 'two_cap'],
n_ears=[1, 2],
delay_buffer=[False, True],
)
def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer):
def test_equal_forward_pass(
self, random_seed, ihc_style, n_ears, delay_buffer
):
"""Tests whether `run_segment` produces the same results as np version."""
# Inits JAX version
params_jax = carfac_jax.CarfacDesignParameters(
n_ears=n_ears, use_delay_buffer=delay_buffer
)
params_jax.n_ears = n_ears
for ear in range(n_ears):
params_jax.ears[ear].ihc.n_caps = 1 if one_cap else 2
params_jax.ears[ear].ihc.ihc_style = ihc_style
params_jax.ears[ear].car.linear_car = False
hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac(
params_jax
)
# Inits numpy version
cfp = carfac_np.design_carfac(
one_cap=one_cap, n_ears=n_ears, use_delay_buffer=delay_buffer
ihc_style=ihc_style, n_ears=n_ears, use_delay_buffer=delay_buffer
)

carfac_np.carfac_init(cfp)
Expand Down Expand Up @@ -419,7 +416,7 @@ def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer):
state_np.ears[ear].ihc_state.lpf1_state,
delta=1e-3, # Low Precision
)
if cfp.ears[ear].ihc_coeffs.one_cap:
if cfp.ears[ear].ihc_coeffs.ihc_style == 1:
self.assertSequenceAlmostEqual(
state_jax.ears[ear].ihc.lpf2_state,
state_np.ears[ear].ihc_state.lpf2_state,
Expand All @@ -430,7 +427,7 @@ def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer):
state_np.ears[ear].ihc_state.cap_voltage,
delta=2e-5, # Low Precision
)
else:
elif cfp.ears[ear].ihc_coeffs.ihc_style == 2:
# `state_np` won't have `cap1_voltage` or `cap2_voltage` if
# `one_cap==True`.
self.assertSequenceAlmostEqual(
Expand All @@ -443,6 +440,8 @@ def test_equal_forward_pass(self, random_seed, one_cap, n_ears, delay_buffer):
state_np.ears[ear].ihc_state.cap2_voltage,
delta=1e-5, # Low Precision
)
else:
self.fail('Unsupported IHC style.')
# Comapares agc state
for stage in range(hypers_jax.ears[ear].agc[0].n_agc_stages):
self.assertSequenceAlmostEqual(
Expand Down
4 changes: 2 additions & 2 deletions python/jax/carfac_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class CarfacUtilTest(absltest.TestCase):

def setUp(self):
super().setUp()
self.one_cap = False
self.ihc_style = 'two_cap'
self.random_seed = 17234
self.open_loop = False
params_jax = carfac.CarfacDesignParameters()
params_jax.ears[0].ihc.n_caps = 1 if self.one_cap else 2
params_jax.ears[0].ihc.ihc_style = self.ihc_style
params_jax.ears[0].car.linear_car = False
self.random_generator = jax.random.PRNGKey(self.random_seed)
self.hypers, self.weights, self.init_state = carfac.design_and_init_carfac(
Expand Down
Loading

0 comments on commit 90731bd

Please sign in to comment.