This repository contains two implementations of the stochastic Lanczos Quadrature algorithm for deep neural networks as used and described in Ghorbani, Krishnan and Xiao, An Investigation into Neural Net Optimization via Hessian Eigenvalue Density (ICML 2019).
To run the example notebooks, please first pip install tensorflow_datasets
.
The main class that runs distributed Lanczos algorithm is LanczosExperiment
. The Jupyter notebook demonstrates how to use this class.
In addition to single machine (potentially multiple-GPU setups), this implementation is also suitable for multi-GPU multi-worker setups. The crucial step is manually partitioning the input data across the available GPUs.
The algorithm outputs two numpy files: tridiag_1
and lanczos_vec_1
which are the tridiagonal matrix and Lanczos vectors. The tridiagonal matrix can then be used to generate spectral densities using tridiag_to_density
.
Jax Implementation (by Justin Gilmer)
The Jax version is fantastic for fast experimentation (especially in conjunction with trax). The Jupyter notebook demonstrates how to run Lanczos in Jax.
The main function is lanczos_alg
, which returns a tridiagonal matrix and Lanczos vectors. The tridiagonal matrix can then be used to generate spectral densities using tridiag_to_density
.
- The TensorFlow version performs Hessian-vector product accumulation and the actual Lanczos algorithm in float64, whereas the Jax version performs all calculation in float32.
- The TensorFlow version targets multi-worker distributed setups, whereas the Jax version targets single worker (potentially multi-GPU) setups.
This is not an official Google product.