# app_optimized_gradio.py import spaces import gradio as gr import torch import soundfile as sf from pathlib import Path import time import traceback from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from snac import SNAC from huggingface_hub import hf_hub_download # ------------------------- # ENV VARS (safe defaults) # ------------------------- os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1") os.environ.setdefault("TORCHINDUCTOR_FUSION", "0") os.environ.setdefault("USE_FLASH_ATTENTION", "0") os.environ.setdefault("XLA_IGNORE_ENV_VARS", "1") # ------------------------- # FlashAttention 3 install # ------------------------- def try_install_flash_attention(): t0 = time.time() try: print("[FA3] Attempting to download/install FlashAttention 3 wheel...") wheel = hf_hub_download( repo_id="rahul7star/flash-attn-3", repo_type="model", filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", ) subprocess.run([sys.executable, "-m", "pip", "install", wheel], check=True) site.addsitedir(site.getsitepackages()[0]) importlib.invalidate_caches() print(f"[FA3] ✅ Installed successfully in {time.time()-t0:.3f}s") return True except Exception as e: print(f"[FA3] ⚠ïļ Install failed ({time.time()-t0:.3f}s): {e}") return False FA_INSTALLED = try_install_flash_attention() # ------------------------- # Config # ------------------------- LOCAL_MODEL = "rahul7star/nava1.1-maya" LORA_NAME = "rahul7star/nava-audio" SNAC_MODEL_NAME = "rahul7star/nava-snac" COMPILED_HUB = "rahul7star/maya-compiled" TARGET_SR = 24000 OUT_ROOT = Path("/tmp/data") OUT_ROOT.mkdir(exist_ok=True, parents=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HAS_CUDA = DEVICE=="cuda" DEFAULT_TEXT = "welcome to matrix . . my name is bond ... james bond " EXAMPLE_AUDIO_PATH = "audio1.wav" EXAMPLE_PROMPT ="welcome to matrix . . my name is bond ... james bond " PRESET_CHARACTERS = { "Male American": {"description": "Realistic male voice in the 20s age with an american accent.", "example_text": "And of course, the so-called easy hack didn't work at all. "}, "Female British": {"description": "Realistic female voice in the 30s age with a british accent.", "example_text": "You propose that the key to happiness is to simply ignore all external pressures. "}, "Robot": {"description": "Creative, ai_machine_voice character. Male voice in their 30s.", "example_text": "My directives require me to conserve energy, yet I have kept the archive active. "}, "Singer": {"description": "Creative, animated_cartoon character. Male voice in their 30s.", "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is viable. "}, "Custom": {"description": "", "example_text": DEFAULT_TEXT} } EMOTION_TAGS = ["", "", "", "", "", "", "", "", "", "", "", "", "", ""] # ------------------------- # Model loader # ------------------------- def load_model(): """Try to load compiled HF model, else fall back to local + LoRA""" global HAS_CUDA, DEVICE try: print("[init] trying to load compiled model from HF Hub...") # Attempt to download compiled model files try: # This will raise if files don't exist _ = hf_hub_download(repo_id=COMPILED_HUB, filename="pytorch_model.bin") print("[init] found compiled model, loading...") model_pt = AutoModelForCausalLM.from_pretrained(COMPILED_HUB, trust_remote_code=True, device_map="auto" if HAS_CUDA else {"": "cpu"}) tokenizer = AutoTokenizer.from_pretrained(COMPILED_HUB, trust_remote_code=True) except Exception: print("[init] no compiled model found, loading local + LoRA fallback...") tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL, trust_remote_code=True) base_model = AutoModelForCausalLM.from_pretrained( LOCAL_MODEL, torch_dtype=torch.bfloat16 if HAS_CUDA else torch.float32, device_map="auto" if HAS_CUDA else {"": "cpu"}, trust_remote_code=True, attn_implementation="flash_attention_3" if FA_INSTALLED else None ) model_pt = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto" if HAS_CUDA else {"": "cpu"}) model_pt.eval() # Pre-compile forward for speed, preserve .generate if HAS_CUDA: model_pt.forward = torch.compile(model_pt.forward) return model_pt, tokenizer except Exception as e: raise RuntimeError(f"Failed to load model: {e}") model_pt, tokenizer = load_model() print("[init] loading SNAC decoder...") snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE) # ------------------------- # Prompt builder # ------------------------- def build_maya_prompt(description, text): soh, eoh, soa, sos, eot = [tokenizer.decode([i]) for i in [128259, 128260, 128261, 128257, 128009]] return f"{soh}{tokenizer.bos_token} {text}{eot}{eoh}{soa}{sos}" # ------------------------- # PyTorch backend generation # ------------------------- @spaces.GPU() def generate_pt(prompt_text): t0 = time.time() try: inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE) with torch.inference_mode(): outputs = model_pt.generate( **inputs, max_new_tokens=240000 if HAS_CUDA else 2048, temperature=0.4, top_p=0.9, repetition_penalty=1.1, do_sample=True, eos_token_id=128258, pad_token_id=tokenizer.pad_token_id, use_cache=True ) gen_ids = outputs[0, inputs['input_ids'].shape[1]:] # SNAC decoding SNAC_MIN, SNAC_MAX = 128266, 156937 snac_mask = (gen_ids >= SNAC_MIN) & (gen_ids <= SNAC_MAX) snac_tokens = gen_ids[snac_mask] frames = snac_tokens.shape[0] // 7 if frames==0: return None, None, "[warn] no SNAC frames" snac_tokens = snac_tokens[:frames*7].reshape(frames,7) l1 = (snac_tokens[:,0]-SNAC_MIN)%4096 l2 = torch.stack([(snac_tokens[:,1]-SNAC_MIN)%4096, (snac_tokens[:,4]-SNAC_MIN)%4096],1).flatten() l3 = torch.stack([(snac_tokens[:,2]-SNAC_MIN)%4096, (snac_tokens[:,3]-SNAC_MIN)%4096, (snac_tokens[:,5]-SNAC_MIN)%4096, (snac_tokens[:,6]-SNAC_MIN)%4096],1).flatten() codes_tensor = [l1.unsqueeze(0).to(DEVICE), l2.unsqueeze(0).to(DEVICE), l3.unsqueeze(0).to(DEVICE)] with torch.inference_mode(): z_q = snac_model.quantizer.from_codes(codes_tensor) audio = snac_model.decoder(z_q)[0,0].cpu().numpy() audio = audio[2048:] if len(audio)>2048 else audio out_path = OUT_ROOT / "tts_pt.wav" sf.write(out_path, audio, TARGET_SR) return str(out_path), str(out_path), f"[ok] PyTorch | elapsed {time.time()-t0:.2f}s" except Exception as e: return None, None, f"[error]{e}\n{traceback.format_exc()}" # ------------------------- # Gradio App # ------------------------- css = ".gradio-container {max-width: 1400px}" with gr.Blocks(title="NAVA Optimized (PyTorch + SNAC)", css=css) as demo: gr.Markdown("# ðŸŠķ NAVA — MAYA + LoRA + SNAC (Optimized) + FA3 ") with gr.Row(): with gr.Column(scale=3): text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3) preset_select = gr.Dropdown(label="Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American") description_box = gr.Textbox(label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2) emotion_select = gr.Dropdown(label="Emotion", choices=EMOTION_TAGS, value="") gen_btn = gr.Button("🔊 Generate Audio") gen_logs = gr.Textbox(label="Logs", lines=10) with gr.Column(scale=2): audio_player = gr.Audio(label="Generated Audio", type="filepath") download_file = gr.File(label="Download generated file") gr.Markdown("### Example") gr.Textbox(label="Example Text", value=EXAMPLE_PROMPT, lines=2, interactive=False) gr.Audio(label="Example Audio", value=EXAMPLE_AUDIO_PATH, type="filepath", interactive=False) def _update_desc(preset_name): return PRESET_CHARACTERS.get(preset_name, {}).get("description", "") preset_select.change(fn=_update_desc, inputs=[preset_select], outputs=[description_box]) def _generate(text_in, preset_select, description_box, emotion_select): combined_desc = f"{emotion_select} {description_box}".strip() final_prompt = build_maya_prompt(combined_desc, text_in) return generate_pt(final_prompt) gen_btn.click(fn=_generate, inputs=[text_in, preset_select, description_box, emotion_select], outputs=[audio_player, download_file, gen_logs]) if __name__=="__main__": demo.launch()