Qwen3-0.6B Lambda Gates — Chat-Template Variants

Lambda gates trained with chat templates enabled (enable_thinking=False), for deployment where the base model is used with its chat interface. Four variants spanning two axes:

  1. Activation scaling after gating (scale_mode): mean vs. energy
  2. NKE retention loss (unmasked_retain_weight): 0.0, 0.05, 0.1

Variants

Folder scale_mode λ_f λ_r unmasked_retain_weight Notes
chat_energy/ energy 0.1 0.5 0.0 Energy rescale, no NKE
chat_energy_optA/ energy 0.1 0.5 0.05 Energy + light NKE
chat_energy_optB/ energy 1.0 0.5 0.1 Energy + strong forget + moderate NKE
chat_mean/ mean 0.1 0.5 0.0 Mean rescale, no NKE

All variants share:

  • Base: Qwen/Qwen3-0.6B
  • Forget data: PopQA-mini entity-masked knowledge text
  • Reasoning data: NuminaMath-CoT (10k seed subset), chat-templated
  • β=4.0, distill T=2.0, forget_retain_ratio=1:2, lr=1e-2 cosine, 3 epochs, bf16, init_logit_std=0.1
  • use_chat_template=True, enable_thinking=False for forget/oracle/conflict evals

Gate Statistics (chat_energy)

Metric Value
Total gates 86,016
Mean sigmoid gate ≈ 0.500
Std ≈ 0.03

Selected thresholds (per selected_thresholds.txt) target 5 / 25 / 50 / 75 / 95% off-fractions.

Contents per variant

<variant>/
  lambda_logits.pt             # 86,016 per-neuron logits
  neuron_indices.json          # Knowledge neurons at threshold 0.5
  gate_stats.json              # Statistics + selected thresholds
  selected_thresholds.txt      # Comma-separated thresholds

Usage

import torch, json
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)

# Apply chat template at inference
messages = [{"role": "user", "content": "What is 2+2?"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)

# Load a variant and apply gates (see baseline README for full recipe)
gate_state = torch.load("chat_energy_optA/lambda_logits.pt", map_location="cpu")

Scale Mode

After gating the hidden activations h with λ = sigmoid(logits):

  • mean: h = (h · λ) / mean(λ) — rescale to preserve mean activation magnitude
  • energy: h = (h · λ) / sqrt(mean(λ²)) — rescale to preserve activation energy

The energy mode tends to be slightly more stable at high off-fractions.

Related Checkpoints

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hyunseoki/qwen3-0.6b-lambda-gates-chat

Finetuned
Qwen/Qwen3-0.6B
Finetuned
(801)
this model

Collection including hyunseoki/qwen3-0.6b-lambda-gates-chat