DeepConf / example_simple_generations.py
kashif's picture
kashif HF Staff
example scripts
c1cd11a
"""
Simple examples showing DeepConf sample generations
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
def generate_with_deepconf(
question: str,
enable_early_stopping: bool = True,
threshold: float = 10.0,
window_size: int = 10,
max_tokens: int = 128,
):
"""Generate with DeepConf and show results"""
# Load model (cached)
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="auto", local_files_only=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
# Prepare prompt
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Configure generation
gen_config = GenerationConfig(
do_sample=True,
temperature=0.7,
top_p=0.95,
max_new_tokens=max_tokens,
enable_conf=True,
enable_early_stopping=enable_early_stopping,
threshold=threshold,
window_size=window_size,
output_confidences=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
)
# Generate
outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf", trust_remote_code=True)
# Extract results
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
tokens_generated = outputs.sequences.shape[1] - inputs.input_ids.shape[1]
if hasattr(outputs, "confidences") and outputs.confidences is not None:
min_conf = outputs.confidences.min().item()
max_conf = outputs.confidences.max().item()
mean_conf = outputs.confidences.mean().item()
else:
min_conf = max_conf = mean_conf = None
return {
"text": generated_text,
"tokens": tokens_generated,
"min_conf": min_conf,
"max_conf": max_conf,
"mean_conf": mean_conf,
}
def print_result(title: str, question: str, result: dict):
"""Pretty print generation result"""
print(f"\n{'=' * 80}")
print(f"{title}")
print(f"{'=' * 80}")
print(f"Question: {question}")
print(f"\nGenerated ({result['tokens']} tokens):")
print(f"{'-' * 80}")
print(result["text"])
print(f"{'-' * 80}")
if result["min_conf"] is not None:
print("\nConfidence stats:")
print(f" Min: {result['min_conf']:.3f}")
print(f" Max: {result['max_conf']:.3f}")
print(f" Mean: {result['mean_conf']:.3f}")
if __name__ == "__main__":
print("\n" + "β–ˆ" * 80)
print("DEEPCONF SAMPLE GENERATIONS")
print("β–ˆ" * 80)
# Example 1: Math with aggressive early stopping
result = generate_with_deepconf(
"What is 25 * 4?", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=64
)
print_result("Example 1: Math (Aggressive Early Stopping)", "What is 25 * 4?", result)
# Example 2: Math with permissive early stopping
result = generate_with_deepconf(
"What is 25 * 4?", enable_early_stopping=True, threshold=15.0, window_size=5, max_tokens=64
)
print_result("Example 2: Math (Permissive Early Stopping)", "What is 25 * 4?", result)
# Example 3: Math without early stopping
result = generate_with_deepconf("What is 25 * 4?", enable_early_stopping=False, max_tokens=64)
print_result("Example 3: Math (No Early Stopping)", "What is 25 * 4?", result)
# Example 4: Reasoning question
result = generate_with_deepconf(
"If 5 apples cost $10, how much do 3 apples cost?",
enable_early_stopping=True,
threshold=8.0,
window_size=5,
max_tokens=96,
)
print_result("Example 4: Word Problem", "If 5 apples cost $10, how much do 3 apples cost?", result)
# Example 5: Factual question
result = generate_with_deepconf(
"Who wrote Romeo and Juliet?", enable_early_stopping=True, threshold=6.0, window_size=5, max_tokens=64
)
print_result("Example 5: Factual Question", "Who wrote Romeo and Juliet?", result)
# Example 6: Calculation
result = generate_with_deepconf(
"Calculate: (15 + 8) Γ— 2", enable_early_stopping=True, threshold=7.0, window_size=5, max_tokens=96
)
print_result("Example 6: Calculation", "Calculate: (15 + 8) Γ— 2", result)
# Example 7: Definition
result = generate_with_deepconf(
"Define photosynthesis in simple terms.",
enable_early_stopping=True,
threshold=10.0,
window_size=10,
max_tokens=128,
)
print_result("Example 7: Definition", "Define photosynthesis in simple terms.", result)
# Example 8: Step-by-step
result = generate_with_deepconf(
"Solve: x + 5 = 12. Show your steps.", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=96
)
print_result("Example 8: Step-by-step Solution", "Solve: x + 5 = 12. Show your steps.", result)
print(f"\n{'β–ˆ' * 80}")
print("ALL EXAMPLES COMPLETE")
print("β–ˆ" * 80)
print("\nKey observations:")
print("- Lower threshold β†’ Earlier stopping (fewer tokens)")
print("- Higher threshold β†’ Later stopping (more tokens)")
print("- No early stopping β†’ Always generates max_tokens")
print("- Confidence varies based on model certainty")