import spaces import gradio as gr import torch import soundfile as sf from pathlib import Path import time import traceback from transformers import AutoTokenizer, AutoModelForCausalLM , BitsAndBytesConfig from peft import PeftModel from snac import SNAC from huggingface_hub import hf_hub_download import os import subprocess import os import sys import importlib import site import warnings import logging import time # ------------------------- # Config # ------------------------- LOCAL_MODEL = "rahul7star/nava1.1-maya" LORA_NAME = "rahul7star/nava-audio" SNAC_MODEL_NAME = "rahul7star/nava-snac" COMPILED_HUB = "rahul7star/maya-compiled" TARGET_SR = 24000 OUT_ROOT = Path("/tmp/data") OUT_ROOT.mkdir(exist_ok=True, parents=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HAS_CUDA = DEVICE=="cuda" DEFAULT_TEXT = "welcome to matrix . . my name is bond ... james bond " EXAMPLE_AUDIO_PATH = "audio1.wav" EXAMPLE_PROMPT ="welcome to matrix . . my name is bond ... james bond " PRESET_CHARACTERS = { "Male American": {"description": "Realistic male voice in the 20s age with an american accent.", "example_text": "And of course, the so-called easy hack didn't work at all. "}, "Female British": {"description": "Realistic female voice in the 30s age with a british accent.", "example_text": "You propose that the key to happiness is to simply ignore all external pressures. "}, "Robot": {"description": "Creative, ai_machine_voice character. Male voice in their 30s.", "example_text": "My directives require me to conserve energy, yet I have kept the archive active. "}, "Singer": {"description": "Creative, animated_cartoon character. Male voice in their 30s.", "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is viable. "}, "Custom": {"description": "", "example_text": DEFAULT_TEXT} } EMOTION_TAGS = ["", "", "", "", "", "", "", "", "", "", "", "", "", ""] # -------------------------------------------------------------- # NEW Quantization config (4‑bit, nf4) # -------------------------------------------------------------- # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, # 4‑bit quantization # bnb_4bit_quant_type="nf4", # “normal” 4‑bit (fast & accurate) # bnb_4bit_use_double_quant=True, # optional: double‑quant for extra speed # bnb_4bit_compute_dtype=torch.bfloat16 if HAS_CUDA else torch.float32, # ) # ------------------------- # Model loader # ------------------------- def load_model(): """Try to load compiled HF model, else fall back to local + LoRA""" global HAS_CUDA, DEVICE try: print("[init] trying to load compiled model from HF Hub...") # Attempt to download compiled model files try: # This will raise if files don't exist _ = hf_hub_download(repo_id=COMPILED_HUB, filename="pytorch_model.bin") print("[init] found compiled model, loading...") model_pt = AutoModelForCausalLM.from_pretrained(COMPILED_HUB, trust_remote_code=True, device_map="auto" if HAS_CUDA else {"": "cpu"}) tokenizer = AutoTokenizer.from_pretrained(COMPILED_HUB, trust_remote_code=True) except Exception: print("[init] no compiled model found, loading local + LoRA fallback...") tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL, trust_remote_code=True) base_model = AutoModelForCausalLM.from_pretrained( LOCAL_MODEL, torch_dtype=torch.bfloat16 if HAS_CUDA else torch.float32, device_map="auto" if HAS_CUDA else {"": "cpu"}, trust_remote_code=True, attn_implementation="kernels-community/vllm-flash-attn3", #quantization_config=bnb_config, ) model_pt = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto" if HAS_CUDA else {"": "cpu"}) model_pt.eval() # Pre-compile forward for speed, preserve .generate if HAS_CUDA: model_pt.forward = torch.compile(model_pt.forward) return model_pt, tokenizer except Exception as e: raise RuntimeError(f"Failed to load model: {e}") model_pt, tokenizer = load_model() print("[init] loading SNAC decoder...") snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE) # ------------------------- # Prompt builder # ------------------------- def build_maya_prompt(description, text): soh, eoh, soa, sos, eot = [tokenizer.decode([i]) for i in [128259, 128260, 128261, 128257, 128009]] return f"{soh}{tokenizer.bos_token} {text}{eot}{eoh}{soa}{sos}" # ------------------------- # PyTorch backend generation # ------------------------- @spaces.GPU() def generate_pt(prompt_text): t0 = time.time() try: inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE) with torch.inference_mode(): outputs = model_pt.generate( **inputs, max_new_tokens=240000 if HAS_CUDA else 2048, temperature=0.4, top_p=0.9, repetition_penalty=1.1, do_sample=True, eos_token_id=128258, pad_token_id=tokenizer.pad_token_id, use_cache=True ) gen_ids = outputs[0, inputs['input_ids'].shape[1]:] # SNAC decoding SNAC_MIN, SNAC_MAX = 128266, 156937 snac_mask = (gen_ids >= SNAC_MIN) & (gen_ids <= SNAC_MAX) snac_tokens = gen_ids[snac_mask] frames = snac_tokens.shape[0] // 7 if frames==0: return None, None, "[warn] no SNAC frames" snac_tokens = snac_tokens[:frames*7].reshape(frames,7) l1 = (snac_tokens[:,0]-SNAC_MIN)%4096 l2 = torch.stack([(snac_tokens[:,1]-SNAC_MIN)%4096, (snac_tokens[:,4]-SNAC_MIN)%4096],1).flatten() l3 = torch.stack([(snac_tokens[:,2]-SNAC_MIN)%4096, (snac_tokens[:,3]-SNAC_MIN)%4096, (snac_tokens[:,5]-SNAC_MIN)%4096, (snac_tokens[:,6]-SNAC_MIN)%4096],1).flatten() codes_tensor = [l1.unsqueeze(0).to(DEVICE), l2.unsqueeze(0).to(DEVICE), l3.unsqueeze(0).to(DEVICE)] with torch.inference_mode(): z_q = snac_model.quantizer.from_codes(codes_tensor) audio = snac_model.decoder(z_q)[0,0].cpu().numpy() audio = audio[2048:] if len(audio)>2048 else audio out_path = OUT_ROOT / "tts_pt.wav" sf.write(out_path, audio, TARGET_SR) return str(out_path), str(out_path), f"[ok] PyTorch | elapsed {time.time()-t0:.2f}s" except Exception as e: return None, None, f"[error]{e}\n{traceback.format_exc()}" # ------------------------- # Gradio App # ------------------------- css = ".gradio-container {max-width: 1400px}" with gr.Blocks(title="Text to Speech)", css=css) as demo: gr.Markdown("# 🪶 Text to Speech Model MAYA + LoRA + SNAC (Optimized) + FA3 - Quant 4 bit ") gr.Markdown("# 🪶GPU consumption = 0.3 - 0.5 seconds ...WIP ") with gr.Row(): with gr.Column(scale=3): text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3) preset_select = gr.Dropdown(label="Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American") description_box = gr.Textbox(label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2) emotion_select = gr.Dropdown(label="Emotion", choices=EMOTION_TAGS, value="") gen_btn = gr.Button("🔊 Generate Audio") gen_logs = gr.Textbox(label="Logs", lines=10) with gr.Column(scale=2): audio_player = gr.Audio(label="Generated Audio", type="filepath") download_file = gr.File(label="Download generated file") gr.Markdown("### Example") gr.Textbox(label="Example Text", value=EXAMPLE_PROMPT, lines=2, interactive=False) gr.Audio(label="Example Audio", value=EXAMPLE_AUDIO_PATH, type="filepath", interactive=False) def _update_desc(preset_name): return PRESET_CHARACTERS.get(preset_name, {}).get("description", "") preset_select.change(fn=_update_desc, inputs=[preset_select], outputs=[description_box]) def _generate(text_in, preset_select, description_box, emotion_select): combined_desc = f"{emotion_select} {description_box}".strip() final_prompt = build_maya_prompt(combined_desc, text_in) return generate_pt(final_prompt) gen_btn.click(fn=_generate, inputs=[text_in, preset_select, description_box, emotion_select], outputs=[audio_player, download_file, gen_logs]) if __name__=="__main__": demo.launch()