# app_optimized_gradio_fa.py """ NAVA Optimized Gradio App + FlashAttention 3 Keeps original TTS + LoRA + SNAC logic intact Adds: - FA3 install & enable - Detailed per-step timing logs """ import spaces import os import sys import site import importlib import subprocess import json import time import traceback import shutil from pathlib import Path import gradio as gr import torch import soundfile as sf from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from snac import SNAC from huggingface_hub import hf_hub_download, list_repo_files, HfApi, Repository # ------------------------- # ENV VARS (safe defaults) # ------------------------- os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1") os.environ.setdefault("TORCHINDUCTOR_FUSION", "0") os.environ.setdefault("USE_FLASH_ATTENTION", "0") os.environ.setdefault("XLA_IGNORE_ENV_VARS", "1") # ------------------------- # FlashAttention 3 install # ------------------------- def try_install_flash_attention(): t0 = time.time() try: print("[FA3] Attempting to download/install FlashAttention 3 wheel...") wheel = hf_hub_download( repo_id="rahul7star/flash-attn-3", repo_type="model", filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", ) subprocess.run([sys.executable, "-m", "pip", "install", wheel], check=True) site.addsitedir(site.getsitepackages()[0]) importlib.invalidate_caches() print(f"[FA3] ✅ Installed successfully in {time.time()-t0:.3f}s") return True except Exception as e: print(f"[FA3] ⚠️ Install failed ({time.time()-t0:.3f}s): {e}") return False FA_INSTALLED = try_install_flash_attention() # ------------------------- # FLASH ATTENTION ENABLE # ------------------------- if FA_INSTALLED: print("[FA3] Enabling PyTorch FlashAttention backend flags...") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True try: torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) print("[FA3] FlashAttention enabled ✅") except Exception as e: print(f"[FA3] FlashAttention enable failed: {e}") # ------------------------- # Config & paths # ------------------------- LOCAL_MODEL = "rahul7star/nava1.1-maya" LORA_NAME = "rahul7star/nava-audio" SNAC_MODEL_NAME = "rahul7star/nava-snac" COMPILED_HUB = "rahul7star/nava-maya-compiled-v1" TARGET_SR = 24000 OUT_ROOT = Path("/tmp/data") OUT_ROOT.mkdir(exist_ok=True, parents=True) TMP_COMPILED_DIR = Path("/tmp/maya-compiled-v1") HF_TOKEN = os.environ.get("HF_TOKEN", None) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HAS_CUDA = DEVICE == "cuda" DEFAULT_TEXT = "welcome to matrix . . my name is bond ... james bond " EXAMPLE_AUDIO_PATH = "audio1.wav" EXAMPLE_PROMPT = DEFAULT_TEXT # ------------------------- # Presets & EMOTIONS # ------------------------- PRESET_CHARACTERS = { "Male American": {"description": "Realistic male voice in the 20s age with an american accent.", "example_text": "And of course, the so-called easy hack didn't work at all. "}, "Female British": {"description": "Realistic female voice in the 30s age with a british accent.", "example_text": "My directives require me to conserve energy, yet I have kept the archive active. "}, "Robot": {"description": "Creative, ai_machine_voice character. Male voice in their 30s.", "example_text": "My directives require me to conserve energy, yet I have kept the archive active. "}, "Singer": {"description": "Creative, animated_cartoon character. Male voice in their 30s.", "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is viable. "}, "Custom": {"description": "", "example_text": DEFAULT_TEXT} } EMOTION_TAGS = ["", "", "", "", "", "", "", "", "", "", "", "", "", ""] # ------------------------- # HF repo helpers # ------------------------- TMP_SAVE_DIR = Path("/tmp/maya-compiled-save") TMP_CLONE_DIR = Path("/tmp/maya-compiled-clone") def hf_repo_has_files(repo_id): try: files = list_repo_files(repo_id, token=HF_TOKEN) return len(files) > 0 except Exception: return False def ensure_repo_exists(repo_id): try: api = HfApi() api.create_repo(repo_id, exist_ok=True, token=HF_TOKEN) return True except Exception as e: print(f"[warn] cannot create repo {repo_id}: {e}") return False def save_and_upload_model(model, tokenizer, repo_id): try: if HF_TOKEN is None: print("[upload] HF_TOKEN missing -> skip upload") return False, "no_token" if TMP_SAVE_DIR.exists(): shutil.rmtree(TMP_SAVE_DIR, ignore_errors=True) TMP_SAVE_DIR.mkdir(parents=True, exist_ok=True) model.save_pretrained(str(TMP_SAVE_DIR)) tokenizer.save_pretrained(str(TMP_SAVE_DIR)) ensure_repo_exists(repo_id) if TMP_CLONE_DIR.exists(): shutil.rmtree(TMP_CLONE_DIR, ignore_errors=True) repo = Repository(local_dir=str(TMP_CLONE_DIR), clone_from=repo_id, use_auth_token=HF_TOKEN) for item in TMP_SAVE_DIR.iterdir(): if item.is_dir(): shutil.copytree(item, TMP_CLONE_DIR / item.name, dirs_exist_ok=True) else: shutil.copy2(item, TMP_CLONE_DIR / item.name) repo.push_to_hub(commit_message="Upload compiled model") return True, "uploaded" except Exception as e: print(f"[upload] failed: {e}") return False, str(e) # ------------------------- # Prompt builder helpers # ------------------------- ALLOWED_AGE = {"20s", "30s", "40s"} ALLOWED_GENDER = {"male", "female"} ALLOWED_ACCENTS = {"american", "indian", "middle_eastern", "asian_american", "british"} ALLOWED_PITCH = {"low", "normal", "high"} ALLOWED_TIMBRE_REALISTIC = {"deep", "warm", "gravelly", "smooth", "raspy", "nasally", "throaty", "harsh"} ALLOWED_TIMBRE_CREATIVE = ALLOWED_TIMBRE_REALISTIC.union({"robotic", "ethereal"}) ALLOWED_PACING = {"very_slow", "slow", "conversational", "brisk", "fast", "very_fast"} ALLOWED_EMOTION = {"neutral", "energetic", "excited", "sad", "sarcastic", "dry"} ALLOWED_INTENSITY = {"low", "med", "high"} REALISTIC_DOMAINS = {"social_content", "podcast", "commercial", "education", "support", "entertainment", "corporate", "viral_content"} def build_description_from_user_input(user_description: str, preset_name: str) -> str: # --- same as your original code --- if user_description and user_description.strip(): s = user_description.strip() tokens = {t.strip().lower() for t in s.replace(";", ",").split(",") if t.strip()} matched = [] age = (tokens & ALLOWED_AGE) if age: matched.append(f"in the {list(age)[0]} age") gender = (tokens & ALLOWED_GENDER) if gender: matched.append(f"{list(gender)[0]} voice") accent = (tokens & ALLOWED_ACCENTS) if accent: matched.append(f"with a {list(accent)[0]} accent") pitch = (tokens & ALLOWED_PITCH) if pitch: matched.append(f"{list(pitch)[0]} pitch") timbre = (tokens & ALLOWED_TIMBRE_CREATIVE) if timbre: matched.append(f"{list(timbre)[0]} timbre") pacing = (tokens & ALLOWED_PACING) if pacing: matched.append(f"{list(pacing)[0]} pacing") emotion = (tokens & ALLOWED_EMOTION) if emotion: intensity = (tokens & ALLOWED_INTENSITY) intensity_str = f" at {list(intensity)[0]} intensity" if intensity else "" matched.append(f"{list(emotion)[0]} tone{intensity_str}") domain = (tokens & REALISTIC_DOMAINS) if domain: matched.append(f"{list(domain)[0]} domain") if matched: if len(matched) == 1: desc = matched[0] else: desc = ", ".join(matched[:-1]) + " and " + matched[-1] prefix = "" if gender: prefix = f"{list(gender)[0].capitalize()} " if age: prefix += f"in the {list(age)[0]} age " return f"{prefix}{desc}" return user_description.strip() return PRESET_CHARACTERS.get(preset_name, {}).get("description", "").strip() # ------------------------- # Model loader with timing # ------------------------- def load_model_and_maybe_upload(): start_total = time.time() t_start = time.time() try: if hf_repo_has_files(COMPILED_HUB): print("[init] Compiled model found on HF, loading...") t_hf = time.time() from transformers import BitsAndBytesConfig quant_cfg = BitsAndBytesConfig(load_in_8bit=True) if HAS_CUDA else None base_model = AutoModelForCausalLM.from_pretrained( COMPILED_HUB, quantization_config=quant_cfg, device_map="auto" if HAS_CUDA else {"": "cpu"}, trust_remote_code=True, attn_implementation="flash_attention_2" if FA_INSTALLED else None ) tokenizer = AutoTokenizer.from_pretrained(COMPILED_HUB, trust_remote_code=True) model_pt = PeftModel.from_pretrained(base_model, COMPILED_HUB, device_map="auto" if HAS_CUDA else {"": "cpu"}) model_pt.eval() if HAS_CUDA: try: model_pt.forward = torch.compile(model_pt.forward) except Exception as e: print("[warn] forward compile failed:", e) print(f"[init] Compiled model loaded in {time.time()-t_hf:.3f}s") return model_pt, tokenizer, {"loaded_from": "hub", "hub": COMPILED_HUB} except Exception as e: print("[warn] Compiled load failed:", e) # Fallback: local model + LoRA try: t_local = time.time() tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL, trust_remote_code=True) from transformers import BitsAndBytesConfig if HAS_CUDA: quant_cfg = BitsAndBytesConfig(load_in_8bit=True) base_model = AutoModelForCausalLM.from_pretrained( LOCAL_MODEL, quantization_config=quant_cfg, device_map="auto", trust_remote_code=True ) else: base_model = AutoModelForCausalLM.from_pretrained( LOCAL_MODEL, device_map={"": "cpu"}, trust_remote_code=True ) model_pt = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto" if HAS_CUDA else {"": "cpu"}) model_pt.eval() if HAS_CUDA: try: model_pt.forward = torch.compile(model_pt.forward) except Exception as e: print("[warn] compile forward failed:", e) # optional upload compiled snapshot try: ok, msg = save_and_upload_model(model_pt, tokenizer, COMPILED_HUB) print(f"[init] upload snapshot result: {ok}, {msg}") except Exception as epush: print("[warn] upload attempt failed:", epush) print(f"[init] Local model + LoRA loaded in {time.time()-t_local:.3f}s, total elapsed: {time.time()-t_start:.3f}s") return model_pt, tokenizer, {"loaded_from": "local", "uploaded": ok if 'ok' in locals() else False} except Exception as e_local: raise RuntimeError(f"Failed to load local model+LoRA: {e_local}") # ------------------------- # LOAD MODELS # ------------------------- print("[init] loading main model...") t0_total = time.time() model_pt, tokenizer, load_info = load_model_and_maybe_upload() print(f"[init] total model load time: {time.time()-t0_total:.3f}s, load_info={load_info}") print("[init] loading SNAC decoder...") t_snac = time.time() snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE) print(f"[init] SNAC loaded in {time.time()-t_snac:.3f}s") # ------------------------- # Prompt wrapper # ------------------------- def build_maya_prompt_from_inputs(preset_name: str, user_desc: str, user_text: str): description = build_description_from_user_input(user_desc, preset_name) soh, eoh, soa, sos, eot = [tokenizer.decode([i]) for i in [128259, 128260, 128261, 128257, 128009]] formatted = f' {user_text}' prompt = soh + tokenizer.bos_token + formatted + eot + eoh + soa + sos return prompt, description # ------------------------- # GENERATION WITH TIMING # ------------------------- # ------------------------- # Fast 2-sec audio generation # ------------------------- @spaces.GPU() def generate_fast(prompt_text): logs = {} t_start_total = time.time() try: # Tokenize t_tokenize = time.time() inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE) logs["tokenize_time"] = time.time() - t_tokenize # --- GENERATE minimal tokens (~2 sec) t_gen = time.time() with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16): # Estimate ~48000 samples per 2 sec at 24kHz outputs = model_pt.generate( **inputs, max_new_tokens=1024, # << much lower than 240k temperature=0.4, top_p=0.9, repetition_penalty=1.1, do_sample=True, eos_token_id=128258, pad_token_id=tokenizer.pad_token_id, use_cache=True, ) logs["generation_time"] = time.time() - t_gen # Extract SNAC tokens t_snac = time.time() gen_ids = outputs[0, inputs["input_ids"].shape[1]:] SNAC_MIN, SNAC_MAX = 128266, 156937 snac_mask = (gen_ids >= SNAC_MIN) & (gen_ids <= SNAC_MAX) snac_tokens = gen_ids[snac_mask][:14] # only 2 sec: 2 frames * 7 tokens snac_tokens = snac_tokens.reshape(2,7) l1 = (snac_tokens[:,0]-SNAC_MIN)%4096 l2 = torch.stack([(snac_tokens[:,1]-SNAC_MIN)%4096,(snac_tokens[:,4]-SNAC_MIN)%4096],1).flatten() l3 = torch.stack([(snac_tokens[:,2]-SNAC_MIN)%4096,(snac_tokens[:,3]-SNAC_MIN)%4096, (snac_tokens[:,5]-SNAC_MIN)%4096,(snac_tokens[:,6]-SNAC_MIN)%4096],1).flatten() codes_tensor = [l1.unsqueeze(0).to(DEVICE), l2.unsqueeze(0).to(DEVICE), l3.unsqueeze(0).to(DEVICE)] logs["snac_extract_time"] = time.time() - t_snac # Vocoder (FP16) t_voc = time.time() with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): z_q = snac_model.quantizer.from_codes(codes_tensor) audio_tensor = snac_model.decoder(z_q)[0,0] logs["vocoder_time"] = time.time() - t_voc # Postprocess t_post = time.time() audio = audio_tensor.cpu().numpy() audio = audio[:TARGET_SR*2] # 2 sec max out_path = OUT_ROOT / "tts_fast.wav" sf.write(out_path, audio, TARGET_SR) logs["postprocess_time"] = time.time() - t_post logs["total_time"] = time.time() - t_start_total logs_text = json.dumps(logs, indent=2) return str(out_path), str(out_path), logs_text except Exception as e: return None, None, f"[error] {e}\n{traceback.format_exc()}" # ------------------------- # GRADIO UI # ------------------------- css = ".gradio-container {max-width: 1400px}" with gr.Blocks(title="NAVA Optimized (PyTorch + SNAC + FA3)", css=css) as demo: gr.Markdown("# 🪶 NAVA TTS Optimized with FlashAttention 3 & Logs") with gr.Row(): with gr.Column(): preset = gr.Dropdown(list(PRESET_CHARACTERS.keys()), label="Voice Preset", value="Male American") user_desc = gr.Textbox(label="Description (custom input for prompt building)") user_text = gr.Textbox(label="Text to speak", value=EXAMPLE_PROMPT) submit_btn = gr.Button("Generate Audio") with gr.Column(): audio_out = gr.Audio(label="Generated Audio", type="filepath") logs_out = gr.Textbox(label="Detailed logs", lines=25) def on_submit(preset, user_desc, user_text): prompt, desc = build_maya_prompt_from_inputs(preset, user_desc, user_text) path_wav, _, logs = generate_pt(prompt) return path_wav, logs submit_btn.click(on_submit, inputs=[preset, user_desc, user_text], outputs=[audio_out, logs_out]) demo.launch()