import os import gc import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) import spaces import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig from threading import Thread, Event import time import uuid import re from diffusers import ChromaPipeline # Pre-load ONLY Chroma (not LLMs, to support custom models) print("Loading Chroma1-HD...") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device at module level: {device}") chroma_pipe = ChromaPipeline.from_pretrained( "lodestones/Chroma1-HD", torch_dtype=torch.bfloat16 ) chroma_pipe = chroma_pipe.to(device) print("✓ Chroma1-HD ready") MODEL_CONFIGS = { "Nekochu/Luminia-13B-v3": { "system": "", "examples": [ "### Instruction:\nCreate stable diffusion metadata based on the given english description. Luminia\n\n### Input:\nfavorites and popular SFW", "### Instruction:\nProvide tips on stable diffusion to optimize low token prompts and enhance quality include prompt example." ], "supports_image_gen": True, "sd_temp": 0.3, "sd_top_p": 0.8, "branch": None # Uses main/default branch }, "Nekochu/Luminia-8B-v4-Chan": { "system": "write a response like a 4chan user", "examples": [], "supports_image_gen": False, "branch": "Llama-3-8B-4Chan_SD_QLoRa" }, "Nekochu/Luminia-8B-RP": { "system": "You are a knowledgeable and empathetic mental health professional.", "examples": ["How to cope with anxiety?"], "supports_image_gen": False, "branch": None } } DEFAULT_MODELS = list(MODEL_CONFIGS.keys()) models_cache = {} stop_event = Event() current_thread = None MAX_CACHE_SIZE = 2 DEFAULT_MODEL = DEFAULT_MODELS[0] def parse_model_id(model_id_str): """Parse model ID and optional branch (format: 'model_id:branch')""" if ':' in model_id_str: parts = model_id_str.split(':', 1) return parts[0], parts[1] if model_id_str in MODEL_CONFIGS: # Check if it's a known model with a specific branch config = MODEL_CONFIGS[model_id_str] return model_id_str, config.get('branch', None) return model_id_str, None def parse_sd_metadata(text: str): """Parse SD metadata""" metadata = { 'prompt': '', 'negative_prompt': '', 'steps': 25, 'cfg_scale': 7.0, 'seed': 42, 'width': 1024, 'height': 1024 } if not text: metadata['prompt'] = '(masterpiece, best quality), 1girl' return metadata try: if "Negative prompt:" in text: parts = text.split("Negative prompt:", 1) metadata['prompt'] = parts[0].strip().rstrip('.,;')[:500] if len(parts) > 1: neg_section = parts[1] param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', neg_section) if param_match: metadata['negative_prompt'] = neg_section[:param_match.start()].strip().rstrip('.,;')[:300] else: metadata['negative_prompt'] = neg_section.strip().rstrip('.,;')[:300] else: param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', text) if param_match: metadata['prompt'] = text[:param_match.start()].strip().rstrip('.,;')[:500] else: metadata['prompt'] = text.strip()[:500] patterns = { 'Steps': (r'Steps:\s*(\d+)', lambda x: min(int(x), 30)), 'CFG scale': (r'CFG scale:\s*([\d.]+)', float), 'Seed': (r'Seed:\s*(\d+)', lambda x: int(x) % (2**32)), 'Size': (r'Size:\s*(\d+)x(\d+)', None) } for key, (pattern, converter) in patterns.items(): match = re.search(pattern, text) if match: try: if key == 'Size': metadata['width'] = min(max(int(match.group(1)), 512), 1536) metadata['height'] = min(max(int(match.group(2)), 512), 1536) else: metadata[key.lower().replace(' ', '_')] = converter(match.group(1)) except: pass except: pass if not metadata['prompt']: metadata['prompt'] = '(masterpiece, best quality), 1girl' return metadata def clear_old_cache(): global models_cache if len(models_cache) >= MAX_CACHE_SIZE: oldest = min(models_cache.items(), key=lambda x: x[1].get('last_used', 0)) del models_cache[oldest[0]] gc.collect() torch.cuda.empty_cache() @spaces.GPU(duration=119) def generate_text_gpu(model_id_str, message, history, system, temp, top_p, top_k, max_tokens, rep_penalty): """Text generation with branch support""" global models_cache, stop_event, current_thread stop_event.clear() model_id, branch = parse_model_id(model_id_str) # Parse model ID and branch cache_key = f"{model_id}:{branch}" if branch else model_id config = MODEL_CONFIGS.get(model_id, {}) if "Luminia-13B-v3" in model_id and ("stable diffusion" in message.lower() or "metadata" in message.lower()): temp = config.get('sd_temp', 0.3) top_p = config.get('sd_top_p', 0.8) print(f"Using SD settings: temp={temp}, top_p={top_p}") if cache_key not in models_cache: clear_old_cache() try: yield history + [[message, f"📥 Loading {model_id}{f' ({branch})' if branch else ''}..."]], "Loading..." # Load with branch/revision support load_kwargs = {"trust_remote_code": True} if branch: load_kwargs["revision"] = branch print(f"Loading from branch: {branch}") tokenizer = AutoTokenizer.from_pretrained(model_id, **load_kwargs) tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) model_kwargs = { "quantization_config": bnb_config, "device_map": "auto", "trust_remote_code": True, "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else None, "low_cpu_mem_usage": True } if branch: model_kwargs["revision"] = branch model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) models_cache[cache_key] = { "model": model, "tokenizer": tokenizer, "last_used": time.time() } except Exception as e: yield history + [[message, f"❌ Failed: {str(e)[:200]}"]], "Error" return models_cache[cache_key]['last_used'] = time.time() model = models_cache[cache_key]["model"] tokenizer = models_cache[cache_key]["tokenizer"] prompt = "" if system: prompt = f"{system}\n\n" for user_msg, assistant_msg in history: if "### Instruction:" in user_msg: prompt += f"{user_msg}\n### Response:\n{assistant_msg}\n\n" else: prompt += f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}\n\n" if "### Instruction:" in message and "### Response:" not in message: prompt += f"{message}\n### Response:\n" elif "### Instruction:" not in message: prompt += f"### Instruction:\n{message}\n\n### Response:\n" else: prompt += message print(f"Prompt ending: ...{prompt[-200:]}") try: inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) input_tokens = inputs['input_ids'].shape[1] inputs = {k: v.to(model.device) for k, v in inputs.items()} except Exception as e: yield history + [[message, f"❌ Tokenization failed: {str(e)}"]], "Error" return print(f"📝 {input_tokens} tokens | Temp: {temp} | Top-p: {top_p}") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=5) gen_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": min(max_tokens, 2048), "temperature": max(temp, 0.01), "top_p": top_p, "top_k": top_k, "repetition_penalty": rep_penalty, "do_sample": temp > 0.01, "pad_token_id": tokenizer.pad_token_id } current_thread = Thread(target=model.generate, kwargs=gen_kwargs) current_thread.start() start_time = time.time() partial = "" token_count = 0 try: for text in streamer: if stop_event.is_set(): break partial += text token_count = len(tokenizer.encode(partial, add_special_tokens=False)) elapsed = time.time() - start_time if elapsed > 0: yield history + [[message, partial]], f"⚡ {token_count} @ {token_count/elapsed:.1f} t/s" except: pass finally: if current_thread.is_alive(): stop_event.set() current_thread.join(timeout=2) final_time = time.time() - start_time yield history + [[message, partial]], f"✅ {token_count} tokens in {final_time:.1f}s" @spaces.GPU() def generate_image_gpu(text_output): """Image generation with pre-loaded Chroma""" global chroma_pipe if not text_output or text_output.isspace(): return None, "❌ No valid text", gr.update(visible=False) try: metadata = parse_sd_metadata(text_output) print(f"Generating: {metadata['width']}x{metadata['height']} | Steps: {metadata['steps']}") if torch.cuda.is_available(): chroma_pipe = chroma_pipe.to("cuda") generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(metadata['seed']) image = chroma_pipe( prompt=metadata['prompt'], negative_prompt=metadata['negative_prompt'], generator=generator, num_inference_steps=metadata['steps'], guidance_scale=metadata['cfg_scale'], width=metadata['width'], height=metadata['height'] ).images[0] status = f"✅ {metadata['width']}x{metadata['height']} | {metadata['steps']} steps | CFG: {metadata['cfg_scale']} | Seed: {metadata['seed']}" return image, status, gr.update(visible=False) except Exception as e: import traceback traceback.print_exc() return None, f"❌ Failed: {str(e)[:200]}", gr.update(visible=False) def stop_generation(): global stop_event, current_thread stop_event.set() if current_thread and current_thread.is_alive(): current_thread.join(timeout=2) return gr.update(visible=True), gr.update(visible=False) css = """ #chatbot {height: 305px;} #input-row {display: flex; gap: 4px;} #input-box {flex-grow: 1;} #button-group {display: inline-flex; flex-direction: column; gap: 2px; width: 45px;} #button-group button {width: 40px; height: 28px; padding: 2px; font-size: 14px;} #status {font-size: 11px; color: #666; margin-top: 2px;} #image-output {max-height: 400px; margin-top: 8px;} #img-loading {font-size: 11px; color: #666; margin: 4px 0;} """ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot(elem_id="chatbot") with gr.Row(elem_id="input-row"): msg = gr.Textbox( label="Instruction", lines=3, elem_id="input-box", value=MODEL_CONFIGS[DEFAULT_MODEL]["examples"][0] if MODEL_CONFIGS[DEFAULT_MODEL]["examples"] else "", scale=10 ) with gr.Column(elem_id="button-group", scale=1, min_width=45): submit = gr.Button("▶", variant="primary", size="sm") stop = gr.Button("⏹", variant="stop", size="sm", visible=False) undo = gr.Button("↩", size="sm") clear = gr.Button("🗑", size="sm") status = gr.Markdown("", elem_id="status") with gr.Row(): image_btn = gr.Button("🎨 Generate Image using Chroma1-HD", visible=False, variant="secondary") last_text = gr.Textbox(visible=False) img_loading = gr.Markdown("", visible=False, elem_id="img-loading") image_output = gr.Image(visible=False, elem_id="image-output") image_status = gr.Markdown("", visible=False) examples = gr.Examples( examples=[[ex] for ex in MODEL_CONFIGS[DEFAULT_MODEL]["examples"] if ex], inputs=msg, label="Examples" ) with gr.Column(scale=1): model = gr.Dropdown( DEFAULT_MODELS, value=DEFAULT_MODEL, label="Model", allow_custom_value=True, info="Custom HF ID + optional :branch" ) with gr.Accordion("Settings", open=False): system = gr.Textbox( label="System Prompt", value=MODEL_CONFIGS[DEFAULT_MODEL]["system"], lines=2 ) temp = gr.Slider(0.1, 1.0, 0.35, label="Temperature") top_p = gr.Slider(0.5, 1.0, 0.85, label="Top-p") top_k = gr.Slider(10, 100, 40, label="Top-k") rep_penalty = gr.Slider(1.0, 1.5, 1.1, label="Repetition Penalty") max_tokens = gr.Slider(256, 2048, 1024, label="Max Tokens") export_btn = gr.Button("💾 Export", size="sm") export_file = gr.File(visible=False) def update_ui_on_model_change(model_id_str): """Update all UI components when model changes""" model_id, branch = parse_model_id(model_id_str) config = MODEL_CONFIGS.get(model_id, {"system": "", "examples": [""], "supports_image_gen": False}) return ( config["system"], config["examples"][0] if config["examples"] else "", gr.update(visible=False), # image_btn "", # last_text None, # image_output (clear image) gr.update(visible=False), # image_output visibility "", # image_status text gr.update(visible=False), # image_status visibility gr.update(visible=False) # img_loading visibility ) def check_image_availability(model_id_str, history): model_id, _ = parse_model_id(model_id_str) if "Luminia-13B-v3" in model_id and history and len(history) > 0: return gr.update(visible=True), history[-1][1] return gr.update(visible=False), "" submit.click( lambda: (gr.update(visible=False), gr.update(visible=True)), None, [submit, stop] ).then( generate_text_gpu, [model, msg, chatbot, system, temp, top_p, top_k, max_tokens, rep_penalty], [chatbot, status] ).then( lambda: (gr.update(visible=True), gr.update(visible=False)), None, [submit, stop] ).then( check_image_availability, [model, chatbot], [image_btn, last_text] ) stop.click(stop_generation, None, [submit, stop]) image_btn.click( lambda: gr.update(value="🎨 Generating...", visible=True), None, img_loading ).then( generate_image_gpu, last_text, [image_output, image_status, img_loading] ).then( lambda img: (gr.update(visible=img is not None), gr.update(visible=True)), image_output, [image_output, image_status] ) model.change( update_ui_on_model_change, model, [system, msg, image_btn, last_text, image_output, image_output, image_status, image_status, img_loading] ) undo.click( lambda h: h[:-1] if h else h, chatbot, chatbot ).then( check_image_availability, [model, chatbot], [image_btn, last_text] ) clear.click( lambda: ([], "", "", None, "", gr.update(visible=False), "", gr.update(visible=False)), None, [chatbot, msg, status, image_output, image_status, image_btn, last_text, img_loading] ) def export_chat(history): if not history: return None content = "\n\n".join([f"User: {u}\n\nAssistant: {a}" for u, a in history]) path = f"chat_{uuid.uuid4().hex[:8]}.txt" with open(path, "w", encoding="utf-8") as f: f.write(content) return path export_btn.click(export_chat, chatbot, export_file).then( lambda: gr.update(visible=True), None, export_file ) demo.queue().launch()