Spaces:
Sleeping
Sleeping
| # app.py | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import soundfile as sf | |
| from pathlib import Path | |
| import traceback | |
| import time | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from snac import SNAC | |
| # ------------------------- | |
| # Config / constants | |
| # ------------------------- | |
| MODEL_NAME = "rahul7star/nava1.1-maya" # base maya model (your variant) | |
| LORA_NAME = "rahul7star/nava-audio" # your LoRA adapter | |
| SNAC_MODEL_NAME = "rahul7star/nava-snac" # snac decoder (use hub model id) | |
| TARGET_SR = 24000 | |
| OUT_ROOT = Path("/tmp/data") | |
| OUT_ROOT.mkdir(exist_ok=True, parents=True) | |
| DEFAULT_TEXT = "welcome to matrix .<sigh> . my name is bond ... james bond <laugh>" | |
| EXAMPLE_AUDIO_PATH = "audio1.wav" # file in repo root, user-supplied | |
| EXAMPLE_PROMPT ="welcome to matrix .<sigh> . my name is bond ... james bond <laugh>" | |
| # Preset characters (2 realistic + 2 creative + Custom) | |
| PRESET_CHARACTERS = { | |
| "Male American": { | |
| "description": "Realistic male voice in the 20s age with an american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery", | |
| "example_text": "And of course, the so-called easy hack didn't work at all. What a surprise. <sigh>" | |
| }, | |
| "Female British": { | |
| "description": "Realistic female voice in the 30s age with a british accent. Normal pitch, throaty timbre, conversational pacing, sarcastic tone delivery at low intensity, podcast domain, interviewer role, formal delivery", | |
| "example_text": "You propose that the key to happiness is to simply ignore all external pressures. <chuckle> I'm sure it must work brilliantly in theory." | |
| }, | |
| "Robot": { | |
| "description": "Creative, ai_machine_voice character. Male voice in their 30s with an american accent. High pitch, robotic timbre, slow pacing, sad tone at medium intensity.", | |
| "example_text": "My directives require me to conserve energy, yet I have kept the archive of their farewell messages active. <sigh>" | |
| }, | |
| "Singer": { | |
| "description": "Creative, animated_cartoon character. Male voice in their 30s with an american accent. High pitch, deep timbre, slow pacing, sarcastic tone at medium intensity.", | |
| "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. <chuckle> Why would we ever consider running away very fast." | |
| }, | |
| "Custom": { | |
| "description": "", # user will edit | |
| "example_text": DEFAULT_TEXT | |
| } | |
| } | |
| # Emotion tags (full list you asked to support) | |
| EMOTION_TAGS = [ | |
| "<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>", | |
| "<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>", | |
| "<sarcastic>", "<sigh>", "<sing>", "<whisper>" | |
| ] | |
| # Short safety / generation limits | |
| SEQ_LEN_CPU = 4096 | |
| MAX_NEW_TOKENS_CPU = 1024 | |
| SEQ_LEN_GPU = 240000 | |
| MAX_NEW_TOKENS_GPU = 240000 | |
| # Detect devices | |
| HAS_CUDA = torch.cuda.is_available() | |
| DEVICE = "cuda" if HAS_CUDA else "cpu" | |
| # Try to detect bitsandbytes availability for faster GPU inference (4-bit) | |
| bnb_available = False | |
| if HAS_CUDA: | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| bnb_available = True | |
| except Exception: | |
| bnb_available = False | |
| print(f"[init] cuda={HAS_CUDA}, bnb={bnb_available}, device={DEVICE}") | |
| # ------------------------- | |
| # Load tokenizer + model + LoRA + SNAC ONCE (startup) | |
| # ------------------------- | |
| print("[init] loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| print("[init] loading base model + LoRA adapter (this can take time)...") | |
| if HAS_CUDA and bnb_available: | |
| # GPU + bnb path (fastest inference if available) | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=quant_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto") | |
| SEQ_LEN = SEQ_LEN_GPU | |
| MAX_NEW_TOKENS = MAX_NEW_TOKENS_GPU | |
| print("[init] loaded base+LoRA on GPU (4-bit via bnb).") | |
| else: | |
| # CPU fallback - load base into CPU memory and attach LoRA | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, | |
| device_map={"": "cpu"}, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": "cpu"}) | |
| SEQ_LEN = SEQ_LEN_CPU | |
| MAX_NEW_TOKENS = MAX_NEW_TOKENS_CPU | |
| print("[init] loaded base+LoRA on CPU (FP32).") | |
| model.eval() | |
| print("[init] model ready.") | |
| print("[init] loading SNAC decoder...") | |
| snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE) | |
| print("[init] snac ready.") | |
| # -------------- | |
| # Helper: build prompt per Maya conventions | |
| # -------------- | |
| def build_maya_prompt(description: str, text: str): | |
| # use the special tokens used by maya-style models | |
| soh_token = tokenizer.decode([128259]) # SOH | |
| eoh_token = tokenizer.decode([128260]) # EOH | |
| soa_token = tokenizer.decode([128261]) # SOA | |
| sos_token = tokenizer.decode([128257]) # SOS (code start) | |
| eot_token = tokenizer.decode([128009]) # TEXT_EOT / EOT marker | |
| bos_token = tokenizer.bos_token | |
| # We use the simple format: "<description> <text>" and Maya wrappers | |
| formatted = f'<description="{description}"> {text}' | |
| prompt = soh_token + bos_token + formatted + eot_token + eoh_token + soa_token + sos_token | |
| return prompt | |
| # -------------- | |
| # Core generate function (uses preloaded model & snac) | |
| # -------------- | |
| def generate_from_loaded_model(final_text: str): | |
| """ | |
| final_text: text that already contains description + emotion + user text | |
| returns: (audio_path_str, download_path_str, logs_str) | |
| """ | |
| logs = [] | |
| t0 = time.time() | |
| try: | |
| logs.append(f"[info] device={DEVICE} | seq_len={SEQ_LEN}") | |
| prompt = final_text | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE) | |
| max_new = MAX_NEW_TOKENS if DEVICE == "cuda" else min(MAX_NEW_TOKENS, 1024) | |
| # Use inference_mode for speed | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new, | |
| temperature=0.4, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| eos_token_id=128258, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Grab generated ids (after prompt length) | |
| gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() | |
| logs.append(f"[info] generated tokens: {len(gen_ids)}") | |
| # Extract SNAC tokens (range used by Maya/SNAC) | |
| SNAC_MIN = 128266 | |
| SNAC_MAX = 156937 | |
| EOS_ID = 128258 | |
| eos_idx = gen_ids.index(EOS_ID) if EOS_ID in gen_ids else len(gen_ids) | |
| snac_tokens = [t for t in gen_ids[:eos_idx] if SNAC_MIN <= t <= SNAC_MAX] | |
| frames = len(snac_tokens) // 7 | |
| snac_tokens = snac_tokens[:frames*7] | |
| if frames == 0 or len(snac_tokens) == 0: | |
| logs.append("[warn] no SNAC frames found in generated tokens — returning debug logs.") | |
| return None, None, "\n".join(logs) | |
| # De-interleave into l1, l2, l3 | |
| l1, l2, l3 = [], [], [] | |
| for i in range(frames): | |
| s = snac_tokens[i*7:(i+1)*7] | |
| l1.append((s[0] - SNAC_MIN) % 4096) | |
| l2.extend([(s[1] - SNAC_MIN) % 4096, (s[4] - SNAC_MIN) % 4096]) | |
| l3.extend([(s[2] - SNAC_MIN) % 4096, (s[3] - SNAC_MIN) % 4096, (s[5] - SNAC_MIN) % 4096, (s[6] - SNAC_MIN) % 4096]) | |
| # Convert to tensors on decoder device and decode | |
| codes_tensor = [ | |
| torch.tensor(l1, dtype=torch.long, device=DEVICE).unsqueeze(0), | |
| torch.tensor(l2, dtype=torch.long, device=DEVICE).unsqueeze(0), | |
| torch.tensor(l3, dtype=torch.long, device=DEVICE).unsqueeze(0), | |
| ] | |
| with torch.inference_mode(): | |
| z_q = snac_model.quantizer.from_codes(codes_tensor) | |
| audio = snac_model.decoder(z_q)[0, 0].cpu().numpy() | |
| # Remove warmup if present and save | |
| if len(audio) > 2048: | |
| audio = audio[2048:] | |
| out_path = OUT_ROOT / "tts_output_loaded_lora.wav" | |
| sf.write(out_path, audio, TARGET_SR) | |
| logs.append(f"[ok] saved {out_path} duration={(len(audio)/TARGET_SR):.2f}s") | |
| logs.append(f"[time] elapsed {time.time() - t0:.2f}s") | |
| return str(out_path), str(out_path), "\n".join(logs) | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| logs.append(f"[error] {e}\n{tb}") | |
| return None, None, "\n".join(logs) | |
| # -------------- | |
| # UI glue: combine description + emotion + user text (3a) | |
| # -------------- | |
| def generate_for_ui(text, preset_name, description, emotion): | |
| logs = [] | |
| try: | |
| # If user selected a preset, and description param is empty (e.g. custom not edited), | |
| # take preset description | |
| if preset_name in PRESET_CHARACTERS and (not description or description.strip() == ""): | |
| description = PRESET_CHARACTERS[preset_name]["description"] | |
| # combine (3a): final_text = f"{emotion} {description}. {text}" | |
| # For Maya prompt, we pass the combined description+text to build_maya_prompt | |
| combined_desc = f"{emotion} {description}".strip() | |
| final_plain = f"{combined_desc}. {text}".strip() | |
| final_prompt = build_maya_prompt(combined_desc, text) # keep maya wrapper | |
| audio_path, download_path, gen_logs = generate_from_loaded_model(final_prompt) | |
| if audio_path is None: | |
| return None, None, gen_logs | |
| return audio_path, download_path, gen_logs | |
| except Exception as e: | |
| return None, None, f"[error] {e}\n{traceback.format_exc()}" | |
| # ------------------------- | |
| # Gradio UI (keeps your layout; wide container) | |
| # ------------------------- | |
| css = ".gradio-container {max-width: 1400px}" | |
| with gr.Blocks(title="NAVA — MAYAAORG + LoRA + SNAC (Optimized)", css=css) as demo: | |
| gr.Markdown("# 🪶 NAVA — MAYAAORG + LoRA + SNAC (Optimized)\nGenerate emotional Hindi speech using Maya1 base + your LoRA adapter.") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Inference (CPU/GPU auto)\nType text + pick a preset or write description manually.") | |
| text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3) | |
| preset_select = gr.Dropdown(label="Select Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American") | |
| description_box = gr.Textbox(label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2) | |
| emotion_select = gr.Dropdown(label="Select Emotion", choices=EMOTION_TAGS, value="<neutral>") | |
| gen_btn = gr.Button("🔊 Generate Audio (LoRA)") | |
| gen_logs = gr.Textbox(label="Logs", lines=10) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Output") | |
| audio_player = gr.Audio(label="Generated Audio", type="filepath") | |
| download_file = gr.File(label="Download generated file") | |
| gr.Markdown("### Example") | |
| gr.Textbox(label="Example Text", value=EXAMPLE_PROMPT, lines=2, interactive=False) | |
| gr.Audio(label="Example Audio (project)", value=EXAMPLE_AUDIO_PATH, type="filepath", interactive=False) | |
| # wire updates: preset -> description | |
| def _update_desc(preset_name): | |
| return PRESET_CHARACTERS.get(preset_name, {}).get("description", "") | |
| preset_select.change(fn=_update_desc, inputs=[preset_select], outputs=[description_box]) | |
| # generation wrapper | |
| def _generate(text_in, preset_select, description_box, emotion_select): | |
| return generate_for_ui(text_in, preset_select, description_box, emotion_select) | |
| gen_btn.click(fn=_generate, | |
| inputs=[text_in, preset_select, description_box, emotion_select], | |
| outputs=[audio_player, download_file, gen_logs]) | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() | |