|
|
""" |
|
|
Simple examples showing DeepConf sample generations |
|
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
|
|
|
|
|
|
def generate_with_deepconf( |
|
|
question: str, |
|
|
enable_early_stopping: bool = True, |
|
|
threshold: float = 10.0, |
|
|
window_size: int = 10, |
|
|
max_tokens: int = 128, |
|
|
): |
|
|
"""Generate with DeepConf and show results""" |
|
|
|
|
|
|
|
|
model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, torch_dtype=torch.float16, device_map="auto", local_files_only=True |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) |
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": question}] |
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
gen_config = GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
max_new_tokens=max_tokens, |
|
|
enable_conf=True, |
|
|
enable_early_stopping=enable_early_stopping, |
|
|
threshold=threshold, |
|
|
window_size=window_size, |
|
|
output_confidences=True, |
|
|
return_dict_in_generate=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf", trust_remote_code=True) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) |
|
|
tokens_generated = outputs.sequences.shape[1] - inputs.input_ids.shape[1] |
|
|
|
|
|
if hasattr(outputs, "confidences") and outputs.confidences is not None: |
|
|
min_conf = outputs.confidences.min().item() |
|
|
max_conf = outputs.confidences.max().item() |
|
|
mean_conf = outputs.confidences.mean().item() |
|
|
else: |
|
|
min_conf = max_conf = mean_conf = None |
|
|
|
|
|
return { |
|
|
"text": generated_text, |
|
|
"tokens": tokens_generated, |
|
|
"min_conf": min_conf, |
|
|
"max_conf": max_conf, |
|
|
"mean_conf": mean_conf, |
|
|
} |
|
|
|
|
|
|
|
|
def print_result(title: str, question: str, result: dict): |
|
|
"""Pretty print generation result""" |
|
|
print(f"\n{'=' * 80}") |
|
|
print(f"{title}") |
|
|
print(f"{'=' * 80}") |
|
|
print(f"Question: {question}") |
|
|
print(f"\nGenerated ({result['tokens']} tokens):") |
|
|
print(f"{'-' * 80}") |
|
|
print(result["text"]) |
|
|
print(f"{'-' * 80}") |
|
|
|
|
|
if result["min_conf"] is not None: |
|
|
print("\nConfidence stats:") |
|
|
print(f" Min: {result['min_conf']:.3f}") |
|
|
print(f" Max: {result['max_conf']:.3f}") |
|
|
print(f" Mean: {result['mean_conf']:.3f}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n" + "β" * 80) |
|
|
print("DEEPCONF SAMPLE GENERATIONS") |
|
|
print("β" * 80) |
|
|
|
|
|
|
|
|
result = generate_with_deepconf( |
|
|
"What is 25 * 4?", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=64 |
|
|
) |
|
|
print_result("Example 1: Math (Aggressive Early Stopping)", "What is 25 * 4?", result) |
|
|
|
|
|
|
|
|
result = generate_with_deepconf( |
|
|
"What is 25 * 4?", enable_early_stopping=True, threshold=15.0, window_size=5, max_tokens=64 |
|
|
) |
|
|
print_result("Example 2: Math (Permissive Early Stopping)", "What is 25 * 4?", result) |
|
|
|
|
|
|
|
|
result = generate_with_deepconf("What is 25 * 4?", enable_early_stopping=False, max_tokens=64) |
|
|
print_result("Example 3: Math (No Early Stopping)", "What is 25 * 4?", result) |
|
|
|
|
|
|
|
|
result = generate_with_deepconf( |
|
|
"If 5 apples cost $10, how much do 3 apples cost?", |
|
|
enable_early_stopping=True, |
|
|
threshold=8.0, |
|
|
window_size=5, |
|
|
max_tokens=96, |
|
|
) |
|
|
print_result("Example 4: Word Problem", "If 5 apples cost $10, how much do 3 apples cost?", result) |
|
|
|
|
|
|
|
|
result = generate_with_deepconf( |
|
|
"Who wrote Romeo and Juliet?", enable_early_stopping=True, threshold=6.0, window_size=5, max_tokens=64 |
|
|
) |
|
|
print_result("Example 5: Factual Question", "Who wrote Romeo and Juliet?", result) |
|
|
|
|
|
|
|
|
result = generate_with_deepconf( |
|
|
"Calculate: (15 + 8) Γ 2", enable_early_stopping=True, threshold=7.0, window_size=5, max_tokens=96 |
|
|
) |
|
|
print_result("Example 6: Calculation", "Calculate: (15 + 8) Γ 2", result) |
|
|
|
|
|
|
|
|
result = generate_with_deepconf( |
|
|
"Define photosynthesis in simple terms.", |
|
|
enable_early_stopping=True, |
|
|
threshold=10.0, |
|
|
window_size=10, |
|
|
max_tokens=128, |
|
|
) |
|
|
print_result("Example 7: Definition", "Define photosynthesis in simple terms.", result) |
|
|
|
|
|
|
|
|
result = generate_with_deepconf( |
|
|
"Solve: x + 5 = 12. Show your steps.", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=96 |
|
|
) |
|
|
print_result("Example 8: Step-by-step Solution", "Solve: x + 5 = 12. Show your steps.", result) |
|
|
|
|
|
print(f"\n{'β' * 80}") |
|
|
print("ALL EXAMPLES COMPLETE") |
|
|
print("β" * 80) |
|
|
print("\nKey observations:") |
|
|
print("- Lower threshold β Earlier stopping (fewer tokens)") |
|
|
print("- Higher threshold β Later stopping (more tokens)") |
|
|
print("- No early stopping β Always generates max_tokens") |
|
|
print("- Confidence varies based on model certainty") |
|
|
|