|
|
import os |
|
|
import torch |
|
|
import pandas as pd |
|
|
import transformers |
|
|
from pynvml import * |
|
|
import torch |
|
|
from langchain import hub |
|
|
from model_ret import zephyr_model,llama_model,mistral_model,phi_model,flant5_model |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.runnables import RunnablePassthrough |
|
|
from create_retriever import ensemble_retriever |
|
|
|
|
|
hf_model_map = { |
|
|
"Zephyr": "HuggingFaceH4/zephyr-7b-beta", |
|
|
"Llama": "NousResearch/Meta-Llama-3-8B", |
|
|
"Mistral": "unsloth/mistral-7b-instruct-v0.3", |
|
|
"Phi": "microsoft/Phi-3-mini-4k-instruct", |
|
|
"Flant5": "google/flan-t5-base" |
|
|
} |
|
|
|
|
|
|
|
|
class model_chain: |
|
|
model_name = "" |
|
|
|
|
|
def __init__(self, |
|
|
model_name_local, |
|
|
model_name_online="Llama", |
|
|
use_local=True, |
|
|
embedding_name="BAAI/bge-base-en-v1.5", |
|
|
splitter_type_dropdown="character", |
|
|
chunk_size_slider=512, |
|
|
chunk_overlap_slider=30, |
|
|
separator_textbox="\n", |
|
|
max_tokens_slider=2048) -> None: |
|
|
if use_local: |
|
|
quantization, self.model_name = model_name_local.split("_")[0], model_name_local.split("_")[1] |
|
|
model_name_temp = model_name_local |
|
|
else: |
|
|
self.model_name = model_name_online |
|
|
model_name_temp = hf_model_map[model_name_online] |
|
|
|
|
|
if self.model_name == "Zephyr": |
|
|
self.llm = zephyr_model(model_name_temp, quantization, use_local=use_local) |
|
|
elif self.model_name == "Llama": |
|
|
self.llm = llama_model(model_name_temp, quantization, use_local=use_local) |
|
|
elif self.model_name == "Mistral": |
|
|
self.llm = mistral_model(model_name_temp, quantization, use_local=use_local) |
|
|
elif self.model_name == "Phi": |
|
|
self.llm = phi_model(model_name_temp, quantization, use_local=use_local) |
|
|
elif self.model_name == "Flant5": |
|
|
self.tokenizer, self.model, self.llm = flant5_model(model_name_temp, use_local=use_local) |
|
|
|
|
|
|
|
|
self.retriever = ensemble_retriever(embedding_name, |
|
|
splitter_type=splitter_type_dropdown, |
|
|
chunk_size=chunk_size_slider, |
|
|
chunk_overlap=chunk_overlap_slider, |
|
|
separator=separator_textbox, |
|
|
max_tokens=max_tokens_slider) |
|
|
|
|
|
|
|
|
prompt = hub.pull("rlm/rag-prompt") |
|
|
self.rag_chain = ( |
|
|
{"context": self.retriever | self.format_docs, "question": RunnablePassthrough()} |
|
|
| prompt |
|
|
| self.llm |
|
|
| StrOutputParser() |
|
|
) |
|
|
|
|
|
|
|
|
def format_docs(self, docs): |
|
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
|
|
|
def rag_chain_ret(self): |
|
|
return self.rag_chain |
|
|
|
|
|
|
|
|
def ans_ret(self, inp, rag_chain): |
|
|
if self.model_name == 'Flant5': |
|
|
my_question = "What is KUET?" |
|
|
data = self.retriever.invoke(inp) |
|
|
context = "" |
|
|
for x in data[:2]: |
|
|
context += (x.page_content) + "\n" |
|
|
inputs = f"""Please answer to this question using this context:\n{context}\n{my_question}""" |
|
|
inputs = self.tokenizer(inputs, return_tensors="pt") |
|
|
outputs = self.model.generate(**inputs) |
|
|
answer = self.tokenizer.decode(outputs[0]) |
|
|
from textwrap import fill |
|
|
ans = fill(answer, width=100) |
|
|
return ans |
|
|
|
|
|
ans = rag_chain.invoke(inp) |
|
|
ans = ans.split("Answer:")[1] |
|
|
return ans |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|