Nava-Maya-INfrence / app_kv.py
rahul7star's picture
Update app_kv.py
f3a7f8f verified
# app_optimized_gradio.py
import spaces
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
import time
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from snac import SNAC
from huggingface_hub import hf_hub_download
# -------------------------
# Config
# -------------------------
LOCAL_MODEL = "rahul7star/nava1.1-maya"
LORA_NAME = "rahul7star/nava-audio"
SNAC_MODEL_NAME = "rahul7star/nava-snac"
COMPILED_HUB = "rahul7star/maya-compiled"
TARGET_SR = 24000
OUT_ROOT = Path("/tmp/data")
OUT_ROOT.mkdir(exist_ok=True, parents=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HAS_CUDA = DEVICE=="cuda"
DEFAULT_TEXT = "welcome to matrix .<sigh> . my name is bond ... james bond <laugh>"
EXAMPLE_AUDIO_PATH = "audio1.wav"
EXAMPLE_PROMPT ="welcome to matrix .<sigh> . my name is bond ... james bond <laugh>"
PRESET_CHARACTERS = {
"Male American": {"description": "Realistic male voice in the 20s age with an american accent.", "example_text": "And of course, the so-called easy hack didn't work at all. <sigh>"},
"Female British": {"description": "Realistic female voice in the 30s age with a british accent.", "example_text": "You propose that the key to happiness is to simply ignore all external pressures. <chuckle>"},
"Robot": {"description": "Creative, ai_machine_voice character. Male voice in their 30s.", "example_text": "My directives require me to conserve energy, yet I have kept the archive active. <sigh>"},
"Singer": {"description": "Creative, animated_cartoon character. Male voice in their 30s.", "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is viable. <chuckle>"},
"Custom": {"description": "", "example_text": DEFAULT_TEXT}
}
EMOTION_TAGS = ["<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>",
"<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>",
"<sarcastic>", "<sigh>", "<sing>", "<whisper>"]
# -------------------------
# Model loader
# -------------------------
def load_model():
"""Try to load compiled HF model, else fall back to local + LoRA"""
global HAS_CUDA, DEVICE
try:
print("[init] trying to load compiled model from HF Hub...")
# Attempt to download compiled model files
try:
# This will raise if files don't exist
_ = hf_hub_download(repo_id=COMPILED_HUB, filename="pytorch_model.bin")
print("[init] found compiled model, loading...")
model_pt = AutoModelForCausalLM.from_pretrained(COMPILED_HUB, trust_remote_code=True, device_map="auto" if HAS_CUDA else {"": "cpu"})
tokenizer = AutoTokenizer.from_pretrained(COMPILED_HUB, trust_remote_code=True)
except Exception:
print("[init] no compiled model found, loading local + LoRA fallback...")
tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
LOCAL_MODEL,
torch_dtype=torch.bfloat16 if HAS_CUDA else torch.float32,
device_map="auto" if HAS_CUDA else {"": "cpu"},
trust_remote_code=True
)
model_pt = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto" if HAS_CUDA else {"": "cpu"})
model_pt.eval()
# Pre-compile forward for speed, preserve .generate
if HAS_CUDA:
model_pt.forward = torch.compile(model_pt.forward)
return model_pt, tokenizer
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
model_pt, tokenizer = load_model()
print("[init] loading SNAC decoder...")
snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
# -------------------------
# Prompt builder
# -------------------------
def build_maya_prompt(description, text):
soh, eoh, soa, sos, eot = [tokenizer.decode([i]) for i in [128259, 128260, 128261, 128257, 128009]]
return f"{soh}{tokenizer.bos_token}<description=\"{description}\"> {text}{eot}{eoh}{soa}{sos}"
# -------------------------
# PyTorch backend generation
# -------------------------
@spaces.GPU()
def generate_pt(prompt_text):
t0 = time.time()
try:
inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
outputs = model_pt.generate(
**inputs,
max_new_tokens=240000 if HAS_CUDA else 2048,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
eos_token_id=128258,
pad_token_id=tokenizer.pad_token_id,
use_cache=True
)
gen_ids = outputs[0, inputs['input_ids'].shape[1]:]
# SNAC decoding
SNAC_MIN, SNAC_MAX = 128266, 156937
snac_mask = (gen_ids >= SNAC_MIN) & (gen_ids <= SNAC_MAX)
snac_tokens = gen_ids[snac_mask]
frames = snac_tokens.shape[0] // 7
if frames==0: return None, None, "[warn] no SNAC frames"
snac_tokens = snac_tokens[:frames*7].reshape(frames,7)
l1 = (snac_tokens[:,0]-SNAC_MIN)%4096
l2 = torch.stack([(snac_tokens[:,1]-SNAC_MIN)%4096, (snac_tokens[:,4]-SNAC_MIN)%4096],1).flatten()
l3 = torch.stack([(snac_tokens[:,2]-SNAC_MIN)%4096, (snac_tokens[:,3]-SNAC_MIN)%4096,
(snac_tokens[:,5]-SNAC_MIN)%4096, (snac_tokens[:,6]-SNAC_MIN)%4096],1).flatten()
codes_tensor = [l1.unsqueeze(0).to(DEVICE), l2.unsqueeze(0).to(DEVICE), l3.unsqueeze(0).to(DEVICE)]
with torch.inference_mode():
z_q = snac_model.quantizer.from_codes(codes_tensor)
audio = snac_model.decoder(z_q)[0,0].cpu().numpy()
audio = audio[2048:] if len(audio)>2048 else audio
out_path = OUT_ROOT / "tts_pt.wav"
sf.write(out_path, audio, TARGET_SR)
return str(out_path), str(out_path), f"[ok] PyTorch | elapsed {time.time()-t0:.2f}s"
except Exception as e:
return None, None, f"[error]{e}\n{traceback.format_exc()}"
# -------------------------
# Gradio App
# -------------------------
css = ".gradio-container {max-width: 1400px}"
with gr.Blocks(title="NAVA Optimized (PyTorch + SNAC)", css=css) as demo:
gr.Markdown("# 🪶 NAVA — MAYA + LoRA + SNAC (Optimized)")
with gr.Row():
with gr.Column(scale=3):
text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3)
preset_select = gr.Dropdown(label="Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American")
description_box = gr.Textbox(label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2)
emotion_select = gr.Dropdown(label="Emotion", choices=EMOTION_TAGS, value="<neutral>")
gen_btn = gr.Button("🔊 Generate Audio")
gen_logs = gr.Textbox(label="Logs", lines=10)
with gr.Column(scale=2):
audio_player = gr.Audio(label="Generated Audio", type="filepath")
download_file = gr.File(label="Download generated file")
gr.Markdown("### Example")
gr.Textbox(label="Example Text", value=EXAMPLE_PROMPT, lines=2, interactive=False)
gr.Audio(label="Example Audio", value=EXAMPLE_AUDIO_PATH, type="filepath", interactive=False)
def _update_desc(preset_name):
return PRESET_CHARACTERS.get(preset_name, {}).get("description", "")
preset_select.change(fn=_update_desc, inputs=[preset_select], outputs=[description_box])
def _generate(text_in, preset_select, description_box, emotion_select):
combined_desc = f"{emotion_select} {description_box}".strip()
final_prompt = build_maya_prompt(combined_desc, text_in)
return generate_pt(final_prompt)
gen_btn.click(fn=_generate,
inputs=[text_in, preset_select, description_box, emotion_select],
outputs=[audio_player, download_file, gen_logs])
if __name__=="__main__":
demo.launch()