""" Example usage of Online mode with warmup This demonstrates: 1. Warmup phase (generate N sequences to calibrate threshold) 2. Threshold computation (DeepConf-low or DeepConf-high) 3. Final generation with calibrated early stopping """ from typing import Optional import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig def extract_answer(text: str) -> Optional[str]: """ Extract boxed answer from LaTeX text Looks for \\boxed{answer} pattern in generated text. """ if "boxed" in text: ans = text.split("boxed")[-1] if len(ans) == 0: return "" elif ans[0] == "{": stack = 1 a = "" for c in ans[1:]: if c == "{": stack += 1 a += c elif c == "}": stack -= 1 if stack == 0: break a += c else: a += c else: a = ans.split("$")[0].strip() return a.strip() return None def compute_least_grouped(confs: list, group_size: int) -> list: """ Compute sliding window mean confidence Args: confs: List of per-token confidence values group_size: Size of sliding window Returns: List of mean confidences for each window position """ if len(confs) < group_size: return [sum(confs) / len(confs)] if confs else [0] sliding_means = [] for i in range(len(confs) - group_size + 1): window = confs[i : i + group_size] sliding_means.append(round(sum(window) / len(window), 3)) return sliding_means def process_single_output( sequence, confidences, tokenizer, window_size: int, threshold: Optional[float] = None ) -> dict: """ Process a single generated sequence Args: sequence: Generated token IDs confidences: Per-token confidence values (list or tensor) tokenizer: Tokenizer for decoding window_size: Size of sliding window for confidence threshold: Optional threshold for early stopping detection Returns: Dictionary with trace data """ # Convert to list if tensor if hasattr(confidences, "tolist"): confs = confidences.tolist() else: confs = list(confidences) # Decode text text = tokenizer.decode(sequence, skip_special_tokens=True) # Compute sliding window statistics sliding_window = compute_least_grouped(confs, window_size) min_conf = min(sliding_window) if sliding_window else 0 # Determine if early stopping would have triggered stopped_early = False stop_position = None if threshold is not None: for pos, window_mean in enumerate(sliding_window): if window_mean < threshold: stopped_early = True stop_position = pos + window_size # Position in original sequence break # Extract answer if present extracted_answer = extract_answer(text) return { "text": text, "confs": confs, "group_confs": sliding_window, "min_conf": min_conf, "stopped_early": stopped_early, "stop_position": stop_position, "extracted_answer": extracted_answer, "num_tokens": len(confs), "token_ids": sequence.tolist() if hasattr(sequence, "tolist") else list(sequence), } def process_batch_results(outputs, tokenizer, window_size: int = 2048, threshold: Optional[float] = None) -> dict: """ Process batch generation outputs This function provides post-processing capabilities for batch-generated sequences, allowing analysis of confidence patterns and early stopping behavior after generation is complete. Args: outputs: GenerateDecoderOnlyOutput from model.generate() tokenizer: Tokenizer for decoding sequences window_size: Size of sliding window for confidence computation threshold: Optional threshold for detecting where early stopping would occur Returns: Dictionary containing: - traces: List of processed trace dictionaries - min_confs: List of minimum confidences per trace - total_tokens: Total tokens across all traces - num_traces: Number of traces processed """ if not hasattr(outputs, "sequences"): raise ValueError("outputs must have 'sequences' attribute") if not hasattr(outputs, "confidences") or outputs.confidences is None: raise ValueError("outputs must have 'confidences' attribute. Set output_confidences=True in generation_config") sequences = outputs.sequences confidences = outputs.confidences # Process each sequence traces = [] min_confs = [] total_tokens = 0 for i in range(sequences.shape[0]): trace_data = process_single_output(sequences[i], confidences[i], tokenizer, window_size, threshold) traces.append(trace_data) min_confs.append(trace_data["min_conf"]) total_tokens += trace_data["num_tokens"] return {"traces": traces, "min_confs": min_confs, "total_tokens": total_tokens, "num_traces": len(traces)} def compute_warmup_threshold(min_confs: list, variant: str = "low", eta: Optional[float] = None) -> float: """ Compute threshold from warmup confidences Args: min_confs: List of minimum confidences from warmup sequences variant: "low" (aggressive) or "high" (permissive) eta: Optional manual eta value (overrides variant default) Returns: Computed threshold value """ if eta is None: eta = 0.1 if variant == "low" else 0.9 if variant == "high" else 0.5 confs = np.asarray(min_confs, dtype=np.float32) pct = max(0.0, min(100.0, 100.0 - (eta * 100.0))) threshold = float(np.percentile(confs, pct)) return threshold # ============================================================================ # Example Functions # ============================================================================ def prepare_prompt(question: str, tokenizer): """Prepare prompt using chat template""" messages = [{"role": "user", "content": question}] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return prompt def run_online_mode_example( question: str, ground_truth: Optional[str] = None, warmup_traces: int = 8, confidence_variant: str = "low", # "low" or "high" window_size: int = 10, max_tokens: int = 128, temperature: float = 0.7, top_p: float = 0.95, ): """ Run DeepConf in online mode Args: question: Question to answer ground_truth: Optional ground truth answer for evaluation warmup_traces: Number of warmup sequences (default: 8) confidence_variant: "low" (aggressive) or "high" (permissive) window_size: Sliding window size for confidence max_tokens: Max tokens per generation temperature: Sampling temperature top_p: Top-p sampling """ # Load model (use local cache to avoid HF Hub timeouts) model_name = "Qwen/Qwen2.5-0.5B-Instruct" print(f"Loading model: {model_name}") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", local_files_only=True, # Use cached model ) tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) # Prepare prompt prompt = prepare_prompt(question, tokenizer) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) print("\n" + "=" * 80) print("DEEPCONF ONLINE MODE - FOLLOWING OFFICIAL PATTERN") print("=" * 80) print(f"\nQuestion: {question}") if ground_truth: print(f"Ground truth: {ground_truth}") print("\nConfiguration:") print(f" - Warmup traces: {warmup_traces}") print(f" - Variant: DeepConf-{confidence_variant}") print(f" - Window size: {window_size}") print(f" - Max tokens: {max_tokens}") print(f" - Temperature: {temperature}") print(f" - Top-p: {top_p}") # ============================================================ # PHASE 1: WARMUP - Generate multiple sequences to calibrate # ============================================================ print("\n" + "=" * 80) print(f"PHASE 1: WARMUP (Generating {warmup_traces} sequences for calibration)") print("=" * 80) warmup_config = GenerationConfig( do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, enable_conf=True, enable_early_stopping=False, # No stopping during warmup output_confidences=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id, ) # Expand inputs for batch generation expanded_ids = inputs.input_ids.repeat(warmup_traces, 1) if "attention_mask" in inputs and inputs.attention_mask is not None: expanded_mask = inputs.attention_mask.repeat(warmup_traces, 1) else: expanded_mask = None print(f"Generating {warmup_traces} warmup sequences...") warmup_outputs = model.generate( input_ids=expanded_ids, attention_mask=expanded_mask, generation_config=warmup_config, custom_generate="kashif/DeepConf", trust_remote_code=True, ) # Process warmup results warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=window_size) print("\nWarmup complete!") print(f" - Total tokens: {warmup_results['total_tokens']}") print(f" - Min confidences: {[round(c, 3) for c in warmup_results['min_confs']]}") # Show warmup traces print("\nWarmup Traces:") print("-" * 80) for i, trace in enumerate(warmup_results["traces"]): text = trace["text"][len(prompt) :].strip() answer = extract_answer(text) print(f"\nTrace {i + 1}:") print(f" Tokens: {trace['num_tokens']}, Min conf: {trace['min_conf']:.3f}") print(f" Text: {text[:80]}..." if len(text) > 80 else f" Text: {text}") if answer: print(f" Answer: {answer}") if ground_truth: correct = answer.strip() == ground_truth.strip() print(f" Correct: {'✓' if correct else '✗'}") # ============================================================ # PHASE 2: THRESHOLD COMPUTATION # ============================================================ print("\n" + "=" * 80) print("PHASE 2: THRESHOLD COMPUTATION") print("=" * 80) threshold = compute_warmup_threshold(warmup_results["min_confs"], variant=confidence_variant) eta = 0.1 if confidence_variant == "low" else 0.9 percentile = (1.0 - eta) * 100 print("\nComputed threshold from warmup:") print(f" - Variant: DeepConf-{confidence_variant} (eta={eta})") print(f" - Percentile: {percentile:.0f}th") print(f" - Threshold: {threshold:.3f}") print("\nInterpretation:") if confidence_variant == "low": print(" DeepConf-low is AGGRESSIVE - stops early to save tokens") else: print(" DeepConf-high is PERMISSIVE - allows longer generation") # ============================================================ # PHASE 3: FINAL GENERATION with calibrated threshold # ============================================================ print("\n" + "=" * 80) print("PHASE 3: FINAL GENERATION (With calibrated early stopping)") print("=" * 80) final_config = GenerationConfig( do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, enable_conf=True, enable_early_stopping=True, # Online stopping with calibrated threshold threshold=threshold, window_size=window_size, output_confidences=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id, ) print(f"Generating with DeepConf-{confidence_variant} (threshold={threshold:.3f})...") final_output = model.generate( **inputs, generation_config=final_config, custom_generate="kashif/DeepConf", trust_remote_code=True, ) final_text = tokenizer.decode(final_output.sequences[0], skip_special_tokens=True) final_tokens = final_output.sequences.shape[1] - inputs.input_ids.shape[1] final_answer = extract_answer(final_text) # Calculate min confidence if available if hasattr(final_output, "confidences") and final_output.confidences is not None: min_conf = final_output.confidences.min().item() mean_conf = final_output.confidences.mean().item() else: min_conf = None mean_conf = None print("\nFinal generation complete!") print(f" - Tokens generated: {final_tokens}") if min_conf is not None: print(f" - Min confidence: {min_conf:.3f}") print(f" - Mean confidence: {mean_conf:.3f}") print("\nGenerated text:") print("-" * 80) print(final_text) print("-" * 80) if final_answer: print(f"\nExtracted answer: {final_answer}") if ground_truth: correct = final_answer.strip() == ground_truth.strip() print(f"Correct: {'✓' if correct else '✗'}") # ============================================================ # SUMMARY # ============================================================ print("\n" + "=" * 80) print("SUMMARY") print("=" * 80) total_warmup_tokens = warmup_results["total_tokens"] total_tokens = total_warmup_tokens + final_tokens print(f"Total tokens: {total_tokens}") print(f" - Warmup: {total_warmup_tokens} ({warmup_traces} sequences)") print(f" - Final: {final_tokens}") # Check if we would have used more tokens without early stopping avg_warmup_tokens = total_warmup_tokens / warmup_traces potential_savings = avg_warmup_tokens - final_tokens if potential_savings > 0: print("\nToken savings from early stopping:") print(f" - Average warmup length: {avg_warmup_tokens:.1f} tokens") print(f" - Final length: {final_tokens} tokens") print(f" - Saved: {potential_savings:.1f} tokens ({potential_savings / avg_warmup_tokens * 100:.1f}%)") print("\n" + "=" * 80) print("Example complete!") print("=" * 80) if __name__ == "__main__": # Example 1: Simple math problem print("\n\n" + "█" * 80) print("EXAMPLE 1: Simple Math Problem") print("█" * 80) run_online_mode_example( question="What is 15 * 8? Show your work step by step.", ground_truth="120", warmup_traces=4, confidence_variant="low", window_size=5, max_tokens=64, ) # Example 2: Square root problem print("\n\n" + "█" * 80) print("EXAMPLE 2: Square Root Problem") print("█" * 80) run_online_mode_example( question="What is the square root of 144? Express your answer in the form \\boxed{answer}.", ground_truth="12", warmup_traces=4, confidence_variant="high", window_size=5, max_tokens=64, ) # Example 3: Word problem print("\n\n" + "█" * 80) print("EXAMPLE 3: Word Problem") print("█" * 80) run_online_mode_example( question="If a train travels 60 miles per hour for 2.5 hours, how far does it travel?", ground_truth="150", warmup_traces=4, confidence_variant="low", window_size=5, max_tokens=96, )