ayushman12 commited on
Commit
0102ace
·
verified ·
1 Parent(s): 375176e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +65 -168
app.py CHANGED
@@ -1,168 +1,65 @@
1
- import random
2
- import os
3
- import numpy as np
4
- import torch
5
- import gradio as gr
6
- import spaces
7
- from chatterbox.tts_turbo import ChatterboxTurboTTS
8
-
9
-
10
- MODEL = ChatterboxTurboTTS.from_pretrained("cuda" )
11
-
12
- EVENT_TAGS = [
13
- "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
14
- "[sniff]", "[gasp]", "[chuckle]", "[laugh]"
15
- ]
16
-
17
- CUSTOM_CSS = """
18
- .tag-container {
19
- display: flex !important;
20
- flex-wrap: wrap !important;
21
- gap: 8px !important;
22
- margin-top: 5px !important;
23
- margin-bottom: 10px !important;
24
- border: none !important;
25
- background: transparent !important;
26
- }
27
-
28
- .tag-btn {
29
- min-width: fit-content !important;
30
- width: auto !important;
31
- height: 32px !important;
32
- font-size: 13px !important;
33
- background: #eef2ff !important;
34
- border: 1px solid #c7d2fe !important;
35
- color: #3730a3 !important;
36
- border-radius: 6px !important;
37
- padding: 0 10px !important;
38
- margin: 0 !important;
39
- box-shadow: none !important;
40
- }
41
-
42
- .tag-btn:hover {
43
- background: #c7d2fe !important;
44
- transform: translateY(-1px);
45
- }
46
- """
47
-
48
- INSERT_TAG_JS = """
49
- (tag_val, current_text) => {
50
- const textarea = document.querySelector('#main_textbox textarea');
51
- if (!textarea) return current_text + " " + tag_val;
52
-
53
- const start = textarea.selectionStart;
54
- const end = textarea.selectionEnd;
55
-
56
- let prefix = " ";
57
- let suffix = " ";
58
-
59
- if (start === 0) prefix = "";
60
- else if (current_text[start - 1] === ' ') prefix = "";
61
-
62
- if (end < current_text.length && current_text[end] === ' ') suffix = "";
63
-
64
- return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
65
- }
66
- """
67
-
68
- def set_seed(seed: int):
69
- torch.manual_seed(seed)
70
- torch.cuda.manual_seed(seed)
71
- torch.cuda.manual_seed_all(seed)
72
- random.seed(seed)
73
- np.random.seed(seed)
74
-
75
- @spaces.GPU
76
- def generate(
77
- text,
78
- audio_prompt_path,
79
- temperature,
80
- seed_num,
81
- min_p,
82
- top_p,
83
- top_k,
84
- repetition_penalty,
85
- norm_loudness
86
- ):
87
- if seed_num != 0:
88
- set_seed(int(seed_num))
89
-
90
- wav = MODEL.generate(
91
- text,
92
- audio_prompt_path=audio_prompt_path,
93
- temperature=temperature,
94
- min_p=min_p,
95
- top_p=top_p,
96
- top_k=int(top_k),
97
- repetition_penalty=repetition_penalty,
98
- norm_loudness=norm_loudness,
99
- )
100
-
101
- return (MODEL.sr, wav.squeeze(0).cpu().numpy())
102
-
103
-
104
- with gr.Blocks(title="Chatterbox Turbo") as demo:
105
- gr.Markdown("# ⚡ Chatterbox Turbo")
106
-
107
- with gr.Row():
108
- with gr.Column():
109
- text = gr.Textbox(
110
- value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?",
111
- label="Text to synthesize (max chars 300)",
112
- max_lines=5,
113
- elem_id="main_textbox"
114
- )
115
-
116
- with gr.Row(elem_classes=["tag-container"]):
117
- for tag in EVENT_TAGS:
118
- btn = gr.Button(tag, elem_classes=["tag-btn"])
119
- btn.click(
120
- fn=None,
121
- inputs=[btn, text],
122
- outputs=text,
123
- js=INSERT_TAG_JS
124
- )
125
-
126
- ref_wav = gr.Audio(
127
- sources=["upload", "microphone"],
128
- type="filepath",
129
- label="Reference Audio File",
130
- value="https://storage.googleapis.com/chatterbox-demo-samples/turbo/2.wav",
131
- )
132
-
133
- run_btn = gr.Button("Generate ⚡", variant="primary")
134
-
135
- with gr.Column():
136
- audio_output = gr.Audio(label="Output Audio")
137
-
138
- with gr.Accordion("Advanced Options", open=False):
139
- seed_num = gr.Number(value=0, label="Random seed (0 for random)")
140
- temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8)
141
- top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
142
- top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000)
143
- repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
144
- min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
145
- norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
146
-
147
- run_btn.click(
148
- fn=generate,
149
- inputs=[
150
- text,
151
- ref_wav,
152
- temp,
153
- seed_num,
154
- min_p,
155
- top_p,
156
- top_k,
157
- repetition_penalty,
158
- norm_loudness,
159
- ],
160
- outputs=audio_output,
161
- )
162
-
163
- if __name__ == "__main__":
164
- demo.queue().launch(
165
- mcp_server=True,
166
- css=CUSTOM_CSS,
167
- ssr_mode=False
168
- )
 
1
+ import os
2
+ import gradio as gr
3
+ import google.generativeai as genai
4
+
5
+ """**How to get Google Gemini API Key?**
6
+
7
+ - Go to https://aistudio.google.com/app/api-keys
8
+ - Click "Create API Key"
9
+ - Copy the API Key for your use
10
+ """
11
+
12
+ GEMINI_API_KEY="AIzaSyBg1CYTTOfWBrOzgxBhBLqHjujx7qVurrM"
13
+ genai.configure(api_key=GEMINI_API_KEY)
14
+
15
+ """
16
+ - Similar to Gemini Model we can also use HuggingFace Transformer Models.
17
+ - Reference links: https://python.langchain.com/docs/integrations/providers/huggingface , https://python.langchain.com/docs/integrations/llms/huggingface_hub.html
18
+
19
+ """
20
+
21
+ # from langchain.llms import HuggingFacePipeline
22
+ # hf = HuggingFacePipeline.from_model_id(
23
+ # model_id="gpt2",
24
+ # task="text-generation",)
25
+
26
+ # Initialize Gemini model
27
+ gemini_model = genai.GenerativeModel('gemini-1.5-flash')
28
+
29
+ # Custom LLM wrapper for Gemini
30
+ class GeminiLLM:
31
+ def __init__(self, model):
32
+ self.model = model
33
+ self.memory_history = []
34
+
35
+ def predict(self, user_message):
36
+ # Build conversation context
37
+ full_prompt = "You are a helpful assistant to answer user queries.\n"
38
+ for msg in self.memory_history:
39
+ full_prompt += f"{msg}\n"
40
+ full_prompt += f"User: {user_message}\nChatbot:"
41
+
42
+ # Generate response
43
+ response = self.model.generate_content(full_prompt)
44
+ answer = response.text
45
+
46
+ # Update memory
47
+ self.memory_history.append(f"User: {user_message}")
48
+ self.memory_history.append(f"Chatbot: {answer}")
49
+
50
+ # Keep only last 10 exchanges
51
+ if len(self.memory_history) > 20:
52
+ self.memory_history = self.memory_history[-20:]
53
+
54
+ return answer
55
+
56
+ llm_chain = GeminiLLM(gemini_model)
57
+
58
+ def get_text_response(user_message,history):
59
+ response = llm_chain.predict(user_message = user_message)
60
+ return response
61
+
62
+ demo = gr.ChatInterface(get_text_response, examples=["How are you doing?","What are your interests?","Which places do you like to visit?"])
63
+
64
+ if __name__ == "__main__":
65
+ demo.launch(debug=True) #To create a public link, set `share=True` in `launch()`. To enable errors and logs, set `debug=True` in `launch()`.