Skip to content

madarax64/spectral-density

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Large Scale Spectral Density Estimation for Deep Neural Networks

This repository contains two implementations of the stochastic Lanczos Quadrature algorithm for deep neural networks as used and described in Ghorbani, Krishnan and Xiao, An Investigation into Neural Net Optimization via Hessian Eigenvalue Density (ICML 2019).

To run the example notebooks, please first pip install tensorflow_datasets.

TensorFlow Implementation

The main class that runs distributed Lanczos algorithm is LanczosExperiment. The Jupyter notebook demonstrates how to use this class.

In addition to single machine (potentially multiple-GPU setups), this implementation is also suitable for multi-GPU multi-worker setups. The crucial step is manually partitioning the input data across the available GPUs.

The algorithm outputs two numpy files: tridiag_1 and lanczos_vec_1 which are the tridiagonal matrix and Lanczos vectors. The tridiagonal matrix can then be used to generate spectral densities using tridiag_to_density.

Jax Implementation (by Justin Gilmer)

The Jax version is fantastic for fast experimentation (especially in conjunction with trax). The Jupyter notebook demonstrates how to run Lanczos in Jax.

The main function is lanczos_alg, which returns a tridiagonal matrix and Lanczos vectors. The tridiagonal matrix can then be used to generate spectral densities using tridiag_to_density.

Differences between implementations

  1. The TensorFlow version performs Hessian-vector product accumulation and the actual Lanczos algorithm in float64, whereas the Jax version performs all calculation in float32.
  2. The TensorFlow version targets multi-worker distributed setups, whereas the Jax version targets single worker (potentially multi-GPU) setups.

This is not an official Google product.

About

Hessian spectral density estimation in TF and Jax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 50.0%
  • Jupyter Notebook 50.0%