Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import random | |
| import gradio as gr | |
| import fal_client | |
| from groq import Groq | |
| # ------------------------- | |
| # Data | |
| # ------------------------- | |
| dataset = [ | |
| { | |
| "category": "Engagement", | |
| "prompt": "An informational post about cereals", | |
| "response": { | |
| "post": "A two-image post explaining the difference between different types of Basmati rice...", | |
| "caption": "Know your product! 🌾 November is rice awareness month..." | |
| } | |
| }, | |
| { | |
| "category": "Engagement", | |
| "prompt": "Generate a giveaway post on rice awareness month", | |
| "response": { | |
| "post": "This post is a video of rice grains falling into a bowl...", | |
| "caption": "Comment 'Country Delight Basmati Rice' and WIN 🏆..." | |
| } | |
| }, | |
| { | |
| "category": "Product Launch", | |
| "prompt": "A product launch post for Country Delight sweets on Diwali", | |
| "response": { | |
| "post": "An image introducing Country Delight sweets for Diwali...", | |
| "caption": "5 lucky WINNERS will get a hamper each of SWEETS 🍬 🍭..." | |
| } | |
| }, | |
| { | |
| "category": "Festival", | |
| "prompt": "A post to create engagement for Diwali", | |
| "response": { | |
| "post": "A compilation video featuring diya lighting, colorful rangoli...", | |
| "caption": "Diwali is all about lights, love, and sweets! 🪔..." | |
| } | |
| }, | |
| { | |
| "category": "Product Launch", | |
| "prompt": "A post launching Country Delight Aloo Bhujia Namkeen", | |
| "response": { | |
| "post": "A video of a person playing with a packet of Country Delight Aloo Bhujia...", | |
| "caption": "Dear namkeen lovers, sounds 🔊 ON please!..." | |
| } | |
| } | |
| ] | |
| def get_examples_for_category(category: str): | |
| examples = [d for d in dataset if d.get("category") == category] | |
| if not examples: | |
| raise gr.Error(f"No examples found for category: {category}") | |
| return examples | |
| def safe_json_loads(s: str): | |
| try: | |
| return json.loads(s) | |
| except json.JSONDecodeError: | |
| start, end = s.find("{"), s.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| try: | |
| return json.loads(s[start:end+1]) | |
| except Exception: | |
| return None | |
| return None | |
| # ------------------------- | |
| # Keys / Clients | |
| # ------------------------- | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("GROQ_API") | |
| FAL_KEY = os.getenv("FAL_KEY") | |
| if FAL_KEY: | |
| os.environ["FAL_KEY"] = FAL_KEY | |
| client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None | |
| # ------------------------- | |
| # Guardrails + language script hints | |
| # ------------------------- | |
| BRAND_SAFE_GUARDRAILS = """ | |
| BRAND-SAFE GUARDRAILS (must follow): | |
| - Do NOT make medical/health claims, cures, guaranteed outcomes, or absolute superlatives (e.g., “100% best”, “always”, “guaranteed”). | |
| - Do NOT mention competitors negatively or comparisons like “better than X brand”. | |
| - Avoid misinformation; keep claims generic and safe (e.g., “made with natural ingredients” only if implied, otherwise keep neutral). | |
| - Avoid sensitive/political content, hate, stereotypes, or derogatory language. | |
| - If unsure about a claim, rewrite it as a neutral, non-factual benefit (taste/experience/occasion/feeling). | |
| """ | |
| SCRIPT_HINTS = { | |
| "English": "Use English only.", | |
| "Hindi": "Use Devanagari script (हिंदी). No Hinglish/romanized.", | |
| "Tamil": "Use Tamil script (தமிழ்). No English/romanized.", | |
| "Bengali": "Use Bengali script (বাংলা). No English/romanized.", | |
| "Punjabi": "Use Gurmukhi script (ਪੰਜਾਬੀ). No English/romanized.", | |
| "Haryanvi": "Use Devanagari script (हरियाणवी). No Hinglish/romanized.", | |
| "Telugu": "Use Telugu script (తెలుగు). No English/romanized.", | |
| "Kannada": "Use Kannada script (ಕನ್ನಡ). No English/romanized.", | |
| "Gujarati": "Use Gujarati script (ગુજરાતી). No English/romanized.", | |
| "Rajasthani": "Use Devanagari script (राजस्थानी). No Hinglish/romanized." | |
| } | |
| # ------------------------- | |
| # LLM Functions | |
| # ------------------------- | |
| def generate_post(prompt, category, temperature, max_tokens): | |
| if not client: | |
| raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).") | |
| examples = get_examples_for_category(category) | |
| few = random.sample(examples, k=min(3, len(examples))) | |
| style = "\n\n".join([f"- Visual: {x['response']['post']}\n Caption: {x['response']['caption']}" for x in few]) | |
| completion = client.chat.completions.create( | |
| model="mixtral-8x7b-32768", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": f""" | |
| You are a marketing expert creating Instagram content for Country Delight (India). | |
| Learn the style from these examples: | |
| {style} | |
| {BRAND_SAFE_GUARDRAILS} | |
| Return STRICT JSON: | |
| {{ | |
| "output": {{ | |
| "post": "Detailed visual layout (frames, overlays, props, setting)", | |
| "caption": "Caption (can be English here as base draft) with emojis + CTA" | |
| }} | |
| }} | |
| """ | |
| }, | |
| {"role": "user", "content": f"category: {category}\nprompt: {prompt}"} | |
| ], | |
| temperature=float(temperature), | |
| max_tokens=int(max_tokens), | |
| response_format={"type": "json_object"}, | |
| ) | |
| raw = completion.choices[0].message.content | |
| obj = safe_json_loads(raw) | |
| if obj and "output" in obj: | |
| return obj["output"].get("post", raw), obj["output"].get("caption", raw) | |
| return raw, raw | |
| def personalise_post(post, caption, age_group, demography, gender, language, region_city, temperature, max_tokens): | |
| if not client: | |
| raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).") | |
| age_map = {"Youth": "youth", "Middle Aged": "middle aged", "Old": "old"} | |
| age_group_n = age_map.get(age_group, age_group) | |
| script_hint = SCRIPT_HINTS.get(language, "Use only the selected language in its native script. No romanization.") | |
| completion = client.chat.completions.create( | |
| model="gemma2-9b-it", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": f""" | |
| You are an Indian marketing expert with deep regional cultural context. | |
| TASK: | |
| Personalize the given Instagram post idea + caption for the cohort and region. | |
| - Add real regional hooks: local festivals/seasonal vibes, food habits, common phrases/idioms (in the target language), city cues if provided. | |
| - Keep brand voice warm, modern, and engaging. | |
| LANGUAGE REQUIREMENT (STRICT): | |
| - Caption MUST be ONLY in: {language}. | |
| - {script_hint} | |
| - Do NOT mix English words except “Country Delight” (brand) and unavoidable product names. | |
| - Do NOT use Hinglish/romanized. | |
| {BRAND_SAFE_GUARDRAILS} | |
| Return STRICT JSON: | |
| {{ | |
| "output": {{ | |
| "post": "Updated visual layout with regional elements (props, setting, wardrobe, text overlays)", | |
| "caption": "Caption in target language only" | |
| }} | |
| }} | |
| """ | |
| }, | |
| { | |
| "role": "user", | |
| "content": f""" | |
| post: {post} | |
| caption: {caption} | |
| age_group: {age_group_n} | |
| demography: {demography} | |
| gender: {gender} | |
| language: {language} | |
| region/city: {region_city or "not provided"} | |
| """ | |
| } | |
| ], | |
| temperature=float(temperature), | |
| max_tokens=int(max_tokens), | |
| response_format={"type": "json_object"}, | |
| ) | |
| raw = completion.choices[0].message.content | |
| obj = safe_json_loads(raw) | |
| if obj and "output" in obj: | |
| return obj["output"].get("post", raw), obj["output"].get("caption", raw) | |
| return raw, raw | |
| def generate_image_prompts(post_idea, visual_mode, n_frames, temperature): | |
| """ | |
| visual_mode: "Single" or "Carousel" | |
| n_frames: 2-5 | |
| returns: (prompts_text, prompts_list) | |
| """ | |
| if not client: | |
| raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).") | |
| if visual_mode == "Single": | |
| count = 1 | |
| else: | |
| count = int(n_frames) | |
| completion = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": f""" | |
| You are expert at generating SDXL-friendly image prompts. | |
| Given an Instagram post idea, generate {count} concise prompts. | |
| - Prompts should be in English (for image models). | |
| - Each prompt should specify: subject, setting, lighting, camera angle, composition, text overlay placement (if any), product placement. | |
| Return STRICT JSON: | |
| {{ | |
| "prompts": ["prompt1", "prompt2", ...] | |
| }} | |
| """ | |
| }, | |
| {"role": "user", "content": f"Post idea:\n{post_idea}"} | |
| ], | |
| temperature=float(temperature), | |
| max_tokens=600, | |
| response_format={"type": "json_object"}, | |
| ) | |
| raw = completion.choices[0].message.content | |
| obj = safe_json_loads(raw) | |
| prompts = (obj or {}).get("prompts") if obj else None | |
| if not prompts or not isinstance(prompts, list): | |
| # fallback | |
| prompts = [raw.strip()] | |
| prompts = prompts[:count] | |
| prompts_text = "\n\n".join([f"Frame {i+1}: {p}" for i, p in enumerate(prompts)]) | |
| return prompts_text, prompts | |
| def generate_image(selected_prompt): | |
| if not os.getenv("FAL_KEY"): | |
| raise gr.Error("FAL_KEY missing. Set FAL_KEY in environment.") | |
| def on_queue_update(update): | |
| if isinstance(update, fal_client.InProgress): | |
| for log in update.logs: | |
| print(log.get("message", "")) | |
| result = fal_client.subscribe( | |
| "fal-ai/fast-lightning-sdxl", | |
| arguments={"prompt": selected_prompt}, | |
| with_logs=True, | |
| on_queue_update=on_queue_update, | |
| ) | |
| url = result["images"][0]["url"] | |
| html = f""" | |
| <div style="display:flex;flex-direction:column;gap:8px;"> | |
| <img src="{url}" style="max-width:100%;border-radius:16px;" /> | |
| <a href="{url}" target="_blank">Open image in new tab</a> | |
| </div> | |
| """ | |
| return html, url | |
| # ------------------------- | |
| # UI | |
| # ------------------------- | |
| theme = gr.themes.Soft() | |
| with gr.Blocks(theme=theme, title="Personalised Marketing Campaign Creator") as demo: | |
| gr.Markdown("# ✨ Personalised Marketing Campaign Creator") | |
| gr.Markdown("Generate → Personalize (pure regional language) → Visualize (single or carousel)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| description = gr.Textbox(label="Post Description", placeholder="Eg: Diwali sweets giveaway / ASMR namkeen launch", lines=3) | |
| category = gr.Dropdown(choices=["Engagement", "Festival", "Product Launch"], label="Category", value="Engagement") | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Advanced settings", open=False): | |
| temperature = gr.Slider(0, 1.5, value=1.0, step=0.1, label="LLM Temperature") | |
| max_tokens = gr.Slider(256, 2048, value=1024, step=128, label="Max tokens") | |
| prompts_state = gr.State([]) | |
| with gr.Tabs(): | |
| with gr.Tab("1) Generate"): | |
| gen_btn = gr.Button("Generate Post", variant="primary") | |
| post_idea = gr.Textbox(label="Post Idea (Visual Layout)", interactive=False, lines=6) | |
| caption = gr.Textbox(label="Base Caption (draft)", interactive=False, lines=6) | |
| with gr.Tab("2) Personalize"): | |
| with gr.Row(): | |
| language = gr.Dropdown( | |
| choices=["English", "Hindi", "Tamil", "Bengali", "Punjabi", "Haryanvi", "Telugu", "Kannada", "Gujarati", "Rajasthani"], | |
| label="Language", | |
| value="Hindi" | |
| ) | |
| age_group = gr.Radio(choices=["Youth", "Middle Aged", "Old"], label="Age Group", value="Youth") | |
| gender = gr.Radio(choices=["Man", "Woman", "Couple"], label="Gender", value="Man") | |
| demography = gr.Radio(choices=["Urban", "Rural"], label="Demography", value="Urban") | |
| region_city = gr.Textbox(label="City/Region (optional)", placeholder="Eg: Chennai / Kolkata / Gurugram / Jaipur", lines=1) | |
| personalize_btn = gr.Button("Personalize (Pure Regional Language)", variant="primary") | |
| personalized_post = gr.Textbox(label="Personalized Post (Visual Layout)", interactive=False, lines=6) | |
| personalized_caption = gr.Textbox(label="Personalized Caption (target language ONLY)", interactive=False, lines=6) | |
| with gr.Tab("3) Visualize"): | |
| with gr.Row(): | |
| visual_mode = gr.Radio(choices=["Single", "Carousel"], value="Single", label="Visual Type (you decide)") | |
| n_frames = gr.Slider(2, 5, value=3, step=1, label="Carousel frames (2–5)", visible=False) | |
| def toggle_frames(mode): | |
| return gr.update(visible=(mode == "Carousel")) | |
| visual_mode.change(toggle_frames, inputs=[visual_mode], outputs=[n_frames]) | |
| prompt_btn = gr.Button("Generate Image Prompt(s)", variant="secondary") | |
| prompts_text = gr.Textbox(label="Prompt(s)", interactive=False, lines=8) | |
| frame_select = gr.Dropdown(label="Select frame to visualize", choices=["Frame 1"], value="Frame 1") | |
| selected_prompt = gr.Textbox(label="Selected Prompt", interactive=False, lines=3) | |
| visualize_btn = gr.Button("Visualize Selected Frame (SDXL)", variant="primary") | |
| image_html = gr.HTML() | |
| image_url = gr.Textbox(label="Image URL", interactive=False) | |
| # Wiring | |
| gen_btn.click(generate_post, [description, category, temperature, max_tokens], [post_idea, caption]) | |
| personalize_btn.click( | |
| personalise_post, | |
| [post_idea, caption, age_group, demography, gender, language, region_city, temperature, max_tokens], | |
| [personalized_post, personalized_caption] | |
| ) | |
| def build_prompts_and_dropdown(p_post, mode, frames, temp): | |
| text, prompts = generate_image_prompts(p_post, mode, frames, temp) | |
| # update dropdown choices | |
| choices = [f"Frame {i+1}" for i in range(len(prompts))] | |
| first = choices[0] | |
| sel_prompt = prompts[0] | |
| return text, prompts, gr.update(choices=choices, value=first), sel_prompt | |
| prompt_btn.click( | |
| build_prompts_and_dropdown, | |
| inputs=[personalized_post, visual_mode, n_frames, temperature], | |
| outputs=[prompts_text, prompts_state, frame_select, selected_prompt] | |
| ) | |
| def pick_frame(frame_label, prompts): | |
| idx = int(frame_label.replace("Frame ", "")) - 1 | |
| idx = max(0, min(idx, len(prompts)-1)) | |
| return prompts[idx] | |
| frame_select.change(pick_frame, inputs=[frame_select, prompts_state], outputs=[selected_prompt]) | |
| visualize_btn.click(generate_image, [selected_prompt], [image_html, image_url]) | |
| demo.launch(share=True) | |