DeepConf / example_online_mode.py
kashif's picture
kashif HF Staff
example scripts
c1cd11a
"""
Example usage of Online mode with warmup
This demonstrates:
1. Warmup phase (generate N sequences to calibrate threshold)
2. Threshold computation (DeepConf-low or DeepConf-high)
3. Final generation with calibrated early stopping
"""
from typing import Optional
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
def extract_answer(text: str) -> Optional[str]:
"""
Extract boxed answer from LaTeX text
Looks for \\boxed{answer} pattern in generated text.
"""
if "boxed" in text:
ans = text.split("boxed")[-1]
if len(ans) == 0:
return ""
elif ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split("$")[0].strip()
return a.strip()
return None
def compute_least_grouped(confs: list, group_size: int) -> list:
"""
Compute sliding window mean confidence
Args:
confs: List of per-token confidence values
group_size: Size of sliding window
Returns:
List of mean confidences for each window position
"""
if len(confs) < group_size:
return [sum(confs) / len(confs)] if confs else [0]
sliding_means = []
for i in range(len(confs) - group_size + 1):
window = confs[i : i + group_size]
sliding_means.append(round(sum(window) / len(window), 3))
return sliding_means
def process_single_output(
sequence, confidences, tokenizer, window_size: int, threshold: Optional[float] = None
) -> dict:
"""
Process a single generated sequence
Args:
sequence: Generated token IDs
confidences: Per-token confidence values (list or tensor)
tokenizer: Tokenizer for decoding
window_size: Size of sliding window for confidence
threshold: Optional threshold for early stopping detection
Returns:
Dictionary with trace data
"""
# Convert to list if tensor
if hasattr(confidences, "tolist"):
confs = confidences.tolist()
else:
confs = list(confidences)
# Decode text
text = tokenizer.decode(sequence, skip_special_tokens=True)
# Compute sliding window statistics
sliding_window = compute_least_grouped(confs, window_size)
min_conf = min(sliding_window) if sliding_window else 0
# Determine if early stopping would have triggered
stopped_early = False
stop_position = None
if threshold is not None:
for pos, window_mean in enumerate(sliding_window):
if window_mean < threshold:
stopped_early = True
stop_position = pos + window_size # Position in original sequence
break
# Extract answer if present
extracted_answer = extract_answer(text)
return {
"text": text,
"confs": confs,
"group_confs": sliding_window,
"min_conf": min_conf,
"stopped_early": stopped_early,
"stop_position": stop_position,
"extracted_answer": extracted_answer,
"num_tokens": len(confs),
"token_ids": sequence.tolist() if hasattr(sequence, "tolist") else list(sequence),
}
def process_batch_results(outputs, tokenizer, window_size: int = 2048, threshold: Optional[float] = None) -> dict:
"""
Process batch generation outputs
This function provides post-processing capabilities for batch-generated
sequences, allowing analysis of confidence patterns and early stopping
behavior after generation is complete.
Args:
outputs: GenerateDecoderOnlyOutput from model.generate()
tokenizer: Tokenizer for decoding sequences
window_size: Size of sliding window for confidence computation
threshold: Optional threshold for detecting where early stopping would occur
Returns:
Dictionary containing:
- traces: List of processed trace dictionaries
- min_confs: List of minimum confidences per trace
- total_tokens: Total tokens across all traces
- num_traces: Number of traces processed
"""
if not hasattr(outputs, "sequences"):
raise ValueError("outputs must have 'sequences' attribute")
if not hasattr(outputs, "confidences") or outputs.confidences is None:
raise ValueError("outputs must have 'confidences' attribute. Set output_confidences=True in generation_config")
sequences = outputs.sequences
confidences = outputs.confidences
# Process each sequence
traces = []
min_confs = []
total_tokens = 0
for i in range(sequences.shape[0]):
trace_data = process_single_output(sequences[i], confidences[i], tokenizer, window_size, threshold)
traces.append(trace_data)
min_confs.append(trace_data["min_conf"])
total_tokens += trace_data["num_tokens"]
return {"traces": traces, "min_confs": min_confs, "total_tokens": total_tokens, "num_traces": len(traces)}
def compute_warmup_threshold(min_confs: list, variant: str = "low", eta: Optional[float] = None) -> float:
"""
Compute threshold from warmup confidences
Args:
min_confs: List of minimum confidences from warmup sequences
variant: "low" (aggressive) or "high" (permissive)
eta: Optional manual eta value (overrides variant default)
Returns:
Computed threshold value
"""
if eta is None:
eta = 0.1 if variant == "low" else 0.9 if variant == "high" else 0.5
confs = np.asarray(min_confs, dtype=np.float32)
pct = max(0.0, min(100.0, 100.0 - (eta * 100.0)))
threshold = float(np.percentile(confs, pct))
return threshold
# ============================================================================
# Example Functions
# ============================================================================
def prepare_prompt(question: str, tokenizer):
"""Prepare prompt using chat template"""
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
def run_online_mode_example(
question: str,
ground_truth: Optional[str] = None,
warmup_traces: int = 8,
confidence_variant: str = "low", # "low" or "high"
window_size: int = 10,
max_tokens: int = 128,
temperature: float = 0.7,
top_p: float = 0.95,
):
"""
Run DeepConf in online mode
Args:
question: Question to answer
ground_truth: Optional ground truth answer for evaluation
warmup_traces: Number of warmup sequences (default: 8)
confidence_variant: "low" (aggressive) or "high" (permissive)
window_size: Sliding window size for confidence
max_tokens: Max tokens per generation
temperature: Sampling temperature
top_p: Top-p sampling
"""
# Load model (use local cache to avoid HF Hub timeouts)
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
print(f"Loading model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
local_files_only=True, # Use cached model
)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
# Prepare prompt
prompt = prepare_prompt(question, tokenizer)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
print("\n" + "=" * 80)
print("DEEPCONF ONLINE MODE - FOLLOWING OFFICIAL PATTERN")
print("=" * 80)
print(f"\nQuestion: {question}")
if ground_truth:
print(f"Ground truth: {ground_truth}")
print("\nConfiguration:")
print(f" - Warmup traces: {warmup_traces}")
print(f" - Variant: DeepConf-{confidence_variant}")
print(f" - Window size: {window_size}")
print(f" - Max tokens: {max_tokens}")
print(f" - Temperature: {temperature}")
print(f" - Top-p: {top_p}")
# ============================================================
# PHASE 1: WARMUP - Generate multiple sequences to calibrate
# ============================================================
print("\n" + "=" * 80)
print(f"PHASE 1: WARMUP (Generating {warmup_traces} sequences for calibration)")
print("=" * 80)
warmup_config = GenerationConfig(
do_sample=True,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_tokens,
enable_conf=True,
enable_early_stopping=False, # No stopping during warmup
output_confidences=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
)
# Expand inputs for batch generation
expanded_ids = inputs.input_ids.repeat(warmup_traces, 1)
if "attention_mask" in inputs and inputs.attention_mask is not None:
expanded_mask = inputs.attention_mask.repeat(warmup_traces, 1)
else:
expanded_mask = None
print(f"Generating {warmup_traces} warmup sequences...")
warmup_outputs = model.generate(
input_ids=expanded_ids,
attention_mask=expanded_mask,
generation_config=warmup_config,
custom_generate="kashif/DeepConf",
trust_remote_code=True,
)
# Process warmup results
warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=window_size)
print("\nWarmup complete!")
print(f" - Total tokens: {warmup_results['total_tokens']}")
print(f" - Min confidences: {[round(c, 3) for c in warmup_results['min_confs']]}")
# Show warmup traces
print("\nWarmup Traces:")
print("-" * 80)
for i, trace in enumerate(warmup_results["traces"]):
text = trace["text"][len(prompt) :].strip()
answer = extract_answer(text)
print(f"\nTrace {i + 1}:")
print(f" Tokens: {trace['num_tokens']}, Min conf: {trace['min_conf']:.3f}")
print(f" Text: {text[:80]}..." if len(text) > 80 else f" Text: {text}")
if answer:
print(f" Answer: {answer}")
if ground_truth:
correct = answer.strip() == ground_truth.strip()
print(f" Correct: {'βœ“' if correct else 'βœ—'}")
# ============================================================
# PHASE 2: THRESHOLD COMPUTATION
# ============================================================
print("\n" + "=" * 80)
print("PHASE 2: THRESHOLD COMPUTATION")
print("=" * 80)
threshold = compute_warmup_threshold(warmup_results["min_confs"], variant=confidence_variant)
eta = 0.1 if confidence_variant == "low" else 0.9
percentile = (1.0 - eta) * 100
print("\nComputed threshold from warmup:")
print(f" - Variant: DeepConf-{confidence_variant} (eta={eta})")
print(f" - Percentile: {percentile:.0f}th")
print(f" - Threshold: {threshold:.3f}")
print("\nInterpretation:")
if confidence_variant == "low":
print(" DeepConf-low is AGGRESSIVE - stops early to save tokens")
else:
print(" DeepConf-high is PERMISSIVE - allows longer generation")
# ============================================================
# PHASE 3: FINAL GENERATION with calibrated threshold
# ============================================================
print("\n" + "=" * 80)
print("PHASE 3: FINAL GENERATION (With calibrated early stopping)")
print("=" * 80)
final_config = GenerationConfig(
do_sample=True,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_tokens,
enable_conf=True,
enable_early_stopping=True, # Online stopping with calibrated threshold
threshold=threshold,
window_size=window_size,
output_confidences=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
)
print(f"Generating with DeepConf-{confidence_variant} (threshold={threshold:.3f})...")
final_output = model.generate(
**inputs,
generation_config=final_config,
custom_generate="kashif/DeepConf",
trust_remote_code=True,
)
final_text = tokenizer.decode(final_output.sequences[0], skip_special_tokens=True)
final_tokens = final_output.sequences.shape[1] - inputs.input_ids.shape[1]
final_answer = extract_answer(final_text)
# Calculate min confidence if available
if hasattr(final_output, "confidences") and final_output.confidences is not None:
min_conf = final_output.confidences.min().item()
mean_conf = final_output.confidences.mean().item()
else:
min_conf = None
mean_conf = None
print("\nFinal generation complete!")
print(f" - Tokens generated: {final_tokens}")
if min_conf is not None:
print(f" - Min confidence: {min_conf:.3f}")
print(f" - Mean confidence: {mean_conf:.3f}")
print("\nGenerated text:")
print("-" * 80)
print(final_text)
print("-" * 80)
if final_answer:
print(f"\nExtracted answer: {final_answer}")
if ground_truth:
correct = final_answer.strip() == ground_truth.strip()
print(f"Correct: {'βœ“' if correct else 'βœ—'}")
# ============================================================
# SUMMARY
# ============================================================
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
total_warmup_tokens = warmup_results["total_tokens"]
total_tokens = total_warmup_tokens + final_tokens
print(f"Total tokens: {total_tokens}")
print(f" - Warmup: {total_warmup_tokens} ({warmup_traces} sequences)")
print(f" - Final: {final_tokens}")
# Check if we would have used more tokens without early stopping
avg_warmup_tokens = total_warmup_tokens / warmup_traces
potential_savings = avg_warmup_tokens - final_tokens
if potential_savings > 0:
print("\nToken savings from early stopping:")
print(f" - Average warmup length: {avg_warmup_tokens:.1f} tokens")
print(f" - Final length: {final_tokens} tokens")
print(f" - Saved: {potential_savings:.1f} tokens ({potential_savings / avg_warmup_tokens * 100:.1f}%)")
print("\n" + "=" * 80)
print("Example complete!")
print("=" * 80)
if __name__ == "__main__":
# Example 1: Simple math problem
print("\n\n" + "β–ˆ" * 80)
print("EXAMPLE 1: Simple Math Problem")
print("β–ˆ" * 80)
run_online_mode_example(
question="What is 15 * 8? Show your work step by step.",
ground_truth="120",
warmup_traces=4,
confidence_variant="low",
window_size=5,
max_tokens=64,
)
# Example 2: Square root problem
print("\n\n" + "β–ˆ" * 80)
print("EXAMPLE 2: Square Root Problem")
print("β–ˆ" * 80)
run_online_mode_example(
question="What is the square root of 144? Express your answer in the form \\boxed{answer}.",
ground_truth="12",
warmup_traces=4,
confidence_variant="high",
window_size=5,
max_tokens=64,
)
# Example 3: Word problem
print("\n\n" + "β–ˆ" * 80)
print("EXAMPLE 3: Word Problem")
print("β–ˆ" * 80)
run_online_mode_example(
question="If a train travels 60 miles per hour for 2.5 hours, how far does it travel?",
ground_truth="150",
warmup_traces=4,
confidence_variant="low",
window_size=5,
max_tokens=96,
)