import os import sys import logging import json import pickle from typing import Optional, Tuple, List, Dict, Any from pathlib import Path from tqdm import tqdm import torch import torch.nn as nn from datasets import load_dataset, Dataset from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.models.qwen3.modeling_qwen3 import ( Qwen3DecoderLayer, Qwen3Attention, Qwen3RotaryEmbedding, ) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)], force=True, ) logger = logging.getLogger("qwen3_dataset_processor") class GroupedCache: """Cache for grouping metadata.""" def __init__(self): self.grouped_positions = None self.position_mapping = None self.group_info = None self.original_seq_length = None class CustomQwen3Attention(Qwen3Attention): def __init__(self, config, layer_idx: int): super().__init__(config, layer_idx) self.layer_idx = layer_idx self.tokenizer = None self.current_input_ids = None self.threshold = 0.1 self.grouped_cache = GroupedCache() if not hasattr(self, 'num_key_value_heads'): self.num_key_value_heads = config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else config.num_attention_heads if not hasattr(self, 'head_dim'): self.head_dim = config.hidden_size // config.num_attention_heads def set_tokenizer(self, tokenizer): self.tokenizer = tokenizer def set_current_input_ids(self, input_ids): self.current_input_ids = input_ids def _is_special_token(self, token: str) -> bool: if self.tokenizer is None: return False special_tokens = set() if hasattr(self.tokenizer, 'special_tokens_map'): for token_type, token_value in self.tokenizer.special_tokens_map.items(): if isinstance(token_value, str): special_tokens.add(token_value) elif isinstance(token_value, list): special_tokens.update(token_value) if hasattr(self.tokenizer, 'added_tokens_encoder'): special_tokens.update(self.tokenizer.added_tokens_encoder.keys()) if token in special_tokens: return True special_patterns = [ lambda t: t.startswith('<|') and t.endswith('|>'), lambda t: t.startswith('<') and t.endswith('>'), lambda t: t.startswith('[') and t.endswith(']'), ] return any(pattern(token) for pattern in special_patterns) def _get_token_relations(self, attention_weights: torch.Tensor, tokens: List[str]) -> List[Dict]: batch_size, num_heads, query_len, key_len = attention_weights.shape attn = attention_weights[0].mean(dim=0) relations = [] if query_len == 1: current_token_pos = len(tokens) - 1 token_relations = [] for j in range(len(tokens)): if j != current_token_pos: weight = attn[0, j].item() if weight > self.threshold: token_relations.append({ 'target_pos': j, 'weight': round(weight, 3) }) relations.append({ 'source_pos': current_token_pos, 'relations': token_relations }) else: for i in range(min(query_len, len(tokens))): token_relations = [] for j in range(len(tokens)): if i != j and j < key_len: weight = attn[i, j].item() if weight > self.threshold: token_relations.append({ 'target_pos': j, 'weight': round(weight, 3) }) relations.append({ 'source_pos': i, 'relations': token_relations }) return relations def _get_token_groups(self, attention_weights: torch.Tensor) -> List[List[int]]: if self.tokenizer is None or self.current_input_ids is None: return [] if len(attention_weights.shape) != 4: return [] batch_size, num_heads, query_len, key_len = attention_weights.shape input_ids = self.current_input_ids if input_ids is None or input_ids.shape[1] < key_len: return [] tokens = [self.tokenizer.decode([token_id]) for token_id in input_ids[0][:key_len]] relations = self._get_token_relations(attention_weights, tokens) groups = [] current_group = [] current_group_indices = [] for i, token in enumerate(tokens): is_empty_relations = i < len(relations) and len(relations[i]['relations']) == 0 starts_with_space = token.startswith(' ') and token != ' ' is_space = token == ' ' is_new_line = '\n' in token prev_token_is_special = False prev_token_is_new_line = False prev_token_is_space = False if i > 0: prev_token = tokens[i-1] prev_token_is_special = self._is_special_token(prev_token) prev_token_is_new_line = '\n' in prev_token prev_token_is_space = prev_token == ' ' prev_newline_current_not = prev_token_is_new_line and not is_new_line prev_space_current_not = prev_token_is_space and not is_space current_space_prev_not = is_space and not prev_token_is_space if (is_empty_relations or starts_with_space or is_new_line or prev_token_is_special or prev_newline_current_not or prev_space_current_not or current_space_prev_not) and current_group: groups.append(current_group_indices) current_group = [] current_group_indices = [] current_group.append(token) current_group_indices.append(i) if current_group: groups.append(current_group_indices) return groups class CustomQwen3DecoderLayer(Qwen3DecoderLayer): def __init__(self, config, layer_idx: int): super().__init__(config, layer_idx) self.layer_idx = layer_idx self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.self_attn = CustomQwen3Attention(config, layer_idx) self.is_initialized = False self.grouped_hidden_states = None self.grouped_cache = GroupedCache() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple] = None, **kwargs, ): if self.layer_idx != 0: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) is_prefill = hidden_states.shape[1] > 1 and not self.is_initialized if not is_prefill: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) residual = hidden_states x = self.input_layernorm(hidden_states) attn_ret = self.self_attn( hidden_states=x, attention_mask=attention_mask, position_ids=position_ids, past_key_value=None, output_attentions=True, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings, ) if isinstance(attn_ret, tuple): if len(attn_ret) == 3: attn_out, attn_weights, _ = attn_ret elif len(attn_ret) == 2: attn_out, attn_weights = attn_ret else: raise RuntimeError(f"Unexpected attention return length: {len(attn_ret)}") else: raise RuntimeError("Attention did not return weights.") groups = self.self_attn._get_token_groups(attn_weights) if not groups: self.is_initialized = True return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) averaged_vectors = [] group_info = [] position_mapping = {} for gi, idxs in enumerate(groups): if len(idxs) == 1: averaged_vectors.append(attn_out[:, idxs[0], :]) group_info.append({"type": "single", "positions": idxs, "new_position": gi}) else: gvecs = attn_out[:, idxs, :] ave = gvecs.mean(dim=1) averaged_vectors.append(ave) group_info.append({"type": "averaged", "positions": idxs, "new_position": gi}) for p in idxs: position_mapping[p] = gi new_attn_out = torch.stack(averaged_vectors, dim=1) expanded_residual = torch.stack([ ( residual[:, info['positions'], :].sum(dim=1) if len(info['positions']) > 1 else residual[:, info['positions'][0], :] ) for info in group_info ], dim=1) hs = expanded_residual + new_attn_out grouped_hidden = self.post_attention_layernorm(hs) self.grouped_cache.grouped_positions = len(groups) self.grouped_cache.position_mapping = position_mapping self.grouped_cache.group_info = group_info self.grouped_cache.original_seq_length = hidden_states.shape[1] self.grouped_hidden_states = grouped_hidden self.is_initialized = True return hs def create_model_with_custom_layer0(model_name: str = "Qwen/Qwen3-0.6B"): tokenizer = AutoTokenizer.from_pretrained(model_name) if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, attn_implementation="eager" ).to(device) orig0 = model.model.layers[0] custom0 = CustomQwen3DecoderLayer(model.config, 0) custom0.mlp.load_state_dict(orig0.mlp.state_dict()) custom0.input_layernorm.load_state_dict(orig0.input_layernorm.state_dict()) custom0.post_attention_layernorm.load_state_dict(orig0.post_attention_layernorm.state_dict()) custom0.self_attn.load_state_dict(orig0.self_attn.state_dict()) custom0.self_attn.set_tokenizer(tokenizer) custom0 = custom0.to(device=device, dtype=dtype) model.model.layers[0] = custom0 return model, tokenizer class DatasetProcessor: def __init__(self, model_name: str = "Qwen/Qwen3-0.6B", dataset_name: str = "Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1", output_dir: str = "./processed_dataset", batch_size: int = 1, max_samples: Optional[int] = None, save_frequency: int = 1000): self.model_name = model_name self.dataset_name = dataset_name self.output_dir = Path(output_dir) self.batch_size = batch_size self.max_samples = max_samples self.save_frequency = save_frequency self.output_dir.mkdir(parents=True, exist_ok=True) # System prompt template for Qwen3 self.system_prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" self.response_start = "<|im_end|>\n<|im_start|>assistant\n" self.current_chunk = 0 self.processed_data_buffer = [] def load_dataset(self) -> Dataset: logger.info(f"Loading dataset: {self.dataset_name}") dataset = load_dataset(self.dataset_name, split="train") if self.max_samples: dataset = dataset.select(range(min(self.max_samples, len(dataset)))) logger.info(f"Dataset loaded: {len(dataset)} samples") return dataset def format_input_text(self, instruction: str) -> str: return f"{self.system_prompt}{instruction}{self.response_start}" def process_embeddings_batch(self, model, tokenizer, texts: List[str]) -> List[torch.Tensor]: device = model.device embeddings_batch = [] for text in texts: try: if hasattr(model.model.layers[0], "is_initialized"): model.model.layers[0].is_initialized = False batch = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(device) input_ids = batch["input_ids"] if hasattr(model.model.layers[0], "self_attn"): sat = model.model.layers[0].self_attn if hasattr(sat, "set_current_input_ids"): sat.set_current_input_ids(input_ids) with torch.no_grad(): inputs_embeds = model.model.embed_tokens(input_ids) seq_len = inputs_embeds.shape[1] position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) if hasattr(model.model, 'rotary_emb'): pos_embeds = model.model.rotary_emb(inputs_embeds, position_ids) else: pos_embeds = None _ = model.model.layers[0]( hidden_states=inputs_embeds, attention_mask=None, position_ids=position_ids, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None, position_embeddings=pos_embeds, ) if (hasattr(model.model.layers[0], "grouped_hidden_states") and model.model.layers[0].grouped_hidden_states is not None): grouped_embeds = model.model.layers[0].grouped_hidden_states.clone().cpu() embeddings_batch.append(grouped_embeds.squeeze(0)) model.model.layers[0].grouped_hidden_states = None else: embeddings_batch.append(inputs_embeds.squeeze(0).cpu()) del inputs_embeds, position_ids if pos_embeds is not None: del pos_embeds if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: logger.warning(f"Error processing sample: {e}") embeddings_batch.append(torch.zeros(1, model.config.hidden_size)) return embeddings_batch def save_chunk(self, chunk_data: List[Dict[str, Any]], chunk_id: int): if not chunk_data: return chunk_path = self.output_dir / f"processed_chunk_{chunk_id:04d}.pkl" with open(chunk_path, 'wb') as f: pickle.dump(chunk_data, f) # Clear memory del chunk_data import gc gc.collect() def merge_chunks(self) -> List[Dict[str, Any]]: logger.info("Merging chunks...") chunk_files = sorted(list(self.output_dir.glob("processed_chunk_*.pkl"))) if not chunk_files: return [] merged_data = [] for chunk_file in tqdm(chunk_files, desc="Merging chunks"): try: with open(chunk_file, 'rb') as f: chunk_data = pickle.load(f) if isinstance(chunk_data, list): merged_data.extend(chunk_data) except Exception as e: logger.error(f"Error loading chunk {chunk_file}: {e}") continue # Clean up chunk files self.cleanup_chunks() logger.info(f"Merged {len(chunk_files)} chunks into {len(merged_data)} samples") return merged_data def cleanup_chunks(self): chunk_files = list(self.output_dir.glob("processed_chunk_*.pkl")) for chunk_file in chunk_files: try: chunk_file.unlink() except Exception as e: logger.warning(f"Could not delete chunk {chunk_file}: {e}") if chunk_files: logger.info(f"Cleaned up {len(chunk_files)} temporary chunk files") def save_final_dataset(self, processed_data: List[Dict[str, Any]], stats: Dict[str, int]): pickle_path = self.output_dir / "processed_dataset.pkl" with open(pickle_path, 'wb') as f: pickle.dump(processed_data, f) error_samples = sum(1 for sample in processed_data if sample.get("error", False)) successful_samples = len(processed_data) - error_samples metadata = { "model_name": self.model_name, "dataset_name": self.dataset_name, "total_samples": stats["total_samples"], "processed_samples": len(processed_data), "successful_samples": successful_samples, "error_samples": error_samples, "batch_size": self.batch_size, "max_samples": self.max_samples, "success_rate": f"{(successful_samples / len(processed_data) * 100):.2f}%" if processed_data else "0%" } with open(self.output_dir / "metadata.json", 'w', encoding='utf-8') as f: json.dump(metadata, f, indent=2, ensure_ascii=False) text_samples = [] count = 0 for i, sample in enumerate(processed_data): if not sample.get("error", False) and count < 10: text_samples.append({ "sample_id": i, "input_text": sample["input_text"][:300] + "..." if len(sample["input_text"]) > 300 else sample["input_text"], "response": sample["response"][:300] + "..." if len(sample["response"]) > 300 else sample["response"], "embedding_shape": sample["embedding_shape"] }) count += 1 with open(self.output_dir / "samples.json", 'w', encoding='utf-8') as f: json.dump(text_samples, f, indent=2, ensure_ascii=False) logger.info(f"Dataset saved: {len(processed_data)} samples") logger.info(f"Success rate: {metadata['success_rate']}") def process_dataset(self): dataset = self.load_dataset() logger.info("Loading model...") model, tokenizer = create_model_with_custom_layer0(self.model_name) total_samples = len(dataset) processed_count = 0 error_count = 0 logger.info(f"Processing {total_samples} samples...") for i in tqdm(range(0, total_samples, self.batch_size), desc="Processing"): batch_end = min(i + self.batch_size, total_samples) batch_samples = dataset.select(range(i, batch_end)) batch_texts = [] batch_instructions = [] batch_responses = [] try: for sample in batch_samples: instruction = sample.get("instruction", "") response = sample.get("response", "") if not instruction.strip() or not response.strip(): instruction = "Empty instruction" response = "Empty response" input_text = self.format_input_text(instruction) batch_texts.append(input_text) batch_instructions.append(input_text) batch_responses.append(response) embeddings_batch = self.process_embeddings_batch(model, tokenizer, batch_texts) for j, (input_text, embedding, response) in enumerate(zip(batch_instructions, embeddings_batch, batch_responses)): processed_sample = { "input_text": input_text, "inputs_embeds": embedding, "response": response, "embedding_shape": list(embedding.shape), "original_index": i + j } self.processed_data_buffer.append(processed_sample) processed_count += 1 if len(self.processed_data_buffer) >= self.save_frequency: self.save_chunk(self.processed_data_buffer, self.current_chunk) self.processed_data_buffer = [] self.current_chunk += 1 import gc gc.collect() except Exception as e: logger.error(f"Error processing batch: {e}") error_count += len(batch_samples) if self.processed_data_buffer: self.save_chunk(self.processed_data_buffer, self.current_chunk) self.processed_data_buffer = [] merged_data = self.merge_chunks() stats = { "total_samples": total_samples, "processed_count": processed_count, "error_count": error_count } self.save_final_dataset(merged_data, stats) return merged_data def load_processed_dataset(dataset_path: str) -> List[Dict[str, Any]]: pickle_path = Path(dataset_path) / "processed_dataset.pkl" with open(pickle_path, 'rb') as f: return pickle.load(f) def get_dataset_info(dataset_path: str) -> Dict: metadata_path = Path(dataset_path) / "metadata.json" with open(metadata_path, 'r') as f: return json.load(f) def main(): model_name = "Qwen/Qwen3-0.6B" dataset_name = "Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1" output_dir = "./processed_qwen3_dataset" batch_size = 1 max_samples = 10000 # Set to number for testing, None for full dataset save_frequency = 1000 logger.info("Starting Qwen3 dataset processing...") logger.info(f"Model: {model_name}") logger.info(f"Dataset: {dataset_name}") logger.info(f"Output: {output_dir}") logger.info(f"Max samples: {max_samples or 'ALL'}") try: processor = DatasetProcessor( model_name=model_name, dataset_name=dataset_name, output_dir=output_dir, batch_size=batch_size, max_samples=max_samples, save_frequency=save_frequency ) processed_data = processor.process_dataset() logger.info("Processing completed successfully!") logger.info(f"Final dataset: {len(processed_data)} samples") logger.info(f"Files saved to: {output_dir}") return processed_data except Exception as e: logger.error(f"Processing failed: {e}") raise if __name__ == "__main__": main()