File size: 6,143 Bytes
adf4f56
df18b77
adf4f56
 
 
df18b77
f975304
df18b77
 
 
 
 
 
dc3619b
df18b77
adf4f56
df18b77
 
 
f975304
cfa2b6e
 
f975304
df18b77
f975304
df18b77
 
 
 
 
adf4f56
f975304
 
 
 
 
 
 
 
adf4f56
 
 
dc3619b
adf4f56
f975304
 
df18b77
 
adf4f56
 
df18b77
f975304
df18b77
f975304
adf4f56
 
 
f975304
 
adf4f56
 
 
 
 
df18b77
 
f975304
 
adf4f56
df18b77
adf4f56
 
dc3619b
 
adf4f56
 
dc3619b
adf4f56
dc3619b
adf4f56
dc3619b
adf4f56
dc3619b
f975304
adf4f56
 
f975304
adf4f56
 
f975304
adf4f56
f975304
dc3619b
 
 
 
adf4f56
 
 
 
dc3619b
 
adf4f56
 
dc3619b
 
 
adf4f56
dc3619b
 
 
adf4f56
 
 
dc3619b
 
 
f975304
 
adf4f56
df18b77
adf4f56
dc3619b
adf4f56
 
df18b77
adf4f56
df18b77
 
 
f975304
df18b77
dc3619b
 
adf4f56
dc3619b
adf4f56
dc3619b
f975304
df18b77
dc3619b
 
adf4f56
 
 
 
 
 
 
 
 
 
 
dc3619b
 
f975304
 
dc3619b
df18b77
 
adf4f56
df18b77
 
dc3619b
f975304
dc3619b
f975304
df18b77
 
 
f975304
adf4f56
 
f975304
df18b77
f975304
df18b77
adf4f56
df18b77
adf4f56
f975304
df18b77
adf4f56
 
 
df18b77
f975304
 
 
 
 
df18b77
f975304
df18b77
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# app_lora_speaker_fixed_optimized_cpu.py
"""
Strict CPU-only NAVA TTS with LoRA + SNAC
All CUDA, AMP, 4bit, device_map, torch.compile removed.
Uses int8 CPU quant for speed improvement.
"""

import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
import time
import traceback
from functools import lru_cache

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from snac import SNAC

# --------------------- CONFIG ---------------------
MODEL_NAME = "rahul7star/nava1.0"
LORA_NAME = "rahul7star/vaani-lora-lata"
SNAC_MODEL = "hubertsiuzdak/snac_24khz"
TARGET_SR = 24000

OUT_ROOT = Path("/tmp/data")
OUT_ROOT.mkdir(exist_ok=True, parents=True)

SPEAKERS = ["lata", "kavya", "agastya", "maitri", "vinaya"]

# special tokens
SPEECH_S = 128257
SPEECH_E = 128258
HUMAN_S  = 128259
HUMAN_E  = 128260
AI_S     = 128261
AI_E     = 128262
AUDIO_OFFSET = 128266

# FORCE CPU ONLY
DEVICE = torch.device("cpu")
USE_CUDA = False

print("[init] Running STRICT CPU ONLY mode")

# --------------------- TOKENIZER ---------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# --------------------- BASE MODEL (INT8 CPU QUANT) ---------------------
print("[init] Loading int8 quantized model (CPU)...")

base = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
).to("cpu")
base.eval()

print("[init] Base model loaded (fp32 CPU).")

# --------------------- LOAD LORA ---------------------
print(f"[init] Loading LoRA: {LORA_NAME}")
model = PeftModel.from_pretrained(base, LORA_NAME).to("cpu")
model.eval()

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"[init] LoRA params trainable={trainable} total={total}")

# Try LoRA merge
is_merged = False
try:
    if hasattr(model, "merge_and_unload"):
        print("[init] Attempting LoRA merge (CPU)...")
        model = model.merge_and_unload()
        is_merged = True
        print("[init] LoRA merged")
except Exception as e:
    print("[warn] merge failed", e)

model.config.use_cache = True

# --------------------- LOAD SNAC ---------------------
print("[init] Loading SNAC CPU...")
snac = SNAC.from_pretrained(SNAC_MODEL).to("cpu").eval()

# offsets
_OFFSETS = torch.tensor([AUDIO_OFFSET + i * 4096 for i in range(7)], device="cpu", dtype=torch.int32)

# --------------------- SNAC DECODE ---------------------
def decode_snac(snac_tokens):
    if not snac_tokens:
        return torch.zeros(1).numpy()

    if len(snac_tokens) % 7 != 0:
        snac_tokens = snac_tokens[: (len(snac_tokens) // 7) * 7]

    toks = torch.tensor(snac_tokens, dtype=torch.int32, device="cpu")

    try:
        frames = toks.view(-1, 7)
    except:
        return torch.zeros(1).numpy()

    lvl0 = frames[:, 0] - _OFFSETS[0]
    lvl1 = torch.stack([frames[:, 1] - _OFFSETS[1], frames[:, 4] - _OFFSETS[4]], dim=1).flatten()
    lvl2_cols = [2,3,5,6]
    lvl2 = torch.stack([frames[:, c] - _OFFSETS[c] for c in lvl2_cols], dim=1).flatten()

    tens = [
        lvl0.unsqueeze(0).to(torch.int32),
        lvl1.unsqueeze(0).to(torch.int32),
        lvl2.unsqueeze(0).to(torch.int32),
    ]

    with torch.inference_mode():
        audio = snac.decode(tens)

    return audio.squeeze().float().cpu().clamp(-1,1).numpy()

# --------------------- CACHES ---------------------
@lru_cache(maxsize=32)
def speaker_prefix_ids(speaker):
    return tokenizer.encode(f"<spk_{speaker}>", add_special_tokens=False)

# --------------------- GENERATE ---------------------
def generate_speech(text, speaker="lata", temperature=0.4, top_p=0.9):
    logs = []
    t0 = time.time()

    try:
        prefix_ids = speaker_prefix_ids(speaker)
        prompt_ids = tokenizer.encode(text, add_special_tokens=False)

        seq = [HUMAN_S] + prefix_ids + prompt_ids + [HUMAN_E, AI_S, SPEECH_S]
        input_ids = torch.tensor([seq], device="cpu", dtype=torch.long)

        max_new = min(21 + len(text) * 9, 900)

        gen_t0 = time.time()
        with torch.inference_mode():
            out = model.generate(
                input_ids,
                max_new_tokens=max_new,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=1.05,
                eos_token_id=[SPEECH_E, AI_E],
                pad_token_id=tokenizer.pad_token_id,
                use_cache=True,
            )
        gen_time = time.time() - gen_t0

        gen = out[0].tolist()[len(input_ids[0]):]

        snac_tokens = [t for t in gen if AUDIO_OFFSET <= t < AUDIO_OFFSET + 7 * 4096]

        if not snac_tokens:
            logs.append("[error] No SNAC audio tokens!")
            return None, "\n".join(logs)

        dec_t0 = time.time()
        audio = decode_snac(snac_tokens)
        dec_time = time.time() - dec_t0

        out_path = OUT_ROOT / f"{speaker}_{int(time.time())}.wav"
        sf.write(out_path, audio, TARGET_SR)

        logs.append(f"[ok] saved: {out_path}")
        logs.append(f"[time] gen: {gen_time:.3f}s decode: {dec_time:.3f}s total: {time.time()-t0:.2f}")
        logs.append(f"[info] merged: {is_merged}")

        return str(out_path), "\n".join(logs)

    except Exception as e:
        return None, f"[error] {e}\n{traceback.format_exc()}"

# --------------------- UI ---------------------
css = ".gradio-container {max-width: 900px !important;}"

with gr.Blocks(title="CPU NAVA TTS + LoRA", css=css) as demo:
    gr.Markdown("# 🎙️ NAVA 1.0 + LoRA (CPU ONLY)")
    gr.Markdown("Strict CPU-only optimized inference.")

    text_in = gr.Textbox(label="Hindi Text", lines=4, value="आज मौसम बहुत सुहावना है।")
    spk = gr.Dropdown(label="Speaker", value="lata", choices=SPEAKERS)
    btn = gr.Button("🔊 Generate")
    audio = gr.Audio(label="Result", type="filepath")
    logs = gr.Textbox(label="Logs", lines=10)

    btn.click(generate_speech, [text_in, spk], [audio, logs])


if __name__ == "__main__":
    demo.launch()