import os import json import random import gradio as gr import fal_client from groq import Groq # ------------------------- # Data # ------------------------- dataset = [ { "category": "Engagement", "prompt": "An informational post about cereals", "response": { "post": "A two-image post explaining the difference between different types of Basmati rice...", "caption": "Know your product! 🌾 November is rice awareness month..." } }, { "category": "Engagement", "prompt": "Generate a giveaway post on rice awareness month", "response": { "post": "This post is a video of rice grains falling into a bowl...", "caption": "Comment 'Country Delight Basmati Rice' and WIN 🏆..." } }, { "category": "Product Launch", "prompt": "A product launch post for Country Delight sweets on Diwali", "response": { "post": "An image introducing Country Delight sweets for Diwali...", "caption": "5 lucky WINNERS will get a hamper each of SWEETS 🍬 🍭..." } }, { "category": "Festival", "prompt": "A post to create engagement for Diwali", "response": { "post": "A compilation video featuring diya lighting, colorful rangoli...", "caption": "Diwali is all about lights, love, and sweets! 🪔..." } }, { "category": "Product Launch", "prompt": "A post launching Country Delight Aloo Bhujia Namkeen", "response": { "post": "A video of a person playing with a packet of Country Delight Aloo Bhujia...", "caption": "Dear namkeen lovers, sounds 🔊 ON please!..." } } ] def get_examples_for_category(category: str): examples = [d for d in dataset if d.get("category") == category] if not examples: raise gr.Error(f"No examples found for category: {category}") return examples def safe_json_loads(s: str): try: return json.loads(s) except json.JSONDecodeError: start, end = s.find("{"), s.rfind("}") if start != -1 and end != -1 and end > start: try: return json.loads(s[start:end+1]) except Exception: return None return None # ------------------------- # Keys / Clients # ------------------------- GROQ_API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("GROQ_API") FAL_KEY = os.getenv("FAL_KEY") if FAL_KEY: os.environ["FAL_KEY"] = FAL_KEY client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None # ------------------------- # Guardrails + language script hints # ------------------------- BRAND_SAFE_GUARDRAILS = """ BRAND-SAFE GUARDRAILS (must follow): - Do NOT make medical/health claims, cures, guaranteed outcomes, or absolute superlatives (e.g., “100% best”, “always”, “guaranteed”). - Do NOT mention competitors negatively or comparisons like “better than X brand”. - Avoid misinformation; keep claims generic and safe (e.g., “made with natural ingredients” only if implied, otherwise keep neutral). - Avoid sensitive/political content, hate, stereotypes, or derogatory language. - If unsure about a claim, rewrite it as a neutral, non-factual benefit (taste/experience/occasion/feeling). """ SCRIPT_HINTS = { "English": "Use English only.", "Hindi": "Use Devanagari script (हिंदी). No Hinglish/romanized.", "Tamil": "Use Tamil script (தமிழ்). No English/romanized.", "Bengali": "Use Bengali script (বাংলা). No English/romanized.", "Punjabi": "Use Gurmukhi script (ਪੰਜਾਬੀ). No English/romanized.", "Haryanvi": "Use Devanagari script (हरियाणवी). No Hinglish/romanized.", "Telugu": "Use Telugu script (తెలుగు). No English/romanized.", "Kannada": "Use Kannada script (ಕನ್ನಡ). No English/romanized.", "Gujarati": "Use Gujarati script (ગુજરાતી). No English/romanized.", "Rajasthani": "Use Devanagari script (राजस्थानी). No Hinglish/romanized." } # ------------------------- # LLM Functions # ------------------------- def generate_post(prompt, category, temperature, max_tokens): if not client: raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).") examples = get_examples_for_category(category) few = random.sample(examples, k=min(3, len(examples))) style = "\n\n".join([f"- Visual: {x['response']['post']}\n Caption: {x['response']['caption']}" for x in few]) completion = client.chat.completions.create( model="mixtral-8x7b-32768", messages=[ { "role": "system", "content": f""" You are a marketing expert creating Instagram content for Country Delight (India). Learn the style from these examples: {style} {BRAND_SAFE_GUARDRAILS} Return STRICT JSON: {{ "output": {{ "post": "Detailed visual layout (frames, overlays, props, setting)", "caption": "Caption (can be English here as base draft) with emojis + CTA" }} }} """ }, {"role": "user", "content": f"category: {category}\nprompt: {prompt}"} ], temperature=float(temperature), max_tokens=int(max_tokens), response_format={"type": "json_object"}, ) raw = completion.choices[0].message.content obj = safe_json_loads(raw) if obj and "output" in obj: return obj["output"].get("post", raw), obj["output"].get("caption", raw) return raw, raw def personalise_post(post, caption, age_group, demography, gender, language, region_city, temperature, max_tokens): if not client: raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).") age_map = {"Youth": "youth", "Middle Aged": "middle aged", "Old": "old"} age_group_n = age_map.get(age_group, age_group) script_hint = SCRIPT_HINTS.get(language, "Use only the selected language in its native script. No romanization.") completion = client.chat.completions.create( model="gemma2-9b-it", messages=[ { "role": "system", "content": f""" You are an Indian marketing expert with deep regional cultural context. TASK: Personalize the given Instagram post idea + caption for the cohort and region. - Add real regional hooks: local festivals/seasonal vibes, food habits, common phrases/idioms (in the target language), city cues if provided. - Keep brand voice warm, modern, and engaging. LANGUAGE REQUIREMENT (STRICT): - Caption MUST be ONLY in: {language}. - {script_hint} - Do NOT mix English words except “Country Delight” (brand) and unavoidable product names. - Do NOT use Hinglish/romanized. {BRAND_SAFE_GUARDRAILS} Return STRICT JSON: {{ "output": {{ "post": "Updated visual layout with regional elements (props, setting, wardrobe, text overlays)", "caption": "Caption in target language only" }} }} """ }, { "role": "user", "content": f""" post: {post} caption: {caption} age_group: {age_group_n} demography: {demography} gender: {gender} language: {language} region/city: {region_city or "not provided"} """ } ], temperature=float(temperature), max_tokens=int(max_tokens), response_format={"type": "json_object"}, ) raw = completion.choices[0].message.content obj = safe_json_loads(raw) if obj and "output" in obj: return obj["output"].get("post", raw), obj["output"].get("caption", raw) return raw, raw def generate_image_prompts(post_idea, visual_mode, n_frames, temperature): """ visual_mode: "Single" or "Carousel" n_frames: 2-5 returns: (prompts_text, prompts_list) """ if not client: raise gr.Error("GROQ API key missing. Set GROQ_API_KEY (or GROQ_API).") if visual_mode == "Single": count = 1 else: count = int(n_frames) completion = client.chat.completions.create( model="llama-3.1-8b-instant", messages=[ { "role": "system", "content": f""" You are expert at generating SDXL-friendly image prompts. Given an Instagram post idea, generate {count} concise prompts. - Prompts should be in English (for image models). - Each prompt should specify: subject, setting, lighting, camera angle, composition, text overlay placement (if any), product placement. Return STRICT JSON: {{ "prompts": ["prompt1", "prompt2", ...] }} """ }, {"role": "user", "content": f"Post idea:\n{post_idea}"} ], temperature=float(temperature), max_tokens=600, response_format={"type": "json_object"}, ) raw = completion.choices[0].message.content obj = safe_json_loads(raw) prompts = (obj or {}).get("prompts") if obj else None if not prompts or not isinstance(prompts, list): # fallback prompts = [raw.strip()] prompts = prompts[:count] prompts_text = "\n\n".join([f"Frame {i+1}: {p}" for i, p in enumerate(prompts)]) return prompts_text, prompts def generate_image(selected_prompt): if not os.getenv("FAL_KEY"): raise gr.Error("FAL_KEY missing. Set FAL_KEY in environment.") def on_queue_update(update): if isinstance(update, fal_client.InProgress): for log in update.logs: print(log.get("message", "")) result = fal_client.subscribe( "fal-ai/fast-lightning-sdxl", arguments={"prompt": selected_prompt}, with_logs=True, on_queue_update=on_queue_update, ) url = result["images"][0]["url"] html = f"""
Open image in new tab
""" return html, url # ------------------------- # UI # ------------------------- theme = gr.themes.Soft() with gr.Blocks(theme=theme, title="Personalised Marketing Campaign Creator") as demo: gr.Markdown("# ✨ Personalised Marketing Campaign Creator") gr.Markdown("Generate → Personalize (pure regional language) → Visualize (single or carousel)") with gr.Row(): with gr.Column(scale=2): description = gr.Textbox(label="Post Description", placeholder="Eg: Diwali sweets giveaway / ASMR namkeen launch", lines=3) category = gr.Dropdown(choices=["Engagement", "Festival", "Product Launch"], label="Category", value="Engagement") with gr.Column(scale=1): with gr.Accordion("Advanced settings", open=False): temperature = gr.Slider(0, 1.5, value=1.0, step=0.1, label="LLM Temperature") max_tokens = gr.Slider(256, 2048, value=1024, step=128, label="Max tokens") prompts_state = gr.State([]) with gr.Tabs(): with gr.Tab("1) Generate"): gen_btn = gr.Button("Generate Post", variant="primary") post_idea = gr.Textbox(label="Post Idea (Visual Layout)", interactive=False, lines=6) caption = gr.Textbox(label="Base Caption (draft)", interactive=False, lines=6) with gr.Tab("2) Personalize"): with gr.Row(): language = gr.Dropdown( choices=["English", "Hindi", "Tamil", "Bengali", "Punjabi", "Haryanvi", "Telugu", "Kannada", "Gujarati", "Rajasthani"], label="Language", value="Hindi" ) age_group = gr.Radio(choices=["Youth", "Middle Aged", "Old"], label="Age Group", value="Youth") gender = gr.Radio(choices=["Man", "Woman", "Couple"], label="Gender", value="Man") demography = gr.Radio(choices=["Urban", "Rural"], label="Demography", value="Urban") region_city = gr.Textbox(label="City/Region (optional)", placeholder="Eg: Chennai / Kolkata / Gurugram / Jaipur", lines=1) personalize_btn = gr.Button("Personalize (Pure Regional Language)", variant="primary") personalized_post = gr.Textbox(label="Personalized Post (Visual Layout)", interactive=False, lines=6) personalized_caption = gr.Textbox(label="Personalized Caption (target language ONLY)", interactive=False, lines=6) with gr.Tab("3) Visualize"): with gr.Row(): visual_mode = gr.Radio(choices=["Single", "Carousel"], value="Single", label="Visual Type (you decide)") n_frames = gr.Slider(2, 5, value=3, step=1, label="Carousel frames (2–5)", visible=False) def toggle_frames(mode): return gr.update(visible=(mode == "Carousel")) visual_mode.change(toggle_frames, inputs=[visual_mode], outputs=[n_frames]) prompt_btn = gr.Button("Generate Image Prompt(s)", variant="secondary") prompts_text = gr.Textbox(label="Prompt(s)", interactive=False, lines=8) frame_select = gr.Dropdown(label="Select frame to visualize", choices=["Frame 1"], value="Frame 1") selected_prompt = gr.Textbox(label="Selected Prompt", interactive=False, lines=3) visualize_btn = gr.Button("Visualize Selected Frame (SDXL)", variant="primary") image_html = gr.HTML() image_url = gr.Textbox(label="Image URL", interactive=False) # Wiring gen_btn.click(generate_post, [description, category, temperature, max_tokens], [post_idea, caption]) personalize_btn.click( personalise_post, [post_idea, caption, age_group, demography, gender, language, region_city, temperature, max_tokens], [personalized_post, personalized_caption] ) def build_prompts_and_dropdown(p_post, mode, frames, temp): text, prompts = generate_image_prompts(p_post, mode, frames, temp) # update dropdown choices choices = [f"Frame {i+1}" for i in range(len(prompts))] first = choices[0] sel_prompt = prompts[0] return text, prompts, gr.update(choices=choices, value=first), sel_prompt prompt_btn.click( build_prompts_and_dropdown, inputs=[personalized_post, visual_mode, n_frames, temperature], outputs=[prompts_text, prompts_state, frame_select, selected_prompt] ) def pick_frame(frame_label, prompts): idx = int(frame_label.replace("Frame ", "")) - 1 idx = max(0, min(idx, len(prompts)-1)) return prompts[idx] frame_select.change(pick_frame, inputs=[frame_select, prompts_state], outputs=[selected_prompt]) visualize_btn.click(generate_image, [selected_prompt], [image_html, image_url]) demo.launch(share=True)