|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer |
|
|
from flashpack.integrations.transformers import FlashPackTransformersModelMixin |
|
|
from transformers import AutoModelForCausalLM, pipeline as hf_pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin): |
|
|
"""Gemma 3 model wrapped with FlashPackTransformersModelMixin""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it" |
|
|
FLASHPACK_REPO = "rahul7star/FlashPack" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
print("📂 Loading model from FlashPack repository...") |
|
|
model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO) |
|
|
except FileNotFoundError: |
|
|
print("⚠️ FlashPack model not found. Loading from HF Hub and uploading FlashPack...") |
|
|
model = FlashPackGemmaModel.from_pretrained(MODEL_ID) |
|
|
model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True) |
|
|
print(f"✅ FlashPack model uploaded to Hugging Face Hub: {FLASHPACK_REPO}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe = hf_pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history): |
|
|
chat_history = chat_history or [] |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "Enhance and expand the following prompt with more details and context:"}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
outputs = pipe( |
|
|
prompt, |
|
|
max_new_tokens=int(max_tokens), |
|
|
temperature=float(temperature), |
|
|
do_sample=True |
|
|
) |
|
|
enhanced = outputs[0]["generated_text"].strip() |
|
|
|
|
|
chat_history.append({"role": "user", "content": user_prompt}) |
|
|
chat_history.append({"role": "assistant", "content": enhanced}) |
|
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ✨ Prompt Enhancer (Gemma 3 270M) |
|
|
Enter a short prompt, and the model will **expand it with details and creative context** |
|
|
using the Gemma chat-template interface. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") |
|
|
with gr.Column(scale=1): |
|
|
user_prompt = gr.Textbox( |
|
|
placeholder="Enter a short prompt...", |
|
|
label="Your Prompt", |
|
|
lines=3, |
|
|
) |
|
|
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") |
|
|
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") |
|
|
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary") |
|
|
clear_btn = gr.Button("🧹 Clear Chat") |
|
|
|
|
|
|
|
|
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
💡 **Tips:** |
|
|
- Works best with short, descriptive prompts (e.g., "a cat sitting on a chair") |
|
|
- Increase *Temperature* for more creative output. |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |
|
|
|