-
Notifications
You must be signed in to change notification settings - Fork 177
/
haiku_simple.py
176 lines (140 loc) · 5.53 KB
/
haiku_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Optuna example that optimizes a neural network classifier configuration for the
MNIST dataset using Jax and Haiku.
In this example, we optimize the validation accuracy of MNIST classification using
jax nad haiku. We optimize the number of linear layers and learning rate of the optimizer.
The example code is based on https://github.com/deepmind/dm-haiku/blob/master/examples/mnist.py
"""
import os
from typing import Any
from typing import Generator
from typing import Mapping
from typing import Tuple
import urllib
import jax
import jax.numpy as jnp
import numpy as np
import optax
import optuna
import tensorflow_datasets as tfds
import haiku as hk
# TODO(crcrpar): Remove the below three lines once everything is ok.
# Register a global custom opener to avoid HTTP Error 403: Forbidden when downloading MNIST.
opener = urllib.request.build_opener()
opener.addheaders = [("User-agent", "Mozilla/5.0")]
urllib.request.install_opener(opener)
OptState = Any
Batch = Mapping[str, np.ndarray]
BATCH_SIZE = 128
TRAIN_STEPS = 1000
N_TRAIN_SAMPLES = 3000
N_VALID_SAMPLES = 1000
# disable tf's warning messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def load_dataset(
split: str,
*,
is_training: bool,
batch_size: int,
sample_size: int,
) -> Generator[Batch, None, None]:
"""Loads the sub-sampled dataset as a generator of batches."""
ds = tfds.load("mnist:3.*.*", split=split).take(sample_size).cache().repeat()
if is_training:
ds = ds.shuffle(sample_size, seed=0)
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds))
def objective(trial):
# Make datasets.
train = load_dataset(
"train", is_training=True, batch_size=BATCH_SIZE, sample_size=N_TRAIN_SAMPLES
)
train_eval = load_dataset(
"train", is_training=False, batch_size=BATCH_SIZE, sample_size=N_TRAIN_SAMPLES
)
test_eval = load_dataset(
"test", is_training=False, batch_size=BATCH_SIZE, sample_size=N_VALID_SAMPLES
)
# Draw the hyperparameters
n_units_l1 = trial.suggest_int("n_units_l1", 4, 128)
n_units_l2 = trial.suggest_int("n_units_l2", 4, 128)
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
# Define feed-forward function by using sampled parameters
def net_fn(batch: Batch) -> jnp.ndarray:
"""Standard MLP network."""
x = batch["image"].astype(jnp.float32) / 255.0
mlp = hk.Sequential(
[
hk.Flatten(),
hk.Linear(n_units_l1),
jax.nn.relu,
hk.Linear(n_units_l2),
jax.nn.relu,
hk.Linear(10),
]
)
return mlp(x)
# Make the network and optimiser.
net = hk.without_apply_rng(hk.transform(net_fn))
opt = optax.adam(lr)
# Training loss (cross-entropy).
def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
"""Compute the loss of the network, including L2."""
logits = net.apply(params, batch)
labels = jax.nn.one_hot(batch["label"], 10)
l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
softmax_xent /= labels.shape[0]
return softmax_xent + 1e-4 * l2_loss
# Evaluation metric (classification accuracy).
@jax.jit
def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:
predictions = net.apply(params, batch)
return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"])
@jax.jit
def update(
params: hk.Params,
opt_state: OptState,
batch: Batch,
) -> Tuple[hk.Params, OptState]:
"""Learning rule (stochastic gradient descent)."""
grads = jax.grad(loss)(params, batch)
updates, opt_state = opt.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, opt_state
# Initialize network and optimiser; note we draw an input to get shapes.
params = net.init(jax.random.PRNGKey(42), next(train))
opt_state = opt.init(params)
best_test_accuracy = 0.0
# Train/eval loop.
for step in range(1, TRAIN_STEPS + 1):
if step % 100 == 0:
# Periodically evaluate classification accuracy on train & test sets.
train_accuracy = accuracy(params, next(train_eval))
test_accuracy = accuracy(params, next(test_eval))
train_accuracy, test_accuracy = jax.device_get((train_accuracy, test_accuracy))
print(
f"[Step {step:5d}] Train / Test accuracy: "
f"{train_accuracy:.3f} / {test_accuracy:.3f}."
)
# Handle pruning based on the intermediate value.
trial.report(test_accuracy, step)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
best_test_accuracy = max(best_test_accuracy, test_accuracy)
# Do SGD on a batch of training examples.
params, opt_state = update(params, opt_state, next(train))
return best_test_accuracy
if __name__ == "__main__":
study = optuna.create_study(
direction="maximize",
pruner=optuna.pruners.MedianPruner(n_startup_trials=2, interval_steps=1000),
)
study.optimize(objective, n_trials=100, timeout=600)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))