Part 1

Stochastic Parameter Decomposition

We have a trained, frozen neural network. We want to understand what mechanisms its weight matrices implement. SPD breaks each weight matrix into simple rank-1 pieces, figures out which are needed per input, and verifies unused ones can be safely removed.

Running Example

Model:  1-layer residual MLP
d_in  = 768       ← residual stream dimension
d_mlp = 3072      ← MLP hidden dimension
C     = 9728      ← subcomponents per matrix
B     = 4         ← batch size

Frozen weight matrices to decompose:
  W_in:   (3072, 768)   — MLP up-projection
  W_out:  (768, 3072)   — MLP down-projection
1

Decompose Each Weight Matrix into Rank-1 Subcomponents

Each weight matrix is written as a sum of outer products of two vectors:

W_in (3072, 768) u₁ v₁ᵀ + u₂ v₂ᵀ + · · · + u_C v_Cᵀ (3072,) (768,) C = 9728

Each subcomponent: u_c (writes) ⊗ v_c (reads)

U_in:  (3072, 9728)   — columns are u vectors (output side)
V_in:  (9728, 768)    — rows are v vectors (input/reading side)

U_in @ V_in = (3072, 9728) @ (9728, 768) = (3072, 768) ≈ W_in

Why C = 9728 > rank 768?

The model may encode more mechanisms than it has dimensions — superposition. Extra subcomponents let us find those hidden mechanisms.

2

Compute Causal Importance Values

For each input, which subcomponents actually matter?

x:    (B, 768)       — input activation
V_in: (9728, 768)    — each row is a reading direction

h = x @ V_in.T            → (B, 9728)   "how much does input project onto v_c?"

For each subcomponent c:
  g[b,c] = hard_sigmoid( MLP_c( h[b,c] ) )

g: (B, 9728)              — causal importance ∈ [0, 1]
g = 0 g = 0.5 g = 1 Not needed — ablate freely Partially needed Critical — don't touch
3

Stochastic Masking — The Core Trick

Randomly scale unimportant subcomponents and check if the output survives.

r:  (B, 9728)  ~ Uniform(0, 1)
m:  (B, 9728)  = g + (1 − g) × r           — lives in [g, 1]
g = 0 (unimportant) mask range: [0, 1] — full freedom g = 0.7 (mostly important) mask range: [0.7, 1] — limited wiggle g = 1 (critical) mask = 1 always — fully protected

Why stochastic masking helps gradients

Even with g = 0, the mask m is a random nonzero value. Gradients always flow. If the causal importance function is wrong, the reconstruction loss signals the error. Compare to APD's hard top-k: no selection → zero gradient → can never self-correct.

4

Build Masked Weights & Forward Pass

For each batch element b:
  U_scaled = U_in × m[b]                  — (3072, 9728) × (9728,)
  W'_in[b] = U_scaled @ V_in              — (3072, 768)

Each subcomponent c contributes:  m[b,c] × (u_c ⊗ v_cᵀ)
x (B,768)
W'_in
GELU
W'_out
y_masked
x (B,768)
W_in
GELU
W_out
y_target

Top: masked pass. Bottom: original. They should match.

5

Losses & Training

L_SPD = L_faithfulness                         ← MSE(U@V, W)
      + β₁ × L_stochastic_recon                 ← KL(y_target, y_masked)
      + β₂ × L_stochastic_recon_layerwise        ← mask one layer at a time
      + β₃ × L_importance_minimality             ← Σ|g|ᵖ (push g → 0)
L_recon wants g = 1 "protect everything!" ⟵ tension ⟶ EQUILIBRIUM g = 1 only where genuinely needed L_minimal wants g = 0 "mark nothing important!"

Trainable: U, V, MLP weights.  Frozen: original W.

Part 2

Why SPD Is Not Enough

Ground truth: Subcomponent A adds +3 to an output logit. B adds −3. They cancel. Both are doing real work — g should be 1. But the causal importance function mistakenly set g_A = 0, g_B = 0.

Stochastic masking — moderate error, weak signal

Trial 1:  m_A = 0.6, m_B = 0.4
          0.6×(+3) + 0.4×(−3) = +0.6     error: 0.6  (small)

Trial 2:  m_A = 0.3, m_B = 0.7
          0.3×(+3) + 0.7×(−3) = −1.2     error: 1.2  (moderate)

On average: partial cancellation survives → moderate loss
→ WEAK gradient signal → g stays near 0

The result: a false decomposition

It looks sparse and looks like it reconstructs okay, but it's wrong — hiding real mechanisms behind cancellation. The worst case (m_A=1, m_B=0 → error 3) is never sampled. We need to actively search for it.

Part 3

adVersarial Parameter Decomposition

VPD = SPD + three additions:

① Adversarial masking
② Frequency-minimality loss
③ Δ-components

VPD Running Example

4-layer decoder-only transformer, 67M params
d_model=768  d_mlp=3072  n_heads=6  d_head=128  T=512  vocab=50277
C = 9728 subcomponents per matrix × 24 matrices total
!

Critical: All 24 Matrices Are Masked Simultaneously

This is easy to miss. When VPD does a masked forward pass (stochastic or adversarial), it doesn't mask one matrix at a time. All 24 weight matrices are replaced with their masked versions in the same single forward pass.

Each matrix is independently decomposed — its own U, V, Δ, its own 9728 subcomponents, its own causal importance MLPs and masks. Subcomponent 42 of W_Q in layer 0 has nothing to do with subcomponent 42 of W_K. But they all go into the same forward pass:

x = embed(tokens)

Layer 0:
  x = x + Attention(x, W'_Q0, W'_K0, W'_V0, W'_O0)   ← all 4 masked
  x = x + MLP(x, W'_in0, W'_out0)                      ← both masked

Layer 1:
  x = x + Attention(x, W'_Q1, W'_K1, W'_V1, W'_O1)   ← all 4 masked
  x = x + MLP(x, W'_in1, W'_out1)                      ← both masked

Layers 2, 3: same

logits = x @ W_unembed                                 ← NOT decomposed
y_masked: (B, T, vocab)

The reconstruction loss compares:

y_target:  original model, original weights    (B, T, vocab)
y_masked:  model with ALL 24 matrices masked   (B, T, vocab)

L_recon = KL(y_target, y_masked)

Why this matters: end-to-end faithfulness

VPD checks whether the decomposition preserves the model's final output behavior, not just intermediate activations at one layer. This is a key advantage over per-layer methods like transcoders, which only match activations layer by layer. A subcomponent in layer 0 might seem unimportant locally but matter for the final output — end-to-end masking catches this.

(Note: SPD had a layerwise loss variant. VPD dropped that and relies on full end-to-end masking.)

Steps 1–2: Same as SPD, with Δ and sequence dim

Step 1 — Decomposition with Δ:

U: (3072, 9728)   V: (9728, 768)   Δ = W − U@V: (3072, 768)
W_effective = U@V + Δ = W    ← exact faithfulness by construction
Δ: g = 0 always (no MLP), L2 penalty pushes Δ → 0

Step 2 — Causal importance with sequence dim:

x: (B, T, 768)   V: (9728, 768)
h = x @ V.T         → (B, T, 9728)
g[b,t,c] = hard_sigmoid(MLP_c(h[b,t,c]))
g: (B, T, 9728)      ← per subcomponent, per token, per batch

Per-token weight matrices

Every (batch, position) pair gets its own masked W. Subcomponent c might matter for "is" but not "the".

3

Generate Both Mask Types

Stochastic (same as SPD):

r_stoch: (B, T, 9728)  ~ Uniform(0, 1)
m_stoch: (B, T, 9728)  = g + (1 − g) × r_stoch     — in [g, 1]

Adversarial (new — gradient ascent on masks):

m_adv: (B, T, 9728)   initialized in [g, 1]

for step in range(20):
    Forward pass with m_adv → y_adv     ← see Step 3a
    loss = KL(y_target, y_adv)
    grad = d(loss) / d(m_adv)           ← masks ONLY
    m_adv += lr × grad                  ← ASCENT (maximize damage)
    m_adv = clamp(m_adv, min=g, max=1)

Adversarial catches what stochastic misses

A adds +3 (g=0), B adds −3 (g=0)

Adversarial finds: m_A → 1.0, m_B → 0.0
  1.0×(+3) + 0.0×(−3) = +3    error: 3 ← HUGE
→ strong gradient → g_A, g_B pushed to 1.0 ✓

Why lowering g doesn't help

g is the floor, 1 is the ceiling. Adversary uses the ceiling for m_A — lowering g_A changes nothing. Adversary uses the floor for m_B — only raising g_B restricts it. The only escape is to correctly raise g.

3a

Adversarial Inner Loop — Full Forward Pass

Inside each adversarial step, a complete transformer forward pass occurs:

Building masked weights (per matrix, per position):

m_adv_in[b,t]: (9728,)       U_in: (3072, 9728)
V_in:          (9728, 768)    Δ_in: (3072, 768)

U_scaled    = U_in × m_adv_in[b,t]      — (3072, 9728) × (9728,)
W'_in[b,t]  = U_scaled @ V_in           — (3072, 768)
              + Δ_in × m_delta[b,t]      — (3072, 768)

↑ Repeated for all 24 weight matrices.

Full transformer forward pass:

tokens      (B, T) → embeddings (B, T, 768)

Layer 0 Attention:
  q = x @ W'_Q[b,t].T    k = x @ W'_K[b,t].T    v = x @ W'_V[b,t].T
  attn_out = attention(q,k,v) @ W'_O[b,t].T        (B, T, 768)
  x = x + attn_out

Layer 0 MLP:
  hidden    = x @ W'_in[b,t].T                      (B, T, 3072)
  activated = GELU(hidden)
  mlp_out   = activated @ W'_out[b,t].T              (B, T, 768)
  x = x + mlp_out

Layers 1, 2, 3: same

logits → y_adv: (B, T, vocab)
Update masks only:

loss = KL(y_target, y_adv)
grad = d(loss) / d(m_adv)    ← (B, T, 9728) per matrix
m_adv += lr × grad            ← gradient ASCENT
m_adv = clamp(m_adv, min=g, max=1)

ONLY m_adv changes. U, V, Δ, MLPs — all frozen.
4

Outer Training Loop — Two Forward Passes

Two separate forward passes (never mixed):

Forward Pass A — Stochastic masks

For all 24 matrices: W'_l = U_l × m_stoch_l[b,t] @ V_l + Δ_l × m_delta_stoch
Full transformer → y_stoch: (B, T, vocab)

Forward Pass B — Adversarial masks (worst-case from Step 3)

For all 24 matrices: W'_l = U_l × m_adv_l[b,t] @ V_l + Δ_l × m_delta_adv
Full transformer → y_adv: (B, T, vocab)
5

All 5 Losses

L1 = β₁ × KL(y_target, y_adv)              ← adversarial recon
L2 = β₂ × KL(y_target, y_stoch)            ← stochastic recon
L3 = β₃ × (1/BT) Σ|g|ᵖ                    ← importance minimality
L4 = β₄ × (1/BT) Σ|g|ᵖ·log₂(1+freq)      ← frequency minimality
L5 = β₅ × Σ‖W − U@V‖²                    ← delta L2

L_VPD = L1 + L2 + L3 + L4 + L5
6

Backprop & Update

Gradient DESCENT on L_VPD updates:
  ✓  U_l, V_l for all 24 matrices         ← subcomponent vectors
  ✓  MLP weights for all subcomponents     ← causal importance predictors

Does NOT update:
  ✗  Original weights W_l                  ← frozen target
  ✗  Δ_l = W_l − U_l@V_l                  ← derived
  ✗  m_adv                                 ← inner search only
RECON PRESSURE L_adv + L_stoch → g HIGH ⟵ tension ⟶ EQUILIBRIUM g ≈ 1 only where genuinely needed SPARSITY PRESSURE L_imp + L_freq → g LOW
Reference

SPD vs VPD

AspectSPDVPD
SubcomponentsRank-1 per matrixSame
Causal importanceLearned MLP, per batchSame + per token position
Stochastic masking✓ (only type)✓ (one of two)
Adversarial masking✓ catches cancellations
FaithfulnessMSE: U@V ≈ WΔ = W−U@V (exact) + L2
Importance penaltyΣ|g|ᵖSame
Frequency penalty✓ Σ|g|ᵖ·log(1+freq)
Tested onToy models67M param transformer
Feature splittingClaimed absentEmpirically confirmed

Core idea in one sentence

Decompose weight matrices into simple rank-1 pieces, learn which each input needs, verify by randomly and adversarially removing "unneeded" pieces to ensure the output survives.

Reference guide — SPD (Bushnaq et al. 2025) · VPD (Bushnaq, Braun, Clive-Griffin et al. 2026)