|
|
""" |
|
|
Example usage of Online mode with warmup |
|
|
|
|
|
This demonstrates: |
|
|
1. Warmup phase (generate N sequences to calibrate threshold) |
|
|
2. Threshold computation (DeepConf-low or DeepConf-high) |
|
|
3. Final generation with calibrated early stopping |
|
|
""" |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
|
|
|
|
|
|
def extract_answer(text: str) -> Optional[str]: |
|
|
""" |
|
|
Extract boxed answer from LaTeX text |
|
|
|
|
|
Looks for \\boxed{answer} pattern in generated text. |
|
|
""" |
|
|
if "boxed" in text: |
|
|
ans = text.split("boxed")[-1] |
|
|
if len(ans) == 0: |
|
|
return "" |
|
|
elif ans[0] == "{": |
|
|
stack = 1 |
|
|
a = "" |
|
|
for c in ans[1:]: |
|
|
if c == "{": |
|
|
stack += 1 |
|
|
a += c |
|
|
elif c == "}": |
|
|
stack -= 1 |
|
|
if stack == 0: |
|
|
break |
|
|
a += c |
|
|
else: |
|
|
a += c |
|
|
else: |
|
|
a = ans.split("$")[0].strip() |
|
|
return a.strip() |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def compute_least_grouped(confs: list, group_size: int) -> list: |
|
|
""" |
|
|
Compute sliding window mean confidence |
|
|
|
|
|
Args: |
|
|
confs: List of per-token confidence values |
|
|
group_size: Size of sliding window |
|
|
|
|
|
Returns: |
|
|
List of mean confidences for each window position |
|
|
""" |
|
|
if len(confs) < group_size: |
|
|
return [sum(confs) / len(confs)] if confs else [0] |
|
|
|
|
|
sliding_means = [] |
|
|
for i in range(len(confs) - group_size + 1): |
|
|
window = confs[i : i + group_size] |
|
|
sliding_means.append(round(sum(window) / len(window), 3)) |
|
|
return sliding_means |
|
|
|
|
|
|
|
|
def process_single_output( |
|
|
sequence, confidences, tokenizer, window_size: int, threshold: Optional[float] = None |
|
|
) -> dict: |
|
|
""" |
|
|
Process a single generated sequence |
|
|
|
|
|
Args: |
|
|
sequence: Generated token IDs |
|
|
confidences: Per-token confidence values (list or tensor) |
|
|
tokenizer: Tokenizer for decoding |
|
|
window_size: Size of sliding window for confidence |
|
|
threshold: Optional threshold for early stopping detection |
|
|
|
|
|
Returns: |
|
|
Dictionary with trace data |
|
|
""" |
|
|
|
|
|
if hasattr(confidences, "tolist"): |
|
|
confs = confidences.tolist() |
|
|
else: |
|
|
confs = list(confidences) |
|
|
|
|
|
|
|
|
text = tokenizer.decode(sequence, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
sliding_window = compute_least_grouped(confs, window_size) |
|
|
min_conf = min(sliding_window) if sliding_window else 0 |
|
|
|
|
|
|
|
|
stopped_early = False |
|
|
stop_position = None |
|
|
|
|
|
if threshold is not None: |
|
|
for pos, window_mean in enumerate(sliding_window): |
|
|
if window_mean < threshold: |
|
|
stopped_early = True |
|
|
stop_position = pos + window_size |
|
|
break |
|
|
|
|
|
|
|
|
extracted_answer = extract_answer(text) |
|
|
|
|
|
return { |
|
|
"text": text, |
|
|
"confs": confs, |
|
|
"group_confs": sliding_window, |
|
|
"min_conf": min_conf, |
|
|
"stopped_early": stopped_early, |
|
|
"stop_position": stop_position, |
|
|
"extracted_answer": extracted_answer, |
|
|
"num_tokens": len(confs), |
|
|
"token_ids": sequence.tolist() if hasattr(sequence, "tolist") else list(sequence), |
|
|
} |
|
|
|
|
|
|
|
|
def process_batch_results(outputs, tokenizer, window_size: int = 2048, threshold: Optional[float] = None) -> dict: |
|
|
""" |
|
|
Process batch generation outputs |
|
|
|
|
|
This function provides post-processing capabilities for batch-generated |
|
|
sequences, allowing analysis of confidence patterns and early stopping |
|
|
behavior after generation is complete. |
|
|
|
|
|
Args: |
|
|
outputs: GenerateDecoderOnlyOutput from model.generate() |
|
|
tokenizer: Tokenizer for decoding sequences |
|
|
window_size: Size of sliding window for confidence computation |
|
|
threshold: Optional threshold for detecting where early stopping would occur |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- traces: List of processed trace dictionaries |
|
|
- min_confs: List of minimum confidences per trace |
|
|
- total_tokens: Total tokens across all traces |
|
|
- num_traces: Number of traces processed |
|
|
""" |
|
|
if not hasattr(outputs, "sequences"): |
|
|
raise ValueError("outputs must have 'sequences' attribute") |
|
|
|
|
|
if not hasattr(outputs, "confidences") or outputs.confidences is None: |
|
|
raise ValueError("outputs must have 'confidences' attribute. Set output_confidences=True in generation_config") |
|
|
|
|
|
sequences = outputs.sequences |
|
|
confidences = outputs.confidences |
|
|
|
|
|
|
|
|
traces = [] |
|
|
min_confs = [] |
|
|
total_tokens = 0 |
|
|
|
|
|
for i in range(sequences.shape[0]): |
|
|
trace_data = process_single_output(sequences[i], confidences[i], tokenizer, window_size, threshold) |
|
|
|
|
|
traces.append(trace_data) |
|
|
min_confs.append(trace_data["min_conf"]) |
|
|
total_tokens += trace_data["num_tokens"] |
|
|
|
|
|
return {"traces": traces, "min_confs": min_confs, "total_tokens": total_tokens, "num_traces": len(traces)} |
|
|
|
|
|
|
|
|
def compute_warmup_threshold(min_confs: list, variant: str = "low", eta: Optional[float] = None) -> float: |
|
|
""" |
|
|
Compute threshold from warmup confidences |
|
|
|
|
|
Args: |
|
|
min_confs: List of minimum confidences from warmup sequences |
|
|
variant: "low" (aggressive) or "high" (permissive) |
|
|
eta: Optional manual eta value (overrides variant default) |
|
|
|
|
|
Returns: |
|
|
Computed threshold value |
|
|
""" |
|
|
if eta is None: |
|
|
eta = 0.1 if variant == "low" else 0.9 if variant == "high" else 0.5 |
|
|
|
|
|
confs = np.asarray(min_confs, dtype=np.float32) |
|
|
pct = max(0.0, min(100.0, 100.0 - (eta * 100.0))) |
|
|
threshold = float(np.percentile(confs, pct)) |
|
|
|
|
|
return threshold |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_prompt(question: str, tokenizer): |
|
|
"""Prepare prompt using chat template""" |
|
|
messages = [{"role": "user", "content": question}] |
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
def run_online_mode_example( |
|
|
question: str, |
|
|
ground_truth: Optional[str] = None, |
|
|
warmup_traces: int = 8, |
|
|
confidence_variant: str = "low", |
|
|
window_size: int = 10, |
|
|
max_tokens: int = 128, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.95, |
|
|
): |
|
|
""" |
|
|
Run DeepConf in online mode |
|
|
|
|
|
Args: |
|
|
question: Question to answer |
|
|
ground_truth: Optional ground truth answer for evaluation |
|
|
warmup_traces: Number of warmup sequences (default: 8) |
|
|
confidence_variant: "low" (aggressive) or "high" (permissive) |
|
|
window_size: Sliding window size for confidence |
|
|
max_tokens: Max tokens per generation |
|
|
temperature: Sampling temperature |
|
|
top_p: Top-p sampling |
|
|
""" |
|
|
|
|
|
|
|
|
model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
|
|
print(f"Loading model: {model_name}") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
local_files_only=True, |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) |
|
|
|
|
|
|
|
|
prompt = prepare_prompt(question, tokenizer) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("DEEPCONF ONLINE MODE - FOLLOWING OFFICIAL PATTERN") |
|
|
print("=" * 80) |
|
|
print(f"\nQuestion: {question}") |
|
|
if ground_truth: |
|
|
print(f"Ground truth: {ground_truth}") |
|
|
print("\nConfiguration:") |
|
|
print(f" - Warmup traces: {warmup_traces}") |
|
|
print(f" - Variant: DeepConf-{confidence_variant}") |
|
|
print(f" - Window size: {window_size}") |
|
|
print(f" - Max tokens: {max_tokens}") |
|
|
print(f" - Temperature: {temperature}") |
|
|
print(f" - Top-p: {top_p}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print(f"PHASE 1: WARMUP (Generating {warmup_traces} sequences for calibration)") |
|
|
print("=" * 80) |
|
|
|
|
|
warmup_config = GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_new_tokens=max_tokens, |
|
|
enable_conf=True, |
|
|
enable_early_stopping=False, |
|
|
output_confidences=True, |
|
|
return_dict_in_generate=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
expanded_ids = inputs.input_ids.repeat(warmup_traces, 1) |
|
|
if "attention_mask" in inputs and inputs.attention_mask is not None: |
|
|
expanded_mask = inputs.attention_mask.repeat(warmup_traces, 1) |
|
|
else: |
|
|
expanded_mask = None |
|
|
|
|
|
print(f"Generating {warmup_traces} warmup sequences...") |
|
|
warmup_outputs = model.generate( |
|
|
input_ids=expanded_ids, |
|
|
attention_mask=expanded_mask, |
|
|
generation_config=warmup_config, |
|
|
custom_generate="kashif/DeepConf", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=window_size) |
|
|
|
|
|
print("\nWarmup complete!") |
|
|
print(f" - Total tokens: {warmup_results['total_tokens']}") |
|
|
print(f" - Min confidences: {[round(c, 3) for c in warmup_results['min_confs']]}") |
|
|
|
|
|
|
|
|
print("\nWarmup Traces:") |
|
|
print("-" * 80) |
|
|
for i, trace in enumerate(warmup_results["traces"]): |
|
|
text = trace["text"][len(prompt) :].strip() |
|
|
answer = extract_answer(text) |
|
|
print(f"\nTrace {i + 1}:") |
|
|
print(f" Tokens: {trace['num_tokens']}, Min conf: {trace['min_conf']:.3f}") |
|
|
print(f" Text: {text[:80]}..." if len(text) > 80 else f" Text: {text}") |
|
|
if answer: |
|
|
print(f" Answer: {answer}") |
|
|
if ground_truth: |
|
|
correct = answer.strip() == ground_truth.strip() |
|
|
print(f" Correct: {'β' if correct else 'β'}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("PHASE 2: THRESHOLD COMPUTATION") |
|
|
print("=" * 80) |
|
|
|
|
|
threshold = compute_warmup_threshold(warmup_results["min_confs"], variant=confidence_variant) |
|
|
|
|
|
eta = 0.1 if confidence_variant == "low" else 0.9 |
|
|
percentile = (1.0 - eta) * 100 |
|
|
|
|
|
print("\nComputed threshold from warmup:") |
|
|
print(f" - Variant: DeepConf-{confidence_variant} (eta={eta})") |
|
|
print(f" - Percentile: {percentile:.0f}th") |
|
|
print(f" - Threshold: {threshold:.3f}") |
|
|
print("\nInterpretation:") |
|
|
if confidence_variant == "low": |
|
|
print(" DeepConf-low is AGGRESSIVE - stops early to save tokens") |
|
|
else: |
|
|
print(" DeepConf-high is PERMISSIVE - allows longer generation") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("PHASE 3: FINAL GENERATION (With calibrated early stopping)") |
|
|
print("=" * 80) |
|
|
|
|
|
final_config = GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_new_tokens=max_tokens, |
|
|
enable_conf=True, |
|
|
enable_early_stopping=True, |
|
|
threshold=threshold, |
|
|
window_size=window_size, |
|
|
output_confidences=True, |
|
|
return_dict_in_generate=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
print(f"Generating with DeepConf-{confidence_variant} (threshold={threshold:.3f})...") |
|
|
final_output = model.generate( |
|
|
**inputs, |
|
|
generation_config=final_config, |
|
|
custom_generate="kashif/DeepConf", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
final_text = tokenizer.decode(final_output.sequences[0], skip_special_tokens=True) |
|
|
final_tokens = final_output.sequences.shape[1] - inputs.input_ids.shape[1] |
|
|
final_answer = extract_answer(final_text) |
|
|
|
|
|
|
|
|
if hasattr(final_output, "confidences") and final_output.confidences is not None: |
|
|
min_conf = final_output.confidences.min().item() |
|
|
mean_conf = final_output.confidences.mean().item() |
|
|
else: |
|
|
min_conf = None |
|
|
mean_conf = None |
|
|
|
|
|
print("\nFinal generation complete!") |
|
|
print(f" - Tokens generated: {final_tokens}") |
|
|
if min_conf is not None: |
|
|
print(f" - Min confidence: {min_conf:.3f}") |
|
|
print(f" - Mean confidence: {mean_conf:.3f}") |
|
|
|
|
|
print("\nGenerated text:") |
|
|
print("-" * 80) |
|
|
print(final_text) |
|
|
print("-" * 80) |
|
|
|
|
|
if final_answer: |
|
|
print(f"\nExtracted answer: {final_answer}") |
|
|
if ground_truth: |
|
|
correct = final_answer.strip() == ground_truth.strip() |
|
|
print(f"Correct: {'β' if correct else 'β'}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("SUMMARY") |
|
|
print("=" * 80) |
|
|
|
|
|
total_warmup_tokens = warmup_results["total_tokens"] |
|
|
total_tokens = total_warmup_tokens + final_tokens |
|
|
|
|
|
print(f"Total tokens: {total_tokens}") |
|
|
print(f" - Warmup: {total_warmup_tokens} ({warmup_traces} sequences)") |
|
|
print(f" - Final: {final_tokens}") |
|
|
|
|
|
|
|
|
avg_warmup_tokens = total_warmup_tokens / warmup_traces |
|
|
potential_savings = avg_warmup_tokens - final_tokens |
|
|
if potential_savings > 0: |
|
|
print("\nToken savings from early stopping:") |
|
|
print(f" - Average warmup length: {avg_warmup_tokens:.1f} tokens") |
|
|
print(f" - Final length: {final_tokens} tokens") |
|
|
print(f" - Saved: {potential_savings:.1f} tokens ({potential_savings / avg_warmup_tokens * 100:.1f}%)") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Example complete!") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("\n\n" + "β" * 80) |
|
|
print("EXAMPLE 1: Simple Math Problem") |
|
|
print("β" * 80) |
|
|
|
|
|
run_online_mode_example( |
|
|
question="What is 15 * 8? Show your work step by step.", |
|
|
ground_truth="120", |
|
|
warmup_traces=4, |
|
|
confidence_variant="low", |
|
|
window_size=5, |
|
|
max_tokens=64, |
|
|
) |
|
|
|
|
|
|
|
|
print("\n\n" + "β" * 80) |
|
|
print("EXAMPLE 2: Square Root Problem") |
|
|
print("β" * 80) |
|
|
|
|
|
run_online_mode_example( |
|
|
question="What is the square root of 144? Express your answer in the form \\boxed{answer}.", |
|
|
ground_truth="12", |
|
|
warmup_traces=4, |
|
|
confidence_variant="high", |
|
|
window_size=5, |
|
|
max_tokens=64, |
|
|
) |
|
|
|
|
|
|
|
|
print("\n\n" + "β" * 80) |
|
|
print("EXAMPLE 3: Word Problem") |
|
|
print("β" * 80) |
|
|
|
|
|
run_online_mode_example( |
|
|
question="If a train travels 60 miles per hour for 2.5 hours, how far does it travel?", |
|
|
ground_truth="150", |
|
|
warmup_traces=4, |
|
|
confidence_variant="low", |
|
|
window_size=5, |
|
|
max_tokens=96, |
|
|
) |
|
|
|