Spaces:
Sleeping
Sleeping
| # app_lora_speaker_fixed_optimized_cpu.py | |
| """ | |
| Strict CPU-only NAVA TTS with LoRA + SNAC | |
| All CUDA, AMP, 4bit, device_map, torch.compile removed. | |
| Uses int8 CPU quant for speed improvement. | |
| """ | |
| import gradio as gr | |
| import torch | |
| import soundfile as sf | |
| from pathlib import Path | |
| import time | |
| import traceback | |
| from functools import lru_cache | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from snac import SNAC | |
| # --------------------- CONFIG --------------------- | |
| MODEL_NAME = "rahul7star/nava1.0" | |
| LORA_NAME = "rahul7star/vaani-lora-lata" | |
| SNAC_MODEL = "hubertsiuzdak/snac_24khz" | |
| TARGET_SR = 24000 | |
| OUT_ROOT = Path("/tmp/data") | |
| OUT_ROOT.mkdir(exist_ok=True, parents=True) | |
| SPEAKERS = ["lata", "kavya", "agastya", "maitri", "vinaya"] | |
| # special tokens | |
| SPEECH_S = 128257 | |
| SPEECH_E = 128258 | |
| HUMAN_S = 128259 | |
| HUMAN_E = 128260 | |
| AI_S = 128261 | |
| AI_E = 128262 | |
| AUDIO_OFFSET = 128266 | |
| # FORCE CPU ONLY | |
| DEVICE = torch.device("cpu") | |
| USE_CUDA = False | |
| print("[init] Running STRICT CPU ONLY mode") | |
| # --------------------- TOKENIZER --------------------- | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| # --------------------- BASE MODEL (INT8 CPU QUANT) --------------------- | |
| print("[init] Loading int8 quantized model (CPU)...") | |
| base = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| ).to("cpu") | |
| base.eval() | |
| print("[init] Base model loaded (fp32 CPU).") | |
| # --------------------- LOAD LORA --------------------- | |
| print(f"[init] Loading LoRA: {LORA_NAME}") | |
| model = PeftModel.from_pretrained(base, LORA_NAME).to("cpu") | |
| model.eval() | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in model.parameters()) | |
| print(f"[init] LoRA params trainable={trainable} total={total}") | |
| # Try LoRA merge | |
| is_merged = False | |
| try: | |
| if hasattr(model, "merge_and_unload"): | |
| print("[init] Attempting LoRA merge (CPU)...") | |
| model = model.merge_and_unload() | |
| is_merged = True | |
| print("[init] LoRA merged") | |
| except Exception as e: | |
| print("[warn] merge failed", e) | |
| model.config.use_cache = True | |
| # --------------------- LOAD SNAC --------------------- | |
| print("[init] Loading SNAC CPU...") | |
| snac = SNAC.from_pretrained(SNAC_MODEL).to("cpu").eval() | |
| # offsets | |
| _OFFSETS = torch.tensor([AUDIO_OFFSET + i * 4096 for i in range(7)], device="cpu", dtype=torch.int32) | |
| # --------------------- SNAC DECODE --------------------- | |
| def decode_snac(snac_tokens): | |
| if not snac_tokens: | |
| return torch.zeros(1).numpy() | |
| if len(snac_tokens) % 7 != 0: | |
| snac_tokens = snac_tokens[: (len(snac_tokens) // 7) * 7] | |
| toks = torch.tensor(snac_tokens, dtype=torch.int32, device="cpu") | |
| try: | |
| frames = toks.view(-1, 7) | |
| except: | |
| return torch.zeros(1).numpy() | |
| lvl0 = frames[:, 0] - _OFFSETS[0] | |
| lvl1 = torch.stack([frames[:, 1] - _OFFSETS[1], frames[:, 4] - _OFFSETS[4]], dim=1).flatten() | |
| lvl2_cols = [2,3,5,6] | |
| lvl2 = torch.stack([frames[:, c] - _OFFSETS[c] for c in lvl2_cols], dim=1).flatten() | |
| tens = [ | |
| lvl0.unsqueeze(0).to(torch.int32), | |
| lvl1.unsqueeze(0).to(torch.int32), | |
| lvl2.unsqueeze(0).to(torch.int32), | |
| ] | |
| with torch.inference_mode(): | |
| audio = snac.decode(tens) | |
| return audio.squeeze().float().cpu().clamp(-1,1).numpy() | |
| # --------------------- CACHES --------------------- | |
| def speaker_prefix_ids(speaker): | |
| return tokenizer.encode(f"<spk_{speaker}>", add_special_tokens=False) | |
| # --------------------- GENERATE --------------------- | |
| def generate_speech(text, speaker="lata", temperature=0.4, top_p=0.9): | |
| logs = [] | |
| t0 = time.time() | |
| try: | |
| prefix_ids = speaker_prefix_ids(speaker) | |
| prompt_ids = tokenizer.encode(text, add_special_tokens=False) | |
| seq = [HUMAN_S] + prefix_ids + prompt_ids + [HUMAN_E, AI_S, SPEECH_S] | |
| input_ids = torch.tensor([seq], device="cpu", dtype=torch.long) | |
| max_new = min(21 + len(text) * 9, 900) | |
| gen_t0 = time.time() | |
| with torch.inference_mode(): | |
| out = model.generate( | |
| input_ids, | |
| max_new_tokens=max_new, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=1.05, | |
| eos_token_id=[SPEECH_E, AI_E], | |
| pad_token_id=tokenizer.pad_token_id, | |
| use_cache=True, | |
| ) | |
| gen_time = time.time() - gen_t0 | |
| gen = out[0].tolist()[len(input_ids[0]):] | |
| snac_tokens = [t for t in gen if AUDIO_OFFSET <= t < AUDIO_OFFSET + 7 * 4096] | |
| if not snac_tokens: | |
| logs.append("[error] No SNAC audio tokens!") | |
| return None, "\n".join(logs) | |
| dec_t0 = time.time() | |
| audio = decode_snac(snac_tokens) | |
| dec_time = time.time() - dec_t0 | |
| out_path = OUT_ROOT / f"{speaker}_{int(time.time())}.wav" | |
| sf.write(out_path, audio, TARGET_SR) | |
| logs.append(f"[ok] saved: {out_path}") | |
| logs.append(f"[time] gen: {gen_time:.3f}s decode: {dec_time:.3f}s total: {time.time()-t0:.2f}") | |
| logs.append(f"[info] merged: {is_merged}") | |
| return str(out_path), "\n".join(logs) | |
| except Exception as e: | |
| return None, f"[error] {e}\n{traceback.format_exc()}" | |
| # --------------------- UI --------------------- | |
| css = ".gradio-container {max-width: 900px !important;}" | |
| with gr.Blocks(title="CPU NAVA TTS + LoRA", css=css) as demo: | |
| gr.Markdown("# 🎙️ NAVA 1.0 + LoRA (CPU ONLY)") | |
| gr.Markdown("Strict CPU-only optimized inference.") | |
| text_in = gr.Textbox(label="Hindi Text", lines=4, value="आज मौसम बहुत सुहावना है।") | |
| spk = gr.Dropdown(label="Speaker", value="lata", choices=SPEAKERS) | |
| btn = gr.Button("🔊 Generate") | |
| audio = gr.Audio(label="Result", type="filepath") | |
| logs = gr.Textbox(label="Logs", lines=10) | |
| btn.click(generate_speech, [text_in, spk], [audio, logs]) | |
| if __name__ == "__main__": | |
| demo.launch() | |