Skip to content

Commit

Permalink
[JAX-CFD] tweak installation & tests
Browse files Browse the repository at this point in the history
The tweaks to the test suite ensure everything passes on my local
installation on my MacBook, and shave a few minutes off the execution
time.

If the README instructions are followed, the whole test-suite runs in about
two and a half minutes on my Mac. It excludes the "validation tests", which
are too slow to run on CPU.

PiperOrigin-RevId: 374442976
  • Loading branch information
shoyer authored and JAX-CFD authors committed May 18, 2021
1 parent dfc7eec commit b47e8df
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 53 deletions.
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ We are currently preparing more example notebooks, inculding:
JAX-CFD is organized around sub-modules:

- `jax_cfd.base`: core numerical methods for CFD, written in JAX.
- `jax_cfd.ml` (not yet released): machine learning augmented models for CFD,
- `jax_cfd.ml`: machine learning augmented models for CFD,
written in JAX and [Haiku](https://dm-haiku.readthedocs.io/en/latest/).
- `jax_cfd.data`: data processing utilities for preparing, evaluating and
post-processing data created with JAX-CFD, written in
[Xarray](http://xarray.pydata.org/) and
[Pillow](https://pillow.readthedocs.io/).

A base install with `pip install jax-cfd` only requires NumPy, SciPy and JAX.
To install dependencies for the other submodules, use `pip install jax-cfd[ml]`,
`pip install jax-cfd[data]` or `pip install jax-cfd[complete]`.

## Numerics

JAX-CFD is currently focused on unsteady turbulent flows:
Expand Down Expand Up @@ -93,3 +97,18 @@ Did we miss something? Please let us know!
primaryClass={physics.flu-dyn}
}
```

## Local development

To locally install for development:
```
git clone https://github.com/google/jax-cfd.git
cd jax-cfd
pip install jaxlib
pip install -e ".[complete]"
```

Then to manually run the test suite:
```
pytest -n auto jax_cfd --dist=loadfile --ignore=jax_cfd/base/validation_test.py
```
3 changes: 2 additions & 1 deletion jax_cfd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

"""Defines the JAX-CFD module for computational fluid dynamics."""

__version__ = '0.1.0'

import jax_cfd.base
import jax_cfd.data
8 changes: 4 additions & 4 deletions jax_cfd/base/advection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class AdvectionTest(test_util.TestCase):
cfl_number=0.5,
atol=7e-2),
dict(testcase_name='upwind_3D',
shape=(101, 101, 101),
shape=(101, 5, 5),
method=_euler_step(adv.advect_upwind),
num_steps=100,
cfl_number=0.5,
Expand All @@ -112,7 +112,7 @@ class AdvectionTest(test_util.TestCase):
atol=2e-2,
v_sign=-1.),
dict(testcase_name='van_leer_3D',
shape=(101, 101, 101),
shape=(101, 5, 5),
method=_euler_step(adv.advect_van_leer),
num_steps=100,
cfl_number=0.5,
Expand All @@ -124,7 +124,7 @@ class AdvectionTest(test_util.TestCase):
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='van_leer_using_limiters_3D',
shape=(101, 101, 101),
shape=(101, 5, 5),
method=_euler_step(adv.advect_van_leer_using_limiters),
num_steps=100,
cfl_number=0.5,
Expand All @@ -136,7 +136,7 @@ class AdvectionTest(test_util.TestCase):
cfl_number=0.5,
atol=7e-2),
dict(testcase_name='semilagrangian_3D',
shape=(101, 101, 101),
shape=(101, 5, 5),
method=adv.advect_step_semilagrangian,
num_steps=100,
cfl_number=0.5,
Expand Down
9 changes: 4 additions & 5 deletions jax_cfd/base/equations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class SemiImplicitNavierStokesTest(test_util.TestCase):
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=1e-3),
momentum_atol=2e-3),
dict(testcase_name='gaussian_force_upwind',
velocity=zero_field,
forcing=lambda v, g: gaussian_field(g),
Expand All @@ -94,9 +94,9 @@ class SemiImplicitNavierStokesTest(test_util.TestCase):
convect=_convect_upwind,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=1e-3),
time_steps=100,
divergence_atol=1e-4,
momentum_atol=1e-4),
dict(testcase_name='sinusoidal_velocity_fast_diag',
velocity=sinusoidal_field,
forcing=None,
Expand Down Expand Up @@ -145,5 +145,4 @@ def test_divergence_and_momentum(


if __name__ == '__main__':
jax.config.update('jax_enable_x64', True)
absltest.main()
2 changes: 1 addition & 1 deletion jax_cfd/base/finite_differences_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def test_divergence(self, grid_type, shape, offsets, f, g, atol):
g=lambda x, y, z: np.array([[y - 1., jnp.zeros_like(x), z],
[x, z - 2., jnp.zeros_like(x)],
[jnp.zeros_like(x), y, x - 3.]]),
atol=1e-6),
atol=4e-6),
)
# pylint: enable=g-long-lambda
def test_cell_centered_gradient(self, shape, f, g, atol):
Expand Down
6 changes: 3 additions & 3 deletions jax_cfd/base/subgrid_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class SubgridModelsTest(test_util.TestCase):
convect=_convect_upwind,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=1e-3),
time_steps=100,
divergence_atol=1e-4,
momentum_atol=1e-4),
dict(
testcase_name='sinusoidal_velocity_with_subgrid_model',
cs=0.12,
Expand Down
6 changes: 1 addition & 5 deletions jax_cfd/data/visualization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@

import os.path

from absl import flags
from absl.testing import absltest
from jax_cfd.base import test_util
from jax_cfd.data import visualization
import numpy as np


FLAGS = flags.FLAGS


class VisualizationTest(test_util.TestCase):

def test_trajectory_to_images_shape(self):
Expand Down Expand Up @@ -54,7 +50,7 @@ def test_horizontal_facet_shape(self):

def test_save_movie_local(self):
"""Tests that save_movie write gif to a file."""
temp_dir = FLAGS.test_tmpdir
temp_dir = self.create_tempdir()
temp_filename = os.path.join(temp_dir, 'tmp_file.gif')
input_trajectory = np.random.uniform(size=(25, 32, 32))
images = visualization.trajectory_to_images(input_trajectory)
Expand Down
36 changes: 36 additions & 0 deletions jax_cfd/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An ML modeling library built on Haiku and Gin-Config for JAX-CFD."""

import jax_cfd.ml.advections
import jax_cfd.ml.decoders
import jax_cfd.ml.diffusions
import jax_cfd.ml.encoders
import jax_cfd.ml.equations
import jax_cfd.ml.forcings
import jax_cfd.ml.interpolations
import jax_cfd.ml.layers
import jax_cfd.ml.layers_util
import jax_cfd.ml.model_builder
import jax_cfd.ml.model_utils
import jax_cfd.ml.networks
import jax_cfd.ml.nonlinearities
import jax_cfd.ml.optimizer_modules
import jax_cfd.ml.physics_specifications
import jax_cfd.ml.pressures
import jax_cfd.ml.tiling
import jax_cfd.ml.time_integrators
import jax_cfd.ml.towers
import jax_cfd.ml.viscosities
41 changes: 21 additions & 20 deletions jax_cfd/ml/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,10 @@ def module_forward(x):

@parameterized.named_parameters([
('interpolation', (128, 128), (0, 0),
lambda x, y: np.sin(2 * x + y), lambda x, y: np.sin(2 * x + y), 1e-1),
lambda x, y: np.sin(2 * x + y), lambda x, y: np.sin(2 * x + y), 0.2),
('first_derivative_x', (128, 128), (1, 0),
lambda x, y: np.cos(2 * x + y), lambda x, y: -2 * np.sin(2 * x + y),
1e-1),
0.1),
])
def test_2d(self, grid_shape, derivative, initial_fn, expected_fn, atol):
"""Tests SpatialDerivative module in 2d."""
Expand All @@ -465,24 +465,25 @@ def test_2d(self, grid_shape, derivative, initial_fn, expected_fn, atol):
_tower_factory, ndims=ndims, conv_block=layers.PeriodicConv2D)

for extract_patches_method in ('conv', 'roll'):
def module_forward(inputs):
net = layers.SpatialDerivative(
stencil_sizes, grid.cell_center, grid.cell_center, derivative,
tower_factory, grid.step, extract_patches_method) # pylint: disable=cell-var-from-loop
return net(inputs)

rng = jax.random.PRNGKey(14)
spatial_derivative_model = hk.without_apply_rng(
hk.transform(module_forward))

x, y = grid.mesh()
inputs = np.expand_dims(initial_fn(x, y), -1) # add channel dimension
params = spatial_derivative_model.init(rng, inputs)
outputs = spatial_derivative_model.apply(params, inputs)
expected_outputs = np.expand_dims(expected_fn(x, y), -1)
np.testing.assert_allclose(
expected_outputs, outputs, atol=atol, rtol=0,
err_msg=f'Failed for method "{extract_patches_method}"')
with self.subTest(f'method_{extract_patches_method}'):
def module_forward(inputs):
net = layers.SpatialDerivative(
stencil_sizes, grid.cell_center, grid.cell_center, derivative,
tower_factory, grid.step, extract_patches_method) # pylint: disable=cell-var-from-loop
return net(inputs)

rng = jax.random.PRNGKey(14)
spatial_derivative_model = hk.without_apply_rng(
hk.transform(module_forward))

x, y = grid.mesh()
inputs = np.expand_dims(initial_fn(x, y), -1) # add channel dimension
params = spatial_derivative_model.init(rng, inputs)
outputs = spatial_derivative_model.apply(params, inputs)
expected_outputs = np.expand_dims(expected_fn(x, y), -1)
np.testing.assert_allclose(
expected_outputs, outputs, atol=atol, rtol=0,
err_msg=f'Failed for method "{extract_patches_method}"')

def test_auxiliary_inputs(self):
"""Tests that auxiliary inputs don't change shape of the output."""
Expand Down
25 changes: 12 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,24 @@
"""Setup JAX-CFD."""
import setuptools


INSTALL_REQUIRES = [
'absl-py',
'jax',
'numpy',
'scipy',
'matplotlib',
'seaborn',
'Pillow',
'xarray',
]
base_requires = ['jax', 'numpy', 'scipy']
data_requires = ['matplotlib', 'seaborn', 'Pillow', 'xarray']
ml_requires = ['dm-haiku', 'einops', 'gin-config']
tests_requires = ['absl-py', 'pytest', 'pytest-xdist', 'scikit-image']

setuptools.setup(
name='jax-cfd',
version='0.0.0',
version='0.1.0',
license='Apache 2.0',
author='Google LLC',
author_email='[email protected]',
install_requires=INSTALL_REQUIRES,
install_requires=base_requires,
extras_require={
'data': data_requires,
'ml': ml_requires,
'tests': tests_requires,
'complete': data_requires + ml_requires + tests_requires,
},
url='https://github.com/google/jax-cfd',
packages=setuptools.find_packages(),
python_requires='>=3',
Expand Down

0 comments on commit b47e8df

Please sign in to comment.