Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to jax 0.4.19 #11

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/action/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu20.04
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip cmake
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: '3.10'

- name: Install dependencies
run: |
Expand Down
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ everything that I'll talk about is covered in more detail somewhere else (even
if that somewhere is just a comment in some source code), but hopefully this
summary can point you in the right direction if you have a use case like this.

**A warning**: I'm writing this in January 2021 and much of what I'm talking
about is based on essentially undocumented APIs that are likely to change.
**A warning**: I'm writing this in January 2021 (most recent update November 2023; see
github for the full revision history) and much of what I'm talking about is based on
essentially undocumented APIs that are likely to change.
Furthermore, I'm not affiliated with the JAX project and I'm far from an expert
so I'm sure there are wrong things that I say. I'll try to update this if I
notice things changing or if I learn of issues, but no promises! So, MIT license
Expand Down Expand Up @@ -358,7 +359,7 @@ from jax.lib import xla_client
from kepler_jax import cpu_ops

for _name, _value in cpu_ops.registrations().items():
xla_client.register_cpu_custom_call_target(_name, _value)
xla_client.register_custom_call_target(_name, _value, platform="cpu")
```

Then, the **lowering rule** is defined roughly as follows (the one you'll
Expand Down Expand Up @@ -400,13 +401,13 @@ def _kepler_lowering(ctx, mean_anom, ecc):
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
result_types=[dtype, dtype],
# The inputs:
operands=[mlir.ir_constant(size), mean_anom, ecc],
# Layout specification:
operand_layouts=[(), layout, layout],
result_layouts=[layout, layout]
)
).results

mlir.register_lowering(
_kepler_prim,
Expand Down Expand Up @@ -651,15 +652,15 @@ def _kepler_lowering_gpu(ctx, mean_anom, ecc):
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
result_types=[dtype, dtype],
# The inputs:
operands=[mean_anom, ecc],
# Layout specification:
operand_layouts=[layout, layout],
result_layouts=[layout, layout],
# GPU-specific additional data for the kernel
backend_config=opaque
)
).results

mlir.register_lowering(
_kepler_prim,
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def build_extension(self, ext):
packages=find_packages("src"),
package_dir={"": "src"},
include_package_data=True,
install_requires=["jax", "jaxlib"],
install_requires=[
"jax>=0.4.16",
"jaxlib>=0.4.16"
],
extras_require={"test": "pytest"},
ext_modules=extensions,
cmdclass={"build_ext": CMakeBuildExt},
Expand Down
12 changes: 6 additions & 6 deletions src/kepler_jax/kepler_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from jax import core, dtypes, lax
from jax import numpy as jnp
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.interpreters import ad, batching, mlir, xla
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
Expand All @@ -16,7 +16,7 @@
from . import cpu_ops

for _name, _value in cpu_ops.registrations().items():
xla_client.register_cpu_custom_call_target(_name, _value)
xla_client.register_custom_call_target(_name, _value, platform="cpu")

# If the GPU version exists, also register those
try:
Expand Down Expand Up @@ -93,13 +93,13 @@ def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"):
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
result_types=[dtype, dtype],
# The inputs:
operands=[mlir.ir_constant(size), mean_anom, ecc],
# Layout specification:
operand_layouts=[(), layout, layout],
result_layouts=[layout, layout]
)
).results

elif platform == "gpu":
if gpu_ops is None:
Expand All @@ -113,15 +113,15 @@ def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"):
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
result_types=[dtype, dtype],
# The inputs:
operands=[mean_anom, ecc],
# Layout specification:
operand_layouts=[layout, layout],
result_layouts=[layout, layout],
# GPU specific additional data
backend_config=opaque
)
).results

raise ValueError(
"Unsupported platform; this must be either 'cpu' or 'gpu'"
Expand Down
Loading