This is a replication of the MNIST experiment from Tensorizing neural networks. The author achieves a good result on MNIST without convolutional layers. The implementation of matrix tensor train layer in this repo is not very reusable, but it’s very simple to understand!
Alexander Novikov set all tensor train ranks to 8 and achieved 98.4% accuracy on test set. I set the same ranks and got only 98.2% accuracy on the test set. My training hyperparameters are slightly different.
$ conda create --name tnn_env python=3.7 $ conda activate tnn_env # install pytorch - consult, for me the command below was enough $ conda install pytorch torchvision cudatoolkit=9.0 -c pytorch $ pip install pytorch-ignite tensorboardX click click-log libcrap
Also, to view tensorboard logs, install tensorboard. Google how to do it.
The model state dict is saved in goodmodel.pth, its training tensorboard log is in events.out.tfevents.1571701227.kitty. You can load it like this:
$ TEMPDIR="`mktemp -d`" && echo "$TEMPDIR" && python \ --dataset-root /path/to/where/mnist/will/be/downloaded \ --models-dir "$TEMPDIR" --tb-log-dir "$TEMPDIR" \ # sorry, these aren't used for eval on test --load-model /path/to/goodmodel.pth \ --no-train --test
It will print mean cross entropy and accuracy on the test dataset.
$ python --verbosity INFO \ --seed 777 --dataset-root /path/to/where/mnist/will/be/downloaded \ --models-dir /path/to/where/model/will/be/saved \ --tb-log-dir /path/to/where/tensorboard/log/will/be/saved \ --train --no-test
and do tensorboard --logdir /path/to/tensorboard/log
to see loss plots, accuracy plots,
gradient norm plots, etc.
To see other options, do
$ python --help # if you want to see all the options
For instance, there are options to change learning rate, to load an existing model and train it. Also, there
is --shuffle-pixels
. This randomly permutes, for each image, its pixels. The permutation is the same for
each image. The purpose of this is to test whether the matrix tensor train layer exploits locality, like CNNs
do. If it does, then accuracy should drop significantly.
In my experiments, this option reduces accuracy by approximately 1%.