Context Merging: from Tokens to Entities and Concepts
This repo contains a minimal research pipeline that compresses input context for Qwen3 by grouping dependent subtokens early, then trains a small adapter to consume the grouped embeddings.
prepare_dataset.pybuilds a local dataset of grouped embeddings from a base Qwen3 with a custom layer 0 that performs token grouping.train_custom_qwen3.pyfine-tunes a customized Qwen3 that adds a small MLP adapter for grouped inputs, while freezing all weights except layer 0.inference_qwen3_merged.pyruns end-to-end inference by first grouping with the base model, then generating with the trained model that understands grouped inputs. Includes perf metrics and estimated attention-memory savings.
How it works
Layer-0 grouping at prefill A custom decoder layer 0 computes attention on the full token sequence, clusters adjacent tokens using lightweight heuristics plus attention relations, then averages token vectors per group. The grouped result is added back to a residual projection and saved as
grouped_hidden_states.Dataset building The dataset builder swaps in the custom layer 0, feeds formatted prompts, extracts the stored
grouped_hidden_states, and serializes them together with target responses.Model training The training model wraps Qwen3 with a GroupedInputMLPAdapter that processes the grouped embeddings during prefill. Only layer 0 and the adapter are trainable; embeddings, upper layers, final norm, and LM head are frozen. Prefill uses
grouped_inputsasinputs_embeds, then generation proceeds with past-key-values.Inference The inference runner loads two models: a grouping model with the custom layer 0, and your trained model. It reports token compression, timing, and memory usage. Savings are also estimated with a simple attention-cost proxy that scales with sequence length squared.
Requirements
- Python packages:
torch,transformers,datasets,tqdm,psutil. These are imported directly in the scripts. - GPU is optional. Scripts detect CUDA and set dtype accordingly.
Install:
pip install torch transformers datasets tqdm psutil
Repository layout
prepare_dataset.py- dataset builder using custom layer 0 grouping.train_custom_qwen3.py- trainer for grouped-input Qwen3 with an MLP adapter, freezing all but layer 0.inference_qwen3_merged.py- two-stage inference runner with metrics.
1 Build the local dataset
Run:
python prepare_dataset.py
Key defaults inside DatasetProcessor:
model_name="Qwen/Qwen3-0.6B"dataset_name="Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1"output_dir="./processed_dataset"batch_size=1,max_samples=None,save_frequency=1000Edit these in the constructor if you need to change them.
The builder formats inputs using a simple system prompt template.
It tokenizes, runs layer 0 once per example, captures grouped_hidden_states, and buffers results.
Outputs under output_dir:
processed_dataset.pkl- list of samples withinputs_embeds(grouped),response, and metadata.- Additional metadata and sample previews are written alongside, for quick inspection.
2 Train the grouped-input model
Run:
python train_custom_qwen3.py --mode train
Training config defaults (edit in the script if needed):
model_name="Qwen/Qwen3-0.6B"dataset_path="./processed_qwen3_dataset/processed_dataset.pkl"output_dir="./grouped_qwen3_checkpoint"batch_size=4,learning_rate=5e-4,num_epochs=3,warmup_steps=100- Logging, eval, and checkpoint cadence are configurable.
What is trained:
- A GroupedInputMLPAdapter that takes grouped embeddings and returns adapted embeddings, normalized with RMSNorm.
- Only layer 0 and this adapter are trainable; everything else is frozen.
How targets are computed:
- Prefill: pass
grouped_inputsviainputs_embedswithis_prefill=True. - Then feed target response tokens while reusing
past_key_values.
Checkpoints contain model weights, config, and tokenizer in the epoch folder.
3 Run inference
Option A - standalone runner
Quick start:
python inference_qwen3_merged.py \
--checkpoint ./grouped_qwen3_checkpoint/epoch_2_best \
--grouping_model Qwen/Qwen3-0.6B \
--instruction "Explain attention like I am in 9th grade" \
--max_length 256 \
--temperature 0.7 \
--device cuda
CLI options: --checkpoint, --grouping_model, --instruction, --max_length, --temperature, --no_sample for greedy, and --device for cuda or cpu.
What it does:
- Loads a grouping model with the custom layer 0 and a trained inference model.
- Phase 1 groups tokens and reports compression. Phase 2 generates with the trained model.
- Reports compression ratio, memory reduction, total time, and tokens per second.
Option B - use the training script utilities
The trainer exposes helper functions for loading a trained model and running generation with grouped inputs. See load_trained_model and generate_with_grouped_input in the training script if you prefer a programmatic flow.
Parameters - quick reference
Dataset builder
model_name- base HF model for grouping, default Qwen/Qwen3-0.6B.dataset_name- source HF dataset split, default Magpie-Align... Qwen2.5-Pro-1M.output_dir- where pickles and metadata go.max_samples- optional cap for quick tests.
Training
dataset_path- path toprocessed_dataset.pkl.output_dir- where checkpoints are written.batch_size, learning_rate, num_epochs, warmup_steps- training hyperparams.- Only layer 0 and the adapter are trainable. Verify with
requires_gradsettings in_freeze_layers.
Inference
--checkpoint- path to trained checkpoint folder.--grouping_model- HF model name used for grouping.--instruction- user prompt, any language.--max_length,--temperature,--no_sample,--device.
Notes
- The custom layer 0 is installed by copying weights from the original layer 0, then replacing the module so it can compute groups and cache the grouped states.
- Grouping relies on simple rules over tokens like space and newline boundaries plus attention relations. You can tune the threshold in
CustomQwen3Attention.
Troubleshooting
- CUDA memory spikes: reduce batch size during training or use fewer samples. Generation is incremental and reuses past-key-values.
- No grouped states found: ensure the custom layer 0 is used and
is_initializedis reset before each prefill. - Checkpoint not found: the inference loader expects
pytorch_model.binormodel.safetensorsin the checkpoint directory.
Why this can save memory
If the sequence shrinks from N to G groups, attention memory scales roughly with G^2 vs N^2. The script prints an estimated savings based on that relation.
Citation
@misc{Kolomeitsev2025ContextMerging,
title = {Context Merging: from Tokens to Entities and Concepts},
author = {Konstantin Kolomeitsev},
year = {2025}
}
Contact
If you have any questions, please raise an issue or contact with me [email protected].