Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compat option for jax._src.core.canonicalize_shape #340

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pysages/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from jax import numpy as np
from jax import random, vmap
from jax._src.nn import initializers
from jax.core import as_named_shape
from jax.numpy.linalg import norm
from jax.tree_util import PyTreeDef, tree_flatten
from numpy import cumsum
from plum import Dispatcher

from pysages.typing import NamedTuple
from pysages.utils import identity, prod
from pysages.utils.compat import canonicalize_shape

# Dispatcher for the `ml` submodule
dispatch = Dispatcher()
Expand Down Expand Up @@ -86,17 +86,17 @@ def uniform_scaling(
raise ValueError(f"invalid mode for variance scaling initializer: {mode}")

if bias_like:
trim_named_shape = idem(lambda named_shp, shp, axis: as_named_shape(shp[axis:]))
trim_shape = idem(lambda cshp, shp, axis: canonicalize_shape(shp[axis:]))
else:
trim_named_shape = idem(lambda named_shp, shp, axis: named_shp)
trim_shape = idem(lambda cshp, shp, axis: cshp)

def init(key, shape, dtype=dtype):
args_named_shape = as_named_shape(shape)
named_shape = trim_named_shape(args_named_shape, shape, out_axis)
canonical_shape = canonicalize_shape(shape)
shape = trim_shape(canonical_shape, shape, out_axis)
# pylint: disable-next=W0212
fan_in, fan_out = initializers._compute_fans(args_named_shape, in_axis, out_axis)
fan_in, fan_out = initializers._compute_fans(canonical_shape, in_axis, out_axis)
s = np.array(scale / denominator(fan_in, fan_out), dtype=dtype)
return random.uniform(key, named_shape, dtype, -1) * transform(s)
return random.uniform(key, shape, dtype, -1) * transform(s)

return init

Expand Down
16 changes: 16 additions & 0 deletions pysages/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def prod(iterable, start=1):
return result


# Compatibility for jax >=0.4.31

# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0431-july-29-2024
if _jax_version_tuple < (0, 4, 31):
_jax_core = import_module("jax.core")

def canonicalize_shape(shape):
return _jax_core.as_named_shape(shape)

else:
_jax_core = import_module("jax._src.core")

def canonicalize_shape(shape):
return _jax_core.canonicalize_shape(shape)


# Compatibility for jax >=0.4.22

# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0422-dec-13-2023
Expand Down
Loading