Nava-Maya-INfrence / app_quant.py
rahul7star's picture
Update app_quant.py
11fc338 verified
# app.py
import spaces
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
import traceback
import time
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from snac import SNAC
# -------------------------
# Config / constants
# -------------------------
MODEL_NAME = "rahul7star/nava1.1-maya" # base maya model (your variant)
LORA_NAME = "rahul7star/nava-audio" # your LoRA adapter
SNAC_MODEL_NAME = "rahul7star/nava-snac" # snac decoder (use hub model id)
TARGET_SR = 24000
OUT_ROOT = Path("/tmp/data")
OUT_ROOT.mkdir(exist_ok=True, parents=True)
DEFAULT_TEXT = "welcome to matrix .<sigh> . my name is bond ... james bond <laugh>"
EXAMPLE_AUDIO_PATH = "audio1.wav" # file in repo root, user-supplied
EXAMPLE_PROMPT ="welcome to matrix .<sigh> . my name is bond ... james bond <laugh>"
# Preset characters (2 realistic + 2 creative + Custom)
PRESET_CHARACTERS = {
"Male American": {
"description": "Realistic male voice in the 20s age with an american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery",
"example_text": "And of course, the so-called easy hack didn't work at all. What a surprise. <sigh>"
},
"Female British": {
"description": "Realistic female voice in the 30s age with a british accent. Normal pitch, throaty timbre, conversational pacing, sarcastic tone delivery at low intensity, podcast domain, interviewer role, formal delivery",
"example_text": "You propose that the key to happiness is to simply ignore all external pressures. <chuckle> I'm sure it must work brilliantly in theory."
},
"Robot": {
"description": "Creative, ai_machine_voice character. Male voice in their 30s with an american accent. High pitch, robotic timbre, slow pacing, sad tone at medium intensity.",
"example_text": "My directives require me to conserve energy, yet I have kept the archive of their farewell messages active. <sigh>"
},
"Singer": {
"description": "Creative, animated_cartoon character. Male voice in their 30s with an american accent. High pitch, deep timbre, slow pacing, sarcastic tone at medium intensity.",
"example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. <chuckle> Why would we ever consider running away very fast."
},
"Custom": {
"description": "", # user will edit
"example_text": DEFAULT_TEXT
}
}
# Emotion tags (full list you asked to support)
EMOTION_TAGS = [
"<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>",
"<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>",
"<sarcastic>", "<sigh>", "<sing>", "<whisper>"
]
# Short safety / generation limits
SEQ_LEN_CPU = 4096
MAX_NEW_TOKENS_CPU = 1024
SEQ_LEN_GPU = 240000
MAX_NEW_TOKENS_GPU = 240000
# Detect devices
HAS_CUDA = torch.cuda.is_available()
DEVICE = "cuda" if HAS_CUDA else "cpu"
# Try to detect bitsandbytes availability for faster GPU inference (4-bit)
bnb_available = False
if HAS_CUDA:
try:
from transformers import BitsAndBytesConfig
bnb_available = True
except Exception:
bnb_available = False
print(f"[init] cuda={HAS_CUDA}, bnb={bnb_available}, device={DEVICE}")
# -------------------------
# Load tokenizer + model + LoRA + SNAC ONCE (startup)
# -------------------------
print("[init] loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
print("[init] loading base model + LoRA adapter (this can take time)...")
if HAS_CUDA and bnb_available:
# GPU + bnb path (fastest inference if available)
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quant_config,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto")
SEQ_LEN = SEQ_LEN_GPU
MAX_NEW_TOKENS = MAX_NEW_TOKENS_GPU
print("[init] loaded base+LoRA on GPU (4-bit via bnb).")
else:
# CPU fallback - load base into CPU memory and attach LoRA
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map={"": "cpu"},
low_cpu_mem_usage=True,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": "cpu"})
SEQ_LEN = SEQ_LEN_CPU
MAX_NEW_TOKENS = MAX_NEW_TOKENS_CPU
print("[init] loaded base+LoRA on CPU (FP32).")
model.eval()
print("[init] model ready.")
print("[init] loading SNAC decoder...")
snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
print("[init] snac ready.")
# --------------
# Helper: build prompt per Maya conventions
# --------------
def build_maya_prompt(description: str, text: str):
# use the special tokens used by maya-style models
soh_token = tokenizer.decode([128259]) # SOH
eoh_token = tokenizer.decode([128260]) # EOH
soa_token = tokenizer.decode([128261]) # SOA
sos_token = tokenizer.decode([128257]) # SOS (code start)
eot_token = tokenizer.decode([128009]) # TEXT_EOT / EOT marker
bos_token = tokenizer.bos_token
# We use the simple format: "<description> <text>" and Maya wrappers
formatted = f'<description="{description}"> {text}'
prompt = soh_token + bos_token + formatted + eot_token + eoh_token + soa_token + sos_token
return prompt
# --------------
# Core generate function (uses preloaded model & snac)
# --------------
def generate_from_loaded_model(final_text: str):
"""
final_text: text that already contains description + emotion + user text
returns: (audio_path_str, download_path_str, logs_str)
"""
logs = []
t0 = time.time()
try:
logs.append(f"[info] device={DEVICE} | seq_len={SEQ_LEN}")
prompt = final_text
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
max_new = MAX_NEW_TOKENS if DEVICE == "cuda" else min(MAX_NEW_TOKENS, 1024)
# Use inference_mode for speed
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_new,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
eos_token_id=128258,
pad_token_id=tokenizer.pad_token_id,
)
# Grab generated ids (after prompt length)
gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
logs.append(f"[info] generated tokens: {len(gen_ids)}")
# Extract SNAC tokens (range used by Maya/SNAC)
SNAC_MIN = 128266
SNAC_MAX = 156937
EOS_ID = 128258
eos_idx = gen_ids.index(EOS_ID) if EOS_ID in gen_ids else len(gen_ids)
snac_tokens = [t for t in gen_ids[:eos_idx] if SNAC_MIN <= t <= SNAC_MAX]
frames = len(snac_tokens) // 7
snac_tokens = snac_tokens[:frames*7]
if frames == 0 or len(snac_tokens) == 0:
logs.append("[warn] no SNAC frames found in generated tokens — returning debug logs.")
return None, None, "\n".join(logs)
# De-interleave into l1, l2, l3
l1, l2, l3 = [], [], []
for i in range(frames):
s = snac_tokens[i*7:(i+1)*7]
l1.append((s[0] - SNAC_MIN) % 4096)
l2.extend([(s[1] - SNAC_MIN) % 4096, (s[4] - SNAC_MIN) % 4096])
l3.extend([(s[2] - SNAC_MIN) % 4096, (s[3] - SNAC_MIN) % 4096, (s[5] - SNAC_MIN) % 4096, (s[6] - SNAC_MIN) % 4096])
# Convert to tensors on decoder device and decode
codes_tensor = [
torch.tensor(l1, dtype=torch.long, device=DEVICE).unsqueeze(0),
torch.tensor(l2, dtype=torch.long, device=DEVICE).unsqueeze(0),
torch.tensor(l3, dtype=torch.long, device=DEVICE).unsqueeze(0),
]
with torch.inference_mode():
z_q = snac_model.quantizer.from_codes(codes_tensor)
audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
# Remove warmup if present and save
if len(audio) > 2048:
audio = audio[2048:]
out_path = OUT_ROOT / "tts_output_loaded_lora.wav"
sf.write(out_path, audio, TARGET_SR)
logs.append(f"[ok] saved {out_path} duration={(len(audio)/TARGET_SR):.2f}s")
logs.append(f"[time] elapsed {time.time() - t0:.2f}s")
return str(out_path), str(out_path), "\n".join(logs)
except Exception as e:
tb = traceback.format_exc()
logs.append(f"[error] {e}\n{tb}")
return None, None, "\n".join(logs)
# --------------
# UI glue: combine description + emotion + user text (3a)
# --------------
@spaces.GPU()
def generate_for_ui(text, preset_name, description, emotion):
logs = []
try:
# If user selected a preset, and description param is empty (e.g. custom not edited),
# take preset description
if preset_name in PRESET_CHARACTERS and (not description or description.strip() == ""):
description = PRESET_CHARACTERS[preset_name]["description"]
# combine (3a): final_text = f"{emotion} {description}. {text}"
# For Maya prompt, we pass the combined description+text to build_maya_prompt
combined_desc = f"{emotion} {description}".strip()
final_plain = f"{combined_desc}. {text}".strip()
final_prompt = build_maya_prompt(combined_desc, text) # keep maya wrapper
audio_path, download_path, gen_logs = generate_from_loaded_model(final_prompt)
if audio_path is None:
return None, None, gen_logs
return audio_path, download_path, gen_logs
except Exception as e:
return None, None, f"[error] {e}\n{traceback.format_exc()}"
# -------------------------
# Gradio UI (keeps your layout; wide container)
# -------------------------
css = ".gradio-container {max-width: 1400px}"
with gr.Blocks(title="NAVA — MAYAAORG + LoRA + SNAC (Optimized)", css=css) as demo:
gr.Markdown("# 🪶 NAVA — MAYAAORG + LoRA + SNAC (Optimized)\nGenerate emotional Hindi speech using Maya1 base + your LoRA adapter.")
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("## Inference (CPU/GPU auto)\nType text + pick a preset or write description manually.")
text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3)
preset_select = gr.Dropdown(label="Select Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American")
description_box = gr.Textbox(label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2)
emotion_select = gr.Dropdown(label="Select Emotion", choices=EMOTION_TAGS, value="<neutral>")
gen_btn = gr.Button("🔊 Generate Audio (LoRA)")
gen_logs = gr.Textbox(label="Logs", lines=10)
with gr.Column(scale=2):
gr.Markdown("### Output")
audio_player = gr.Audio(label="Generated Audio", type="filepath")
download_file = gr.File(label="Download generated file")
gr.Markdown("### Example")
gr.Textbox(label="Example Text", value=EXAMPLE_PROMPT, lines=2, interactive=False)
gr.Audio(label="Example Audio (project)", value=EXAMPLE_AUDIO_PATH, type="filepath", interactive=False)
# wire updates: preset -> description
def _update_desc(preset_name):
return PRESET_CHARACTERS.get(preset_name, {}).get("description", "")
preset_select.change(fn=_update_desc, inputs=[preset_select], outputs=[description_box])
# generation wrapper
def _generate(text_in, preset_select, description_box, emotion_select):
return generate_for_ui(text_in, preset_select, description_box, emotion_select)
gen_btn.click(fn=_generate,
inputs=[text_in, preset_select, description_box, emotion_select],
outputs=[audio_player, download_file, gen_logs])
# -------------------------
if __name__ == "__main__":
demo.launch()