Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import base64 | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.prompts import PromptTemplate | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.chains import RetrievalQA | |
| from langchain.document_loaders import TextLoader, DirectoryLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| # === Load and Embed Documents === | |
| loader = DirectoryLoader( | |
| "courses", | |
| glob="**/*.txt", | |
| loader_cls=TextLoader | |
| ) | |
| raw_docs = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=700, | |
| chunk_overlap=100, | |
| separators=["\n###", "\n##", "\n\n", "\n", ".", " "] | |
| ) | |
| docs = text_splitter.split_documents(raw_docs) | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| vectorstore = Chroma.from_documents(docs, embedding=embedding_model) | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4}) | |
| # === Prompt Template === | |
| custom_prompt_template = """ | |
| You are a helpful and knowledgeable course advisor at the University of Hertfordshire. Answer the student's question using only the information provided in the context below. | |
| If the context does not contain the answer, politely respond that the information is not available. | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer: | |
| """ | |
| prompt = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=custom_prompt_template | |
| ) | |
| # === Load Falcon Model === | |
| model_name = "tiiuae/Falcon3-1B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=256, | |
| do_sample=False, | |
| temperature=0.1, | |
| top_p=0.9 | |
| ) | |
| llm = HuggingFacePipeline(pipeline=generator, model_kwargs={"return_full_text": False}) | |
| # === Setup Retrieval QA Chain === | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| retriever=retriever, | |
| chain_type="stuff", | |
| chain_type_kwargs={"prompt": prompt} | |
| ) | |
| # === Avatar and Crest === | |
| avatar_img = "images/UH.png" # Avatar shown beside bot messages | |
| logo = "images/UH Crest.png" # Crest image | |
| # # === Chat Logic with Course Memory === | |
| def chat_with_bot(message, history, course_state): | |
| lower_msg = message.lower() | |
| # Try to detect course from first question | |
| if "msc" in lower_msg: | |
| course_state = message.strip() # Store it for later use | |
| full_query = f"For the course '{course_state}': {message}" | |
| elif "change course to" in lower_msg: | |
| course_state = message.replace("change course to", "").strip() | |
| response = f"🔁 Course changed. Now answering based on: **{course_state}**" | |
| history.append((message, response)) | |
| return "", history, course_state | |
| elif course_state: | |
| full_query = f"For the course '{course_state}': {message}" | |
| else: | |
| full_query = message # No course memory yet | |
| try: | |
| raw_output = qa_chain.run(full_query) | |
| response = raw_output.split("Answer:")[-1].strip() | |
| response = response.replace("<|assistant|>", "").strip() | |
| except Exception as e: | |
| response = f"⚠️ An error occurred: {str(e)}" | |
| history.append((message, response)) | |
| return "", history, course_state | |
| # === Build Gradio UI === | |
| initial_message = ( | |
| "👋 Welcome! I'm your Assistant for the University of Hertfordshire.\n" | |
| "Struggling to find something on our website?\n" | |
| "Want to know anything about your MSc course?\n\n" | |
| "Simply ask and we can get started!\n\n" | |
| "⚠️ Please avoid sharing personal details in this chat.\n" | |
| "If personal details are ever needed, we’ll always ask for consent first." | |
| ) | |
| with gr.Blocks(title="🎓 UH Academic Advisor", css=""" | |
| .message.user { | |
| background-color: #d2e5ff !important; | |
| } | |
| """) as demo: | |
| # Convert crest image to base64 | |
| with open(logo, "rb") as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode("utf-8") | |
| # Logo header | |
| gr.Markdown(f""" | |
| <div style='display: flex; align-items: center; gap: 6px; line-height: 1;'> | |
| <img src="data:image/png;base64,{encoded_string}" style="height: 30px; margin-bottom: 2px;"> | |
| <h1 style='font-size: 18px; margin: 0;'>University of Hertfordshire Course Advisor Chatbot</h1> | |
| </div> | |
| """) | |
| chatbot = gr.Chatbot( | |
| avatar_images=(None, avatar_img), | |
| value=[(initial_message, "I'm ready to help!")], | |
| show_copy_button=True | |
| ) | |
| state = gr.State("") # Keeps course memory in-session | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Ask a question...", lines=1, scale=5) | |
| send_btn = gr.Button(" Send", scale=1) | |
| msg.submit(chat_with_bot, [msg, chatbot, state], [msg, chatbot, state]) | |
| send_btn.click(chat_with_bot, [msg, chatbot, state], [msg, chatbot, state]) | |
| # === Launch === | |
| demo.launch() | |