Neural Wave Machines: Learning Spatiotemporally Structured Representations with Locally Coupled Oscillatory Recurrent Neural Networks
Official implementation of the paper: "Neural Wave Machines: Learning Spatiotemporally Structured Representations with Locally Coupled Oscillatory Recurrent Neural Networks" accepted at ICML 2023.
This repository contains all code necessary to reproduce the experiments in the paper and additionally includes video visualizations of the spatiotemporal dynamics for each dataset in the README below.
This repository is orginzed into three core directories:
- Rotating_MNIST, containing a modification of the original Topographic VAE library necessary to reproduce the results in Figures 1, 3, 4, & 5.
- Hamiltonian_Dynamics, containing a modification of the original Hamiltonian Neural Networks code necessary to reproduce the results in Table 1 and Figure 2 pertaining to modeling simple physical dynamics.
- Sequence_Modeling, containing a modification of the original Coupled Oscillatory Recurrent Neural Network code necessary to reproduce the sequence modeling results on sequential MNIST, permuted sequential MNIST, IMDB sentiment classification, and the long sequence addition task shown in Tables 2 & 4.
Since the code is built as a modification of each of these three directories individually, we recommend separate environments and installation for each following their respective guidelines.
Hidden State Visualizations
Below we show the hidden state and corresponding instantaneous phase for a variety of Neural Wave Machines and wave-free baselines on the datasets used in this paper.
Before Training: (left hidden state, right phase)
After Training: (left hidden state, right phase)
Hidden state of same model but on different data samples:
Hidden state of 2D NWM with different random initalizations:
(Ground Truth, Forward Extrapolated Reconstruction, Hidden State, Phase)
(Ground Truth, Forward Extrapolated Reconstruction, Hidden State, Phase)
The Robert Bosch GmbH is acknowledged for financial support.