English
kkolomeitsev commited on
Commit
d872c55
Β·
verified Β·
1 Parent(s): ca1f8dd

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ grouped_qwen3_checkpoint/epoch_2_best/tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Konstantin Kolomeitsev
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,192 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Context Merging: from Tokens to Entities and Concepts
2
+
3
+ This repo contains a minimal research pipeline that compresses input context for Qwen3 by grouping dependent subtokens early, then trains a small adapter to consume the grouped embeddings.
4
+
5
+ - `prepare_dataset.py` builds a local dataset of grouped embeddings from a base Qwen3 with a custom layer 0 that performs token grouping.
6
+ - `train_custom_qwen3.py` fine-tunes a customized Qwen3 that adds a small MLP adapter for grouped inputs, while freezing all weights except layer 0.
7
+ - `inference_qwen3_merged.py` runs end-to-end inference by first grouping with the base model, then generating with the trained model that understands grouped inputs. Includes perf metrics and estimated attention-memory savings.
8
+
9
+ ---
10
+
11
+ ## How it works
12
+
13
+ 1. **Layer-0 grouping at prefill**
14
+ A custom decoder layer 0 computes attention on the full token sequence, clusters adjacent tokens using lightweight heuristics plus attention relations, then averages token vectors per group. The grouped result is added back to a residual projection and saved as `grouped_hidden_states`.
15
+
16
+ 2. **Dataset building**
17
+ The dataset builder swaps in the custom layer 0, feeds formatted prompts, extracts the stored `grouped_hidden_states`, and serializes them together with target responses.
18
+
19
+ 3. **Model training**
20
+ The training model wraps Qwen3 with a **GroupedInputMLPAdapter** that processes the grouped embeddings during prefill. Only layer 0 and the adapter are trainable; embeddings, upper layers, final norm, and LM head are frozen. Prefill uses `grouped_inputs` as `inputs_embeds`, then generation proceeds with past-key-values.
21
+
22
+ 4. **Inference**
23
+ The inference runner loads two models: a grouping model with the custom layer 0, and your trained model. It reports token compression, timing, and memory usage. Savings are also estimated with a simple attention-cost proxy that scales with sequence length squared.
24
+
25
+ ---
26
+
27
+ ## Requirements
28
+
29
+ - Python packages: `torch`, `transformers`, `datasets`, `tqdm`, `psutil`. These are imported directly in the scripts.
30
+ - GPU is optional. Scripts detect CUDA and set dtype accordingly.
31
+
32
+ Install:
33
+
34
+ ```bash
35
+ pip install torch transformers datasets tqdm psutil
36
+ ```
37
+
38
+ ---
39
+
40
+ ## Repository layout
41
+
42
+ - `prepare_dataset.py` - dataset builder using custom layer 0 grouping.
43
+ - `train_custom_qwen3.py` - trainer for grouped-input Qwen3 with an MLP adapter, freezing all but layer 0.
44
+ - `inference_qwen3_merged.py` - two-stage inference runner with metrics.
45
+
46
+ ---
47
+
48
+ ## 1 Build the local dataset
49
+
50
+ Run:
51
+
52
+ ```bash
53
+ python prepare_dataset.py
54
+ ```
55
+
56
+ Key defaults inside `DatasetProcessor`:
57
+
58
+ - `model_name="Qwen/Qwen3-0.6B"`
59
+ - `dataset_name="Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1"`
60
+ - `output_dir="./processed_dataset"`
61
+ - `batch_size=1`, `max_samples=None`, `save_frequency=1000`
62
+ Edit these in the constructor if you need to change them.
63
+
64
+ The builder formats inputs using a simple system prompt template.
65
+ It tokenizes, runs layer 0 once per example, captures `grouped_hidden_states`, and buffers results.
66
+
67
+ **Outputs** under `output_dir`:
68
+
69
+ - `processed_dataset.pkl` - list of samples with `inputs_embeds` (grouped), `response`, and metadata.
70
+ - Additional metadata and sample previews are written alongside, for quick inspection.
71
+
72
+ ---
73
+
74
+ ## 2 Train the grouped-input model
75
+
76
+ Run:
77
+
78
+ ```bash
79
+ python train_custom_qwen3.py --mode train
80
+ ```
81
+
82
+ Training config defaults (edit in the script if needed):
83
+
84
+ - `model_name="Qwen/Qwen3-0.6B"`
85
+ - `dataset_path="./processed_qwen3_dataset/processed_dataset.pkl"`
86
+ - `output_dir="./grouped_qwen3_checkpoint"`
87
+ - `batch_size=4`, `learning_rate=5e-4`, `num_epochs=3`, `warmup_steps=100`
88
+ - Logging, eval, and checkpoint cadence are configurable.
89
+
90
+ What is trained:
91
+
92
+ - A **GroupedInputMLPAdapter** that takes grouped embeddings and returns adapted embeddings, normalized with RMSNorm.
93
+ - Only layer 0 and this adapter are trainable; everything else is frozen.
94
+
95
+ How targets are computed:
96
+
97
+ - Prefill: pass `grouped_inputs` via `inputs_embeds` with `is_prefill=True`.
98
+ - Then feed target response tokens while reusing `past_key_values`.
99
+
100
+ Checkpoints contain model weights, config, and tokenizer in the epoch folder.
101
+
102
+ ---
103
+
104
+ ## 3 Run inference
105
+
106
+ ### Option A - standalone runner
107
+
108
+ Quick start:
109
+
110
+ ```bash
111
+ python inference_qwen3_merged.py \
112
+ --checkpoint ./grouped_qwen3_checkpoint/epoch_2_best \
113
+ --grouping_model Qwen/Qwen3-0.6B \
114
+ --instruction "Explain attention like I am in 9th grade" \
115
+ --max_length 256 \
116
+ --temperature 0.7 \
117
+ --device cuda
118
+ ```
119
+
120
+ CLI options: `--checkpoint`, `--grouping_model`, `--instruction`, `--max_length`, `--temperature`, `--no_sample` for greedy, and `--device` for cuda or cpu.
121
+
122
+ What it does:
123
+
124
+ - Loads a grouping model with the custom layer 0 and a trained inference model.
125
+ - Phase 1 groups tokens and reports compression. Phase 2 generates with the trained model.
126
+ - Reports compression ratio, memory reduction, total time, and tokens per second.
127
+
128
+ ### Option B - use the training script utilities
129
+
130
+ The trainer exposes helper functions for loading a trained model and running generation with grouped inputs. See `load_trained_model` and `generate_with_grouped_input` in the training script if you prefer a programmatic flow.
131
+
132
+ ---
133
+
134
+ ## Parameters - quick reference
135
+
136
+ ### Dataset builder
137
+
138
+ - `model_name` - base HF model for grouping, default Qwen/Qwen3-0.6B.
139
+ - `dataset_name` - source HF dataset split, default Magpie-Align... Qwen2.5-Pro-1M.
140
+ - `output_dir` - where pickles and metadata go.
141
+ - `max_samples` - optional cap for quick tests.
142
+
143
+ ### Training
144
+
145
+ - `dataset_path` - path to `processed_dataset.pkl`.
146
+ - `output_dir` - where checkpoints are written.
147
+ - `batch_size, learning_rate, num_epochs, warmup_steps` - training hyperparams.
148
+ - Only layer 0 and the adapter are trainable. Verify with `requires_grad` settings in `_freeze_layers`.
149
+
150
+ ### Inference
151
+
152
+ - `--checkpoint` - path to trained checkpoint folder.
153
+ - `--grouping_model` - HF model name used for grouping.
154
+ - `--instruction` - user prompt, any language.
155
+ - `--max_length`, `--temperature`, `--no_sample`, `--device`.
156
+
157
+ ---
158
+
159
+ ## Notes
160
+
161
+ - The custom layer 0 is installed by copying weights from the original layer 0, then replacing the module so it can compute groups and cache the grouped states.
162
+ - Grouping relies on simple rules over tokens like space and newline boundaries plus attention relations. You can tune the threshold in `CustomQwen3Attention`.
163
+
164
+ ---
165
+
166
+ ## Troubleshooting
167
+
168
+ - **CUDA memory spikes**: reduce batch size during training or use fewer samples. Generation is incremental and reuses past-key-values.
169
+ - **No grouped states found**: ensure the custom layer 0 is used and `is_initialized` is reset before each prefill.
170
+ - **Checkpoint not found**: the inference loader expects `pytorch_model.bin` or `model.safetensors` in the checkpoint directory.
171
+
172
+ ---
173
+
174
+ ## Why this can save memory
175
+
176
+ If the sequence shrinks from `N` to `G` groups, attention memory scales roughly with `G^2` vs `N^2`. The script prints an estimated savings based on that relation.
177
+
178
+ ---
179
+
180
+ ## Citation
181
+
182
+ ```
183
+ @misc{Kolomeitsev2025ContextMerging,
184
+ title = {Context Merging: from Tokens to Entities and Concepts},
185
+ author = {Konstantin Kolomeitsev},
186
+ year = {2025}
187
+ }
188
+ ```
189
+
190
+ ## Contact
191
+
192
+ If you have any questions, please raise an issue or contact with me [[email protected]]([email protected]).
grouped_qwen3_checkpoint/epoch_2_best/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
grouped_qwen3_checkpoint/epoch_2_best/chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
grouped_qwen3_checkpoint/epoch_2_best/config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_types": [
15
+ "full_attention",
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention"
43
+ ],
44
+ "max_position_embeddings": 40960,
45
+ "max_window_layers": 28,
46
+ "model_type": "qwen3",
47
+ "num_attention_heads": 16,
48
+ "num_hidden_layers": 28,
49
+ "num_key_value_heads": 8,
50
+ "rms_norm_eps": 1e-06,
51
+ "rope_scaling": null,
52
+ "rope_theta": 1000000,
53
+ "sliding_window": null,
54
+ "tie_word_embeddings": true,
55
+ "torch_dtype": "float32",
56
+ "transformers_version": "4.55.0",
57
+ "use_cache": true,
58
+ "use_sliding_window": false,
59
+ "vocab_size": 151936
60
+ }
grouped_qwen3_checkpoint/epoch_2_best/epoch_metadata.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 2,
3
+ "global_step": 27000,
4
+ "model_name": "Qwen/Qwen3-0.6B",
5
+ "learning_rate": 0.0005,
6
+ "batch_size": 1,
7
+ "is_best": true,
8
+ "model_class": "CustomQwen3ForCausalLM"
9
+ }
grouped_qwen3_checkpoint/epoch_2_best/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
grouped_qwen3_checkpoint/epoch_2_best/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db3c76f1ed2b27fe53e699354e4b628d09c75da48a5f1371058542645a525b5a
3
+ size 2401122294
grouped_qwen3_checkpoint/epoch_2_best/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
grouped_qwen3_checkpoint/epoch_2_best/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:352a863cd2761388ccc58f1432467ba6a1037bf12df9069889b142fa246471f6
3
+ size 11422752
grouped_qwen3_checkpoint/epoch_2_best/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
grouped_qwen3_checkpoint/epoch_2_best/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
inference_qwen3_merged.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import time
4
+ import psutil
5
+ from typing import Optional, List, Dict, Any, Tuple
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
13
+ from transformers.models.qwen3.modeling_qwen3 import (
14
+ Qwen3ForCausalLM,
15
+ Qwen3RMSNorm,
16
+ Qwen3DecoderLayer,
17
+ Qwen3Attention,
18
+ Qwen3RotaryEmbedding,
19
+ )
20
+
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format="%(asctime)s - %(levelname)s - %(message)s"
24
+ )
25
+ logger = logging.getLogger("qwen3_grouped_inference")
26
+
27
+ class PerformanceMonitor:
28
+ def __init__(self):
29
+ self.reset()
30
+
31
+ def reset(self):
32
+ """Reset all metrics."""
33
+ self.start_time = None
34
+ self.end_time = None
35
+ self.start_memory = None
36
+ self.peak_memory = None
37
+ self.start_gpu_memory = None
38
+ self.peak_gpu_memory = None
39
+
40
+ def start_monitoring(self):
41
+ self.reset()
42
+ self.start_time = time.time()
43
+
44
+ process = psutil.Process()
45
+ self.start_memory = process.memory_info().rss / 1024 / 1024 # MB
46
+ self.peak_memory = self.start_memory
47
+
48
+ if torch.cuda.is_available():
49
+ torch.cuda.empty_cache()
50
+ self.start_gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 # MB
51
+ self.peak_gpu_memory = self.start_gpu_memory
52
+
53
+ def update_peak_memory(self):
54
+ process = psutil.Process()
55
+ current_memory = process.memory_info().rss / 1024 / 1024 # MB
56
+ self.peak_memory = max(self.peak_memory, current_memory)
57
+
58
+ if torch.cuda.is_available():
59
+ current_gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 # MB
60
+ self.peak_gpu_memory = max(self.peak_gpu_memory, current_gpu_memory)
61
+
62
+ def stop_monitoring(self):
63
+ self.end_time = time.time()
64
+ self.update_peak_memory()
65
+
66
+ metrics = {
67
+ "duration_ms": (self.end_time - self.start_time) * 1000,
68
+ "cpu_memory_start_mb": self.start_memory,
69
+ "cpu_memory_peak_mb": self.peak_memory,
70
+ "cpu_memory_used_mb": self.peak_memory - self.start_memory,
71
+ }
72
+
73
+ if torch.cuda.is_available():
74
+ metrics.update({
75
+ "gpu_memory_start_mb": self.start_gpu_memory,
76
+ "gpu_memory_peak_mb": self.peak_gpu_memory,
77
+ "gpu_memory_used_mb": self.peak_gpu_memory - self.start_gpu_memory,
78
+ })
79
+
80
+ return metrics
81
+
82
+ class CustomQwen3Attention(Qwen3Attention):
83
+ def __init__(self, config, layer_idx: int):
84
+ super().__init__(config, layer_idx)
85
+ self.layer_idx = layer_idx
86
+ self.tokenizer = None
87
+ self.current_input_ids = None
88
+ self.threshold = 0.1
89
+
90
+ if not hasattr(self, 'num_key_value_heads'):
91
+ self.num_key_value_heads = config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else config.num_attention_heads
92
+ if not hasattr(self, 'head_dim'):
93
+ self.head_dim = config.hidden_size // config.num_attention_heads
94
+
95
+ def set_tokenizer(self, tokenizer):
96
+ self.tokenizer = tokenizer
97
+
98
+ def set_current_input_ids(self, input_ids):
99
+ self.current_input_ids = input_ids
100
+
101
+ def _is_special_token(self, token: str) -> bool:
102
+ if self.tokenizer is None:
103
+ return False
104
+
105
+ special_tokens = set()
106
+
107
+ if hasattr(self.tokenizer, 'special_tokens_map'):
108
+ for token_type, token_value in self.tokenizer.special_tokens_map.items():
109
+ if isinstance(token_value, str):
110
+ special_tokens.add(token_value)
111
+ elif isinstance(token_value, list):
112
+ special_tokens.update(token_value)
113
+
114
+ if hasattr(self.tokenizer, 'added_tokens_encoder'):
115
+ special_tokens.update(self.tokenizer.added_tokens_encoder.keys())
116
+
117
+ if token in special_tokens:
118
+ return True
119
+
120
+ special_patterns = [
121
+ lambda t: t.startswith('<|') and t.endswith('|>'),
122
+ lambda t: t.startswith('<') and t.endswith('>'),
123
+ lambda t: t.startswith('[') and t.endswith(']'),
124
+ ]
125
+
126
+ return any(pattern(token) for pattern in special_patterns)
127
+
128
+ def _get_token_relations(self, attention_weights: torch.Tensor, tokens: List[str]) -> List[Dict]:
129
+ batch_size, num_heads, query_len, key_len = attention_weights.shape
130
+ attn = attention_weights[0].mean(dim=0)
131
+ relations = []
132
+
133
+ if query_len == 1:
134
+ current_token_pos = len(tokens) - 1
135
+ token_relations = []
136
+ for j in range(len(tokens)):
137
+ if j != current_token_pos:
138
+ weight = attn[0, j].item()
139
+ if weight > self.threshold:
140
+ token_relations.append({
141
+ 'target_pos': j,
142
+ 'weight': round(weight, 3)
143
+ })
144
+
145
+ relations.append({
146
+ 'source_pos': current_token_pos,
147
+ 'relations': token_relations
148
+ })
149
+ else:
150
+ for i in range(min(query_len, len(tokens))):
151
+ token_relations = []
152
+ for j in range(len(tokens)):
153
+ if i != j and j < key_len:
154
+ weight = attn[i, j].item()
155
+ if weight > self.threshold:
156
+ token_relations.append({
157
+ 'target_pos': j,
158
+ 'weight': round(weight, 3)
159
+ })
160
+
161
+ relations.append({
162
+ 'source_pos': i,
163
+ 'relations': token_relations
164
+ })
165
+
166
+ return relations
167
+
168
+ def _get_token_groups(self, attention_weights: torch.Tensor) -> List[List[int]]:
169
+ if self.tokenizer is None or self.current_input_ids is None:
170
+ return []
171
+
172
+ if len(attention_weights.shape) != 4:
173
+ return []
174
+
175
+ batch_size, num_heads, query_len, key_len = attention_weights.shape
176
+
177
+ input_ids = self.current_input_ids
178
+ if input_ids is None or input_ids.shape[1] < key_len:
179
+ return []
180
+
181
+ tokens = [self.tokenizer.decode([token_id]) for token_id in input_ids[0][:key_len]]
182
+ relations = self._get_token_relations(attention_weights, tokens)
183
+
184
+ groups = []
185
+ current_group = []
186
+ current_group_indices = []
187
+
188
+ for i, token in enumerate(tokens):
189
+ is_empty_relations = i < len(relations) and len(relations[i]['relations']) == 0
190
+ starts_with_space = token.startswith(' ') and token != ' '
191
+ is_space = token == ' '
192
+ is_new_line = '\n' in token
193
+
194
+ prev_token_is_special = False
195
+ prev_token_is_new_line = False
196
+ prev_token_is_space = False
197
+ if i > 0:
198
+ prev_token = tokens[i-1]
199
+ prev_token_is_special = self._is_special_token(prev_token)
200
+ prev_token_is_new_line = '\n' in prev_token
201
+ prev_token_is_space = prev_token == ' '
202
+
203
+ prev_newline_current_not = prev_token_is_new_line and not is_new_line
204
+ prev_space_current_not = prev_token_is_space and not is_space
205
+ current_space_prev_not = is_space and not prev_token_is_space
206
+
207
+ if (is_empty_relations or starts_with_space or is_new_line or
208
+ prev_token_is_special or prev_newline_current_not or prev_space_current_not or
209
+ current_space_prev_not) and current_group:
210
+ groups.append(current_group_indices)
211
+ current_group = []
212
+ current_group_indices = []
213
+
214
+ current_group.append(token)
215
+ current_group_indices.append(i)
216
+
217
+ if current_group:
218
+ groups.append(current_group_indices)
219
+
220
+ if groups:
221
+ logger.info("Token grouping details:")
222
+ for group_idx, group_indices in enumerate(groups):
223
+ group_tokens = [tokens[i] for i in group_indices]
224
+ combined_text = ''.join(group_tokens)
225
+ logger.info(f" Group {group_idx + 1}: {group_tokens} β†’ '{combined_text}'")
226
+
227
+ return groups
228
+
229
+ class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
230
+ """Custom Qwen3 decoder layer with grouping functionality."""
231
+
232
+ def __init__(self, config, layer_idx: int):
233
+ super().__init__(config, layer_idx)
234
+ self.layer_idx = layer_idx
235
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
236
+ self.self_attn = CustomQwen3Attention(config, layer_idx)
237
+ self.is_initialized = False
238
+ self.grouped_hidden_states = None
239
+
240
+ def forward(
241
+ self,
242
+ hidden_states: torch.Tensor,
243
+ attention_mask: Optional[torch.Tensor] = None,
244
+ position_ids: Optional[torch.LongTensor] = None,
245
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
246
+ output_attentions: Optional[bool] = False,
247
+ use_cache: Optional[bool] = False,
248
+ cache_position: Optional[torch.LongTensor] = None,
249
+ position_embeddings: Optional[tuple] = None,
250
+ **kwargs,
251
+ ):
252
+ if self.layer_idx != 0:
253
+ return super().forward(
254
+ hidden_states=hidden_states,
255
+ attention_mask=attention_mask,
256
+ position_ids=position_ids,
257
+ past_key_value=past_key_value,
258
+ output_attentions=output_attentions,
259
+ use_cache=use_cache,
260
+ cache_position=cache_position,
261
+ position_embeddings=position_embeddings,
262
+ **kwargs,
263
+ )
264
+
265
+ is_prefill = hidden_states.shape[1] > 1 and not self.is_initialized
266
+ if not is_prefill:
267
+ return super().forward(
268
+ hidden_states=hidden_states,
269
+ attention_mask=attention_mask,
270
+ position_ids=position_ids,
271
+ past_key_value=past_key_value,
272
+ output_attentions=output_attentions,
273
+ use_cache=use_cache,
274
+ cache_position=cache_position,
275
+ position_embeddings=position_embeddings,
276
+ **kwargs,
277
+ )
278
+
279
+ residual = hidden_states
280
+ x = self.input_layernorm(hidden_states)
281
+
282
+ attn_ret = self.self_attn(
283
+ hidden_states=x,
284
+ attention_mask=attention_mask,
285
+ position_ids=position_ids,
286
+ past_key_value=None,
287
+ output_attentions=True,
288
+ use_cache=False,
289
+ cache_position=cache_position,
290
+ position_embeddings=position_embeddings,
291
+ )
292
+
293
+ if isinstance(attn_ret, tuple):
294
+ if len(attn_ret) == 3:
295
+ attn_out, attn_weights, _ = attn_ret
296
+ elif len(attn_ret) == 2:
297
+ attn_out, attn_weights = attn_ret
298
+ else:
299
+ raise RuntimeError(f"Unexpected attention return length: {len(attn_ret)}")
300
+ else:
301
+ raise RuntimeError("Attention did not return weights.")
302
+
303
+ groups = self.self_attn._get_token_groups(attn_weights)
304
+ if not groups:
305
+ self.is_initialized = True
306
+ return super().forward(
307
+ hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ position_ids=position_ids,
310
+ past_key_value=past_key_value,
311
+ output_attentions=output_attentions,
312
+ use_cache=use_cache,
313
+ cache_position=cache_position,
314
+ position_embeddings=position_embeddings,
315
+ **kwargs,
316
+ )
317
+
318
+ averaged_vectors = []
319
+ group_info = []
320
+
321
+ for gi, idxs in enumerate(groups):
322
+ if len(idxs) == 1:
323
+ averaged_vectors.append(attn_out[:, idxs[0], :])
324
+ group_info.append({"type": "single", "positions": idxs, "new_position": gi})
325
+ else:
326
+ gvecs = attn_out[:, idxs, :]
327
+ ave = gvecs.mean(dim=1)
328
+ averaged_vectors.append(ave)
329
+ group_info.append({"type": "averaged", "positions": idxs, "new_position": gi})
330
+
331
+ new_attn_out = torch.stack(averaged_vectors, dim=1)
332
+
333
+ expanded_residual = torch.stack([
334
+ (
335
+ residual[:, info['positions'], :].sum(dim=1)
336
+ if len(info['positions']) > 1
337
+ else residual[:, info['positions'][0], :]
338
+ )
339
+ for info in group_info
340
+ ], dim=1)
341
+
342
+ hs = expanded_residual + new_attn_out
343
+ grouped_hidden = self.post_attention_layernorm(hs)
344
+
345
+ # Store grouped embeddings
346
+ self.grouped_hidden_states = grouped_hidden
347
+ self.is_initialized = True
348
+ return hs
349
+
350
+ class GroupedInputMLPAdapter(nn.Module):
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ self.config = config
354
+ hidden_size = config.hidden_size
355
+
356
+ self.grouped_processor = nn.Sequential(
357
+ nn.Linear(hidden_size, hidden_size * 2),
358
+ nn.SiLU(),
359
+ nn.Dropout(0.1),
360
+ nn.Linear(hidden_size * 2, hidden_size),
361
+ nn.Dropout(0.1)
362
+ )
363
+
364
+ norm_eps = getattr(config, 'rms_norm_eps', 1e-6)
365
+ self.layer_norm = Qwen3RMSNorm(hidden_size, eps=norm_eps)
366
+
367
+ def forward(self, grouped_embeds: torch.Tensor) -> torch.Tensor:
368
+ processed = self.grouped_processor(grouped_embeds)
369
+ output = self.layer_norm(grouped_embeds + processed)
370
+ return output
371
+
372
+
373
+ class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
374
+ def __init__(self, config):
375
+ super().__init__(config)
376
+ self.grouped_input_mlp = GroupedInputMLPAdapter(config)
377
+ self.is_grouped_input_mode = False
378
+
379
+ def forward(
380
+ self,
381
+ input_ids: Optional[torch.LongTensor] = None,
382
+ attention_mask: Optional[torch.Tensor] = None,
383
+ position_ids: Optional[torch.LongTensor] = None,
384
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
385
+ inputs_embeds: Optional[torch.FloatTensor] = None,
386
+ use_cache: Optional[bool] = None,
387
+ output_attentions: Optional[bool] = None,
388
+ output_hidden_states: Optional[bool] = None,
389
+ return_dict: Optional[bool] = None,
390
+ cache_position: Optional[torch.LongTensor] = None,
391
+ grouped_inputs: Optional[torch.FloatTensor] = None,
392
+ is_prefill: Optional[bool] = None,
393
+ **kwargs
394
+ ):
395
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
396
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
397
+
398
+ if grouped_inputs is not None and is_prefill:
399
+ self.is_grouped_input_mode = True
400
+ processed_grouped_inputs = self.grouped_input_mlp(grouped_inputs)
401
+ inputs_embeds = processed_grouped_inputs
402
+ input_ids = None
403
+
404
+ batch_size, seq_len = inputs_embeds.shape[:2]
405
+ if position_ids is None:
406
+ device = inputs_embeds.device
407
+ position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
408
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
409
+
410
+ if attention_mask is None:
411
+ attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long)
412
+
413
+ return super().forward(
414
+ input_ids=input_ids,
415
+ attention_mask=attention_mask,
416
+ position_ids=position_ids,
417
+ past_key_values=past_key_values,
418
+ inputs_embeds=inputs_embeds,
419
+ use_cache=use_cache,
420
+ output_attentions=output_attentions,
421
+ output_hidden_states=output_hidden_states,
422
+ return_dict=return_dict,
423
+ cache_position=cache_position,
424
+ **kwargs
425
+ )
426
+
427
+ def create_grouping_model(model_name: str = "Qwen/Qwen3-0.6B") -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
428
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
429
+
430
+ if torch.cuda.is_available():
431
+ device = torch.device("cuda")
432
+ dtype = torch.float16
433
+ else:
434
+ device = torch.device("cpu")
435
+ dtype = torch.float32
436
+
437
+ model = AutoModelForCausalLM.from_pretrained(
438
+ model_name,
439
+ torch_dtype=dtype,
440
+ attn_implementation="eager"
441
+ ).to(device)
442
+
443
+ orig0 = model.model.layers[0]
444
+ custom0 = CustomQwen3DecoderLayer(model.config, 0)
445
+
446
+ custom0.mlp.load_state_dict(orig0.mlp.state_dict())
447
+ custom0.input_layernorm.load_state_dict(orig0.input_layernorm.state_dict())
448
+ custom0.post_attention_layernorm.load_state_dict(orig0.post_attention_layernorm.state_dict())
449
+ custom0.self_attn.load_state_dict(orig0.self_attn.state_dict())
450
+
451
+ custom0.self_attn.set_tokenizer(tokenizer)
452
+ custom0 = custom0.to(device=device, dtype=dtype)
453
+ model.model.layers[0] = custom0
454
+
455
+ return model, tokenizer
456
+
457
+
458
+ def load_inference_model(checkpoint_path: str) -> Tuple[CustomQwen3ForCausalLM, AutoTokenizer]:
459
+ logger.info(f"Loading inference model from {checkpoint_path}")
460
+
461
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
462
+ config = Qwen3Config.from_pretrained(checkpoint_path)
463
+ model = CustomQwen3ForCausalLM(config)
464
+
465
+ model_path = Path(checkpoint_path) / "pytorch_model.bin"
466
+ if not model_path.exists():
467
+ model_path = Path(checkpoint_path) / "model.safetensors"
468
+
469
+ if not model_path.exists():
470
+ raise FileNotFoundError(f"No model weights found in {checkpoint_path}")
471
+
472
+ state_dict = torch.load(model_path, map_location="cpu")
473
+ model.load_state_dict(state_dict, strict=False)
474
+ model = model.eval().to(torch.float32)
475
+
476
+ return model, tokenizer
477
+
478
+ class Qwen3GroupedInference:
479
+ def __init__(self,
480
+ checkpoint_path: str,
481
+ grouping_model_name: str = "Qwen/Qwen3-0.6B",
482
+ device: Optional[str] = None):
483
+ """Initialize inference system with both models."""
484
+
485
+ if device is None:
486
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
487
+ else:
488
+ self.device = torch.device(device)
489
+
490
+ logger.info(f"Initializing inference on device: {self.device}")
491
+
492
+ self.system_prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
493
+ self.response_start = "<|im_end|>\n<|im_start|>assistant\n"
494
+
495
+ logger.info("Loading grouping model...")
496
+ self.grouping_model, self.grouping_tokenizer = create_grouping_model(grouping_model_name)
497
+ self.grouping_model = self.grouping_model.to(self.device)
498
+
499
+ logger.info("Loading inference model...")
500
+ self.inference_model, self.inference_tokenizer = load_inference_model(checkpoint_path)
501
+ self.inference_model = self.inference_model.to(self.device)
502
+
503
+ logger.info("Both models loaded successfully")
504
+
505
+ def format_input_text(self, instruction: str) -> str:
506
+ return f"{self.system_prompt}{instruction}{self.response_start}"
507
+
508
+ def get_grouped_embeddings(self, text: str) -> Tuple[torch.Tensor, Dict[str, Any]]:
509
+ monitor = PerformanceMonitor()
510
+ monitor.start_monitoring()
511
+
512
+ if hasattr(self.grouping_model.model.layers[0], "is_initialized"):
513
+ self.grouping_model.model.layers[0].is_initialized = False
514
+
515
+ batch = self.grouping_tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(self.device)
516
+ input_ids = batch["input_ids"]
517
+ original_token_count = input_ids.shape[1]
518
+
519
+ original_tokens = [self.grouping_tokenizer.decode([token_id]) for token_id in input_ids[0]]
520
+ logger.info(f"Original input tokens ({original_token_count}): {original_tokens}")
521
+
522
+ if hasattr(self.grouping_model.model.layers[0], "self_attn"):
523
+ sat = self.grouping_model.model.layers[0].self_attn
524
+ if hasattr(sat, "set_current_input_ids"):
525
+ sat.set_current_input_ids(input_ids)
526
+
527
+ monitor.update_peak_memory()
528
+
529
+ with torch.no_grad():
530
+ inputs_embeds = self.grouping_model.model.embed_tokens(input_ids)
531
+ seq_len = inputs_embeds.shape[1]
532
+ position_ids = torch.arange(seq_len, device=self.device, dtype=torch.long).unsqueeze(0)
533
+
534
+ if hasattr(self.grouping_model.model, 'rotary_emb'):
535
+ pos_embeds = self.grouping_model.model.rotary_emb(inputs_embeds, position_ids)
536
+ else:
537
+ pos_embeds = None
538
+
539
+ monitor.update_peak_memory()
540
+
541
+ _ = self.grouping_model.model.layers[0](
542
+ hidden_states=inputs_embeds,
543
+ attention_mask=None,
544
+ position_ids=position_ids,
545
+ past_key_value=None,
546
+ output_attentions=False,
547
+ use_cache=False,
548
+ cache_position=None,
549
+ position_embeddings=pos_embeds,
550
+ )
551
+
552
+ monitor.update_peak_memory()
553
+
554
+ if (hasattr(self.grouping_model.model.layers[0], "grouped_hidden_states") and
555
+ self.grouping_model.model.layers[0].grouped_hidden_states is not None):
556
+ grouped_embeds = self.grouping_model.model.layers[0].grouped_hidden_states.clone()
557
+ grouped_token_count = grouped_embeds.shape[1]
558
+ # Clear the stored state
559
+ self.grouping_model.model.layers[0].grouped_hidden_states = None
560
+
561
+ compression_ratio = original_token_count / grouped_token_count if grouped_token_count > 0 else 1.0
562
+ reduction_percent = (1 - grouped_token_count / original_token_count) * 100 if original_token_count > 0 else 0.0
563
+
564
+ logger.info(f"Grouped tokens: {grouped_token_count}")
565
+ logger.info(f"Compression ratio: {compression_ratio:.2f}x ({reduction_percent:.1f}% reduction)")
566
+
567
+ metrics = monitor.stop_monitoring()
568
+ metrics.update({
569
+ "original_tokens": original_token_count,
570
+ "grouped_tokens": grouped_token_count,
571
+ "compression_ratio": compression_ratio,
572
+ "reduction_percent": reduction_percent
573
+ })
574
+
575
+ return grouped_embeds.squeeze(0), metrics
576
+ else:
577
+ logger.warning("Grouping failed, using original embeddings")
578
+
579
+ metrics = monitor.stop_monitoring()
580
+ metrics.update({
581
+ "original_tokens": original_token_count,
582
+ "grouped_tokens": original_token_count,
583
+ "compression_ratio": 1.0,
584
+ "reduction_percent": 0.0
585
+ })
586
+
587
+ return inputs_embeds.squeeze(0), metrics
588
+
589
+ def generate_with_grouped_input(self,
590
+ grouped_input: torch.Tensor,
591
+ max_length: int = 512,
592
+ temperature: float = 0.7,
593
+ do_sample: bool = True) -> Tuple[str, Dict[str, Any]]:
594
+ """Generate text using grouped input embeddings."""
595
+ monitor = PerformanceMonitor()
596
+ monitor.start_monitoring()
597
+
598
+ model_dtype = next(self.inference_model.parameters()).dtype
599
+ grouped_input = grouped_input.to(device=self.device, dtype=model_dtype)
600
+
601
+ if grouped_input.ndim == 2:
602
+ grouped_input = grouped_input.unsqueeze(0)
603
+
604
+ input_seq_len = grouped_input.shape[1]
605
+ logger.info(f"Inference model input sequence length: {input_seq_len}")
606
+
607
+ monitor.update_peak_memory()
608
+
609
+ with torch.no_grad():
610
+ outputs = self.inference_model(
611
+ grouped_inputs=grouped_input,
612
+ is_prefill=True,
613
+ use_cache=True,
614
+ return_dict=True
615
+ )
616
+
617
+ monitor.update_peak_memory()
618
+
619
+ if hasattr(outputs, 'logits') and outputs.logits is not None:
620
+ next_token_logits = outputs.logits[:, -1, :]
621
+ else:
622
+ raise RuntimeError("Could not extract logits from model output")
623
+
624
+ if do_sample:
625
+ next_token_logits = next_token_logits / temperature
626
+ probs = F.softmax(next_token_logits, dim=-1)
627
+ next_token = torch.multinomial(probs, num_samples=1)
628
+ else:
629
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
630
+
631
+ generated_ids = next_token
632
+ past_key_values = getattr(outputs, 'past_key_values', None)
633
+
634
+ generated_tokens = 1
635
+ for step in range(max_length - 1):
636
+ monitor.update_peak_memory()
637
+
638
+ with torch.no_grad():
639
+ outputs = self.inference_model(
640
+ input_ids=next_token,
641
+ past_key_values=past_key_values,
642
+ use_cache=True,
643
+ return_dict=True
644
+ )
645
+
646
+ if not hasattr(outputs, 'logits'):
647
+ break
648
+
649
+ next_token_logits = outputs.logits[:, -1, :]
650
+
651
+ if do_sample:
652
+ next_token_logits = next_token_logits / temperature
653
+ probs = F.softmax(next_token_logits, dim=-1)
654
+ next_token = torch.multinomial(probs, num_samples=1)
655
+ else:
656
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
657
+
658
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
659
+ past_key_values = getattr(outputs, 'past_key_values', None)
660
+ generated_tokens += 1
661
+
662
+ if next_token.item() == self.inference_tokenizer.eos_token_id:
663
+ break
664
+
665
+ generated_text = self.inference_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
666
+
667
+ metrics = monitor.stop_monitoring()
668
+ metrics.update({
669
+ "input_seq_len": input_seq_len,
670
+ "generated_tokens": generated_tokens,
671
+ "tokens_per_second": generated_tokens / (metrics["duration_ms"] / 1000) if metrics["duration_ms"] > 0 else 0
672
+ })
673
+
674
+ logger.info(f"Generated {generated_tokens} tokens in {metrics['duration_ms']:.1f}ms")
675
+ logger.info(f"Generation speed: {metrics['tokens_per_second']:.1f} tokens/second")
676
+
677
+ return generated_text, metrics
678
+
679
+ def inference(self,
680
+ instruction: str,
681
+ max_length: int = 512,
682
+ temperature: float = 0.7,
683
+ do_sample: bool = True) -> Dict[str, Any]:
684
+ """Run complete inference pipeline from instruction to response."""
685
+ logger.info("=" * 60)
686
+ logger.info("STARTING INFERENCE PIPELINE")
687
+ logger.info("=" * 60)
688
+
689
+ input_text = self.format_input_text(instruction)
690
+
691
+ logger.info("PHASE 1: Token Grouping")
692
+ grouped_embeddings, grouping_metrics = self.get_grouped_embeddings(input_text)
693
+
694
+ logger.info("PHASE 2: Response Generation")
695
+ response, generation_metrics = self.generate_with_grouped_input(
696
+ grouped_input=grouped_embeddings,
697
+ max_length=max_length,
698
+ temperature=temperature,
699
+ do_sample=do_sample
700
+ )
701
+
702
+ total_metrics = {
703
+ "grouping": grouping_metrics,
704
+ "generation": generation_metrics,
705
+ "total_duration_ms": grouping_metrics["duration_ms"] + generation_metrics["duration_ms"],
706
+ }
707
+
708
+ logger.info("=" * 60)
709
+ logger.info("INFERENCE SUMMARY")
710
+ logger.info("=" * 60)
711
+ logger.info(f"Input compression: {grouping_metrics['original_tokens']} β†’ {grouping_metrics['grouped_tokens']} tokens")
712
+ logger.info(f"Compression ratio: {grouping_metrics['compression_ratio']:.2f}x")
713
+ logger.info(f"Memory reduction: {grouping_metrics['reduction_percent']:.1f}%")
714
+ logger.info(f"Total time: {total_metrics['total_duration_ms']:.1f}ms")
715
+ logger.info(f"Generation speed: {generation_metrics['tokens_per_second']:.1f} tokens/sec")
716
+
717
+ if torch.cuda.is_available():
718
+ total_gpu_memory = grouping_metrics.get("gpu_memory_used_mb", 0) + generation_metrics.get("gpu_memory_used_mb", 0)
719
+ logger.info(f"Total GPU memory used: {total_gpu_memory:.1f}MB")
720
+
721
+ total_cpu_memory = grouping_metrics.get("cpu_memory_used_mb", 0) + generation_metrics.get("cpu_memory_used_mb", 0)
722
+ logger.info(f"Total CPU memory used: {total_cpu_memory:.1f}MB")
723
+
724
+ original_seq_len = grouping_metrics['original_tokens']
725
+ grouped_seq_len = grouping_metrics['grouped_tokens']
726
+
727
+ estimated_memory_savings = (1 - (grouped_seq_len ** 2) / (original_seq_len ** 2)) * 100 if original_seq_len > 0 else 0
728
+ logger.info(f"Estimated attention memory savings: {estimated_memory_savings:.1f}%")
729
+
730
+ logger.info("=" * 60)
731
+
732
+ return {
733
+ "instruction": instruction,
734
+ "response": response,
735
+ "metrics": total_metrics
736
+ }
737
+
738
+ def main():
739
+ import argparse
740
+
741
+ parser = argparse.ArgumentParser(description="Qwen3 Grouped Inference")
742
+ parser.add_argument("--checkpoint", type=str,
743
+ default="./grouped_qwen3_checkpoint/epoch_2_best",
744
+ help="Path to trained model checkpoint")
745
+ parser.add_argument("--grouping_model", type=str, default="Qwen/Qwen3-0.6B",
746
+ help="Grouping model name")
747
+ parser.add_argument("--instruction", type=str, default="""
748
+ Π§Ρ‚ΠΎ Ρ‚Π°ΠΊΠΎΠ΅ Π½Π΅ΠΉΡ€ΠΎΠ½Π½Ρ‹Π΅ сСти, объясни ΠΊΠ°ΠΊ ΡˆΠΊΠΎΠ»ΡŒΠ½ΠΈΠΊΡƒ 9Π³ΠΎ класса
749
+ """,
750
+ help="Instruction for inference")
751
+ parser.add_argument("--max_length", type=int, default=512,
752
+ help="Maximum generation length")
753
+ parser.add_argument("--temperature", type=float, default=0.7,
754
+ help="Generation temperature")
755
+ parser.add_argument("--no_sample", action="store_true",
756
+ help="Use greedy decoding")
757
+ parser.add_argument("--device", type=str,
758
+ help="Device to use (cuda/cpu)")
759
+
760
+ args = parser.parse_args()
761
+
762
+ inference_system = Qwen3GroupedInference(
763
+ checkpoint_path=args.checkpoint,
764
+ grouping_model_name=args.grouping_model,
765
+ device=args.device
766
+ )
767
+
768
+ do_sample = not args.no_sample
769
+
770
+ result = inference_system.inference(
771
+ instruction=args.instruction,
772
+ max_length=args.max_length,
773
+ temperature=args.temperature,
774
+ do_sample=do_sample
775
+ )
776
+
777
+ print(f"\nInstruction: {result['instruction']}")
778
+ print(f"Response: {result['response']}")
779
+
780
+ metrics = result.get('metrics', {})
781
+ if metrics:
782
+ print(f"\n--- Performance Metrics ---")
783
+ grouping = metrics.get('grouping', {})
784
+ generation = metrics.get('generation', {})
785
+
786
+ print(f"Token compression: {grouping.get('compression_ratio', 'N/A'):.2f}x")
787
+ print(f"Memory reduction: {grouping.get('reduction_percent', 'N/A'):.1f}%")
788
+ print(f"Total time: {metrics.get('total_duration_ms', 'N/A'):.1f}ms")
789
+ print(f"Generation speed: {generation.get('tokens_per_second', 'N/A'):.1f} tokens/sec")
790
+
791
+
792
+ if __name__ == "__main__":
793
+ main()
prepare_dataset.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import json
5
+ import pickle
6
+ from typing import Optional, Tuple, List, Dict, Any
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from datasets import load_dataset, Dataset
13
+
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ from transformers.models.qwen3.modeling_qwen3 import (
16
+ Qwen3DecoderLayer,
17
+ Qwen3Attention,
18
+ Qwen3RotaryEmbedding,
19
+ )
20
+
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format="%(asctime)s - %(levelname)s - %(message)s",
24
+ handlers=[logging.StreamHandler(sys.stdout)],
25
+ force=True,
26
+ )
27
+ logger = logging.getLogger("qwen3_dataset_processor")
28
+
29
+ class GroupedCache:
30
+ """Cache for grouping metadata."""
31
+ def __init__(self):
32
+ self.grouped_positions = None
33
+ self.position_mapping = None
34
+ self.group_info = None
35
+ self.original_seq_length = None
36
+
37
+ class CustomQwen3Attention(Qwen3Attention):
38
+ def __init__(self, config, layer_idx: int):
39
+ super().__init__(config, layer_idx)
40
+ self.layer_idx = layer_idx
41
+ self.tokenizer = None
42
+ self.current_input_ids = None
43
+ self.threshold = 0.1
44
+ self.grouped_cache = GroupedCache()
45
+
46
+ if not hasattr(self, 'num_key_value_heads'):
47
+ self.num_key_value_heads = config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else config.num_attention_heads
48
+ if not hasattr(self, 'head_dim'):
49
+ self.head_dim = config.hidden_size // config.num_attention_heads
50
+
51
+ def set_tokenizer(self, tokenizer):
52
+ self.tokenizer = tokenizer
53
+
54
+ def set_current_input_ids(self, input_ids):
55
+ self.current_input_ids = input_ids
56
+
57
+ def _is_special_token(self, token: str) -> bool:
58
+ if self.tokenizer is None:
59
+ return False
60
+
61
+ special_tokens = set()
62
+
63
+ if hasattr(self.tokenizer, 'special_tokens_map'):
64
+ for token_type, token_value in self.tokenizer.special_tokens_map.items():
65
+ if isinstance(token_value, str):
66
+ special_tokens.add(token_value)
67
+ elif isinstance(token_value, list):
68
+ special_tokens.update(token_value)
69
+
70
+ if hasattr(self.tokenizer, 'added_tokens_encoder'):
71
+ special_tokens.update(self.tokenizer.added_tokens_encoder.keys())
72
+
73
+ if token in special_tokens:
74
+ return True
75
+
76
+ special_patterns = [
77
+ lambda t: t.startswith('<|') and t.endswith('|>'),
78
+ lambda t: t.startswith('<') and t.endswith('>'),
79
+ lambda t: t.startswith('[') and t.endswith(']'),
80
+ ]
81
+
82
+ return any(pattern(token) for pattern in special_patterns)
83
+
84
+ def _get_token_relations(self, attention_weights: torch.Tensor, tokens: List[str]) -> List[Dict]:
85
+ batch_size, num_heads, query_len, key_len = attention_weights.shape
86
+
87
+ attn = attention_weights[0].mean(dim=0)
88
+
89
+ relations = []
90
+
91
+ if query_len == 1:
92
+ current_token_pos = len(tokens) - 1
93
+
94
+ token_relations = []
95
+ for j in range(len(tokens)):
96
+ if j != current_token_pos:
97
+ weight = attn[0, j].item()
98
+ if weight > self.threshold:
99
+ token_relations.append({
100
+ 'target_pos': j,
101
+ 'weight': round(weight, 3)
102
+ })
103
+
104
+ relations.append({
105
+ 'source_pos': current_token_pos,
106
+ 'relations': token_relations
107
+ })
108
+
109
+ else:
110
+ for i in range(min(query_len, len(tokens))):
111
+ token_relations = []
112
+ for j in range(len(tokens)):
113
+ if i != j and j < key_len:
114
+ weight = attn[i, j].item()
115
+ if weight > self.threshold:
116
+ token_relations.append({
117
+ 'target_pos': j,
118
+ 'weight': round(weight, 3)
119
+ })
120
+
121
+ relations.append({
122
+ 'source_pos': i,
123
+ 'relations': token_relations
124
+ })
125
+
126
+ return relations
127
+
128
+ def _get_token_groups(self, attention_weights: torch.Tensor) -> List[List[int]]:
129
+ if self.tokenizer is None or self.current_input_ids is None:
130
+ return []
131
+
132
+ if len(attention_weights.shape) != 4:
133
+ return []
134
+
135
+ batch_size, num_heads, query_len, key_len = attention_weights.shape
136
+
137
+ input_ids = self.current_input_ids
138
+ if input_ids is None or input_ids.shape[1] < key_len:
139
+ return []
140
+
141
+ tokens = [self.tokenizer.decode([token_id]) for token_id in input_ids[0][:key_len]]
142
+
143
+ relations = self._get_token_relations(attention_weights, tokens)
144
+
145
+ groups = []
146
+ current_group = []
147
+ current_group_indices = []
148
+
149
+ for i, token in enumerate(tokens):
150
+ is_empty_relations = i < len(relations) and len(relations[i]['relations']) == 0
151
+ starts_with_space = token.startswith(' ') and token != ' '
152
+ is_space = token == ' '
153
+ is_new_line = '\n' in token
154
+
155
+ prev_token_is_special = False
156
+ prev_token_is_new_line = False
157
+ prev_token_is_space = False
158
+ if i > 0:
159
+ prev_token = tokens[i-1]
160
+ prev_token_is_special = self._is_special_token(prev_token)
161
+ prev_token_is_new_line = '\n' in prev_token
162
+ prev_token_is_space = prev_token == ' '
163
+
164
+ prev_newline_current_not = prev_token_is_new_line and not is_new_line
165
+ prev_space_current_not = prev_token_is_space and not is_space
166
+ current_space_prev_not = is_space and not prev_token_is_space
167
+
168
+ if (is_empty_relations or starts_with_space or is_new_line or
169
+ prev_token_is_special or prev_newline_current_not or prev_space_current_not or
170
+ current_space_prev_not) and current_group:
171
+ groups.append(current_group_indices)
172
+ current_group = []
173
+ current_group_indices = []
174
+
175
+ current_group.append(token)
176
+ current_group_indices.append(i)
177
+
178
+ if current_group:
179
+ groups.append(current_group_indices)
180
+
181
+ return groups
182
+
183
+ class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
184
+ def __init__(self, config, layer_idx: int):
185
+ super().__init__(config, layer_idx)
186
+ self.layer_idx = layer_idx
187
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
188
+ self.self_attn = CustomQwen3Attention(config, layer_idx)
189
+ self.is_initialized = False
190
+ self.grouped_hidden_states = None
191
+ self.grouped_cache = GroupedCache()
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ position_ids: Optional[torch.LongTensor] = None,
198
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
199
+ output_attentions: Optional[bool] = False,
200
+ use_cache: Optional[bool] = False,
201
+ cache_position: Optional[torch.LongTensor] = None,
202
+ position_embeddings: Optional[tuple] = None,
203
+ **kwargs,
204
+ ):
205
+ if self.layer_idx != 0:
206
+ return super().forward(
207
+ hidden_states=hidden_states,
208
+ attention_mask=attention_mask,
209
+ position_ids=position_ids,
210
+ past_key_value=past_key_value,
211
+ output_attentions=output_attentions,
212
+ use_cache=use_cache,
213
+ cache_position=cache_position,
214
+ position_embeddings=position_embeddings,
215
+ **kwargs,
216
+ )
217
+
218
+ is_prefill = hidden_states.shape[1] > 1 and not self.is_initialized
219
+ if not is_prefill:
220
+ return super().forward(
221
+ hidden_states=hidden_states,
222
+ attention_mask=attention_mask,
223
+ position_ids=position_ids,
224
+ past_key_value=past_key_value,
225
+ output_attentions=output_attentions,
226
+ use_cache=use_cache,
227
+ cache_position=cache_position,
228
+ position_embeddings=position_embeddings,
229
+ **kwargs,
230
+ )
231
+
232
+ residual = hidden_states
233
+ x = self.input_layernorm(hidden_states)
234
+
235
+ attn_ret = self.self_attn(
236
+ hidden_states=x,
237
+ attention_mask=attention_mask,
238
+ position_ids=position_ids,
239
+ past_key_value=None,
240
+ output_attentions=True,
241
+ use_cache=False,
242
+ cache_position=cache_position,
243
+ position_embeddings=position_embeddings,
244
+ )
245
+
246
+ if isinstance(attn_ret, tuple):
247
+ if len(attn_ret) == 3:
248
+ attn_out, attn_weights, _ = attn_ret
249
+ elif len(attn_ret) == 2:
250
+ attn_out, attn_weights = attn_ret
251
+ else:
252
+ raise RuntimeError(f"Unexpected attention return length: {len(attn_ret)}")
253
+ else:
254
+ raise RuntimeError("Attention did not return weights.")
255
+
256
+ groups = self.self_attn._get_token_groups(attn_weights)
257
+ if not groups:
258
+ self.is_initialized = True
259
+ return super().forward(
260
+ hidden_states=hidden_states,
261
+ attention_mask=attention_mask,
262
+ position_ids=position_ids,
263
+ past_key_value=past_key_value,
264
+ output_attentions=output_attentions,
265
+ use_cache=use_cache,
266
+ cache_position=cache_position,
267
+ position_embeddings=position_embeddings,
268
+ **kwargs,
269
+ )
270
+
271
+ averaged_vectors = []
272
+ group_info = []
273
+ position_mapping = {}
274
+
275
+ for gi, idxs in enumerate(groups):
276
+ if len(idxs) == 1:
277
+ averaged_vectors.append(attn_out[:, idxs[0], :])
278
+ group_info.append({"type": "single", "positions": idxs, "new_position": gi})
279
+ else:
280
+ gvecs = attn_out[:, idxs, :]
281
+ ave = gvecs.mean(dim=1)
282
+ averaged_vectors.append(ave)
283
+ group_info.append({"type": "averaged", "positions": idxs, "new_position": gi})
284
+
285
+ for p in idxs:
286
+ position_mapping[p] = gi
287
+
288
+ new_attn_out = torch.stack(averaged_vectors, dim=1)
289
+
290
+ expanded_residual = torch.stack([
291
+ (
292
+ residual[:, info['positions'], :].sum(dim=1)
293
+ if len(info['positions']) > 1
294
+ else residual[:, info['positions'][0], :]
295
+ )
296
+ for info in group_info
297
+ ], dim=1)
298
+
299
+ hs = expanded_residual + new_attn_out
300
+ grouped_hidden = self.post_attention_layernorm(hs)
301
+
302
+ self.grouped_cache.grouped_positions = len(groups)
303
+ self.grouped_cache.position_mapping = position_mapping
304
+ self.grouped_cache.group_info = group_info
305
+ self.grouped_cache.original_seq_length = hidden_states.shape[1]
306
+ self.grouped_hidden_states = grouped_hidden
307
+
308
+ self.is_initialized = True
309
+ return hs
310
+
311
+ def create_model_with_custom_layer0(model_name: str = "Qwen/Qwen3-0.6B"):
312
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
313
+
314
+ if torch.cuda.is_available():
315
+ device = torch.device("cuda")
316
+ dtype = torch.float16
317
+ else:
318
+ device = torch.device("cpu")
319
+ dtype = torch.float32
320
+
321
+ model = AutoModelForCausalLM.from_pretrained(
322
+ model_name,
323
+ torch_dtype=dtype,
324
+ attn_implementation="eager"
325
+ ).to(device)
326
+
327
+ orig0 = model.model.layers[0]
328
+ custom0 = CustomQwen3DecoderLayer(model.config, 0)
329
+
330
+ custom0.mlp.load_state_dict(orig0.mlp.state_dict())
331
+ custom0.input_layernorm.load_state_dict(orig0.input_layernorm.state_dict())
332
+ custom0.post_attention_layernorm.load_state_dict(orig0.post_attention_layernorm.state_dict())
333
+ custom0.self_attn.load_state_dict(orig0.self_attn.state_dict())
334
+
335
+ custom0.self_attn.set_tokenizer(tokenizer)
336
+ custom0 = custom0.to(device=device, dtype=dtype)
337
+ model.model.layers[0] = custom0
338
+
339
+ return model, tokenizer
340
+
341
+ class DatasetProcessor:
342
+ def __init__(self,
343
+ model_name: str = "Qwen/Qwen3-0.6B",
344
+ dataset_name: str = "Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1",
345
+ output_dir: str = "./processed_dataset",
346
+ batch_size: int = 1,
347
+ max_samples: Optional[int] = None,
348
+ save_frequency: int = 1000):
349
+
350
+ self.model_name = model_name
351
+ self.dataset_name = dataset_name
352
+ self.output_dir = Path(output_dir)
353
+ self.batch_size = batch_size
354
+ self.max_samples = max_samples
355
+ self.save_frequency = save_frequency
356
+
357
+ self.output_dir.mkdir(parents=True, exist_ok=True)
358
+
359
+ # System prompt template for Qwen3
360
+ self.system_prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
361
+ self.response_start = "<|im_end|>\n<|im_start|>assistant\n"
362
+
363
+ self.current_chunk = 0
364
+ self.processed_data_buffer = []
365
+
366
+ def load_dataset(self) -> Dataset:
367
+ logger.info(f"Loading dataset: {self.dataset_name}")
368
+ dataset = load_dataset(self.dataset_name, split="train")
369
+
370
+ if self.max_samples:
371
+ dataset = dataset.select(range(min(self.max_samples, len(dataset))))
372
+
373
+ logger.info(f"Dataset loaded: {len(dataset)} samples")
374
+ return dataset
375
+
376
+ def format_input_text(self, instruction: str) -> str:
377
+ return f"{self.system_prompt}{instruction}{self.response_start}"
378
+
379
+ def process_embeddings_batch(self, model, tokenizer, texts: List[str]) -> List[torch.Tensor]:
380
+ device = model.device
381
+ embeddings_batch = []
382
+
383
+ for text in texts:
384
+ try:
385
+ if hasattr(model.model.layers[0], "is_initialized"):
386
+ model.model.layers[0].is_initialized = False
387
+
388
+ batch = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(device)
389
+ input_ids = batch["input_ids"]
390
+
391
+ if hasattr(model.model.layers[0], "self_attn"):
392
+ sat = model.model.layers[0].self_attn
393
+ if hasattr(sat, "set_current_input_ids"):
394
+ sat.set_current_input_ids(input_ids)
395
+
396
+ with torch.no_grad():
397
+ inputs_embeds = model.model.embed_tokens(input_ids)
398
+ seq_len = inputs_embeds.shape[1]
399
+ position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
400
+
401
+ if hasattr(model.model, 'rotary_emb'):
402
+ pos_embeds = model.model.rotary_emb(inputs_embeds, position_ids)
403
+ else:
404
+ pos_embeds = None
405
+
406
+ _ = model.model.layers[0](
407
+ hidden_states=inputs_embeds,
408
+ attention_mask=None,
409
+ position_ids=position_ids,
410
+ past_key_value=None,
411
+ output_attentions=False,
412
+ use_cache=False,
413
+ cache_position=None,
414
+ position_embeddings=pos_embeds,
415
+ )
416
+
417
+ if (hasattr(model.model.layers[0], "grouped_hidden_states") and
418
+ model.model.layers[0].grouped_hidden_states is not None):
419
+ grouped_embeds = model.model.layers[0].grouped_hidden_states.clone().cpu()
420
+ embeddings_batch.append(grouped_embeds.squeeze(0))
421
+
422
+ model.model.layers[0].grouped_hidden_states = None
423
+ else:
424
+ embeddings_batch.append(inputs_embeds.squeeze(0).cpu())
425
+
426
+ del inputs_embeds, position_ids
427
+ if pos_embeds is not None:
428
+ del pos_embeds
429
+ if torch.cuda.is_available():
430
+ torch.cuda.empty_cache()
431
+
432
+ except Exception as e:
433
+ logger.warning(f"Error processing sample: {e}")
434
+ embeddings_batch.append(torch.zeros(1, model.config.hidden_size))
435
+
436
+ return embeddings_batch
437
+
438
+ def save_chunk(self, chunk_data: List[Dict[str, Any]], chunk_id: int):
439
+ if not chunk_data:
440
+ return
441
+
442
+ chunk_path = self.output_dir / f"processed_chunk_{chunk_id:04d}.pkl"
443
+ with open(chunk_path, 'wb') as f:
444
+ pickle.dump(chunk_data, f)
445
+
446
+ # Clear memory
447
+ del chunk_data
448
+ import gc
449
+ gc.collect()
450
+
451
+ def merge_chunks(self) -> List[Dict[str, Any]]:
452
+ logger.info("Merging chunks...")
453
+
454
+ chunk_files = sorted(list(self.output_dir.glob("processed_chunk_*.pkl")))
455
+ if not chunk_files:
456
+ return []
457
+
458
+ merged_data = []
459
+ for chunk_file in tqdm(chunk_files, desc="Merging chunks"):
460
+ try:
461
+ with open(chunk_file, 'rb') as f:
462
+ chunk_data = pickle.load(f)
463
+ if isinstance(chunk_data, list):
464
+ merged_data.extend(chunk_data)
465
+ except Exception as e:
466
+ logger.error(f"Error loading chunk {chunk_file}: {e}")
467
+ continue
468
+
469
+ # Clean up chunk files
470
+ self.cleanup_chunks()
471
+
472
+ logger.info(f"Merged {len(chunk_files)} chunks into {len(merged_data)} samples")
473
+ return merged_data
474
+
475
+ def cleanup_chunks(self):
476
+ chunk_files = list(self.output_dir.glob("processed_chunk_*.pkl"))
477
+ for chunk_file in chunk_files:
478
+ try:
479
+ chunk_file.unlink()
480
+ except Exception as e:
481
+ logger.warning(f"Could not delete chunk {chunk_file}: {e}")
482
+
483
+ if chunk_files:
484
+ logger.info(f"Cleaned up {len(chunk_files)} temporary chunk files")
485
+
486
+ def save_final_dataset(self, processed_data: List[Dict[str, Any]], stats: Dict[str, int]):
487
+ pickle_path = self.output_dir / "processed_dataset.pkl"
488
+ with open(pickle_path, 'wb') as f:
489
+ pickle.dump(processed_data, f)
490
+
491
+ error_samples = sum(1 for sample in processed_data if sample.get("error", False))
492
+ successful_samples = len(processed_data) - error_samples
493
+
494
+ metadata = {
495
+ "model_name": self.model_name,
496
+ "dataset_name": self.dataset_name,
497
+ "total_samples": stats["total_samples"],
498
+ "processed_samples": len(processed_data),
499
+ "successful_samples": successful_samples,
500
+ "error_samples": error_samples,
501
+ "batch_size": self.batch_size,
502
+ "max_samples": self.max_samples,
503
+ "success_rate": f"{(successful_samples / len(processed_data) * 100):.2f}%" if processed_data else "0%"
504
+ }
505
+
506
+ with open(self.output_dir / "metadata.json", 'w', encoding='utf-8') as f:
507
+ json.dump(metadata, f, indent=2, ensure_ascii=False)
508
+
509
+ text_samples = []
510
+ count = 0
511
+ for i, sample in enumerate(processed_data):
512
+ if not sample.get("error", False) and count < 10:
513
+ text_samples.append({
514
+ "sample_id": i,
515
+ "input_text": sample["input_text"][:300] + "..." if len(sample["input_text"]) > 300 else sample["input_text"],
516
+ "response": sample["response"][:300] + "..." if len(sample["response"]) > 300 else sample["response"],
517
+ "embedding_shape": sample["embedding_shape"]
518
+ })
519
+ count += 1
520
+
521
+ with open(self.output_dir / "samples.json", 'w', encoding='utf-8') as f:
522
+ json.dump(text_samples, f, indent=2, ensure_ascii=False)
523
+
524
+ logger.info(f"Dataset saved: {len(processed_data)} samples")
525
+ logger.info(f"Success rate: {metadata['success_rate']}")
526
+
527
+ def process_dataset(self):
528
+ dataset = self.load_dataset()
529
+
530
+ logger.info("Loading model...")
531
+ model, tokenizer = create_model_with_custom_layer0(self.model_name)
532
+
533
+ total_samples = len(dataset)
534
+ processed_count = 0
535
+ error_count = 0
536
+
537
+ logger.info(f"Processing {total_samples} samples...")
538
+
539
+ for i in tqdm(range(0, total_samples, self.batch_size), desc="Processing"):
540
+ batch_end = min(i + self.batch_size, total_samples)
541
+ batch_samples = dataset.select(range(i, batch_end))
542
+
543
+ batch_texts = []
544
+ batch_instructions = []
545
+ batch_responses = []
546
+
547
+ try:
548
+ for sample in batch_samples:
549
+ instruction = sample.get("instruction", "")
550
+ response = sample.get("response", "")
551
+
552
+ if not instruction.strip() or not response.strip():
553
+ instruction = "Empty instruction"
554
+ response = "Empty response"
555
+
556
+ input_text = self.format_input_text(instruction)
557
+ batch_texts.append(input_text)
558
+ batch_instructions.append(input_text)
559
+ batch_responses.append(response)
560
+
561
+ embeddings_batch = self.process_embeddings_batch(model, tokenizer, batch_texts)
562
+
563
+ for j, (input_text, embedding, response) in enumerate(zip(batch_instructions, embeddings_batch, batch_responses)):
564
+ processed_sample = {
565
+ "input_text": input_text,
566
+ "inputs_embeds": embedding,
567
+ "response": response,
568
+ "embedding_shape": list(embedding.shape),
569
+ "original_index": i + j
570
+ }
571
+ self.processed_data_buffer.append(processed_sample)
572
+ processed_count += 1
573
+
574
+ if len(self.processed_data_buffer) >= self.save_frequency:
575
+ self.save_chunk(self.processed_data_buffer, self.current_chunk)
576
+ self.processed_data_buffer = []
577
+ self.current_chunk += 1
578
+
579
+ import gc
580
+ gc.collect()
581
+
582
+ except Exception as e:
583
+ logger.error(f"Error processing batch: {e}")
584
+ error_count += len(batch_samples)
585
+
586
+ if self.processed_data_buffer:
587
+ self.save_chunk(self.processed_data_buffer, self.current_chunk)
588
+ self.processed_data_buffer = []
589
+
590
+ merged_data = self.merge_chunks()
591
+
592
+ stats = {
593
+ "total_samples": total_samples,
594
+ "processed_count": processed_count,
595
+ "error_count": error_count
596
+ }
597
+
598
+ self.save_final_dataset(merged_data, stats)
599
+ return merged_data
600
+
601
+
602
+ def load_processed_dataset(dataset_path: str) -> List[Dict[str, Any]]:
603
+ pickle_path = Path(dataset_path) / "processed_dataset.pkl"
604
+ with open(pickle_path, 'rb') as f:
605
+ return pickle.load(f)
606
+
607
+
608
+ def get_dataset_info(dataset_path: str) -> Dict:
609
+ metadata_path = Path(dataset_path) / "metadata.json"
610
+ with open(metadata_path, 'r') as f:
611
+ return json.load(f)
612
+
613
+
614
+ def main():
615
+ model_name = "Qwen/Qwen3-0.6B"
616
+ dataset_name = "Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1"
617
+ output_dir = "./processed_qwen3_dataset"
618
+ batch_size = 1
619
+ max_samples = 10000 # Set to number for testing, None for full dataset
620
+ save_frequency = 1000
621
+
622
+ logger.info("Starting Qwen3 dataset processing...")
623
+ logger.info(f"Model: {model_name}")
624
+ logger.info(f"Dataset: {dataset_name}")
625
+ logger.info(f"Output: {output_dir}")
626
+ logger.info(f"Max samples: {max_samples or 'ALL'}")
627
+
628
+ try:
629
+ processor = DatasetProcessor(
630
+ model_name=model_name,
631
+ dataset_name=dataset_name,
632
+ output_dir=output_dir,
633
+ batch_size=batch_size,
634
+ max_samples=max_samples,
635
+ save_frequency=save_frequency
636
+ )
637
+
638
+ processed_data = processor.process_dataset()
639
+
640
+ logger.info("Processing completed successfully!")
641
+ logger.info(f"Final dataset: {len(processed_data)} samples")
642
+ logger.info(f"Files saved to: {output_dir}")
643
+
644
+ return processed_data
645
+
646
+ except Exception as e:
647
+ logger.error(f"Processing failed: {e}")
648
+ raise
649
+
650
+
651
+ if __name__ == "__main__":
652
+ main()
processed_qwen3_dataset/metadata.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "Qwen/Qwen3-0.6B",
3
+ "dataset_name": "Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1",
4
+ "total_samples": 10000,
5
+ "processed_samples": 10000,
6
+ "successful_samples": 10000,
7
+ "error_samples": 0,
8
+ "batch_size": 1,
9
+ "max_samples": 10000,
10
+ "success_rate": "100.00%"
11
+ }
processed_qwen3_dataset/processed_dataset.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:696e5df891b54ac7ea3f1b1c3a69ae24f8a7e0b8e96600e4105585b8882af2eb
3
+ size 1029880114
processed_qwen3_dataset/samples.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "sample_id": 0,
4
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nCan you tell me about the latest developments in AI research that could impact the field of healthcare?<|im_end|>\n<|im_start|>assistant\n",
5
+ "response": "Certainly! The intersection of AI and healthcare is a rapidly evolving field with numerous recent developments that hold significant promise. Here are some of the latest advancements:\n\n1. **AI in Diagnostics**:\n - **Imaging Analysis**: AI algorithms, particularly deep learning models, are becoming...",
6
+ "embedding_shape": [
7
+ 47,
8
+ 1024
9
+ ]
10
+ },
11
+ {
12
+ "sample_id": 1,
13
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the most effective way to learn a new language, especially for someone who is already an adult and has a busy schedule?<|im_end|>\n<|im_start|>assistant\n",
14
+ "response": "Learning a new language as an adult with a busy schedule can be challenging, but it's definitely achievable with the right strategies and consistent effort. Here are some effective methods to help you learn a new language efficiently:\n\n1. **Set Clear Goals**: Define what you want to achieve with the...",
15
+ "embedding_shape": [
16
+ 54,
17
+ 1024
18
+ ]
19
+ },
20
+ {
21
+ "sample_id": 2,
22
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nI need to create a marketing plan for a new line of organic skincare products. What are the key elements I should consider?<|im_end|>\n<|im_start|>assistant\n",
23
+ "response": "Creating a marketing plan for a new line of organic skincare products involves several key elements to ensure the successful launch and sustained growth of your brand. Here’s a comprehensive guide to help you develop an effective marketing plan:\n\n### 1. **Market Research**\n - **Target Audience**: ...",
24
+ "embedding_shape": [
25
+ 53,
26
+ 1024
27
+ ]
28
+ },
29
+ {
30
+ "sample_id": 3,
31
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the role of a project manager in a software development team?<|im_end|>\n<|im_start|>assistant\n",
32
+ "response": "The role of a project manager in a software development team is multifaceted and crucial for the success of the project. Here are some of the key responsibilities and tasks typically associated with this role:\n\n1. **Project Planning**: The project manager is responsible for defining the scope of the...",
33
+ "embedding_shape": [
34
+ 42,
35
+ 1024
36
+ ]
37
+ },
38
+ {
39
+ "sample_id": 4,
40
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nI need to create a social media campaign for a new line of eco-friendly products. Can you help me with some ideas?<|im_end|>\n<|im_start|>assistant\n",
41
+ "response": "Absolutely, I'd be happy to help you brainstorm some ideas for a social media campaign for your new line of eco-friendly products! Here are a few strategies and ideas to consider:\n\n### 1. **Eco-Challenges**\n- **30-Day Eco-Challenge:** Encourage followers to adopt one eco-friendly habit each day for ...",
42
+ "embedding_shape": [
43
+ 52,
44
+ 1024
45
+ ]
46
+ },
47
+ {
48
+ "sample_id": 5,
49
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nCan you explain the concept of \"dynamic equilibrium\" in chemistry? And could you provide an example to illustrate this concept?<|im_end|>\n<|im_start|>assistant\n",
50
+ "response": "Certainly! The concept of \"dynamic equilibrium\" in chemistry refers to a state in which the concentrations of the reactants and products in a reversible reaction remain constant over time, despite the fact that the forward and reverse reactions are still occurring at the same rate. This means that w...",
51
+ "embedding_shape": [
52
+ 51,
53
+ 1024
54
+ ]
55
+ },
56
+ {
57
+ "sample_id": 6,
58
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the best way to create a budget for a small business?<|im_end|>\n<|im_start|>assistant\n",
59
+ "response": "Creating a budget for a small business is a crucial step in managing finances effectively and ensuring the long-term success of the business. Here’s a step-by-step guide to help you create a comprehensive and effective budget:\n\n### 1. **Define Your Business Goals**\n - **Short-term and Long-term Go...",
60
+ "embedding_shape": [
61
+ 42,
62
+ 1024
63
+ ]
64
+ },
65
+ {
66
+ "sample_id": 7,
67
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nCan you provide a brief overview of Alibaba Cloud's security measures and how they protect user data?<|im_end|>\n<|im_start|>assistant\n",
68
+ "response": "Certainly! Alibaba Cloud places a strong emphasis on security and has implemented a comprehensive set of measures to protect user data and ensure the integrity, confidentiality, and availability of its services. Here’s a brief overview of some key security measures:\n\n1. **Data Encryption**: Alibaba ...",
69
+ "embedding_shape": [
70
+ 47,
71
+ 1024
72
+ ]
73
+ },
74
+ {
75
+ "sample_id": 8,
76
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nI am interested in learning more about the cultural significance of the Chinese zodiac. Could you tell me about its origin and how it is used in Chinese culture today?<|im_end|>\n<|im_st...",
77
+ "response": "Certainly! The Chinese zodiac, known as \"Sheng Xiao\" (η”Ÿθ‚–) in Chinese, is a classification scheme that assigns an animal and its reputed attributes to each year in a repeating 12-year cycle. The 12 animals of the Chinese zodiac are the Rat, Ox, Tiger, Rabbit, Dragon, Snake, Horse, Goat, Monkey, Roost...",
78
+ "embedding_shape": [
79
+ 60,
80
+ 1024
81
+ ]
82
+ },
83
+ {
84
+ "sample_id": 9,
85
+ "input_text": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHow can I use a Virtual Private Cloud (VPC) to enhance the security of my cloud applications?<|im_end|>\n<|im_start|>assistant\n",
86
+ "response": "Using a Virtual Private Cloud (VPC) can significantly enhance the security of your cloud applications by providing a secure and isolated network environment. Here are several ways you can leverage a VPC to improve security:\n\n### 1. **Network Isolation**\n - **Private Network**: A VPC allows you to ...",
87
+ "embedding_shape": [
88
+ 48,
89
+ 1024
90
+ ]
91
+ }
92
+ ]
train_custom_qwen3.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import json
5
+ import pickle
6
+ from typing import Optional, Tuple, List, Dict, Any, Union
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torch.optim import AdamW
16
+ from torch.optim.lr_scheduler import CosineAnnealingLR
17
+
18
+ from datasets import load_dataset, Dataset as HFDataset
19
+ from transformers import (
20
+ AutoTokenizer,
21
+ AutoModelForCausalLM,
22
+ get_linear_schedule_with_warmup,
23
+ PreTrainedModel
24
+ )
25
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
26
+ from transformers.models.qwen3.modeling_qwen3 import (
27
+ Qwen3Model,
28
+ Qwen3ForCausalLM,
29
+ Qwen3PreTrainedModel,
30
+ Qwen3RMSNorm
31
+ )
32
+
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format="%(asctime)s - %(levelname)s - %(message)s",
36
+ handlers=[logging.StreamHandler(sys.stdout)],
37
+ force=True,
38
+ )
39
+ logger = logging.getLogger("grouped_qwen3_training")
40
+
41
+
42
+ class GroupedInputMLPAdapter(nn.Module):
43
+ def __init__(self, config):
44
+ super().__init__()
45
+ self.config = config
46
+ hidden_size = config.hidden_size
47
+
48
+ self.grouped_processor = nn.Sequential(
49
+ nn.Linear(hidden_size, hidden_size * 2),
50
+ nn.SiLU(), # Using SiLU activation like Qwen3
51
+ nn.Dropout(0.1),
52
+ nn.Linear(hidden_size * 2, hidden_size),
53
+ nn.Dropout(0.1)
54
+ )
55
+
56
+ norm_eps = getattr(config, 'rms_norm_eps', 1e-6)
57
+ self.layer_norm = Qwen3RMSNorm(hidden_size, eps=norm_eps)
58
+
59
+ def forward(self, grouped_embeds: torch.Tensor) -> torch.Tensor:
60
+ processed = self.grouped_processor(grouped_embeds)
61
+
62
+ output = self.layer_norm(grouped_embeds + processed)
63
+
64
+ return output
65
+
66
+
67
+ class CustomQwen3ForCausalLM(Qwen3ForCausalLM):
68
+ def __init__(self, config):
69
+ super().__init__(config)
70
+
71
+ self.grouped_input_mlp = GroupedInputMLPAdapter(config)
72
+
73
+ self.is_grouped_input_mode = False
74
+ self.grouped_cache_initialized = False
75
+
76
+ self._init_grouped_weights()
77
+
78
+ self._freeze_layers()
79
+
80
+ def _init_grouped_weights(self):
81
+ def _init_weights(module):
82
+ if isinstance(module, nn.Linear):
83
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
84
+ if module.bias is not None:
85
+ torch.nn.init.zeros_(module.bias)
86
+ elif isinstance(module, nn.LayerNorm):
87
+ torch.nn.init.ones_(module.weight)
88
+ torch.nn.init.zeros_(module.bias)
89
+
90
+ self.grouped_input_mlp.apply(_init_weights)
91
+
92
+ def _freeze_layers(self):
93
+ for param in self.model.embed_tokens.parameters():
94
+ param.requires_grad = False
95
+
96
+ for i, layer in enumerate(self.model.layers):
97
+ if i == 0:
98
+ for param in layer.parameters():
99
+ param.requires_grad = True
100
+ else:
101
+ for param in layer.parameters():
102
+ param.requires_grad = False
103
+
104
+ for param in self.model.norm.parameters():
105
+ param.requires_grad = False
106
+
107
+ for param in self.lm_head.parameters():
108
+ param.requires_grad = False
109
+
110
+ for param in self.grouped_input_mlp.parameters():
111
+ param.requires_grad = True
112
+
113
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
114
+ total_params = sum(p.numel() for p in self.parameters())
115
+ logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} "
116
+ f"({trainable_params/total_params*100:.2f}%)")
117
+
118
+ def forward(
119
+ self,
120
+ input_ids: Optional[torch.LongTensor] = None,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ position_ids: Optional[torch.LongTensor] = None,
123
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
124
+ inputs_embeds: Optional[torch.FloatTensor] = None,
125
+ labels: Optional[torch.LongTensor] = None,
126
+ use_cache: Optional[bool] = None,
127
+ output_attentions: Optional[bool] = None,
128
+ output_hidden_states: Optional[bool] = None,
129
+ return_dict: Optional[bool] = None,
130
+ cache_position: Optional[torch.LongTensor] = None,
131
+ grouped_inputs: Optional[torch.FloatTensor] = None, # New parameter for grouped inputs
132
+ is_prefill: Optional[bool] = None, # Flag to indicate prefill phase
133
+ **kwargs
134
+ ):
135
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
136
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
137
+
138
+ if grouped_inputs is not None and is_prefill:
139
+ self.is_grouped_input_mode = True
140
+
141
+ processed_grouped_inputs = self.grouped_input_mlp(grouped_inputs)
142
+
143
+ inputs_embeds = processed_grouped_inputs
144
+ input_ids = None # Don't use input_ids when we have grouped inputs
145
+
146
+ batch_size, seq_len = inputs_embeds.shape[:2]
147
+ if position_ids is None:
148
+ device = inputs_embeds.device
149
+ position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
150
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
151
+
152
+ if attention_mask is None:
153
+ attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long)
154
+
155
+ self.grouped_cache_initialized = True
156
+
157
+ elif not is_prefill and self.is_grouped_input_mode:
158
+ pass
159
+ else:
160
+ self.is_grouped_input_mode = False
161
+
162
+ # Call parent forward
163
+ outputs = super().forward(
164
+ input_ids=input_ids,
165
+ attention_mask=attention_mask,
166
+ position_ids=position_ids,
167
+ past_key_values=past_key_values,
168
+ inputs_embeds=inputs_embeds,
169
+ labels=labels,
170
+ use_cache=use_cache,
171
+ output_attentions=output_attentions,
172
+ output_hidden_states=output_hidden_states,
173
+ return_dict=return_dict,
174
+ cache_position=cache_position,
175
+ **kwargs
176
+ )
177
+
178
+ return outputs
179
+
180
+ class GroupedDataset(Dataset):
181
+ def __init__(self, dataset_path: str, tokenizer, max_response_length: int = 512,
182
+ validation_split: float = 0.1, is_validation: bool = False,
183
+ chunk_size: int = 1000, max_samples: Optional[int] = None):
184
+ self.dataset_path = dataset_path
185
+ self.tokenizer = tokenizer
186
+ self.max_response_length = max_response_length
187
+ self.validation_split = validation_split
188
+ self.is_validation = is_validation
189
+ self.chunk_size = chunk_size
190
+ self.max_samples = max_samples
191
+
192
+ self._chunk_cache = {}
193
+ self._cache_size_limit = 3 # Keep max 3 chunks in memory
194
+
195
+ self._build_index()
196
+
197
+ def _build_index(self):
198
+ logger.info(f"Building index for {self.dataset_path}")
199
+
200
+ with open(self.dataset_path, 'rb') as f:
201
+ data = pickle.load(f)
202
+
203
+ valid_indices = []
204
+ for i, item in enumerate(data):
205
+ if not item.get("error", False):
206
+ valid_indices.append(i)
207
+
208
+ if self.max_samples and len(valid_indices) >= self.max_samples:
209
+ break
210
+
211
+ total_valid = len(valid_indices)
212
+
213
+ val_size = min(1000, int(self.validation_split * total_valid))
214
+ train_size = total_valid - val_size
215
+
216
+ if self.is_validation:
217
+ self.valid_indices = valid_indices[train_size:train_size + val_size]
218
+ self.total_samples = val_size
219
+ else:
220
+ self.valid_indices = valid_indices[:train_size]
221
+ self.total_samples = train_size
222
+
223
+ self._full_data = data
224
+
225
+ logger.info(f"{'Validation' if self.is_validation else 'Training'} dataset: {self.total_samples} samples")
226
+
227
+ def _get_chunk_id(self, idx):
228
+ return idx // self.chunk_size
229
+
230
+ def _load_chunk(self, chunk_id):
231
+ if chunk_id in self._chunk_cache:
232
+ return self._chunk_cache[chunk_id]
233
+
234
+ start_idx = chunk_id * self.chunk_size
235
+ end_idx = min(start_idx + self.chunk_size, self.total_samples)
236
+
237
+ chunk_data = {}
238
+ for i in range(start_idx, end_idx):
239
+ actual_idx = self.valid_indices[i]
240
+ chunk_data[i] = self._full_data[actual_idx]
241
+
242
+ if len(self._chunk_cache) >= self._cache_size_limit:
243
+ oldest_chunk = min(self._chunk_cache.keys())
244
+ del self._chunk_cache[oldest_chunk]
245
+
246
+ self._chunk_cache[chunk_id] = chunk_data
247
+ return chunk_data
248
+
249
+ def __len__(self):
250
+ return self.total_samples
251
+
252
+ def __getitem__(self, idx):
253
+ if idx >= self.total_samples:
254
+ raise IndexError(f"Index {idx} out of range for dataset of size {self.total_samples}")
255
+
256
+ chunk_id = self._get_chunk_id(idx)
257
+ chunk_data = self._load_chunk(chunk_id)
258
+ item = chunk_data[idx]
259
+
260
+ return self._process_item(item)
261
+
262
+ def _process_item(self, item):
263
+ grouped_embeds = item["inputs_embeds"]
264
+ if isinstance(grouped_embeds, torch.Tensor):
265
+ grouped_embeds = grouped_embeds.clone()
266
+ else:
267
+ grouped_embeds = torch.tensor(grouped_embeds)
268
+
269
+ if grouped_embeds.dtype != torch.float32:
270
+ grouped_embeds = grouped_embeds.float()
271
+
272
+ response = item["response"]
273
+
274
+ response_tokens = self.tokenizer(
275
+ response,
276
+ max_length=self.max_response_length,
277
+ truncation=True,
278
+ padding=False,
279
+ return_tensors="pt"
280
+ )
281
+
282
+ response_input_ids = response_tokens["input_ids"].squeeze(0)
283
+
284
+ return {
285
+ "grouped_inputs": grouped_embeds,
286
+ "response_input_ids": response_input_ids,
287
+ "response_text": response,
288
+ "input_text": item["input_text"],
289
+ }
290
+
291
+ def cleanup(self):
292
+ self._chunk_cache.clear()
293
+ if hasattr(self, '_full_data'):
294
+ del self._full_data
295
+
296
+
297
+ def collate_fn(batch, tokenizer, pad_token_id=None):
298
+ if pad_token_id is None:
299
+ pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
300
+
301
+ grouped_inputs = [item["grouped_inputs"] for item in batch]
302
+ response_input_ids = [item["response_input_ids"] for item in batch]
303
+
304
+ max_grouped_len = max(gi.shape[0] for gi in grouped_inputs)
305
+ batch_size = len(grouped_inputs)
306
+ hidden_size = grouped_inputs[0].shape[-1]
307
+
308
+ padded_grouped_inputs = torch.zeros(batch_size, max_grouped_len, hidden_size)
309
+ grouped_attention_mask = torch.zeros(batch_size, max_grouped_len, dtype=torch.long)
310
+
311
+ for i, gi in enumerate(grouped_inputs):
312
+ seq_len = gi.shape[0]
313
+ padded_grouped_inputs[i, :seq_len] = gi
314
+ grouped_attention_mask[i, :seq_len] = 1
315
+
316
+ max_response_len = max(len(rid) for rid in response_input_ids)
317
+ padded_response_ids = torch.full((batch_size, max_response_len), pad_token_id, dtype=torch.long)
318
+
319
+ for i, rid in enumerate(response_input_ids):
320
+ padded_response_ids[i, :len(rid)] = rid
321
+
322
+ return {
323
+ "grouped_inputs": padded_grouped_inputs,
324
+ "grouped_attention_mask": grouped_attention_mask,
325
+ "response_input_ids": padded_response_ids,
326
+ "response_texts": [item["response_text"] for item in batch],
327
+ "input_texts": [item["input_text"] for item in batch],
328
+ }
329
+
330
+ class TrainingState:
331
+ def __init__(self, output_dir: Path):
332
+ self.output_dir = output_dir
333
+ self.state_file = output_dir / "training_state.json"
334
+
335
+ def save_state(self, epoch: int, global_step: int, best_val_loss: float,
336
+ optimizer_state: Dict, scheduler_state: Dict):
337
+ """Save training state."""
338
+ state = {
339
+ "epoch": epoch,
340
+ "global_step": global_step,
341
+ "best_val_loss": best_val_loss,
342
+ "optimizer_state": optimizer_state,
343
+ "scheduler_state": scheduler_state,
344
+ "completed_epochs": epoch
345
+ }
346
+
347
+ with open(self.state_file, 'w') as f:
348
+ json.dump(state, f, indent=2, default=str) # default=str for handling tensor types
349
+
350
+ logger.info(f"Saved training state at epoch {epoch}, step {global_step}")
351
+
352
+ def load_state(self):
353
+ if not self.state_file.exists():
354
+ return None
355
+
356
+ try:
357
+ with open(self.state_file, 'r') as f:
358
+ state = json.load(f)
359
+ logger.info(f"Loaded training state from epoch {state['epoch']}, step {state['global_step']}")
360
+ return state
361
+ except Exception as e:
362
+ logger.warning(f"Failed to load training state: {e}")
363
+ return None
364
+
365
+ def get_latest_checkpoint(self):
366
+ state = self.load_state()
367
+ if state is None:
368
+ return None
369
+
370
+ epoch = state["completed_epochs"]
371
+ checkpoint_path = self.output_dir / f"epoch_{epoch}"
372
+
373
+ if checkpoint_path.exists():
374
+ return checkpoint_path, state
375
+ else:
376
+ logger.warning(f"Checkpoint for epoch {epoch} not found")
377
+ return None
378
+
379
+ class GroupedTrainer:
380
+ def __init__(
381
+ self,
382
+ model_name: str = "Qwen/Qwen3-0.6B",
383
+ dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
384
+ output_dir: str = "./grouped_qwen3_checkpoint",
385
+ batch_size: int = 4,
386
+ learning_rate: float = 5e-4,
387
+ num_epochs: int = 3,
388
+ warmup_steps: int = 100,
389
+ max_grad_norm: float = 1.0,
390
+ save_steps: int = 500,
391
+ eval_steps: int = 500,
392
+ logging_steps: int = 50,
393
+ resume_training: bool = True,
394
+ debug: bool = False,
395
+ chunk_size: int = 1000, # Chunk size for streaming
396
+ max_samples: Optional[int] = None, # Limit dataset size for testing
397
+ ):
398
+ self.model_name = model_name
399
+ self.dataset_path = dataset_path
400
+ self.output_dir = Path(output_dir)
401
+ self.batch_size = batch_size
402
+ self.learning_rate = learning_rate
403
+ self.num_epochs = num_epochs
404
+ self.warmup_steps = warmup_steps
405
+ self.max_grad_norm = max_grad_norm
406
+ self.save_steps = save_steps
407
+ self.eval_steps = eval_steps
408
+ self.logging_steps = logging_steps
409
+ self.resume_training = resume_training
410
+ self.debug = debug
411
+ self.chunk_size = chunk_size
412
+ self.max_samples = max_samples
413
+
414
+ if self.debug:
415
+ logger.setLevel(logging.DEBUG)
416
+
417
+ self.output_dir.mkdir(parents=True, exist_ok=True)
418
+
419
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
420
+ self.tokenizer = None
421
+ self.model = None
422
+ self.train_dataset = None
423
+ self.val_dataset = None
424
+
425
+ self.training_state = TrainingState(self.output_dir)
426
+
427
+ def load_model_and_tokenizer(self):
428
+ logger.info(f"Loading tokenizer and model: {self.model_name}")
429
+
430
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
431
+
432
+ if self.tokenizer.pad_token is None:
433
+ self.tokenizer.pad_token = self.tokenizer.eos_token
434
+
435
+ config = Qwen3Config.from_pretrained(self.model_name)
436
+
437
+ self.model = CustomQwen3ForCausalLM.from_pretrained(
438
+ self.model_name,
439
+ config=config,
440
+ torch_dtype=torch.float32, # Use float32 for training
441
+ attn_implementation="eager"
442
+ ).to(self.device)
443
+
444
+ logger.info(f"Model loaded on {self.device}")
445
+
446
+ def load_dataset(self, chunk_size: int = 1000, max_samples: Optional[int] = None):
447
+ logger.info(f"Loading streaming dataset from {self.dataset_path}")
448
+
449
+ # Create streaming datasets
450
+ self.train_dataset = GroupedDataset(
451
+ dataset_path=self.dataset_path,
452
+ tokenizer=self.tokenizer,
453
+ is_validation=False,
454
+ chunk_size=chunk_size,
455
+ max_samples=max_samples
456
+ )
457
+
458
+ self.val_dataset = GroupedDataset(
459
+ dataset_path=self.dataset_path,
460
+ tokenizer=self.tokenizer,
461
+ is_validation=True,
462
+ chunk_size=chunk_size,
463
+ max_samples=max_samples
464
+ )
465
+
466
+ logger.info(f"Train samples: {len(self.train_dataset)}")
467
+ logger.info(f"Val samples: {len(self.val_dataset)}")
468
+
469
+ # Log memory usage
470
+ if torch.cuda.is_available():
471
+ torch.cuda.empty_cache()
472
+ memory_used = torch.cuda.memory_allocated() / 1024**3
473
+ logger.info(f"GPU memory after dataset loading: {memory_used:.2f} GB")
474
+
475
+ def cleanup_datasets(self):
476
+ if hasattr(self.train_dataset, 'cleanup'):
477
+ self.train_dataset.cleanup()
478
+ if hasattr(self.val_dataset, 'cleanup'):
479
+ self.val_dataset.cleanup()
480
+
481
+ import gc
482
+ gc.collect()
483
+
484
+ if torch.cuda.is_available():
485
+ torch.cuda.empty_cache()
486
+
487
+ def load_checkpoint(self, checkpoint_path: Path):
488
+ logger.info(f"Loading checkpoint from {checkpoint_path}")
489
+
490
+ model_path = checkpoint_path / "pytorch_model.bin"
491
+ if not model_path.exists():
492
+ model_path = checkpoint_path / "model.safetensors"
493
+
494
+ if model_path.exists():
495
+ state_dict = torch.load(model_path, map_location=self.device)
496
+ missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
497
+
498
+ if missing_keys:
499
+ logger.warning(f"Missing keys when loading checkpoint: {missing_keys}")
500
+ if unexpected_keys:
501
+ logger.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
502
+
503
+ logger.info("Model checkpoint loaded successfully")
504
+ else:
505
+ raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
506
+
507
+ def compute_loss(self, batch, outputs):
508
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
509
+
510
+ target_ids = batch["response_input_ids"].to(self.device) # [batch_size, target_len]
511
+
512
+ logger.debug(f"Logits shape: {logits.shape}, Target shape: {target_ids.shape}")
513
+
514
+ batch_size = target_ids.shape[0]
515
+
516
+ if target_ids.shape[1] > 1:
517
+ labels = target_ids.clone()
518
+ pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
519
+ labels[labels == pad_token_id] = -100
520
+
521
+ seq_len = min(logits.shape[1], labels.shape[1])
522
+ logits_truncated = logits[:, :seq_len, :] # [batch_size, seq_len, vocab_size]
523
+ labels_truncated = labels[:, :seq_len] # [batch_size, seq_len]
524
+
525
+ logger.debug(f"After truncation - Logits: {logits_truncated.shape}, Labels: {labels_truncated.shape}")
526
+
527
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
528
+ loss = loss_fct(
529
+ logits_truncated.reshape(-1, logits_truncated.size(-1)),
530
+ labels_truncated.reshape(-1)
531
+ )
532
+ else:
533
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
534
+ loss = loss_fct(logits.view(-1, logits.size(-1)), target_ids.view(-1))
535
+
536
+ return loss
537
+
538
+ def training_step(self, batch, step):
539
+ self.model.train()
540
+
541
+ if step < 5 and torch.cuda.is_available():
542
+ torch.cuda.empty_cache()
543
+ memory_before = torch.cuda.memory_allocated() / 1024**3
544
+
545
+ grouped_inputs = batch["grouped_inputs"].to(self.device)
546
+ grouped_attention_mask = batch["grouped_attention_mask"].to(self.device)
547
+ response_input_ids = batch["response_input_ids"].to(self.device)
548
+
549
+ batch_size = grouped_inputs.shape[0]
550
+ grouped_seq_len = grouped_inputs.shape[1]
551
+ response_seq_len = response_input_ids.shape[1]
552
+
553
+ if self.debug:
554
+ logger.debug(f"Batch sizes - grouped: {grouped_inputs.shape}, response: {response_input_ids.shape}")
555
+
556
+ grouped_outputs = self.model(
557
+ grouped_inputs=grouped_inputs,
558
+ attention_mask=grouped_attention_mask,
559
+ is_prefill=True,
560
+ use_cache=True,
561
+ return_dict=True
562
+ )
563
+
564
+ if response_seq_len > 1:
565
+ response_attention_mask = (response_input_ids != self.tokenizer.pad_token_id).long()
566
+
567
+ response_outputs = self.model(
568
+ input_ids=response_input_ids[:, :-1], # All but last token as input
569
+ attention_mask=response_attention_mask[:, :-1],
570
+ past_key_values=grouped_outputs.past_key_values,
571
+ use_cache=False,
572
+ return_dict=True
573
+ )
574
+
575
+ logits = response_outputs.logits
576
+ labels = response_input_ids[:, 1:] # All but first token as targets
577
+
578
+ labels = labels.clone()
579
+ pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
580
+ labels[labels == pad_token_id] = -100
581
+
582
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
583
+ loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
584
+
585
+ else:
586
+ loss = torch.tensor(0.0, requires_grad=True, device=self.device)
587
+
588
+ if step < 5 and torch.cuda.is_available():
589
+ memory_after = torch.cuda.memory_allocated() / 1024**3
590
+ memory_peak = torch.cuda.max_memory_allocated() / 1024**3
591
+ logger.info(f"Step {step} Memory: {memory_before:.2f}GB β†’ {memory_after:.2f}GB (Peak: {memory_peak:.2f}GB)")
592
+
593
+ if memory_peak > 20.0: # 20GB threshold for L4
594
+ logger.warning("High memory usage detected! Consider reducing batch_size")
595
+
596
+ class MockOutputs:
597
+ def __init__(self, loss, logits):
598
+ self.loss = loss
599
+ self.logits = logits
600
+
601
+ outputs = MockOutputs(loss, response_outputs.logits if 'response_outputs' in locals() else grouped_outputs.logits)
602
+
603
+ return loss, outputs
604
+
605
+ def validation_step(self, batch):
606
+ """Single validation step."""
607
+ self.model.eval()
608
+
609
+ with torch.no_grad():
610
+ grouped_inputs = batch["grouped_inputs"].to(self.device)
611
+ grouped_attention_mask = batch["grouped_attention_mask"].to(self.device)
612
+ response_input_ids = batch["response_input_ids"].to(self.device)
613
+
614
+ grouped_outputs = self.model(
615
+ grouped_inputs=grouped_inputs,
616
+ attention_mask=grouped_attention_mask,
617
+ is_prefill=True,
618
+ use_cache=True,
619
+ return_dict=True
620
+ )
621
+
622
+ if response_input_ids.shape[1] > 1:
623
+ response_attention_mask = (response_input_ids != self.tokenizer.pad_token_id).long()
624
+
625
+ response_outputs = self.model(
626
+ input_ids=response_input_ids[:, :-1],
627
+ attention_mask=response_attention_mask[:, :-1],
628
+ past_key_values=grouped_outputs.past_key_values,
629
+ use_cache=False,
630
+ return_dict=True
631
+ )
632
+
633
+ logits = response_outputs.logits
634
+ labels = response_input_ids[:, 1:]
635
+
636
+ labels = labels.clone()
637
+ pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
638
+ labels[labels == pad_token_id] = -100
639
+
640
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
641
+ loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
642
+ else:
643
+ loss = torch.tensor(0.0, device=self.device)
644
+
645
+ return loss.item()
646
+
647
+ def save_epoch_checkpoint(self, epoch: int, global_step: int, is_best: bool = False):
648
+ checkpoint_name = f"epoch_{epoch}"
649
+ if is_best:
650
+ checkpoint_name += "_best"
651
+
652
+ checkpoint_dir = self.output_dir / checkpoint_name
653
+ checkpoint_dir.mkdir(exist_ok=True)
654
+
655
+ torch.save(self.model.state_dict(), checkpoint_dir / "pytorch_model.bin")
656
+
657
+ self.model.config.save_pretrained(checkpoint_dir)
658
+
659
+ self.tokenizer.save_pretrained(checkpoint_dir)
660
+
661
+ metadata = {
662
+ "epoch": epoch,
663
+ "global_step": global_step,
664
+ "model_name": self.model_name,
665
+ "learning_rate": self.learning_rate,
666
+ "batch_size": self.batch_size,
667
+ "is_best": is_best,
668
+ "model_class": "CustomQwen3ForCausalLM"
669
+ }
670
+
671
+ with open(checkpoint_dir / "epoch_metadata.json", 'w') as f:
672
+ json.dump(metadata, f, indent=2)
673
+
674
+ logger.info(f"Saved epoch checkpoint: {checkpoint_dir}")
675
+ return checkpoint_dir
676
+
677
+ def train(self):
678
+ logger.info("Starting training...")
679
+
680
+ train_loader = DataLoader(
681
+ self.train_dataset,
682
+ batch_size=self.batch_size,
683
+ shuffle=True,
684
+ collate_fn=lambda batch: collate_fn(batch, self.tokenizer),
685
+ num_workers=0 # Avoid multiprocessing issues with custom collate_fn
686
+ )
687
+
688
+ val_loader = DataLoader(
689
+ self.val_dataset,
690
+ batch_size=self.batch_size,
691
+ shuffle=False,
692
+ collate_fn=lambda batch: collate_fn(batch, self.tokenizer),
693
+ num_workers=0
694
+ )
695
+
696
+ optimizer = AdamW(
697
+ [p for p in self.model.parameters() if p.requires_grad],
698
+ lr=self.learning_rate,
699
+ weight_decay=0.01
700
+ )
701
+
702
+ total_steps = len(train_loader) * self.num_epochs
703
+ scheduler = get_linear_schedule_with_warmup(
704
+ optimizer,
705
+ num_warmup_steps=self.warmup_steps,
706
+ num_training_steps=total_steps
707
+ )
708
+
709
+ start_epoch = 0
710
+ global_step = 0
711
+ best_val_loss = float('inf')
712
+
713
+ if self.resume_training:
714
+ checkpoint_info = self.training_state.get_latest_checkpoint()
715
+ if checkpoint_info is not None:
716
+ checkpoint_path, state = checkpoint_info
717
+
718
+ self.load_checkpoint(checkpoint_path)
719
+
720
+ start_epoch = state["completed_epochs"]
721
+ global_step = state["global_step"]
722
+ best_val_loss = state["best_val_loss"]
723
+
724
+ if "optimizer_state" in state and state["optimizer_state"]:
725
+ try:
726
+ optimizer.load_state_dict(state["optimizer_state"])
727
+ except Exception as e:
728
+ logger.warning(f"Failed to load optimizer state: {e}")
729
+
730
+ if "scheduler_state" in state and state["scheduler_state"]:
731
+ try:
732
+ scheduler.load_state_dict(state["scheduler_state"])
733
+ except Exception as e:
734
+ logger.warning(f"Failed to load scheduler state: {e}")
735
+
736
+ logger.info(f"Resumed training from epoch {start_epoch + 1}")
737
+
738
+ for epoch in range(start_epoch, self.num_epochs):
739
+ logger.info(f"Epoch {epoch + 1}/{self.num_epochs}")
740
+
741
+ epoch_train_loss = 0
742
+ train_steps = 0
743
+
744
+ progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}")
745
+
746
+ for batch_idx, batch in enumerate(progress_bar):
747
+ try:
748
+ loss, outputs = self.training_step(batch, global_step)
749
+
750
+ loss.backward()
751
+
752
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
753
+
754
+ optimizer.step()
755
+ scheduler.step()
756
+ optimizer.zero_grad()
757
+
758
+ epoch_train_loss += loss.item()
759
+ train_steps += 1
760
+ global_step += 1
761
+
762
+ progress_bar.set_postfix({
763
+ 'loss': f'{loss.item():.4f}',
764
+ 'lr': f'{scheduler.get_last_lr()[0]:.2e}'
765
+ })
766
+
767
+ if global_step % self.logging_steps == 0:
768
+ avg_loss = epoch_train_loss / train_steps
769
+ logger.info(f"Step {global_step}: train_loss={avg_loss:.4f}, lr={scheduler.get_last_lr()[0]:.2e}")
770
+
771
+ if global_step % self.eval_steps == 0:
772
+ val_loss = self.validate(val_loader)
773
+ logger.info(f"Step {global_step}: val_loss={val_loss:.4f}")
774
+
775
+ if val_loss < best_val_loss:
776
+ best_val_loss = val_loss
777
+ best_checkpoint = self.save_epoch_checkpoint(epoch, global_step, is_best=True)
778
+ logger.info(f"New best validation loss: {val_loss:.4f}")
779
+
780
+ except Exception as e:
781
+ logger.error(f"Error in training step {global_step}: {e}")
782
+ continue
783
+
784
+ val_loss = self.validate(val_loader)
785
+ avg_train_loss = epoch_train_loss / train_steps if train_steps > 0 else 0
786
+
787
+ logger.info(f"Epoch {epoch + 1} completed:")
788
+ logger.info(f" Average train loss: {avg_train_loss:.4f}")
789
+ logger.info(f" Validation loss: {val_loss:.4f}")
790
+
791
+ is_best = val_loss < best_val_loss
792
+ if is_best:
793
+ best_val_loss = val_loss
794
+
795
+ checkpoint_dir = self.save_epoch_checkpoint(epoch, global_step, is_best=is_best)
796
+
797
+ self.training_state.save_state(
798
+ epoch=epoch,
799
+ global_step=global_step,
800
+ best_val_loss=best_val_loss,
801
+ optimizer_state=optimizer.state_dict(),
802
+ scheduler_state=scheduler.state_dict()
803
+ )
804
+
805
+ logger.info(f"Epoch {epoch + 1} checkpoint and state saved")
806
+
807
+ logger.info(f"Training completed! Best validation loss: {best_val_loss:.4f}")
808
+
809
+ final_checkpoint = self.save_epoch_checkpoint(self.num_epochs - 1, global_step, is_best=False)
810
+ logger.info(f"Final checkpoint saved: {final_checkpoint}")
811
+
812
+ def validate(self, val_loader):
813
+ self.model.eval()
814
+ total_loss = 0
815
+ num_batches = 0
816
+
817
+ with torch.no_grad():
818
+ for batch in tqdm(val_loader, desc="Validation"):
819
+ try:
820
+ loss = self.validation_step(batch)
821
+ total_loss += loss
822
+ num_batches += 1
823
+ except Exception as e:
824
+ logger.warning(f"Error in validation step: {e}")
825
+ continue
826
+
827
+ avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
828
+ self.model.train() # Set back to training mode
829
+ return avg_loss
830
+
831
+ def run(self):
832
+ try:
833
+ self.load_model_and_tokenizer()
834
+
835
+ self.load_dataset(
836
+ chunk_size=self.chunk_size,
837
+ max_samples=self.max_samples
838
+ )
839
+
840
+ self.train()
841
+
842
+ logger.info("Training pipeline completed successfully!")
843
+
844
+ self.cleanup_datasets()
845
+
846
+ except Exception as e:
847
+ logger.error(f"Training pipeline failed: {e}")
848
+ import traceback
849
+ logger.error(traceback.format_exc())
850
+
851
+ try:
852
+ self.cleanup_datasets()
853
+ except:
854
+ pass
855
+
856
+ raise
857
+
858
+ def load_trained_model(checkpoint_path: str, model_name: str = "Qwen/Qwen3-0.6B"):
859
+ logger.info(f"Loading trained model from {checkpoint_path}")
860
+
861
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
862
+
863
+ config = Qwen3Config.from_pretrained(checkpoint_path)
864
+
865
+ model = CustomQwen3ForCausalLM(config)
866
+
867
+ model_path = Path(checkpoint_path) / "pytorch_model.bin"
868
+ if not model_path.exists():
869
+ model_path = Path(checkpoint_path) / "model.safetensors"
870
+
871
+ if not model_path.exists():
872
+ raise FileNotFoundError(f"No model weights found in {checkpoint_path}")
873
+
874
+ state_dict = torch.load(model_path, map_location="cpu")
875
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
876
+
877
+ if missing_keys:
878
+ logger.warning(f"Missing keys when loading model: {missing_keys}")
879
+ if unexpected_keys:
880
+ logger.warning(f"Unexpected keys when loading model: {unexpected_keys}")
881
+
882
+ model = model.eval().to(torch.float32)
883
+
884
+ return model, tokenizer
885
+
886
+
887
+ def generate_with_grouped_input(
888
+ model,
889
+ tokenizer,
890
+ grouped_input: torch.Tensor,
891
+ max_length: int = 512,
892
+ temperature: float = 0.7,
893
+ do_sample: bool = True
894
+ ):
895
+ device = model.device
896
+ model_dtype = next(model.parameters()).dtype
897
+
898
+ grouped_input = grouped_input.to(device=device, dtype=model_dtype)
899
+
900
+ if grouped_input.ndim == 2:
901
+ grouped_input = grouped_input.unsqueeze(0) # Add batch dimension
902
+
903
+ logger.debug(f"Grouped input shape: {grouped_input.shape}, dtype: {grouped_input.dtype}")
904
+ logger.debug(f"Model dtype: {model_dtype}, device: {device}")
905
+
906
+ with torch.no_grad():
907
+ try:
908
+ outputs = model(
909
+ grouped_inputs=grouped_input,
910
+ is_prefill=True,
911
+ use_cache=True,
912
+ return_dict=True
913
+ )
914
+ except Exception as e:
915
+ logger.error(f"Error in prefill phase: {e}")
916
+ raise
917
+
918
+ if hasattr(outputs, 'logits') and outputs.logits is not None:
919
+ next_token_logits = outputs.logits[:, -1, :]
920
+ elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
921
+ last_hidden_state = outputs.hidden_states[-1] if isinstance(outputs.hidden_states, (list, tuple)) else outputs.hidden_states
922
+ next_token_logits = model.lm_head(last_hidden_state[:, -1, :])
923
+ else:
924
+ raise RuntimeError("Could not extract logits from model output")
925
+
926
+ if do_sample:
927
+ next_token_logits = next_token_logits / temperature
928
+ probs = F.softmax(next_token_logits, dim=-1)
929
+ next_token = torch.multinomial(probs, num_samples=1)
930
+ else:
931
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
932
+
933
+ generated_ids = next_token
934
+ past_key_values = getattr(outputs, 'past_key_values', None)
935
+
936
+ for step in range(max_length - 1):
937
+ with torch.no_grad():
938
+ try:
939
+ outputs = model(
940
+ input_ids=next_token,
941
+ past_key_values=past_key_values,
942
+ use_cache=True,
943
+ return_dict=True
944
+ )
945
+ except Exception as e:
946
+ logger.error(f"Error in generation step {step}: {e}")
947
+ break
948
+
949
+ if hasattr(outputs, 'logits'):
950
+ next_token_logits = outputs.logits[:, -1, :]
951
+ else:
952
+ logger.warning("No logits in generation output, stopping generation")
953
+ break
954
+
955
+ if do_sample:
956
+ next_token_logits = next_token_logits / temperature
957
+ probs = F.softmax(next_token_logits, dim=-1)
958
+ next_token = torch.multinomial(probs, num_samples=1)
959
+ else:
960
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
961
+
962
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
963
+ past_key_values = getattr(outputs, 'past_key_values', None)
964
+
965
+ if next_token.item() == tokenizer.eos_token_id:
966
+ break
967
+
968
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
969
+ return generated_text
970
+
971
+ def main():
972
+ config = {
973
+ "model_name": "Qwen/Qwen3-0.6B",
974
+ "dataset_path": "./processed_qwen3_dataset/processed_dataset.pkl",
975
+ "output_dir": "./grouped_qwen3_checkpoint",
976
+ "batch_size": 12, # Optimized for L4 24GB VRAM
977
+ "learning_rate": 5e-4,
978
+ "num_epochs": 3,
979
+ "warmup_steps": 500, # Increased for larger batch
980
+ "max_grad_norm": 1.0,
981
+ "save_steps": 1000, # Less frequent saves due to larger batches
982
+ "eval_steps": 1000, # Less frequent evaluation
983
+ "logging_steps": 100,
984
+ "resume_training": True,
985
+ "debug": False, # Disable debug for performance
986
+ # Streaming parameters
987
+ "chunk_size": 2000, # Load 2000 samples per chunk
988
+ "max_samples": None, # Use full dataset (set to smaller number for testing)
989
+ }
990
+
991
+ logger.info("="*60)
992
+ logger.info("GROUPED QWEN3 TRAINING CONFIGURATION (STREAMING)")
993
+ logger.info("="*60)
994
+ for key, value in config.items():
995
+ logger.info(f"{key}: {value}")
996
+ logger.info("="*60)
997
+
998
+ if torch.cuda.is_available():
999
+ logger.info(f"GPU: {torch.cuda.get_device_name()}")
1000
+ logger.info(f"VRAM Total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
1001
+
1002
+ import psutil
1003
+ ram_usage = psutil.virtual_memory()
1004
+ logger.info(f"System RAM: {ram_usage.used / 1024**3:.1f} GB / {ram_usage.total / 1024**3:.1f} GB ({ram_usage.percent:.1f}%)")
1005
+
1006
+ trainer = GroupedTrainer(**config)
1007
+ trainer.run()
1008
+
1009
+
1010
+ def inference_by_id(sample_id: int, checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
1011
+ dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
1012
+ max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
1013
+ """Run inference on a specific sample ID from the dataset."""
1014
+ logger.info(f"Running inference on sample ID: {sample_id}")
1015
+
1016
+ try:
1017
+ model, tokenizer = load_trained_model(checkpoint_path)
1018
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1019
+ model = model.to(device)
1020
+ logger.info(f"Model loaded from {checkpoint_path}")
1021
+ except Exception as e:
1022
+ logger.error(f"Failed to load model: {e}")
1023
+ return None
1024
+
1025
+ try:
1026
+ logger.info(f"Loading sample {sample_id} from dataset...")
1027
+ with open(dataset_path, 'rb') as f:
1028
+ processed_data = pickle.load(f)
1029
+
1030
+ if sample_id >= len(processed_data):
1031
+ logger.error(f"Sample ID {sample_id} is out of range. Dataset has {len(processed_data)} samples.")
1032
+ return None
1033
+
1034
+ sample = processed_data[sample_id]
1035
+
1036
+ if sample.get("error", False):
1037
+ logger.error(f"Sample {sample_id} has errors and cannot be used for inference.")
1038
+ return None
1039
+
1040
+ except Exception as e:
1041
+ logger.error(f"Failed to load dataset: {e}")
1042
+ return None
1043
+
1044
+ grouped_embeds_raw = sample["inputs_embeds"]
1045
+ if isinstance(grouped_embeds_raw, torch.Tensor):
1046
+ grouped_input = grouped_embeds_raw.detach().clone().float()
1047
+ else:
1048
+ grouped_input = torch.tensor(grouped_embeds_raw, dtype=torch.float32)
1049
+
1050
+ original_input = sample["input_text"]
1051
+ expected_response = sample["response"]
1052
+
1053
+ print("\n" + "="*80)
1054
+ print(f"INFERENCE ON SAMPLE ID: {sample_id}")
1055
+ print("="*80)
1056
+ print(f"πŸ“ ORIGINAL REQUEST:")
1057
+ print(f"{original_input}")
1058
+ print("\n" + "-"*80)
1059
+ print(f"🎯 EXPECTED RESPONSE:")
1060
+ print(f"{expected_response}")
1061
+ print("\n" + "-"*80)
1062
+ print(f"πŸ€– MODEL GENERATED RESPONSE:")
1063
+
1064
+ try:
1065
+ generated_text = generate_with_grouped_input(
1066
+ model=model,
1067
+ tokenizer=tokenizer,
1068
+ grouped_input=grouped_input,
1069
+ max_length=max_length,
1070
+ temperature=temperature,
1071
+ do_sample=do_sample
1072
+ )
1073
+
1074
+ print(f"{generated_text}")
1075
+ print("\n" + "="*80)
1076
+
1077
+ expected_words = expected_response.split()
1078
+ generated_words = generated_text.split()
1079
+
1080
+ print(f"πŸ“Š METRICS:")
1081
+ print(f"Expected length: {len(expected_words)} words")
1082
+ print(f"Generated length: {len(generated_words)} words")
1083
+ print(f"Temperature: {temperature}")
1084
+ print(f"Max length: {max_length}")
1085
+ print("="*80)
1086
+
1087
+ return {
1088
+ "sample_id": sample_id,
1089
+ "original_input": original_input,
1090
+ "expected_response": expected_response,
1091
+ "generated_response": generated_text,
1092
+ "expected_length": len(expected_words),
1093
+ "generated_length": len(generated_words)
1094
+ }
1095
+
1096
+ except Exception as e:
1097
+ logger.error(f"Failed to generate response: {e}")
1098
+ print(f"❌ GENERATION FAILED: {e}")
1099
+ print("="*80)
1100
+ return None
1101
+
1102
+
1103
+ def batch_inference(sample_ids: List[int], checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
1104
+ dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
1105
+ max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
1106
+ """Run inference on multiple sample IDs."""
1107
+ logger.info(f"Running batch inference on {len(sample_ids)} samples")
1108
+
1109
+ results = []
1110
+ for sample_id in sample_ids:
1111
+ result = inference_by_id(
1112
+ sample_id=sample_id,
1113
+ checkpoint_path=checkpoint_path,
1114
+ dataset_path=dataset_path,
1115
+ max_length=max_length,
1116
+ temperature=temperature,
1117
+ do_sample=do_sample
1118
+ )
1119
+ if result:
1120
+ results.append(result)
1121
+
1122
+ print("\n" + "πŸ”„ " + "-"*78 + " πŸ”„\n") # Separator between samples
1123
+
1124
+ print("\n" + "="*80)
1125
+ print(f"πŸ“‹ BATCH INFERENCE SUMMARY")
1126
+ print("="*80)
1127
+ print(f"Total samples processed: {len(results)}")
1128
+ if results:
1129
+ avg_expected_len = sum(r["expected_length"] for r in results) / len(results)
1130
+ avg_generated_len = sum(r["generated_length"] for r in results) / len(results)
1131
+ print(f"Average expected length: {avg_expected_len:.1f} words")
1132
+ print(f"Average generated length: {avg_generated_len:.1f} words")
1133
+ print("="*80)
1134
+
1135
+ return results
1136
+
1137
+
1138
+ def random_inference(num_samples: int = 3, checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
1139
+ dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl",
1140
+ max_length: int = 512, temperature: float = 0.7, do_sample: bool = True):
1141
+ """Run inference on random samples from the dataset."""
1142
+ import random
1143
+
1144
+ try:
1145
+ with open(dataset_path, 'rb') as f:
1146
+ processed_data = pickle.load(f)
1147
+
1148
+ # Find valid samples
1149
+ valid_indices = [i for i, item in enumerate(processed_data) if not item.get("error", False)]
1150
+
1151
+ if len(valid_indices) < num_samples:
1152
+ logger.warning(f"Only {len(valid_indices)} valid samples available, using all of them")
1153
+ num_samples = len(valid_indices)
1154
+
1155
+ # Select random samples
1156
+ random_ids = random.sample(valid_indices, num_samples)
1157
+
1158
+ logger.info(f"Selected random sample IDs: {random_ids}")
1159
+
1160
+ # Run batch inference
1161
+ return batch_inference(
1162
+ sample_ids=random_ids,
1163
+ checkpoint_path=checkpoint_path,
1164
+ dataset_path=dataset_path,
1165
+ max_length=max_length,
1166
+ temperature=temperature,
1167
+ do_sample=do_sample
1168
+ )
1169
+
1170
+ except Exception as e:
1171
+ logger.error(f"Failed to load dataset for random sampling: {e}")
1172
+ return None
1173
+
1174
+
1175
+ def interactive_inference(checkpoint_path: str = "./grouped_qwen3_checkpoint/epoch_2_best",
1176
+ dataset_path: str = "./processed_qwen3_dataset/processed_dataset.pkl"):
1177
+ """Interactive inference mode where user can input sample IDs."""
1178
+ print("\n" + "="*80)
1179
+ print("πŸ€– INTERACTIVE INFERENCE MODE")
1180
+ print("="*80)
1181
+ print("Commands:")
1182
+ print(" <number> - Run inference on sample ID")
1183
+ print(" random <n> - Run inference on n random samples (default: 3)")
1184
+ print(" batch <ids> - Run inference on multiple IDs (e.g., 'batch 1,5,10')")
1185
+ print(" quit - Exit")
1186
+ print("="*80)
1187
+
1188
+ while True:
1189
+ try:
1190
+ user_input = input("\nπŸ” Enter command: ").strip().lower()
1191
+
1192
+ if user_input in ['quit', 'exit', 'q']:
1193
+ print("πŸ‘‹ Goodbye!")
1194
+ break
1195
+ elif user_input.startswith('random'):
1196
+ parts = user_input.split()
1197
+ num_samples = int(parts[1]) if len(parts) > 1 else 3
1198
+ random_inference(num_samples=num_samples, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
1199
+ elif user_input.startswith('batch'):
1200
+ parts = user_input.split(maxsplit=1)
1201
+ if len(parts) > 1:
1202
+ ids_str = parts[1]
1203
+ sample_ids = [int(x.strip()) for x in ids_str.split(',')]
1204
+ batch_inference(sample_ids=sample_ids, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
1205
+ else:
1206
+ print("❌ Please provide sample IDs: batch 1,5,10")
1207
+ elif user_input.isdigit():
1208
+ sample_id = int(user_input)
1209
+ inference_by_id(sample_id=sample_id, checkpoint_path=checkpoint_path, dataset_path=dataset_path)
1210
+ else:
1211
+ print("❌ Invalid command. Try a number, 'random', 'batch', or 'quit'")
1212
+
1213
+ except ValueError:
1214
+ print("❌ Invalid input. Please enter a valid number or command.")
1215
+ except KeyboardInterrupt:
1216
+ print("\nπŸ‘‹ Goodbye!")
1217
+ break
1218
+ except Exception as e:
1219
+ print(f"❌ Error: {e}")
1220
+
1221
+
1222
+ def test_inference():
1223
+ logger.info("Running inference tests...")
1224
+
1225
+ test_ids = [0, 1, 2, 100, 500] # Mix of early and later samples
1226
+
1227
+ print("\nπŸ§ͺ TESTING INFERENCE ON PREDEFINED SAMPLES")
1228
+ results = batch_inference(
1229
+ sample_ids=test_ids,
1230
+ max_length=300,
1231
+ temperature=0.7,
1232
+ do_sample=True
1233
+ )
1234
+
1235
+ return results
1236
+
1237
+
1238
+ if __name__ == "__main__":
1239
+ import argparse
1240
+
1241
+ parser = argparse.ArgumentParser(description="Grouped Qwen3 Training and Inference")
1242
+ parser.add_argument("--mode", choices=["train", "test", "inference", "interactive", "random"],
1243
+ default="train", help="Mode to run")
1244
+ parser.add_argument("--sample_id", type=int, help="Sample ID for inference mode")
1245
+ parser.add_argument("--sample_ids", type=str, help="Comma-separated sample IDs for batch inference")
1246
+ parser.add_argument("--num_samples", type=int, default=3, help="Number of random samples for random mode")
1247
+ parser.add_argument("--checkpoint", type=str, default="./grouped_qwen3_checkpoint/epoch_2_best",
1248
+ help="Path to model checkpoint")
1249
+ parser.add_argument("--dataset", type=str, default="./processed_qwen3_dataset/processed_dataset.pkl",
1250
+ help="Path to dataset")
1251
+ parser.add_argument("--max_length", type=int, default=512, help="Maximum generation length")
1252
+ parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature")
1253
+
1254
+ args = parser.parse_args()
1255
+
1256
+ if args.mode == "train":
1257
+ main()
1258
+ elif args.mode == "test":
1259
+ test_inference()
1260
+ elif args.mode == "inference":
1261
+ if args.sample_id is not None:
1262
+ inference_by_id(
1263
+ sample_id=args.sample_id,
1264
+ checkpoint_path=args.checkpoint,
1265
+ dataset_path=args.dataset,
1266
+ max_length=args.max_length,
1267
+ temperature=args.temperature
1268
+ )
1269
+ elif args.sample_ids is not None:
1270
+ sample_ids = [int(x.strip()) for x in args.sample_ids.split(',')]
1271
+ batch_inference(
1272
+ sample_ids=sample_ids,
1273
+ checkpoint_path=args.checkpoint,
1274
+ dataset_path=args.dataset,
1275
+ max_length=args.max_length,
1276
+ temperature=args.temperature
1277
+ )
1278
+ else:
1279
+ print("❌ Please provide --sample_id or --sample_ids for inference mode")
1280
+ elif args.mode == "interactive":
1281
+ interactive_inference(
1282
+ checkpoint_path=args.checkpoint,
1283
+ dataset_path=args.dataset
1284
+ )
1285
+ elif args.mode == "random":
1286
+ random_inference(
1287
+ num_samples=args.num_samples,
1288
+ checkpoint_path=args.checkpoint,
1289
+ dataset_path=args.dataset,
1290
+ max_length=args.max_length,
1291
+ temperature=args.temperature
1292
+ )