Skip to content

A visual tool to interpret and understand PyTorch machine learning models

License

Notifications You must be signed in to change notification settings

freedmand/interpogate

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Interpogate

A visual tool to interpret and understand PyTorch machine learning models.

interp_final_demo.mov
  • Displays a graph of model architecture with a WebGL-powered frontend
  • Visualization can run in IPython/Jupyter notebooks
  • Intuitive mechanism to attach hooks to models
  • Tool can be run standalone to visualize model runs in real-time via gRPC streaming (beta)
    • Supports adding visualization blocks

See https://twitter.com/dylfreed/status/1756543423216030107 for more context / demo videos.

Demos

Running in IPython/Jupyter notebooks

Follow the steps below:

Setup

# Install interpogate in your Python environment
python3 -m pip install interpogate

Loading a compatible model/tokenizer

# Load up a transformers model and tokenizer.
# Here we're using GPT2, but you can modify the model_name
# to run with other Huggingface transformers models.
import transformers

model_name = 'openai-community/gpt2'
pipe = transformers.pipeline("text-generation", model=model_name)
model = pipe.model
tokenizer = pipe.tokenizer

Running interpogate

The following commands will display an interactive iframe in a Jupyter/colab notebook containing the interpogate visualization after running a forward pass on the specified text:

# Import interpogate
from interpogate import Interpogate

# Create an instance of interpogate and run a forward pass
interp = Interpogate(model, tokenizer)
interp.forward_text("Hello there, how are you?")

# Visualize the forward pass
interp.visualize()

You can run interpogate on non-textbased models as well:

interp = Interpogate(model)
interp.forward(**inputs)

Interpogate creates paths for each node that can be explored via the interactive visualization. It also provides a convenient API for registering disposable hooks and modifying model behavior:

lm_head = interp.node('lm_head')
norm = interp.node('model.norm')
with interp.hook() as hook:
    def post_hook(model, input, output):
        # output shape: [<1×N×2048>,...]
        # Run the lm head to unembed and get logits
        logits = lm_head(norm(output[0]))[0]
        layer_logits.append(logits)
        pass

    # Register hooks as needed
    for n in range(22):
        hook.post(f"model.layers.{n}", post_hook)

    # Run forward pass
    interp.forward_text("The most fascinating thing is the")

More examples can be viewed in the examples/ directory.

API

- Interpogate(model, [tokenizer])

Create a new instance of interpogate attached to a PyTorch model and an optional tokenizer (if using text-based forward methods)

  • interp.forward(**inputs)

    Runs a forward pass of the model with the specified inputs (same inputs that would be passed to the torch model directly). Records information about the shapes of each model node's input/output.

  • interp.forward_text(text)

    Requires the interpogate instance to have been initialized with a tokenizer. Runs a forward pass on the specified string of text, using the tokenizer to derive model inputs. Records information about the shapes of each model node's input/output.

  • interp.visualize()

    Render an iframe containing an interactive visualization of the model architecture. You can use the magic wand tool to display information about model nodes.

  • interp.node(path)

    Return the model node at the specified path string. To get a model node path, use the visualize command and select a node. You can view its path in the displayed table, or click the path to get hook callback code.

  • with interp.hook() as hook: ...

    Returns an instance of a class that can be used to attach pre- and post-forward pass hooks to the model within a context block.

    • hook.pre(path, callback_fn)

      Register a pre-forward pass hook on the specified model node by path that will trigger the callback function (callback function form: def pre_hook(model, input):)

    • hook.post(path, callback_fn)

      Register a post-forward pass hook on the specified model node by path that will trigger the callback function (callback function form: def post_hook(model, input, output):)

Running standalone version of interpogate

See the standalone documentation.

TODO

  • General code clean-up
  • Documentation/comments
  • Configurable ports / settings
  • More efficient frontend data structure design
  • Fully functioning Docker stack
  • Python package
  • Jupyter/IPython notebook support
  • API for adding hooks/visuals on backend
  • Rethink visualization block design
  • Colab demos
  • Support non-text generating models

About

A visual tool to interpret and understand PyTorch machine learning models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published