# app_lora_speaker_fixed_optimized_cpu.py """ Strict CPU-only NAVA TTS with LoRA + SNAC All CUDA, AMP, 4bit, device_map, torch.compile removed. Uses int8 CPU quant for speed improvement. """ import gradio as gr import torch import soundfile as sf from pathlib import Path import time import traceback from functools import lru_cache from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from snac import SNAC # --------------------- CONFIG --------------------- MODEL_NAME = "rahul7star/nava1.0" LORA_NAME = "rahul7star/vaani-lora-lata" SNAC_MODEL = "hubertsiuzdak/snac_24khz" TARGET_SR = 24000 OUT_ROOT = Path("/tmp/data") OUT_ROOT.mkdir(exist_ok=True, parents=True) SPEAKERS = ["lata", "kavya", "agastya", "maitri", "vinaya"] # special tokens SPEECH_S = 128257 SPEECH_E = 128258 HUMAN_S = 128259 HUMAN_E = 128260 AI_S = 128261 AI_E = 128262 AUDIO_OFFSET = 128266 # FORCE CPU ONLY DEVICE = torch.device("cpu") USE_CUDA = False print("[init] Running STRICT CPU ONLY mode") # --------------------- TOKENIZER --------------------- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) # --------------------- BASE MODEL (INT8 CPU QUANT) --------------------- print("[init] Loading int8 quantized model (CPU)...") base = AutoModelForCausalLM.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, low_cpu_mem_usage=True, ).to("cpu") base.eval() print("[init] Base model loaded (fp32 CPU).") # --------------------- LOAD LORA --------------------- print(f"[init] Loading LoRA: {LORA_NAME}") model = PeftModel.from_pretrained(base, LORA_NAME).to("cpu") model.eval() trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f"[init] LoRA params trainable={trainable} total={total}") # Try LoRA merge is_merged = False try: if hasattr(model, "merge_and_unload"): print("[init] Attempting LoRA merge (CPU)...") model = model.merge_and_unload() is_merged = True print("[init] LoRA merged") except Exception as e: print("[warn] merge failed", e) model.config.use_cache = True # --------------------- LOAD SNAC --------------------- print("[init] Loading SNAC CPU...") snac = SNAC.from_pretrained(SNAC_MODEL).to("cpu").eval() # offsets _OFFSETS = torch.tensor([AUDIO_OFFSET + i * 4096 for i in range(7)], device="cpu", dtype=torch.int32) # --------------------- SNAC DECODE --------------------- def decode_snac(snac_tokens): if not snac_tokens: return torch.zeros(1).numpy() if len(snac_tokens) % 7 != 0: snac_tokens = snac_tokens[: (len(snac_tokens) // 7) * 7] toks = torch.tensor(snac_tokens, dtype=torch.int32, device="cpu") try: frames = toks.view(-1, 7) except: return torch.zeros(1).numpy() lvl0 = frames[:, 0] - _OFFSETS[0] lvl1 = torch.stack([frames[:, 1] - _OFFSETS[1], frames[:, 4] - _OFFSETS[4]], dim=1).flatten() lvl2_cols = [2,3,5,6] lvl2 = torch.stack([frames[:, c] - _OFFSETS[c] for c in lvl2_cols], dim=1).flatten() tens = [ lvl0.unsqueeze(0).to(torch.int32), lvl1.unsqueeze(0).to(torch.int32), lvl2.unsqueeze(0).to(torch.int32), ] with torch.inference_mode(): audio = snac.decode(tens) return audio.squeeze().float().cpu().clamp(-1,1).numpy() # --------------------- CACHES --------------------- @lru_cache(maxsize=32) def speaker_prefix_ids(speaker): return tokenizer.encode(f"", add_special_tokens=False) # --------------------- GENERATE --------------------- def generate_speech(text, speaker="lata", temperature=0.4, top_p=0.9): logs = [] t0 = time.time() try: prefix_ids = speaker_prefix_ids(speaker) prompt_ids = tokenizer.encode(text, add_special_tokens=False) seq = [HUMAN_S] + prefix_ids + prompt_ids + [HUMAN_E, AI_S, SPEECH_S] input_ids = torch.tensor([seq], device="cpu", dtype=torch.long) max_new = min(21 + len(text) * 9, 900) gen_t0 = time.time() with torch.inference_mode(): out = model.generate( input_ids, max_new_tokens=max_new, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=1.05, eos_token_id=[SPEECH_E, AI_E], pad_token_id=tokenizer.pad_token_id, use_cache=True, ) gen_time = time.time() - gen_t0 gen = out[0].tolist()[len(input_ids[0]):] snac_tokens = [t for t in gen if AUDIO_OFFSET <= t < AUDIO_OFFSET + 7 * 4096] if not snac_tokens: logs.append("[error] No SNAC audio tokens!") return None, "\n".join(logs) dec_t0 = time.time() audio = decode_snac(snac_tokens) dec_time = time.time() - dec_t0 out_path = OUT_ROOT / f"{speaker}_{int(time.time())}.wav" sf.write(out_path, audio, TARGET_SR) logs.append(f"[ok] saved: {out_path}") logs.append(f"[time] gen: {gen_time:.3f}s decode: {dec_time:.3f}s total: {time.time()-t0:.2f}") logs.append(f"[info] merged: {is_merged}") return str(out_path), "\n".join(logs) except Exception as e: return None, f"[error] {e}\n{traceback.format_exc()}" # --------------------- UI --------------------- css = ".gradio-container {max-width: 900px !important;}" with gr.Blocks(title="CPU NAVA TTS + LoRA", css=css) as demo: gr.Markdown("# 🎙️ NAVA 1.0 + LoRA (CPU ONLY)") gr.Markdown("Strict CPU-only optimized inference.") text_in = gr.Textbox(label="Hindi Text", lines=4, value="आज मौसम बहुत सुहावना है।") spk = gr.Dropdown(label="Speaker", value="lata", choices=SPEAKERS) btn = gr.Button("🔊 Generate") audio = gr.Audio(label="Result", type="filepath") logs = gr.Textbox(label="Logs", lines=10) btn.click(generate_speech, [text_in, spk], [audio, logs]) if __name__ == "__main__": demo.launch()