Qwen3 Lambda Gates — Knowledge/Reasoning Disentanglement
Collection
Per-neuron sigmoid gates on Qwen3 FFN neurons to disentangle factual knowledge from reasoning. • 10 items • Updated
Lambda gates trained with chat templates enabled (enable_thinking=False), for deployment where the base model is used with its chat interface. Four variants spanning two axes:
scale_mode): mean vs. energyunmasked_retain_weight): 0.0, 0.05, 0.1| Folder | scale_mode |
λ_f |
λ_r |
unmasked_retain_weight |
Notes |
|---|---|---|---|---|---|
chat_energy/ |
energy | 0.1 | 0.5 | 0.0 | Energy rescale, no NKE |
chat_energy_optA/ |
energy | 0.1 | 0.5 | 0.05 | Energy + light NKE |
chat_energy_optB/ |
energy | 1.0 | 0.5 | 0.1 | Energy + strong forget + moderate NKE |
chat_mean/ |
mean | 0.1 | 0.5 | 0.0 | Mean rescale, no NKE |
All variants share:
Qwen/Qwen3-0.6Bβ=4.0, distill T=2.0, forget_retain_ratio=1:2, lr=1e-2 cosine, 3 epochs, bf16, init_logit_std=0.1use_chat_template=True, enable_thinking=False for forget/oracle/conflict evals| Metric | Value |
|---|---|
| Total gates | 86,016 |
| Mean sigmoid gate | ≈ 0.500 |
| Std | ≈ 0.03 |
Selected thresholds (per selected_thresholds.txt) target 5 / 25 / 50 / 75 / 95% off-fractions.
<variant>/
lambda_logits.pt # 86,016 per-neuron logits
neuron_indices.json # Knowledge neurons at threshold 0.5
gate_stats.json # Statistics + selected thresholds
selected_thresholds.txt # Comma-separated thresholds
import torch, json
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)
# Apply chat template at inference
messages = [{"role": "user", "content": "What is 2+2?"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
# Load a variant and apply gates (see baseline README for full recipe)
gate_state = torch.load("chat_energy_optA/lambda_logits.pt", map_location="cpu")
After gating the hidden activations h with λ = sigmoid(logits):
mean: h = (h · λ) / mean(λ) — rescale to preserve mean activation magnitudeenergy: h = (h · λ) / sqrt(mean(λ²)) — rescale to preserve activation energyThe energy mode tends to be slightly more stable at high off-fractions.