veen-1.0 / lora.py
rahul7star's picture
Update lora.py
adf4f56 verified
# app_lora_speaker_fixed_optimized_cpu.py
"""
Strict CPU-only NAVA TTS with LoRA + SNAC
All CUDA, AMP, 4bit, device_map, torch.compile removed.
Uses int8 CPU quant for speed improvement.
"""
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
import time
import traceback
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from snac import SNAC
# --------------------- CONFIG ---------------------
MODEL_NAME = "rahul7star/nava1.0"
LORA_NAME = "rahul7star/vaani-lora-lata"
SNAC_MODEL = "hubertsiuzdak/snac_24khz"
TARGET_SR = 24000
OUT_ROOT = Path("/tmp/data")
OUT_ROOT.mkdir(exist_ok=True, parents=True)
SPEAKERS = ["lata", "kavya", "agastya", "maitri", "vinaya"]
# special tokens
SPEECH_S = 128257
SPEECH_E = 128258
HUMAN_S = 128259
HUMAN_E = 128260
AI_S = 128261
AI_E = 128262
AUDIO_OFFSET = 128266
# FORCE CPU ONLY
DEVICE = torch.device("cpu")
USE_CUDA = False
print("[init] Running STRICT CPU ONLY mode")
# --------------------- TOKENIZER ---------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# --------------------- BASE MODEL (INT8 CPU QUANT) ---------------------
print("[init] Loading int8 quantized model (CPU)...")
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to("cpu")
base.eval()
print("[init] Base model loaded (fp32 CPU).")
# --------------------- LOAD LORA ---------------------
print(f"[init] Loading LoRA: {LORA_NAME}")
model = PeftModel.from_pretrained(base, LORA_NAME).to("cpu")
model.eval()
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"[init] LoRA params trainable={trainable} total={total}")
# Try LoRA merge
is_merged = False
try:
if hasattr(model, "merge_and_unload"):
print("[init] Attempting LoRA merge (CPU)...")
model = model.merge_and_unload()
is_merged = True
print("[init] LoRA merged")
except Exception as e:
print("[warn] merge failed", e)
model.config.use_cache = True
# --------------------- LOAD SNAC ---------------------
print("[init] Loading SNAC CPU...")
snac = SNAC.from_pretrained(SNAC_MODEL).to("cpu").eval()
# offsets
_OFFSETS = torch.tensor([AUDIO_OFFSET + i * 4096 for i in range(7)], device="cpu", dtype=torch.int32)
# --------------------- SNAC DECODE ---------------------
def decode_snac(snac_tokens):
if not snac_tokens:
return torch.zeros(1).numpy()
if len(snac_tokens) % 7 != 0:
snac_tokens = snac_tokens[: (len(snac_tokens) // 7) * 7]
toks = torch.tensor(snac_tokens, dtype=torch.int32, device="cpu")
try:
frames = toks.view(-1, 7)
except:
return torch.zeros(1).numpy()
lvl0 = frames[:, 0] - _OFFSETS[0]
lvl1 = torch.stack([frames[:, 1] - _OFFSETS[1], frames[:, 4] - _OFFSETS[4]], dim=1).flatten()
lvl2_cols = [2,3,5,6]
lvl2 = torch.stack([frames[:, c] - _OFFSETS[c] for c in lvl2_cols], dim=1).flatten()
tens = [
lvl0.unsqueeze(0).to(torch.int32),
lvl1.unsqueeze(0).to(torch.int32),
lvl2.unsqueeze(0).to(torch.int32),
]
with torch.inference_mode():
audio = snac.decode(tens)
return audio.squeeze().float().cpu().clamp(-1,1).numpy()
# --------------------- CACHES ---------------------
@lru_cache(maxsize=32)
def speaker_prefix_ids(speaker):
return tokenizer.encode(f"<spk_{speaker}>", add_special_tokens=False)
# --------------------- GENERATE ---------------------
def generate_speech(text, speaker="lata", temperature=0.4, top_p=0.9):
logs = []
t0 = time.time()
try:
prefix_ids = speaker_prefix_ids(speaker)
prompt_ids = tokenizer.encode(text, add_special_tokens=False)
seq = [HUMAN_S] + prefix_ids + prompt_ids + [HUMAN_E, AI_S, SPEECH_S]
input_ids = torch.tensor([seq], device="cpu", dtype=torch.long)
max_new = min(21 + len(text) * 9, 900)
gen_t0 = time.time()
with torch.inference_mode():
out = model.generate(
input_ids,
max_new_tokens=max_new,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.05,
eos_token_id=[SPEECH_E, AI_E],
pad_token_id=tokenizer.pad_token_id,
use_cache=True,
)
gen_time = time.time() - gen_t0
gen = out[0].tolist()[len(input_ids[0]):]
snac_tokens = [t for t in gen if AUDIO_OFFSET <= t < AUDIO_OFFSET + 7 * 4096]
if not snac_tokens:
logs.append("[error] No SNAC audio tokens!")
return None, "\n".join(logs)
dec_t0 = time.time()
audio = decode_snac(snac_tokens)
dec_time = time.time() - dec_t0
out_path = OUT_ROOT / f"{speaker}_{int(time.time())}.wav"
sf.write(out_path, audio, TARGET_SR)
logs.append(f"[ok] saved: {out_path}")
logs.append(f"[time] gen: {gen_time:.3f}s decode: {dec_time:.3f}s total: {time.time()-t0:.2f}")
logs.append(f"[info] merged: {is_merged}")
return str(out_path), "\n".join(logs)
except Exception as e:
return None, f"[error] {e}\n{traceback.format_exc()}"
# --------------------- UI ---------------------
css = ".gradio-container {max-width: 900px !important;}"
with gr.Blocks(title="CPU NAVA TTS + LoRA", css=css) as demo:
gr.Markdown("# 🎙️ NAVA 1.0 + LoRA (CPU ONLY)")
gr.Markdown("Strict CPU-only optimized inference.")
text_in = gr.Textbox(label="Hindi Text", lines=4, value="आज मौसम बहुत सुहावना है।")
spk = gr.Dropdown(label="Speaker", value="lata", choices=SPEAKERS)
btn = gr.Button("🔊 Generate")
audio = gr.Audio(label="Result", type="filepath")
logs = gr.Textbox(label="Logs", lines=10)
btn.click(generate_speech, [text_in, spk], [audio, logs])
if __name__ == "__main__":
demo.launch()