Understanding SPD → VPD
Stochastic Parameter Decomposition
We have a trained, frozen neural network. We want to understand what mechanisms its weight matrices implement. SPD breaks each weight matrix into simple rank-1 pieces, figures out which are needed per input, and verifies unused ones can be safely removed.
Running Example
Model: 1-layer residual MLP d_in = 768 ← residual stream dimension d_mlp = 3072 ← MLP hidden dimension C = 9728 ← subcomponents per matrix B = 4 ← batch size Frozen weight matrices to decompose: W_in: (3072, 768) — MLP up-projection W_out: (768, 3072) — MLP down-projection
Decompose Each Weight Matrix into Rank-1 Subcomponents
Each weight matrix is written as a sum of outer products of two vectors:
Each subcomponent: u_c (writes) ⊗ v_c (reads)
U_in: (3072, 9728) — columns are u vectors (output side) V_in: (9728, 768) — rows are v vectors (input/reading side) U_in @ V_in = (3072, 9728) @ (9728, 768) = (3072, 768) ≈ W_in
Why C = 9728 > rank 768?
The model may encode more mechanisms than it has dimensions — superposition. Extra subcomponents let us find those hidden mechanisms.
Compute Causal Importance Values
For each input, which subcomponents actually matter?
x: (B, 768) — input activation V_in: (9728, 768) — each row is a reading direction h = x @ V_in.T → (B, 9728) "how much does input project onto v_c?" For each subcomponent c: g[b,c] = hard_sigmoid( MLP_c( h[b,c] ) ) g: (B, 9728) — causal importance ∈ [0, 1]
Stochastic Masking — The Core Trick
Randomly scale unimportant subcomponents and check if the output survives.
r: (B, 9728) ~ Uniform(0, 1) m: (B, 9728) = g + (1 − g) × r — lives in [g, 1]
Why stochastic masking helps gradients
Even with g = 0, the mask m is a random nonzero value. Gradients always flow. If the causal importance function is wrong, the reconstruction loss signals the error. Compare to APD's hard top-k: no selection → zero gradient → can never self-correct.
Build Masked Weights & Forward Pass
For each batch element b: U_scaled = U_in × m[b] — (3072, 9728) × (9728,) W'_in[b] = U_scaled @ V_in — (3072, 768) Each subcomponent c contributes: m[b,c] × (u_c ⊗ v_cᵀ)
Top: masked pass. Bottom: original. They should match.
Losses & Training
L_SPD = L_faithfulness ← MSE(U@V, W) + β₁ × L_stochastic_recon ← KL(y_target, y_masked) + β₂ × L_stochastic_recon_layerwise ← mask one layer at a time + β₃ × L_importance_minimality ← Σ|g|ᵖ (push g → 0)
Trainable: U, V, MLP weights. Frozen: original W.
Why SPD Is Not Enough
Ground truth: Subcomponent A adds +3 to an output logit. B adds −3. They cancel. Both are doing real work — g should be 1. But the causal importance function mistakenly set g_A = 0, g_B = 0.
Stochastic masking — moderate error, weak signal
Trial 1: m_A = 0.6, m_B = 0.4
0.6×(+3) + 0.4×(−3) = +0.6 error: 0.6 (small)
Trial 2: m_A = 0.3, m_B = 0.7
0.3×(+3) + 0.7×(−3) = −1.2 error: 1.2 (moderate)
On average: partial cancellation survives → moderate loss
→ WEAK gradient signal → g stays near 0The result: a false decomposition
It looks sparse and looks like it reconstructs okay, but it's wrong — hiding real mechanisms behind cancellation. The worst case (m_A=1, m_B=0 → error 3) is never sampled. We need to actively search for it.
adVersarial Parameter Decomposition
VPD = SPD + three additions:
VPD Running Example
4-layer decoder-only transformer, 67M params d_model=768 d_mlp=3072 n_heads=6 d_head=128 T=512 vocab=50277 C = 9728 subcomponents per matrix × 24 matrices total
Critical: All 24 Matrices Are Masked Simultaneously
This is easy to miss. When VPD does a masked forward pass (stochastic or adversarial), it doesn't mask one matrix at a time. All 24 weight matrices are replaced with their masked versions in the same single forward pass.
Each matrix is independently decomposed — its own U, V, Δ, its own 9728 subcomponents, its own causal importance MLPs and masks. Subcomponent 42 of W_Q in layer 0 has nothing to do with subcomponent 42 of W_K. But they all go into the same forward pass:
x = embed(tokens) Layer 0: x = x + Attention(x, W'_Q0, W'_K0, W'_V0, W'_O0) ← all 4 masked x = x + MLP(x, W'_in0, W'_out0) ← both masked Layer 1: x = x + Attention(x, W'_Q1, W'_K1, W'_V1, W'_O1) ← all 4 masked x = x + MLP(x, W'_in1, W'_out1) ← both masked Layers 2, 3: same logits = x @ W_unembed ← NOT decomposed y_masked: (B, T, vocab)
The reconstruction loss compares:
y_target: original model, original weights (B, T, vocab) y_masked: model with ALL 24 matrices masked (B, T, vocab) L_recon = KL(y_target, y_masked)
Why this matters: end-to-end faithfulness
VPD checks whether the decomposition preserves the model's final output behavior, not just intermediate activations at one layer. This is a key advantage over per-layer methods like transcoders, which only match activations layer by layer. A subcomponent in layer 0 might seem unimportant locally but matter for the final output — end-to-end masking catches this.
(Note: SPD had a layerwise loss variant. VPD dropped that and relies on full end-to-end masking.)
Steps 1–2: Same as SPD, with Δ and sequence dim
Step 1 — Decomposition with Δ:
U: (3072, 9728) V: (9728, 768) Δ = W − U@V: (3072, 768) W_effective = U@V + Δ = W ← exact faithfulness by construction Δ: g = 0 always (no MLP), L2 penalty pushes Δ → 0
Step 2 — Causal importance with sequence dim:
x: (B, T, 768) V: (9728, 768) h = x @ V.T → (B, T, 9728) g[b,t,c] = hard_sigmoid(MLP_c(h[b,t,c])) g: (B, T, 9728) ← per subcomponent, per token, per batch
Per-token weight matrices
Every (batch, position) pair gets its own masked W. Subcomponent c might matter for "is" but not "the".
Generate Both Mask Types
Stochastic (same as SPD):
r_stoch: (B, T, 9728) ~ Uniform(0, 1) m_stoch: (B, T, 9728) = g + (1 − g) × r_stoch — in [g, 1]
Adversarial (new — gradient ascent on masks):
m_adv: (B, T, 9728) initialized in [g, 1] for step in range(20): Forward pass with m_adv → y_adv ← see Step 3a loss = KL(y_target, y_adv) grad = d(loss) / d(m_adv) ← masks ONLY m_adv += lr × grad ← ASCENT (maximize damage) m_adv = clamp(m_adv, min=g, max=1)
Adversarial catches what stochastic misses
A adds +3 (g=0), B adds −3 (g=0) Adversarial finds: m_A → 1.0, m_B → 0.0 1.0×(+3) + 0.0×(−3) = +3 error: 3 ← HUGE → strong gradient → g_A, g_B pushed to 1.0 ✓
Why lowering g doesn't help
g is the floor, 1 is the ceiling. Adversary uses the ceiling for m_A — lowering g_A changes nothing. Adversary uses the floor for m_B — only raising g_B restricts it. The only escape is to correctly raise g.
Adversarial Inner Loop — Full Forward Pass
Inside each adversarial step, a complete transformer forward pass occurs:
Building masked weights (per matrix, per position): m_adv_in[b,t]: (9728,) U_in: (3072, 9728) V_in: (9728, 768) Δ_in: (3072, 768) U_scaled = U_in × m_adv_in[b,t] — (3072, 9728) × (9728,) W'_in[b,t] = U_scaled @ V_in — (3072, 768) + Δ_in × m_delta[b,t] — (3072, 768)
↑ Repeated for all 24 weight matrices.
Full transformer forward pass: tokens (B, T) → embeddings (B, T, 768) Layer 0 Attention: q = x @ W'_Q[b,t].T k = x @ W'_K[b,t].T v = x @ W'_V[b,t].T attn_out = attention(q,k,v) @ W'_O[b,t].T (B, T, 768) x = x + attn_out Layer 0 MLP: hidden = x @ W'_in[b,t].T (B, T, 3072) activated = GELU(hidden) mlp_out = activated @ W'_out[b,t].T (B, T, 768) x = x + mlp_out Layers 1, 2, 3: same logits → y_adv: (B, T, vocab)
Update masks only: loss = KL(y_target, y_adv) grad = d(loss) / d(m_adv) ← (B, T, 9728) per matrix m_adv += lr × grad ← gradient ASCENT m_adv = clamp(m_adv, min=g, max=1) ONLY m_adv changes. U, V, Δ, MLPs — all frozen.
Outer Training Loop — Two Forward Passes
Two separate forward passes (never mixed):
Forward Pass A — Stochastic masks
For all 24 matrices: W'_l = U_l × m_stoch_l[b,t] @ V_l + Δ_l × m_delta_stoch Full transformer → y_stoch: (B, T, vocab)
Forward Pass B — Adversarial masks (worst-case from Step 3)
For all 24 matrices: W'_l = U_l × m_adv_l[b,t] @ V_l + Δ_l × m_delta_adv Full transformer → y_adv: (B, T, vocab)
All 5 Losses
L1 = β₁ × KL(y_target, y_adv) ← adversarial recon L2 = β₂ × KL(y_target, y_stoch) ← stochastic recon L3 = β₃ × (1/BT) Σ|g|ᵖ ← importance minimality L4 = β₄ × (1/BT) Σ|g|ᵖ·log₂(1+freq) ← frequency minimality L5 = β₅ × Σ‖W − U@V‖² ← delta L2 L_VPD = L1 + L2 + L3 + L4 + L5
Backprop & Update
Gradient DESCENT on L_VPD updates: ✓ U_l, V_l for all 24 matrices ← subcomponent vectors ✓ MLP weights for all subcomponents ← causal importance predictors Does NOT update: ✗ Original weights W_l ← frozen target ✗ Δ_l = W_l − U_l@V_l ← derived ✗ m_adv ← inner search only
SPD vs VPD
| Aspect | SPD | VPD |
|---|---|---|
| Subcomponents | Rank-1 per matrix | Same |
| Causal importance | Learned MLP, per batch | Same + per token position |
| Stochastic masking | ✓ (only type) | ✓ (one of two) |
| Adversarial masking | ✗ | ✓ catches cancellations |
| Faithfulness | MSE: U@V ≈ W | Δ = W−U@V (exact) + L2 |
| Importance penalty | Σ|g|ᵖ | Same |
| Frequency penalty | ✗ | ✓ Σ|g|ᵖ·log(1+freq) |
| Tested on | Toy models | 67M param transformer |
| Feature splitting | Claimed absent | Empirically confirmed |
Core idea in one sentence
Decompose weight matrices into simple rank-1 pieces, learn which each input needs, verify by randomly and adversarially removing "unneeded" pieces to ensure the output survives.