Spaces:
Sleeping
Sleeping
| # app_optimized_gradio.py | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import soundfile as sf | |
| from pathlib import Path | |
| import time | |
| import traceback | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from snac import SNAC | |
| from huggingface_hub import hf_hub_download | |
| # ------------------------- | |
| # Config | |
| # ------------------------- | |
| LOCAL_MODEL = "rahul7star/nava1.1-maya" | |
| LORA_NAME = "rahul7star/nava-audio" | |
| SNAC_MODEL_NAME = "rahul7star/nava-snac" | |
| COMPILED_HUB = "rahul7star/maya-compiled" | |
| TARGET_SR = 24000 | |
| OUT_ROOT = Path("/tmp/data") | |
| OUT_ROOT.mkdir(exist_ok=True, parents=True) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| HAS_CUDA = DEVICE=="cuda" | |
| DEFAULT_TEXT = "welcome to matrix .<sigh> . my name is bond ... james bond <laugh>" | |
| EXAMPLE_AUDIO_PATH = "audio1.wav" | |
| EXAMPLE_PROMPT ="welcome to matrix .<sigh> . my name is bond ... james bond <laugh>" | |
| PRESET_CHARACTERS = { | |
| "Male American": {"description": "Realistic male voice in the 20s age with an american accent.", "example_text": "And of course, the so-called easy hack didn't work at all. <sigh>"}, | |
| "Female British": {"description": "Realistic female voice in the 30s age with a british accent.", "example_text": "You propose that the key to happiness is to simply ignore all external pressures. <chuckle>"}, | |
| "Robot": {"description": "Creative, ai_machine_voice character. Male voice in their 30s.", "example_text": "My directives require me to conserve energy, yet I have kept the archive active. <sigh>"}, | |
| "Singer": {"description": "Creative, animated_cartoon character. Male voice in their 30s.", "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is viable. <chuckle>"}, | |
| "Custom": {"description": "", "example_text": DEFAULT_TEXT} | |
| } | |
| EMOTION_TAGS = ["<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>", | |
| "<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>", | |
| "<sarcastic>", "<sigh>", "<sing>", "<whisper>"] | |
| # ------------------------- | |
| # Model loader | |
| # ------------------------- | |
| def load_model(): | |
| """Try to load compiled HF model, else fall back to local + LoRA""" | |
| global HAS_CUDA, DEVICE | |
| try: | |
| print("[init] trying to load compiled model from HF Hub...") | |
| # Attempt to download compiled model files | |
| try: | |
| # This will raise if files don't exist | |
| _ = hf_hub_download(repo_id=COMPILED_HUB, filename="pytorch_model.bin") | |
| print("[init] found compiled model, loading...") | |
| model_pt = AutoModelForCausalLM.from_pretrained(COMPILED_HUB, trust_remote_code=True, device_map="auto" if HAS_CUDA else {"": "cpu"}) | |
| tokenizer = AutoTokenizer.from_pretrained(COMPILED_HUB, trust_remote_code=True) | |
| except Exception: | |
| print("[init] no compiled model found, loading local + LoRA fallback...") | |
| tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL, trust_remote_code=True) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| LOCAL_MODEL, | |
| torch_dtype=torch.bfloat16 if HAS_CUDA else torch.float32, | |
| device_map="auto" if HAS_CUDA else {"": "cpu"}, | |
| trust_remote_code=True | |
| ) | |
| model_pt = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto" if HAS_CUDA else {"": "cpu"}) | |
| model_pt.eval() | |
| # Pre-compile forward for speed, preserve .generate | |
| if HAS_CUDA: | |
| model_pt.forward = torch.compile(model_pt.forward) | |
| return model_pt, tokenizer | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {e}") | |
| model_pt, tokenizer = load_model() | |
| print("[init] loading SNAC decoder...") | |
| snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE) | |
| # ------------------------- | |
| # Prompt builder | |
| # ------------------------- | |
| def build_maya_prompt(description, text): | |
| soh, eoh, soa, sos, eot = [tokenizer.decode([i]) for i in [128259, 128260, 128261, 128257, 128009]] | |
| return f"{soh}{tokenizer.bos_token}<description=\"{description}\"> {text}{eot}{eoh}{soa}{sos}" | |
| # ------------------------- | |
| # PyTorch backend generation | |
| # ------------------------- | |
| def generate_pt(prompt_text): | |
| t0 = time.time() | |
| try: | |
| inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE) | |
| with torch.inference_mode(): | |
| outputs = model_pt.generate( | |
| **inputs, | |
| max_new_tokens=240000 if HAS_CUDA else 2048, | |
| temperature=0.4, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| eos_token_id=128258, | |
| pad_token_id=tokenizer.pad_token_id, | |
| use_cache=True | |
| ) | |
| gen_ids = outputs[0, inputs['input_ids'].shape[1]:] | |
| # SNAC decoding | |
| SNAC_MIN, SNAC_MAX = 128266, 156937 | |
| snac_mask = (gen_ids >= SNAC_MIN) & (gen_ids <= SNAC_MAX) | |
| snac_tokens = gen_ids[snac_mask] | |
| frames = snac_tokens.shape[0] // 7 | |
| if frames==0: return None, None, "[warn] no SNAC frames" | |
| snac_tokens = snac_tokens[:frames*7].reshape(frames,7) | |
| l1 = (snac_tokens[:,0]-SNAC_MIN)%4096 | |
| l2 = torch.stack([(snac_tokens[:,1]-SNAC_MIN)%4096, (snac_tokens[:,4]-SNAC_MIN)%4096],1).flatten() | |
| l3 = torch.stack([(snac_tokens[:,2]-SNAC_MIN)%4096, (snac_tokens[:,3]-SNAC_MIN)%4096, | |
| (snac_tokens[:,5]-SNAC_MIN)%4096, (snac_tokens[:,6]-SNAC_MIN)%4096],1).flatten() | |
| codes_tensor = [l1.unsqueeze(0).to(DEVICE), l2.unsqueeze(0).to(DEVICE), l3.unsqueeze(0).to(DEVICE)] | |
| with torch.inference_mode(): | |
| z_q = snac_model.quantizer.from_codes(codes_tensor) | |
| audio = snac_model.decoder(z_q)[0,0].cpu().numpy() | |
| audio = audio[2048:] if len(audio)>2048 else audio | |
| out_path = OUT_ROOT / "tts_pt.wav" | |
| sf.write(out_path, audio, TARGET_SR) | |
| return str(out_path), str(out_path), f"[ok] PyTorch | elapsed {time.time()-t0:.2f}s" | |
| except Exception as e: | |
| return None, None, f"[error]{e}\n{traceback.format_exc()}" | |
| # ------------------------- | |
| # Gradio App | |
| # ------------------------- | |
| css = ".gradio-container {max-width: 1400px}" | |
| with gr.Blocks(title="NAVA Optimized (PyTorch + SNAC)", css=css) as demo: | |
| gr.Markdown("# 🪶 NAVA — MAYA + LoRA + SNAC (Optimized)") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3) | |
| preset_select = gr.Dropdown(label="Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American") | |
| description_box = gr.Textbox(label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2) | |
| emotion_select = gr.Dropdown(label="Emotion", choices=EMOTION_TAGS, value="<neutral>") | |
| gen_btn = gr.Button("🔊 Generate Audio") | |
| gen_logs = gr.Textbox(label="Logs", lines=10) | |
| with gr.Column(scale=2): | |
| audio_player = gr.Audio(label="Generated Audio", type="filepath") | |
| download_file = gr.File(label="Download generated file") | |
| gr.Markdown("### Example") | |
| gr.Textbox(label="Example Text", value=EXAMPLE_PROMPT, lines=2, interactive=False) | |
| gr.Audio(label="Example Audio", value=EXAMPLE_AUDIO_PATH, type="filepath", interactive=False) | |
| def _update_desc(preset_name): | |
| return PRESET_CHARACTERS.get(preset_name, {}).get("description", "") | |
| preset_select.change(fn=_update_desc, inputs=[preset_select], outputs=[description_box]) | |
| def _generate(text_in, preset_select, description_box, emotion_select): | |
| combined_desc = f"{emotion_select} {description_box}".strip() | |
| final_prompt = build_maya_prompt(combined_desc, text_in) | |
| return generate_pt(final_prompt) | |
| gen_btn.click(fn=_generate, | |
| inputs=[text_in, preset_select, description_box, emotion_select], | |
| outputs=[audio_player, download_file, gen_logs]) | |
| if __name__=="__main__": | |
| demo.launch() |