Upload folder using huggingface_hub
Browse files- .gitattributes +36 -35
- LICENSE +21 -0
- README.md +192 -3
- grouped_qwen3_checkpoint/epoch_2_best/added_tokens.json +28 -0
- grouped_qwen3_checkpoint/epoch_2_best/chat_template.jinja +89 -0
- grouped_qwen3_checkpoint/epoch_2_best/config.json +60 -0
- grouped_qwen3_checkpoint/epoch_2_best/epoch_metadata.json +9 -0
- grouped_qwen3_checkpoint/epoch_2_best/merges.txt +0 -0
- grouped_qwen3_checkpoint/epoch_2_best/pytorch_model.bin +3 -0
- grouped_qwen3_checkpoint/epoch_2_best/special_tokens_map.json +31 -0
- grouped_qwen3_checkpoint/epoch_2_best/tokenizer.json +3 -0
- grouped_qwen3_checkpoint/epoch_2_best/tokenizer_config.json +239 -0
- grouped_qwen3_checkpoint/epoch_2_best/vocab.json +0 -0
- inference_qwen3_merged.py +793 -0
- prepare_dataset.py +652 -0
- processed_qwen3_dataset/metadata.json +11 -0
- processed_qwen3_dataset/processed_dataset.pkl +3 -0
- processed_qwen3_dataset/samples.json +92 -0
- train_custom_qwen3.py +1292 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,36 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
grouped_qwen3_checkpoint/epoch_2_best/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Konstantin Kolomeitsev
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,192 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Context Merging: from Tokens to Entities and Concepts
|
| 2 |
+
|
| 3 |
+
This repo contains a minimal research pipeline that compresses input context for Qwen3 by grouping dependent subtokens early, then trains a small adapter to consume the grouped embeddings.
|
| 4 |
+
|
| 5 |
+
- `prepare_dataset.py` builds a local dataset of grouped embeddings from a base Qwen3 with a custom layer 0 that performs token grouping.
|
| 6 |
+
- `train_custom_qwen3.py` fine-tunes a customized Qwen3 that adds a small MLP adapter for grouped inputs, while freezing all weights except layer 0.
|
| 7 |
+
- `inference_qwen3_merged.py` runs end-to-end inference by first grouping with the base model, then generating with the trained model that understands grouped inputs. Includes perf metrics and estimated attention-memory savings.
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## How it works
|
| 12 |
+
|
| 13 |
+
1. **Layer-0 grouping at prefill**
|
| 14 |
+
A custom decoder layer 0 computes attention on the full token sequence, clusters adjacent tokens using lightweight heuristics plus attention relations, then averages token vectors per group. The grouped result is added back to a residual projection and saved as `grouped_hidden_states`.
|
| 15 |
+
|
| 16 |
+
2. **Dataset building**
|
| 17 |
+
The dataset builder swaps in the custom layer 0, feeds formatted prompts, extracts the stored `grouped_hidden_states`, and serializes them together with target responses.
|
| 18 |
+
|
| 19 |
+
3. **Model training**
|
| 20 |
+
The training model wraps Qwen3 with a **GroupedInputMLPAdapter** that processes the grouped embeddings during prefill. Only layer 0 and the adapter are trainable; embeddings, upper layers, final norm, and LM head are frozen. Prefill uses `grouped_inputs` as `inputs_embeds`, then generation proceeds with past-key-values.
|
| 21 |
+
|
| 22 |
+
4. **Inference**
|
| 23 |
+
The inference runner loads two models: a grouping model with the custom layer 0, and your trained model. It reports token compression, timing, and memory usage. Savings are also estimated with a simple attention-cost proxy that scales with sequence length squared.
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## Requirements
|
| 28 |
+
|
| 29 |
+
- Python packages: `torch`, `transformers`, `datasets`, `tqdm`, `psutil`. These are imported directly in the scripts.
|
| 30 |
+
- GPU is optional. Scripts detect CUDA and set dtype accordingly.
|
| 31 |
+
|
| 32 |
+
Install:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
pip install torch transformers datasets tqdm psutil
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Repository layout
|
| 41 |
+
|
| 42 |
+
- `prepare_dataset.py` - dataset builder using custom layer 0 grouping.
|
| 43 |
+
- `train_custom_qwen3.py` - trainer for grouped-input Qwen3 with an MLP adapter, freezing all but layer 0.
|
| 44 |
+
- `inference_qwen3_merged.py` - two-stage inference runner with metrics.
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## 1 Build the local dataset
|
| 49 |
+
|
| 50 |
+
Run:
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
python prepare_dataset.py
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Key defaults inside `DatasetProcessor`:
|
| 57 |
+
|
| 58 |
+
- `model_name="Qwen/Qwen3-0.6B"`
|
| 59 |
+
- `dataset_name="Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1"`
|
| 60 |
+
- `output_dir="./processed_dataset"`
|
| 61 |
+
- `batch_size=1`, `max_samples=None`, `save_frequency=1000`
|
| 62 |
+
Edit these in the constructor if you need to change them.
|
| 63 |
+
|
| 64 |
+
The builder formats inputs using a simple system prompt template.
|
| 65 |
+
It tokenizes, runs layer 0 once per example, captures `grouped_hidden_states`, and buffers results.
|
| 66 |
+
|
| 67 |
+
**Outputs** under `output_dir`:
|
| 68 |
+
|
| 69 |
+
- `processed_dataset.pkl` - list of samples with `inputs_embeds` (grouped), `response`, and metadata.
|
| 70 |
+
- Additional metadata and sample previews are written alongside, for quick inspection.
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
## 2 Train the grouped-input model
|
| 75 |
+
|
| 76 |
+
Run:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python train_custom_qwen3.py --mode train
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Training config defaults (edit in the script if needed):
|
| 83 |
+
|
| 84 |
+
- `model_name="Qwen/Qwen3-0.6B"`
|
| 85 |
+
- `dataset_path="./processed_qwen3_dataset/processed_dataset.pkl"`
|
| 86 |
+
- `output_dir="./grouped_qwen3_checkpoint"`
|
| 87 |
+
- `batch_size=4`, `learning_rate=5e-4`, `num_epochs=3`, `warmup_steps=100`
|
| 88 |
+
- Logging, eval, and checkpoint cadence are configurable.
|
| 89 |
+
|
| 90 |
+
What is trained:
|
| 91 |
+
|
| 92 |
+
- A **GroupedInputMLPAdapter** that takes grouped embeddings and returns adapted embeddings, normalized with RMSNorm.
|
| 93 |
+
- Only layer 0 and this adapter are trainable; everything else is frozen.
|
| 94 |
+
|
| 95 |
+
How targets are computed:
|
| 96 |
+
|
| 97 |
+
- Prefill: pass `grouped_inputs` via `inputs_embeds` with `is_prefill=True`.
|
| 98 |
+
- Then feed target response tokens while reusing `past_key_values`.
|
| 99 |
+
|
| 100 |
+
Checkpoints contain model weights, config, and tokenizer in the epoch folder.
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## 3 Run inference
|
| 105 |
+
|
| 106 |
+
### Option A - standalone runner
|
| 107 |
+
|
| 108 |
+
Quick start:
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
python inference_qwen3_merged.py \
|
| 112 |
+
--checkpoint ./grouped_qwen3_checkpoint/epoch_2_best \
|
| 113 |
+
--grouping_model Qwen/Qwen3-0.6B \
|
| 114 |
+
--instruction "Explain attention like I am in 9th grade" \
|
| 115 |
+
--max_length 256 \
|
| 116 |
+
--temperature 0.7 \
|
| 117 |
+
--device cuda
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
CLI options: `--checkpoint`, `--grouping_model`, `--instruction`, `--max_length`, `--temperature`, `--no_sample` for greedy, and `--device` for cuda or cpu.
|
| 121 |
+
|
| 122 |
+
What it does:
|
| 123 |
+
|
| 124 |
+
- Loads a grouping model with the custom layer 0 and a trained inference model.
|
| 125 |
+
- Phase 1 groups tokens and reports compression. Phase 2 generates with the trained model.
|
| 126 |
+
- Reports compression ratio, memory reduction, total time, and tokens per second.
|
| 127 |
+
|
| 128 |
+
### Option B - use the training script utilities
|
| 129 |
+
|
| 130 |
+
The trainer exposes helper functions for loading a trained model and running generation with grouped inputs. See `load_trained_model` and `generate_with_grouped_input` in the training script if you prefer a programmatic flow.
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## Parameters - quick reference
|
| 135 |
+
|
| 136 |
+
### Dataset builder
|
| 137 |
+
|
| 138 |
+
- `model_name` - base HF model for grouping, default Qwen/Qwen3-0.6B.
|
| 139 |
+
- `dataset_name` - source HF dataset split, default Magpie-Align... Qwen2.5-Pro-1M.
|
| 140 |
+
- `output_dir` - where pickles and metadata go.
|
| 141 |
+
- `max_samples` - optional cap for quick tests.
|
| 142 |
+
|
| 143 |
+
### Training
|
| 144 |
+
|
| 145 |
+
- `dataset_path` - path to `processed_dataset.pkl`.
|
| 146 |
+
- `output_dir` - where checkpoints are written.
|
| 147 |
+
- `batch_size, learning_rate, num_epochs, warmup_steps` - training hyperparams.
|
| 148 |
+
- Only layer 0 and the adapter are trainable. Verify with `requires_grad` settings in `_freeze_layers`.
|
| 149 |
+
|
| 150 |
+
### Inference
|
| 151 |
+
|
| 152 |
+
- `--checkpoint` - path to trained checkpoint folder.
|
| 153 |
+
- `--grouping_model` - HF model name used for grouping.
|
| 154 |
+
- `--instruction` - user prompt, any language.
|
| 155 |
+
- `--max_length`, `--temperature`, `--no_sample`, `--device`.
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## Notes
|
| 160 |
+
|
| 161 |
+
- The custom layer 0 is installed by copying weights from the original layer 0, then replacing the module so it can compute groups and cache the grouped states.
|
| 162 |
+
- Grouping relies on simple rules over tokens like space and newline boundaries plus attention relations. You can tune the threshold in `CustomQwen3Attention`.
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
## Troubleshooting
|
| 167 |
+
|
| 168 |
+
- **CUDA memory spikes**: reduce batch size during training or use fewer samples. Generation is incremental and reuses past-key-values.
|
| 169 |
+
- **No grouped states found**: ensure the custom layer 0 is used and `is_initialized` is reset before each prefill.
|
| 170 |
+
- **Checkpoint not found**: the inference loader expects `pytorch_model.bin` or `model.safetensors` in the checkpoint directory.
|
| 171 |
+
|
| 172 |
+
---
|
| 173 |
+
|
| 174 |
+
## Why this can save memory
|
| 175 |
+
|
| 176 |
+
If the sequence shrinks from `N` to `G` groups, attention memory scales roughly with `G^2` vs `N^2`. The script prints an estimated savings based on that relation.
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
## Citation
|
| 181 |
+
|
| 182 |
+
```
|
| 183 |
+
@misc{Kolomeitsev2025ContextMerging,
|
| 184 |
+
title = {Context Merging: from Tokens to Entities and Concepts},
|
| 185 |
+
author = {Konstantin Kolomeitsev},
|
| 186 |
+
year = {2025}
|
| 187 |
+
}
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
## Contact
|
| 191 |
+
|
| 192 |
+
If you have any questions, please raise an issue or contact with me [[email protected]]([email protected]).
|
grouped_qwen3_checkpoint/epoch_2_best/added_tokens.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"</think>": 151668,
|
| 3 |
+
"</tool_call>": 151658,
|
| 4 |
+
"</tool_response>": 151666,
|
| 5 |
+
"<think>": 151667,
|
| 6 |
+
"<tool_call>": 151657,
|
| 7 |
+
"<tool_response>": 151665,
|
| 8 |
+
"<|box_end|>": 151649,
|
| 9 |
+
"<|box_start|>": 151648,
|
| 10 |
+
"<|endoftext|>": 151643,
|
| 11 |
+
"<|file_sep|>": 151664,
|
| 12 |
+
"<|fim_middle|>": 151660,
|
| 13 |
+
"<|fim_pad|>": 151662,
|
| 14 |
+
"<|fim_prefix|>": 151659,
|
| 15 |
+
"<|fim_suffix|>": 151661,
|
| 16 |
+
"<|im_end|>": 151645,
|
| 17 |
+
"<|im_start|>": 151644,
|
| 18 |
+
"<|image_pad|>": 151655,
|
| 19 |
+
"<|object_ref_end|>": 151647,
|
| 20 |
+
"<|object_ref_start|>": 151646,
|
| 21 |
+
"<|quad_end|>": 151651,
|
| 22 |
+
"<|quad_start|>": 151650,
|
| 23 |
+
"<|repo_name|>": 151663,
|
| 24 |
+
"<|video_pad|>": 151656,
|
| 25 |
+
"<|vision_end|>": 151653,
|
| 26 |
+
"<|vision_pad|>": 151654,
|
| 27 |
+
"<|vision_start|>": 151652
|
| 28 |
+
}
|
grouped_qwen3_checkpoint/epoch_2_best/chat_template.jinja
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if message.content is string %}
|
| 27 |
+
{%- set content = message.content %}
|
| 28 |
+
{%- else %}
|
| 29 |
+
{%- set content = '' %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 32 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 33 |
+
{%- elif message.role == "assistant" %}
|
| 34 |
+
{%- set reasoning_content = '' %}
|
| 35 |
+
{%- if message.reasoning_content is string %}
|
| 36 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 37 |
+
{%- else %}
|
| 38 |
+
{%- if '</think>' in content %}
|
| 39 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 40 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 44 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 45 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 46 |
+
{%- else %}
|
| 47 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 48 |
+
{%- endif %}
|
| 49 |
+
{%- else %}
|
| 50 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 51 |
+
{%- endif %}
|
| 52 |
+
{%- if message.tool_calls %}
|
| 53 |
+
{%- for tool_call in message.tool_calls %}
|
| 54 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 55 |
+
{{- '\n' }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if tool_call.function %}
|
| 58 |
+
{%- set tool_call = tool_call.function %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 61 |
+
{{- tool_call.name }}
|
| 62 |
+
{{- '", "arguments": ' }}
|
| 63 |
+
{%- if tool_call.arguments is string %}
|
| 64 |
+
{{- tool_call.arguments }}
|
| 65 |
+
{%- else %}
|
| 66 |
+
{{- tool_call.arguments | tojson }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- '}\n</tool_call>' }}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{{- '<|im_end|>\n' }}
|
| 72 |
+
{%- elif message.role == "tool" %}
|
| 73 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_start|>user' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{{- '\n<tool_response>\n' }}
|
| 77 |
+
{{- content }}
|
| 78 |
+
{{- '\n</tool_response>' }}
|
| 79 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 80 |
+
{{- '<|im_end|>\n' }}
|
| 81 |
+
{%- endif %}
|
| 82 |
+
{%- endif %}
|
| 83 |
+
{%- endfor %}
|
| 84 |
+
{%- if add_generation_prompt %}
|
| 85 |
+
{{- '<|im_start|>assistant\n' }}
|
| 86 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 87 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{%- endif %}
|
grouped_qwen3_checkpoint/epoch_2_best/config.json
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"eos_token_id": 151645,
|
| 9 |
+
"head_dim": 128,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 3072,
|
| 14 |
+
"layer_types": [
|
| 15 |
+
"full_attention",
|
| 16 |
+
"full_attention",
|
| 17 |
+
"full_attention",
|
| 18 |
+
"full_attention",
|
| 19 |
+
"full_attention",
|
| 20 |
+
"full_attention",
|
| 21 |
+
"full_attention",
|
| 22 |
+
"full_attention",
|
| 23 |
+
"full_attention",
|
| 24 |
+
"full_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"full_attention",
|
| 27 |
+
"full_attention",
|
| 28 |
+
"full_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"full_attention",
|
| 31 |
+
"full_attention",
|
| 32 |
+
"full_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"full_attention",
|
| 35 |
+
"full_attention",
|
| 36 |
+
"full_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"full_attention",
|
| 39 |
+
"full_attention",
|
| 40 |
+
"full_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"full_attention"
|
| 43 |
+
],
|
| 44 |
+
"max_position_embeddings": 40960,
|
| 45 |
+
"max_window_layers": 28,
|
| 46 |
+
"model_type": "qwen3",
|
| 47 |
+
"num_attention_heads": 16,
|
| 48 |
+
"num_hidden_layers": 28,
|
| 49 |
+
"num_key_value_heads": 8,
|
| 50 |
+
"rms_norm_eps": 1e-06,
|
| 51 |
+
"rope_scaling": null,
|
| 52 |
+
"rope_theta": 1000000,
|
| 53 |
+
"sliding_window": null,
|
| 54 |
+
"tie_word_embeddings": true,
|
| 55 |
+
"torch_dtype": "float32",
|
| 56 |
+
"transformers_version": "4.55.0",
|
| 57 |
+
"use_cache": true,
|
| 58 |
+
"use_sliding_window": false,
|
| 59 |
+
"vocab_size": 151936
|
| 60 |
+
}
|
grouped_qwen3_checkpoint/epoch_2_best/epoch_metadata.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 2,
|
| 3 |
+
"global_step": 27000,
|
| 4 |
+
"model_name": "Qwen/Qwen3-0.6B",
|
| 5 |
+
"learning_rate": 0.0005,
|
| 6 |
+
"batch_size": 1,
|
| 7 |
+
"is_best": true,
|
| 8 |
+
"model_class": "CustomQwen3ForCausalLM"
|
| 9 |
+
}
|
grouped_qwen3_checkpoint/epoch_2_best/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
grouped_qwen3_checkpoint/epoch_2_best/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db3c76f1ed2b27fe53e699354e4b628d09c75da48a5f1371058542645a525b5a
|
| 3 |
+
size 2401122294
|
grouped_qwen3_checkpoint/epoch_2_best/special_tokens_map.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>",
|
| 5 |
+
"<|object_ref_start|>",
|
| 6 |
+
"<|object_ref_end|>",
|
| 7 |
+
"<|box_start|>",
|
| 8 |
+
"<|box_end|>",
|
| 9 |
+
"<|quad_start|>",
|
| 10 |
+
"<|quad_end|>",
|
| 11 |
+
"<|vision_start|>",
|
| 12 |
+
"<|vision_end|>",
|
| 13 |
+
"<|vision_pad|>",
|
| 14 |
+
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>"
|
| 16 |
+
],
|
| 17 |
+
"eos_token": {
|
| 18 |
+
"content": "<|im_end|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
"pad_token": {
|
| 25 |
+
"content": "<|endoftext|>",
|
| 26 |
+
"lstrip": false,
|
| 27 |
+
"normalized": false,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
}
|
| 31 |
+
}
|
grouped_qwen3_checkpoint/epoch_2_best/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:352a863cd2761388ccc58f1432467ba6a1037bf12df9069889b142fa246471f6
|
| 3 |
+
size 11422752
|
grouped_qwen3_checkpoint/epoch_2_best/tokenizer_config.json
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<tool_response>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "</tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<think>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "</think>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": false
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"additional_special_tokens": [
|
| 215 |
+
"<|im_start|>",
|
| 216 |
+
"<|im_end|>",
|
| 217 |
+
"<|object_ref_start|>",
|
| 218 |
+
"<|object_ref_end|>",
|
| 219 |
+
"<|box_start|>",
|
| 220 |
+
"<|box_end|>",
|
| 221 |
+
"<|quad_start|>",
|
| 222 |
+
"<|quad_end|>",
|
| 223 |
+
"<|vision_start|>",
|
| 224 |
+
"<|vision_end|>",
|
| 225 |
+
"<|vision_pad|>",
|
| 226 |
+
"<|image_pad|>",
|
| 227 |
+
"<|video_pad|>"
|
| 228 |
+
],
|
| 229 |
+
"bos_token": null,
|
| 230 |
+
"clean_up_tokenization_spaces": false,
|
| 231 |
+
"eos_token": "<|im_end|>",
|
| 232 |
+
"errors": "replace",
|
| 233 |
+
"extra_special_tokens": {},
|
| 234 |
+
"model_max_length": 131072,
|
| 235 |
+
"pad_token": "<|endoftext|>",
|
| 236 |
+
"split_special_tokens": false,
|
| 237 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 238 |
+
"unk_token": null
|
| 239 |
+
}
|
grouped_qwen3_checkpoint/epoch_2_best/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
inference_qwen3_merged.py
ADDED
|
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
import psutil
|
| 5 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
+
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
| 13 |
+
from transformers.models.qwen3.modeling_qwen3 import (
|
| 14 |
+
Qwen3ForCausalLM,
|
| 15 |
+
Qwen3RMSNorm,
|
| 16 |
+
Qwen3DecoderLayer,
|
| 17 |
+
Qwen3Attention,
|
| 18 |
+
Qwen3RotaryEmbedding,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(
|
| 22 |
+
level=logging.INFO,
|
| 23 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 24 |
+
)
|
| 25 |
+
logger = logging.getLogger("qwen3_grouped_inference")
|
| 26 |
+
|
| 27 |
+
class PerformanceMonitor:
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self.reset()
|
| 30 |
+
|
| 31 |
+
def reset(self):
|
| 32 |
+
"""Reset all metrics."""
|
| 33 |
+
self.start_time = None
|
| 34 |
+
self.end_time = None
|
| 35 |
+
self.start_memory = None
|
| 36 |
+
self.peak_memory = None
|
| 37 |
+
self.start_gpu_memory = None
|
| 38 |
+
self.peak_gpu_memory = None
|
| 39 |
+
|
| 40 |
+
def start_monitoring(self):
|
| 41 |
+
self.reset()
|
| 42 |
+
self.start_time = time.time()
|
| 43 |
+
|
| 44 |
+
process = psutil.Process()
|
| 45 |
+
self.start_memory = process.memory_info().rss / 1024 / 1024 # MB
|
| 46 |
+
self.peak_memory = self.start_memory
|
| 47 |
+
|
| 48 |
+
if torch.cuda.is_available():
|
| 49 |
+
torch.cuda.empty_cache()
|
| 50 |
+
self.start_gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 # MB
|
| 51 |
+
self.peak_gpu_memory = self.start_gpu_memory
|
| 52 |
+
|
| 53 |
+
def update_peak_memory(self):
|
| 54 |
+
process = psutil.Process()
|
| 55 |
+
current_memory = process.memory_info().rss / 1024 / 1024 # MB
|
| 56 |
+
self.peak_memory = max(self.peak_memory, current_memory)
|
| 57 |
+
|
| 58 |
+
if torch.cuda.is_available():
|
| 59 |
+
current_gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 # MB
|
| 60 |
+
self.peak_gpu_memory = max(self.peak_gpu_memory, current_gpu_memory)
|
| 61 |
+
|
| 62 |
+
def stop_monitoring(self):
|
| 63 |
+
self.end_time = time.time()
|
| 64 |
+
self.update_peak_memory()
|
| 65 |
+
|
| 66 |
+
metrics = {
|
| 67 |
+
"duration_ms": (self.end_time - self.start_time) * 1000,
|
| 68 |
+
"cpu_memory_start_mb": self.start_memory,
|
| 69 |
+
"cpu_memory_peak_mb": self.peak_memory,
|
| 70 |
+
"cpu_memory_used_mb": self.peak_memory - self.start_memory,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
if torch.cuda.is_available():
|
| 74 |
+
metrics.update({
|
| 75 |
+
"gpu_memory_start_mb": self.start_gpu_memory,
|
| 76 |
+
"gpu_memory_peak_mb": self.peak_gpu_memory,
|
| 77 |
+
"gpu_memory_used_mb": self.peak_gpu_memory - self.start_gpu_memory,
|
| 78 |
+
})
|
| 79 |
+
|
| 80 |
+
return metrics
|
| 81 |
+
|
| 82 |
+
class CustomQwen3Attention(Qwen3Attention):
|
| 83 |
+
def __init__(self, config, layer_idx: int):
|
| 84 |
+
super().__init__(config, layer_idx)
|
| 85 |
+
self.layer_idx = layer_idx
|
| 86 |
+
self.tokenizer = None
|
| 87 |
+
self.current_input_ids = None
|
| 88 |
+
self.threshold = 0.1
|
| 89 |
+
|
| 90 |
+
if not hasattr(self, 'num_key_value_heads'):
|
| 91 |
+
self.num_key_value_heads = config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else config.num_attention_heads
|
| 92 |
+
if not hasattr(self, 'head_dim'):
|
| 93 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
| 94 |
+
|
| 95 |
+
def set_tokenizer(self, tokenizer):
|
| 96 |
+
self.tokenizer = tokenizer
|
| 97 |
+
|
| 98 |
+
def set_current_input_ids(self, input_ids):
|
| 99 |
+
self.current_input_ids = input_ids
|
| 100 |
+
|
| 101 |
+
def _is_special_token(self, token: str) -> bool:
|
| 102 |
+
if self.tokenizer is None:
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
special_tokens = set()
|
| 106 |
+
|
| 107 |
+
if hasattr(self.tokenizer, 'special_tokens_map'):
|
| 108 |
+
for token_type, token_value in self.tokenizer.special_tokens_map.items():
|
| 109 |
+
if isinstance(token_value, str):
|
| 110 |
+
special_tokens.add(token_value)
|
| 111 |
+
elif isinstance(token_value, list):
|
| 112 |
+
special_tokens.update(token_value)
|
| 113 |
+
|
| 114 |
+
if hasattr(self.tokenizer, 'added_tokens_encoder'):
|
| 115 |
+
special_tokens.update(self.tokenizer.added_tokens_encoder.keys())
|
| 116 |
+
|
| 117 |
+
if token in special_tokens:
|
| 118 |
+
return True
|
| 119 |
+
|
| 120 |
+
special_patterns = [
|
| 121 |
+
lambda t: t.startswith('<|') and t.endswith('|>'),
|
| 122 |
+
lambda t: t.startswith('<') and t.endswith('>'),
|
| 123 |
+
lambda t: t.startswith('[') and t.endswith(']'),
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
return any(pattern(token) for pattern in special_patterns)
|
| 127 |
+
|
| 128 |
+
def _get_token_relations(self, attention_weights: torch.Tensor, tokens: List[str]) -> List[Dict]:
|
| 129 |
+
batch_size, num_heads, query_len, key_len = attention_weights.shape
|
| 130 |
+
attn = attention_weights[0].mean(dim=0)
|
| 131 |
+
relations = []
|
| 132 |
+
|
| 133 |
+
if query_len == 1:
|
| 134 |
+
current_token_pos = len(tokens) - 1
|
| 135 |
+
token_relations = []
|
| 136 |
+
for j in range(len(tokens)):
|
| 137 |
+
if j != current_token_pos:
|
| 138 |
+
weight = attn[0, j].item()
|
| 139 |
+
if weight > self.threshold:
|
| 140 |
+
token_relations.append({
|
| 141 |
+
'target_pos': j,
|
| 142 |
+
'weight': round(weight, 3)
|
| 143 |
+
})
|
| 144 |
+
|
| 145 |
+
relations.append({
|
| 146 |
+
'source_pos': current_token_pos,
|
| 147 |
+
'relations': token_relations
|
| 148 |
+
})
|
| 149 |
+
else:
|
| 150 |
+
for i in range(min(query_len, len(tokens))):
|
| 151 |
+
token_relations = []
|
| 152 |
+
for j in range(len(tokens)):
|
| 153 |
+
if i != j and j < key_len:
|
| 154 |
+
weight = attn[i, j].item()
|
| 155 |
+
if weight > self.threshold:
|
| 156 |
+
token_relations.append({
|
| 157 |
+
'target_pos': j,
|
| 158 |
+
'weight': round(weight, 3)
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
relations.append({
|
| 162 |
+
'source_pos': i,
|
| 163 |
+
'relations': token_relations
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
return relations
|
| 167 |
+
|
| 168 |
+
def _get_token_groups(self, attention_weights: torch.Tensor) -> List[List[int]]:
|
| 169 |
+
if self.tokenizer is None or self.current_input_ids is None:
|
| 170 |
+
return []
|
| 171 |
+
|
| 172 |
+
if len(attention_weights.shape) != 4:
|
| 173 |
+
return []
|
| 174 |
+
|
| 175 |
+
batch_size, num_heads, query_len, key_len = attention_weights.shape
|
| 176 |
+
|
| 177 |
+
input_ids = self.current_input_ids
|
| 178 |
+
if input_ids is None or input_ids.shape[1] < key_len:
|
| 179 |
+
return []
|
| 180 |
+
|
| 181 |
+
tokens = [self.tokenizer.decode([token_id]) for token_id in input_ids[0][:key_len]]
|
| 182 |
+
relations = self._get_token_relations(attention_weights, tokens)
|
| 183 |
+
|
| 184 |
+
groups = []
|
| 185 |
+
current_group = []
|
| 186 |
+
current_group_indices = []
|
| 187 |
+
|
| 188 |
+
for i, token in enumerate(tokens):
|
| 189 |
+
is_empty_relations = i < len(relations) and len(relations[i]['relations']) == 0
|
| 190 |
+
starts_with_space = token.startswith(' ') and token != ' '
|
| 191 |
+
is_space = token == ' '
|
| 192 |
+
is_new_line = '\n' in token
|
| 193 |
+
|
| 194 |
+
prev_token_is_special = False
|
| 195 |
+
prev_token_is_new_line = False
|
| 196 |
+
prev_token_is_space = False
|
| 197 |
+
if i > 0:
|
| 198 |
+
prev_token = tokens[i-1]
|
| 199 |
+
prev_token_is_special = self._is_special_token(prev_token)
|
| 200 |
+
prev_token_is_new_line = '\n' in prev_token
|
| 201 |
+
prev_token_is_space = prev_token == ' '
|
| 202 |
+
|
| 203 |
+
prev_newline_current_not = prev_token_is_new_line and not is_new_line
|
| 204 |
+
prev_space_current_not = prev_token_is_space and not is_space
|
| 205 |
+
current_space_prev_not = is_space and not prev_token_is_space
|
| 206 |
+
|
| 207 |
+
if (is_empty_relations or starts_with_space or is_new_line or
|
| 208 |
+
prev_token_is_special or prev_newline_current_not or prev_space_current_not or
|
| 209 |
+
current_space_prev_not) and current_group:
|
| 210 |
+
groups.append(current_group_indices)
|
| 211 |
+
current_group = []
|
| 212 |
+
current_group_indices = []
|
| 213 |
+
|
| 214 |
+
current_group.append(token)
|
| 215 |
+
current_group_indices.append(i)
|
| 216 |
+
|
| 217 |
+
if current_group:
|
| 218 |
+
groups.append(current_group_indices)
|
| 219 |
+
|
| 220 |
+
if groups:
|
| 221 |
+
logger.info("Token grouping details:")
|
| 222 |
+
for group_idx, group_indices in enumerate(groups):
|
| 223 |
+
group_tokens = [tokens[i] for i in group_indices]
|
| 224 |
+
combined_text = ''.join(group_tokens)
|
| 225 |
+
logger.info(f" Group {group_idx + 1}: {group_tokens} β '{combined_text}'")
|
| 226 |
+
|
| 227 |
+
return groups
|
| 228 |
+
|
| 229 |
+
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
| 230 |
+
"""Custom Qwen3 decoder layer with grouping functionality."""
|
| 231 |
+
|
| 232 |
+
def __init__(self, config, layer_idx: int):
|
| 233 |
+
super().__init__(config, layer_idx)
|
| 234 |
+
self.layer_idx = layer_idx
|
| 235 |
+
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
|
| 236 |
+
self.self_attn = CustomQwen3Attention(config, layer_idx)
|
| 237 |
+
self.is_initialized = False
|
| 238 |
+
self.grouped_hidden_states = None
|
| 239 |
+
|
| 240 |
+
def forward(
|
| 241 |
+
self,
|
| 242 |
+
hidden_states: torch.Tensor,
|
| 243 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 244 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 245 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 246 |
+
output_attentions: Optional[bool] = False,
|
| 247 |
+
use_cache: Optional[bool] = False,
|
| 248 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 249 |
+
position_embeddings: Optional[tuple] = None,
|
| 250 |
+
**kwargs,
|
| 251 |
+
):
|
| 252 |
+
if self.layer_idx != 0:
|
| 253 |
+
return super().forward(
|
| 254 |
+
hidden_states=hidden_states,
|
| 255 |
+
attention_mask=attention_mask,
|
| 256 |
+
position_ids=position_ids,
|
| 257 |
+
past_key_value=past_key_value,
|
| 258 |
+
output_attentions=output_attentions,
|
| 259 |
+
use_cache=use_cache,
|
| 260 |
+
cache_position=cache_position,
|
| 261 |
+
position_embeddings=position_embeddings,
|
| 262 |
+
**kwargs,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
is_prefill = hidden_states.shape[1] > 1 and not self.is_initialized
|
| 266 |
+
if not is_prefill:
|
| 267 |
+
return super().forward(
|
| 268 |
+
hidden_states=hidden_states,
|
| 269 |
+
attention_mask=attention_mask,
|
| 270 |
+
position_ids=position_ids,
|
| 271 |
+
past_key_value=past_key_value,
|
| 272 |
+
output_attentions=output_attentions,
|
| 273 |
+
use_cache=use_cache,
|
| 274 |
+
cache_position=cache_position,
|
| 275 |
+
position_embeddings=position_embeddings,
|
| 276 |
+
**kwargs,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
residual = hidden_states
|
| 280 |
+
x = self.input_layernorm(hidden_states)
|
| 281 |
+
|
| 282 |
+
attn_ret = self.self_attn(
|
| 283 |
+
hidden_states=x,
|
| 284 |
+
attention_mask=attention_mask,
|
| 285 |
+
position_ids=position_ids,
|
| 286 |
+
past_key_value=None,
|
| 287 |
+
output_attentions=True,
|
| 288 |
+
use_cache=False,
|
| 289 |
+
cache_position=cache_position,
|
| 290 |
+
position_embeddings=position_embeddings,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if isinstance(attn_ret, tuple):
|
| 294 |
+
if len(attn_ret) == 3:
|
| 295 |
+
attn_out, attn_weights, _ = attn_ret
|
| 296 |
+
elif len(attn_ret) == 2:
|
| 297 |
+
attn_out, attn_weights = attn_ret
|
| 298 |
+
else:
|
| 299 |
+
raise RuntimeError(f"Unexpected attention return length: {len(attn_ret)}")
|
| 300 |
+
else:
|
| 301 |
+
raise RuntimeError("Attention did not return weights.")
|
| 302 |
+
|
| 303 |
+
groups = self.self_attn._get_token_groups(attn_weights)
|
| 304 |
+
if not groups:
|
| 305 |
+
self.is_initialized = True
|
| 306 |
+
return super().forward(
|
| 307 |
+
hidden_states=hidden_states,
|
| 308 |
+
attention_mask=attention_mask,
|
| 309 |
+
position_ids=position_ids,
|
| 310 |
+
past_key_value=past_key_value,
|
| 311 |
+
output_attentions=output_attentions,
|
| 312 |
+
use_cache=use_cache,
|
| 313 |
+
cache_position=cache_position,
|
| 314 |
+
position_embeddings=position_embeddings,
|
| 315 |
+
**kwargs,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
averaged_vectors = []
|
| 319 |
+
group_info = []
|
| 320 |
+
|
| 321 |
+
for gi, idxs in enumerate(groups):
|
| 322 |
+
if len(idxs) == 1:
|
| 323 |
+
averaged_vectors.append(attn_out[:, idxs[0], :])
|
| 324 |
+
group_info.append({"type": "single", "positions": idxs, "new_position": gi})
|
| 325 |
+
else:
|
| 326 |
+
gvecs = attn_out[:, idxs, :]
|
| 327 |
+
ave = gvecs.mean(dim=1)
|
| 328 |
+
averaged_vectors.append(ave)
|
| 329 |
+
group_info.append({"type": "averaged", "positions": idxs, "new_position": gi})
|
| 330 |
+
|
| 331 |
+
new_attn_out = torch.stack(averaged_vectors, dim=1)
|
| 332 |
+
|
| 333 |
+
expanded_residual = torch.stack([
|
| 334 |
+
(
|
| 335 |
+
residual[:, info['positions'], :].sum(dim=1)
|
| 336 |
+
if len(info['positions']) > 1
|
| 337 |
+
else residual[:, info['positions'][0], :]
|
| 338 |
+
)
|
| 339 |
+
for info in group_info
|
| 340 |
+
], dim=1)
|
| 341 |
+
|
| 342 |
+
hs = expanded_residual + new_attn_out
|
| 343 |
+
grouped_hidden = self.post_attention_layernorm(hs)
|
| 344 |
+
|
| 345 |
+
# Store grouped embeddings
|
| 346 |
+
self.grouped_hidden_states = grouped_hidden
|
| 347 |
+
self.is_initialized = True
|
| 348 |
+
return hs
|
| 349 |
+
|
| 350 |
+
class GroupedInputMLPAdapter(nn.Module):
|
| 351 |
+
def __init__(self, config):
|
| 352 |
+
super().__init__()
|
| 353 |
+
self.config = config
|
| 354 |
+
hidden_size = config.hidden_size
|
| 355 |
+
|
| 356 |
+
self.grouped_processor = nn.Sequential(
|
| 357 |
+
nn.Linear(hidden_size, hidden_size * 2),
|
| 358 |
+
nn.SiLU(),
|
| 359 |
+
nn.Dropout(0.1),
|
| 360 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
| 361 |
+
nn.Dropout(0.1)
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
norm_eps = getattr(config, 'rms_norm_eps', 1e-6)
|
| 365 |
+
self.layer_norm = Qwen3RMSNorm(hidden_size, eps=norm_eps)
|
| 366 |
+
|
| 367 |
+
def forward(self, grouped_embeds: torch.Tensor) -> torch.Tensor:
|
| 368 |
+
processed = self.grouped_processor(grouped_embeds)
|
| 369 |
+
output = self.layer_norm(grouped_embeds + processed)
|
| 370 |
+
return output
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
|
| 374 |
+
def __init__(self, config):
|
| 375 |
+
super().__init__(config)
|
| 376 |
+
self.grouped_input_mlp = GroupedInputMLPAdapter(config)
|
| 377 |
+
self.is_grouped_input_mode = False
|
| 378 |
+
|
| 379 |
+
def forward(
|
| 380 |
+
self,
|
| 381 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 382 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 383 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 384 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 385 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 386 |
+
use_cache: Optional[bool] = None,
|
| 387 |
+
output_attentions: Optional[bool] = None,
|
| 388 |
+
output_hidden_states: Optional[bool] = None,
|
| 389 |
+
return_dict: Optional[bool] = None,
|
| 390 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 391 |
+
grouped_inputs: Optional[torch.FloatTensor] = None,
|
| 392 |
+
is_prefill: Optional[bool] = None,
|
| 393 |
+
**kwargs
|
| 394 |
+
):
|
| 395 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 396 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 397 |
+
|
| 398 |
+
if grouped_inputs is not None and is_prefill:
|
| 399 |
+
self.is_grouped_input_mode = True
|
| 400 |
+
processed_grouped_inputs = self.grouped_input_mlp(grouped_inputs)
|
| 401 |
+
inputs_embeds = processed_grouped_inputs
|
| 402 |
+
input_ids = None
|
| 403 |
+
|
| 404 |
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
| 405 |
+
if position_ids is None:
|
| 406 |
+
device = inputs_embeds.device
|
| 407 |
+
position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
|
| 408 |
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
| 409 |
+
|
| 410 |
+
if attention_mask is None:
|
| 411 |
+
attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long)
|
| 412 |
+
|
| 413 |
+
return super().forward(
|
| 414 |
+
input_ids=input_ids,
|
| 415 |
+
attention_mask=attention_mask,
|
| 416 |
+
position_ids=position_ids,
|
| 417 |
+
past_key_values=past_key_values,
|
| 418 |
+
inputs_embeds=inputs_embeds,
|
| 419 |
+
use_cache=use_cache,
|
| 420 |
+
output_attentions=output_attentions,
|
| 421 |
+
output_hidden_states=output_hidden_states,
|
| 422 |
+
return_dict=return_dict,
|
| 423 |
+
cache_position=cache_position,
|
| 424 |
+
**kwargs
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
def create_grouping_model(model_name: str = "Qwen/Qwen3-0.6B") -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 428 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 429 |
+
|
| 430 |
+
if torch.cuda.is_available():
|
| 431 |
+
device = torch.device("cuda")
|
| 432 |
+
dtype = torch.float16
|
| 433 |
+
else:
|
| 434 |
+
device = torch.device("cpu")
|
| 435 |
+
dtype = torch.float32
|
| 436 |
+
|
| 437 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 438 |
+
model_name,
|
| 439 |
+
torch_dtype=dtype,
|
| 440 |
+
attn_implementation="eager"
|
| 441 |
+
).to(device)
|
| 442 |
+
|
| 443 |
+
orig0 = model.model.layers[0]
|
| 444 |
+
custom0 = CustomQwen3DecoderLayer(model.config, 0)
|
| 445 |
+
|
| 446 |
+
custom0.mlp.load_state_dict(orig0.mlp.state_dict())
|
| 447 |
+
custom0.input_layernorm.load_state_dict(orig0.input_layernorm.state_dict())
|
| 448 |
+
custom0.post_attention_layernorm.load_state_dict(orig0.post_attention_layernorm.state_dict())
|
| 449 |
+
custom0.self_attn.load_state_dict(orig0.self_attn.state_dict())
|
| 450 |
+
|
| 451 |
+
custom0.self_attn.set_tokenizer(tokenizer)
|
| 452 |
+
custom0 = custom0.to(device=device, dtype=dtype)
|
| 453 |
+
model.model.layers[0] = custom0
|
| 454 |
+
|
| 455 |
+
return model, tokenizer
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def load_inference_model(checkpoint_path: str) -> Tuple[CustomQwen3ForCausalLM, AutoTokenizer]:
|
| 459 |
+
logger.info(f"Loading inference model from {checkpoint_path}")
|
| 460 |
+
|
| 461 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
| 462 |
+
config = Qwen3Config.from_pretrained(checkpoint_path)
|
| 463 |
+
model = CustomQwen3ForCausalLM(config)
|
| 464 |
+
|
| 465 |
+
model_path = Path(checkpoint_path) / "pytorch_model.bin"
|
| 466 |
+
if not model_path.exists():
|
| 467 |
+
model_path = Path(checkpoint_path) / "model.safetensors"
|
| 468 |
+
|
| 469 |
+
if not model_path.exists():
|
| 470 |
+
raise FileNotFoundError(f"No model weights found in {checkpoint_path}")
|
| 471 |
+
|
| 472 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 473 |
+
model.load_state_dict(state_dict, strict=False)
|
| 474 |
+
model = model.eval().to(torch.float32)
|
| 475 |
+
|
| 476 |
+
return model, tokenizer
|
| 477 |
+
|
| 478 |
+
class Qwen3GroupedInference:
|
| 479 |
+
def __init__(self,
|
| 480 |
+
checkpoint_path: str,
|
| 481 |
+
grouping_model_name: str = "Qwen/Qwen3-0.6B",
|
| 482 |
+
device: Optional[str] = None):
|
| 483 |
+
"""Initialize inference system with both models."""
|
| 484 |
+
|
| 485 |
+
if device is None:
|
| 486 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 487 |
+
else:
|
| 488 |
+
self.device = torch.device(device)
|
| 489 |
+
|
| 490 |
+
logger.info(f"Initializing inference on device: {self.device}")
|
| 491 |
+
|
| 492 |
+
self.system_prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
|
| 493 |
+
self.response_start = "<|im_end|>\n<|im_start|>assistant\n"
|
| 494 |
+
|
| 495 |
+
logger.info("Loading grouping model...")
|
| 496 |
+
self.grouping_model, self.grouping_tokenizer = create_grouping_model(grouping_model_name)
|
| 497 |
+
self.grouping_model = self.grouping_model.to(self.device)
|
| 498 |
+
|
| 499 |
+
logger.info("Loading inference model...")
|
| 500 |
+
self.inference_model, self.inference_tokenizer = load_inference_model(checkpoint_path)
|
| 501 |
+
self.inference_model = self.inference_model.to(self.device)
|
| 502 |
+
|
| 503 |
+
logger.info("Both models loaded successfully")
|
| 504 |
+
|
| 505 |
+
def format_input_text(self, instruction: str) -> str:
|
| 506 |
+
return f"{self.system_prompt}{instruction}{self.response_start}"
|
| 507 |
+
|
| 508 |
+
def get_grouped_embeddings(self, text: str) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
| 509 |
+
monitor = PerformanceMonitor()
|
| 510 |
+
monitor.start_monitoring()
|
| 511 |
+
|
| 512 |
+
if hasattr(self.grouping_model.model.layers[0], "is_initialized"):
|
| 513 |
+
self.grouping_model.model.layers[0].is_initialized = False
|
| 514 |
+
|
| 515 |
+
batch = self.grouping_tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(self.device)
|
| 516 |
+
input_ids = batch["input_ids"]
|
| 517 |
+
original_token_count = input_ids.shape[1]
|
| 518 |
+
|
| 519 |
+
original_tokens = [self.grouping_tokenizer.decode([token_id]) for token_id in input_ids[0]]
|
| 520 |
+
logger.info(f"Original input tokens ({original_token_count}): {original_tokens}")
|
| 521 |
+
|
| 522 |
+
if hasattr(self.grouping_model.model.layers[0], "self_attn"):
|
| 523 |
+
sat = self.grouping_model.model.layers[0].self_attn
|
| 524 |
+
if hasattr(sat, "set_current_input_ids"):
|
| 525 |
+
sat.set_current_input_ids(input_ids)
|
| 526 |
+
|
| 527 |
+
monitor.update_peak_memory()
|
| 528 |
+
|
| 529 |
+
with torch.no_grad():
|
| 530 |
+
inputs_embeds = self.grouping_model.model.embed_tokens(input_ids)
|
| 531 |
+
seq_len = inputs_embeds.shape[1]
|
| 532 |
+
position_ids = torch.arange(seq_len, device=self.device, dtype=torch.long).unsqueeze(0)
|
| 533 |
+
|
| 534 |
+
if hasattr(self.grouping_model.model, 'rotary_emb'):
|
| 535 |
+
pos_embeds = self.grouping_model.model.rotary_emb(inputs_embeds, position_ids)
|
| 536 |
+
else:
|
| 537 |
+
pos_embeds = None
|
| 538 |
+
|
| 539 |
+
monitor.update_peak_memory()
|
| 540 |
+
|
| 541 |
+
_ = self.grouping_model.model.layers[0](
|
| 542 |
+
hidden_states=inputs_embeds,
|
| 543 |
+
attention_mask=None,
|
| 544 |
+
position_ids=position_ids,
|
| 545 |
+
past_key_value=None,
|
| 546 |
+
output_attentions=False,
|
| 547 |
+
use_cache=False,
|
| 548 |
+
cache_position=None,
|
| 549 |
+
position_embeddings=pos_embeds,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
monitor.update_peak_memory()
|
| 553 |
+
|
| 554 |
+
if (hasattr(self.grouping_model.model.layers[0], "grouped_hidden_states") and
|
| 555 |
+
self.grouping_model.model.layers[0].grouped_hidden_states is not None):
|
| 556 |
+
grouped_embeds = self.grouping_model.model.layers[0].grouped_hidden_states.clone()
|
| 557 |
+
grouped_token_count = grouped_embeds.shape[1]
|
| 558 |
+
# Clear the stored state
|
| 559 |
+
self.grouping_model.model.layers[0].grouped_hidden_states = None
|
| 560 |
+
|
| 561 |
+
compression_ratio = original_token_count / grouped_token_count if grouped_token_count > 0 else 1.0
|
| 562 |
+
reduction_percent = (1 - grouped_token_count / original_token_count) * 100 if original_token_count > 0 else 0.0
|
| 563 |
+
|
| 564 |
+
logger.info(f"Grouped tokens: {grouped_token_count}")
|
| 565 |
+
logger.info(f"Compression ratio: {compression_ratio:.2f}x ({reduction_percent:.1f}% reduction)")
|
| 566 |
+
|
| 567 |
+
metrics = monitor.stop_monitoring()
|
| 568 |
+
metrics.update({
|
| 569 |
+
"original_tokens": original_token_count,
|
| 570 |
+
"grouped_tokens": grouped_token_count,
|
| 571 |
+
"compression_ratio": compression_ratio,
|
| 572 |
+
"reduction_percent": reduction_percent
|
| 573 |
+
})
|
| 574 |
+
|
| 575 |
+
return grouped_embeds.squeeze(0), metrics
|
| 576 |
+
else:
|
| 577 |
+
logger.warning("Grouping failed, using original embeddings")
|
| 578 |
+
|
| 579 |
+
metrics = monitor.stop_monitoring()
|
| 580 |
+
metrics.update({
|
| 581 |
+
"original_tokens": original_token_count,
|
| 582 |
+
"grouped_tokens": original_token_count,
|
| 583 |
+
"compression_ratio": 1.0,
|
| 584 |
+
"reduction_percent": 0.0
|
| 585 |
+
})
|
| 586 |
+
|
| 587 |
+
return inputs_embeds.squeeze(0), metrics
|
| 588 |
+
|
| 589 |
+
def generate_with_grouped_input(self,
|
| 590 |
+
grouped_input: torch.Tensor,
|
| 591 |
+
max_length: int = 512,
|
| 592 |
+
temperature: float = 0.7,
|
| 593 |
+
do_sample: bool = True) -> Tuple[str, Dict[str, Any]]:
|
| 594 |
+
"""Generate text using grouped input embeddings."""
|
| 595 |
+
monitor = PerformanceMonitor()
|
| 596 |
+
monitor.start_monitoring()
|
| 597 |
+
|
| 598 |
+
model_dtype = next(self.inference_model.parameters()).dtype
|
| 599 |
+
grouped_input = grouped_input.to(device=self.device, dtype=model_dtype)
|
| 600 |
+
|
| 601 |
+
if grouped_input.ndim == 2:
|
| 602 |
+
grouped_input = grouped_input.unsqueeze(0)
|
| 603 |
+
|
| 604 |
+
input_seq_len = grouped_input.shape[1]
|
| 605 |
+
logger.info(f"Inference model input sequence length: {input_seq_len}")
|
| 606 |
+
|
| 607 |
+
monitor.update_peak_memory()
|
| 608 |
+
|
| 609 |
+
with torch.no_grad():
|
| 610 |
+
outputs = self.inference_model(
|
| 611 |
+
grouped_inputs=grouped_input,
|
| 612 |
+
is_prefill=True,
|
| 613 |
+
use_cache=True,
|
| 614 |
+
return_dict=True
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
monitor.update_peak_memory()
|
| 618 |
+
|
| 619 |
+
if hasattr(outputs, 'logits') and outputs.logits is not None:
|
| 620 |
+
next_token_logits = outputs.logits[:, -1, :]
|
| 621 |
+
else:
|
| 622 |
+
raise RuntimeError("Could not extract logits from model output")
|
| 623 |
+
|
| 624 |
+
if do_sample:
|
| 625 |
+
next_token_logits = next_token_logits / temperature
|
| 626 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 627 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 628 |
+
else:
|
| 629 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 630 |
+
|
| 631 |
+
generated_ids = next_token
|
| 632 |
+
past_key_values = getattr(outputs, 'past_key_values', None)
|
| 633 |
+
|
| 634 |
+
generated_tokens = 1
|
| 635 |
+
for step in range(max_length - 1):
|
| 636 |
+
monitor.update_peak_memory()
|
| 637 |
+
|
| 638 |
+
with torch.no_grad():
|
| 639 |
+
outputs = self.inference_model(
|
| 640 |
+
input_ids=next_token,
|
| 641 |
+
past_key_values=past_key_values,
|
| 642 |
+
use_cache=True,
|
| 643 |
+
return_dict=True
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if not hasattr(outputs, 'logits'):
|
| 647 |
+
break
|
| 648 |
+
|
| 649 |
+
next_token_logits = outputs.logits[:, -1, :]
|
| 650 |
+
|
| 651 |
+
if do_sample:
|
| 652 |
+
next_token_logits = next_token_logits / temperature
|
| 653 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 654 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 655 |
+
else:
|
| 656 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 657 |
+
|
| 658 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 659 |
+
past_key_values = getattr(outputs, 'past_key_values', None)
|
| 660 |
+
generated_tokens += 1
|
| 661 |
+
|
| 662 |
+
if next_token.item() == self.inference_tokenizer.eos_token_id:
|
| 663 |
+
break
|
| 664 |
+
|
| 665 |
+
generated_text = self.inference_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 666 |
+
|
| 667 |
+
metrics = monitor.stop_monitoring()
|
| 668 |
+
metrics.update({
|
| 669 |
+
"input_seq_len": input_seq_len,
|
| 670 |
+
"generated_tokens": generated_tokens,
|
| 671 |
+
"tokens_per_second": generated_tokens / (metrics["duration_ms"] / 1000) if metrics["duration_ms"] > 0 else 0
|
| 672 |
+
})
|
| 673 |
+
|
| 674 |
+
logger.info(f"Generated {generated_tokens} tokens in {metrics['duration_ms']:.1f}ms")
|
| 675 |
+
logger.info(f"Generation speed: {metrics['tokens_per_second']:.1f} tokens/second")
|
| 676 |
+
|
| 677 |
+
return generated_text, metrics
|
| 678 |
+
|
| 679 |
+
def inference(self,
|
| 680 |
+
instruction: str,
|
| 681 |
+
max_length: int = 512,
|
| 682 |
+
temperature: float = 0.7,
|
| 683 |
+
do_sample: bool = True) -> Dict[str, Any]:
|
| 684 |
+
"""Run complete inference pipeline from instruction to response."""
|
| 685 |
+
logger.info("=" * 60)
|
| 686 |
+
logger.info("STARTING INFERENCE PIPELINE")
|
| 687 |
+
logger.info("=" * 60)
|
| 688 |
+
|
| 689 |
+
input_text = self.format_input_text(instruction)
|
| 690 |
+
|
| 691 |
+
logger.info("PHASE 1: Token Grouping")
|
| 692 |
+
grouped_embeddings, grouping_metrics = self.get_grouped_embeddings(input_text)
|
| 693 |
+
|
| 694 |
+
logger.info("PHASE 2: Response Generation")
|
| 695 |
+
response, generation_metrics = self.generate_with_grouped_input(
|
| 696 |
+
grouped_input=grouped_embeddings,
|
| 697 |
+
max_length=max_length,
|
| 698 |
+
temperature=temperature,
|
| 699 |
+
do_sample=do_sample
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
total_metrics = {
|
| 703 |
+
"grouping": grouping_metrics,
|
| 704 |
+
"generation": generation_metrics,
|
| 705 |
+
"total_duration_ms": grouping_metrics["duration_ms"] + generation_metrics["duration_ms"],
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
logger.info("=" * 60)
|
| 709 |
+
logger.info("INFERENCE SUMMARY")
|
| 710 |
+
logger.info("=" * 60)
|
| 711 |
+
logger.info(f"Input compression: {grouping_metrics['original_tokens']} β {grouping_metrics['grouped_tokens']} tokens")
|
| 712 |
+
logger.info(f"Compression ratio: {grouping_metrics['compression_ratio']:.2f}x")
|
| 713 |
+
logger.info(f"Memory reduction: {grouping_metrics['reduction_percent']:.1f}%")
|
| 714 |
+
logger.info(f"Total time: {total_metrics['total_duration_ms']:.1f}ms")
|
| 715 |
+
logger.info(f"Generation speed: {generation_metrics['tokens_per_second']:.1f} tokens/sec")
|
| 716 |
+
|
| 717 |
+
if torch.cuda.is_available():
|
| 718 |
+
total_gpu_memory = grouping_metrics.get("gpu_memory_used_mb", 0) + generation_metrics.get("gpu_memory_used_mb", 0)
|
| 719 |
+
logger.info(f"Total GPU memory used: {total_gpu_memory:.1f}MB")
|
| 720 |
+
|
| 721 |
+
total_cpu_memory = grouping_metrics.get("cpu_memory_used_mb", 0) + generation_metrics.get("cpu_memory_used_mb", 0)
|
| 722 |
+
logger.info(f"Total CPU memory used: {total_cpu_memory:.1f}MB")
|
| 723 |
+
|
| 724 |
+
original_seq_len = grouping_metrics['original_tokens']
|
| 725 |
+
grouped_seq_len = grouping_metrics['grouped_tokens']
|
| 726 |
+
|
| 727 |
+
estimated_memory_savings = (1 - (grouped_seq_len ** 2) / (original_seq_len ** 2)) * 100 if original_seq_len > 0 else 0
|
| 728 |
+
logger.info(f"Estimated attention memory savings: {estimated_memory_savings:.1f}%")
|
| 729 |
+
|
| 730 |
+
logger.info("=" * 60)
|
| 731 |
+
|
| 732 |
+
return {
|
| 733 |
+
"instruction": instruction,
|
| 734 |
+
"response": response,
|
| 735 |
+
"metrics": total_metrics
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
def main():
|
| 739 |
+
import argparse
|
| 740 |
+
|
| 741 |
+
parser = argparse.ArgumentParser(description="Qwen3 Grouped Inference")
|
| 742 |
+
parser.add_argument("--checkpoint", type=str,
|
| 743 |
+
default="./grouped_qwen3_checkpoint/epoch_2_best",
|
| 744 |
+
help="Path to trained model checkpoint")
|
| 745 |
+
parser.add_argument("--grouping_model", type=str, default="Qwen/Qwen3-0.6B",
|
| 746 |
+
help="Grouping model name")
|
| 747 |
+
parser.add_argument("--instruction", type=str, default="""
|
| 748 |
+
Π§ΡΠΎ ΡΠ°ΠΊΠΎΠ΅ Π½Π΅ΠΉΡΠΎΠ½Π½ΡΠ΅ ΡΠ΅ΡΠΈ, ΠΎΠ±ΡΡΡΠ½ΠΈ ΠΊΠ°ΠΊ ΡΠΊΠΎΠ»ΡΠ½ΠΈΠΊΡ 9Π³ΠΎ ΠΊΠ»Π°ΡΡΠ°
|
| 749 |
+
""",
|
| 750 |
+
help="Instruction for inference")
|
| 751 |
+
parser.add_argument("--max_length", type=int, default=512,
|
| 752 |
+
help="Maximum generation length")
|
| 753 |
+
parser.add_argument("--temperature", type=float, default=0.7,
|
| 754 |
+
help="Generation temperature")
|
| 755 |
+
parser.add_argument("--no_sample", action="store_true",
|
| 756 |
+
help="Use greedy decoding")
|
| 757 |
+
parser.add_argument("--device", type=str,
|
| 758 |
+
help="Device to use (cuda/cpu)")
|
| 759 |
+
|
| 760 |
+
args = parser.parse_args()
|
| 761 |
+
|
| 762 |
+
inference_system = Qwen3GroupedInference(
|
| 763 |
+
checkpoint_path=args.checkpoint,
|
| 764 |
+
grouping_model_name=args.grouping_model,
|
| 765 |
+
device=args.device
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
do_sample = not args.no_sample
|
| 769 |
+
|
| 770 |
+
result = inference_system.inference(
|
| 771 |
+
instruction=args.instruction,
|
| 772 |
+
max_length=args.max_length,
|
| 773 |
+
temperature=args.temperature,
|
| 774 |
+
do_sample=do_sample
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
print(f"\nInstruction: {result['instruction']}")
|
| 778 |
+
print(f"Response: {result['response']}")
|
| 779 |
+
|
| 780 |
+
metrics = result.get('metrics', {})
|
| 781 |
+
if metrics:
|
| 782 |
+
print(f"\n--- Performance Metrics ---")
|
| 783 |
+
grouping = metrics.get('grouping', {})
|
| 784 |
+
generation = metrics.get('generation', {})
|
| 785 |
+
|
| 786 |
+
print(f"Token compression: {grouping.get('compression_ratio', 'N/A'):.2f}x")
|
| 787 |
+
print(f"Memory reduction: {grouping.get('reduction_percent', 'N/A'):.1f}%")
|
| 788 |
+
print(f"Total time: {metrics.get('total_duration_ms', 'N/A'):.1f}ms")
|
| 789 |
+
print(f"Generation speed: {generation.get('tokens_per_second', 'N/A'):.1f} tokens/sec")
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
if __name__ == "__main__":
|
| 793 |
+
main()
|
prepare_dataset.py
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import json
|
| 5 |
+
import pickle
|
| 6 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from datasets import load_dataset, Dataset
|
| 13 |
+
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
+
from transformers.models.qwen3.modeling_qwen3 import (
|
| 16 |
+
Qwen3DecoderLayer,
|
| 17 |
+
Qwen3Attention,
|
| 18 |
+
Qwen3RotaryEmbedding,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(
|
| 22 |
+
level=logging.INFO,
|
| 23 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 24 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 25 |
+
force=True,
|
| 26 |
+
)
|
| 27 |
+
logger = logging.getLogger("qwen3_dataset_processor")
|
| 28 |
+
|
| 29 |
+
class GroupedCache:
|
| 30 |
+
"""Cache for grouping metadata."""
|
| 31 |
+
def __init__(self):
|
| 32 |
+
self.grouped_positions = None
|
| 33 |
+
self.position_mapping = None
|
| 34 |
+
self.group_info = None
|
| 35 |
+
self.original_seq_length = None
|
| 36 |
+
|
| 37 |
+
class CustomQwen3Attention(Qwen3Attention):
|
| 38 |
+
def __init__(self, config, layer_idx: int):
|
| 39 |
+
super().__init__(config, layer_idx)
|
| 40 |
+
self.layer_idx = layer_idx
|
| 41 |
+
self.tokenizer = None
|
| 42 |
+
self.current_input_ids = None
|
| 43 |
+
self.threshold = 0.1
|
| 44 |
+
self.grouped_cache = GroupedCache()
|
| 45 |
+
|
| 46 |
+
if not hasattr(self, 'num_key_value_heads'):
|
| 47 |
+
self.num_key_value_heads = config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else config.num_attention_heads
|
| 48 |
+
if not hasattr(self, 'head_dim'):
|
| 49 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
| 50 |
+
|
| 51 |
+
def set_tokenizer(self, tokenizer):
|
| 52 |
+
self.tokenizer = tokenizer
|
| 53 |
+
|
| 54 |
+
def set_current_input_ids(self, input_ids):
|
| 55 |
+
self.current_input_ids = input_ids
|
| 56 |
+
|
| 57 |
+
def _is_special_token(self, token: str) -> bool:
|
| 58 |
+
if self.tokenizer is None:
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
special_tokens = set()
|
| 62 |
+
|
| 63 |
+
if hasattr(self.tokenizer, 'special_tokens_map'):
|
| 64 |
+
for token_type, token_value in self.tokenizer.special_tokens_map.items():
|
| 65 |
+
if isinstance(token_value, str):
|
| 66 |
+
special_tokens.add(token_value)
|
| 67 |
+
elif isinstance(token_value, list):
|
| 68 |
+
special_tokens.update(token_value)
|
| 69 |
+
|
| 70 |
+
if hasattr(self.tokenizer, 'added_tokens_encoder'):
|
| 71 |
+
special_tokens.update(self.tokenizer.added_tokens_encoder.keys())
|
| 72 |
+
|
| 73 |
+
if token in special_tokens:
|
| 74 |
+
return True
|
| 75 |
+
|
| 76 |
+
special_patterns = [
|
| 77 |
+
lambda t: t.startswith('<|') and t.endswith('|>'),
|
| 78 |
+
lambda t: t.startswith('<') and t.endswith('>'),
|
| 79 |
+
lambda t: t.startswith('[') and t.endswith(']'),
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
return any(pattern(token) for pattern in special_patterns)
|
| 83 |
+
|
| 84 |
+
def _get_token_relations(self, attention_weights: torch.Tensor, tokens: List[str]) -> List[Dict]:
|
| 85 |
+
batch_size, num_heads, query_len, key_len = attention_weights.shape
|
| 86 |
+
|
| 87 |
+
attn = attention_weights[0].mean(dim=0)
|
| 88 |
+
|
| 89 |
+
relations = []
|
| 90 |
+
|
| 91 |
+
if query_len == 1:
|
| 92 |
+
current_token_pos = len(tokens) - 1
|
| 93 |
+
|
| 94 |
+
token_relations = []
|
| 95 |
+
for j in range(len(tokens)):
|
| 96 |
+
if j != current_token_pos:
|
| 97 |
+
weight = attn[0, j].item()
|
| 98 |
+
if weight > self.threshold:
|
| 99 |
+
token_relations.append({
|
| 100 |
+
'target_pos': j,
|
| 101 |
+
'weight': round(weight, 3)
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
relations.append({
|
| 105 |
+
'source_pos': current_token_pos,
|
| 106 |
+
'relations': token_relations
|
| 107 |
+
})
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
for i in range(min(query_len, len(tokens))):
|
| 111 |
+
token_relations = []
|
| 112 |
+
for j in range(len(tokens)):
|
| 113 |
+
if i != j and j < key_len:
|
| 114 |
+
weight = attn[i, j].item()
|
| 115 |
+
if weight > self.threshold:
|
| 116 |
+
token_relations.append({
|
| 117 |
+
'target_pos': j,
|
| 118 |
+
'weight': round(weight, 3)
|
| 119 |
+
})
|
| 120 |
+
|
| 121 |
+
relations.append({
|
| 122 |
+
'source_pos': i,
|
| 123 |
+
'relations': token_relations
|
| 124 |
+
})
|
| 125 |
+
|
| 126 |
+
return relations
|
| 127 |
+
|
| 128 |
+
def _get_token_groups(self, attention_weights: torch.Tensor) -> List[List[int]]:
|
| 129 |
+
if self.tokenizer is None or self.current_input_ids is None:
|
| 130 |
+
return []
|
| 131 |
+
|
| 132 |
+
if len(attention_weights.shape) != 4:
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
batch_size, num_heads, query_len, key_len = attention_weights.shape
|
| 136 |
+
|
| 137 |
+
input_ids = self.current_input_ids
|
| 138 |
+
if input_ids is None or input_ids.shape[1] < key_len:
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
tokens = [self.tokenizer.decode([token_id]) for token_id in input_ids[0][:key_len]]
|
| 142 |
+
|
| 143 |
+
relations = self._get_token_relations(attention_weights, tokens)
|
| 144 |
+
|
| 145 |
+
groups = []
|
| 146 |
+
current_group = []
|
| 147 |
+
current_group_indices = []
|
| 148 |
+
|
| 149 |
+
for i, token in enumerate(tokens):
|
| 150 |
+
is_empty_relations = i < len(relations) and len(relations[i]['relations']) == 0
|
| 151 |
+
starts_with_space = token.startswith(' ') and token != ' '
|
| 152 |
+
is_space = token == ' '
|
| 153 |
+
is_new_line = '\n' in token
|
| 154 |
+
|
| 155 |
+
prev_token_is_special = False
|
| 156 |
+
prev_token_is_new_line = False
|
| 157 |
+
prev_token_is_space = False
|
| 158 |
+
if i > 0:
|
| 159 |
+
prev_token = tokens[i-1]
|
| 160 |
+
prev_token_is_special = self._is_special_token(prev_token)
|
| 161 |
+
prev_token_is_new_line = '\n' in prev_token
|
| 162 |
+
prev_token_is_space = prev_token == ' '
|
| 163 |
+
|
| 164 |
+
prev_newline_current_not = prev_token_is_new_line and not is_new_line
|
| 165 |
+
prev_space_current_not = prev_token_is_space and not is_space
|
| 166 |
+
current_space_prev_not = is_space and not prev_token_is_space
|
| 167 |
+
|
| 168 |
+
if (is_empty_relations or starts_with_space or is_new_line or
|
| 169 |
+
prev_token_is_special or prev_newline_current_not or prev_space_current_not or
|
| 170 |
+
current_space_prev_not) and current_group:
|
| 171 |
+
groups.append(current_group_indices)
|
| 172 |
+
current_group = []
|
| 173 |
+
current_group_indices = []
|
| 174 |
+
|
| 175 |
+
current_group.append(token)
|
| 176 |
+
current_group_indices.append(i)
|
| 177 |
+
|
| 178 |
+
if current_group:
|
| 179 |
+
groups.append(current_group_indices)
|
| 180 |
+
|
| 181 |
+
return groups
|
| 182 |
+
|
| 183 |
+
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
|
| 184 |
+
def __init__(self, config, layer_idx: int):
|
| 185 |
+
super().__init__(config, layer_idx)
|
| 186 |
+
self.layer_idx = layer_idx
|
| 187 |
+
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
|
| 188 |
+
self.self_attn = CustomQwen3Attention(config, layer_idx)
|
| 189 |
+
self.is_initialized = False
|
| 190 |
+
self.grouped_hidden_states = None
|
| 191 |
+
self.grouped_cache = GroupedCache()
|
| 192 |
+
|
| 193 |
+
def forward(
|
| 194 |
+
self,
|
| 195 |
+
hidden_states: torch.Tensor,
|
| 196 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 197 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 198 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 199 |
+
output_attentions: Optional[bool] = False,
|
| 200 |
+
use_cache: Optional[bool] = False,
|
| 201 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 202 |
+
position_embeddings: Optional[tuple] = None,
|
| 203 |
+
**kwargs,
|
| 204 |
+
):
|
| 205 |
+
if self.layer_idx != 0:
|
| 206 |
+
return super().forward(
|
| 207 |
+
hidden_states=hidden_states,
|
| 208 |
+
attention_mask=attention_mask,
|
| 209 |
+
position_ids=position_ids,
|
| 210 |
+
past_key_value=past_key_value,
|
| 211 |
+
output_attentions=output_attentions,
|
| 212 |
+
use_cache=use_cache,
|
| 213 |
+
cache_position=cache_position,
|
| 214 |
+
position_embeddings=position_embeddings,
|
| 215 |
+
**kwargs,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
is_prefill = hidden_states.shape[1] > 1 and not self.is_initialized
|
| 219 |
+
if not is_prefill:
|
| 220 |
+
return super().forward(
|
| 221 |
+
hidden_states=hidden_states,
|
| 222 |
+
attention_mask=attention_mask,
|
| 223 |
+
position_ids=position_ids,
|
| 224 |
+
past_key_value=past_key_value,
|
| 225 |
+
output_attentions=output_attentions,
|
| 226 |
+
use_cache=use_cache,
|
| 227 |
+
cache_position=cache_position,
|
| 228 |
+
position_embeddings=position_embeddings,
|
| 229 |
+
**kwargs,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
residual = hidden_states
|
| 233 |
+
x = self.input_layernorm(hidden_states)
|
| 234 |
+
|
| 235 |
+
attn_ret = self.self_attn(
|
| 236 |
+
hidden_states=x,
|
| 237 |
+
attention_mask=attention_mask,
|
| 238 |
+
position_ids=position_ids,
|
| 239 |
+
past_key_value=None,
|
| 240 |
+
output_attentions=True,
|
| 241 |
+
use_cache=False,
|
| 242 |
+
cache_position=cache_position,
|
| 243 |
+
position_embeddings=position_embeddings,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if isinstance(attn_ret, tuple):
|
| 247 |
+
if len(attn_ret) == 3:
|
| 248 |
+
attn_out, attn_weights, _ = attn_ret
|
| 249 |
+
elif len(attn_ret) == 2:
|
| 250 |
+
attn_out, attn_weights = attn_ret
|
| 251 |
+
else:
|
| 252 |
+
raise RuntimeError(f"Unexpected attention return length: {len(attn_ret)}")
|
| 253 |
+
else:
|
| 254 |
+
raise RuntimeError("Attention did not return weights.")
|
| 255 |
+
|
| 256 |
+
groups = self.self_attn._get_token_groups(attn_weights)
|
| 257 |
+
if not groups:
|
| 258 |
+
self.is_initialized = True
|
| 259 |
+
return super().forward(
|
| 260 |
+
hidden_states=hidden_states,
|
| 261 |
+
attention_mask=attention_mask,
|
| 262 |
+
position_ids=position_ids,
|
| 263 |
+
past_key_value=past_key_value,
|
| 264 |
+
output_attentions=output_attentions,
|
| 265 |
+
use_cache=use_cache,
|
| 266 |
+
cache_position=cache_position,
|
| 267 |
+
position_embeddings=position_embeddings,
|
| 268 |
+
**kwargs,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
averaged_vectors = []
|
| 272 |
+
group_info = []
|
| 273 |
+
position_mapping = {}
|
| 274 |
+
|
| 275 |
+
for gi, idxs in enumerate(groups):
|
| 276 |
+
if len(idxs) == 1:
|
| 277 |
+
averaged_vectors.append(attn_out[:, idxs[0], :])
|
| 278 |
+
group_info.append({"type": "single", "positions": idxs, "new_position": gi})
|
| 279 |
+
else:
|
| 280 |
+
gvecs = attn_out[:, idxs, :]
|
| 281 |
+
ave = gvecs.mean(dim=1)
|
| 282 |
+
averaged_vectors.append(ave)
|
| 283 |
+
group_info.append({"type": "averaged", "positions": idxs, "new_position": gi})
|
| 284 |
+
|
| 285 |
+
for p in idxs:
|
| 286 |
+
position_mapping[p] = gi
|
| 287 |
+
|
| 288 |
+
new_attn_out = torch.stack(averaged_vectors, dim=1)
|
| 289 |
+
|
| 290 |
+
expanded_residual = torch.stack([
|
| 291 |
+
(
|
| 292 |
+
residual[:, info['positions'], :].sum(dim=1)
|
| 293 |
+
if len(info['positions']) > 1
|
| 294 |
+
else residual[:, info['positions'][0], :]
|
| 295 |
+
)
|
| 296 |
+
for info in group_info
|
| 297 |
+
], dim=1)
|
| 298 |
+
|
| 299 |
+
hs = expanded_residual + new_attn_out
|
| 300 |
+
grouped_hidden = self.post_attention_layernorm(hs)
|
| 301 |
+
|
| 302 |
+
self.grouped_cache.grouped_positions = len(groups)
|
| 303 |
+
self.grouped_cache.position_mapping = position_mapping
|
| 304 |
+
self.grouped_cache.group_info = group_info
|
| 305 |
+
self.grouped_cache.original_seq_length = hidden_states.shape[1]
|
| 306 |
+
self.grouped_hidden_states = grouped_hidden
|
| 307 |
+
|
| 308 |
+
self.is_initialized = True
|
| 309 |
+
return hs
|
| 310 |
+
|
| 311 |
+
def create_model_with_custom_layer0(model_name: str = "Qwen/Qwen3-0.6B"):
|
| 312 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 313 |
+
|
| 314 |
+
if torch.cuda.is_available():
|
| 315 |
+
device = torch.device("cuda")
|
| 316 |
+
dtype = torch.float16
|
| 317 |
+
else:
|
| 318 |
+
device = torch.device("cpu")
|
| 319 |
+
dtype = torch.float32
|
| 320 |
+
|
| 321 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 322 |
+
model_name,
|
| 323 |
+
torch_dtype=dtype,
|
| 324 |
+
attn_implementation="eager"
|
| 325 |
+
).to(device)
|
| 326 |
+
|
| 327 |
+
orig0 = model.model.layers[0]
|
| 328 |
+
custom0 = CustomQwen3DecoderLayer(model.config, 0)
|
| 329 |
+
|
| 330 |
+
custom0.mlp.load_state_dict(orig0.mlp.state_dict())
|
| 331 |
+
custom0.input_layernorm.load_state_dict(orig0.input_layernorm.state_dict())
|
| 332 |
+
custom0.post_attention_layernorm.load_state_dict(orig0.post_attention_layernorm.state_dict())
|
| 333 |
+
custom0.self_attn.load_state_dict(orig0.self_attn.state_dict())
|
| 334 |
+
|
| 335 |
+
custom0.self_attn.set_tokenizer(tokenizer)
|
| 336 |
+
custom0 = custom0.to(device=device, dtype=dtype)
|
| 337 |
+
model.model.layers[0] = custom0
|
| 338 |
+
|
| 339 |
+
return model, tokenizer
|
| 340 |
+
|
| 341 |
+
class DatasetProcessor:
|
| 342 |
+
def __init__(self,
|
| 343 |
+
model_name: str = "Qwen/Qwen3-0.6B",
|
| 344 |
+
dataset_name: str = "Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1",
|
| 345 |
+
output_dir: str = "./processed_dataset",
|
| 346 |
+
batch_size: int = 1,
|
| 347 |
+
max_samples: Optional[int] = None,
|
| 348 |
+
save_frequency: int = 1000):
|
| 349 |
+
|
| 350 |
+
self.model_name = model_name
|
| 351 |
+
self.dataset_name = dataset_name
|
| 352 |
+
self.output_dir = Path(output_dir)
|
| 353 |
+
self.batch_size = batch_size
|
| 354 |
+
self.max_samples = max_samples
|
| 355 |
+
self.save_frequency = save_frequency
|
| 356 |
+
|
| 357 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 358 |
+
|
| 359 |
+
# System prompt template for Qwen3
|
| 360 |
+
self.system_prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
|
| 361 |
+
self.response_start = "<|im_end|>\n<|im_start|>assistant\n"
|
| 362 |
+
|
| 363 |
+
self.current_chunk = 0
|
| 364 |
+
self.processed_data_buffer = []
|
| 365 |
+
|
| 366 |
+
def load_dataset(self) -> Dataset:
|
| 367 |
+
logger.info(f"Loading dataset: {self.dataset_name}")
|
| 368 |
+
dataset = load_dataset(self.dataset_name, split="train")
|
| 369 |
+
|
| 370 |
+
if self.max_samples:
|
| 371 |
+
dataset = dataset.select(range(min(self.max_samples, len(dataset))))
|
| 372 |
+
|
| 373 |
+
logger.info(f"Dataset loaded: {len(dataset)} samples")
|
| 374 |
+
return dataset
|
| 375 |
+
|
| 376 |
+
def format_input_text(self, instruction: str) -> str:
|
| 377 |
+
return f"{self.system_prompt}{instruction}{self.response_start}"
|
| 378 |
+
|
| 379 |
+
def process_embeddings_batch(self, model, tokenizer, texts: List[str]) -> List[torch.Tensor]:
|
| 380 |
+
device = model.device
|
| 381 |
+
embeddings_batch = []
|
| 382 |
+
|
| 383 |
+
for text in texts:
|
| 384 |
+
try:
|
| 385 |
+
if hasattr(model.model.layers[0], "is_initialized"):
|
| 386 |
+
model.model.layers[0].is_initialized = False
|
| 387 |
+
|
| 388 |
+
batch = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(device)
|
| 389 |
+
input_ids = batch["input_ids"]
|
| 390 |
+
|
| 391 |
+
if hasattr(model.model.layers[0], "self_attn"):
|
| 392 |
+
sat = model.model.layers[0].self_attn
|
| 393 |
+
if hasattr(sat, "set_current_input_ids"):
|
| 394 |
+
sat.set_current_input_ids(input_ids)
|
| 395 |
+
|
| 396 |
+
with torch.no_grad():
|
| 397 |
+
inputs_embeds = model.model.embed_tokens(input_ids)
|
| 398 |
+
seq_len = inputs_embeds.shape[1]
|
| 399 |
+
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
|
| 400 |
+
|
| 401 |
+
if hasattr(model.model, 'rotary_emb'):
|
| 402 |
+
pos_embeds = model.model.rotary_emb(inputs_embeds, position_ids)
|
| 403 |
+
else:
|
| 404 |
+
pos_embeds = None
|
| 405 |
+
|
| 406 |
+
_ = model.model.layers[0](
|
| 407 |
+
hidden_states=inputs_embeds,
|
| 408 |
+
attention_mask=None,
|
| 409 |
+
position_ids=position_ids,
|
| 410 |
+
past_key_value=None,
|
| 411 |
+
output_attentions=False,
|
| 412 |
+
use_cache=False,
|
| 413 |
+
cache_position=None,
|
| 414 |
+
position_embeddings=pos_embeds,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
if (hasattr(model.model.layers[0], "grouped_hidden_states") and
|
| 418 |
+
model.model.layers[0].grouped_hidden_states is not None):
|
| 419 |
+
grouped_embeds = model.model.layers[0].grouped_hidden_states.clone().cpu()
|
| 420 |
+
embeddings_batch.append(grouped_embeds.squeeze(0))
|
| 421 |
+
|
| 422 |
+
model.model.layers[0].grouped_hidden_states = None
|
| 423 |
+
else:
|
| 424 |
+
embeddings_batch.append(inputs_embeds.squeeze(0).cpu())
|
| 425 |
+
|
| 426 |
+
del inputs_embeds, position_ids
|
| 427 |
+
if pos_embeds is not None:
|
| 428 |
+
del pos_embeds
|
| 429 |
+
if torch.cuda.is_available():
|
| 430 |
+
torch.cuda.empty_cache()
|
| 431 |
+
|
| 432 |
+
except Exception as e:
|
| 433 |
+
logger.warning(f"Error processing sample: {e}")
|
| 434 |
+
embeddings_batch.append(torch.zeros(1, model.config.hidden_size))
|
| 435 |
+
|
| 436 |
+
return embeddings_batch
|
| 437 |
+
|
| 438 |
+
def save_chunk(self, chunk_data: List[Dict[str, Any]], chunk_id: int):
|
| 439 |
+
if not chunk_data:
|
| 440 |
+
return
|
| 441 |
+
|
| 442 |
+
chunk_path = self.output_dir / f"processed_chunk_{chunk_id:04d}.pkl"
|
| 443 |
+
with open(chunk_path, 'wb') as f:
|
| 444 |
+
pickle.dump(chunk_data, f)
|
| 445 |
+
|
| 446 |
+
# Clear memory
|
| 447 |
+
del chunk_data
|
| 448 |
+
import gc
|
| 449 |
+
gc.collect()
|
| 450 |
+
|
| 451 |
+
def merge_chunks(self) -> List[Dict[str, Any]]:
|
| 452 |
+
logger.info("Merging chunks...")
|
| 453 |
+
|
| 454 |
+
chunk_files = sorted(list(self.output_dir.glob("processed_chunk_*.pkl")))
|
| 455 |
+
if not chunk_files:
|
| 456 |
+
return []
|
| 457 |
+
|
| 458 |
+
merged_data = []
|
| 459 |
+
for chunk_file in tqdm(chunk_files, desc="Merging chunks"):
|
| 460 |
+
try:
|
| 461 |
+
with open(chunk_file, 'rb') as f:
|
| 462 |
+
chunk_data = pickle.load(f)
|
| 463 |
+
if isinstance(chunk_data, list):
|
| 464 |
+
merged_data.extend(chunk_data)
|
| 465 |
+
except Exception as e:
|
| 466 |
+
logger.error(f"Error loading chunk {chunk_file}: {e}")
|
| 467 |
+
continue
|
| 468 |
+
|
| 469 |
+
# Clean up chunk files
|
| 470 |
+
self.cleanup_chunks()
|
| 471 |
+
|
| 472 |
+
logger.info(f"Merged {len(chunk_files)} chunks into {len(merged_data)} samples")
|
| 473 |
+
return merged_data
|
| 474 |
+
|
| 475 |
+
def cleanup_chunks(self):
|
| 476 |
+
chunk_files = list(self.output_dir.glob("processed_chunk_*.pkl"))
|
| 477 |
+
for chunk_file in chunk_files:
|
| 478 |
+
try:
|
| 479 |
+
chunk_file.unlink()
|
| 480 |
+
except Exception as e:
|
| 481 |
+
logger.warning(f"Could not delete chunk {chunk_file}: {e}")
|
| 482 |
+
|
| 483 |
+
if chunk_files:
|
| 484 |
+
logger.info(f"Cleaned up {len(chunk_files)} temporary chunk files")
|
| 485 |
+
|
| 486 |
+
def save_final_dataset(self, processed_data: List[Dict[str, Any]], stats: Dict[str, int]):
|
| 487 |
+
pickle_path = self.output_dir / "processed_dataset.pkl"
|
| 488 |
+
with open(pickle_path, 'wb') as f:
|
| 489 |
+
pickle.dump(processed_data, f)
|
| 490 |
+
|
| 491 |
+
error_samples = sum(1 for sample in processed_data if sample.get("error", False))
|
| 492 |
+
successful_samples = len(processed_data) - error_samples
|
| 493 |
+
|
| 494 |
+
metadata = {
|
| 495 |
+
"model_name": self.model_name,
|
| 496 |
+
"dataset_name": self.dataset_name,
|
| 497 |
+
"total_samples": stats["total_samples"],
|
| 498 |
+
"processed_samples": len(processed_data),
|
| 499 |
+
"successful_samples": successful_samples,
|
| 500 |
+
"error_samples": error_samples,
|
| 501 |
+
"batch_size": self.batch_size,
|
| 502 |
+
"max_samples": self.max_samples,
|
| 503 |
+
"success_rate": f"{(successful_samples / len(processed_data) * 100):.2f}%" if processed_data else "0%"
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
with open(self.output_dir / "metadata.json", 'w', encoding='utf-8') as f:
|
| 507 |
+
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
| 508 |
+
|
| 509 |
+
text_samples = []
|
| 510 |
+
count = 0
|
| 511 |
+
for i, sample in enumerate(processed_data):
|
| 512 |
+
if not sample.get("error", False) and count < 10:
|
| 513 |
+
text_samples.append({
|
| 514 |
+
"sample_id": i,
|
| 515 |
+
"input_text": sample["input_text"][:300] + "..." if len(sample["input_text"]) > 300 else sample["input_text"],
|
| 516 |
+
"response": sample["response"][:300] + "..." if len(sample["response"]) > 300 else sample["response"],
|
| 517 |
+
"embedding_shape": sample["embedding_shape"]
|
| 518 |
+
})
|
| 519 |
+
count += 1
|
| 520 |
+
|
| 521 |
+
with open(self.output_dir / "samples.json", 'w', encoding='utf-8') as f:
|
| 522 |
+
json.dump(text_samples, f, indent=2, ensure_ascii=False)
|
| 523 |
+
|
| 524 |
+
logger.info(f"Dataset saved: {len(processed_data)} samples")
|
| 525 |
+
logger.info(f"Success rate: {metadata['success_rate']}")
|
| 526 |
+
|
| 527 |
+
def process_dataset(self):
|
| 528 |
+
dataset = self.load_dataset()
|
| 529 |
+
|
| 530 |
+
logger.info("Loading model...")
|
| 531 |
+
model, tokenizer = create_model_with_custom_layer0(self.model_name)
|
| 532 |
+
|
| 533 |
+
total_samples = len(dataset)
|
| 534 |
+
processed_count = 0
|
| 535 |
+
error_count = 0
|
| 536 |
+
|
| 537 |
+
logger.info(f"Processing {total_samples} samples...")
|
| 538 |
+
|
| 539 |
+
for i in tqdm(range(0, total_samples, self.batch_size), desc="Processing"):
|
| 540 |
+
batch_end = min(i + self.batch_size, total_samples)
|
| 541 |
+
batch_samples = dataset.select(range(i, batch_end))
|
| 542 |
+
|
| 543 |
+
batch_texts = []
|
| 544 |
+
batch_instructions = []
|
| 545 |
+
batch_responses = []
|
| 546 |
+
|
| 547 |
+
try:
|
| 548 |
+
for sample in batch_samples:
|
| 549 |
+
instruction = sample.get("instruction", "")
|
| 550 |
+
response = sample.get("response", "")
|
| 551 |
+
|
| 552 |
+
if not instruction.strip() or not response.strip():
|
| 553 |
+
instruction = "Empty instruction"
|
| 554 |
+
response = "Empty response"
|
| 555 |
+
|
| 556 |
+
input_text = self.format_input_text(instruction)
|
| 557 |
+
batch_texts.append(input_text)
|
| 558 |
+
batch_instructions.append(input_text)
|
| 559 |
+
batch_responses.append(response)
|
| 560 |
+
|
| 561 |
+
embeddings_batch = self.process_embeddings_batch(model, tokenizer, batch_texts)
|
| 562 |
+
|
| 563 |
+
for j, (input_text, embedding, response) in enumerate(zip(batch_instructions, embeddings_batch, batch_responses)):
|
| 564 |
+
processed_sample = {
|
| 565 |
+
"input_text": input_text,
|
| 566 |
+
"inputs_embeds": embedding,
|
| 567 |
+
"response": response,
|
| 568 |
+
"embedding_shape": list(embedding.shape),
|
| 569 |
+
"original_index": i + j
|
| 570 |
+
}
|
| 571 |
+
self.processed_data_buffer.append(processed_sample)
|
| 572 |
+
processed_count += 1
|
| 573 |
+
|
| 574 |
+
if len(self.processed_data_buffer) >= self.save_frequency:
|
| 575 |
+
self.save_chunk(self.processed_data_buffer, self.current_chunk)
|
| 576 |
+
self.processed_data_buffer = []
|
| 577 |
+
self.current_chunk += 1
|
| 578 |
+
|
| 579 |
+
import gc
|
| 580 |
+
gc.collect()
|
| 581 |
+
|
| 582 |
+
except Exception as e:
|
| 583 |
+
logger.error(f"Error processing batch: {e}")
|
| 584 |
+
error_count += len(batch_samples)
|
| 585 |
+
|
| 586 |
+
if self.processed_data_buffer:
|
| 587 |
+
self.save_chunk(self.processed_data_buffer, self.current_chunk)
|
| 588 |
+
self.processed_data_buffer = []
|
| 589 |
+
|
| 590 |
+
merged_data = self.merge_chunks()
|
| 591 |
+
|
| 592 |
+
stats = {
|
| 593 |
+
"total_samples": total_samples,
|
| 594 |
+
"processed_count": processed_count,
|
| 595 |
+
"error_count": error_count
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
self.save_final_dataset(merged_data, stats)
|
| 599 |
+
return merged_data
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def load_processed_dataset(dataset_path: str) -> List[Dict[str, Any]]:
|
| 603 |
+
pickle_path = Path(dataset_path) / "processed_dataset.pkl"
|
| 604 |
+
with open(pickle_path, 'rb') as f:
|
| 605 |
+
return pickle.load(f)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def get_dataset_info(dataset_path: str) -> Dict:
|
| 609 |
+
metadata_path = Path(dataset_path) / "metadata.json"
|
| 610 |
+
with open(metadata_path, 'r') as f:
|
| 611 |
+
return json.load(f)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def main():
|
| 615 |
+
model_name = "Qwen/Qwen3-0.6B"
|
| 616 |
+
dataset_name = "Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1"
|
| 617 |
+
output_dir = "./processed_qwen3_dataset"
|
| 618 |
+
batch_size = 1
|
| 619 |
+
max_samples = 10000 # Set to number for testing, None for full dataset
|
| 620 |
+
save_frequency = 1000
|
| 621 |
+
|
| 622 |
+
logger.info("Starting Qwen3 dataset processing...")
|
| 623 |
+
logger.info(f"Model: {model_name}")
|
| 624 |
+
logger.info(f"Dataset: {dataset_name}")
|
| 625 |
+
logger.info(f"Output: {output_dir}")
|
| 626 |
+
logger.info(f"Max samples: {max_samples or 'ALL'}")
|
| 627 |
+
|
| 628 |
+
try:
|
| 629 |
+
processor = DatasetProcessor(
|
| 630 |
+
model_name=model_name,
|
| 631 |
+
dataset_name=dataset_name,
|
| 632 |
+
output_dir=output_dir,
|
| 633 |
+
batch_size=batch_size,
|
| 634 |
+
max_samples=max_samples,
|
| 635 |
+
save_frequency=save_frequency
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
processed_data = processor.process_dataset()
|
| 639 |
+
|
| 640 |
+
logger.info("Processing completed successfully!")
|
| 641 |
+
logger.info(f"Final dataset: {len(processed_data)} samples")
|
| 642 |
+
logger.info(f"Files saved to: {output_dir}")
|
| 643 |
+
|
| 644 |
+
return processed_data
|
| 645 |
+
|
| 646 |
+
except Exception as e:
|
| 647 |
+
logger.error(f"Processing failed: {e}")
|
| 648 |
+
raise
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
if __name__ == "__main__":
|
| 652 |
+
main()
|
processed_qwen3_dataset/metadata.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "Qwen/Qwen3-0.6B",
|
| 3 |
+
"dataset_name": "Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1",
|
| 4 |
+
"total_samples": 10000,
|
| 5 |
+
"processed_samples": 10000,
|
| 6 |
+
"successful_samples": 10000,
|
| 7 |
+
"error_samples": 0,
|
| 8 |
+
"batch_size": 1,
|
| 9 |
+
"max_samples": 10000,
|
| 10 |
+
"success_rate": "100.00%"
|
| 11 |
+
}
|
processed_qwen3_dataset/processed_dataset.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:696e5df891b54ac7ea3f1b1c3a69ae24f8a7e0b8e96600e4105585b8882af2eb
|
| 3 |
+
size 1029880114
|
processed_qwen3_dataset/samples.json
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"sample_id": 0,
|
| 4 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nCan you tell me about the latest developments in AI research that could impact the field of healthcare?<|im_end|>\n<|im_start|>assistant\n",
|
| 5 |
+
"response": "Certainly! The intersection of AI and healthcare is a rapidly evolving field with numerous recent developments that hold significant promise. Here are some of the latest advancements:\n\n1. **AI in Diagnostics**:\n - **Imaging Analysis**: AI algorithms, particularly deep learning models, are becoming...",
|
| 6 |
+
"embedding_shape": [
|
| 7 |
+
47,
|
| 8 |
+
1024
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"sample_id": 1,
|
| 13 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the most effective way to learn a new language, especially for someone who is already an adult and has a busy schedule?<|im_end|>\n<|im_start|>assistant\n",
|
| 14 |
+
"response": "Learning a new language as an adult with a busy schedule can be challenging, but it's definitely achievable with the right strategies and consistent effort. Here are some effective methods to help you learn a new language efficiently:\n\n1. **Set Clear Goals**: Define what you want to achieve with the...",
|
| 15 |
+
"embedding_shape": [
|
| 16 |
+
54,
|
| 17 |
+
1024
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"sample_id": 2,
|
| 22 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nI need to create a marketing plan for a new line of organic skincare products. What are the key elements I should consider?<|im_end|>\n<|im_start|>assistant\n",
|
| 23 |
+
"response": "Creating a marketing plan for a new line of organic skincare products involves several key elements to ensure the successful launch and sustained growth of your brand. Hereβs a comprehensive guide to help you develop an effective marketing plan:\n\n### 1. **Market Research**\n - **Target Audience**: ...",
|
| 24 |
+
"embedding_shape": [
|
| 25 |
+
53,
|
| 26 |
+
1024
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"sample_id": 3,
|
| 31 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the role of a project manager in a software development team?<|im_end|>\n<|im_start|>assistant\n",
|
| 32 |
+
"response": "The role of a project manager in a software development team is multifaceted and crucial for the success of the project. Here are some of the key responsibilities and tasks typically associated with this role:\n\n1. **Project Planning**: The project manager is responsible for defining the scope of the...",
|
| 33 |
+
"embedding_shape": [
|
| 34 |
+
42,
|
| 35 |
+
1024
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"sample_id": 4,
|
| 40 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nI need to create a social media campaign for a new line of eco-friendly products. Can you help me with some ideas?<|im_end|>\n<|im_start|>assistant\n",
|
| 41 |
+
"response": "Absolutely, I'd be happy to help you brainstorm some ideas for a social media campaign for your new line of eco-friendly products! Here are a few strategies and ideas to consider:\n\n### 1. **Eco-Challenges**\n- **30-Day Eco-Challenge:** Encourage followers to adopt one eco-friendly habit each day for ...",
|
| 42 |
+
"embedding_shape": [
|
| 43 |
+
52,
|
| 44 |
+
1024
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"sample_id": 5,
|
| 49 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nCan you explain the concept of \"dynamic equilibrium\" in chemistry? And could you provide an example to illustrate this concept?<|im_end|>\n<|im_start|>assistant\n",
|
| 50 |
+
"response": "Certainly! The concept of \"dynamic equilibrium\" in chemistry refers to a state in which the concentrations of the reactants and products in a reversible reaction remain constant over time, despite the fact that the forward and reverse reactions are still occurring at the same rate. This means that w...",
|
| 51 |
+
"embedding_shape": [
|
| 52 |
+
51,
|
| 53 |
+
1024
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"sample_id": 6,
|
| 58 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the best way to create a budget for a small business?<|im_end|>\n<|im_start|>assistant\n",
|
| 59 |
+
"response": "Creating a budget for a small business is a crucial step in managing finances effectively and ensuring the long-term success of the business. Hereβs a step-by-step guide to help you create a comprehensive and effective budget:\n\n### 1. **Define Your Business Goals**\n - **Short-term and Long-term Go...",
|
| 60 |
+
"embedding_shape": [
|
| 61 |
+
42,
|
| 62 |
+
1024
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"sample_id": 7,
|
| 67 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nCan you provide a brief overview of Alibaba Cloud's security measures and how they protect user data?<|im_end|>\n<|im_start|>assistant\n",
|
| 68 |
+
"response": "Certainly! Alibaba Cloud places a strong emphasis on security and has implemented a comprehensive set of measures to protect user data and ensure the integrity, confidentiality, and availability of its services. Hereβs a brief overview of some key security measures:\n\n1. **Data Encryption**: Alibaba ...",
|
| 69 |
+
"embedding_shape": [
|
| 70 |
+
47,
|
| 71 |
+
1024
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"sample_id": 8,
|
| 76 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nI am interested in learning more about the cultural significance of the Chinese zodiac. Could you tell me about its origin and how it is used in Chinese culture today?<|im_end|>\n<|im_st...",
|
| 77 |
+
"response": "Certainly! The Chinese zodiac, known as \"Sheng Xiao\" (ηθ) in Chinese, is a classification scheme that assigns an animal and its reputed attributes to each year in a repeating 12-year cycle. The 12 animals of the Chinese zodiac are the Rat, Ox, Tiger, Rabbit, Dragon, Snake, Horse, Goat, Monkey, Roost...",
|
| 78 |
+
"embedding_shape": [
|
| 79 |
+
60,
|
| 80 |
+
1024
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"sample_id": 9,
|
| 85 |
+
"input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHow can I use a Virtual Private Cloud (VPC) to enhance the security of my cloud applications?<|im_end|>\n<|im_start|>assistant\n",
|
| 86 |
+
"response": "Using a Virtual Private Cloud (VPC) can significantly enhance the security of your cloud applications by providing a secure and isolated network environment. Here are several ways you can leverage a VPC to improve security:\n\n### 1. **Network Isolation**\n - **Private Network**: A VPC allows you to ...",
|
| 87 |
+
"embedding_shape": [
|
| 88 |
+
48,
|
| 89 |
+
1024
|
| 90 |
+
]
|
| 91 |
+
}
|
| 92 |
+
]
|
train_custom_qwen3.py
ADDED
|
@@ -0,0 +1,1292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import json
|
| 5 |
+
import pickle
|
| 6 |
+
from typing import Optional, Tuple, List, Dict, Any, Union
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch.utils.data import Dataset, DataLoader
|
| 15 |
+
from torch.optim import AdamW
|
| 16 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 17 |
+
|
| 18 |
+
from datasets import load_dataset, Dataset as HFDataset
|
| 19 |
+
from transformers import (
|
| 20 |
+
AutoTokenizer,
|
| 21 |
+
AutoModelForCausalLM,
|
| 22 |
+
get_linear_schedule_with_warmup,
|
| 23 |
+
PreTrainedModel
|
| 24 |
+
)
|
| 25 |
+
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
| 26 |
+
from transformers.models.qwen3.modeling_qwen3 import (
|
| 27 |
+
Qwen3Model,
|
| 28 |
+
Qwen3ForCausalLM,
|
| 29 |
+
Qwen3PreTrainedModel,
|
| 30 |
+
Qwen3RMSNorm
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 36 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 37 |
+
force=True,
|
| 38 |
+
)
|
| 39 |
+
logger = logging.getLogger("grouped_qwen3_training")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GroupedInputMLPAdapter(nn.Module):
|
| 43 |
+
def __init__(self, config):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.config = config
|
| 46 |
+
hidden_size = config.hidden_size
|
| 47 |
+
|
| 48 |
+
self.grouped_processor = nn.Sequential(
|
| 49 |
+
nn.Linear(hidden_size, hidden_size * 2),
|
| 50 |
+
nn.SiLU(), # Using SiLU activation like Qwen3
|
| 51 |
+
nn.Dropout(0.1),
|
| 52 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
| 53 |
+
nn.Dropout(0.1)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
norm_eps = getattr(config, 'rms_norm_eps', 1e-6)
|
| 57 |
+
self.layer_norm = Qwen3RMSNorm(hidden_size, eps=norm_eps)
|
| 58 |
+
|
| 59 |
+
def forward(self, grouped_embeds: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
processed = self.grouped_processor(grouped_embeds)
|
| 61 |
+
|
| 62 |
+
output = self.layer_norm(grouped_embeds + processed)
|
| 63 |
+
|
| 64 |
+
return output
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
|
| 68 |
+
def __init__(self, config):
|
| 69 |
+
super().__init__(config)
|
| 70 |
+
|
| 71 |
+
self.grouped_input_mlp = GroupedInputMLPAdapter(config)
|
| 72 |
+
|
| 73 |
+
self.is_grouped_input_mode = False
|
| 74 |
+
self.grouped_cache_initialized = False
|
| 75 |
+
|
| 76 |
+
self._init_grouped_weights()
|
| 77 |
+
|
| 78 |
+
self._freeze_layers()
|
| 79 |
+
|
| 80 |
+
def _init_grouped_weights(self):
|
| 81 |
+
def _init_weights(module):
|
| 82 |
+
if isinstance(module, nn.Linear):
|
| 83 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 84 |
+
if module.bias is not None:
|
| 85 |
+
torch.nn.init.zeros_(module.bias)
|
| 86 |
+
elif isinstance(module, nn.LayerNorm):
|
| 87 |
+
torch.nn.init.ones_(module.weight)
|
| 88 |
+
torch.nn.init.zeros_(module.bias)
|
| 89 |
+
|
| 90 |
+
self.grouped_input_mlp.apply(_init_weights)
|
| 91 |
+
|
| 92 |
+
def _freeze_layers(self):
|
| 93 |
+
for param in self.model.embed_tokens.parameters():
|
| 94 |
+
param.requires_grad = False
|
| 95 |
+
|
| 96 |
+
for i, layer in enumerate(self.model.layers):
|
| 97 |
+
if i == 0:
|
| 98 |
+
for param in layer.parameters():
|
| 99 |
+
param.requires_grad = True
|
| 100 |
+
else:
|
| 101 |
+
for param in layer.parameters():
|
| 102 |
+
param.requires_grad = False
|
| 103 |
+
|
| 104 |
+
for param in self.model.norm.parameters():
|
| 105 |
+
param.requires_grad = False
|
| 106 |
+
|
| 107 |
+
for param in self.lm_head.parameters():
|
| 108 |
+
param.requires_grad = False
|
| 109 |
+
|
| 110 |
+
for param in self.grouped_input_mlp.parameters():
|
| 111 |
+
param.requires_grad = True
|
| 112 |
+
|
| 113 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 114 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 115 |
+
logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} "
|
| 116 |
+
f"({trainable_params/total_params*100:.2f}%)")
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 121 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 122 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 123 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 124 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 125 |
+
labels: Optional[torch.LongTensor] = None,
|
| 126 |
+
use_cache: Optional[bool] = None,
|
| 127 |
+
output_attentions: Optional[bool] = None,
|
| 128 |
+
output_hidden_states: Optional[bool] = None,
|
| 129 |
+
return_dict: Optional[bool] = None,
|
| 130 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 131 |
+
grouped_inputs: Optional[torch.FloatTensor] = None, # New parameter for grouped inputs
|
| 132 |
+
is_prefill: Optional[bool] = None, # Flag to indicate prefill phase
|
| 133 |
+
**kwargs
|
| 134 |
+
):
|
| 135 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 136 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 137 |
+
|
| 138 |
+
if grouped_inputs is not None and is_prefill:
|
| 139 |
+
self.is_grouped_input_mode = True
|
| 140 |
+
|
| 141 |
+
processed_grouped_inputs = self.grouped_input_mlp(grouped_inputs)
|
| 142 |
+
|
| 143 |
+
inputs_embeds = processed_grouped_inputs
|
| 144 |
+
input_ids = None # Don't use input_ids when we have grouped inputs
|
| 145 |
+
|
| 146 |
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
| 147 |
+
if position_ids is None:
|
| 148 |
+
device = inputs_embeds.device
|
| 149 |
+
position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
|
| 150 |
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
| 151 |
+
|
| 152 |
+
if attention_mask is None:
|
| 153 |
+
attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long)
|
| 154 |
+
|
| 155 |
+
self.grouped_cache_initialized = True
|
| 156 |
+
|
| 157 |
+
elif not is_prefill and self.is_grouped_input_mode:
|
| 158 |
+
pass
|
| 159 |
+
else:
|
| 160 |
+
self.is_grouped_input_mode = False
|
| 161 |
+
|
| 162 |
+
# Call parent forward
|
| 163 |
+
outputs = super().forward(
|
| 164 |
+
input_ids=input_ids,
|
| 165 |
+
attention_mask=attention_mask,
|
| 166 |
+
position_ids=position_ids,
|
| 167 |
+
past_key_values=past_key_values,
|
| 168 |
+
inputs_embeds=inputs_embeds,
|
| 169 |
+
labels=labels,
|
| 170 |
+
use_cache=use_cache,
|
| 171 |
+
output_attentions=output_attentions,
|
| 172 |
+
output_hidden_states=output_hidden_states,
|
| 173 |
+
return_dict=return_dict,
|
| 174 |
+
cache_position=cache_position,
|
| 175 |
+
**kwargs
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return outputs
|
| 179 |
+
|
| 180 |
+
class GroupedDataset(Dataset):
|
| 181 |
+
def __init__(self, dataset_path: str, tokenizer, max_response_length: int = 512,
|
| 182 |
+
validation_split: float = 0.1, is_validation: bool = False,
|
| 183 |
+
chunk_size: int = 1000, max_samples: Optional[int] = None):
|
| 184 |
+
self.dataset_path = dataset_path
|
| 185 |
+
self.tokenizer = tokenizer
|
| 186 |
+
self.max_response_length = max_response_length
|
| 187 |
+
self.validation_split = validation_split
|
| 188 |
+
self.is_validation = is_validation
|
| 189 |
+
self.chunk_size = chunk_size
|
| 190 |
+
self.max_samples = max_samples
|
| 191 |
+
|
| 192 |
+
self._chunk_cache = {}
|
| 193 |
+
self._cache_size_limit = 3 # Keep max 3 chunks in memory
|
| 194 |
+
|
| 195 |
+
self._build_index()
|
| 196 |
+
|
| 197 |
+
def _build_index(self):
|
| 198 |
+
logger.info(f"Building index for {self.dataset_path}")
|
| 199 |
+
|
| 200 |
+
with open(self.dataset_path, 'rb') as f:
|
| 201 |
+
data = pickle.load(f)
|
| 202 |
+
|
| 203 |
+
valid_indices = []
|
| 204 |
+
for i, item in enumerate(data):
|
| 205 |
+
if not item.get("error", False):
|
| 206 |
+
valid_indices.append(i)
|
| 207 |
+
|
| 208 |
+
if self.max_samples and len(valid_indices) >= self.max_samples:
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
total_valid = len(valid_indices)
|
| 212 |
+
|
| 213 |
+
val_size = min(1000, int(self.validation_split * total_valid))
|
| 214 |
+
train_size = total_valid - val_size
|
| 215 |
+
|
| 216 |
+
if self.is_validation:
|
| 217 |
+
self.valid_indices = valid_indices[train_size:train_size + val_size]
|
| 218 |
+
self.total_samples = val_size
|
| 219 |
+
else:
|
| 220 |
+
self.valid_indices = valid_indices[:train_size]
|
| 221 |
+
self.total_samples = train_size
|
| 222 |
+
|
| 223 |
+
self._full_data = data
|
| 224 |
+
|
| 225 |
+
logger.info(f"{'Validation' if self.is_validation else 'Training'} dataset: {self.total_samples} samples")
|
| 226 |
+
|
| 227 |
+
def _get_chunk_id(self, idx):
|
| 228 |
+
return idx // self.chunk_size
|
| 229 |
+
|
| 230 |
+
def _load_chunk(self, chunk_id):
|
| 231 |
+
if chunk_id in self._chunk_cache:
|
| 232 |
+
return self._chunk_cache[chunk_id]
|
| 233 |
+
|
| 234 |
+
start_idx = chunk_id * self.chunk_size
|
| 235 |
+
end_idx = min(start_idx + self.chunk_size, self.total_samples)
|
| 236 |
+
|
| 237 |
+
chunk_data = {}
|
| 238 |
+
for i in range(start_idx, end_idx):
|
| 239 |
+
actual_idx = self.valid_indices[i]
|
| 240 |
+
chunk_data[i] = self._full_data[actual_idx]
|
| 241 |
+
|
| 242 |
+
if len(self._chunk_cache) >= self._cache_size_limit:
|
| 243 |
+
oldest_chunk = min(self._chunk_cache.keys())
|
| 244 |
+
del self._chunk_cache[oldest_chunk]
|
| 245 |
+
|
| 246 |
+
self._chunk_cache[chunk_id] = chunk_data
|
| 247 |
+
return chunk_data
|
| 248 |
+
|
| 249 |
+
def __len__(self):
|
| 250 |
+
return self.total_samples
|
| 251 |
+
|
| 252 |
+
def __getitem__(self, idx):
|
| 253 |
+
if idx >= self.total_samples:
|
| 254 |
+
raise IndexError(f"Index {idx} out of range for dataset of size {self.total_samples}")
|
| 255 |
+
|
| 256 |
+
chunk_id = self._get_chunk_id(idx)
|
| 257 |
+
chunk_data = self._load_chunk(chunk_id)
|
| 258 |
+
item = chunk_data[idx]
|
| 259 |
+
|
| 260 |
+
return self._process_item(item)
|
| 261 |
+
|
| 262 |
+
def _process_item(self, item):
|
| 263 |
+
grouped_embeds = item["inputs_embeds"]
|
| 264 |
+
if isinstance(grouped_embeds, torch.Tensor):
|
| 265 |
+
grouped_embeds = grouped_embeds.clone()
|
| 266 |
+
else:
|
| 267 |
+
grouped_embeds = torch.tensor(grouped_embeds)
|
| 268 |
+
|
| 269 |
+
if grouped_embeds.dtype != torch.float32:
|
| 270 |
+
grouped_embeds = grouped_embeds.float()
|
| 271 |
+
|
| 272 |
+
response = item["response"]
|
| 273 |
+
|
| 274 |
+
response_tokens = self.tokenizer(
|
| 275 |
+
response,
|
| 276 |
+
max_length=self.max_response_length,
|
| 277 |
+
truncation=True,
|
| 278 |
+
padding=False,
|
| 279 |
+
return_tensors="pt"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
response_input_ids = response_tokens["input_ids"].squeeze(0)
|
| 283 |
+
|
| 284 |
+
return {
|
| 285 |
+
"grouped_inputs": grouped_embeds,
|
| 286 |
+
"response_input_ids": response_input_ids,
|
| 287 |
+
"response_text": response,
|
| 288 |
+
"input_text": item["input_text"],
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
def cleanup(self):
|
| 292 |
+
self._chunk_cache.clear()
|
| 293 |
+
if hasattr(self, '_full_data'):
|
| 294 |
+
del self._full_data
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def collate_fn(batch, tokenizer, pad_token_id=None):
|
| 298 |
+
if pad_token_id is None:
|
| 299 |
+
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
| 300 |
+
|
| 301 |
+
grouped_inputs = [item["grouped_inputs"] for item in batch]
|
| 302 |
+
response_input_ids = [item["response_input_ids"] for item in batch]
|
| 303 |
+
|
| 304 |
+
max_grouped_len = max(gi.shape[0] for gi in grouped_inputs)
|
| 305 |
+
batch_size = len(grouped_inputs)
|
| 306 |
+
hidden_size = grouped_inputs[0].shape[-1]
|
| 307 |
+
|
| 308 |
+
padded_grouped_inputs = torch.zeros(batch_size, max_grouped_len, hidden_size)
|
| 309 |
+
grouped_attention_mask = torch.zeros(batch_size, max_grouped_len, dtype=torch.long)
|
| 310 |
+
|
| 311 |
+
for i, gi in enumerate(grouped_inputs):
|
| 312 |
+
seq_len = gi.shape[0]
|
| 313 |
+
padded_grouped_inputs[i, :seq_len] = gi
|
| 314 |
+
grouped_attention_mask[i, :seq_len] = 1
|
| 315 |
+
|
| 316 |
+
max_response_len = max(len(rid) for rid in response_input_ids)
|
| 317 |
+
padded_response_ids = torch.full((batch_size, max_response_len), pad_token_id, dtype=torch.long)
|
| 318 |
+
|
| 319 |
+
for i, rid in enumerate(response_input_ids):
|
| 320 |
+
padded_response_ids[i, :len(rid)] = rid
|
| 321 |
+
|
| 322 |
+
return {
|
| 323 |
+
"grouped_inputs": padded_grouped_inputs,
|
| 324 |
+
"grouped_attention_mask": grouped_attention_mask,
|
| 325 |
+
"response_input_ids": padded_response_ids,
|
| 326 |
+
"response_texts": [item["response_text"] for item in batch],
|
| 327 |
+
"input_texts": [item["input_text"] for item in batch],
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
class TrainingState:
|
| 331 |
+
def __init__(self, output_dir: Path):
|
| 332 |
+
self.output_dir = output_dir
|
| 333 |
+
self.state_file = output_dir / "training_state.json"
|
| 334 |
+
|
| 335 |
+
def save_state(self, epoch: int, global_step: int, best_val_loss: float,
|
| 336 |
+
optimizer_state: Dict, scheduler_state: Dict):
|
| 337 |
+
"""Save training state."""
|
| 338 |
+
state = {
|
| 339 |
+
"epoch": epoch,
|
| 340 |
+
"global_step": global_step,
|
| 341 |
+
"best_val_loss": best_val_loss,
|
| 342 |
+
"optimizer_state": optimizer_state,
|
| 343 |
+
"scheduler_state": scheduler_state,
|
| 344 |
+
"completed_epochs": epoch
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
with open(self.state_file, 'w') as f:
|
| 348 |
+
json.dump(state, f, indent=2, default=str) # default=str for handling tensor types
|
| 349 |
+
|
| 350 |
+
logger.info(f"Saved training state at epoch {epoch}, step {global_step}")
|
| 351 |
+
|
| 352 |
+
def load_state(self):
|
| 353 |
+
if not self.state_file.exists():
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
try:
|
| 357 |
+
with open(self.state_file, 'r') as f:
|
| 358 |
+
state = json.load(f)
|
| 359 |
+
logger.info(f"Loaded training state from epoch {state['epoch']}, step {state['global_step']}")
|
| 360 |
+
return state
|
| 361 |
+
except Exception as e:
|
| 362 |
+
logger.warning(f"Failed to load training state: {e}")
|
| 363 |
+
return None
|
| 364 |
+
|
| 365 |
+
def get_latest_checkpoint(self):
|
| 366 |
+
state = self.load_state()
|
| 367 |
+
if state is None:
|
| 368 |
+
return None
|
| 369 |
+
|
| 370 |
+
epoch = state["completed_epochs"]
|
| 371 |
+
checkpoint_path = self.output_dir / f"epoch_{epoch}"
|
| 372 |
+
|
| 373 |
+
if checkpoint_path.exists():
|
| 374 |
+
return checkpoint_path, state
|
| 375 |
+
else:
|
| 376 |
+
logger.warning(f"Checkpoint for epoch {epoch} not found")
|
| 377 |
+
return None
|
| 378 |
+
|
| 379 |
+
class GroupedTrainer:
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
model_name: str = "Qwen/Qwen3-0.6B",
|
| 383 |
+
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
|
| 384 |
+
output_dir: str = "./grouped_qwen3_checkpoint",
|
| 385 |
+
batch_size: int = 4,
|
| 386 |
+
learning_rate: float = 5e-4,
|
| 387 |
+
num_epochs: int = 3,
|
| 388 |
+
warmup_steps: int = 100,
|
| 389 |
+
max_grad_norm: float = 1.0,
|
| 390 |
+
save_steps: int = 500,
|
| 391 |
+
eval_steps: int = 500,
|
| 392 |
+
logging_steps: int = 50,
|
| 393 |
+
resume_training: bool = True,
|
| 394 |
+
debug: bool = False,
|
| 395 |
+
chunk_size: int = 1000, # Chunk size for streaming
|
| 396 |
+
max_samples: Optional[int] = None, # Limit dataset size for testing
|
| 397 |
+
):
|
| 398 |
+
self.model_name = model_name
|
| 399 |
+
self.dataset_path = dataset_path
|
| 400 |
+
self.output_dir = Path(output_dir)
|
| 401 |
+
self.batch_size = batch_size
|
| 402 |
+
self.learning_rate = learning_rate
|
| 403 |
+
self.num_epochs = num_epochs
|
| 404 |
+
self.warmup_steps = warmup_steps
|
| 405 |
+
self.max_grad_norm = max_grad_norm
|
| 406 |
+
self.save_steps = save_steps
|
| 407 |
+
self.eval_steps = eval_steps
|
| 408 |
+
self.logging_steps = logging_steps
|
| 409 |
+
self.resume_training = resume_training
|
| 410 |
+
self.debug = debug
|
| 411 |
+
self.chunk_size = chunk_size
|
| 412 |
+
self.max_samples = max_samples
|
| 413 |
+
|
| 414 |
+
if self.debug:
|
| 415 |
+
logger.setLevel(logging.DEBUG)
|
| 416 |
+
|
| 417 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 418 |
+
|
| 419 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 420 |
+
self.tokenizer = None
|
| 421 |
+
self.model = None
|
| 422 |
+
self.train_dataset = None
|
| 423 |
+
self.val_dataset = None
|
| 424 |
+
|
| 425 |
+
self.training_state = TrainingState(self.output_dir)
|
| 426 |
+
|
| 427 |
+
def load_model_and_tokenizer(self):
|
| 428 |
+
logger.info(f"Loading tokenizer and model: {self.model_name}")
|
| 429 |
+
|
| 430 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 431 |
+
|
| 432 |
+
if self.tokenizer.pad_token is None:
|
| 433 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 434 |
+
|
| 435 |
+
config = Qwen3Config.from_pretrained(self.model_name)
|
| 436 |
+
|
| 437 |
+
self.model = CustomQwen3ForCausalLM.from_pretrained(
|
| 438 |
+
self.model_name,
|
| 439 |
+
config=config,
|
| 440 |
+
torch_dtype=torch.float32, # Use float32 for training
|
| 441 |
+
attn_implementation="eager"
|
| 442 |
+
).to(self.device)
|
| 443 |
+
|
| 444 |
+
logger.info(f"Model loaded on {self.device}")
|
| 445 |
+
|
| 446 |
+
def load_dataset(self, chunk_size: int = 1000, max_samples: Optional[int] = None):
|
| 447 |
+
logger.info(f"Loading streaming dataset from {self.dataset_path}")
|
| 448 |
+
|
| 449 |
+
# Create streaming datasets
|
| 450 |
+
self.train_dataset = GroupedDataset(
|
| 451 |
+
dataset_path=self.dataset_path,
|
| 452 |
+
tokenizer=self.tokenizer,
|
| 453 |
+
is_validation=False,
|
| 454 |
+
chunk_size=chunk_size,
|
| 455 |
+
max_samples=max_samples
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
self.val_dataset = GroupedDataset(
|
| 459 |
+
dataset_path=self.dataset_path,
|
| 460 |
+
tokenizer=self.tokenizer,
|
| 461 |
+
is_validation=True,
|
| 462 |
+
chunk_size=chunk_size,
|
| 463 |
+
max_samples=max_samples
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
logger.info(f"Train samples: {len(self.train_dataset)}")
|
| 467 |
+
logger.info(f"Val samples: {len(self.val_dataset)}")
|
| 468 |
+
|
| 469 |
+
# Log memory usage
|
| 470 |
+
if torch.cuda.is_available():
|
| 471 |
+
torch.cuda.empty_cache()
|
| 472 |
+
memory_used = torch.cuda.memory_allocated() / 1024**3
|
| 473 |
+
logger.info(f"GPU memory after dataset loading: {memory_used:.2f} GB")
|
| 474 |
+
|
| 475 |
+
def cleanup_datasets(self):
|
| 476 |
+
if hasattr(self.train_dataset, 'cleanup'):
|
| 477 |
+
self.train_dataset.cleanup()
|
| 478 |
+
if hasattr(self.val_dataset, 'cleanup'):
|
| 479 |
+
self.val_dataset.cleanup()
|
| 480 |
+
|
| 481 |
+
import gc
|
| 482 |
+
gc.collect()
|
| 483 |
+
|
| 484 |
+
if torch.cuda.is_available():
|
| 485 |
+
torch.cuda.empty_cache()
|
| 486 |
+
|
| 487 |
+
def load_checkpoint(self, checkpoint_path: Path):
|
| 488 |
+
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
| 489 |
+
|
| 490 |
+
model_path = checkpoint_path / "pytorch_model.bin"
|
| 491 |
+
if not model_path.exists():
|
| 492 |
+
model_path = checkpoint_path / "model.safetensors"
|
| 493 |
+
|
| 494 |
+
if model_path.exists():
|
| 495 |
+
state_dict = torch.load(model_path, map_location=self.device)
|
| 496 |
+
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
|
| 497 |
+
|
| 498 |
+
if missing_keys:
|
| 499 |
+
logger.warning(f"Missing keys when loading checkpoint: {missing_keys}")
|
| 500 |
+
if unexpected_keys:
|
| 501 |
+
logger.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
|
| 502 |
+
|
| 503 |
+
logger.info("Model checkpoint loaded successfully")
|
| 504 |
+
else:
|
| 505 |
+
raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
|
| 506 |
+
|
| 507 |
+
def compute_loss(self, batch, outputs):
|
| 508 |
+
logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
|
| 509 |
+
|
| 510 |
+
target_ids = batch["response_input_ids"].to(self.device) # [batch_size, target_len]
|
| 511 |
+
|
| 512 |
+
logger.debug(f"Logits shape: {logits.shape}, Target shape: {target_ids.shape}")
|
| 513 |
+
|
| 514 |
+
batch_size = target_ids.shape[0]
|
| 515 |
+
|
| 516 |
+
if target_ids.shape[1] > 1:
|
| 517 |
+
labels = target_ids.clone()
|
| 518 |
+
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
| 519 |
+
labels[labels == pad_token_id] = -100
|
| 520 |
+
|
| 521 |
+
seq_len = min(logits.shape[1], labels.shape[1])
|
| 522 |
+
logits_truncated = logits[:, :seq_len, :] # [batch_size, seq_len, vocab_size]
|
| 523 |
+
labels_truncated = labels[:, :seq_len] # [batch_size, seq_len]
|
| 524 |
+
|
| 525 |
+
logger.debug(f"After truncation - Logits: {logits_truncated.shape}, Labels: {labels_truncated.shape}")
|
| 526 |
+
|
| 527 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 528 |
+
loss = loss_fct(
|
| 529 |
+
logits_truncated.reshape(-1, logits_truncated.size(-1)),
|
| 530 |
+
labels_truncated.reshape(-1)
|
| 531 |
+
)
|
| 532 |
+
else:
|
| 533 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 534 |
+
loss = loss_fct(logits.view(-1, logits.size(-1)), target_ids.view(-1))
|
| 535 |
+
|
| 536 |
+
return loss
|
| 537 |
+
|
| 538 |
+
def training_step(self, batch, step):
|
| 539 |
+
self.model.train()
|
| 540 |
+
|
| 541 |
+
if step < 5 and torch.cuda.is_available():
|
| 542 |
+
torch.cuda.empty_cache()
|
| 543 |
+
memory_before = torch.cuda.memory_allocated() / 1024**3
|
| 544 |
+
|
| 545 |
+
grouped_inputs = batch["grouped_inputs"].to(self.device)
|
| 546 |
+
grouped_attention_mask = batch["grouped_attention_mask"].to(self.device)
|
| 547 |
+
response_input_ids = batch["response_input_ids"].to(self.device)
|
| 548 |
+
|
| 549 |
+
batch_size = grouped_inputs.shape[0]
|
| 550 |
+
grouped_seq_len = grouped_inputs.shape[1]
|
| 551 |
+
response_seq_len = response_input_ids.shape[1]
|
| 552 |
+
|
| 553 |
+
if self.debug:
|
| 554 |
+
logger.debug(f"Batch sizes - grouped: {grouped_inputs.shape}, response: {response_input_ids.shape}")
|
| 555 |
+
|
| 556 |
+
grouped_outputs = self.model(
|
| 557 |
+
grouped_inputs=grouped_inputs,
|
| 558 |
+
attention_mask=grouped_attention_mask,
|
| 559 |
+
is_prefill=True,
|
| 560 |
+
use_cache=True,
|
| 561 |
+
return_dict=True
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if response_seq_len > 1:
|
| 565 |
+
response_attention_mask = (response_input_ids != self.tokenizer.pad_token_id).long()
|
| 566 |
+
|
| 567 |
+
response_outputs = self.model(
|
| 568 |
+
input_ids=response_input_ids[:, :-1], # All but last token as input
|
| 569 |
+
attention_mask=response_attention_mask[:, :-1],
|
| 570 |
+
past_key_values=grouped_outputs.past_key_values,
|
| 571 |
+
use_cache=False,
|
| 572 |
+
return_dict=True
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
logits = response_outputs.logits
|
| 576 |
+
labels = response_input_ids[:, 1:] # All but first token as targets
|
| 577 |
+
|
| 578 |
+
labels = labels.clone()
|
| 579 |
+
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
| 580 |
+
labels[labels == pad_token_id] = -100
|
| 581 |
+
|
| 582 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 583 |
+
loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
|
| 584 |
+
|
| 585 |
+
else:
|
| 586 |
+
loss = torch.tensor(0.0, requires_grad=True, device=self.device)
|
| 587 |
+
|
| 588 |
+
if step < 5 and torch.cuda.is_available():
|
| 589 |
+
memory_after = torch.cuda.memory_allocated() / 1024**3
|
| 590 |
+
memory_peak = torch.cuda.max_memory_allocated() / 1024**3
|
| 591 |
+
logger.info(f"Step {step} Memory: {memory_before:.2f}GB β {memory_after:.2f}GB (Peak: {memory_peak:.2f}GB)")
|
| 592 |
+
|
| 593 |
+
if memory_peak > 20.0: # 20GB threshold for L4
|
| 594 |
+
logger.warning("High memory usage detected! Consider reducing batch_size")
|
| 595 |
+
|
| 596 |
+
class MockOutputs:
|
| 597 |
+
def __init__(self, loss, logits):
|
| 598 |
+
self.loss = loss
|
| 599 |
+
self.logits = logits
|
| 600 |
+
|
| 601 |
+
outputs = MockOutputs(loss, response_outputs.logits if 'response_outputs' in locals() else grouped_outputs.logits)
|
| 602 |
+
|
| 603 |
+
return loss, outputs
|
| 604 |
+
|
| 605 |
+
def validation_step(self, batch):
|
| 606 |
+
"""Single validation step."""
|
| 607 |
+
self.model.eval()
|
| 608 |
+
|
| 609 |
+
with torch.no_grad():
|
| 610 |
+
grouped_inputs = batch["grouped_inputs"].to(self.device)
|
| 611 |
+
grouped_attention_mask = batch["grouped_attention_mask"].to(self.device)
|
| 612 |
+
response_input_ids = batch["response_input_ids"].to(self.device)
|
| 613 |
+
|
| 614 |
+
grouped_outputs = self.model(
|
| 615 |
+
grouped_inputs=grouped_inputs,
|
| 616 |
+
attention_mask=grouped_attention_mask,
|
| 617 |
+
is_prefill=True,
|
| 618 |
+
use_cache=True,
|
| 619 |
+
return_dict=True
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
if response_input_ids.shape[1] > 1:
|
| 623 |
+
response_attention_mask = (response_input_ids != self.tokenizer.pad_token_id).long()
|
| 624 |
+
|
| 625 |
+
response_outputs = self.model(
|
| 626 |
+
input_ids=response_input_ids[:, :-1],
|
| 627 |
+
attention_mask=response_attention_mask[:, :-1],
|
| 628 |
+
past_key_values=grouped_outputs.past_key_values,
|
| 629 |
+
use_cache=False,
|
| 630 |
+
return_dict=True
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
logits = response_outputs.logits
|
| 634 |
+
labels = response_input_ids[:, 1:]
|
| 635 |
+
|
| 636 |
+
labels = labels.clone()
|
| 637 |
+
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
| 638 |
+
labels[labels == pad_token_id] = -100
|
| 639 |
+
|
| 640 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 641 |
+
loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
|
| 642 |
+
else:
|
| 643 |
+
loss = torch.tensor(0.0, device=self.device)
|
| 644 |
+
|
| 645 |
+
return loss.item()
|
| 646 |
+
|
| 647 |
+
def save_epoch_checkpoint(self, epoch: int, global_step: int, is_best: bool = False):
|
| 648 |
+
checkpoint_name = f"epoch_{epoch}"
|
| 649 |
+
if is_best:
|
| 650 |
+
checkpoint_name += "_best"
|
| 651 |
+
|
| 652 |
+
checkpoint_dir = self.output_dir / checkpoint_name
|
| 653 |
+
checkpoint_dir.mkdir(exist_ok=True)
|
| 654 |
+
|
| 655 |
+
torch.save(self.model.state_dict(), checkpoint_dir / "pytorch_model.bin")
|
| 656 |
+
|
| 657 |
+
self.model.config.save_pretrained(checkpoint_dir)
|
| 658 |
+
|
| 659 |
+
self.tokenizer.save_pretrained(checkpoint_dir)
|
| 660 |
+
|
| 661 |
+
metadata = {
|
| 662 |
+
"epoch": epoch,
|
| 663 |
+
"global_step": global_step,
|
| 664 |
+
"model_name": self.model_name,
|
| 665 |
+
"learning_rate": self.learning_rate,
|
| 666 |
+
"batch_size": self.batch_size,
|
| 667 |
+
"is_best": is_best,
|
| 668 |
+
"model_class": "CustomQwen3ForCausalLM"
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
with open(checkpoint_dir / "epoch_metadata.json", 'w') as f:
|
| 672 |
+
json.dump(metadata, f, indent=2)
|
| 673 |
+
|
| 674 |
+
logger.info(f"Saved epoch checkpoint: {checkpoint_dir}")
|
| 675 |
+
return checkpoint_dir
|
| 676 |
+
|
| 677 |
+
def train(self):
|
| 678 |
+
logger.info("Starting training...")
|
| 679 |
+
|
| 680 |
+
train_loader = DataLoader(
|
| 681 |
+
self.train_dataset,
|
| 682 |
+
batch_size=self.batch_size,
|
| 683 |
+
shuffle=True,
|
| 684 |
+
collate_fn=lambda batch: collate_fn(batch, self.tokenizer),
|
| 685 |
+
num_workers=0 # Avoid multiprocessing issues with custom collate_fn
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
val_loader = DataLoader(
|
| 689 |
+
self.val_dataset,
|
| 690 |
+
batch_size=self.batch_size,
|
| 691 |
+
shuffle=False,
|
| 692 |
+
collate_fn=lambda batch: collate_fn(batch, self.tokenizer),
|
| 693 |
+
num_workers=0
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
optimizer = AdamW(
|
| 697 |
+
[p for p in self.model.parameters() if p.requires_grad],
|
| 698 |
+
lr=self.learning_rate,
|
| 699 |
+
weight_decay=0.01
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
total_steps = len(train_loader) * self.num_epochs
|
| 703 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 704 |
+
optimizer,
|
| 705 |
+
num_warmup_steps=self.warmup_steps,
|
| 706 |
+
num_training_steps=total_steps
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
start_epoch = 0
|
| 710 |
+
global_step = 0
|
| 711 |
+
best_val_loss = float('inf')
|
| 712 |
+
|
| 713 |
+
if self.resume_training:
|
| 714 |
+
checkpoint_info = self.training_state.get_latest_checkpoint()
|
| 715 |
+
if checkpoint_info is not None:
|
| 716 |
+
checkpoint_path, state = checkpoint_info
|
| 717 |
+
|
| 718 |
+
self.load_checkpoint(checkpoint_path)
|
| 719 |
+
|
| 720 |
+
start_epoch = state["completed_epochs"]
|
| 721 |
+
global_step = state["global_step"]
|
| 722 |
+
best_val_loss = state["best_val_loss"]
|
| 723 |
+
|
| 724 |
+
if "optimizer_state" in state and state["optimizer_state"]:
|
| 725 |
+
try:
|
| 726 |
+
optimizer.load_state_dict(state["optimizer_state"])
|
| 727 |
+
except Exception as e:
|
| 728 |
+
logger.warning(f"Failed to load optimizer state: {e}")
|
| 729 |
+
|
| 730 |
+
if "scheduler_state" in state and state["scheduler_state"]:
|
| 731 |
+
try:
|
| 732 |
+
scheduler.load_state_dict(state["scheduler_state"])
|
| 733 |
+
except Exception as e:
|
| 734 |
+
logger.warning(f"Failed to load scheduler state: {e}")
|
| 735 |
+
|
| 736 |
+
logger.info(f"Resumed training from epoch {start_epoch + 1}")
|
| 737 |
+
|
| 738 |
+
for epoch in range(start_epoch, self.num_epochs):
|
| 739 |
+
logger.info(f"Epoch {epoch + 1}/{self.num_epochs}")
|
| 740 |
+
|
| 741 |
+
epoch_train_loss = 0
|
| 742 |
+
train_steps = 0
|
| 743 |
+
|
| 744 |
+
progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}")
|
| 745 |
+
|
| 746 |
+
for batch_idx, batch in enumerate(progress_bar):
|
| 747 |
+
try:
|
| 748 |
+
loss, outputs = self.training_step(batch, global_step)
|
| 749 |
+
|
| 750 |
+
loss.backward()
|
| 751 |
+
|
| 752 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
| 753 |
+
|
| 754 |
+
optimizer.step()
|
| 755 |
+
scheduler.step()
|
| 756 |
+
optimizer.zero_grad()
|
| 757 |
+
|
| 758 |
+
epoch_train_loss += loss.item()
|
| 759 |
+
train_steps += 1
|
| 760 |
+
global_step += 1
|
| 761 |
+
|
| 762 |
+
progress_bar.set_postfix({
|
| 763 |
+
'loss': f'{loss.item():.4f}',
|
| 764 |
+
'lr': f'{scheduler.get_last_lr()[0]:.2e}'
|
| 765 |
+
})
|
| 766 |
+
|
| 767 |
+
if global_step % self.logging_steps == 0:
|
| 768 |
+
avg_loss = epoch_train_loss / train_steps
|
| 769 |
+
logger.info(f"Step {global_step}: train_loss={avg_loss:.4f}, lr={scheduler.get_last_lr()[0]:.2e}")
|
| 770 |
+
|
| 771 |
+
if global_step % self.eval_steps == 0:
|
| 772 |
+
val_loss = self.validate(val_loader)
|
| 773 |
+
logger.info(f"Step {global_step}: val_loss={val_loss:.4f}")
|
| 774 |
+
|
| 775 |
+
if val_loss < best_val_loss:
|
| 776 |
+
best_val_loss = val_loss
|
| 777 |
+
best_checkpoint = self.save_epoch_checkpoint(epoch, global_step, is_best=True)
|
| 778 |
+
logger.info(f"New best validation loss: {val_loss:.4f}")
|
| 779 |
+
|
| 780 |
+
except Exception as e:
|
| 781 |
+
logger.error(f"Error in training step {global_step}: {e}")
|
| 782 |
+
continue
|
| 783 |
+
|
| 784 |
+
val_loss = self.validate(val_loader)
|
| 785 |
+
avg_train_loss = epoch_train_loss / train_steps if train_steps > 0 else 0
|
| 786 |
+
|
| 787 |
+
logger.info(f"Epoch {epoch + 1} completed:")
|
| 788 |
+
logger.info(f" Average train loss: {avg_train_loss:.4f}")
|
| 789 |
+
logger.info(f" Validation loss: {val_loss:.4f}")
|
| 790 |
+
|
| 791 |
+
is_best = val_loss < best_val_loss
|
| 792 |
+
if is_best:
|
| 793 |
+
best_val_loss = val_loss
|
| 794 |
+
|
| 795 |
+
checkpoint_dir = self.save_epoch_checkpoint(epoch, global_step, is_best=is_best)
|
| 796 |
+
|
| 797 |
+
self.training_state.save_state(
|
| 798 |
+
epoch=epoch,
|
| 799 |
+
global_step=global_step,
|
| 800 |
+
best_val_loss=best_val_loss,
|
| 801 |
+
optimizer_state=optimizer.state_dict(),
|
| 802 |
+
scheduler_state=scheduler.state_dict()
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
logger.info(f"Epoch {epoch + 1} checkpoint and state saved")
|
| 806 |
+
|
| 807 |
+
logger.info(f"Training completed! Best validation loss: {best_val_loss:.4f}")
|
| 808 |
+
|
| 809 |
+
final_checkpoint = self.save_epoch_checkpoint(self.num_epochs - 1, global_step, is_best=False)
|
| 810 |
+
logger.info(f"Final checkpoint saved: {final_checkpoint}")
|
| 811 |
+
|
| 812 |
+
def validate(self, val_loader):
|
| 813 |
+
self.model.eval()
|
| 814 |
+
total_loss = 0
|
| 815 |
+
num_batches = 0
|
| 816 |
+
|
| 817 |
+
with torch.no_grad():
|
| 818 |
+
for batch in tqdm(val_loader, desc="Validation"):
|
| 819 |
+
try:
|
| 820 |
+
loss = self.validation_step(batch)
|
| 821 |
+
total_loss += loss
|
| 822 |
+
num_batches += 1
|
| 823 |
+
except Exception as e:
|
| 824 |
+
logger.warning(f"Error in validation step: {e}")
|
| 825 |
+
continue
|
| 826 |
+
|
| 827 |
+
avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
|
| 828 |
+
self.model.train() # Set back to training mode
|
| 829 |
+
return avg_loss
|
| 830 |
+
|
| 831 |
+
def run(self):
|
| 832 |
+
try:
|
| 833 |
+
self.load_model_and_tokenizer()
|
| 834 |
+
|
| 835 |
+
self.load_dataset(
|
| 836 |
+
chunk_size=self.chunk_size,
|
| 837 |
+
max_samples=self.max_samples
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
self.train()
|
| 841 |
+
|
| 842 |
+
logger.info("Training pipeline completed successfully!")
|
| 843 |
+
|
| 844 |
+
self.cleanup_datasets()
|
| 845 |
+
|
| 846 |
+
except Exception as e:
|
| 847 |
+
logger.error(f"Training pipeline failed: {e}")
|
| 848 |
+
import traceback
|
| 849 |
+
logger.error(traceback.format_exc())
|
| 850 |
+
|
| 851 |
+
try:
|
| 852 |
+
self.cleanup_datasets()
|
| 853 |
+
except:
|
| 854 |
+
pass
|
| 855 |
+
|
| 856 |
+
raise
|
| 857 |
+
|
| 858 |
+
def load_trained_model(checkpoint_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
|
| 859 |
+
logger.info(f"Loading trained model from {checkpoint_path}")
|
| 860 |
+
|
| 861 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
| 862 |
+
|
| 863 |
+
config = Qwen3Config.from_pretrained(checkpoint_path)
|
| 864 |
+
|
| 865 |
+
model = CustomQwen3ForCausalLM(config)
|
| 866 |
+
|
| 867 |
+
model_path = Path(checkpoint_path) / "pytorch_model.bin"
|
| 868 |
+
if not model_path.exists():
|
| 869 |
+
model_path = Path(checkpoint_path) / "model.safetensors"
|
| 870 |
+
|
| 871 |
+
if not model_path.exists():
|
| 872 |
+
raise FileNotFoundError(f"No model weights found in {checkpoint_path}")
|
| 873 |
+
|
| 874 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 875 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 876 |
+
|
| 877 |
+
if missing_keys:
|
| 878 |
+
logger.warning(f"Missing keys when loading model: {missing_keys}")
|
| 879 |
+
if unexpected_keys:
|
| 880 |
+
logger.warning(f"Unexpected keys when loading model: {unexpected_keys}")
|
| 881 |
+
|
| 882 |
+
model = model.eval().to(torch.float32)
|
| 883 |
+
|
| 884 |
+
return model, tokenizer
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def generate_with_grouped_input(
|
| 888 |
+
model,
|
| 889 |
+
tokenizer,
|
| 890 |
+
grouped_input: torch.Tensor,
|
| 891 |
+
max_length: int = 512,
|
| 892 |
+
temperature: float = 0.7,
|
| 893 |
+
do_sample: bool = True
|
| 894 |
+
):
|
| 895 |
+
device = model.device
|
| 896 |
+
model_dtype = next(model.parameters()).dtype
|
| 897 |
+
|
| 898 |
+
grouped_input = grouped_input.to(device=device, dtype=model_dtype)
|
| 899 |
+
|
| 900 |
+
if grouped_input.ndim == 2:
|
| 901 |
+
grouped_input = grouped_input.unsqueeze(0) # Add batch dimension
|
| 902 |
+
|
| 903 |
+
logger.debug(f"Grouped input shape: {grouped_input.shape}, dtype: {grouped_input.dtype}")
|
| 904 |
+
logger.debug(f"Model dtype: {model_dtype}, device: {device}")
|
| 905 |
+
|
| 906 |
+
with torch.no_grad():
|
| 907 |
+
try:
|
| 908 |
+
outputs = model(
|
| 909 |
+
grouped_inputs=grouped_input,
|
| 910 |
+
is_prefill=True,
|
| 911 |
+
use_cache=True,
|
| 912 |
+
return_dict=True
|
| 913 |
+
)
|
| 914 |
+
except Exception as e:
|
| 915 |
+
logger.error(f"Error in prefill phase: {e}")
|
| 916 |
+
raise
|
| 917 |
+
|
| 918 |
+
if hasattr(outputs, 'logits') and outputs.logits is not None:
|
| 919 |
+
next_token_logits = outputs.logits[:, -1, :]
|
| 920 |
+
elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
|
| 921 |
+
last_hidden_state = outputs.hidden_states[-1] if isinstance(outputs.hidden_states, (list, tuple)) else outputs.hidden_states
|
| 922 |
+
next_token_logits = model.lm_head(last_hidden_state[:, -1, :])
|
| 923 |
+
else:
|
| 924 |
+
raise RuntimeError("Could not extract logits from model output")
|
| 925 |
+
|
| 926 |
+
if do_sample:
|
| 927 |
+
next_token_logits = next_token_logits / temperature
|
| 928 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 929 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 930 |
+
else:
|
| 931 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 932 |
+
|
| 933 |
+
generated_ids = next_token
|
| 934 |
+
past_key_values = getattr(outputs, 'past_key_values', None)
|
| 935 |
+
|
| 936 |
+
for step in range(max_length - 1):
|
| 937 |
+
with torch.no_grad():
|
| 938 |
+
try:
|
| 939 |
+
outputs = model(
|
| 940 |
+
input_ids=next_token,
|
| 941 |
+
past_key_values=past_key_values,
|
| 942 |
+
use_cache=True,
|
| 943 |
+
return_dict=True
|
| 944 |
+
)
|
| 945 |
+
except Exception as e:
|
| 946 |
+
logger.error(f"Error in generation step {step}: {e}")
|
| 947 |
+
break
|
| 948 |
+
|
| 949 |
+
if hasattr(outputs, 'logits'):
|
| 950 |
+
next_token_logits = outputs.logits[:, -1, :]
|
| 951 |
+
else:
|
| 952 |
+
logger.warning("No logits in generation output, stopping generation")
|
| 953 |
+
break
|
| 954 |
+
|
| 955 |
+
if do_sample:
|
| 956 |
+
next_token_logits = next_token_logits / temperature
|
| 957 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 958 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 959 |
+
else:
|
| 960 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 961 |
+
|
| 962 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 963 |
+
past_key_values = getattr(outputs, 'past_key_values', None)
|
| 964 |
+
|
| 965 |
+
if next_token.item() == tokenizer.eos_token_id:
|
| 966 |
+
break
|
| 967 |
+
|
| 968 |
+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 969 |
+
return generated_text
|
| 970 |
+
|
| 971 |
+
def main():
|
| 972 |
+
config = {
|
| 973 |
+
"model_name": "Qwen/Qwen3-0.6B",
|
| 974 |
+
"dataset_path": "./processed_qwen3_dataset/processed_dataset.pkl",
|
| 975 |
+
"output_dir": "./grouped_qwen3_checkpoint",
|
| 976 |
+
"batch_size": 12, # Optimized for L4 24GB VRAM
|
| 977 |
+
"learning_rate": 5e-4,
|
| 978 |
+
"num_epochs": 3,
|
| 979 |
+
"warmup_steps": 500, # Increased for larger batch
|
| 980 |
+
"max_grad_norm": 1.0,
|
| 981 |
+
"save_steps": 1000, # Less frequent saves due to larger batches
|
| 982 |
+
"eval_steps": 1000, # Less frequent evaluation
|
| 983 |
+
"logging_steps": 100,
|
| 984 |
+
"resume_training": True,
|
| 985 |
+
"debug": False, # Disable debug for performance
|
| 986 |
+
# Streaming parameters
|
| 987 |
+
"chunk_size": 2000, # Load 2000 samples per chunk
|
| 988 |
+
"max_samples": None, # Use full dataset (set to smaller number for testing)
|
| 989 |
+
}
|
| 990 |
+
|
| 991 |
+
logger.info("="*60)
|
| 992 |
+
logger.info("GROUPED QWEN3 TRAINING CONFIGURATION (STREAMING)")
|
| 993 |
+
logger.info("="*60)
|
| 994 |
+
for key, value in config.items():
|
| 995 |
+
logger.info(f"{key}: {value}")
|
| 996 |
+
logger.info("="*60)
|
| 997 |
+
|
| 998 |
+
if torch.cuda.is_available():
|
| 999 |
+
logger.info(f"GPU: {torch.cuda.get_device_name()}")
|
| 1000 |
+
logger.info(f"VRAM Total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
| 1001 |
+
|
| 1002 |
+
import psutil
|
| 1003 |
+
ram_usage = psutil.virtual_memory()
|
| 1004 |
+
logger.info(f"System RAM: {ram_usage.used / 1024**3:.1f} GB / {ram_usage.total / 1024**3:.1f} GB ({ram_usage.percent:.1f}%)")
|
| 1005 |
+
|
| 1006 |
+
trainer = GroupedTrainer(**config)
|
| 1007 |
+
trainer.run()
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
def inference_by_id(sample_id: int, checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
|
| 1011 |
+
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
|
| 1012 |
+
max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
|
| 1013 |
+
"""Run inference on a specific sample ID from the dataset."""
|
| 1014 |
+
logger.info(f"Running inference on sample ID: {sample_id}")
|
| 1015 |
+
|
| 1016 |
+
try:
|
| 1017 |
+
model, tokenizer = load_trained_model(checkpoint_path)
|
| 1018 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1019 |
+
model = model.to(device)
|
| 1020 |
+
logger.info(f"Model loaded from {checkpoint_path}")
|
| 1021 |
+
except Exception as e:
|
| 1022 |
+
logger.error(f"Failed to load model: {e}")
|
| 1023 |
+
return None
|
| 1024 |
+
|
| 1025 |
+
try:
|
| 1026 |
+
logger.info(f"Loading sample {sample_id} from dataset...")
|
| 1027 |
+
with open(dataset_path, 'rb') as f:
|
| 1028 |
+
processed_data = pickle.load(f)
|
| 1029 |
+
|
| 1030 |
+
if sample_id >= len(processed_data):
|
| 1031 |
+
logger.error(f"Sample ID {sample_id} is out of range. Dataset has {len(processed_data)} samples.")
|
| 1032 |
+
return None
|
| 1033 |
+
|
| 1034 |
+
sample = processed_data[sample_id]
|
| 1035 |
+
|
| 1036 |
+
if sample.get("error", False):
|
| 1037 |
+
logger.error(f"Sample {sample_id} has errors and cannot be used for inference.")
|
| 1038 |
+
return None
|
| 1039 |
+
|
| 1040 |
+
except Exception as e:
|
| 1041 |
+
logger.error(f"Failed to load dataset: {e}")
|
| 1042 |
+
return None
|
| 1043 |
+
|
| 1044 |
+
grouped_embeds_raw = sample["inputs_embeds"]
|
| 1045 |
+
if isinstance(grouped_embeds_raw, torch.Tensor):
|
| 1046 |
+
grouped_input = grouped_embeds_raw.detach().clone().float()
|
| 1047 |
+
else:
|
| 1048 |
+
grouped_input = torch.tensor(grouped_embeds_raw, dtype=torch.float32)
|
| 1049 |
+
|
| 1050 |
+
original_input = sample["input_text"]
|
| 1051 |
+
expected_response = sample["response"]
|
| 1052 |
+
|
| 1053 |
+
print("\n" + "="*80)
|
| 1054 |
+
print(f"INFERENCE ON SAMPLE ID: {sample_id}")
|
| 1055 |
+
print("="*80)
|
| 1056 |
+
print(f"π ORIGINAL REQUEST:")
|
| 1057 |
+
print(f"{original_input}")
|
| 1058 |
+
print("\n" + "-"*80)
|
| 1059 |
+
print(f"π― EXPECTED RESPONSE:")
|
| 1060 |
+
print(f"{expected_response}")
|
| 1061 |
+
print("\n" + "-"*80)
|
| 1062 |
+
print(f"π€ MODEL GENERATED RESPONSE:")
|
| 1063 |
+
|
| 1064 |
+
try:
|
| 1065 |
+
generated_text = generate_with_grouped_input(
|
| 1066 |
+
model=model,
|
| 1067 |
+
tokenizer=tokenizer,
|
| 1068 |
+
grouped_input=grouped_input,
|
| 1069 |
+
max_length=max_length,
|
| 1070 |
+
temperature=temperature,
|
| 1071 |
+
do_sample=do_sample
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
print(f"{generated_text}")
|
| 1075 |
+
print("\n" + "="*80)
|
| 1076 |
+
|
| 1077 |
+
expected_words = expected_response.split()
|
| 1078 |
+
generated_words = generated_text.split()
|
| 1079 |
+
|
| 1080 |
+
print(f"π METRICS:")
|
| 1081 |
+
print(f"Expected length: {len(expected_words)} words")
|
| 1082 |
+
print(f"Generated length: {len(generated_words)} words")
|
| 1083 |
+
print(f"Temperature: {temperature}")
|
| 1084 |
+
print(f"Max length: {max_length}")
|
| 1085 |
+
print("="*80)
|
| 1086 |
+
|
| 1087 |
+
return {
|
| 1088 |
+
"sample_id": sample_id,
|
| 1089 |
+
"original_input": original_input,
|
| 1090 |
+
"expected_response": expected_response,
|
| 1091 |
+
"generated_response": generated_text,
|
| 1092 |
+
"expected_length": len(expected_words),
|
| 1093 |
+
"generated_length": len(generated_words)
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
except Exception as e:
|
| 1097 |
+
logger.error(f"Failed to generate response: {e}")
|
| 1098 |
+
print(f"β GENERATION FAILED: {e}")
|
| 1099 |
+
print("="*80)
|
| 1100 |
+
return None
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
def batch_inference(sample_ids: List[int], checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
|
| 1104 |
+
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
|
| 1105 |
+
max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
|
| 1106 |
+
"""Run inference on multiple sample IDs."""
|
| 1107 |
+
logger.info(f"Running batch inference on {len(sample_ids)} samples")
|
| 1108 |
+
|
| 1109 |
+
results = []
|
| 1110 |
+
for sample_id in sample_ids:
|
| 1111 |
+
result = inference_by_id(
|
| 1112 |
+
sample_id=sample_id,
|
| 1113 |
+
checkpoint_path=checkpoint_path,
|
| 1114 |
+
dataset_path=dataset_path,
|
| 1115 |
+
max_length=max_length,
|
| 1116 |
+
temperature=temperature,
|
| 1117 |
+
do_sample=do_sample
|
| 1118 |
+
)
|
| 1119 |
+
if result:
|
| 1120 |
+
results.append(result)
|
| 1121 |
+
|
| 1122 |
+
print("\n" + "π " + "-"*78 + " π\n") # Separator between samples
|
| 1123 |
+
|
| 1124 |
+
print("\n" + "="*80)
|
| 1125 |
+
print(f"π BATCH INFERENCE SUMMARY")
|
| 1126 |
+
print("="*80)
|
| 1127 |
+
print(f"Total samples processed: {len(results)}")
|
| 1128 |
+
if results:
|
| 1129 |
+
avg_expected_len = sum(r["expected_length"] for r in results) / len(results)
|
| 1130 |
+
avg_generated_len = sum(r["generated_length"] for r in results) / len(results)
|
| 1131 |
+
print(f"Average expected length: {avg_expected_len:.1f} words")
|
| 1132 |
+
print(f"Average generated length: {avg_generated_len:.1f} words")
|
| 1133 |
+
print("="*80)
|
| 1134 |
+
|
| 1135 |
+
return results
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
def random_inference(num_samples: int = 3, checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
|
| 1139 |
+
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
|
| 1140 |
+
max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
|
| 1141 |
+
"""Run inference on random samples from the dataset."""
|
| 1142 |
+
import random
|
| 1143 |
+
|
| 1144 |
+
try:
|
| 1145 |
+
with open(dataset_path, 'rb') as f:
|
| 1146 |
+
processed_data = pickle.load(f)
|
| 1147 |
+
|
| 1148 |
+
# Find valid samples
|
| 1149 |
+
valid_indices = [i for i, item in enumerate(processed_data) if not item.get("error", False)]
|
| 1150 |
+
|
| 1151 |
+
if len(valid_indices) < num_samples:
|
| 1152 |
+
logger.warning(f"Only {len(valid_indices)} valid samples available, using all of them")
|
| 1153 |
+
num_samples = len(valid_indices)
|
| 1154 |
+
|
| 1155 |
+
# Select random samples
|
| 1156 |
+
random_ids = random.sample(valid_indices, num_samples)
|
| 1157 |
+
|
| 1158 |
+
logger.info(f"Selected random sample IDs: {random_ids}")
|
| 1159 |
+
|
| 1160 |
+
# Run batch inference
|
| 1161 |
+
return batch_inference(
|
| 1162 |
+
sample_ids=random_ids,
|
| 1163 |
+
checkpoint_path=checkpoint_path,
|
| 1164 |
+
dataset_path=dataset_path,
|
| 1165 |
+
max_length=max_length,
|
| 1166 |
+
temperature=temperature,
|
| 1167 |
+
do_sample=do_sample
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
except Exception as e:
|
| 1171 |
+
logger.error(f"Failed to load dataset for random sampling: {e}")
|
| 1172 |
+
return None
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
def interactive_inference(checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
|
| 1176 |
+
dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl"):
|
| 1177 |
+
"""Interactive inference mode where user can input sample IDs."""
|
| 1178 |
+
print("\n" + "="*80)
|
| 1179 |
+
print("π€ INTERACTIVE INFERENCE MODE")
|
| 1180 |
+
print("="*80)
|
| 1181 |
+
print("Commands:")
|
| 1182 |
+
print(" <number> - Run inference on sample ID")
|
| 1183 |
+
print(" random <n> - Run inference on n random samples (default: 3)")
|
| 1184 |
+
print(" batch <ids> - Run inference on multiple IDs (e.g., 'batch 1,5,10')")
|
| 1185 |
+
print(" quit - Exit")
|
| 1186 |
+
print("="*80)
|
| 1187 |
+
|
| 1188 |
+
while True:
|
| 1189 |
+
try:
|
| 1190 |
+
user_input = input("\nπ Enter command: ").strip().lower()
|
| 1191 |
+
|
| 1192 |
+
if user_input in ['quit', 'exit', 'q']:
|
| 1193 |
+
print("π Goodbye!")
|
| 1194 |
+
break
|
| 1195 |
+
elif user_input.startswith('random'):
|
| 1196 |
+
parts = user_input.split()
|
| 1197 |
+
num_samples = int(parts[1]) if len(parts) > 1 else 3
|
| 1198 |
+
random_inference(num_samples=num_samples, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
|
| 1199 |
+
elif user_input.startswith('batch'):
|
| 1200 |
+
parts = user_input.split(maxsplit=1)
|
| 1201 |
+
if len(parts) > 1:
|
| 1202 |
+
ids_str = parts[1]
|
| 1203 |
+
sample_ids = [int(x.strip()) for x in ids_str.split(',')]
|
| 1204 |
+
batch_inference(sample_ids=sample_ids, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
|
| 1205 |
+
else:
|
| 1206 |
+
print("β Please provide sample IDs: batch 1,5,10")
|
| 1207 |
+
elif user_input.isdigit():
|
| 1208 |
+
sample_id = int(user_input)
|
| 1209 |
+
inference_by_id(sample_id=sample_id, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
|
| 1210 |
+
else:
|
| 1211 |
+
print("β Invalid command. Try a number, 'random', 'batch', or 'quit'")
|
| 1212 |
+
|
| 1213 |
+
except ValueError:
|
| 1214 |
+
print("β Invalid input. Please enter a valid number or command.")
|
| 1215 |
+
except KeyboardInterrupt:
|
| 1216 |
+
print("\nπ Goodbye!")
|
| 1217 |
+
break
|
| 1218 |
+
except Exception as e:
|
| 1219 |
+
print(f"β Error: {e}")
|
| 1220 |
+
|
| 1221 |
+
|
| 1222 |
+
def test_inference():
|
| 1223 |
+
logger.info("Running inference tests...")
|
| 1224 |
+
|
| 1225 |
+
test_ids = [0, 1, 2, 100, 500] # Mix of early and later samples
|
| 1226 |
+
|
| 1227 |
+
print("\nπ§ͺ TESTING INFERENCE ON PREDEFINED SAMPLES")
|
| 1228 |
+
results = batch_inference(
|
| 1229 |
+
sample_ids=test_ids,
|
| 1230 |
+
max_length=300,
|
| 1231 |
+
temperature=0.7,
|
| 1232 |
+
do_sample=True
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
return results
|
| 1236 |
+
|
| 1237 |
+
|
| 1238 |
+
if __name__ == "__main__":
|
| 1239 |
+
import argparse
|
| 1240 |
+
|
| 1241 |
+
parser = argparse.ArgumentParser(description="Grouped Qwen3 Training and Inference")
|
| 1242 |
+
parser.add_argument("--mode", choices=["train", "test", "inference", "interactive", "random"],
|
| 1243 |
+
default="train", help="Mode to run")
|
| 1244 |
+
parser.add_argument("--sample_id", type=int, help="Sample ID for inference mode")
|
| 1245 |
+
parser.add_argument("--sample_ids", type=str, help="Comma-separated sample IDs for batch inference")
|
| 1246 |
+
parser.add_argument("--num_samples", type=int, default=3, help="Number of random samples for random mode")
|
| 1247 |
+
parser.add_argument("--checkpoint", type=str, default="./grouped_qwen3_checkpoint/epoch_2_best",
|
| 1248 |
+
help="Path to model checkpoint")
|
| 1249 |
+
parser.add_argument("--dataset", type=str, default="./processed_qwen3_dataset/processed_dataset.pkl",
|
| 1250 |
+
help="Path to dataset")
|
| 1251 |
+
parser.add_argument("--max_length", type=int, default=512, help="Maximum generation length")
|
| 1252 |
+
parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature")
|
| 1253 |
+
|
| 1254 |
+
args = parser.parse_args()
|
| 1255 |
+
|
| 1256 |
+
if args.mode == "train":
|
| 1257 |
+
main()
|
| 1258 |
+
elif args.mode == "test":
|
| 1259 |
+
test_inference()
|
| 1260 |
+
elif args.mode == "inference":
|
| 1261 |
+
if args.sample_id is not None:
|
| 1262 |
+
inference_by_id(
|
| 1263 |
+
sample_id=args.sample_id,
|
| 1264 |
+
checkpoint_path=args.checkpoint,
|
| 1265 |
+
dataset_path=args.dataset,
|
| 1266 |
+
max_length=args.max_length,
|
| 1267 |
+
temperature=args.temperature
|
| 1268 |
+
)
|
| 1269 |
+
elif args.sample_ids is not None:
|
| 1270 |
+
sample_ids = [int(x.strip()) for x in args.sample_ids.split(',')]
|
| 1271 |
+
batch_inference(
|
| 1272 |
+
sample_ids=sample_ids,
|
| 1273 |
+
checkpoint_path=args.checkpoint,
|
| 1274 |
+
dataset_path=args.dataset,
|
| 1275 |
+
max_length=args.max_length,
|
| 1276 |
+
temperature=args.temperature
|
| 1277 |
+
)
|
| 1278 |
+
else:
|
| 1279 |
+
print("β Please provide --sample_id or --sample_ids for inference mode")
|
| 1280 |
+
elif args.mode == "interactive":
|
| 1281 |
+
interactive_inference(
|
| 1282 |
+
checkpoint_path=args.checkpoint,
|
| 1283 |
+
dataset_path=args.dataset
|
| 1284 |
+
)
|
| 1285 |
+
elif args.mode == "random":
|
| 1286 |
+
random_inference(
|
| 1287 |
+
num_samples=args.num_samples,
|
| 1288 |
+
checkpoint_path=args.checkpoint,
|
| 1289 |
+
dataset_path=args.dataset,
|
| 1290 |
+
max_length=args.max_length,
|
| 1291 |
+
temperature=args.temperature
|
| 1292 |
+
)
|