diff --git a/.travis.yml b/.travis.yml index 2bf6a1e54..63cda3d39 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,7 +26,7 @@ env: # Backward Compatibility in insured for release less than 1 year old. # https://pypi.org/project/tensorflow/#history matrix: - - _TF_VERSION=2.0.0b1 + - _TF_VERSION=2.0.0-rc1 # - _TF_VERSION=1.12.0 # Remove on Oct 22, 2019 # - _TF_VERSION=1.11.0 # Remove on Sep 28, 2019 # - _TF_VERSION=1.10.1 # Remove on Aug 24, 2019 @@ -63,7 +63,7 @@ matrix: install: - | if [[ -v _DOC_AND_YAPF_TEST ]]; then - pip install tensorflow==2.0.0b1 + pip install tensorflow==2.0.0-rc1 pip install yapf pip install -e .[doc] else @@ -101,7 +101,7 @@ deploy: on: tags: true python: '3.6' - condition: '$_TF_VERSION = 2.0.0b1' + condition: '$_TF_VERSION = 2.0.0-rc1' # condition: '$_TF_VERSION = 1.11.0' # Documentation: https://docs.travis-ci.com/user/deployment/releases/ @@ -115,5 +115,5 @@ deploy: on: tags: true python: '3.6' - condition: '$_TF_VERSION = 2.0.0b1' + condition: '$_TF_VERSION = 2.0.0-rc1' # condition: '$_TF_VERSION = 1.11.0' diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e09387c3..f6782b064 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,7 +79,6 @@ To release a new version, please update the changelog as followed: ### Deprecated ### Fixed -- RNN updates: remove warnings, fix if seq_len=0, unitest (#PR 1033) ### Removed @@ -88,7 +87,11 @@ To release a new version, please update the changelog as followed: ### Contributors -## [2.2.1] +## [2.2.0] - 2019-09-13 + +TensorLayer 2.2.0 is a maintenance release. +It contains numerous API improvement and bug fixes. +This release is compatible with TensorFlow 2 RC1. ### Added - Support nested layer customization (#PR 1015) @@ -96,13 +99,18 @@ To release a new version, please update the changelog as followed: - Support Dynamic RNN in RNN (#PR 1023) - Add ResNet50 static model (#PR 1030) - Add Transformer model (#PR 1027) +- Add performance test code in static model (#PR 1041) ### Changed - `SpatialTransform2dAffine` auto `in_channels` -- support TensorFlow 2.0.0-beta1 +- support TensorFlow 2.0.0-rc1 - Update model weights property, now returns its copy (#PR 1010) +### Fixed +- RNN updates: remove warnings, fix if seq_len=0, unitest (#PR 1033) +- BN updates: fix BatchNorm1d for 2D data, refactored (#PR 1040) + ### Dependencies Update ### Deprecated @@ -116,6 +124,7 @@ To release a new version, please update the changelog as followed: - Copy original model's `trainable_weights` and `nontrainable_weights` when initializing `LayerList` (#PR 1029) - Remove redundant parts in `model.all_layers` (#PR 1029) - Replace `tf.image.resize_image_with_crop_or_pad` with `tf.image.resize_with_crop_or_pad` (#PR 1032) +- Fix a bug in `ResNet50` static model (#PR 1041) ### Removed @@ -129,6 +138,7 @@ To release a new version, please update the changelog as followed: - @ArnoldLIULJ: #1023 #1027 - @JingqingZ: #1023 #1027 + ## [2.1.0] ### Changed @@ -199,15 +209,12 @@ A maintain release. - @warshallrho: #PR966 - @zsdonghao: #931 - @yd-yin: #963 -<<<<<<< HEAD - @Tokarev-TT-33: # 995 - @initial-h: # 995 - @quantumiracle: #995 - @Officium: #995 -======= - @1FengL: #958 - @dvklopfenstein: #971 ->>>>>>> 560dbb8a17963023a3b1d59a79e1c2752530114a ## [2.0.0] - 2019-05-04 @@ -560,7 +567,7 @@ To many PR for this update, please check [here](https://github.com/tensorlayer/t @zsdonghao @luomai @DEKHTIARJonathan [Unreleased]: https://github.com/tensorlayer/tensorlayer/compare/2.0....master -[2.1.1]: https://github.com/tensorlayer/tensorlayer/compare/2.1.1...2.1.1 +[2.2.0]: https://github.com/tensorlayer/tensorlayer/compare/2.2.0...2.2.0 [2.1.0]: https://github.com/tensorlayer/tensorlayer/compare/2.1.0...2.1.0 [2.0.2]: https://github.com/tensorlayer/tensorlayer/compare/2.0.2...2.0.2 [2.0.1]: https://github.com/tensorlayer/tensorlayer/compare/2.0.1...2.0.1 diff --git a/README.md b/README.md index c6751ea76..87ef86315 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,10 @@
-TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides a large collection of customizable neural layers / functions that are key to build real-world AI applications. TensorLayer is awarded the 2017 Best Open Source Software by the [ACM Multimedia Society](https://twitter.com/ImperialDSI/status/923928895325442049). +TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extensive collection of customizable neural layers to build complex AI models. TensorLayer is awarded the 2017 Best Open Source Software by the [ACM Multimedia Society](https://twitter.com/ImperialDSI/status/923928895325442049). +TensorLayer can also be found at [iHub](https://code.ihub.org.cn/projects/328) and [Gitee](https://gitee.com/organizations/TensorLayer). + +# News 🔥📰🔥 Reinforcement Learning Model Zoos: [Low-level APIs for Research](https://github.com/tensorlayer/tensorlayer/tree/master/examples/reinforcement_learning) and [High-level APIs for Production](https://github.com/tensorlayer/RLzoo) @@ -42,37 +45,32 @@ TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning 🔥📰🔥 [NNoM](https://github.com/majianjia/nnom): Run TensorLayer quantized models on the **MCU** (e.g., STM32) (Coming Soon) - # Features -As deep learning practitioners, we have been looking for a library that can address various development - purposes. This library is easy to adopt by providing diverse examples, tutorials and pre-trained models. -Also, it allow users to easily fine-tune TensorFlow; while being suitable for production deployment. TensorLayer aims to satisfy all these purposes. It has three key features: +TensorLayer is a new deep learning library designed with simplicity, flexibility and high-performance in mind. -- ***Simplicity*** : TensorLayer lifts the low-level dataflow interface of TensorFlow to *high-level* layers / models. It is very easy to learn through the rich [example codes](https://github.com/tensorlayer/awesome-tensorlayer) contributed by a wide community. -- ***Flexibility*** : TensorLayer APIs are transparent: it does not mask TensorFlow from users; but leaving massive hooks that help *low-level tuning* and *deep customization*. -- ***Zero-cost Abstraction*** : TensorLayer can achieve the *full power* of TensorFlow. The following table shows the training speeds of [VGG16](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) using TensorLayer and native TensorFlow on a TITAN Xp. +- ***Simplicity*** : TensorLayer has a high-level layer/model abstraction which is effortless to learn. You can learn how deep learning can benefit your AI tasks in minutes through the massive [examples](https://github.com/tensorlayer/awesome-tensorlayer). +- ***Flexibility*** : TensorLayer APIs are transparent and flexible, inspired by the emerging PyTorch library. Compared to the Keras abstraction, TensorLayer makes it much easier to build and train complex AI models. +- ***Zero-cost Abstraction*** : Though simple to use, TensorLayer does not require you to make any compromise in the performance of TensorFlow (Check the following benchmark section for more details). - | Mode | Lib | Data Format | Max GPU Memory Usage(MB) |Max CPU Memory Usage(MB) | Avg CPU Memory Usage(MB) | Runtime (sec) | - | :-------: | :-------------: | :-----------: | :-----------------: | :-----------------: | :-----------------: | :-----------: | - | AutoGraph | TensorFlow 2.0 | channel last | 11833 | 2161 | 2136 | 74 | - | | Tensorlayer 2.0 | channel last | 11833 | 2187 | 2169 | 76 | - | Graph | Keras | channel last | 8677 | 2580 | 2576 | 101 | - | Eager | TensorFlow 2.0 | channel last | 8723 | 2052 | 2024 | 97 | - | | TensorLayer 2.0 | channel last | 8723 | 2010 | 2007 | 95 | +TensorLayer is NOT yet another library in the TensorFlow world. Other wrappers like Keras and TFLearn +hide many powerful features of TensorFlow and provide little support for writing custom, complex AI models. Inspired by PyTorch, TensorLayer APIs are simple, flexible and most importantly, pythonic. +TensorLayer has a fast-growing community. It has been used by researchers and engineers all over the world, including those from Peking University, +Imperial College London, UC Berkeley, Carnegie Mellon University, Stanford University, and companies like Google, Microsoft, Alibaba, Tencent, Xiaomi, and Bloomberg. +# Multilingual documents -TensorLayer stands at a unique spot in the library landscape. Other wrapper libraries like Keras and TFLearn also provide high-level abstractions. They, however, often -hide the underlying engine from users, which make them hard to customize -and fine-tune. On the contrary, TensorLayer APIs are generally lightweight, flexible and transparent. -Users often find it easy to start with the examples and tutorials, and then dive -into TensorFlow seamlessly. In addition, TensorLayer does not create library lock-in through native supports for importing components from Keras. +TensorLayer has extensive documentation for both beginners and professionals. The documentation is available in +both English and Chinese. -TensorLayer has a fast growing usage among top researchers and engineers, from universities like Peking University, -Imperial College London, UC Berkeley, Carnegie Mellon University, Stanford University, and -University of Technology of Compiegne (UTC), and companies like Google, Microsoft, Alibaba, Tencent, Xiaomi, and Bloomberg. +[![English Documentation](https://img.shields.io/badge/documentation-english-blue.svg)](https://tensorlayer.readthedocs.io/) +[![Chinese Documentation](https://img.shields.io/badge/documentation-%E4%B8%AD%E6%96%87-blue.svg)](https://tensorlayercn.readthedocs.io/) +[![Chinese Book](https://img.shields.io/badge/book-%E4%B8%AD%E6%96%87-blue.svg)](http://www.broadview.com.cn/book/5059/) + +If you want to try the experimental features on the the master branch, you can find the latest document +[here](https://tensorlayer.readthedocs.io/en/latest/). -# Tutorials and Real-World Applications +# Extensive examples You can find a large collection of tutorials, examples and real-world applications using TensorLayer within [examples](examples/) or through the following space: @@ -82,73 +80,42 @@ You can find a large collection of tutorials, examples and real-world applicatio -# Documentation - -TensorLayer has extensive documentation for both beginners and professionals. The documentation is available in -both English and Chinese. Please click the following icons to find the documents you need: - -[![English Documentation](https://img.shields.io/badge/documentation-english-blue.svg)](https://tensorlayer.readthedocs.io/) -[![Chinese Documentation](https://img.shields.io/badge/documentation-%E4%B8%AD%E6%96%87-blue.svg)](https://tensorlayercn.readthedocs.io/) -[![Chinese Book](https://img.shields.io/badge/book-%E4%B8%AD%E6%96%87-blue.svg)](http://www.broadview.com.cn/book/5059/) - -If you want to try the experimental features on the the master branch, you can find the latest document -[here](https://tensorlayer.readthedocs.io/en/latest/). - -# Install +# Installing TensorLayer is easy -For latest code for TensorLayer 2.0, please build from the source. TensorLayer 2.0 has pre-requisites including TensorFlow 2, numpy, and others. For GPU support, CUDA and cuDNN are required. +TensorLayer 2.0 relies on TensorFlow, numpy, and others. To use GPUs, CUDA and cuDNN are required. Install TensorFlow: ```bash -pip3 install tensorflow-gpu==2.0.0-beta1 # specific version (YOU SHOULD INSTALL THIS ONE NOW) -pip3 install tensorflow-gpu # GPU version +pip3 install tensorflow-gpu==2.0.0-rc1 # TensorFlow GPU (version 2.0 RC1) pip3 install tensorflow # CPU version ``` -Install the stable version of TensorLayer: +Install the stable release of TensorLayer: ```bash pip3 install tensorlayer ``` -Install the latest version of TensorLayer: +Install the unstable development version of TensorLayer: ```bash pip3 install git+https://github.com/tensorlayer/tensorlayer.git -or -pip3 install https://github.com/tensorlayer/tensorlayer/archive/master.zip ``` -For developers, you should clone the folder to your local machine and put it along with your project scripts. - +If you want to install the additional dependencies, you can also run ```bash -git clone https://github.com/tensorlayer/tensorlayer.git -``` - -If you want install TensorLayer 1.X, the simplest way to install TensorLayer 1.X is to use the **Py**thon **P**ackage **I**ndex (PyPI): - -```bash -# for last stable version of TensorLayer 1.X -pip3 install --upgrade tensorlayer==1.X - -# for latest release candidate of TensorLayer 1.X -pip3 install --upgrade --pre tensorlayer - -# if you want to install the additional dependencies, you can also run pip3 install --upgrade tensorlayer[all] # all additional dependencies pip3 install --upgrade tensorlayer[extra] # only the `extra` dependencies pip3 install --upgrade tensorlayer[contrib_loggers] # only the `contrib_loggers` dependencies ``` - +# Benchmark + +The following table shows the training speeds of [VGG16](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) using TensorLayer and native TensorFlow on a TITAN Xp. + +| Mode | Lib | Data Format | Max GPU Memory Usage(MB) |Max CPU Memory Usage(MB) | Avg CPU Memory Usage(MB) | Runtime (sec) | +| :-------: | :-------------: | :-----------: | :-----------------: | :-----------------: | :-----------------: | :-----------: | +| AutoGraph | TensorFlow 2.0 | channel last | 11833 | 2161 | 2136 | 74 | +| | Tensorlayer 2.0 | channel last | 11833 | 2187 | 2169 | 76 | +| Graph | Keras | channel last | 8677 | 2580 | 2576 | 101 | +| Eager | TensorFlow 2.0 | channel last | 8723 | 2052 | 2024 | 97 | +| | TensorLayer 2.0 | channel last | 8723 | 2010 | 2007 | 95 | + # Contribute Please read the [Contributor Guideline](CONTRIBUTING.md) before submitting your PRs. @@ -201,4 +180,4 @@ If you use TensorLayer for any projects, please cite this paper: # License -TensorLayer is released under the Apache 2.0 license. We also host TensorLayer on [iHub](https://code.ihub.org.cn/projects/328) and [Gitee](https://gitee.com/organizations/TensorLayer). +TensorLayer is released under the Apache 2.0 license. diff --git a/tensorlayer/files/utils.py b/tensorlayer/files/utils.py index 242590c04..1a3a429f0 100644 --- a/tensorlayer/files/utils.py +++ b/tensorlayer/files/utils.py @@ -2666,6 +2666,10 @@ def _load_weights_from_hdf5_group(f, layers, skip=False): elif isinstance(layer, tl.layers.Layer): weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] for iid, w_name in enumerate(weight_names): + # FIXME : this is only for compatibility + if isinstance(layer, tl.layers.BatchNorm) and np.asarray(g[w_name]).ndim > 1: + assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]).squeeze()) + continue assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name])) else: raise Exception("Only layer or model can be saved into hdf5.") diff --git a/tensorlayer/layers/normalization.py b/tensorlayer/layers/normalization.py index 226795981..a609f5671 100644 --- a/tensorlayer/layers/normalization.py +++ b/tensorlayer/layers/normalization.py @@ -108,6 +108,19 @@ def _bias_add(x, b, data_format): def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None): """Data Format aware version of tf.nn.batch_normalization.""" + if data_format == 'channels_last': + mean = tf.reshape(mean, [1] * (len(x.shape) - 1) + [-1]) + variance = tf.reshape(variance, [1] * (len(x.shape) - 1) + [-1]) + offset = tf.reshape(offset, [1] * (len(x.shape) - 1) + [-1]) + scale = tf.reshape(scale, [1] * (len(x.shape) - 1) + [-1]) + elif data_format == 'channels_first': + mean = tf.reshape(mean, [1] + [-1] + [1] * (len(x.shape) - 2)) + variance = tf.reshape(variance, [1] + [-1] + [1] * (len(x.shape) - 2)) + offset = tf.reshape(offset, [1] + [-1] + [1] * (len(x.shape) - 2)) + scale = tf.reshape(scale, [1] + [-1] + [1] * (len(x.shape) - 2)) + else: + raise ValueError('invalid data_format: %s' % data_format) + with ops.name_scope(name, 'batchnorm', [x, mean, variance, scale, offset]): inv = math_ops.rsqrt(variance + variance_epsilon) if scale is not None: @@ -204,13 +217,10 @@ def __init__( self.moving_var_init = moving_var_init self.num_features = num_features + self.channel_axis = -1 if data_format == 'channels_last' else 1 + self.axes = None + if num_features is not None: - if not isinstance(self, BatchNorm1d) and not isinstance(self, BatchNorm2d) and not isinstance(self, - BatchNorm3d): - raise ValueError( - "Please use BatchNorm1d or BatchNorm2d or BatchNorm3d instead of BatchNorm " - "if you want to specify 'num_features'." - ) self.build(None) self._built = True @@ -233,21 +243,23 @@ def __repr__(self): def _get_param_shape(self, inputs_shape): if self.data_format == 'channels_last': - axis = len(inputs_shape) - 1 + axis = -1 elif self.data_format == 'channels_first': axis = 1 else: raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) channels = inputs_shape[axis] - params_shape = [1] * len(inputs_shape) - params_shape[axis] = channels + params_shape = [channels] - axes = [i for i in range(len(inputs_shape)) if i != axis] - return params_shape, axes + return params_shape + + def _check_input_shape(self, inputs): + if inputs.ndim <= 1: + raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim)) def build(self, inputs_shape): - params_shape, self.axes = self._get_param_shape(inputs_shape) + params_shape = [self.num_features] if self.num_features is not None else self._get_param_shape(inputs_shape) self.beta, self.gamma = None, None if self.beta_init: @@ -264,7 +276,12 @@ def build(self, inputs_shape): ) def forward(self, inputs): - mean, var = tf.nn.moments(inputs, self.axes, keepdims=True) + self._check_input_shape(inputs) + + if self.axes is None: + self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis] + + mean, var = tf.nn.moments(inputs, self.axes, keepdims=False) if self.is_train: # update moving_mean and moving_var self.moving_mean = moving_averages.assign_moving_average( @@ -282,8 +299,8 @@ def forward(self, inputs): class BatchNorm1d(BatchNorm): - """The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D - inputs with additional channel dimension), of shape (N, L, C) or (N, C, L). + """The :class:`BatchNorm1d` applies Batch Normalization over 2D/3D input (a mini-batch of 1D + inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L). See more details in :class:`BatchNorm`. Examples @@ -299,23 +316,9 @@ class BatchNorm1d(BatchNorm): """ - def _get_param_shape(self, inputs_shape): - if self.data_format == 'channels_last': - axis = 2 - elif self.data_format == 'channels_first': - axis = 1 - else: - raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) - - if self.num_features is None: - channels = inputs_shape[axis] - else: - channels = self.num_features - params_shape = [1] * 3 - params_shape[axis] = channels - - axes = [i for i in range(3) if i != axis] - return params_shape, axes + def _check_input_shape(self, inputs): + if inputs.ndim != 2 and inputs.ndim != 3: + raise ValueError('expected input to be 2D or 3D, but got {}D input'.format(inputs.ndim)) class BatchNorm2d(BatchNorm): @@ -336,23 +339,9 @@ class BatchNorm2d(BatchNorm): """ - def _get_param_shape(self, inputs_shape): - if self.data_format == 'channels_last': - axis = 3 - elif self.data_format == 'channels_first': - axis = 1 - else: - raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) - - if self.num_features is None: - channels = inputs_shape[axis] - else: - channels = self.num_features - params_shape = [1] * 4 - params_shape[axis] = channels - - axes = [i for i in range(4) if i != axis] - return params_shape, axes + def _check_input_shape(self, inputs): + if inputs.ndim != 4: + raise ValueError('expected input to be 4D, but got {}D input'.format(inputs.ndim)) class BatchNorm3d(BatchNorm): @@ -373,23 +362,9 @@ class BatchNorm3d(BatchNorm): """ - def _get_param_shape(self, inputs_shape): - if self.data_format == 'channels_last': - axis = 4 - elif self.data_format == 'channels_first': - axis = 1 - else: - raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) - - if self.num_features is None: - channels = inputs_shape[axis] - else: - channels = self.num_features - params_shape = [1] * 5 - params_shape[axis] = channels - - axes = [i for i in range(5) if i != axis] - return params_shape, axes + def _check_input_shape(self, inputs): + if inputs.ndim != 5: + raise ValueError('expected input to be 5D, but got {}D input'.format(inputs.ndim)) class InstanceNorm(Layer): diff --git a/tensorlayer/models/mobilenetv1.py b/tensorlayer/models/mobilenetv1.py index 4908b3d89..82ea7be46 100644 --- a/tensorlayer/models/mobilenetv1.py +++ b/tensorlayer/models/mobilenetv1.py @@ -43,9 +43,9 @@ def restore_params(network, path='models'): expected_bytes=25600116 ) # ls -al params = load_npz(name=os.path.join(path, 'mobilenet.npz')) - for idx, net_weight in enumerate(network.all_weights): - if 'batchnorm' in net_weight.name: - params[idx] = params[idx].reshape(1, 1, 1, -1) + # for idx, net_weight in enumerate(network.all_weights): + # if 'batchnorm' in net_weight.name: + # params[idx] = params[idx].reshape(1, 1, 1, -1) assign_weights(params[:len(network.all_weights)], network) del params diff --git a/tensorlayer/models/resnet.py b/tensorlayer/models/resnet.py index 9938fd1cd..7df069468 100644 --- a/tensorlayer/models/resnet.py +++ b/tensorlayer/models/resnet.py @@ -150,21 +150,21 @@ def ResNet50(pretrained=False, end_with='fc1000', n_classes=1000, name=None): n = BatchNorm(name='bn_conv1', act='relu')(n) n = MaxPool2d((3, 3), strides=(2, 2), name='max_pool1')(n) - for i, name in enumerate(block_names): - if len(name) == 2: - stage = int(name[0]) - block = name[1] + for i, block_name in enumerate(block_names): + if len(block_name) == 2: + stage = int(block_name[0]) + block = block_name[1] if block == 'a': strides = (1, 1) if stage == 2 else (2, 2) n = conv_block(n, 3, block_filters[stage - 2], stage=stage, block=block, strides=strides) else: n = identity_block(n, 3, block_filters[stage - 2], stage=stage, block=block) - elif name == 'avg_pool': + elif block_name == 'avg_pool': n = GlobalMeanPool2d(name='avg_pool')(n) - elif name == 'fc1000': + elif block_name == 'fc1000': n = Dense(n_classes, name='fc1000')(n) - if name == end_with: + if block_name == end_with: break network = Model(inputs=ni, outputs=n, name=name) @@ -194,8 +194,8 @@ def restore_params(network, path='models'): continue w_names = list(f[layer.name]) params = [f[layer.name][n][:] for n in w_names] - if 'bn' in layer.name: - params = [x.reshape(1, 1, 1, -1) for x in params] + # if 'bn' in layer.name: + # params = [x.reshape(1, 1, 1, -1) for x in params] assign_weights(params, layer) del params diff --git a/tests/layers/test_layers_normalization.py b/tests/layers/test_layers_normalization.py index a25e47f76..c223f61ed 100644 --- a/tests/layers/test_layers_normalization.py +++ b/tests/layers/test_layers_normalization.py @@ -18,11 +18,13 @@ class Laye_BatchNorm_Test(CustomTestCase): @classmethod def setUpClass(cls): + x_0_input_shape = [None, 10] x_1_input_shape = [None, 100, 1] x_2_input_shape = [None, 100, 100, 3] x_3_input_shape = [None, 100, 100, 100, 3] batchsize = 2 + cls.x0 = tf.random.normal([batchsize] + x_0_input_shape[1:]) cls.x1 = tf.random.normal([batchsize] + x_1_input_shape[1:]) cls.x2 = tf.random.normal([batchsize] + x_2_input_shape[1:]) cls.x3 = tf.random.normal([batchsize] + x_3_input_shape[1:]) @@ -36,16 +38,58 @@ def setUpClass(cls): ni_2 = Input(x_2_input_shape, name='test_ni2') nn_2 = Conv2d(n_filter=32, filter_size=(3, 3), strides=(2, 2), name='test_conv2d')(ni_2) - n2_b = BatchNorm2d(name='test_bn2d')(nn_2) + n2_b = BatchNorm(name='test_bn2d')(nn_2) cls.n2_b = n2_b cls.base_2d = Model(inputs=ni_2, outputs=n2_b, name='test_base_2d') ni_3 = Input(x_3_input_shape, name='test_ni2') nn_3 = Conv3d(n_filter=32, filter_size=(3, 3, 3), strides=(2, 2, 2), name='test_conv3d')(ni_3) - n3_b = BatchNorm3d(name='test_bn3d')(nn_3) + n3_b = BatchNorm(name='test_bn3d')(nn_3) cls.n3_b = n3_b cls.base_3d = Model(inputs=ni_3, outputs=n3_b, name='test_base_3d') + class bn_0d_model(Model): + + def __init__(self): + super(bn_0d_model, self).__init__() + self.fc = Dense(32, in_channels=10) + self.bn = BatchNorm(num_features=32, name='test_bn1d') + + def forward(self, x): + x = self.bn(self.fc(x)) + return x + + dynamic_base = bn_0d_model() + cls.n0_b = dynamic_base(cls.x0, is_train=True) + + ## 0D ======================================================================== + + nin_0 = Input(x_0_input_shape, name='test_in1') + + n0 = Dense(32)(nin_0) + n0 = BatchNorm1d(name='test_bn0d')(n0) + + cls.n0 = n0 + + cls.static_0d = Model(inputs=nin_0, outputs=n0) + + class bn_0d_model(Model): + + def __init__(self): + super(bn_0d_model, self).__init__(name='test_bn_0d_model') + self.fc = Dense(32, in_channels=10) + self.bn = BatchNorm1d(num_features=32, name='test_bn1d') + + def forward(self, x): + x = self.bn(self.fc(x)) + return x + + cls.dynamic_0d = bn_0d_model() + + print("Printing BatchNorm0d") + print(cls.static_0d) + print(cls.dynamic_0d) + ## 1D ======================================================================== nin_1 = Input(x_1_input_shape, name='test_in1') @@ -147,6 +191,14 @@ def test_BatchNorm(self): self.assertEqual(self.n3_b.shape[1:], (50, 50, 50, 32)) out = self.base_3d(self.x3, is_train=True) + self.assertEqual(self.n0_b.shape[1:], (32)) + print("test_BatchNorm OK") + + def test_BatchNorm0d(self): + self.assertEqual(self.n0.shape[1:], (32)) + out = self.static_0d(self.x0, is_train=True) + out = self.dynamic_0d(self.x0, is_train=True) + def test_BatchNorm1d(self): self.assertEqual(self.n1.shape[1:], (50, 32)) out = self.static_1d(self.x1, is_train=True) @@ -189,6 +241,26 @@ def test_exception(self): self.assertIsInstance(e, ValueError) print(e) + def test_input_shape(self): + try: + bn = BatchNorm1d(num_features=32) + out = bn(self.x2) + except Exception as e: + self.assertIsInstance(e, ValueError) + print(e) + try: + bn = BatchNorm2d(num_features=32) + out = bn(self.x3) + except Exception as e: + self.assertIsInstance(e, ValueError) + print(e) + try: + bn = BatchNorm3d(num_features=32) + out = bn(self.x1) + except Exception as e: + self.assertIsInstance(e, ValueError) + print(e) + if __name__ == '__main__': diff --git a/tests/performance_test/vgg/tl2-static-autograph.py b/tests/performance_test/vgg/tl2-static-autograph.py new file mode 100644 index 000000000..0af20adb8 --- /dev/null +++ b/tests/performance_test/vgg/tl2-static-autograph.py @@ -0,0 +1,79 @@ +import time +import os +import psutil +import tensorflow as tf +import tensorlayer as tl +from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE + +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + +tl.logging.set_verbosity(tl.logging.DEBUG) + +# get the whole model +vgg = tl.models.vgg16(mode='static') + +# system monitor +info = psutil.virtual_memory() +monitor_interval = MONITOR_INTERVAL +avg_mem_usage = 0 +max_mem_usage = 0 +count = 0 +total_time = 0 + +# training setting +num_iter = NUM_ITERS +batch_size = BATCH_SIZE +train_weights = vgg.trainable_weights +optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE) +loss_object = tl.cost.cross_entropy + +# data generator +gen = random_input_generator(num_iter, batch_size) + + +# training function +@tf.function +def train_step(x_batch, y_batch): + # forward + backward + with tf.GradientTape() as tape: + ## compute outputs + _logits = vgg(x_batch) + ## compute loss and update model + _loss = loss_object(_logits, y_batch) + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + + +# begin training +vgg.train() + +for idx, data in enumerate(gen): + start_time = time.time() + + train_step(data[0], data[1]) + + end_time = time.time() + consume_time = end_time - start_time + total_time += consume_time + + if idx % monitor_interval == 0: + cur_usage = psutil.Process(os.getpid()).memory_info().rss + max_mem_usage = max(cur_usage, max_mem_usage) + avg_mem_usage += cur_usage + count += 1 + tl.logging.info( + "[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s".format( + idx, cur_usage / (1024 * 1024), consume_time + ) + ) + +print('consumed time:', total_time) + +avg_mem_usage = avg_mem_usage / count / (1024 * 1024) +max_mem_usage = max_mem_usage / (1024 * 1024) +print('average memory usage: {:.2f}MB'.format(avg_mem_usage)) +print('maximum memory usage: {:.2f}MB'.format(max_mem_usage)) diff --git a/tests/performance_test/vgg/tl2-static-eager.py b/tests/performance_test/vgg/tl2-static-eager.py new file mode 100644 index 000000000..b6d5287ba --- /dev/null +++ b/tests/performance_test/vgg/tl2-static-eager.py @@ -0,0 +1,79 @@ +import time +import os +import psutil +import tensorflow as tf +import tensorlayer as tl +from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE + +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + +tl.logging.set_verbosity(tl.logging.DEBUG) + +# get the whole model +vgg = tl.models.vgg16(mode='static') + +# system monitor +info = psutil.virtual_memory() +monitor_interval = MONITOR_INTERVAL +avg_mem_usage = 0 +max_mem_usage = 0 +count = 0 +total_time = 0 + +# training setting +num_iter = NUM_ITERS +batch_size = BATCH_SIZE +train_weights = vgg.trainable_weights +optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE) +loss_object = tl.cost.cross_entropy + +# data generator +gen = random_input_generator(num_iter, batch_size) + + +# training function +def train_step(x_batch, y_batch): + # forward + backward + with tf.GradientTape() as tape: + ## compute outputs + _logits = vgg(x_batch) + ## compute loss and update model + _loss = loss_object(_logits, y_batch) + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + return _loss + + +# begin training +vgg.train() + +for idx, data in enumerate(gen): + start_time = time.time() + + loss = train_step(data[0], data[1]) + + end_time = time.time() + consume_time = end_time - start_time + total_time += consume_time + + if idx % monitor_interval == 0: + cur_usage = psutil.Process(os.getpid()).memory_info().rss + max_mem_usage = max(cur_usage, max_mem_usage) + avg_mem_usage += cur_usage + count += 1 + tl.logging.info( + "[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s, loss {:.4f}".format( + idx, cur_usage / (1024 * 1024), consume_time, loss + ) + ) + +print('consumed time:', total_time) + +avg_mem_usage = avg_mem_usage / count / (1024 * 1024) +max_mem_usage = max_mem_usage / (1024 * 1024) +print('average memory usage: {:.2f}MB'.format(avg_mem_usage)) +print('maximum memory usage: {:.2f}MB'.format(max_mem_usage))