File size: 5,988 Bytes
4cc7463
 
 
 
 
 
 
 
 
 
 
 
 
e679d7f
4cc7463
 
1ab73a3
cb49713
 
 
72e871e
1ab73a3
cb49713
2028b24
cb49713
 
0ca306e
cb49713
fc9e2f8
cb49713
fc9e2f8
cb49713
b48e08c
cc47112
cb49713
 
 
 
4cc7463
 
 
 
 
 
 
 
 
 
 
 
 
 
938df3e
 
4cc7463
 
 
 
 
 
 
 
 
 
 
 
 
1ab73a3
2028b24
 
 
72e871e
1ab73a3
2028b24
 
73b4d6b
2028b24
0ca306e
2028b24
fc9e2f8
2028b24
fc9e2f8
2028b24
b48e08c
cc47112
2028b24
 
 
 
4cc7463
 
 
20a08ff
4cc7463
 
 
 
4d7679c
 
 
 
 
 
 
 
 
4cc7463
 
 
 
 
 
1de2e9a
4cc7463
1de2e9a
 
 
 
 
9a6bf7a
e679d7f
4cc7463
e679d7f
 
 
 
 
 
 
4cc7463
 
 
 
 
 
 
 
 
 
 
 
 
938df3e
 
 
 
 
 
 
4cc7463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client
from langchain_litellm import ChatLiteLLM

# Tools
from tools.add import add
from tools.analyze_csv_file import analyze_csv_file
from tools.analyze_excel_file import analyze_excel_file
from tools.arvix_search import arvix_search
from tools.combine_images import combine_images
from tools.divide import divide
from tools.download_file import download_file
from tools.draw_on_image import draw_on_image
from tools.execute_code_multilang import execute_code_multilang
from tools.extract_text_from_image import extract_text_from_image
from tools.generate_image import generate_image
from tools.get_image_properties import get_image_properties
from tools.modulus import modulus
from tools.multiply import multiply
from tools.power import power
from tools.python_code_parser import python_code_parser
from tools.save_read_file import save_read_file
from tools.set_image_properties import set_image_properties
from tools.square_root import square_root
from tools.subtract import subtract
from tools.web_search import web_search
from tools.wiki_search import wiki_search

load_dotenv()

# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read()

# System message
sys_msg = SystemMessage(content=system_prompt)

# build a retriever
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") #  dim=768
supabase: Client = create_client(
    os.environ.get("SUPABASE_URL"), 
    os.environ.get("SUPABASE_SERVICE_KEY")
)
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding= embeddings,
    table_name="documents",
    query_name="match_documents_langchain",
)
create_retriever_tool = create_retriever_tool(
    retriever=vector_store.as_retriever(),
    name="Question Search",
    description="A tool to retrieve similar questions from a vector store.",
)

tools = [
    add,
    analyze_csv_file,
    analyze_excel_file,
    arvix_search,
    combine_images,
    divide,
    download_file,
    draw_on_image,
    execute_code_multilang,
    extract_text_from_image,
    generate_image,
    get_image_properties,
    modulus,
    multiply,
    power,
    python_code_parser,
    save_read_file,
    set_image_properties,
    square_root,
    subtract,
    web_search,
    wiki_search,
]

# Build graph function
def build_graph(provider: str = "google"):
    """Build the graph"""
    # Load environment variables from .env file
    if provider == "google":
        # Google Gemini
        llm = ChatGoogleGenerativeAI(
            model="gemini-2.0-flash", 
            temperature=0,
            generation_config={
                "temperature": 0.0,
                "max_output_tokens": 2000,
                "candidate_count": 1,
            }
        )
    elif provider == "groq":
        # Groq https://console.groq.com/docs/models
        llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
    elif provider == "huggingface":
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                repo_id="mistralai/Mistral-7B-Instruct-v0.2",
                temperature=0,
                # Other models to try:
                # "meta-llama/Llama-2-7b-chat-hf"
                # "google/gemma-7b-it"
                # "mosaicml/mpt-7b-instruct"
                # "tiiuae/falcon-7b-instruct"
                token=os.environ.get("HF_TOKEN"),
            )
        )
    elif provider == "litellm":
        # HuggingFace Embeddings
        llm = ChatLiteLLM(
                model_id="ollama_chat/qwen2:7b",
                api_base="http://127.0.0.1:11434",
                num_ctx=8192,
            )
    else:
        raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(tools)

    # Node
    def assistant(state: MessagesState):
        """Assistant node"""
        return {"messages": [llm_with_tools.invoke(state["messages"])]}
    
    def retriever(state: MessagesState):
        """Retriever node"""
        similar_question = vector_store.similarity_search(state["messages"][0].content)

        if similar_question:
            content = f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}"
            example_msg = HumanMessage(content=content)
            return {"messages": [sys_msg] + state["messages"] + [example_msg]}
        else:
            return {"messages": [sys_msg] + state["messages"]}

    builder = StateGraph(MessagesState)
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    builder.add_edge("tools", "assistant")

    # Compile graph
    return builder.compile()

# test
if __name__ == "__main__":
    question = "When was a picture of Eiffel Tower first added to the Wikipedia page on the Principle of double effect?"
    # Build the graph
    graph = build_graph()
    # Run the graph
    messages = [HumanMessage(content=question)]
    messages = graph.invoke({"messages": messages})
    for m in messages["messages"]:
        m.pretty_print()