From 16713378217b4edbaa9b1812fd3c5b11c163d3ab Mon Sep 17 00:00:00 2001 From: nkovela1 Date: Wed, 8 Nov 2023 21:23:53 +0000 Subject: [PATCH] Update EANet to Keras 3 --- examples/vision/eanet.py | 74 +++++---- examples/vision/img/eanet/eanet_24_0.png | Bin 0 -> 23897 bytes examples/vision/ipynb/eanet.ipynb | 78 +++++----- examples/vision/md/eanet.md | 187 ++++++----------------- 4 files changed, 115 insertions(+), 224 deletions(-) create mode 100644 examples/vision/img/eanet/eanet_24_0.png diff --git a/examples/vision/eanet.py b/examples/vision/eanet.py index 5eda90f8a73..c17e4d30d16 100644 --- a/examples/vision/eanet.py +++ b/examples/vision/eanet.py @@ -2,7 +2,7 @@ Title: Image classification with EANet (External Attention Transformer) Author: [ZhiYong Chang](https://github.com/czy00000) Date created: 2021/10/19 -Last modified: 2021/10/19 +Last modified: 2023/07/18 Description: Image classification with a Transformer that leverages external attention. Accelerator: GPU """ @@ -18,25 +18,16 @@ linear layers and two normalization layers. It conveniently replaces self-attention as used in existing architectures. External attention has linear complexity, as it only implicitly considers the correlations between all samples. - -This example requires TensorFlow 2.5 or higher, as well as -[TensorFlow Addons](https://www.tensorflow.org/addons/overview) package, -which can be installed using the following command: - -```python -pip install -U tensorflow-addons -``` """ """ ## Setup """ -import numpy as np -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -import tensorflow_addons as tfa +import keras +from keras import layers +from keras import ops + import matplotlib.pyplot as plt @@ -62,7 +53,7 @@ label_smoothing = 0.1 validation_split = 0.2 batch_size = 128 -num_epochs = 50 +num_epochs = 1 # Recommended num_epochs = 1. patch_size = 2 # Size of the patches to be extracted from the input images. num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch embedding_dim = 64 # Number of hidden units. @@ -104,18 +95,11 @@ def __init__(self, patch_size, **kwargs): super().__init__(**kwargs) self.patch_size = patch_size - def call(self, images): - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=(1, self.patch_size, self.patch_size, 1), - strides=(1, self.patch_size, self.patch_size, 1), - rates=(1, 1, 1, 1), - padding="VALID", - ) - patch_dim = patches.shape[-1] - patch_num = patches.shape[1] - return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim)) + def call(self, x): + B, C = ops.shape(x)[0], ops.shape(x)[-1] + x = ops.image.extract_patches(x, self.patch_size) + x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C)) + return x class PatchEmbedding(layers.Layer): @@ -126,7 +110,7 @@ def __init__(self, num_patch, embed_dim, **kwargs): self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim) def call(self, patch): - pos = tf.range(start=0, limit=self.num_patch, delta=1) + pos = ops.arange(start=0, stop=self.num_patch, step=1) return self.proj(patch) + self.pos_embed(pos) @@ -136,7 +120,12 @@ def call(self, patch): def external_attention( - x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0 + x, + dim, + num_heads, + dim_coefficient=4, + attention_dropout=0, + projection_dropout=0, ): _, num_patch, channel = x.shape assert dim % num_heads == 0 @@ -144,21 +133,24 @@ def external_attention( x = layers.Dense(dim * dim_coefficient)(x) # create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads] - x = tf.reshape( - x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads) - ) - x = tf.transpose(x, perm=[0, 2, 1, 3]) + x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads)) + x = ops.transpose(x, axes=[0, 2, 1, 3]) # a linear layer M_k attn = layers.Dense(dim // dim_coefficient)(x) # normalize attention map attn = layers.Softmax(axis=2)(attn) # dobule-normalization - attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True)) + attn = layers.Lambda( + lambda attn: ops.divide( + attn, + ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True), + ) + )(attn) attn = layers.Dropout(attention_dropout)(attn) # a linear layer M_v x = layers.Dense(dim * dim_coefficient // num_heads)(attn) - x = tf.transpose(x, perm=[0, 2, 1, 3]) - x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient]) + x = ops.transpose(x, axes=[0, 2, 1, 3]) + x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient]) # a linear layer to project original dim x = layers.Dense(dim)(x) x = layers.Dropout(projection_dropout)(x) @@ -171,7 +163,7 @@ def external_attention( def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2): - x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x) + x = layers.Dense(mlp_dim, activation=ops.gelu)(x) x = layers.Dropout(drop_rate)(x) x = layers.Dense(embedding_dim)(x) x = layers.Dropout(drop_rate)(x) @@ -206,7 +198,9 @@ def transformer_encoder( ) elif attention_type == "self_attention": x = layers.MultiHeadAttention( - num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout + num_heads=num_heads, + key_dim=embedding_dim, + dropout=attention_dropout, )(x, x) x = layers.add([x, residual_1]) residual_2 = x @@ -256,7 +250,7 @@ def get_model(attention_type="external_attention"): attention_type, ) - x = layers.GlobalAvgPool1D()(x) + x = layers.GlobalAveragePooling1D()(x) outputs = layers.Dense(num_classes, activation="softmax")(x) model = keras.Model(inputs=inputs, outputs=outputs) return model @@ -272,7 +266,7 @@ def get_model(attention_type="external_attention"): model.compile( loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing), - optimizer=tfa.optimizers.AdamW( + optimizer=keras.optimizers.AdamW( learning_rate=learning_rate, weight_decay=weight_decay ), metrics=[ diff --git a/examples/vision/img/eanet/eanet_24_0.png b/examples/vision/img/eanet/eanet_24_0.png new file mode 100644 index 0000000000000000000000000000000000000000..2331f599ce1d1e236e952a4ab0f2c97d6c6717f6 GIT binary patch literal 23897 zcmeIa2UJ$s)-8Ad1tbV4s2~yysGuMQRFEK&L=+_`DA7a^1tdt2pnzmBp`wySfv_>e%(Dr_j_HVsu&8K@0@-1UTdy7=UTpp4=J-U z@i0*o#kyZbL4%@bRVa$)6(c?VCQFH{2VZvED(cv3TAsDNWPI@qb%EHpt!tA7g z{h5n4W|kMWiS7~=-y~paYingADJFLQA3q>!dC^4dL(o*I2W5`uvR>H|}jSS^VYe*Ms^@+%34<(&8(0LlyjPk=E_yD|TO8 zt0l}eH`OaPGx(CNG1q2&NlD4zfZP1a&d$27ibLLuR}Gdhi*@TC)gEKf>`0J_&`Q+u zU$$9TBK!PjiPLA!{LnvoV~@17h}~@6i??D$0cIZ`w?0WNEJ@UUGD3f6X6Ewcu}>yj zcJ8cbOg|;kSMcQIyuQAEYfq1P&(>FzuC6X43rlJ2fxBDW<}Pp2%V9rz_H0v!nVH!N zN!txFa}%c;KYu>;>FF7vq-=}JmoE#?PV;N)?v~%Q=l-qV;?6)Fla; zan;W(4L5|Sk{iRvt_xo}{Pohes8ba|Y1!FIU!U8FyUsX#DR6FR;W<~A=yhX_^6K^L z8Kf@#td%mEpPN-VaNwmo4P8rH+ttj>?c|BWq%K{{$k@iow2()_!pC>%`qgXJylfd4 z{bBW@?#VF@PR<3&%F4Bg+RW0+J#VHhbvam) zIjF+P&Xps5pO#iuR;E3DdiBABl{2%m{w@W0@D;Pul6|aYAaiASX-H_8Mj!{zRzd8*|{q6{A`N{ zo7ImG>lQFDtXaKUemyUZJ1$%}Kj*wlQu5^Fua1e~Dp?vzIhfB&O-RJr(o{)_e$%E+ z)mi6`g$+G?$V-+A@5eJMYX>$g9^I<1G5ut3ciofaIDLG39NZm;JEQjB5=u);iyy2W zo)|B)>8dntDZX+zL5n>3)M&5w$2j#~m#-af?}}gevMVPq@8yRNOUTj7v2M5Oc*9jN zY;*A+yh9Qb*r~Zc`3ex$_XNs!XvzIhHCqCPOJyg-hF& zmj|s9p?o;jDETyHo%6@e7eDic^GK1a3yoab`3w|J zA7T_Ny=&;Y(c=91vX&PfWw$qK+V)s{$w`0p?p>|pY>5~1Ldtq#&^XN6Wmf&bfu%ox z{-kwhWMpK=OXAw+y_oCHoja=q1R@)qQVfekzhBJl{*)2N$jrPn!=!4NSi#86!y7nD z-n>zrnd(h#%Khd#oxEb}X&*}98VmPoM8|j+&W%&!^V0)Gc0EqR?{_YwT7Ui&eXJI` z8vjKn+q}8GQPSq^u%+(JXeUI=@LkU~?MONC>=4DCl9IA|(UHEES@A&Kceq&2Zv3Kv^msLmY zD@)SNC_Nxnuoi#z%E-vbvHfw={-@jgv|-fY4W5C4fs9N{OUO%EDx}6P?KHwZ+S{-* ze&>Y?7iisctXkz0hR;@oZ>D@MeRDa+D5$V^?-fL$D7=P`(*tQ2J5*`pk_?Lsr0l21 z&euFWbD)1b5`osY!NB?8uHCyA+1S_&^*{U8lN4i872Z_ko|l)mV&{d2k&QR|26N@^0Be8=t~35Co3QBUcF(1%Do2qBm&xLD5E*jVsqbHS4^ zX1UDspO0=o-=t;RUAlQvic*OjdwQbMsG8?7dj7rx{ivCzl<`!F{@*!prIb>sGJ zt)6@wY`<%)E16<7HZEDZ)UI=tw3F7o&`{sy1`d0rBu0LR=R4B6PxL%V^?N^19p#ahnaTOg@>_jS0p1y{dt;{At;#RJ7>sOe zY)9e_tqM!i=Dd#h6PFRbe&bXhdBf|AW#=4bG6oTGciVJ!C!Ir@qjfj^`aG-BVW7ES zA(d|3t`s(ev+%8KYHjKK{MAPTcVCW-*!GVv6!gr#gOuk7oX|h9KH{>*riuPf8Hjuz zA0HISZbsZ6LNxhsTsuZw;D|3?wy3BmJ;UwQZ|{mNrxYwKk}4g{WFwzheA2LxR($I^ z-T!QBL6KICe9=jI29xAsXJ0`ez2ACz_SOFL!0qxs>%CR06qa4;rV zW(ida*f3ja$H~dbC?zFT^W>OVnDsuoo5z!v9X)!q%M*+N7@hPQVWi@v3*=6WzR^pWi=Rl3uyDzdX9nhCv|U%$S!9`QBTrfYrW!9$1Y zuT`wxxN&2X^oC$5x4C?Tyv3^$SLHblKi;x!o5(U=za)bKY~^om^96WD0j_?#9kuPP z&mUDN^Se9>GaPX_607lJHnyy)N}XblI~2A=p3!>p*BKhQJ>y?&q9Y^aOfI;(7TT-V zgoqsTWcPA5J+n2`^ur_H%7!ERDaI2gKA+4o-Ogit;eu#>K|$6jMH;i8UvlZkyL4aN z%1qE}%0ADGr6_p2b?a7kE-o*Aw;7JRVus}a4t96qVq>{yWL5#nNdR$4Q*vj|KHl3q z(VJSt&du$MsNWjlHh1NCvglED4|>)&qlG4w4{0b{Y?{UyQT^w)C;h_2PnEZXU21+| z65(((g_I7h!~&(9ygyL*Yqv6Rv@Do-%Xsoja8QV(nzc5d__zJQ0bGj<)K2 z$4#;TQu6xz{QU6A6VJB@9LcRsK5=Yd0{gl;{gkp^?!|@rdG=mKMY4cr@7}$;e*3m3 za6yWW+{VBWhK1DmZihWAx4)zQP{1a<9u^je__Gk zyxUAF6|gYXsfMz(PmZzG*Vj`_mNkd($jGK-o~Ip|{~WQ)n$~4%goYaFZD4zUf5*P$ z<5>c?UZ(*bJ=|^c_{%8g>&nV&xEow$q<-zK*SJr8i%%b2uL)W(lQk0jW z8t*8td z3tglPTt>CS&H&;RixtiM(&~No>=^@9a4fK+tE(M5^vvA(%U^P=8?(+Wm#}OOz)L*m zW*P{zE7b)!sgv&@uBWe$l=f)pSxrMo`)LFLn1 z>($N2Go@K^6cMvqI^RDiD>;FJ!IOR6+~F;Ep4RKway1|-U;F&blKj>Lj@wYg{0zf# zF@sWs`o{&r7AW)pAsH`Tyl6EuVcl=CcFPv_<)V7-s<9tZrU4|yhUTO-aHcDQW#@dR z#|PHP$nYV^O5=QbU!h}Y8~^5}45WSF@_EFrmd?&=ZgZ0i2-LC7#UsdB)ZsoE!x*K& zk^t^)gOgnmr#?K|kF3rpYFHSsv)O@;$xA-&88ujc!m{-xo9uSe&weY;srwYB23Az~ z7iRlq$3B`w^pG6JX`)&XCKJ+*^B)7)xA>gRx^Uq_vh%H4ifIu1d2R_Eyz1dWclGL3 z;SQH4k_MgXTsjbk$*M5^{t;;@tf|`1^SyXf}_LK0iBpqtV zpFe|8qVSP50H}!oR5#ctbVEZ!6w^2(jYHe2PowFm zuTu>Q(k@S4B=t^?+q}%llPAL`p5uq94Wro+(ce^V>-gKZFE?p24xtmhB)p7ElPOaUYsf6Hm!SGilq329JcVHM~6J zW#4dKJAS>eFca$IY0XqO>@kG5mqafxGS2te0dR2^AVt2>o z*X#9?pLHJrjF9XH=oq)^ZBXx6!IVY#r1wAjs*()U@noP-32BgN-;H) z$9??$MV4K?GXE_@P@3Uj1b|Q!Zk?QYz!<-}(D3je!F|_S@Ty7CQJrb_@nr7?$%gJ* z>M}2IYc=slVor|BUbMD;L1?bF(nXgCUoezS(NlQEl7?*McS1vTH*J_3`KZ3jVNl?B zmU*l+e~Rnu6iH9y(F&woZ{Q%VKmh>U%C=|E9*UFCZ>9Lj1=x0vW_hJW*Q|-rD+-_~ z25uVp`C9m`N#x@;L5ak}NTTI94MX67+JKo&t0QSp^mrggo}3tLMJ&2|q&IYnF>UsR zuYM@5$r^)8Dgg+)aqpfFIqBZ~t`3Q#h?b|o6)-MZv=<>-+`M)Jp@8s$!25`f?|+?I zFvipDHt*7JkW#^xO5;`ZxAZGwY|rW^8ox@ zLspX)y44CYZ(_8;5Fpj0KU~U7{tCz-9-tRxFcmd%>inQHc#!xA5=3Ng#d!zrpr|_t zB`>|%@JjehYvVcaOxLl-coorviv*Hc>Kv_+kmfe;%EG&g9lugaa34rGEcLnhS`~oM z@?id8gkHk|N!;lIXJ_Xmopf5BKziI`jk6A_Dm#>^FM#I*`#(R6!5V@7$-#mh3X}8# z25bWj#5ERYgFkAw9bA{3l(fa9k~b_JXYkbQ)TmEzureUsjr;fg0BlN{ns_2E|MIQ; zf@(d6ya-v>JtS2U=)CLFPu{V{;Aah4VmxH{NEHdV~zSd<%9gN-YK} z)Xe=+IZu6jeP4hlAe`2A%Vyp~hYk@eoMuwRhlpy71aZc|B+&Tgxvu!*8~V4-r=P00 zGx7Y#oXGHIaKgw(57z=R{9G%0wYG z@Qaj_Annw*pQ~5H4-<-M2zkmE)mXv6i96E*i;g&KL`E&zgilIKPv;#N|8<#kz_r=Z{cVz0XSO%6xNXYP)btrI8e35cC!x( zh-Zx04mHSd1V0tr4Kd8KKd6THJpGB#S*Wnvuw|Dsu<_MaelBt=TCweHfN4#%*J0t1 zXh7$-+^+DY1Np;2IVi367HqtE_wM?mNjiezZgY-N8VR0WUS3u1ceRdb5f8``4Z4w7PPPhn8 zCa<(-&z7UiSV~Bd$^6?RJ2rc67%OOjs>b0%o zDMB|EMVZL1DClH5Nsty?8vw-1hi!jVS>!1dT6Sw(PNJT=~g}K<@&H#bcj#HJa#u z`t)gVaIkugS8NfT#%``vM?tKelKWt4bm|lxpr7lI_Uf@u4vpv9jUoa)Tt>RAP3Nk-l+m2W{*K{f8-?X8@!lRfDI6BE<* z2M_#_#_HnxKPWR^$J9VPKVsvzLBx?qKFyU>&VMT9N1SIpz za7$^`vWzhmtcCX8o#{uqP(0QRoKgLY3Dck z1Hd^RI7RYNha>Slox;nxmA~wI>!W}*pi9UjvJbWSi<#eb0S^myX-86WG7`%I9G@?F zmr}oEt+?t;)?%$EdrqqO`-qOT>@C~6g_E?O9BoA$mV0tEnVo}!h9YsBG6DpcfZ#>W z2y)As-MhJwy($3vA!l5vl9^^fypdN`WdaT|Y;*SX_9iOS1T>7av+sGJVYCB-S=rid z`X~`5{vd?rAN1d=bgHVjC93n{%K z*HwG>?$yb*;6_TlhS16@?ZheXv$gK<(cf=i$@q>*hYK&%g&YTlIcU<-FaHWQL3jg& zT7p%>_M_Z?r&R2AqDWC*JN=OSc7W75Y5Cy|ct?Qy*5JHUZ^iHsDWit&iS z(3RQ|ycCU$j6$=`%KW$nhlgJxfY4CyAMC7{9PS!=!N4y7-Sb$A-j(F!%?1vw3n6(C ztqe((;qZ#~mE88WdgMjDd5jjDTKg6RLJ}quaY<30kYg#W_%j@3Ly4mTl4@ zI$U}pd3U+*I9J~Su}|^xbL$Qe zVf_*NHy*_IA=da~7t?I`5_9|h{o}9eW@W|1#CT<0^-TIEhT2V=^K-{6V?!T2c=_!c z|J>Z1PL7ozO1;IaB+0t&gE)PzAq85I2)P*`ofaUPGL(|S$(@Pwsp<&1PCFo513jx8 zJg5j7L-kAEM~FC=L0GNWHGgH%J!*|lqz(YHbu5LR-KI@$U8g9$J*EONE7w^za=U1-iahVq6#|pp-c6N3O%ymrNkN*C^`I#kYnegtrN1y$6g8uH7|&t!}VOHlW=o87GB{UNy1rh9-81U>CA@$w9S?mZ z6!&e^`7YGyE|L>>?)^6}-0?RGXWIJuSin(f-aV4w;1h^2^xtov%w$Sbzx-Ync#A^flcSCw2B+}G06?3Kt(Wb2=mM#s$PJKZ3tnMJ6>WWgdl(*fF#rb ztE_+`n4ToB-a=mT{zWu<@=fl0v(tx6d^jN1iu#V!#P)rt$r_8?WWhV!K%%3eqbC&u zFd#FO*)1q`zz23De6|^AmR_)6LGgaB3BJKj$xk(eX@IBnrV0XoibOh7 z*H%9hzr^`EvTho*5fsbik(0oG_jcO|J-;+yVDSm{h8-WA99NJmv14sJ6%`b^-#Md@ z6FGeXyDMPbu3am>T^?qwHQzxQ0XP!^0Sxa&8yj((PBn6TSb4U)qo%ug>z2rAHVwpB zQ1wq(l_^qvfHVU`y)Y;!2<2?e&ny5HF6gc)VNwu_^5F&<93NN4LZh}GsokNy{no(V zH}BsoT3Ky@GVZzVz@0tTuHdBcqBx+Iqnum4ZJVY^E`+hlVUWIyL3k>G8?!49Q9YBM zi}OKI%nwSWcRzNrq{DqN$}J+IlZw$CiuV zRrX-BJr!<9hh(M}x^*FFxQb(jE?%Ij1+m{yqLq`h3j)3cb#hjua=bL0rJ_V zdj8s$Nxnz$1P~hl9%-kktAfG;Kq|=pIqNG6tlQ~bCt4W_Y+|CLTTr87ceg_pD#io7 z(Uaofs5`t>x7AowbC2pSZ3TW25fNgtkh1S922JL>Z1eHr&!1OPggu1$!3d$t^wZOI zfQi=;@CEya5L!SPTj2=wH$pNOy!*B$T9LV5XL_t({L)W7kY=<9FeTW1QTp`=WNHyI zs~8sZ7K5a~^VTyZBuC$~Wz)b8hp2bs`gNn~8ANeoa62lhs#{Ix;v!s8?fD{I^aiM7 zahjW(fk*&9)^EY(3BUjiy}?@FJnMKK^HS<*Mg|Q9*vk`G9_eJB7~StyKYjR9qe;7Q zYG91;sBTP+9OD(7rIW|JwCw;*NEX2<_l6pNeik{q3?XBSQ85Kx$`~&RX`=dq3CEA& z51gaxH*8=e#X5WqY0$F2%MMmbP|1 zbZD_sFkc|@%*=cV+X@#H3x6@*K|zk$q)57Rt+*d6tyX@u-+1ePqDOFzH(f{!XeY)08jRun?< z)xJK9?;ykpU4VD!2X9d1g|E*c*z;of$$ohh7kBF2J#iGGLV%IPe`Qo2z)jyUv=-`g zDXwe{lIof>Z0{UaJcGpc+F{vgxxn}JEv2D3LE z7m2oCb(_H*TRa*-6gx>C^wfdm4j@z0x} z>i`M?m|2G8sbvTm|G8YXU3lP@p1LQ*uduf$^VoWM#_!+1%Oz;tafrQia2Rp&(c{M@ zFJHzq7Ma-FOTa)A2;N7-&Gmr=+f87(%5do?+m`Y0Fu{=ZE?hdVQB+)Az3Sv}7q~IN z;-uqQOc1Li-x9?D@Mw=XFO6600nxiJA!FyXP>%1(h?@pUYIz9f0!uDDk=iF z=`LNmq@bh(r%7Vt^Bbqj%gO|0WDJ}cs*tb9flV!(62|Tj zwqhm~l{W~EmcX6O)p<$SE2uOc-VoU;eao=Z@IHjU+amo+%F6o~uS7;ho_q^W6`Tm^ zrZubM>&#(>V1zlTVYe7A)2+MJVbEw|YN|*7Y)^snV~to^4e%ogRk0|x9$(r6wGK($ z>8_)*GtqB4;)8sA;?HUbhaic=Asu~-T4h$DoC4qs~K?{$}4NHpTImW>s*BiQ5> zC-Ny|FRS6sg9NxK=`S|K#s%%u_QHc<$DFW)-~w!CABaBkSmzHPo_^^)37{Ko5)QbB zSM3dKs@)jT9xseq3p6|`v2$93BB)|nQ7_GFUAg|dh9>x_8hn@VP@KB(O{JI@aIXPq z2Aj0`ZQr(S8$2hnb}LR31U<#SZCl34ox&+Uao!0sAP0OYO+s{SpT`Y?uR4lqM` zpsoLpemQc6Rs%^DXA{xzZ1qfcwks^b5r;TSruCu@cW+`_ zyto`lf@A4YxgvNlSgoLRH_R)c6h0DY{7YH|g&IDq9Cy^x(lS_z+G`SjR&sFI z#k2E5(xo8zD0igwv7rg4TeohdgPFYy-2pB$;({Wzu!6Gkn$4S&OS1h5KOlCC;qUV8 zlCIM7a(F4YHy_U|E}9z^`%>s~B5xCvC=^+cTbjgC$-}EwT2XP=Z67WcU%ussU^W(M zt&q@aLBSW1lfXGIQ3?>#OWQ*{DM`nK#{v~TSSd9S={7$Hox%f~^ux!G;5d_pc@WDH zq*p`iFC9&+GsjWuA4dG8+?8UkD8-1!&)rWnosJll(USNplXS+EUv1g-6oUdk9FNsN z8gLS*Ka99_%L8ooDR@)jL7VsvL-UQDokxs}K@NgvVi9VJ)~>E&Kb_XBSp#HbU@-&w z7hLGZai?=P|0&P>X1KcK!T->|M{bjJnc5BqMN5Bw@^29bCnXKeA>1x}JZ5+GpW7}1 zJ_N|TZl4qxNy)YHMc%&63EJ~fb?o4OSj;(=P8|nYN{NHhKipuWcrKf3L!{7jKSsgbsE2|f9 zEZl0zOM|H5j~_D{X9ygF$rY8|Rc~+a=J8yJ4Hv!@vceWh$szk+hZhOaEe7QU<&OVF z;Su0RU046)@CHA9Rt5?mJX1W&=4{N{mob#qQ2s(@CBwAdt7^V`n*=t7sAihishYg2 zuU`hnylDu}QT+AmDx`GKLc$V;u^Ke)_Iq~jTn-=a3J5}ksDsfoe|GG%X;bD#6ePc4 zIFS}nd8Zq{q=%n{)F4o2;dVW-&T(WLU%DJ-7y&AuDQNB|t7> z7zKGwgPd9SX*%On#bO61rxGj*9X*R2j=~@u)yK+($g$d?weA9^?pg_n?Ak63B-f9Q z7og}xxK91>9cb_Fe)T|})ULjXBYeNvR!t)k0axGr-{edDR+^0vNKx?O7;Q6k1-wN- z{OxJ_XuAELw6g@r9af}dKt*GKAQ+}UI-Uod7rgb_mJhlaSI!P31g$XQjHgsoREiNG zYaSn5>CqYY`0+^;d)>_id6f9PzOLsN_9!m z5{_{F9c`4hILiI^i=4c7&CPtkL6XCe%l)AuwSv@&H!Lc&*GteUDJzRD$@W0(GGBa) z|Kx-1W=rT7^D-i?d2}i(C=?_3A!_5Hj~-oyhlYAm>41a)j8|Rg;!G4#hmr0hx(C)H z*Ydo12|?n^#;E&}WX-G)5_`(tWWXsI96&YZ1s;ul^yZLy~>i-4vVGQC?f z*@T!?&=f(IiX;sYktgdbt6sl;0SiUC*+(I0>eIf%K;=q7yo&@E)TbJzW>^53HqBdp ze9Q`K--IhMd+ZQ@ZSI&7dm*KnwYm5LD}943ofJ~;jBCG6Gwsag{PMGc&2I8rJzD- z72wAlyWTsM%?Tzqlw$U(sj*Cr^#e09Yu?p7mHrzrGE7_GzEY4Hh;aapyx6%JiBEiM z5l|mWJ0CUK1V?>%csTZF_91p!>faz#YEG6{AM7PylPMq49LqgW(%x5m6nV$dyagoiJtQ1f}l#;${_nNW5xRqX?riL8o%TY(yj ziGjgewb`6thsA2Hp>Ad=dbzZ@Uj~``=O(-O;bUN+s^LotJA?Zs_KL{E8`xK^3M73n z$wy$#=pe1X7E>ge%uxY$Mi;%NwrWr~vraEMgv7=RJw@JZMvCXA%j(seh$~?;ALFaFNyt>N-3TQ1sid zJ~FVhOhGsa-@3M!=UuhyWDsForFL!6$ijFM1?aZ+(7o;o3kY zIy6{NL~_60c#ZWC4nDDdt_On^Htp@3$Awid4?j5dVDK!+WP z4@f`@zqs%baMno#Fw%2@O0}AJiHf1!WO_Hh0YIaHO;`anH_QD9$mgIwqEX(Lp(f8} zf&N#HgN+T!hR6;^5Q;8C)9x29UJx|~E}k^7Uc9?6E}2)2LY~D4%i@;Jdsw)) zE;!$m`P)AV<&3a+)b)oC15iY_!<%v?HC2q*8o=xefHi`R{7hmPd0g`GV4Fzfq21iUGVQ>K1C>d584kKUE15sD}=1q?;>6h8!G-7Jel(`5HR8;+w`Z|Pv0I{-i zHr`znS_)zsG`Y%3OYc4RG@<=VvRVG-?b~gp2H3qw)VEdFxtF+axjbP7{EyZeZumwd zOjpIJM<_u8{cSr&Swh5s`V@WUtmCXCK|F=GLl8nXgt`ww5TrRTc;@bg`M0I5jg~?~ z4?Q+_)-Z^^DyL(YC;bCE{V0DNLjciJP-sWIz zjVb%mr5-rGXooAQ?WooH3ouQ3(#=?X>$Zr%j`K~Nf6^|4AJt`zr$&BI?g0TT9xt!{ zx5UN?v^;r@>)fo2*lEcJ*I}A@oiLDiaY+d+)I7L>A|=?*n3%w3Xpnd82_6o6q4dYc zxSFVa3{)}fa)%EerVI_gE$P+Xz}e(1D_vPqa+N=4-dHTa!$pE`tO<%agm7l=CJ*f0 zh0M!ZLXH`{0D(@l0JH`!lxuB%1cq~E=3=mhFg*MS_|()yyaG}(XWy+rH`E%uIutr( zolmrcscP;B#5O`#p>u+Gz3$1n$)FC+96tcsN#4jvfH-l`HVLbVfJux}CSM=_wQXdYT0w+7Kx`RkUbtiIZmmGEw4Td=uj3hl355SUp zVaa3;3?Ld22YyWFSpr-lHpb4?=%ID)GJLfi@g(ksFYI{Erm=S~=^{00_m4`1|OdL_O#0%`W5m@MLoXZW0a7y~LRlRVyPK-=1O^e2zsa5_*MOGne)Q zA@I~?CKsSgs1M3A`$&(*CL;e5=7mH1T!&<-J!&x23d>fGkBpQez=l8HfX`X2DSS-M z90~(=^a+d0P!Leezt>z2o?qd#Dt-+CZ!0lDLDR_?<5RJ=wk9qnyuV{n3x9w5iQF-P zQMB#3%?vRR2^sx0r_Y`xo;DIw(Yt0)v*Y&Y|NS>k?q)m&vc~z?)ImWX97R2|4`O{j{a!d}+t2@&L!gt5{MW zwPZCKGx+gCv>AT?$O{{?(*N(u`H5NLQX0U!fn|dWkp5p`p4vq2iRTD6t~5~5?K6sU^> z3??8pavMbvDP{ORu|yzS{?>2Gk+6`H(pdK304}jW>xw%#fVE_w zKton%Q3-|@1^Jy=;NW>2t<3#X;vknRyQ8S6Xgt(bhH^!QpcKMic6P3Nh9vio3W3_r z!@8X(@P%E2g5u}-r17Yyf5pe5`=*UPxN`y@!k`KgYQ@Zbn3E7X@cS>{y<@>f*ar!D zv~m_k&L}NmafI}+|Drf4w1hqfZrDGNf0q>{At6EEV0i0qvqiWXqo8o^vI6ijOfVTB z37J+&gN|Gd&CVO{beQE}MxKGm3xa1E5Jo!e>d5$cjjcPW!I2RK+*&%KnZ`Nj7tGBL z?W{N!u$y#;#aOc4^!8=|3%^iGDsR1|6)T)c|3v~2k&;@8e5ZtHnIKzqFq`2A;w+rG z?y$buZ<3Yehnw~V?kdaK=Mi{#(v3`+Hlz?Sj)R+9cas2Y*_`l-lfJ8P9e-qeVyLLC z)grB+1H4}HQ9JkO>gWv8jm!Rn(*9kc*a+@(HhSs~`M?{EbUguH*RGX({79N4>sJPi zY=kEBDMRx#jyTLWbP)WeOYPRKTZbSDw)|6C8-W~3pqs!IgX|Bfya#g!s?W@J9Ox@~ zc~U6ZXgI?ezPGn~ptzDtK6(mEM!YM(OTBPZM}1#%m%`I3Hdcs;MBI`v5a2~~EMM-8 zT-9E8%+SljW8b>-n@RXv{~wgf6}`Fn@u#HWLFmv!AE4(Ka8}~%OrXtW2&S3dd{Lk; zdBz0`7jm75laG=E+=0HfAnSM_I5v`3AuTDzTC$O=vS|wol$Dp?nY#S?<~sVXOalMP zhi|@t2oz4E-Kic3K88b^>rdy}|0Z9wg#>WW|3v-kDnFsy0@xEi3rPR|z(U3lASG?& z<>es(*z&<6FKGjZpa*n5bV2To)i8$h9tS8n;62HQ=C8igE~Mn-jJ+?c?a@t*&GS&j!$*RmH{RRq2H2W<5J)hqFpt#rQEG{0Zpi z!uBumlH!kV{a4cEw&XJ;bjd1M75vcY!wJs!<02J|NB~H3YPj(SbYFv549b`yjvS+6 zcN${wY3=MBvt5a(s8ARFnhs{#TU)GJ4W=??|La0INFPWh+-QBUr<# z5(%TO&ePM2a0)02TN+@X$!Y&_K0@mNOG=ND(Fc1V&3dyb% zW}<4i-}+mh0o<8&!T!#JLTN8t3-FmU%gPpAY=APU2MTWA^d7`JkSN^^sr-F?eQyey z@Bir$huai3DleE9)BjMC0Yw*&=^m@$h+HrJx3qRY-|CIIdO^?}XekoAzkfds(QTQO zY=R0B=8-)6^D`)SVM_A=$OA2I+qnl`Nr4H85_S~`!b)loUp}g z&gVhjfXDJ;>A&)t_TT7=CoFe;h(oty*B5~V1R&yMv9hyUND*WVh_ed>JnXv^1(uf` zu8)@9-V>F{f;a&8!G7;g(254BBzY=3JDUue2ny!jMv$IQK)`~-4t&=k5u+xgMYzI( zkn^QeAy z-dPv@(rioDA1X!8F+#l#rSvk`8;8k`yEGIgW{_r}!5?uEaA7P(M>%;-%ft+lY#c0b zdf@-tiuxFw8Mzbm3pI6r1{tZKtr^0k0>@jdjN#MM`xBngl4b&&xRKIvYg^mz2L(y8 zQ276YD5#Ua38GCM(8Q5?l>mE;VXNA?b0-cJJxBz}UNg)PLF0Jk8e!qYrxL^vmdI_R zxg1k;;DgGsUW2P@_Uu7lIcw;}m`a`7KB_*{+K!AcYq@t!)3R@?N6|A9_@)8z6_7T1Q3~Z zU!yp&WBo~C+o1ONFDdN5YI$xa<(wX})50J|Y_YI99dFAA_5*4LRTMT9AHQLfLH-Iz zClw;8`B#AjUVh2{ZVv(LACjAw_vrr($kGLtZHP~SFtZ|EnN`H8h0=n5gvf`M#if|j zAPV#ygx00o6NQk7(6=ro;R=sAs>2m}ZU=?FIxC0#Q@z~L(ZRi4UC5I?M9<#Mff^S;1n?I;A)S++oEgQ~!xp{YsiM+Lvx z=)u$fbHUT3>d=v0&mD$VFXoqFA)UD?YtT!4F_#y~|IBAb;Dy8+5Kh78kPfbGXK^Uj zuV3Fk0rW{HRoeux;ozuqdOqVl_mYbdZ5SxNAUd z$+d2#g|0y0<-~-;$j`4URb!VRO@ow{B=#a?@D=HFh-O;ap);%u++ib>xnH4 zRLrgPS(H3hQ=*ofUtr+o(aD?Z4yYDz$$^L*dSp|Hehv)Hc>U`RX^yCJ7aopN!ddr8 zWcqJyoc~RK#s4}#;w!xLIvT|RD7OGZEa~cN%3=mS1DHz7%2Gt-JG2ggJMTQvD#IXMc0fg;0}^6~~i7b2EHLP&w3v@&(CKxd%r3F3rc#<((kdK z?l0ED!E;ZaKE)AT{0F(3bmG}!Aen^}oM?0uc;F|)_y2=lMkT3h0$pG?@2XWJ!h8>CZ2$i#>l zQzBlXg_RC{b{^=+k1f2U9rs^`(ccU$%mZO)&Olr6;MiDw5NSe3Gk{K>eX>2^9@afgsKr=(o*QkJK^YXY5NPs@oT3WeZ79nkb43VzYm4`-7J1{t(NI zW*zIP5wre>|LK_E%JCwxPCpdZIgJDb7%ptf(Qt+~m`AJlx1D9dlsE7Tiy{a22!J2~ zHMGDyiwAI&FkI=bvrbfXv$7egX4Cp4ZW80rAMx`2dsckFLdwC>Q4lTgnEF%!IuTil zFt7m_a{)w2-yr5m8BL@-zpsWW4(;7w6K47!E%_@qloB2BSy8s)z{y`v8dYsDP3u+1 ze9^7mzm3C5VZ)@I8sEp5r+pX(NOTyKK)nyUG9iNzg9bdtsKSV+1LH&VA2)$Drzq4R z>bW6UhSL`4I72s_n02n&o{jQ14A9;Mw}1zF_0)v6+3gn6wk?^0U=cC7v|eQCU$lqb zCGrBWtKhjLrHmiJhw&&;P4sRI_8f7~rDtTcFOT5(!$(pLj}LDkj@N>5-)O&=e>q?sFMqdK`K--h0F3_oW&oCpV27bxy{xPxHYij- z`*2$csTdtZ0|sURvL_k+Mj96&Y>iYFxJ*lc@nS&770N`_#;ZM1Tc<8#J_neS7dRkr znGV8tjNyC2&YFq0$;_Sqm}Unx3ZhO{?Ifx?Pl$*xVOBr4JBO(c>L$7Dv}h+J<331J z)xbb7I%SV`3TZKOeKLXS45vX19LWf&UT8u+4M`!VXrrOYn8o0ZoA*&TcJLh8vIf8YmIekEkeF7`S3x* z;?q;bhfDD^!Z4P>Nk%mhU)BIyiSZk0se3($cz3x+o#t zEW^YZ4&2Q$GDK>8;M>mnv)a44zIK63$1;eKSsT`^B@2N+)}zmNEhpBWBrZU8$^+5?E0 zv|s1q#ix`2G<(}lGKLHzS~Pk0I1Gj^RtfpqRe{FwBfvU&ZGNdG>RjlLns(a9+q&H2bhAA`bV-T1(=ku<0}8*`=V>YKg{Ef`#o9z;?mgnOWDqPWgH z!FOu(;j1@!}v&*0`}mnTi+qyu(0 zmeAL0!n)FKSO#4&l3 zwYlyflA`h1vq$@S_Pfk(#5UFFnTein#zYQ)H%S{@z5dC5JyhwSUS4HY2?r)5ncL3-2CaAo(y*RnYrl!SV(oz;q&B|M5G_; z+FRf^o{l+UgcV3Ou~E-@GW5(O2{d2Xp$~K51+(5}JbmhL(|+N&E~6l9%^Tn~6F#18 zu@S_QaNMAIe6+I*wDz!pGWZ+05PeVLO}AlSVaPKd&~$5^G?NZGmI!2m&O5oXu{`+2d;W5*vSBlx7|VO+$pA-&5Jt=_A5 z@7CRgDmTm)f3vb6vXc8cXY$Ij%2Vqj9j7?0|qtRUP(4ipvNkjt7AVhoQxr(3Pio`(HZ-BX- zj4uQnzK45u@*Q~LuC?2ROrgSr8_W)%TXP}QELk^0sW#QnkPP^PC(f75=eY}UQZoP} zEluX!=7#yn+!!*#j;uUsx(Vw*x;{BKGZ_ul1tJV1oK`S(oeVgb@Im-#m12aQOK^qk zd{v^~oZ10#fklE#+uYW>5)^JG zHveO^Z&Kr$EnBKynn8bRpu@p&Tx9E52Pb=;a8IE50g`Is;HxfefvAnI^&S(#!N zMln#4gM$rd`mQN*b3;msLDHFk9mvkW4TlYZk^zYrM4`Nn&I+_-*>xgEEhN)R;ig8C z6;cmR%9ugSAw>q9{V>UIgA+?RuMa_=U_^93f#(Z^R>j3^tDv<*TmWwE9q=-efgVU% zd(nvn=;#IFlqv@QP8uy47#N5x9OiN26|$qP#_frt%pjg`G65{Zgbg}W;xHev$zq@g zsh==`fPx7p$EGW2Y_=*w7Sp3xIw)6j0}w zkNFVtj2Qj7XsJ+letAQ~v6)`Od5o4DLds*tt)MILiT)fW8!_hUiT<&h$rMItrq4=w zC?BMDvxu(FP9~}iBT+cz8OeYYNT|ofSE?x}l|e*~Z|cG57ot0D>+J51pPdIw-o2?8 zxcKJHn-qncB>i~+zgcflU-}_Gz@8(A(ao4kfM6C|Uky)a|NLiI9DuD5kSKSz*!SUB-*LGl5OaDLs9lOevo9u!(U1T9`cYKYb!HcmAMA)U zD0Tp@34x@py^r*`VLk()a&=Hs5+7OIaefYNhH<>^tk|PZerLwB7cQhob2#C>4h{|N zC9Zp*V=@M{G2d|;B>)11v|i!mkXDlaXuTqRQNYu*6jU@%ydz>gCGrAVrh@oogV1XK zF>;Eu7Nesk^32;igo{BFy+kC(k-BoUTGZ78j9~@^nz9k&^m43dmyp00bC7YLYOg(xiHO35L1XpZw54T zIMOj@g^3K)9cW&KOLL`lC}@4clyBpn)jmGL_o3G#u=yi{-G9yrF2G?iu`~SRxL4sw v^7xvtvDbWT`NK*5FXJozhv^mHdAcn@{G3bro_@iYb85fhA%zrqqpSZ1mL?5! literal 0 HcmV?d00001 diff --git a/examples/vision/ipynb/eanet.ipynb b/examples/vision/ipynb/eanet.ipynb index a46cc01b1f4..babaa30b71f 100644 --- a/examples/vision/ipynb/eanet.ipynb +++ b/examples/vision/ipynb/eanet.ipynb @@ -10,7 +10,7 @@ "\n", "**Author:** [ZhiYong Chang](https://github.com/czy00000)
\n", "**Date created:** 2021/10/19
\n", - "**Last modified:** 2021/10/19
\n", + "**Last modified:** 2023/07/18
\n", "**Description:** Image classification with a Transformer that leverages external attention." ] }, @@ -21,6 +21,7 @@ }, "source": [ "## Introduction\n", + "\n", "This example implements the [EANet](https://arxiv.org/abs/2105.02358)\n", "model for image classification, and demonstrates it on the CIFAR-100 dataset.\n", "EANet introduces a novel attention mechanism\n", @@ -28,13 +29,7 @@ "shared memories, which can be implemented easily by simply using two cascaded\n", "linear layers and two normalization layers. It conveniently replaces self-attention\n", "as used in existing architectures. External attention has linear complexity, as it only\n", - "implicitly considers the correlations between all samples.\n", - "This example requires TensorFlow 2.5 or higher, as well as\n", - "[TensorFlow Addons](https://www.tensorflow.org/addons/overview) package,\n", - "which can be installed using the following command:\n", - "```python\n", - "pip install -U tensorflow-addons\n", - "```" + "implicitly considers the correlations between all samples." ] }, { @@ -54,11 +49,10 @@ }, "outputs": [], "source": [ - "import numpy as np\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from tensorflow.keras import layers\n", - "import tensorflow_addons as tfa\n", + "import keras\n", + "from keras import layers\n", + "from keras import ops\n", + "\n", "import matplotlib.pyplot as plt\n", "" ] @@ -112,7 +106,7 @@ "label_smoothing = 0.1\n", "validation_split = 0.2\n", "batch_size = 128\n", - "num_epochs = 50\n", + "num_epochs = 1 # Recommended num_epochs = 1.\n", "patch_size = 2 # Size of the patches to be extracted from the input images.\n", "num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch\n", "embedding_dim = 64 # Number of hidden units.\n", @@ -182,18 +176,11 @@ " super().__init__(**kwargs)\n", " self.patch_size = patch_size\n", "\n", - " def call(self, images):\n", - " batch_size = tf.shape(images)[0]\n", - " patches = tf.image.extract_patches(\n", - " images=images,\n", - " sizes=(1, self.patch_size, self.patch_size, 1),\n", - " strides=(1, self.patch_size, self.patch_size, 1),\n", - " rates=(1, 1, 1, 1),\n", - " padding=\"VALID\",\n", - " )\n", - " patch_dim = patches.shape[-1]\n", - " patch_num = patches.shape[1]\n", - " return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))\n", + " def call(self, x):\n", + " B, C = ops.shape(x)[0], ops.shape(x)[-1]\n", + " x = ops.image.extract_patches(x, self.patch_size)\n", + " x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))\n", + " return x\n", "\n", "\n", "class PatchEmbedding(layers.Layer):\n", @@ -204,7 +191,7 @@ " self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)\n", "\n", " def call(self, patch):\n", - " pos = tf.range(start=0, limit=self.num_patch, delta=1)\n", + " pos = ops.arange(start=0, stop=self.num_patch, step=1)\n", " return self.proj(patch) + self.pos_embed(pos)\n", "" ] @@ -228,7 +215,12 @@ "source": [ "\n", "def external_attention(\n", - " x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0\n", + " x,\n", + " dim,\n", + " num_heads,\n", + " dim_coefficient=4,\n", + " attention_dropout=0,\n", + " projection_dropout=0,\n", "):\n", " _, num_patch, channel = x.shape\n", " assert dim % num_heads == 0\n", @@ -236,21 +228,24 @@ "\n", " x = layers.Dense(dim * dim_coefficient)(x)\n", " # create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]\n", - " x = tf.reshape(\n", - " x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads)\n", - " )\n", - " x = tf.transpose(x, perm=[0, 2, 1, 3])\n", + " x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))\n", + " x = ops.transpose(x, axes=[0, 2, 1, 3])\n", " # a linear layer M_k\n", " attn = layers.Dense(dim // dim_coefficient)(x)\n", " # normalize attention map\n", " attn = layers.Softmax(axis=2)(attn)\n", " # dobule-normalization\n", - " attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))\n", + " attn = layers.Lambda(\n", + " lambda attn: ops.divide(\n", + " attn,\n", + " ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),\n", + " )\n", + " )(attn)\n", " attn = layers.Dropout(attention_dropout)(attn)\n", " # a linear layer M_v\n", " x = layers.Dense(dim * dim_coefficient // num_heads)(attn)\n", - " x = tf.transpose(x, perm=[0, 2, 1, 3])\n", - " x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient])\n", + " x = ops.transpose(x, axes=[0, 2, 1, 3])\n", + " x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])\n", " # a linear layer to project original dim\n", " x = layers.Dense(dim)(x)\n", " x = layers.Dropout(projection_dropout)(x)\n", @@ -277,7 +272,7 @@ "source": [ "\n", "def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):\n", - " x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x)\n", + " x = layers.Dense(mlp_dim, activation=ops.gelu)(x)\n", " x = layers.Dropout(drop_rate)(x)\n", " x = layers.Dense(embedding_dim)(x)\n", " x = layers.Dropout(drop_rate)(x)\n", @@ -326,7 +321,9 @@ " )\n", " elif attention_type == \"self_attention\":\n", " x = layers.MultiHeadAttention(\n", - " num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout\n", + " num_heads=num_heads,\n", + " key_dim=embedding_dim,\n", + " dropout=attention_dropout,\n", " )(x, x)\n", " x = layers.add([x, residual_1])\n", " residual_2 = x\n", @@ -395,7 +392,7 @@ " attention_type,\n", " )\n", "\n", - " x = layers.GlobalAvgPool1D()(x)\n", + " x = layers.GlobalAveragePooling1D()(x)\n", " outputs = layers.Dense(num_classes, activation=\"softmax\")(x)\n", " model = keras.Model(inputs=inputs, outputs=outputs)\n", " return model\n", @@ -424,7 +421,7 @@ "\n", "model.compile(\n", " loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),\n", - " optimizer=tfa.optimizers.AdamW(\n", + " optimizer=keras.optimizers.AdamW(\n", " learning_rate=learning_rate, weight_decay=weight_decay\n", " ),\n", " metrics=[\n", @@ -504,6 +501,7 @@ "and the same hyperparameters, The EANet model we just trained has just 0.3M parameters,\n", "and it gets us to ~73% test top-5 accuracy and ~43% top-1 accuracy. This fully demonstrates the\n", "effectiveness of external attention.\n", + "\n", "We only show the training\n", "process of EANet, you can train Vit under the same experimental conditions and observe\n", "the test results." @@ -514,7 +512,7 @@ "accelerator": "GPU", "colab": { "collapsed_sections": [], - "name": "EANet", + "name": "eanet", "private_outputs": false, "provenance": [], "toc_visible": true diff --git a/examples/vision/md/eanet.md b/examples/vision/md/eanet.md index d13ce96a5f3..fb2bc09bf4a 100644 --- a/examples/vision/md/eanet.md +++ b/examples/vision/md/eanet.md @@ -2,16 +2,17 @@ **Author:** [ZhiYong Chang](https://github.com/czy00000)
**Date created:** 2021/10/19
-**Last modified:** 2021/10/19
+**Last modified:** 2023/07/18
**Description:** Image classification with a Transformer that leverages external attention. - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/EANet.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/vision/eanet.py) + [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/eanet.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/vision/eanet.py) --- ## Introduction + This example implements the [EANet](https://arxiv.org/abs/2105.02358) model for image classification, and demonstrates it on the CIFAR-100 dataset. EANet introduces a novel attention mechanism @@ -20,23 +21,16 @@ shared memories, which can be implemented easily by simply using two cascaded linear layers and two normalization layers. It conveniently replaces self-attention as used in existing architectures. External attention has linear complexity, as it only implicitly considers the correlations between all samples. -This example requires TensorFlow 2.5 or higher, as well as -[TensorFlow Addons](https://www.tensorflow.org/addons/overview) package, -which can be installed using the following command: -```python -pip install -U tensorflow-addons -``` --- ## Setup ```python -import numpy as np -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -import tensorflow_addons as tfa +import keras +from keras import layers +from keras import ops + import matplotlib.pyplot as plt ``` @@ -73,7 +67,7 @@ learning_rate = 0.001 label_smoothing = 0.1 validation_split = 0.2 batch_size = 128 -num_epochs = 50 +num_epochs = 1 # Recommended num_epochs = 1. patch_size = 2 # Size of the patches to be extracted from the input images. num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch embedding_dim = 64 # Number of hidden units. @@ -126,18 +120,11 @@ class PatchExtract(layers.Layer): super().__init__(**kwargs) self.patch_size = patch_size - def call(self, images): - batch_size = tf.shape(images)[0] - patches = tf.image.extract_patches( - images=images, - sizes=(1, self.patch_size, self.patch_size, 1), - strides=(1, self.patch_size, self.patch_size, 1), - rates=(1, 1, 1, 1), - padding="VALID", - ) - patch_dim = patches.shape[-1] - patch_num = patches.shape[1] - return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim)) + def call(self, x): + B, C = ops.shape(x)[0], ops.shape(x)[-1] + x = ops.image.extract_patches(x, self.patch_size) + x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C)) + return x class PatchEmbedding(layers.Layer): @@ -148,7 +135,7 @@ class PatchEmbedding(layers.Layer): self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim) def call(self, patch): - pos = tf.range(start=0, limit=self.num_patch, delta=1) + pos = ops.arange(start=0, stop=self.num_patch, step=1) return self.proj(patch) + self.pos_embed(pos) ``` @@ -160,7 +147,12 @@ class PatchEmbedding(layers.Layer): ```python def external_attention( - x, dim, num_heads, dim_coefficient=4, attention_dropout=0, projection_dropout=0 + x, + dim, + num_heads, + dim_coefficient=4, + attention_dropout=0, + projection_dropout=0, ): _, num_patch, channel = x.shape assert dim % num_heads == 0 @@ -168,21 +160,24 @@ def external_attention( x = layers.Dense(dim * dim_coefficient)(x) # create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads] - x = tf.reshape( - x, shape=(-1, num_patch, num_heads, dim * dim_coefficient // num_heads) - ) - x = tf.transpose(x, perm=[0, 2, 1, 3]) + x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads)) + x = ops.transpose(x, axes=[0, 2, 1, 3]) # a linear layer M_k attn = layers.Dense(dim // dim_coefficient)(x) # normalize attention map attn = layers.Softmax(axis=2)(attn) # dobule-normalization - attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True)) + attn = layers.Lambda( + lambda attn: ops.divide( + attn, + ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True), + ) + )(attn) attn = layers.Dropout(attention_dropout)(attn) # a linear layer M_v x = layers.Dense(dim * dim_coefficient // num_heads)(attn) - x = tf.transpose(x, perm=[0, 2, 1, 3]) - x = tf.reshape(x, [-1, num_patch, dim * dim_coefficient]) + x = ops.transpose(x, axes=[0, 2, 1, 3]) + x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient]) # a linear layer to project original dim x = layers.Dense(dim)(x) x = layers.Dropout(projection_dropout)(x) @@ -197,7 +192,7 @@ def external_attention( ```python def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2): - x = layers.Dense(mlp_dim, activation=tf.nn.gelu)(x) + x = layers.Dense(mlp_dim, activation=ops.gelu)(x) x = layers.Dropout(drop_rate)(x) x = layers.Dense(embedding_dim)(x) x = layers.Dropout(drop_rate)(x) @@ -234,7 +229,9 @@ def transformer_encoder( ) elif attention_type == "self_attention": x = layers.MultiHeadAttention( - num_heads=num_heads, key_dim=embedding_dim, dropout=attention_dropout + num_heads=num_heads, + key_dim=embedding_dim, + dropout=attention_dropout, )(x, x) x = layers.add([x, residual_1]) residual_2 = x @@ -284,7 +281,7 @@ def get_model(attention_type="external_attention"): attention_type, ) - x = layers.GlobalAvgPool1D()(x) + x = layers.GlobalAveragePooling1D()(x) outputs = layers.Dense(num_classes, activation="softmax")(x) model = keras.Model(inputs=inputs, outputs=outputs) return model @@ -301,7 +298,7 @@ model = get_model(attention_type="external_attention") model.compile( loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing), - optimizer=tfa.optimizers.AdamW( + optimizer=keras.optimizers.AdamW( learning_rate=learning_rate, weight_decay=weight_decay ), metrics=[ @@ -321,106 +318,7 @@ history = model.fit(
``` -Epoch 1/50 -313/313 [==============================] - 40s 95ms/step - loss: 4.2091 - accuracy: 0.0723 - top-5-accuracy: 0.2384 - val_loss: 3.9706 - val_accuracy: 0.1153 - val_top-5-accuracy: 0.3336 -Epoch 2/50 -313/313 [==============================] - 29s 91ms/step - loss: 3.8028 - accuracy: 0.1427 - top-5-accuracy: 0.3871 - val_loss: 3.6672 - val_accuracy: 0.1829 - val_top-5-accuracy: 0.4513 -Epoch 3/50 -313/313 [==============================] - 29s 93ms/step - loss: 3.5493 - accuracy: 0.1978 - top-5-accuracy: 0.4805 - val_loss: 3.5402 - val_accuracy: 0.2141 - val_top-5-accuracy: 0.5038 -Epoch 4/50 -313/313 [==============================] - 29s 93ms/step - loss: 3.4029 - accuracy: 0.2355 - top-5-accuracy: 0.5328 - val_loss: 3.4496 - val_accuracy: 0.2354 - val_top-5-accuracy: 0.5316 -Epoch 5/50 -313/313 [==============================] - 29s 92ms/step - loss: 3.2917 - accuracy: 0.2636 - top-5-accuracy: 0.5678 - val_loss: 3.3342 - val_accuracy: 0.2699 - val_top-5-accuracy: 0.5679 -Epoch 6/50 -313/313 [==============================] - 29s 92ms/step - loss: 3.2116 - accuracy: 0.2830 - top-5-accuracy: 0.5921 - val_loss: 3.2896 - val_accuracy: 0.2749 - val_top-5-accuracy: 0.5874 -Epoch 7/50 -313/313 [==============================] - 28s 90ms/step - loss: 3.1453 - accuracy: 0.2980 - top-5-accuracy: 0.6100 - val_loss: 3.3090 - val_accuracy: 0.2857 - val_top-5-accuracy: 0.5831 -Epoch 8/50 -313/313 [==============================] - 29s 94ms/step - loss: 3.0889 - accuracy: 0.3121 - top-5-accuracy: 0.6266 - val_loss: 3.1969 - val_accuracy: 0.2975 - val_top-5-accuracy: 0.6082 -Epoch 9/50 -313/313 [==============================] - 29s 92ms/step - loss: 3.0390 - accuracy: 0.3252 - top-5-accuracy: 0.6441 - val_loss: 3.1249 - val_accuracy: 0.3175 - val_top-5-accuracy: 0.6330 -Epoch 10/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.9871 - accuracy: 0.3365 - top-5-accuracy: 0.6615 - val_loss: 3.1121 - val_accuracy: 0.3200 - val_top-5-accuracy: 0.6374 -Epoch 11/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.9476 - accuracy: 0.3489 - top-5-accuracy: 0.6697 - val_loss: 3.1156 - val_accuracy: 0.3268 - val_top-5-accuracy: 0.6421 -Epoch 12/50 -313/313 [==============================] - 29s 91ms/step - loss: 2.9106 - accuracy: 0.3576 - top-5-accuracy: 0.6783 - val_loss: 3.1337 - val_accuracy: 0.3226 - val_top-5-accuracy: 0.6389 -Epoch 13/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.8772 - accuracy: 0.3662 - top-5-accuracy: 0.6871 - val_loss: 3.0373 - val_accuracy: 0.3348 - val_top-5-accuracy: 0.6624 -Epoch 14/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.8508 - accuracy: 0.3756 - top-5-accuracy: 0.6944 - val_loss: 3.0297 - val_accuracy: 0.3441 - val_top-5-accuracy: 0.6643 -Epoch 15/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.8211 - accuracy: 0.3821 - top-5-accuracy: 0.7034 - val_loss: 2.9680 - val_accuracy: 0.3604 - val_top-5-accuracy: 0.6847 -Epoch 16/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.8017 - accuracy: 0.3864 - top-5-accuracy: 0.7090 - val_loss: 2.9746 - val_accuracy: 0.3584 - val_top-5-accuracy: 0.6855 -Epoch 17/50 -313/313 [==============================] - 29s 91ms/step - loss: 2.7714 - accuracy: 0.3962 - top-5-accuracy: 0.7169 - val_loss: 2.9104 - val_accuracy: 0.3738 - val_top-5-accuracy: 0.6940 -Epoch 18/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.7523 - accuracy: 0.4008 - top-5-accuracy: 0.7204 - val_loss: 2.8560 - val_accuracy: 0.3861 - val_top-5-accuracy: 0.7115 -Epoch 19/50 -313/313 [==============================] - 28s 91ms/step - loss: 2.7320 - accuracy: 0.4051 - top-5-accuracy: 0.7263 - val_loss: 2.8780 - val_accuracy: 0.3820 - val_top-5-accuracy: 0.7101 -Epoch 20/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.7139 - accuracy: 0.4114 - top-5-accuracy: 0.7290 - val_loss: 2.9831 - val_accuracy: 0.3694 - val_top-5-accuracy: 0.6922 -Epoch 21/50 -313/313 [==============================] - 28s 91ms/step - loss: 2.6991 - accuracy: 0.4142 - top-5-accuracy: 0.7335 - val_loss: 2.8420 - val_accuracy: 0.3968 - val_top-5-accuracy: 0.7138 -Epoch 22/50 -313/313 [==============================] - 29s 91ms/step - loss: 2.6842 - accuracy: 0.4195 - top-5-accuracy: 0.7377 - val_loss: 2.7965 - val_accuracy: 0.4088 - val_top-5-accuracy: 0.7266 -Epoch 23/50 -313/313 [==============================] - 28s 91ms/step - loss: 2.6571 - accuracy: 0.4273 - top-5-accuracy: 0.7436 - val_loss: 2.8620 - val_accuracy: 0.3947 - val_top-5-accuracy: 0.7155 -Epoch 24/50 -313/313 [==============================] - 29s 91ms/step - loss: 2.6508 - accuracy: 0.4277 - top-5-accuracy: 0.7469 - val_loss: 2.8459 - val_accuracy: 0.3963 - val_top-5-accuracy: 0.7150 -Epoch 25/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.6403 - accuracy: 0.4283 - top-5-accuracy: 0.7520 - val_loss: 2.7886 - val_accuracy: 0.4128 - val_top-5-accuracy: 0.7283 -Epoch 26/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.6281 - accuracy: 0.4353 - top-5-accuracy: 0.7523 - val_loss: 2.8493 - val_accuracy: 0.4026 - val_top-5-accuracy: 0.7153 -Epoch 27/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.6092 - accuracy: 0.4403 - top-5-accuracy: 0.7580 - val_loss: 2.7539 - val_accuracy: 0.4186 - val_top-5-accuracy: 0.7392 -Epoch 28/50 -313/313 [==============================] - 29s 91ms/step - loss: 2.5992 - accuracy: 0.4423 - top-5-accuracy: 0.7600 - val_loss: 2.8625 - val_accuracy: 0.3964 - val_top-5-accuracy: 0.7174 -Epoch 29/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.5913 - accuracy: 0.4456 - top-5-accuracy: 0.7598 - val_loss: 2.7911 - val_accuracy: 0.4162 - val_top-5-accuracy: 0.7329 -Epoch 30/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.5780 - accuracy: 0.4480 - top-5-accuracy: 0.7649 - val_loss: 2.8158 - val_accuracy: 0.4118 - val_top-5-accuracy: 0.7288 -Epoch 31/50 -313/313 [==============================] - 28s 91ms/step - loss: 2.5657 - accuracy: 0.4547 - top-5-accuracy: 0.7661 - val_loss: 2.8651 - val_accuracy: 0.4056 - val_top-5-accuracy: 0.7217 -Epoch 32/50 -313/313 [==============================] - 29s 91ms/step - loss: 2.5637 - accuracy: 0.4480 - top-5-accuracy: 0.7681 - val_loss: 2.8190 - val_accuracy: 0.4094 - val_top-5-accuracy: 0.7267 -Epoch 33/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.5525 - accuracy: 0.4545 - top-5-accuracy: 0.7693 - val_loss: 2.7985 - val_accuracy: 0.4216 - val_top-5-accuracy: 0.7303 -Epoch 34/50 -313/313 [==============================] - 28s 91ms/step - loss: 2.5462 - accuracy: 0.4579 - top-5-accuracy: 0.7721 - val_loss: 2.8865 - val_accuracy: 0.4016 - val_top-5-accuracy: 0.7204 -Epoch 35/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.5329 - accuracy: 0.4616 - top-5-accuracy: 0.7740 - val_loss: 2.7862 - val_accuracy: 0.4232 - val_top-5-accuracy: 0.7389 -Epoch 36/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.5234 - accuracy: 0.4610 - top-5-accuracy: 0.7765 - val_loss: 2.8234 - val_accuracy: 0.4134 - val_top-5-accuracy: 0.7312 -Epoch 37/50 -313/313 [==============================] - 29s 91ms/step - loss: 2.5152 - accuracy: 0.4663 - top-5-accuracy: 0.7774 - val_loss: 2.7894 - val_accuracy: 0.4161 - val_top-5-accuracy: 0.7376 -Epoch 38/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.5117 - accuracy: 0.4674 - top-5-accuracy: 0.7790 - val_loss: 2.8091 - val_accuracy: 0.4142 - val_top-5-accuracy: 0.7360 -Epoch 39/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.5047 - accuracy: 0.4681 - top-5-accuracy: 0.7805 - val_loss: 2.8199 - val_accuracy: 0.4167 - val_top-5-accuracy: 0.7299 -Epoch 40/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.4974 - accuracy: 0.4697 - top-5-accuracy: 0.7819 - val_loss: 2.7864 - val_accuracy: 0.4247 - val_top-5-accuracy: 0.7402 -Epoch 41/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.4889 - accuracy: 0.4749 - top-5-accuracy: 0.7854 - val_loss: 2.8120 - val_accuracy: 0.4217 - val_top-5-accuracy: 0.7358 -Epoch 42/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.4799 - accuracy: 0.4771 - top-5-accuracy: 0.7866 - val_loss: 2.9003 - val_accuracy: 0.4038 - val_top-5-accuracy: 0.7170 -Epoch 43/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.4814 - accuracy: 0.4770 - top-5-accuracy: 0.7868 - val_loss: 2.7504 - val_accuracy: 0.4260 - val_top-5-accuracy: 0.7457 -Epoch 44/50 -313/313 [==============================] - 28s 91ms/step - loss: 2.4747 - accuracy: 0.4757 - top-5-accuracy: 0.7870 - val_loss: 2.8207 - val_accuracy: 0.4166 - val_top-5-accuracy: 0.7363 -Epoch 45/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.4653 - accuracy: 0.4809 - top-5-accuracy: 0.7924 - val_loss: 2.8663 - val_accuracy: 0.4130 - val_top-5-accuracy: 0.7209 -Epoch 46/50 -313/313 [==============================] - 28s 90ms/step - loss: 2.4554 - accuracy: 0.4825 - top-5-accuracy: 0.7929 - val_loss: 2.8145 - val_accuracy: 0.4250 - val_top-5-accuracy: 0.7357 -Epoch 47/50 -313/313 [==============================] - 29s 91ms/step - loss: 2.4602 - accuracy: 0.4823 - top-5-accuracy: 0.7919 - val_loss: 2.8352 - val_accuracy: 0.4189 - val_top-5-accuracy: 0.7365 -Epoch 48/50 -313/313 [==============================] - 28s 91ms/step - loss: 2.4493 - accuracy: 0.4848 - top-5-accuracy: 0.7933 - val_loss: 2.8246 - val_accuracy: 0.4160 - val_top-5-accuracy: 0.7362 -Epoch 49/50 -313/313 [==============================] - 28s 91ms/step - loss: 2.4454 - accuracy: 0.4846 - top-5-accuracy: 0.7958 - val_loss: 2.7731 - val_accuracy: 0.4320 - val_top-5-accuracy: 0.7436 -Epoch 50/50 -313/313 [==============================] - 29s 92ms/step - loss: 2.4418 - accuracy: 0.4848 - top-5-accuracy: 0.7951 - val_loss: 2.7926 - val_accuracy: 0.4317 - val_top-5-accuracy: 0.7410 + 313/313 ━━━━━━━━━━━━━━━━━━━━ 959s 3s/step - accuracy: 0.0396 - loss: 4.4656 - top-5-accuracy: 0.1473 - val_accuracy: 0.0716 - val_loss: 4.4896 - val_top-5-accuracy: 0.2253 ```
@@ -440,7 +338,7 @@ plt.show() -![png](/img/examples/vision/EANet/EANet_24_0.png) +![png](/img/examples/vision/eanet/eanet_24_0.png) @@ -456,10 +354,10 @@ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
``` -313/313 [==============================] - 6s 21ms/step - loss: 2.7574 - accuracy: 0.4391 - top-5-accuracy: 0.7471 -Test loss: 2.76 -Test accuracy: 43.91% -Test top 5 accuracy: 74.71% + 313/313 ━━━━━━━━━━━━━━━━━━━━ 69s 210ms/step - accuracy: 0.0691 - loss: 4.4804 - top-5-accuracy: 0.2291 +Test loss: 4.47 +Test accuracy: 7.26% +Test top 5 accuracy: 23.33% ```
@@ -469,6 +367,7 @@ training 50 epochs, but with 0.6M parameters. Under the same experimental enviro and the same hyperparameters, The EANet model we just trained has just 0.3M parameters, and it gets us to ~73% test top-5 accuracy and ~43% top-1 accuracy. This fully demonstrates the effectiveness of external attention. + We only show the training process of EANet, you can train Vit under the same experimental conditions and observe the test results.