-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example, tests for initializing Flax modules
- Loading branch information
Showing
5 changed files
with
853 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
.. Comment: this file is automatically generated by `update_example_docs.py`. | ||
It should not be modified manually. | ||
17. Flax Modules | ||
========================================== | ||
|
||
|
||
If you use `Flax <https://github.com/google/flax>`_\ , modules can be instantiated | ||
directly from ``dcargs.cli``. | ||
|
||
|
||
|
||
.. code-block:: python | ||
:linenos: | ||
from flax import linen as nn | ||
from jax import numpy as jnp | ||
import dcargs | ||
class Classifier(nn.Module): | ||
layers: int | ||
"""Layers in our network.""" | ||
units: int = 32 | ||
"""Hidden unit count.""" | ||
output_dim: int = 10 | ||
"""Number of classes.""" | ||
@nn.compact | ||
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore | ||
for i in range(self.layers - 1): | ||
x = nn.Dense( | ||
self.units, | ||
kernel_init=nn.initializers.kaiming_normal(), | ||
)(x) | ||
x = nn.relu(x) | ||
x = nn.Dense( | ||
self.output_dim, | ||
kernel_init=nn.initializers.xavier_normal(), | ||
)(x) | ||
x = nn.sigmoid(x) | ||
return x | ||
def train(model: Classifier, num_iterations: int = 1000) -> None: | ||
"""Train a model. | ||
Args: | ||
model: Model to train. | ||
num_iterations: Number of training iterations. | ||
""" | ||
print(f"{model=}") | ||
print(f"{num_iterations=}") | ||
if __name__ == "__main__": | ||
dcargs.cli(train) | ||
------------ | ||
|
||
.. raw:: html | ||
|
||
<kbd>python 17_flax_modules.py --help</kbd> | ||
|
||
.. program-output:: python ../../examples/17_flax_modules.py --help | ||
|
||
------------ | ||
|
||
.. raw:: html | ||
|
||
<kbd>python 17_flax_modules.py --model.layers 4</kbd> | ||
|
||
.. program-output:: python ../../examples/17_flax_modules.py --model.layers 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""If you use [Flax](https://github.com/google/flax), modules can be instantiated | ||
directly from `dcargs.cli`. | ||
Usage: | ||
`python ./17_flax_modules.py --help` | ||
`python ./17_flax_modules.py --model.layers 4` | ||
""" | ||
|
||
from flax import linen as nn | ||
from jax import numpy as jnp | ||
|
||
import dcargs | ||
|
||
|
||
class Classifier(nn.Module): | ||
layers: int | ||
"""Layers in our network.""" | ||
units: int = 32 | ||
"""Hidden unit count.""" | ||
output_dim: int = 10 | ||
"""Number of classes.""" | ||
|
||
@nn.compact | ||
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore | ||
for i in range(self.layers - 1): | ||
x = nn.Dense( | ||
self.units, | ||
kernel_init=nn.initializers.kaiming_normal(), | ||
)(x) | ||
x = nn.relu(x) | ||
|
||
x = nn.Dense( | ||
self.output_dim, | ||
kernel_init=nn.initializers.xavier_normal(), | ||
)(x) | ||
x = nn.sigmoid(x) | ||
return x | ||
|
||
|
||
def train(model: Classifier, num_iterations: int = 1000) -> None: | ||
"""Train a model. | ||
Args: | ||
model: Model to train. | ||
num_iterations: Number of training iterations. | ||
""" | ||
print(f"{model=}") | ||
print(f"{num_iterations=}") | ||
|
||
|
||
if __name__ == "__main__": | ||
dcargs.cli(train) |
Oops, something went wrong.