Skip to content

Commit

Permalink
Add example, tests for initializing Flax modules
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Sep 21, 2022
1 parent a233a60 commit f7e4844
Show file tree
Hide file tree
Showing 5 changed files with 853 additions and 31 deletions.
76 changes: 76 additions & 0 deletions docs/source/examples/17_flax_modules.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
.. Comment: this file is automatically generated by `update_example_docs.py`.
It should not be modified manually.
17. Flax Modules
==========================================


If you use `Flax <https://github.com/google/flax>`_\ , modules can be instantiated
directly from ``dcargs.cli``.



.. code-block:: python
:linenos:
from flax import linen as nn
from jax import numpy as jnp
import dcargs
class Classifier(nn.Module):
layers: int
"""Layers in our network."""
units: int = 32
"""Hidden unit count."""
output_dim: int = 10
"""Number of classes."""
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore
for i in range(self.layers - 1):
x = nn.Dense(
self.units,
kernel_init=nn.initializers.kaiming_normal(),
)(x)
x = nn.relu(x)
x = nn.Dense(
self.output_dim,
kernel_init=nn.initializers.xavier_normal(),
)(x)
x = nn.sigmoid(x)
return x
def train(model: Classifier, num_iterations: int = 1000) -> None:
"""Train a model.
Args:
model: Model to train.
num_iterations: Number of training iterations.
"""
print(f"{model=}")
print(f"{num_iterations=}")
if __name__ == "__main__":
dcargs.cli(train)
------------

.. raw:: html

<kbd>python 17_flax_modules.py --help</kbd>

.. program-output:: python ../../examples/17_flax_modules.py --help

------------

.. raw:: html

<kbd>python 17_flax_modules.py --model.layers 4</kbd>

.. program-output:: python ../../examples/17_flax_modules.py --model.layers 4
52 changes: 52 additions & 0 deletions examples/17_flax_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""If you use [Flax](https://github.com/google/flax), modules can be instantiated
directly from `dcargs.cli`.
Usage:
`python ./17_flax_modules.py --help`
`python ./17_flax_modules.py --model.layers 4`
"""

from flax import linen as nn
from jax import numpy as jnp

import dcargs


class Classifier(nn.Module):
layers: int
"""Layers in our network."""
units: int = 32
"""Hidden unit count."""
output_dim: int = 10
"""Number of classes."""

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore
for i in range(self.layers - 1):
x = nn.Dense(
self.units,
kernel_init=nn.initializers.kaiming_normal(),
)(x)
x = nn.relu(x)

x = nn.Dense(
self.output_dim,
kernel_init=nn.initializers.xavier_normal(),
)(x)
x = nn.sigmoid(x)
return x


def train(model: Classifier, num_iterations: int = 1000) -> None:
"""Train a model.
Args:
model: Model to train.
num_iterations: Number of training iterations.
"""
print(f"{model=}")
print(f"{num_iterations=}")


if __name__ == "__main__":
dcargs.cli(train)
Loading

0 comments on commit f7e4844

Please sign in to comment.