Spaces:
Build error
Build error
File size: 5,988 Bytes
4cc7463 e679d7f 4cc7463 1ab73a3 cb49713 72e871e 1ab73a3 cb49713 2028b24 cb49713 0ca306e cb49713 fc9e2f8 cb49713 fc9e2f8 cb49713 b48e08c cc47112 cb49713 4cc7463 938df3e 4cc7463 1ab73a3 2028b24 72e871e 1ab73a3 2028b24 73b4d6b 2028b24 0ca306e 2028b24 fc9e2f8 2028b24 fc9e2f8 2028b24 b48e08c cc47112 2028b24 4cc7463 20a08ff 4cc7463 4d7679c 4cc7463 1de2e9a 4cc7463 1de2e9a 9a6bf7a e679d7f 4cc7463 e679d7f 4cc7463 938df3e 4cc7463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client
from langchain_litellm import ChatLiteLLM
# Tools
from tools.add import add
from tools.analyze_csv_file import analyze_csv_file
from tools.analyze_excel_file import analyze_excel_file
from tools.arvix_search import arvix_search
from tools.combine_images import combine_images
from tools.divide import divide
from tools.download_file import download_file
from tools.draw_on_image import draw_on_image
from tools.execute_code_multilang import execute_code_multilang
from tools.extract_text_from_image import extract_text_from_image
from tools.generate_image import generate_image
from tools.get_image_properties import get_image_properties
from tools.modulus import modulus
from tools.multiply import multiply
from tools.power import power
from tools.python_code_parser import python_code_parser
from tools.save_read_file import save_read_file
from tools.set_image_properties import set_image_properties
from tools.square_root import square_root
from tools.subtract import subtract
from tools.web_search import web_search
from tools.wiki_search import wiki_search
load_dotenv()
# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
# System message
sys_msg = SystemMessage(content=system_prompt)
# build a retriever
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
supabase: Client = create_client(
os.environ.get("SUPABASE_URL"),
os.environ.get("SUPABASE_SERVICE_KEY")
)
vector_store = SupabaseVectorStore(
client=supabase,
embedding= embeddings,
table_name="documents",
query_name="match_documents_langchain",
)
create_retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="A tool to retrieve similar questions from a vector store.",
)
tools = [
add,
analyze_csv_file,
analyze_excel_file,
arvix_search,
combine_images,
divide,
download_file,
draw_on_image,
execute_code_multilang,
extract_text_from_image,
generate_image,
get_image_properties,
modulus,
multiply,
power,
python_code_parser,
save_read_file,
set_image_properties,
square_root,
subtract,
web_search,
wiki_search,
]
# Build graph function
def build_graph(provider: str = "google"):
"""Build the graph"""
# Load environment variables from .env file
if provider == "google":
# Google Gemini
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
temperature=0,
generation_config={
"temperature": 0.0,
"max_output_tokens": 2000,
"candidate_count": 1,
}
)
elif provider == "groq":
# Groq https://console.groq.com/docs/models
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
elif provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
temperature=0,
# Other models to try:
# "meta-llama/Llama-2-7b-chat-hf"
# "google/gemma-7b-it"
# "mosaicml/mpt-7b-instruct"
# "tiiuae/falcon-7b-instruct"
token=os.environ.get("HF_TOKEN"),
)
)
elif provider == "litellm":
# HuggingFace Embeddings
llm = ChatLiteLLM(
model_id="ollama_chat/qwen2:7b",
api_base="http://127.0.0.1:11434",
num_ctx=8192,
)
else:
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# Node
def assistant(state: MessagesState):
"""Assistant node"""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
def retriever(state: MessagesState):
"""Retriever node"""
similar_question = vector_store.similarity_search(state["messages"][0].content)
if similar_question:
content = f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}"
example_msg = HumanMessage(content=content)
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
else:
return {"messages": [sys_msg] + state["messages"]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()
# test
if __name__ == "__main__":
question = "When was a picture of Eiffel Tower first added to the Wikipedia page on the Principle of double effect?"
# Build the graph
graph = build_graph()
# Run the graph
messages = [HumanMessage(content=question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()
|