newspace / app.py
hotmemeh's picture
Update app.py
a18f23e verified
raw
history blame
1.13 kB
import gradio as gr
from transformers import pipeline
import torch
# Auto-select model based on device
if torch.cuda.is_available():
MODEL_NAME = "tiiuae/falcon-7b-instruct" # GPU model
device = 0
else:
MODEL_NAME = "gpt2" # CPU fallback
device = -1
print(f"Loading model: {MODEL_NAME} on {'GPU' if device == 0 else 'CPU'}")
# Load Hugging Face pipeline
generator = pipeline("text-generation", model=MODEL_NAME, device=device)
# Streaming response function
def respond(message, history):
output = generator(
message,
max_new_tokens=256, # use this instead of max_length
num_return_sequences=1,
do_sample=True,
temperature=0.7,
truncation=True, # fixes truncation warning
)[0]["generated_text"]
# Stream output in chunks
for i in range(0, len(output), 20):
yield {"role": "assistant", "content": output[: i + 20]}
# Build the Gradio chat
chat = gr.ChatInterface(
fn=respond,
type="messages",
chatbot=gr.Chatbot(height=600, show_copy_button=True, type="messages"),
)
if __name__ == "__main__":
chat.launch()