Ator / app.py
lakshay315's picture
Create app.py
c051ff2 verified
import os
import json
import random
import gradio as gr
import fal_client
from groq import Groq
# -------------------------
# Data
# -------------------------
dataset = [
{
"category": "Engagement",
"prompt": "An informational post about cereals",
"response": {
"post": "A two-image post explaining the difference between different types of Basmati rice...",
"caption": "Know your product! 🌾 November is rice awareness month..."
}
},
{
"category": "Engagement",
"prompt": "Generate a giveaway post on rice awareness month",
"response": {
"post": "This post is a video of rice grains falling into a bowl...",
"caption": "Comment 'Country Delight Basmati Rice' and WIN 🏆..."
}
},
{
"category": "Product Launch",
"prompt": "A product launch post for Country Delight sweets on Diwali",
"response": {
"post": "An image introducing Country Delight sweets for Diwali...",
"caption": "5 lucky WINNERS will get a hamper each of SWEETS 🍬 🍭..."
}
},
{
"category": "Festival",
"prompt": "A post to create engagement for Diwali",
"response": {
"post": "A compilation video featuring diya lighting, colorful rangoli...",
"caption": "Diwali is all about lights, love, and sweets! 🪔..."
}
},
{
"category": "Product Launch",
"prompt": "A post launching Country Delight Aloo Bhujia Namkeen",
"response": {
"post": "A video of a person playing with a packet of Country Delight Aloo Bhujia...",
"caption": "Dear namkeen lovers, sounds 🔊 ON please!..."
}
}
]
def get_examples_for_category(category: str):
examples = [d for d in dataset if d.get("category") == category]
if not examples:
raise gr.Error(f"No examples found for category: {category}")
return examples
def safe_json_loads(s: str):
try:
return json.loads(s)
except json.JSONDecodeError:
start, end = s.find("{"), s.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(s[start:end+1])
except Exception:
return None
return None
# -------------------------
# Keys / Clients
# -------------------------
GROQ_API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("GROQ_API")
FAL_KEY = os.getenv("FAL_KEY")
if FAL_KEY:
os.environ["FAL_KEY"] = FAL_KEY
client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
# -------------------------
# Guardrails + language script hints
# -------------------------
BRAND_SAFE_GUARDRAILS = """
BRAND-SAFE GUARDRAILS (must follow):
- Do NOT make medical/health claims, cures, guaranteed outcomes, or absolute superlatives (e.g., “100% best”, “always”, “guaranteed”).
- Do NOT mention competitors negatively or comparisons like “better than X brand”.
- Avoid misinformation; keep claims generic and safe (e.g., “made with natural ingredients” only if implied, otherwise keep neutral).
- Avoid sensitive/political content, hate, stereotypes, or derogatory language.
- If unsure about a claim, rewrite it as a neutral, non-factual benefit (taste/experience/occasion/feeling).
"""
SCRIPT_HINTS = {
"English": "Use English only.",
"Hindi": "Use Devanagari script (हिंदी). No Hinglish/romanized.",
"Tamil": "Use Tamil script (தமிழ்). No English/romanized.",
"Bengali": "Use Bengali script (বাংলা). No English/romanized.",
"Punjabi": "Use Gurmukhi script (ਪੰਜਾਬੀ). No English/romanized.",
"Haryanvi": "Use Devanagari script (हरियाणवी). No Hinglish/romanized.",
"Telugu": "Use Telugu script (తెలుగు). No English/romanized.",
"Kannada": "Use Kannada script (ಕನ್ನಡ). No English/romanized.",
"Gujarati": "Use Gujarati script (ગુજરાતી). No English/romanized.",
"Rajasthani": "Use Devanagari script (राजस्थानी). No Hinglish/romanized."
}
# -------------------------
# LLM Functions
# -------------------------
def generate_post(prompt, category, temperature, max_tokens):
if not client:
raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).")
examples = get_examples_for_category(category)
few = random.sample(examples, k=min(3, len(examples)))
style = "\n\n".join([f"- Visual: {x['response']['post']}\n Caption: {x['response']['caption']}" for x in few])
completion = client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=[
{
"role": "system",
"content": f"""
You are a marketing expert creating Instagram content for Country Delight (India).
Learn the style from these examples:
{style}
{BRAND_SAFE_GUARDRAILS}
Return STRICT JSON:
{{
"output": {{
"post": "Detailed visual layout (frames, overlays, props, setting)",
"caption": "Caption (can be English here as base draft) with emojis + CTA"
}}
}}
"""
},
{"role": "user", "content": f"category: {category}\nprompt: {prompt}"}
],
temperature=float(temperature),
max_tokens=int(max_tokens),
response_format={"type": "json_object"},
)
raw = completion.choices[0].message.content
obj = safe_json_loads(raw)
if obj and "output" in obj:
return obj["output"].get("post", raw), obj["output"].get("caption", raw)
return raw, raw
def personalise_post(post, caption, age_group, demography, gender, language, region_city, temperature, max_tokens):
if not client:
raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).")
age_map = {"Youth": "youth", "Middle Aged": "middle aged", "Old": "old"}
age_group_n = age_map.get(age_group, age_group)
script_hint = SCRIPT_HINTS.get(language, "Use only the selected language in its native script. No romanization.")
completion = client.chat.completions.create(
model="gemma2-9b-it",
messages=[
{
"role": "system",
"content": f"""
You are an Indian marketing expert with deep regional cultural context.
TASK:
Personalize the given Instagram post idea + caption for the cohort and region.
- Add real regional hooks: local festivals/seasonal vibes, food habits, common phrases/idioms (in the target language), city cues if provided.
- Keep brand voice warm, modern, and engaging.
LANGUAGE REQUIREMENT (STRICT):
- Caption MUST be ONLY in: {language}.
- {script_hint}
- Do NOT mix English words except “Country Delight” (brand) and unavoidable product names.
- Do NOT use Hinglish/romanized.
{BRAND_SAFE_GUARDRAILS}
Return STRICT JSON:
{{
"output": {{
"post": "Updated visual layout with regional elements (props, setting, wardrobe, text overlays)",
"caption": "Caption in target language only"
}}
}}
"""
},
{
"role": "user",
"content": f"""
post: {post}
caption: {caption}
age_group: {age_group_n}
demography: {demography}
gender: {gender}
language: {language}
region/city: {region_city or "not provided"}
"""
}
],
temperature=float(temperature),
max_tokens=int(max_tokens),
response_format={"type": "json_object"},
)
raw = completion.choices[0].message.content
obj = safe_json_loads(raw)
if obj and "output" in obj:
return obj["output"].get("post", raw), obj["output"].get("caption", raw)
return raw, raw
def generate_image_prompts(post_idea, visual_mode, n_frames, temperature):
"""
visual_mode: "Single" or "Carousel"
n_frames: 2-5
returns: (prompts_text, prompts_list)
"""
if not client:
raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).")
if visual_mode == "Single":
count = 1
else:
count = int(n_frames)
completion = client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[
{
"role": "system",
"content": f"""
You are expert at generating SDXL-friendly image prompts.
Given an Instagram post idea, generate {count} concise prompts.
- Prompts should be in English (for image models).
- Each prompt should specify: subject, setting, lighting, camera angle, composition, text overlay placement (if any), product placement.
Return STRICT JSON:
{{
"prompts": ["prompt1", "prompt2", ...]
}}
"""
},
{"role": "user", "content": f"Post idea:\n{post_idea}"}
],
temperature=float(temperature),
max_tokens=600,
response_format={"type": "json_object"},
)
raw = completion.choices[0].message.content
obj = safe_json_loads(raw)
prompts = (obj or {}).get("prompts") if obj else None
if not prompts or not isinstance(prompts, list):
# fallback
prompts = [raw.strip()]
prompts = prompts[:count]
prompts_text = "\n\n".join([f"Frame {i+1}: {p}" for i, p in enumerate(prompts)])
return prompts_text, prompts
def generate_image(selected_prompt):
if not os.getenv("FAL_KEY"):
raise gr.Error("FAL_KEY missing. Set FAL_KEY in environment.")
def on_queue_update(update):
if isinstance(update, fal_client.InProgress):
for log in update.logs:
print(log.get("message", ""))
result = fal_client.subscribe(
"fal-ai/fast-lightning-sdxl",
arguments={"prompt": selected_prompt},
with_logs=True,
on_queue_update=on_queue_update,
)
url = result["images"][0]["url"]
html = f"""
<div style="display:flex;flex-direction:column;gap:8px;">
<img src="{url}" style="max-width:100%;border-radius:16px;" />
<a href="{url}" target="_blank">Open image in new tab</a>
</div>
"""
return html, url
# -------------------------
# UI
# -------------------------
theme = gr.themes.Soft()
with gr.Blocks(theme=theme, title="Personalised Marketing Campaign Creator") as demo:
gr.Markdown("# ✨ Personalised Marketing Campaign Creator")
gr.Markdown("Generate → Personalize (pure regional language) → Visualize (single or carousel)")
with gr.Row():
with gr.Column(scale=2):
description = gr.Textbox(label="Post Description", placeholder="Eg: Diwali sweets giveaway / ASMR namkeen launch", lines=3)
category = gr.Dropdown(choices=["Engagement", "Festival", "Product Launch"], label="Category", value="Engagement")
with gr.Column(scale=1):
with gr.Accordion("Advanced settings", open=False):
temperature = gr.Slider(0, 1.5, value=1.0, step=0.1, label="LLM Temperature")
max_tokens = gr.Slider(256, 2048, value=1024, step=128, label="Max tokens")
prompts_state = gr.State([])
with gr.Tabs():
with gr.Tab("1) Generate"):
gen_btn = gr.Button("Generate Post", variant="primary")
post_idea = gr.Textbox(label="Post Idea (Visual Layout)", interactive=False, lines=6)
caption = gr.Textbox(label="Base Caption (draft)", interactive=False, lines=6)
with gr.Tab("2) Personalize"):
with gr.Row():
language = gr.Dropdown(
choices=["English", "Hindi", "Tamil", "Bengali", "Punjabi", "Haryanvi", "Telugu", "Kannada", "Gujarati", "Rajasthani"],
label="Language",
value="Hindi"
)
age_group = gr.Radio(choices=["Youth", "Middle Aged", "Old"], label="Age Group", value="Youth")
gender = gr.Radio(choices=["Man", "Woman", "Couple"], label="Gender", value="Man")
demography = gr.Radio(choices=["Urban", "Rural"], label="Demography", value="Urban")
region_city = gr.Textbox(label="City/Region (optional)", placeholder="Eg: Chennai / Kolkata / Gurugram / Jaipur", lines=1)
personalize_btn = gr.Button("Personalize (Pure Regional Language)", variant="primary")
personalized_post = gr.Textbox(label="Personalized Post (Visual Layout)", interactive=False, lines=6)
personalized_caption = gr.Textbox(label="Personalized Caption (target language ONLY)", interactive=False, lines=6)
with gr.Tab("3) Visualize"):
with gr.Row():
visual_mode = gr.Radio(choices=["Single", "Carousel"], value="Single", label="Visual Type (you decide)")
n_frames = gr.Slider(2, 5, value=3, step=1, label="Carousel frames (2–5)", visible=False)
def toggle_frames(mode):
return gr.update(visible=(mode == "Carousel"))
visual_mode.change(toggle_frames, inputs=[visual_mode], outputs=[n_frames])
prompt_btn = gr.Button("Generate Image Prompt(s)", variant="secondary")
prompts_text = gr.Textbox(label="Prompt(s)", interactive=False, lines=8)
frame_select = gr.Dropdown(label="Select frame to visualize", choices=["Frame 1"], value="Frame 1")
selected_prompt = gr.Textbox(label="Selected Prompt", interactive=False, lines=3)
visualize_btn = gr.Button("Visualize Selected Frame (SDXL)", variant="primary")
image_html = gr.HTML()
image_url = gr.Textbox(label="Image URL", interactive=False)
# Wiring
gen_btn.click(generate_post, [description, category, temperature, max_tokens], [post_idea, caption])
personalize_btn.click(
personalise_post,
[post_idea, caption, age_group, demography, gender, language, region_city, temperature, max_tokens],
[personalized_post, personalized_caption]
)
def build_prompts_and_dropdown(p_post, mode, frames, temp):
text, prompts = generate_image_prompts(p_post, mode, frames, temp)
# update dropdown choices
choices = [f"Frame {i+1}" for i in range(len(prompts))]
first = choices[0]
sel_prompt = prompts[0]
return text, prompts, gr.update(choices=choices, value=first), sel_prompt
prompt_btn.click(
build_prompts_and_dropdown,
inputs=[personalized_post, visual_mode, n_frames, temperature],
outputs=[prompts_text, prompts_state, frame_select, selected_prompt]
)
def pick_frame(frame_label, prompts):
idx = int(frame_label.replace("Frame ", "")) - 1
idx = max(0, min(idx, len(prompts)-1))
return prompts[idx]
frame_select.change(pick_frame, inputs=[frame_select, prompts_state], outputs=[selected_prompt])
visualize_btn.click(generate_image, [selected_prompt], [image_html, image_url])
demo.launch(share=True)