Skip to content

Commit

Permalink
Update MLP Image Classification to Keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
nkovela1 committed Nov 8, 2023
1 parent 7cbf95c commit c50bd24
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 501 deletions.
178 changes: 113 additions & 65 deletions examples/vision/ipynb/mlp_image_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>\n",
"**Date created:** 2021/05/30<br>\n",
"**Last modified:** 2021/05/30<br>\n",
"**Last modified:** 2023/08/03<br>\n",
"**Description:** Implementing the MLP-Mixer, FNet, and gMLP models for CIFAR-100 image classification."
]
},
Expand All @@ -32,15 +32,7 @@
"\n",
"The purpose of the example is not to compare between these models, as they might perform differently on\n",
"different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their\n",
"main building blocks.\n",
"\n",
"This example requires TensorFlow 2.4 or higher, as well as\n",
"[TensorFlow Addons](https://www.tensorflow.org/addons/overview),\n",
"which can be installed using the following command:\n",
"\n",
"```shell\n",
"pip install -U tensorflow-addons\n",
"```"
"main building blocks."
]
},
{
Expand All @@ -61,10 +53,8 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import tensorflow_addons as tfa"
"import keras\n",
"from keras import layers"
]
},
{
Expand Down Expand Up @@ -112,7 +102,7 @@
"source": [
"weight_decay = 0.0001\n",
"batch_size = 128\n",
"num_epochs = 50\n",
"num_epochs = 1 # Recommended num_epochs = 50\n",
"dropout_rate = 0.2\n",
"image_size = 64 # We'll resize input images to this size.\n",
"patch_size = 8 # Size of the patches to be extracted from the input images.\n",
Expand Down Expand Up @@ -151,15 +141,11 @@
" # Augment data.\n",
" augmented = data_augmentation(inputs)\n",
" # Create patches.\n",
" patches = Patches(patch_size, num_patches)(augmented)\n",
" patches = Patches(patch_size)(augmented)\n",
" # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.\n",
" x = layers.Dense(units=embedding_dim)(patches)\n",
" if positional_encoding:\n",
" positions = tf.range(start=0, limit=num_patches, delta=1)\n",
" position_embedding = layers.Embedding(\n",
" input_dim=num_patches, output_dim=embedding_dim\n",
" )(positions)\n",
" x = x + position_embedding\n",
" x = x + PositionEmbedding(sequence_length=num_patches)(x)\n",
" # Process x using the module blocks.\n",
" x = blocks(x)\n",
" # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.\n",
Expand Down Expand Up @@ -195,8 +181,9 @@
"\n",
"def run_experiment(model):\n",
" # Create Adam optimizer with weight decay.\n",
" optimizer = tfa.optimizers.AdamW(\n",
" learning_rate=learning_rate, weight_decay=weight_decay,\n",
" optimizer = keras.optimizers.AdamW(\n",
" learning_rate=learning_rate,\n",
" weight_decay=weight_decay,\n",
" )\n",
" # Compile the model.\n",
" model.compile(\n",
Expand All @@ -212,7 +199,7 @@
" monitor=\"val_loss\", factor=0.5, patience=5\n",
" )\n",
" # Create an early stopping callback.\n",
" early_stopping = tf.keras.callbacks.EarlyStopping(\n",
" early_stopping = keras.callbacks.EarlyStopping(\n",
" monitor=\"val_loss\", patience=10, restore_best_weights=True\n",
" )\n",
" # Fit the model.\n",
Expand Down Expand Up @@ -256,9 +243,7 @@
" layers.Normalization(),\n",
" layers.Resizing(image_size, image_size),\n",
" layers.RandomFlip(\"horizontal\"),\n",
" layers.RandomZoom(\n",
" height_factor=0.2, width_factor=0.2\n",
" ),\n",
" layers.RandomZoom(height_factor=0.2, width_factor=0.2),\n",
" ],\n",
" name=\"data_augmentation\",\n",
")\n",
Expand Down Expand Up @@ -286,23 +271,88 @@
"source": [
"\n",
"class Patches(layers.Layer):\n",
" def __init__(self, patch_size, num_patches):\n",
" super().__init__()\n",
" def __init__(self, patch_size, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.patch_size = patch_size\n",
" self.num_patches = num_patches\n",
"\n",
" def call(self, images):\n",
" batch_size = tf.shape(images)[0]\n",
" patches = tf.image.extract_patches(\n",
" images=images,\n",
" sizes=[1, self.patch_size, self.patch_size, 1],\n",
" strides=[1, self.patch_size, self.patch_size, 1],\n",
" rates=[1, 1, 1, 1],\n",
" padding=\"VALID\",\n",
"\n",
" def call(self, x):\n",
" patches = keras.ops.image.extract_patches(x, self.patch_size)\n",
" batch_size = keras.ops.shape(patches)[0]\n",
" num_patches = keras.ops.shape(patches)[1] * keras.ops.shape(patches)[2]\n",
" patch_dim = keras.ops.shape(patches)[3]\n",
" out = keras.ops.reshape(patches, (batch_size, num_patches, patch_dim))\n",
" return out\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"## Implement position embedding as a layer"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"class PositionEmbedding(keras.layers.Layer):\n",
" def __init__(\n",
" self,\n",
" sequence_length,\n",
" initializer=\"glorot_uniform\",\n",
" **kwargs,\n",
" ):\n",
" super().__init__(**kwargs)\n",
" if sequence_length is None:\n",
" raise ValueError(\"`sequence_length` must be an Integer, received `None`.\")\n",
" self.sequence_length = int(sequence_length)\n",
" self.initializer = keras.initializers.get(initializer)\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
" config.update(\n",
" {\n",
" \"sequence_length\": self.sequence_length,\n",
" \"initializer\": keras.initializers.serialize(self.initializer),\n",
" }\n",
" )\n",
" return config\n",
"\n",
" def build(self, input_shape):\n",
" feature_size = input_shape[-1]\n",
" self.position_embeddings = self.add_weight(\n",
" name=\"embeddings\",\n",
" shape=[self.sequence_length, feature_size],\n",
" initializer=self.initializer,\n",
" trainable=True,\n",
" )\n",
"\n",
" super().build(input_shape)\n",
"\n",
" def call(self, inputs, start_index=0):\n",
" shape = keras.ops.shape(inputs)\n",
" feature_length = shape[-1]\n",
" sequence_length = shape[-2]\n",
" # trim to match the length of the input sequence, which might be less\n",
" # than the sequence_length of the layer.\n",
" position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)\n",
" position_embeddings = keras.ops.slice(\n",
" position_embeddings,\n",
" (start_index, 0),\n",
" (sequence_length, feature_length),\n",
" )\n",
" patch_dims = patches.shape[-1]\n",
" patches = tf.reshape(patches, [batch_size, self.num_patches, patch_dims])\n",
" return patches\n",
" return keras.ops.broadcast_to(position_embeddings, shape)\n",
"\n",
" def compute_output_shape(self, input_shape):\n",
" return input_shape\n",
""
]
},
Expand All @@ -320,7 +370,7 @@
"1. One applied independently to image patches, which mixes the per-location features.\n",
"2. The other applied across patches (along channels), which mixes spatial information.\n",
"\n",
"This is similar to a [depthwise separable convolution based model](https://arxiv.org/pdf/1610.02357.pdf)\n",
"This is similar to a [depthwise separable convolution based model](https://arxiv.org/abs/1610.02357)\n",
"such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization\n",
"instead of batch normalization."
]
Expand Down Expand Up @@ -349,31 +399,32 @@
"\n",
" self.mlp1 = keras.Sequential(\n",
" [\n",
" layers.Dense(units=num_patches),\n",
" tfa.layers.GELU(),\n",
" layers.Dense(units=num_patches, activation=\"gelu\"),\n",
" layers.Dense(units=num_patches),\n",
" layers.Dropout(rate=dropout_rate),\n",
" ]\n",
" )\n",
" self.mlp2 = keras.Sequential(\n",
" [\n",
" layers.Dense(units=num_patches),\n",
" tfa.layers.GELU(),\n",
" layers.Dense(units=embedding_dim),\n",
" layers.Dense(units=num_patches, activation=\"gelu\"),\n",
" layers.Dense(units=hidden_units),\n",
" layers.Dropout(rate=dropout_rate),\n",
" ]\n",
" )\n",
" self.normalize = layers.LayerNormalization(epsilon=1e-6)\n",
"\n",
" def build(self, input_shape):\n",
" return super().build(input_shape)\n",
"\n",
" def call(self, inputs):\n",
" # Apply layer normalization.\n",
" x = self.normalize(inputs)\n",
" # Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].\n",
" x_channels = tf.linalg.matrix_transpose(x)\n",
" x_channels = keras.ops.transpose(x, axes=(0, 2, 1))\n",
" # Apply mlp1 on each channel independently.\n",
" mlp1_outputs = self.mlp1(x_channels)\n",
" # Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].\n",
" mlp1_outputs = tf.linalg.matrix_transpose(mlp1_outputs)\n",
" mlp1_outputs = keras.ops.transpose(mlp1_outputs, axes=(0, 2, 1))\n",
" # Add skip connection.\n",
" x = mlp1_outputs + inputs\n",
" # Apply layer normalization.\n",
Expand Down Expand Up @@ -466,13 +517,12 @@
"source": [
"\n",
"class FNetLayer(layers.Layer):\n",
" def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):\n",
" def __init__(self, embedding_dim, dropout_rate, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
"\n",
" self.ffn = keras.Sequential(\n",
" [\n",
" layers.Dense(units=embedding_dim),\n",
" tfa.layers.GELU(),\n",
" layers.Dense(units=embedding_dim, activation=\"gelu\"),\n",
" layers.Dropout(rate=dropout_rate),\n",
" layers.Dense(units=embedding_dim),\n",
" ]\n",
Expand All @@ -483,10 +533,9 @@
"\n",
" def call(self, inputs):\n",
" # Apply fourier transformations.\n",
" x = tf.cast(\n",
" tf.signal.fft2d(tf.cast(inputs, dtype=tf.dtypes.complex64)),\n",
" dtype=tf.dtypes.float32,\n",
" )\n",
" real_part = inputs\n",
" im_part = keras.ops.zeros_like(inputs)\n",
" x = keras.ops.fft2((real_part, im_part))[0]\n",
" # Add skip connection.\n",
" x = x + inputs\n",
" # Apply layer normalization.\n",
Expand Down Expand Up @@ -521,7 +570,7 @@
"outputs": [],
"source": [
"fnet_blocks = keras.Sequential(\n",
" [FNetLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]\n",
" [FNetLayer(embedding_dim, dropout_rate) for _ in range(num_blocks)]\n",
")\n",
"learning_rate = 0.001\n",
"fnet_classifier = build_classifier(fnet_blocks, positional_encoding=True)\n",
Expand Down Expand Up @@ -581,8 +630,7 @@
"\n",
" self.channel_projection1 = keras.Sequential(\n",
" [\n",
" layers.Dense(units=embedding_dim * 2),\n",
" tfa.layers.GELU(),\n",
" layers.Dense(units=embedding_dim * 2, activation=\"gelu\"),\n",
" layers.Dropout(rate=dropout_rate),\n",
" ]\n",
" )\n",
Expand All @@ -598,14 +646,14 @@
"\n",
" def spatial_gating_unit(self, x):\n",
" # Split x along the channel dimensions.\n",
" # Tensors u and v will in th shape of [batch_size, num_patchs, embedding_dim].\n",
" u, v = tf.split(x, num_or_size_splits=2, axis=2)\n",
" # Tensors u and v will in the shape of [batch_size, num_patchs, embedding_dim].\n",
" u, v = keras.ops.split(x, indices_or_sections=2, axis=2)\n",
" # Apply layer normalization.\n",
" v = self.normalize2(v)\n",
" # Apply spatial projection.\n",
" v_channels = tf.linalg.matrix_transpose(v)\n",
" v_channels = keras.ops.transpose(v, axes=(0, 2, 1))\n",
" v_projected = self.spatial_projection(v_channels)\n",
" v_projected = tf.linalg.matrix_transpose(v_projected)\n",
" v_projected = keras.ops.transpose(v_projected, axes=(0, 2, 1))\n",
" # Apply element-wise multiplication.\n",
" return u * v_projected\n",
"\n",
Expand Down Expand Up @@ -659,7 +707,7 @@
"source": [
"As shown in the [gMLP](https://arxiv.org/abs/2105.08050) paper,\n",
"better results can be achieved by increasing the embedding dimensions,\n",
"increasing, increasing the number of gMLP blocks, and training the model for longer.\n",
"increasing the number of gMLP blocks, and training the model for longer.\n",
"You may also try to increase the size of the input images and use different patch sizes.\n",
"Note that, the paper used advanced regularization strategies, such as MixUp and CutMix,\n",
"as well as AutoAugment."
Expand Down
Loading

0 comments on commit c50bd24

Please sign in to comment.