diff --git a/examples/vision/eanet.py b/examples/vision/eanet.py
index 5eda90f8a73..c17e4d30d16 100644
--- a/examples/vision/eanet.py
+++ b/examples/vision/eanet.py
@@ -2,7 +2,7 @@
Title: Image classification with EANet (External Attention Transformer)
Author: [ZhiYong Chang](https://github.com/czy00000)
Date created: 2021/10/19
-Last modified: 2021/10/19
+Last modified: 2023/07/18
Description: Image classification with a Transformer that leverages external attention.
Accelerator: GPU
"""
@@ -18,25 +18,16 @@
linear layers and two normalization layers. It conveniently replaces self-attention
as used in existing architectures. External attention has linear complexity, as it only
implicitly considers the correlations between all samples.
-
-This example requires TensorFlow 2.5 or higher, as well as
-[TensorFlow Addons](https://www.tensorflow.org/addons/overview) package,
-which can be installed using the following command:
-
-```python
-pip install -U tensorflow-addons
-```
"""
"""
## Setup
"""
-import numpy as np
-import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
-import tensorflow_addons as tfa
+import keras
+from keras import layers
+from keras import ops
+
import matplotlib.pyplot as plt
@@ -62,7 +53,7 @@
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
-num_epochs = 50
+num_epochs = 1 # Recommended num_epochs = 1.
patch_size = 2 # Size of the patches to be extracted from the input images.
num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch
embedding_dim = 64 # Number of hidden units.
@@ -104,18 +95,11 @@ def __init__(self, patch_size, **kwargs):
super().__init__(**kwargs)
self.patch_size = patch_size
- def call(self, images):
- batch_size = tf.shape(images)[0]
- patches = tf.image.extract_patches(
- images=images,
- sizes=(1, self.patch_size, self.patch_size, 1),
- strides=(1, self.patch_size, self.patch_size, 1),
- rates=(1, 1, 1, 1),
- padding="VALID",
- )
- patch_dim = patches.shape[-1]
- patch_num = patches.shape[1]
- return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))
+ def call(self, x):
+ B, C = ops.shape(x)[0], ops.shape(x)[-1]
+ x = ops.image.extract_patches(x, self.patch_size)
+ x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))
+ return x
class PatchEmbedding(layers.Layer):
@@ -126,7 +110,7 @@ def __init__(self, num_patch, embed_dim, **kwargs):
self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
def call(self, patch):
- pos = tf.range(start=0, limit=self.num_patch, delta=1)
+ pos = ops.arange(start=0, stop=self.num_patch, step=1)
return self.proj(patch) + self.pos_embed(pos)
@@ -136,7 +120,12 @@ def call(self, patch):
def external_attention(
- x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0
+ x,
+ dim,
+ num_heads,
+ dim_coefficient=4,
+ attention_dropout=0,
+ projection_dropout=0,
):
_, num_patch, channel = x.shape
assert dim % num_heads == 0
@@ -144,21 +133,24 @@ def external_attention(
x = layers.Dense(dim * dim_coefficient)(x)
# create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
- x = tf.reshape(
- x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)
- )
- x = tf.transpose(x, perm=[0, 2, 1, 3])
+ x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))
+ x = ops.transpose(x, axes=[0, 2, 1, 3])
# a linear layer M_k
attn = layers.Dense(dim // dim_coefficient)(x)
# normalize attention map
attn = layers.Softmax(axis=2)(attn)
# dobule-normalization
- attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))
+ attn = layers.Lambda(
+ lambda attn: ops.divide(
+ attn,
+ ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),
+ )
+ )(attn)
attn = layers.Dropout(attention_dropout)(attn)
# a linear layer M_v
x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
- x = tf.transpose(x, perm=[0, 2, 1, 3])
- x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])
+ x = ops.transpose(x, axes=[0, 2, 1, 3])
+ x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])
# a linear layer to project original dim
x = layers.Dense(dim)(x)
x = layers.Dropout(projection_dropout)(x)
@@ -171,7 +163,7 @@ def external_attention(
def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
- x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)
+ x = layers.Dense(mlp_dim, activation=ops.gelu)(x)
x = layers.Dropout(drop_rate)(x)
x = layers.Dense(embedding_dim)(x)
x = layers.Dropout(drop_rate)(x)
@@ -206,7 +198,9 @@ def transformer_encoder(
)
elif attention_type == "self_attention":
x = layers.MultiHeadAttention(
- num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout
+ num_heads=num_heads,
+ key_dim=embedding_dim,
+ dropout=attention_dropout,
)(x, x)
x = layers.add([x, residual_1])
residual_2 = x
@@ -256,7 +250,7 @@ def get_model(attention_type="external_attention"):
attention_type,
)
- x = layers.GlobalAvgPool1D()(x)
+ x = layers.GlobalAveragePooling1D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
@@ -272,7 +266,7 @@ def get_model(attention_type="external_attention"):
model.compile(
loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
- optimizer=tfa.optimizers.AdamW(
+ optimizer=keras.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
),
metrics=[
diff --git a/examples/vision/img/eanet/eanet_24_0.png b/examples/vision/img/eanet/eanet_24_0.png
new file mode 100644
index 00000000000..2331f599ce1
Binary files /dev/null and b/examples/vision/img/eanet/eanet_24_0.png differ
diff --git a/examples/vision/ipynb/eanet.ipynb b/examples/vision/ipynb/eanet.ipynb
index a46cc01b1f4..babaa30b71f 100644
--- a/examples/vision/ipynb/eanet.ipynb
+++ b/examples/vision/ipynb/eanet.ipynb
@@ -10,7 +10,7 @@
"\n",
"**Author:** [ZhiYong Chang](https://github.com/czy00000)
\n",
"**Date created:** 2021/10/19
\n",
- "**Last modified:** 2021/10/19
\n",
+ "**Last modified:** 2023/07/18
\n",
"**Description:** Image classification with a Transformer that leverages external attention."
]
},
@@ -21,6 +21,7 @@
},
"source": [
"## Introduction\n",
+ "\n",
"This example implements the [EANet](https://arxiv.org/abs/2105.02358)\n",
"model for image classification, and demonstrates it on the CIFAR-100 dataset.\n",
"EANet introduces a novel attention mechanism\n",
@@ -28,13 +29,7 @@
"shared memories, which can be implemented easily by simply using two cascaded\n",
"linear layers and two normalization layers. It conveniently replaces self-attention\n",
"as used in existing architectures. External attention has linear complexity, as it only\n",
- "implicitly considers the correlations between all samples.\n",
- "This example requires TensorFlow 2.5 or higher, as well as\n",
- "[TensorFlow Addons](https://www.tensorflow.org/addons/overview) package,\n",
- "which can be installed using the following command:\n",
- "```python\n",
- "pip install -U tensorflow-addons\n",
- "```"
+ "implicitly considers the correlations between all samples."
]
},
{
@@ -54,11 +49,10 @@
},
"outputs": [],
"source": [
- "import numpy as np\n",
- "import tensorflow as tf\n",
- "from tensorflow import keras\n",
- "from tensorflow.keras import layers\n",
- "import tensorflow_addons as tfa\n",
+ "import keras\n",
+ "from keras import layers\n",
+ "from keras import ops\n",
+ "\n",
"import matplotlib.pyplot as plt\n",
""
]
@@ -112,7 +106,7 @@
"label_smoothing = 0.1\n",
"validation_split = 0.2\n",
"batch_size = 128\n",
- "num_epochs = 50\n",
+ "num_epochs = 1 # Recommended num_epochs = 1.\n",
"patch_size = 2 # Size of the patches to be extracted from the input images.\n",
"num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch\n",
"embedding_dim = 64 # Number of hidden units.\n",
@@ -182,18 +176,11 @@
" super().__init__(**kwargs)\n",
" self.patch_size = patch_size\n",
"\n",
- " def call(self, images):\n",
- " batch_size = tf.shape(images)[0]\n",
- " patches = tf.image.extract_patches(\n",
- " images=images,\n",
- " sizes=(1, self.patch_size, self.patch_size, 1),\n",
- " strides=(1, self.patch_size, self.patch_size, 1),\n",
- " rates=(1, 1, 1, 1),\n",
- " padding=\"VALID\",\n",
- " )\n",
- " patch_dim = patches.shape[-1]\n",
- " patch_num = patches.shape[1]\n",
- " return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))\n",
+ " def call(self, x):\n",
+ " B, C = ops.shape(x)[0], ops.shape(x)[-1]\n",
+ " x = ops.image.extract_patches(x, self.patch_size)\n",
+ " x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))\n",
+ " return x\n",
"\n",
"\n",
"class PatchEmbedding(layers.Layer):\n",
@@ -204,7 +191,7 @@
" self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)\n",
"\n",
" def call(self, patch):\n",
- " pos = tf.range(start=0, limit=self.num_patch, delta=1)\n",
+ " pos = ops.arange(start=0, stop=self.num_patch, step=1)\n",
" return self.proj(patch) + self.pos_embed(pos)\n",
""
]
@@ -228,7 +215,12 @@
"source": [
"\n",
"def external_attention(\n",
- " x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0\n",
+ " x,\n",
+ " dim,\n",
+ " num_heads,\n",
+ " dim_coefficient=4,\n",
+ " attention_dropout=0,\n",
+ " projection_dropout=0,\n",
"):\n",
" _, num_patch, channel = x.shape\n",
" assert dim % num_heads == 0\n",
@@ -236,21 +228,24 @@
"\n",
" x = layers.Dense(dim * dim_coefficient)(x)\n",
" # create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]\n",
- " x = tf.reshape(\n",
- " x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)\n",
- " )\n",
- " x = tf.transpose(x, perm=[0, 2, 1, 3])\n",
+ " x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))\n",
+ " x = ops.transpose(x, axes=[0, 2, 1, 3])\n",
" # a linear layer M_k\n",
" attn = layers.Dense(dim // dim_coefficient)(x)\n",
" # normalize attention map\n",
" attn = layers.Softmax(axis=2)(attn)\n",
" # dobule-normalization\n",
- " attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))\n",
+ " attn = layers.Lambda(\n",
+ " lambda attn: ops.divide(\n",
+ " attn,\n",
+ " ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),\n",
+ " )\n",
+ " )(attn)\n",
" attn = layers.Dropout(attention_dropout)(attn)\n",
" # a linear layer M_v\n",
" x = layers.Dense(dim * dim_coefficient // num_heads)(attn)\n",
- " x = tf.transpose(x, perm=[0, 2, 1, 3])\n",
- " x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])\n",
+ " x = ops.transpose(x, axes=[0, 2, 1, 3])\n",
+ " x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])\n",
" # a linear layer to project original dim\n",
" x = layers.Dense(dim)(x)\n",
" x = layers.Dropout(projection_dropout)(x)\n",
@@ -277,7 +272,7 @@
"source": [
"\n",
"def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):\n",
- " x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)\n",
+ " x = layers.Dense(mlp_dim, activation=ops.gelu)(x)\n",
" x = layers.Dropout(drop_rate)(x)\n",
" x = layers.Dense(embedding_dim)(x)\n",
" x = layers.Dropout(drop_rate)(x)\n",
@@ -326,7 +321,9 @@
" )\n",
" elif attention_type == \"self_attention\":\n",
" x = layers.MultiHeadAttention(\n",
- " num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout\n",
+ " num_heads=num_heads,\n",
+ " key_dim=embedding_dim,\n",
+ " dropout=attention_dropout,\n",
" )(x, x)\n",
" x = layers.add([x, residual_1])\n",
" residual_2 = x\n",
@@ -395,7 +392,7 @@
" attention_type,\n",
" )\n",
"\n",
- " x = layers.GlobalAvgPool1D()(x)\n",
+ " x = layers.GlobalAveragePooling1D()(x)\n",
" outputs = layers.Dense(num_classes, activation=\"softmax\")(x)\n",
" model = keras.Model(inputs=inputs, outputs=outputs)\n",
" return model\n",
@@ -424,7 +421,7 @@
"\n",
"model.compile(\n",
" loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),\n",
- " optimizer=tfa.optimizers.AdamW(\n",
+ " optimizer=keras.optimizers.AdamW(\n",
" learning_rate=learning_rate, weight_decay=weight_decay\n",
" ),\n",
" metrics=[\n",
@@ -504,6 +501,7 @@
"and the same hyperparameters, The EANet model we just trained has just 0.3M parameters,\n",
"and it gets us to ~73% test top-5 accuracy and ~43% top-1 accuracy. This fully demonstrates the\n",
"effectiveness of external attention.\n",
+ "\n",
"We only show the training\n",
"process of EANet, you can train Vit under the same experimental conditions and observe\n",
"the test results."
@@ -514,7 +512,7 @@
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
- "name": "EANet",
+ "name": "eanet",
"private_outputs": false,
"provenance": [],
"toc_visible": true
diff --git a/examples/vision/md/eanet.md b/examples/vision/md/eanet.md
index d13ce96a5f3..fb2bc09bf4a 100644
--- a/examples/vision/md/eanet.md
+++ b/examples/vision/md/eanet.md
@@ -2,16 +2,17 @@
**Author:** [ZhiYong Chang](https://github.com/czy00000)
**Date created:** 2021/10/19
-**Last modified:** 2021/10/19
+**Last modified:** 2023/07/18
**Description:** Image classification with a Transformer that leverages external attention.
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/EANet.ipynb) • [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/vision/eanet.py)
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/eanet.ipynb) • [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/vision/eanet.py)
---
## Introduction
+
This example implements the [EANet](https://arxiv.org/abs/2105.02358)
model for image classification, and demonstrates it on the CIFAR-100 dataset.
EANet introduces a novel attention mechanism
@@ -20,23 +21,16 @@ shared memories, which can be implemented easily by simply using two cascaded
linear layers and two normalization layers. It conveniently replaces self-attention
as used in existing architectures. External attention has linear complexity, as it only
implicitly considers the correlations between all samples.
-This example requires TensorFlow 2.5 or higher, as well as
-[TensorFlow Addons](https://www.tensorflow.org/addons/overview) package,
-which can be installed using the following command:
-```python
-pip install -U tensorflow-addons
-```
---
## Setup
```python
-import numpy as np
-import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
-import tensorflow_addons as tfa
+import keras
+from keras import layers
+from keras import ops
+
import matplotlib.pyplot as plt
```
@@ -73,7 +67,7 @@ learning_rate = 0.001
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
-num_epochs = 50
+num_epochs = 1 # Recommended num_epochs = 1.
patch_size = 2 # Size of the patches to be extracted from the input images.
num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch
embedding_dim = 64 # Number of hidden units.
@@ -126,18 +120,11 @@ class PatchExtract(layers.Layer):
super().__init__(**kwargs)
self.patch_size = patch_size
- def call(self, images):
- batch_size = tf.shape(images)[0]
- patches = tf.image.extract_patches(
- images=images,
- sizes=(1, self.patch_size, self.patch_size, 1),
- strides=(1, self.patch_size, self.patch_size, 1),
- rates=(1, 1, 1, 1),
- padding="VALID",
- )
- patch_dim = patches.shape[-1]
- patch_num = patches.shape[1]
- return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))
+ def call(self, x):
+ B, C = ops.shape(x)[0], ops.shape(x)[-1]
+ x = ops.image.extract_patches(x, self.patch_size)
+ x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))
+ return x
class PatchEmbedding(layers.Layer):
@@ -148,7 +135,7 @@ class PatchEmbedding(layers.Layer):
self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
def call(self, patch):
- pos = tf.range(start=0, limit=self.num_patch, delta=1)
+ pos = ops.arange(start=0, stop=self.num_patch, step=1)
return self.proj(patch) + self.pos_embed(pos)
```
@@ -160,7 +147,12 @@ class PatchEmbedding(layers.Layer):
```python
def external_attention(
- x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0
+ x,
+ dim,
+ num_heads,
+ dim_coefficient=4,
+ attention_dropout=0,
+ projection_dropout=0,
):
_, num_patch, channel = x.shape
assert dim % num_heads == 0
@@ -168,21 +160,24 @@ def external_attention(
x = layers.Dense(dim * dim_coefficient)(x)
# create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
- x = tf.reshape(
- x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)
- )
- x = tf.transpose(x, perm=[0, 2, 1, 3])
+ x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))
+ x = ops.transpose(x, axes=[0, 2, 1, 3])
# a linear layer M_k
attn = layers.Dense(dim // dim_coefficient)(x)
# normalize attention map
attn = layers.Softmax(axis=2)(attn)
# dobule-normalization
- attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))
+ attn = layers.Lambda(
+ lambda attn: ops.divide(
+ attn,
+ ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),
+ )
+ )(attn)
attn = layers.Dropout(attention_dropout)(attn)
# a linear layer M_v
x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
- x = tf.transpose(x, perm=[0, 2, 1, 3])
- x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])
+ x = ops.transpose(x, axes=[0, 2, 1, 3])
+ x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])
# a linear layer to project original dim
x = layers.Dense(dim)(x)
x = layers.Dropout(projection_dropout)(x)
@@ -197,7 +192,7 @@ def external_attention(
```python
def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
- x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)
+ x = layers.Dense(mlp_dim, activation=ops.gelu)(x)
x = layers.Dropout(drop_rate)(x)
x = layers.Dense(embedding_dim)(x)
x = layers.Dropout(drop_rate)(x)
@@ -234,7 +229,9 @@ def transformer_encoder(
)
elif attention_type == "self_attention":
x = layers.MultiHeadAttention(
- num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout
+ num_heads=num_heads,
+ key_dim=embedding_dim,
+ dropout=attention_dropout,
)(x, x)
x = layers.add([x, residual_1])
residual_2 = x
@@ -284,7 +281,7 @@ def get_model(attention_type="external_attention"):
attention_type,
)
- x = layers.GlobalAvgPool1D()(x)
+ x = layers.GlobalAveragePooling1D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
@@ -301,7 +298,7 @@ model = get_model(attention_type="external_attention")
model.compile(
loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
- optimizer=tfa.optimizers.AdamW(
+ optimizer=keras.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
),
metrics=[
@@ -321,106 +318,7 @@ history = model.fit(