Skip to content

Commit

Permalink
Update EANet to Keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
nkovela1 committed Nov 9, 2023
1 parent a0d1df5 commit 1671337
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 224 deletions.
74 changes: 34 additions & 40 deletions examples/vision/eanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Image classification with EANet (External Attention Transformer)
Author: [ZhiYong Chang](https://github.com/czy00000)
Date created: 2021/10/19
Last modified: 2021/10/19
Last modified: 2023/07/18
Description: Image classification with a Transformer that leverages external attention.
Accelerator: GPU
"""
Expand All @@ -18,25 +18,16 @@
linear layers and two normalization layers. It conveniently replaces self-attention
as used in existing architectures. External attention has linear complexity, as it only
implicitly considers the correlations between all samples.
This example requires TensorFlow 2.5 or higher, as well as
[TensorFlow Addons](https://www.tensorflow.org/addons/overview) package,
which can be installed using the following command:
```python
pip install -U tensorflow-addons
```
"""

"""
## Setup
"""

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import keras
from keras import layers
from keras import ops

import matplotlib.pyplot as plt


Expand All @@ -62,7 +53,7 @@
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
num_epochs = 50
num_epochs = 1 # Recommended num_epochs = 1.
patch_size = 2 # Size of the patches to be extracted from the input images.
num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch
embedding_dim = 64 # Number of hidden units.
Expand Down Expand Up @@ -104,18 +95,11 @@ def __init__(self, patch_size, **kwargs):
super().__init__(**kwargs)
self.patch_size = patch_size

def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=(1, self.patch_size, self.patch_size, 1),
strides=(1, self.patch_size, self.patch_size, 1),
rates=(1, 1, 1, 1),
padding="VALID",
)
patch_dim = patches.shape[-1]
patch_num = patches.shape[1]
return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))
def call(self, x):
B, C = ops.shape(x)[0], ops.shape(x)[-1]
x = ops.image.extract_patches(x, self.patch_size)
x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))
return x


class PatchEmbedding(layers.Layer):
Expand All @@ -126,7 +110,7 @@ def __init__(self, num_patch, embed_dim, **kwargs):
self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

def call(self, patch):
pos = tf.range(start=0, limit=self.num_patch, delta=1)
pos = ops.arange(start=0, stop=self.num_patch, step=1)
return self.proj(patch) + self.pos_embed(pos)


Expand All @@ -136,29 +120,37 @@ def call(self, patch):


def external_attention(
x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0
x,
dim,
num_heads,
dim_coefficient=4,
attention_dropout=0,
projection_dropout=0,
):
_, num_patch, channel = x.shape
assert dim % num_heads == 0
num_heads = num_heads * dim_coefficient

x = layers.Dense(dim * dim_coefficient)(x)
# create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
x = tf.reshape(
x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)
)
x = tf.transpose(x, perm=[0, 2, 1, 3])
x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))
x = ops.transpose(x, axes=[0, 2, 1, 3])
# a linear layer M_k
attn = layers.Dense(dim // dim_coefficient)(x)
# normalize attention map
attn = layers.Softmax(axis=2)(attn)
# dobule-normalization
attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))
attn = layers.Lambda(
lambda attn: ops.divide(
attn,
ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),
)
)(attn)
attn = layers.Dropout(attention_dropout)(attn)
# a linear layer M_v
x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
x = tf.transpose(x, perm=[0, 2, 1, 3])
x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])
x = ops.transpose(x, axes=[0, 2, 1, 3])
x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])
# a linear layer to project original dim
x = layers.Dense(dim)(x)
x = layers.Dropout(projection_dropout)(x)
Expand All @@ -171,7 +163,7 @@ def external_attention(


def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)
x = layers.Dense(mlp_dim, activation=ops.gelu)(x)
x = layers.Dropout(drop_rate)(x)
x = layers.Dense(embedding_dim)(x)
x = layers.Dropout(drop_rate)(x)
Expand Down Expand Up @@ -206,7 +198,9 @@ def transformer_encoder(
)
elif attention_type == "self_attention":
x = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout
num_heads=num_heads,
key_dim=embedding_dim,
dropout=attention_dropout,
)(x, x)
x = layers.add([x, residual_1])
residual_2 = x
Expand Down Expand Up @@ -256,7 +250,7 @@ def get_model(attention_type="external_attention"):
attention_type,
)

x = layers.GlobalAvgPool1D()(x)
x = layers.GlobalAveragePooling1D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
Expand All @@ -272,7 +266,7 @@ def get_model(attention_type="external_attention"):

model.compile(
loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
optimizer=tfa.optimizers.AdamW(
optimizer=keras.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
),
metrics=[
Expand Down
Binary file added examples/vision/img/eanet/eanet_24_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
78 changes: 38 additions & 40 deletions examples/vision/ipynb/eanet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [ZhiYong Chang](https://github.com/czy00000)<br>\n",
"**Date created:** 2021/10/19<br>\n",
"**Last modified:** 2021/10/19<br>\n",
"**Last modified:** 2023/07/18<br>\n",
"**Description:** Image classification with a Transformer that leverages external attention."
]
},
Expand All @@ -21,20 +21,15 @@
},
"source": [
"## Introduction\n",
"\n",
"This example implements the [EANet](https://arxiv.org/abs/2105.02358)\n",
"model for image classification, and demonstrates it on the CIFAR-100 dataset.\n",
"EANet introduces a novel attention mechanism\n",
"named ***external attention***, based on two external, small, learnable, and\n",
"shared memories, which can be implemented easily by simply using two cascaded\n",
"linear layers and two normalization layers. It conveniently replaces self-attention\n",
"as used in existing architectures. External attention has linear complexity, as it only\n",
"implicitly considers the correlations between all samples.\n",
"This example requires TensorFlow 2.5 or higher, as well as\n",
"[TensorFlow Addons](https://www.tensorflow.org/addons/overview) package,\n",
"which can be installed using the following command:\n",
"```python\n",
"pip install -U tensorflow-addons\n",
"```"
"implicitly considers the correlations between all samples."
]
},
{
Expand All @@ -54,11 +49,10 @@
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import tensorflow_addons as tfa\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"\n",
"import matplotlib.pyplot as plt\n",
""
]
Expand Down Expand Up @@ -112,7 +106,7 @@
"label_smoothing = 0.1\n",
"validation_split = 0.2\n",
"batch_size = 128\n",
"num_epochs = 50\n",
"num_epochs = 1 # Recommended num_epochs = 1.\n",
"patch_size = 2 # Size of the patches to be extracted from the input images.\n",
"num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch\n",
"embedding_dim = 64 # Number of hidden units.\n",
Expand Down Expand Up @@ -182,18 +176,11 @@
" super().__init__(**kwargs)\n",
" self.patch_size = patch_size\n",
"\n",
" def call(self, images):\n",
" batch_size = tf.shape(images)[0]\n",
" patches = tf.image.extract_patches(\n",
" images=images,\n",
" sizes=(1, self.patch_size, self.patch_size, 1),\n",
" strides=(1, self.patch_size, self.patch_size, 1),\n",
" rates=(1, 1, 1, 1),\n",
" padding=\"VALID\",\n",
" )\n",
" patch_dim = patches.shape[-1]\n",
" patch_num = patches.shape[1]\n",
" return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))\n",
" def call(self, x):\n",
" B, C = ops.shape(x)[0], ops.shape(x)[-1]\n",
" x = ops.image.extract_patches(x, self.patch_size)\n",
" x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))\n",
" return x\n",
"\n",
"\n",
"class PatchEmbedding(layers.Layer):\n",
Expand All @@ -204,7 +191,7 @@
" self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)\n",
"\n",
" def call(self, patch):\n",
" pos = tf.range(start=0, limit=self.num_patch, delta=1)\n",
" pos = ops.arange(start=0, stop=self.num_patch, step=1)\n",
" return self.proj(patch) + self.pos_embed(pos)\n",
""
]
Expand All @@ -228,29 +215,37 @@
"source": [
"\n",
"def external_attention(\n",
" x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0\n",
" x,\n",
" dim,\n",
" num_heads,\n",
" dim_coefficient=4,\n",
" attention_dropout=0,\n",
" projection_dropout=0,\n",
"):\n",
" _, num_patch, channel = x.shape\n",
" assert dim % num_heads == 0\n",
" num_heads = num_heads * dim_coefficient\n",
"\n",
" x = layers.Dense(dim * dim_coefficient)(x)\n",
" # create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]\n",
" x = tf.reshape(\n",
" x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)\n",
" )\n",
" x = tf.transpose(x, perm=[0, 2, 1, 3])\n",
" x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))\n",
" x = ops.transpose(x, axes=[0, 2, 1, 3])\n",
" # a linear layer M_k\n",
" attn = layers.Dense(dim // dim_coefficient)(x)\n",
" # normalize attention map\n",
" attn = layers.Softmax(axis=2)(attn)\n",
" # dobule-normalization\n",
" attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))\n",
" attn = layers.Lambda(\n",
" lambda attn: ops.divide(\n",
" attn,\n",
" ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),\n",
" )\n",
" )(attn)\n",
" attn = layers.Dropout(attention_dropout)(attn)\n",
" # a linear layer M_v\n",
" x = layers.Dense(dim * dim_coefficient // num_heads)(attn)\n",
" x = tf.transpose(x, perm=[0, 2, 1, 3])\n",
" x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])\n",
" x = ops.transpose(x, axes=[0, 2, 1, 3])\n",
" x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])\n",
" # a linear layer to project original dim\n",
" x = layers.Dense(dim)(x)\n",
" x = layers.Dropout(projection_dropout)(x)\n",
Expand All @@ -277,7 +272,7 @@
"source": [
"\n",
"def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):\n",
" x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)\n",
" x = layers.Dense(mlp_dim, activation=ops.gelu)(x)\n",
" x = layers.Dropout(drop_rate)(x)\n",
" x = layers.Dense(embedding_dim)(x)\n",
" x = layers.Dropout(drop_rate)(x)\n",
Expand Down Expand Up @@ -326,7 +321,9 @@
" )\n",
" elif attention_type == \"self_attention\":\n",
" x = layers.MultiHeadAttention(\n",
" num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout\n",
" num_heads=num_heads,\n",
" key_dim=embedding_dim,\n",
" dropout=attention_dropout,\n",
" )(x, x)\n",
" x = layers.add([x, residual_1])\n",
" residual_2 = x\n",
Expand Down Expand Up @@ -395,7 +392,7 @@
" attention_type,\n",
" )\n",
"\n",
" x = layers.GlobalAvgPool1D()(x)\n",
" x = layers.GlobalAveragePooling1D()(x)\n",
" outputs = layers.Dense(num_classes, activation=\"softmax\")(x)\n",
" model = keras.Model(inputs=inputs, outputs=outputs)\n",
" return model\n",
Expand Down Expand Up @@ -424,7 +421,7 @@
"\n",
"model.compile(\n",
" loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),\n",
" optimizer=tfa.optimizers.AdamW(\n",
" optimizer=keras.optimizers.AdamW(\n",
" learning_rate=learning_rate, weight_decay=weight_decay\n",
" ),\n",
" metrics=[\n",
Expand Down Expand Up @@ -504,6 +501,7 @@
"and the same hyperparameters, The EANet model we just trained has just 0.3M parameters,\n",
"and it gets us to ~73% test top-5 accuracy and ~43% top-1 accuracy. This fully demonstrates the\n",
"effectiveness of external attention.\n",
"\n",
"We only show the training\n",
"process of EANet, you can train Vit under the same experimental conditions and observe\n",
"the test results."
Expand All @@ -514,7 +512,7 @@
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "EANet",
"name": "eanet",
"private_outputs": false,
"provenance": [],
"toc_visible": true
Expand Down
Loading

0 comments on commit 1671337

Please sign in to comment.