|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
MODEL_NAME = "tiiuae/falcon-7b-instruct" |
|
|
device = 0 |
|
|
else: |
|
|
MODEL_NAME = "gpt2" |
|
|
device = -1 |
|
|
|
|
|
print(f"Loading model: {MODEL_NAME} on {'GPU' if device == 0 else 'CPU'}") |
|
|
|
|
|
|
|
|
generator = pipeline("text-generation", model=MODEL_NAME, device=device) |
|
|
|
|
|
|
|
|
def respond(message, history): |
|
|
output = generator( |
|
|
message, |
|
|
max_new_tokens=256, |
|
|
num_return_sequences=1, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
truncation=True, |
|
|
)[0]["generated_text"] |
|
|
|
|
|
|
|
|
for i in range(0, len(output), 20): |
|
|
yield {"role": "assistant", "content": output[: i + 20]} |
|
|
|
|
|
|
|
|
chat = gr.ChatInterface( |
|
|
fn=respond, |
|
|
type="messages", |
|
|
chatbot=gr.Chatbot(height=600, show_copy_button=True, type="messages"), |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
chat.launch() |
|
|
|