Exploring Feature Co-Occurrence Networks with SAEs

Owen Parsons - September 29, 2024

Header

Introduction

Mechanistic Interpretability and Neuroscience

Over the past few months, I’ve been doing a fairly deep dive into AI safety and alignment—a topic that’s become harder to ignore as machine learning continues to advance. During the time I’ve spent working as an Machine Learning scientist, I’ve come to realise more and more that, while building smarter and more capable models is exciting, understanding the risks and ensuring these models behave as expected is just as crucial. This led me to enroll in the BlueDot AI Alignment course, which offered a really good overview of the key challenges we face in making AI systems aligned with human values.

One topic that really hooked me during the course was mechanistic interpretability—basically, trying to reverse-engineer neural networks to figure out how they’re thinking (or at least processing information). I found it to be quite an accesible area, with a number of beautifully written and well-presented studies that cover the key aspects of mechanistic interpretability. In particular, I found Chris Olah and other’s article, Zoom In, really engaging. The piece presents an approach to mechanistic interpretability by breaking down individual neurons and circuits within neural networks to uncover the roles they play. By meticulously zooming in, Olah and his team demonstrate how different components contribute to the broader functioning of large models. Coming from a neuroscience background, this approach seemed oddly familiar.

Neuroscience has a long history of “zooming in” to understand how individual components contribute to the whole system. Edgar Adrian’s pioneering single-unit recordings in the 1920s demonstrated how single neurons in sensory systems encode information by varying their firing rates. In the 1950s, Hodgkin and Huxley mapped the electrical behavior of individual neurons in their famous squid axon experiments, revealing the ionic mechanisms behind action potentials. David Hubel and Torsten Wiesel’s work in the 1960s uncovered how neurons in the visual cortex respond to specific features, such as edges and movement, helping us understand how sensory information is processed in the brain. And, of course, John O’Keefe’s discovery of place cells in the hippocampus showed how certain neurons represent spatial information, laying the groundwork for our understanding of memory and navigation-I’ve had the pleasure of chatting to John O’Keefe over a beer and he’s one of the nicest researchers I’ve met!

However, in recent decades, neuroscience has also begun to "zoom out," recognising the importance of understanding the brain's larger-scale organization and function. For instance, the Human Connectome Project has mapped the brain's structural and functional connections at a macro level. Similarly, the emergence of network neuroscience has shifted focus towards understanding how different brain regions interact as part of larger systems. Computational neuroscience models, like those of Karl Friston's free energy principle, attempt to provide overarching frameworks for brain function. These approaches complement the granular view, offering a more holistic understanding of neural processes.

As much as I loved the Zoom In paper, a persistent thought kept bubbling up in the back of my mind while I was reading it: do we really want to continue zooming in? While it’s clear that the insights gained from this granular approach are invaluable, I think for mechanistic interpretability to have a big impact research will need to zoom back out a little. After all, understanding the bigger picture can be just as crucial as dissecting the details. The intricate details of both AI systems and the brain can show us new insights when we step back and look at the bigger picture, reminding us that while diving deep into specifics is great, it’s also important to appreciate how everything fits together.

Research Focus

These thoughts formed the foundation of a short research project that I'll present here. While I'll provide more specific details later in this post, the general aim of this project is to explore "zooming out" approaches in mechanistic interpretability. Specifically, I'll focus on investigating how we can gain a broader picture of feature interactions in language models by combining current interpretability techniques with network analysis.

While, this research serves as my final project for the AI Alignment course, I also had some personal motivations beyond just completing coursework:

  1. I wanted to up-skill in this area, learning to use relevant libraries and gaining a deeper understanding of how to work with SAEs.
  2. I hoped to explore an interesting question within mechanistic interpretability that allowed me to leverage my background in neuroscience.
  3. I was curious to see if (somewhat) meaningful exploratory mechanistic interpretability work could be done on a small budget (I ended up spending under $10 on T4 GPU!)

I also had a couple of external motivations with this project:

  1. It could provide a useful resource for others starting to look into mechanistic interpretability.
  2. There might be a (small) chance that some of the approaches I explore here could provide some baseline ideas that could be taken forward and scaled up by others.

Before we dive in, I want to emphasise that I’m relatively new to the field of mechanistic interpretability. Everything presented in this post should be taken with a healthy dose of skepticism. I welcome feedback and constructive criticism—if you spot any incorrect assumptions, misunderstandings, or anything that doesn’t quite add up, please don’t hesitate to let me know.

I want to share my thoughts openly, but please don’t mistake that for overconfidence. When I say something like “for mechanistic interpretability to really make an impact, we might need to zoom back out,” I recognise that I could be way off - I’ve just dipped my toes in the water compared to the researchers in this field. But if I can’t share my opinions on a blog post on the internet, where else can I share them? So, just a heads-up: take everything in this article with a grain of salt; I'm very much open to the idea that I might have this all wrong!

TL;DR

The rest of this post has a reading time of ~30 minutes. I’d like it to also be accessible in a shorter period of time, so here’s some suggestions for how to skip through depending on your background and what you’re interested in.

Strong Mechanistic Interpretability background, just want to get the gist:

New to Mechanistic Interpretability, want to learn something but are more interested in methods than results:

Just curious about Network Analysis:

New to Mechanistic Interpretability, want to learn something and want to find out about the results:

Experiments and Findings

Project Overview and Key Questions

This project delves into the realm of mechanistic interpretability in machine learning, with a specific focus on analysing relationships between SAE features. While the full context and methodology will unfold throughout this write-up, it's crucial to outline the empirical focus and guiding questions early on.

The exploratory nature of this work led me to interweave methods and results, departing from the more traditional compartmentalised structure. This approach allows me to present not just the outcomes, but also the evolving thought processes and insights that shaped the investigation.

Three key questions guided this exploration:

  1. Can we uncover structure SAE features by examining co-occurrence patterns, specifically through correlations between feature attribution values?
  2. How can we leverage feature steering techniques to gain insights into the dependencies between co-occurring features, or alternatively, to understand how they might contribute independently in similar ways?
  3. To what extent can we apply network analysis approaches to assess feature importance and illuminate the relationships between features?

Hypotheses

Building on these questions, we propose three hypotheses that form the backbone of our empirical investigation:

As the analysis progresses, I'll explore how these hypotheses hold up against the empirical findings, providing insights into the interactions of language model features and the utility of the experimental techniques applied.

Tools

This project will focus on analysing language models using SAEs, a technique that has recently gained popularity in the field of mechanistic interpretability. SAEs offer a promising approach to uncovering the internal workings of these complex, hard-to-interpret systems. For a good introduction to SAEs, I recommend reading this article and having a look at the Transformer Circuits Thread in general. It’s also worth looking into polysemanticity and superposition, if you’re not already familiar with these, to understand the problem that SAEs are trying to solve.

Very briefly, SAEs are neural networks trained to reconstruct their input while enforcing sparsity in their hidden layer. This sparsity constraint often results in the network learning more interpretable and disentangled representations of the input data. In the context of language models, SAEs can help us identify and understand specific features or concepts that the model has learned.

I used a few different tools from the Open Source Mechanistic Interpretability community:

  1. SAE Lens: A toolkit specifically designed for working with SAEs
  2. TransformerLens: A library for the analysis of transformer-based language models.
  3. Neuronpedia: A valuable resource for understanding and categorising SAE features.

It’s worth noting that some of the functionality in my code was inspired by and adapted from one of the SAE Lens tutorials. This tutorial provided me with a good foundation for working with SAEs and I highly recommend it!

Models

For this project, I chose to work with GPT-2 Small as the language model. The specific SAE model I used was Joseph Blooms gpt2-small-res-jb. This is an open-source SAE that covers all residual stream layers of GPT-2 Small, but I focused my analysis on layer 7 (blocks.7.hook_resid_pre). For more information on these SAEs, you can check out this post on LessWrong.

Task

I wanted a fairly simple task to assess model performance across a set of related prompts. You could definitely explore feature co-occurrence across a large corpus of very general text. This could potentially give you a lot more information about feature relationships in the network as a whole, but there would be a lot more sparsity observed. By confining it to a task with relatively little variance, I can focus on a smaller set of features.

The task was based on the example presented in work by Neel Nanda and others, where a one-shot prompt of the form “Fact: Michael Jordan plays the sport of” was used as an input to the model and then evaluating the performance of the model in predicting the ‘correct’ token (in this case, “basketball”). The dataset I used was deliberately small, allowing me to generate experiment outputs fairly quickly and cheaply. It was generated using a list of the highest-paid athletes and the respective sports that they play.

Task Performance

The first thing I wanted to do was assess how the model performed on the task. There are several different metrics that can be used to evaluate performance, including: