|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it" |
|
|
|
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID) |
|
|
|
|
|
pipe = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history): |
|
|
chat_history = chat_history or [] |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "Enhance and expand the following prompt with more details and context:"}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
output = pipe( |
|
|
prompt, |
|
|
max_new_tokens=int(max_tokens), |
|
|
temperature=float(temperature), |
|
|
do_sample=True, |
|
|
)[0]["generated_text"].strip() |
|
|
print(output) |
|
|
print(output[0]) |
|
|
|
|
|
|
|
|
chat_history.append({"role": "user", "content": user_prompt}) |
|
|
chat_history.append({"role": "assistant", "content": output}) |
|
|
|
|
|
return chat_history |
|
|
|
|
|
import re |
|
|
|
|
|
def extract_later_part(user_prompt, generated_text): |
|
|
""" |
|
|
Cleans the model output and extracts only the enhanced (later) portion. |
|
|
Removes prompt echoes and system tags like <end_of_turn>, <start_of_turn>, etc. |
|
|
""" |
|
|
|
|
|
cleaned = re.sub(r"<.*?>", "", generated_text) |
|
|
cleaned = cleaned.strip() |
|
|
|
|
|
|
|
|
cleaned = re.sub(r"\s+", " ", cleaned) |
|
|
|
|
|
|
|
|
user_prompt_clean = user_prompt.strip().lower() |
|
|
cleaned_lower = cleaned.lower() |
|
|
|
|
|
if cleaned_lower.startswith(user_prompt_clean): |
|
|
cleaned = cleaned[len(user_prompt):].strip(",. ").strip() |
|
|
|
|
|
return cleaned |
|
|
|
|
|
|
|
|
|
|
|
def enhance_prompt1(user_prompt, temperature, max_tokens, chat_history): |
|
|
messages = [ |
|
|
{"role": "system", "content": "Enhance and expand the following prompt with more details and context:"}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
output = pipe(prompt, max_new_tokens=256) |
|
|
raw_output = output[0]['generated_text'] |
|
|
|
|
|
print("=== RAW MODEL OUTPUT ===") |
|
|
print(raw_output) |
|
|
|
|
|
|
|
|
later_part = extract_later_part(user_prompt, raw_output) |
|
|
print("=== EXTRACTED CLEANED OUTPUT ===") |
|
|
print(later_part) |
|
|
|
|
|
|
|
|
chat_history = chat_history or [] |
|
|
chat_history.append({"role": "user", "content": user_prompt}) |
|
|
chat_history.append({"role": "assistant", "content": later_part}) |
|
|
|
|
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ✨ Prompt Enhancer (Gemma 3 270M) |
|
|
Enter a short prompt, and the model will **expand it with details and creative context** |
|
|
using the Gemma chat-template interface. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") |
|
|
with gr.Column(scale=1): |
|
|
user_prompt = gr.Textbox( |
|
|
placeholder="Enter a short prompt...", |
|
|
label="Your Prompt", |
|
|
lines=3, |
|
|
) |
|
|
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") |
|
|
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") |
|
|
send_btn = gr.Button("🚀 dev dont click", variant="primary") |
|
|
clear_btn = gr.Button("🧹 Clear Chat") |
|
|
add_btn = gr.Button("🚀 Enchance Prompt", variant="primary") |
|
|
|
|
|
|
|
|
|
|
|
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
add_btn.click(enhance_prompt1, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
💡 **Tips:** |
|
|
- Works best with short, descriptive prompts (e.g., "a cat sitting on a chair") |
|
|
- Increase *Temperature* for more creative output. |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |
|
|
|