File size: 4,705 Bytes
49266ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# app.py
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, get_peft_model, LoraConfig, TaskType
from snac import SNAC

# -----------------------------
# CONFIG
# -----------------------------
MODEL_NAME = "rahul7star/nava1.0"   # Base Maya model
LORA_NAME  = "rahul7star/nava-audio"  # LoRA adapter
SEQ_LEN = 2048
TARGET_SR = 24000
OUT_ROOT = Path("/tmp/data")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

# -----------------------------
# GENERATE AUDIO (LoRA)
# -----------------------------
def generate_audio_cpu_lora(text: str):
    logs = []
    try:
        DEVICE_CPU = "cpu"

        # Load tokenizer and base model
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        base_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            device_map={"": DEVICE_CPU},
            torch_dtype=torch.float32,
            trust_remote_code=True
        )
        logs.append("✅ Loaded base Maya model")

        # Load LoRA adapter from HF Hub
        model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": DEVICE_CPU})
        model.eval()
        logs.append(f"✅ Applied LoRA adapter from {LORA_NAME}")

        # Build prompt: just text prompt
        soh_token = tokenizer.decode([128259])
        eoh_token = tokenizer.decode([128260])
        soa_token = tokenizer.decode([128261])
        sos_token = tokenizer.decode([128257])
        eot_token = tokenizer.decode([128009])
        bos_token = tokenizer.bos_token
        prompt = soh_token + bos_token + text + eot_token + eoh_token + soa_token + sos_token

        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE_CPU)

        # Generate tokens
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=SEQ_LEN,
                temperature=0.4,
                top_p=0.9,
                repetition_penalty=1.1,
                do_sample=True,
                eos_token_id=128258,
                pad_token_id=tokenizer.pad_token_id
            )
        generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
        logs.append(f"✅ Generated {len(generated_ids)} token IDs")

        # Extract SNAC codes
        snac_min, snac_max = 128266, 156937
        eos_id = 128258
        try:
            eos_idx = generated_ids.index(eos_id)
        except ValueError:
            eos_idx = len(generated_ids)
        snac_tokens = [t for t in generated_ids[:eos_idx] if snac_min <= t <= snac_max]

        # Unpack 7-token SNAC frames
        l1, l2, l3 = [], [], []
        frames = len(snac_tokens) // 7
        snac_tokens = snac_tokens[:frames*7]
        for i in range(frames):
            slots = snac_tokens[i*7:(i+1)*7]
            l1.append((slots[0]-128266)%4096)
            l2.extend([(slots[1]-128266)%4096, (slots[4]-128266)%4096])
            l3.extend([(slots[2]-128266)%4096, (slots[3]-128266)%4096, (slots[5]-128266)%4096, (slots[6]-128266)%4096])
        logs.append(f"✅ Unpacked to {len(l1)} L1 frames, {len(l2)} L2 codes, {len(l3)} L3 codes")

        # SNAC decoder
        snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE_CPU)
        codes_tensor = [torch.tensor(level, dtype=torch.long, device=DEVICE_CPU).unsqueeze(0) for level in [l1,l2,l3]]
        with torch.inference_mode():
            z_q = snac_model.quantizer.from_codes(codes_tensor)
            audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
        if len(audio) > 2048:
            audio = audio[2048:]

        audio_path = OUT_ROOT / "tts_output_cpu_lora.wav"
        sf.write(audio_path, audio, TARGET_SR)
        logs.append(f"✅ Audio saved: {audio_path}, duration: {len(audio)/TARGET_SR:.2f}s")

        return str(audio_path), "\n".join(logs)

    except Exception as e:
        import traceback
        logs.append(f"[❌] CPU LoRA TTS error: {e}\n{traceback.format_exc()}")
        return None, "\n".join(logs)

# -----------------------------
# GRADIO UI
# -----------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Maya LoRA TTS (CPU)")
    input_text = gr.Textbox(label="Enter text", lines=2, placeholder="Type Hindi text here...")
    run_button = gr.Button("🔊 Generate Audio")
    audio_output = gr.Audio(label="Generated Audio", type="filepath")
    logs_output = gr.Textbox(label="Logs", lines=12)

    run_button.click(
        fn=generate_audio_cpu_lora,
        inputs=[input_text],
        outputs=[audio_output, logs_output]
    )

if __name__ == "__main__":
    demo.launch()