import os import logging import time import psutil from typing import Optional, List, Dict, Any, Tuple from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.models.qwen3.configuration_qwen3 import Qwen3Config from transformers.models.qwen3.modeling_qwen3 import ( Qwen3ForCausalLM, Qwen3RMSNorm, Qwen3DecoderLayer, Qwen3Attention, Qwen3RotaryEmbedding, ) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("qwen3_grouped_inference") class PerformanceMonitor: def __init__(self): self.reset() def reset(self): """Reset all metrics.""" self.start_time = None self.end_time = None self.start_memory = None self.peak_memory = None self.start_gpu_memory = None self.peak_gpu_memory = None def start_monitoring(self): self.reset() self.start_time = time.time() process = psutil.Process() self.start_memory = process.memory_info().rss / 1024 / 1024 # MB self.peak_memory = self.start_memory if torch.cuda.is_available(): torch.cuda.empty_cache() self.start_gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 # MB self.peak_gpu_memory = self.start_gpu_memory def update_peak_memory(self): process = psutil.Process() current_memory = process.memory_info().rss / 1024 / 1024 # MB self.peak_memory = max(self.peak_memory, current_memory) if torch.cuda.is_available(): current_gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024 # MB self.peak_gpu_memory = max(self.peak_gpu_memory, current_gpu_memory) def stop_monitoring(self): self.end_time = time.time() self.update_peak_memory() metrics = { "duration_ms": (self.end_time - self.start_time) * 1000, "cpu_memory_start_mb": self.start_memory, "cpu_memory_peak_mb": self.peak_memory, "cpu_memory_used_mb": self.peak_memory - self.start_memory, } if torch.cuda.is_available(): metrics.update({ "gpu_memory_start_mb": self.start_gpu_memory, "gpu_memory_peak_mb": self.peak_gpu_memory, "gpu_memory_used_mb": self.peak_gpu_memory - self.start_gpu_memory, }) return metrics class CustomQwen3Attention(Qwen3Attention): def __init__(self, config, layer_idx: int): super().__init__(config, layer_idx) self.layer_idx = layer_idx self.tokenizer = None self.current_input_ids = None self.threshold = 0.1 if not hasattr(self, 'num_key_value_heads'): self.num_key_value_heads = config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else config.num_attention_heads if not hasattr(self, 'head_dim'): self.head_dim = config.hidden_size // config.num_attention_heads def set_tokenizer(self, tokenizer): self.tokenizer = tokenizer def set_current_input_ids(self, input_ids): self.current_input_ids = input_ids def _is_special_token(self, token: str) -> bool: if self.tokenizer is None: return False special_tokens = set() if hasattr(self.tokenizer, 'special_tokens_map'): for token_type, token_value in self.tokenizer.special_tokens_map.items(): if isinstance(token_value, str): special_tokens.add(token_value) elif isinstance(token_value, list): special_tokens.update(token_value) if hasattr(self.tokenizer, 'added_tokens_encoder'): special_tokens.update(self.tokenizer.added_tokens_encoder.keys()) if token in special_tokens: return True special_patterns = [ lambda t: t.startswith('<|') and t.endswith('|>'), lambda t: t.startswith('<') and t.endswith('>'), lambda t: t.startswith('[') and t.endswith(']'), ] return any(pattern(token) for pattern in special_patterns) def _get_token_relations(self, attention_weights: torch.Tensor, tokens: List[str]) -> List[Dict]: batch_size, num_heads, query_len, key_len = attention_weights.shape attn = attention_weights[0].mean(dim=0) relations = [] if query_len == 1: current_token_pos = len(tokens) - 1 token_relations = [] for j in range(len(tokens)): if j != current_token_pos: weight = attn[0, j].item() if weight > self.threshold: token_relations.append({ 'target_pos': j, 'weight': round(weight, 3) }) relations.append({ 'source_pos': current_token_pos, 'relations': token_relations }) else: for i in range(min(query_len, len(tokens))): token_relations = [] for j in range(len(tokens)): if i != j and j < key_len: weight = attn[i, j].item() if weight > self.threshold: token_relations.append({ 'target_pos': j, 'weight': round(weight, 3) }) relations.append({ 'source_pos': i, 'relations': token_relations }) return relations def _get_token_groups(self, attention_weights: torch.Tensor) -> List[List[int]]: if self.tokenizer is None or self.current_input_ids is None: return [] if len(attention_weights.shape) != 4: return [] batch_size, num_heads, query_len, key_len = attention_weights.shape input_ids = self.current_input_ids if input_ids is None or input_ids.shape[1] < key_len: return [] tokens = [self.tokenizer.decode([token_id]) for token_id in input_ids[0][:key_len]] relations = self._get_token_relations(attention_weights, tokens) groups = [] current_group = [] current_group_indices = [] for i, token in enumerate(tokens): is_empty_relations = i < len(relations) and len(relations[i]['relations']) == 0 starts_with_space = token.startswith(' ') and token != ' ' is_space = token == ' ' is_new_line = '\n' in token prev_token_is_special = False prev_token_is_new_line = False prev_token_is_space = False if i > 0: prev_token = tokens[i-1] prev_token_is_special = self._is_special_token(prev_token) prev_token_is_new_line = '\n' in prev_token prev_token_is_space = prev_token == ' ' prev_newline_current_not = prev_token_is_new_line and not is_new_line prev_space_current_not = prev_token_is_space and not is_space current_space_prev_not = is_space and not prev_token_is_space if (is_empty_relations or starts_with_space or is_new_line or prev_token_is_special or prev_newline_current_not or prev_space_current_not or current_space_prev_not) and current_group: groups.append(current_group_indices) current_group = [] current_group_indices = [] current_group.append(token) current_group_indices.append(i) if current_group: groups.append(current_group_indices) if groups: logger.info("Token grouping details:") for group_idx, group_indices in enumerate(groups): group_tokens = [tokens[i] for i in group_indices] combined_text = ''.join(group_tokens) logger.info(f" Group {group_idx + 1}: {group_tokens} → '{combined_text}'") return groups class CustomQwen3DecoderLayer(Qwen3DecoderLayer): """Custom Qwen3 decoder layer with grouping functionality.""" def __init__(self, config, layer_idx: int): super().__init__(config, layer_idx) self.layer_idx = layer_idx self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.self_attn = CustomQwen3Attention(config, layer_idx) self.is_initialized = False self.grouped_hidden_states = None def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple] = None, **kwargs, ): if self.layer_idx != 0: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) is_prefill = hidden_states.shape[1] > 1 and not self.is_initialized if not is_prefill: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) residual = hidden_states x = self.input_layernorm(hidden_states) attn_ret = self.self_attn( hidden_states=x, attention_mask=attention_mask, position_ids=position_ids, past_key_value=None, output_attentions=True, use_cache=False, cache_position=cache_position, position_embeddings=position_embeddings, ) if isinstance(attn_ret, tuple): if len(attn_ret) == 3: attn_out, attn_weights, _ = attn_ret elif len(attn_ret) == 2: attn_out, attn_weights = attn_ret else: raise RuntimeError(f"Unexpected attention return length: {len(attn_ret)}") else: raise RuntimeError("Attention did not return weights.") groups = self.self_attn._get_token_groups(attn_weights) if not groups: self.is_initialized = True return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) averaged_vectors = [] group_info = [] for gi, idxs in enumerate(groups): if len(idxs) == 1: averaged_vectors.append(attn_out[:, idxs[0], :]) group_info.append({"type": "single", "positions": idxs, "new_position": gi}) else: gvecs = attn_out[:, idxs, :] ave = gvecs.mean(dim=1) averaged_vectors.append(ave) group_info.append({"type": "averaged", "positions": idxs, "new_position": gi}) new_attn_out = torch.stack(averaged_vectors, dim=1) expanded_residual = torch.stack([ ( residual[:, info['positions'], :].sum(dim=1) if len(info['positions']) > 1 else residual[:, info['positions'][0], :] ) for info in group_info ], dim=1) hs = expanded_residual + new_attn_out grouped_hidden = self.post_attention_layernorm(hs) # Store grouped embeddings self.grouped_hidden_states = grouped_hidden self.is_initialized = True return hs class GroupedInputMLPAdapter(nn.Module): def __init__(self, config): super().__init__() self.config = config hidden_size = config.hidden_size self.grouped_processor = nn.Sequential( nn.Linear(hidden_size, hidden_size * 2), nn.SiLU(), nn.Dropout(0.1), nn.Linear(hidden_size * 2, hidden_size), nn.Dropout(0.1) ) norm_eps = getattr(config, 'rms_norm_eps', 1e-6) self.layer_norm = Qwen3RMSNorm(hidden_size, eps=norm_eps) def forward(self, grouped_embeds: torch.Tensor) -> torch.Tensor: processed = self.grouped_processor(grouped_embeds) output = self.layer_norm(grouped_embeds + processed) return output class CustomQwen3ForCausalLM(Qwen3ForCausalLM): def __init__(self, config): super().__init__(config) self.grouped_input_mlp = GroupedInputMLPAdapter(config) self.is_grouped_input_mode = False def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, grouped_inputs: Optional[torch.FloatTensor] = None, is_prefill: Optional[bool] = None, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if grouped_inputs is not None and is_prefill: self.is_grouped_input_mode = True processed_grouped_inputs = self.grouped_input_mlp(grouped_inputs) inputs_embeds = processed_grouped_inputs input_ids = None batch_size, seq_len = inputs_embeds.shape[:2] if position_ids is None: device = inputs_embeds.device position_ids = torch.arange(seq_len, device=device, dtype=torch.long) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) if attention_mask is None: attention_mask = torch.ones((batch_size, seq_len), device=inputs_embeds.device, dtype=torch.long) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs ) def create_grouping_model(model_name: str = "Qwen/Qwen3-0.6B") -> Tuple[AutoModelForCausalLM, AutoTokenizer]: tokenizer = AutoTokenizer.from_pretrained(model_name) if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, attn_implementation="eager" ).to(device) orig0 = model.model.layers[0] custom0 = CustomQwen3DecoderLayer(model.config, 0) custom0.mlp.load_state_dict(orig0.mlp.state_dict()) custom0.input_layernorm.load_state_dict(orig0.input_layernorm.state_dict()) custom0.post_attention_layernorm.load_state_dict(orig0.post_attention_layernorm.state_dict()) custom0.self_attn.load_state_dict(orig0.self_attn.state_dict()) custom0.self_attn.set_tokenizer(tokenizer) custom0 = custom0.to(device=device, dtype=dtype) model.model.layers[0] = custom0 return model, tokenizer def load_inference_model(checkpoint_path: str) -> Tuple[CustomQwen3ForCausalLM, AutoTokenizer]: logger.info(f"Loading inference model from {checkpoint_path}") tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) config = Qwen3Config.from_pretrained(checkpoint_path) model = CustomQwen3ForCausalLM(config) model_path = Path(checkpoint_path) / "pytorch_model.bin" if not model_path.exists(): model_path = Path(checkpoint_path) / "model.safetensors" if not model_path.exists(): raise FileNotFoundError(f"No model weights found in {checkpoint_path}") state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict, strict=False) model = model.eval().to(torch.float32) return model, tokenizer class Qwen3GroupedInference: def __init__(self, checkpoint_path: str, grouping_model_name: str = "Qwen/Qwen3-0.6B", device: Optional[str] = None): """Initialize inference system with both models.""" if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) logger.info(f"Initializing inference on device: {self.device}") self.system_prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" self.response_start = "<|im_end|>\n<|im_start|>assistant\n" logger.info("Loading grouping model...") self.grouping_model, self.grouping_tokenizer = create_grouping_model(grouping_model_name) self.grouping_model = self.grouping_model.to(self.device) logger.info("Loading inference model...") self.inference_model, self.inference_tokenizer = load_inference_model(checkpoint_path) self.inference_model = self.inference_model.to(self.device) logger.info("Both models loaded successfully") def format_input_text(self, instruction: str) -> str: return f"{self.system_prompt}{instruction}{self.response_start}" def get_grouped_embeddings(self, text: str) -> Tuple[torch.Tensor, Dict[str, Any]]: monitor = PerformanceMonitor() monitor.start_monitoring() if hasattr(self.grouping_model.model.layers[0], "is_initialized"): self.grouping_model.model.layers[0].is_initialized = False batch = self.grouping_tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(self.device) input_ids = batch["input_ids"] original_token_count = input_ids.shape[1] original_tokens = [self.grouping_tokenizer.decode([token_id]) for token_id in input_ids[0]] logger.info(f"Original input tokens ({original_token_count}): {original_tokens}") if hasattr(self.grouping_model.model.layers[0], "self_attn"): sat = self.grouping_model.model.layers[0].self_attn if hasattr(sat, "set_current_input_ids"): sat.set_current_input_ids(input_ids) monitor.update_peak_memory() with torch.no_grad(): inputs_embeds = self.grouping_model.model.embed_tokens(input_ids) seq_len = inputs_embeds.shape[1] position_ids = torch.arange(seq_len, device=self.device, dtype=torch.long).unsqueeze(0) if hasattr(self.grouping_model.model, 'rotary_emb'): pos_embeds = self.grouping_model.model.rotary_emb(inputs_embeds, position_ids) else: pos_embeds = None monitor.update_peak_memory() _ = self.grouping_model.model.layers[0]( hidden_states=inputs_embeds, attention_mask=None, position_ids=position_ids, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None, position_embeddings=pos_embeds, ) monitor.update_peak_memory() if (hasattr(self.grouping_model.model.layers[0], "grouped_hidden_states") and self.grouping_model.model.layers[0].grouped_hidden_states is not None): grouped_embeds = self.grouping_model.model.layers[0].grouped_hidden_states.clone() grouped_token_count = grouped_embeds.shape[1] # Clear the stored state self.grouping_model.model.layers[0].grouped_hidden_states = None compression_ratio = original_token_count / grouped_token_count if grouped_token_count > 0 else 1.0 reduction_percent = (1 - grouped_token_count / original_token_count) * 100 if original_token_count > 0 else 0.0 logger.info(f"Grouped tokens: {grouped_token_count}") logger.info(f"Compression ratio: {compression_ratio:.2f}x ({reduction_percent:.1f}% reduction)") metrics = monitor.stop_monitoring() metrics.update({ "original_tokens": original_token_count, "grouped_tokens": grouped_token_count, "compression_ratio": compression_ratio, "reduction_percent": reduction_percent }) return grouped_embeds.squeeze(0), metrics else: logger.warning("Grouping failed, using original embeddings") metrics = monitor.stop_monitoring() metrics.update({ "original_tokens": original_token_count, "grouped_tokens": original_token_count, "compression_ratio": 1.0, "reduction_percent": 0.0 }) return inputs_embeds.squeeze(0), metrics def generate_with_grouped_input(self, grouped_input: torch.Tensor, max_length: int = 512, temperature: float = 0.7, do_sample: bool = True) -> Tuple[str, Dict[str, Any]]: """Generate text using grouped input embeddings.""" monitor = PerformanceMonitor() monitor.start_monitoring() model_dtype = next(self.inference_model.parameters()).dtype grouped_input = grouped_input.to(device=self.device, dtype=model_dtype) if grouped_input.ndim == 2: grouped_input = grouped_input.unsqueeze(0) input_seq_len = grouped_input.shape[1] logger.info(f"Inference model input sequence length: {input_seq_len}") monitor.update_peak_memory() with torch.no_grad(): outputs = self.inference_model( grouped_inputs=grouped_input, is_prefill=True, use_cache=True, return_dict=True ) monitor.update_peak_memory() if hasattr(outputs, 'logits') and outputs.logits is not None: next_token_logits = outputs.logits[:, -1, :] else: raise RuntimeError("Could not extract logits from model output") if do_sample: next_token_logits = next_token_logits / temperature probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) generated_ids = next_token past_key_values = getattr(outputs, 'past_key_values', None) generated_tokens = 1 for step in range(max_length - 1): monitor.update_peak_memory() with torch.no_grad(): outputs = self.inference_model( input_ids=next_token, past_key_values=past_key_values, use_cache=True, return_dict=True ) if not hasattr(outputs, 'logits'): break next_token_logits = outputs.logits[:, -1, :] if do_sample: next_token_logits = next_token_logits / temperature probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) generated_ids = torch.cat([generated_ids, next_token], dim=1) past_key_values = getattr(outputs, 'past_key_values', None) generated_tokens += 1 if next_token.item() == self.inference_tokenizer.eos_token_id: break generated_text = self.inference_tokenizer.decode(generated_ids[0], skip_special_tokens=True) metrics = monitor.stop_monitoring() metrics.update({ "input_seq_len": input_seq_len, "generated_tokens": generated_tokens, "tokens_per_second": generated_tokens / (metrics["duration_ms"] / 1000) if metrics["duration_ms"] > 0 else 0 }) logger.info(f"Generated {generated_tokens} tokens in {metrics['duration_ms']:.1f}ms") logger.info(f"Generation speed: {metrics['tokens_per_second']:.1f} tokens/second") return generated_text, metrics def inference(self, instruction: str, max_length: int = 512, temperature: float = 0.7, do_sample: bool = True) -> Dict[str, Any]: """Run complete inference pipeline from instruction to response.""" logger.info("=" * 60) logger.info("STARTING INFERENCE PIPELINE") logger.info("=" * 60) input_text = self.format_input_text(instruction) logger.info("PHASE 1: Token Grouping") grouped_embeddings, grouping_metrics = self.get_grouped_embeddings(input_text) logger.info("PHASE 2: Response Generation") response, generation_metrics = self.generate_with_grouped_input( grouped_input=grouped_embeddings, max_length=max_length, temperature=temperature, do_sample=do_sample ) total_metrics = { "grouping": grouping_metrics, "generation": generation_metrics, "total_duration_ms": grouping_metrics["duration_ms"] + generation_metrics["duration_ms"], } logger.info("=" * 60) logger.info("INFERENCE SUMMARY") logger.info("=" * 60) logger.info(f"Input compression: {grouping_metrics['original_tokens']} → {grouping_metrics['grouped_tokens']} tokens") logger.info(f"Compression ratio: {grouping_metrics['compression_ratio']:.2f}x") logger.info(f"Memory reduction: {grouping_metrics['reduction_percent']:.1f}%") logger.info(f"Total time: {total_metrics['total_duration_ms']:.1f}ms") logger.info(f"Generation speed: {generation_metrics['tokens_per_second']:.1f} tokens/sec") if torch.cuda.is_available(): total_gpu_memory = grouping_metrics.get("gpu_memory_used_mb", 0) + generation_metrics.get("gpu_memory_used_mb", 0) logger.info(f"Total GPU memory used: {total_gpu_memory:.1f}MB") total_cpu_memory = grouping_metrics.get("cpu_memory_used_mb", 0) + generation_metrics.get("cpu_memory_used_mb", 0) logger.info(f"Total CPU memory used: {total_cpu_memory:.1f}MB") original_seq_len = grouping_metrics['original_tokens'] grouped_seq_len = grouping_metrics['grouped_tokens'] estimated_memory_savings = (1 - (grouped_seq_len ** 2) / (original_seq_len ** 2)) * 100 if original_seq_len > 0 else 0 logger.info(f"Estimated attention memory savings: {estimated_memory_savings:.1f}%") logger.info("=" * 60) return { "instruction": instruction, "response": response, "metrics": total_metrics } def main(): import argparse parser = argparse.ArgumentParser(description="Qwen3 Grouped Inference") parser.add_argument("--checkpoint", type=str, default="./grouped_qwen3_checkpoint/epoch_2_best", help="Path to trained model checkpoint") parser.add_argument("--grouping_model", type=str, default="Qwen/Qwen3-0.6B", help="Grouping model name") parser.add_argument("--instruction", type=str, default=""" Что такое нейронные сети, объясни как школьнику 9го класса """, help="Instruction for inference") parser.add_argument("--max_length", type=int, default=512, help="Maximum generation length") parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature") parser.add_argument("--no_sample", action="store_true", help="Use greedy decoding") parser.add_argument("--device", type=str, help="Device to use (cuda/cpu)") args = parser.parse_args() inference_system = Qwen3GroupedInference( checkpoint_path=args.checkpoint, grouping_model_name=args.grouping_model, device=args.device ) do_sample = not args.no_sample result = inference_system.inference( instruction=args.instruction, max_length=args.max_length, temperature=args.temperature, do_sample=do_sample ) print(f"\nInstruction: {result['instruction']}") print(f"Response: {result['response']}") metrics = result.get('metrics', {}) if metrics: print(f"\n--- Performance Metrics ---") grouping = metrics.get('grouping', {}) generation = metrics.get('generation', {}) print(f"Token compression: {grouping.get('compression_ratio', 'N/A'):.2f}x") print(f"Memory reduction: {grouping.get('reduction_percent', 'N/A'):.1f}%") print(f"Total time: {metrics.get('total_duration_ms', 'N/A'):.1f}ms") print(f"Generation speed: {generation.get('tokens_per_second', 'N/A'):.1f} tokens/sec") if __name__ == "__main__": main()