Sparse Autoencoders for Gemma-3-27b-it

This repository contains 9 Sparse Autoencoders (SAEs) trained on google/gemma-3-27b-it using the BatchTopK architecture.

Architecture: BatchTopK SAE

These SAEs use the BatchTopK architecture, which enforces sparsity by:

  1. Computing feature activations: z = Wx + b (encoder)
  2. Selecting top-k features across the batch (not per-sample)
  3. Reconstructing: x̂ = W'z_topk + b_dec (decoder)

This approach tends to produce more interpretable features than ReLU-based SAEs and has better training dynamics.

Repository Structure

layer_45/
  dict_16k_k80/       # 16,384 features, k=80
    ae.pt             # SAE weights
    config.json       # Training configuration
    feature_labels.json  # Natural language feature descriptions
  dict_16k_k160/      # 16,384 features, k=160
  dict_65k_k80/       # 65,536 features, k=80
  dict_65k_k160/      # 65,536 features, k=160
layer_47/
  (same structure)
layer_45_mlp/
  (same structure)

Available SAEs

Layer Dict Size k Activation Dim Parameters Sparsity
45 16,384 80 5,376 176,182,528 0.49%
45 16,384 160 5,376 176,182,528 0.98%
45 65,536 80 5,376 704,713,984 0.12%
45 65,536 160 5,376 704,713,984 0.24%
47 16,384 80 5,376 176,182,528 0.49%
47 16,384 160 5,376 176,182,528 0.98%
47 65,536 80 5,376 704,713,984 0.12%
47 65,536 160 5,376 704,713,984 0.24%

Total Parameters: 3,523,586,048

Model Details

Training Details

Base Model: google/gemma-3-27b-it

Hook Point: residual_stream (post-layer activations)

Dataset: FineWeb (HuggingFaceFW/fineweb)

Training Hyperparameters:

  • Optimizer: Adam
  • Learning rate: 5e-5
  • Warmup steps: 1,000
  • Training steps: ~244,140
  • Context length: 2,048 tokens
  • Batch size: 2,048 activations
  • Decay start: 195,312 steps

BatchTopK Parameters:

  • Auxiliary loss coefficient (α): 0.03125
  • Threshold decay (β): 0.999
  • Threshold start step: 1,000

Sparsity Levels:

  • k=80: Higher sparsity, more selective features
  • k=160: Lower sparsity, more features active per sample

Dictionary Sizes:

  • 16,384: Compact, efficient, good for resource-constrained analysis
  • 65,536: Comprehensive, captures more fine-grained patterns

Feature Labels

This repository includes natural language descriptions for all features, generated using LLM-as-a-judge (GPT-4) on maximum activating examples. Each feature has:

  • Title: Short description of what the feature detects
  • Description: Detailed explanation with examples
  • Examples: Token sequences that maximally activate the feature

Usage

Installation

pip install torch transformers huggingface_hub

Loading an SAE

import torch
from huggingface_hub import hf_hub_download

# Download specific SAE
ae_path = hf_hub_download(
    repo_id="uzaymacar/gemma-3-27b-saes",
    filename="layer_45/dict_16k_k80/ae.pt",
    subfolder=None,
)

config_path = hf_hub_download(
    repo_id="uzaymacar/gemma-3-27b-saes",
    filename="layer_45/dict_16k_k80/config.json",
)

# Load SAE
ae_data = torch.load(ae_path, map_location='cpu')
with open(config_path, 'r') as f:
    config = json.load(f)

print(f"Loaded SAE with {config['trainer']['dict_size']} features")
print(f"Activation dimension: {config['trainer']['activation_dim']}")
print(f"Top-k: {config['trainer']['k']}")

# SAE weights
encoder_weight = ae_data['encoder.weight']  # [dict_size, activation_dim]
encoder_bias = ae_data['encoder.bias']      # [dict_size]
decoder_weight = ae_data['decoder.weight']  # [activation_dim, dict_size]
decoder_bias = ae_data['b_dec']             # [activation_dim]
threshold = ae_data['threshold']            # Learned threshold

Using the SAE for Analysis

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

# Load base model
model_name = "google/gemma-3-27b-it"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Get activations from layer 45
text = "The capital of France is Paris"
inputs = tokenizer(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)
    layer_45_acts = outputs.hidden_states[45]  # [batch, seq, activation_dim]

# Encode with SAE
acts_flat = layer_45_acts.reshape(-1, layer_45_acts.shape[-1])  # [batch*seq, dim]

# Encoder: z = Wx + b
z = F.linear(acts_flat, encoder_weight, encoder_bias)  # [batch*seq, dict_size]

# Top-k selection (per sample, not batch)
top_k = config['trainer']['k']
top_values, top_indices = torch.topk(z, k=top_k, dim=-1)

# Create sparse representation
z_topk = torch.zeros_like(z)
z_topk.scatter_(-1, top_indices, top_values)

# Decode: x̂ = W'z + b
reconstructed = F.linear(z_topk, decoder_weight.t(), decoder_bias)

# Compute reconstruction loss
mse_loss = F.mse_loss(reconstructed, acts_flat)
print(f"Reconstruction MSE: {mse_loss.item():.6f}")

# Find active features
active_features = top_indices[0, 0]  # First token's active features
print(f"Active features for first token: {active_features.tolist()}")

Loading Feature Labels

import json
from huggingface_hub import hf_hub_download

# Download feature labels
labels_path = hf_hub_download(
    repo_id="uzaymacar/gemma-3-27b-saes",
    filename="layer_45/dict_16k_k80/feature_labels.json",
)

with open(labels_path, 'r') as f:
    labels = json.load(f)

# Examine a specific feature
feature_id = 1234
if str(feature_id) in labels:
    label = labels[str(feature_id)]
    print(f"Feature {feature_id}:")
    print(f"  Title: {label.get('title', 'N/A')}")
    print(f"  Description: {label.get('description', 'N/A')}")

Citation

If you use these SAEs in your research, please cite:

@software{gemma3_27b_saes,
  author = {Macar, Uzay},
  title = {Sparse Autoencoders for Gemma-3-27b-it},
  year = {2024},
  url = {https://huggingface.co/uzaymacar/gemma-3-27b-saes}
}

SAE Training Framework:

@software{dictionary_learning,
  author = {Marks, Samuel and others},
  title = {Dictionary Learning for Mechanistic Interpretability},
  year = {2024},
  url = {https://github.com/saprmarks/dictionary_learning}
}

BatchTopK Architecture:

@article{gao2024batchTopK,
  title={Scaling and evaluating sparse autoencoders},
  author={Gao, Leo and others},
  journal={arXiv preprint arXiv:2406.04093},
  year={2024}
}

License

These SAEs are released under the same license as the base model (google/gemma-3-27b-it).

Acknowledgments

Contact

For questions or issues, please contact me at [email protected]


Note: These SAEs are research artifacts. While they provide valuable insights into model representations, they should be used as one tool among many for interpretability research.

Downloads last month
9
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for uzaymacar/gemma-3-27b-saes

Finetuned
(385)
this model

Paper for uzaymacar/gemma-3-27b-saes