2. Add tools, agent with tools created and functionnal
+ Small clean up on main.py
This commit is contained in:
@@ -1 +1,6 @@
|
|||||||
|
# LLM Parameters
|
||||||
GOOGLE_API_KEY=your_google_api_key_here
|
GOOGLE_API_KEY=your_google_api_key_here
|
||||||
|
GRPC_VERBOSITY=NONE # Options: DEBUG, INFO, ERROR, NONE - to remove errors
|
||||||
|
|
||||||
|
# Tools parameters
|
||||||
|
TAVILY_API_KEY= # Search engine
|
||||||
|
|||||||
@@ -6,6 +6,6 @@ Step 1 : follow the LangGraph tutorial "Build a custom workflow"
|
|||||||
Step 2 : use the base to create a different custom workflow
|
Step 2 : use the base to create a different custom workflow
|
||||||
|
|
||||||
To use, install all dependancies : requirements.txt
|
To use, install all dependancies : requirements.txt
|
||||||
Add your Google Gemini API key in .env
|
Fullfil the .env (all parameters are mandatory)
|
||||||
Run : python -m main
|
Run : python -m main
|
||||||
Exit with "quit", "exit", "q"
|
Exit with "quit", "exit", "q"
|
||||||
38
main.py
38
main.py
@@ -1,19 +1,30 @@
|
|||||||
# from src.graph_builder import build_graph
|
# imports from langgraph
|
||||||
from src.graph_builder import graph_builder
|
|
||||||
from src.nodes import chatbot
|
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
|
|
||||||
|
# imports from local files
|
||||||
|
from src.graph_builder import graph_builder
|
||||||
|
from src.nodes import chatbot, BasicToolNode
|
||||||
|
from src.utils import route_tools
|
||||||
|
from src.tools import tool
|
||||||
|
|
||||||
# The first argument is the unique node name
|
|
||||||
# The second argument is the function or object that will be called whenever
|
|
||||||
# the node is used.
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
graph_builder.add_node("chatbot", chatbot)
|
# The first argument is the unique node name
|
||||||
graph_builder.add_edge(START, "chatbot")
|
# The second argument is the function or object that will be called whenever
|
||||||
graph_builder.add_edge("chatbot", END)
|
# the node is used.
|
||||||
|
graph_builder.add_node("chatbot", chatbot) # Create the chatbot node
|
||||||
|
graph_builder.add_edge("chatbot", END) # Add the edge from chatbot node to END
|
||||||
|
tool_node = BasicToolNode(tools=[tool]) # Intermediate variable creation to create a tool node with the tools
|
||||||
|
graph_builder.add_node("tools", tool_node) # Create a tool node with the tools
|
||||||
|
graph_builder.add_conditional_edges( # Create the conditionnal edges from chatbot to tools
|
||||||
|
"chatbot",
|
||||||
|
route_tools,
|
||||||
|
{"tools": "tools", END: END},
|
||||||
|
)
|
||||||
|
graph_builder.add_edge("tools", "chatbot") # Add the edge from tools node back to chatbot
|
||||||
|
graph_builder.add_edge(START, "chatbot") # Add the edge from START to chatbot node
|
||||||
graph = graph_builder.compile()
|
graph = graph_builder.compile()
|
||||||
|
|
||||||
def stream_graph_updates(user_input: str):
|
def stream_graph_updates(user_input: str):
|
||||||
for event in graph.stream({"messages": [{"role": "user", "content": user_input}]}):
|
for event in graph.stream({"messages": [{"role": "user", "content": user_input}]}):
|
||||||
for value in event.values():
|
for value in event.values():
|
||||||
@@ -32,4 +43,7 @@ if __name__ == "__main__":
|
|||||||
user_input = "What do you know about LangGraph?"
|
user_input = "What do you know about LangGraph?"
|
||||||
print("User: " + user_input)
|
print("User: " + user_input)
|
||||||
stream_graph_updates(user_input)
|
stream_graph_updates(user_input)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -17,16 +17,3 @@ class State(TypedDict):
|
|||||||
|
|
||||||
graph_builder = StateGraph(State)
|
graph_builder = StateGraph(State)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# from langgraph import Graph
|
|
||||||
# from src.nodes import CustomNode
|
|
||||||
|
|
||||||
# def build_graph():
|
|
||||||
# graph = Graph()
|
|
||||||
# # Add nodes and edges here
|
|
||||||
# node = CustomNode()
|
|
||||||
# graph.add_node(node)
|
|
||||||
# return graph
|
|
||||||
|
|||||||
44
src/nodes.py
44
src/nodes.py
@@ -1,21 +1,47 @@
|
|||||||
import os
|
# General imports
|
||||||
|
import os, json
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# LangGraph imports
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
|
# Local imports
|
||||||
from src.graph_builder import State
|
from src.graph_builder import State
|
||||||
|
from src.tools import tools
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
#os.environ["GOOGLE_API_KEY"] = "..."
|
|
||||||
|
|
||||||
llm = init_chat_model("google_genai:gemini-2.0-flash")
|
llm = init_chat_model("google_genai:gemini-2.0-flash")
|
||||||
|
|
||||||
|
# Modification: tell the LLM which tools it can call
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
|
||||||
def chatbot(state: State):
|
def chatbot(state: State):
|
||||||
return {"messages": [llm.invoke(state["messages"])]}
|
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
||||||
|
|
||||||
|
class BasicToolNode:
|
||||||
|
"""A node that runs the tools requested in the last AIMessage."""
|
||||||
|
|
||||||
# class CustomNode:
|
def __init__(self, tools: list) -> None:
|
||||||
# def __init__(self):
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
||||||
# pass
|
|
||||||
# def run(self):
|
def __call__(self, inputs: dict):
|
||||||
# print("Custom node running")
|
if messages := inputs.get("messages", []):
|
||||||
|
message = messages[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("No message found in input")
|
||||||
|
outputs = []
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(
|
||||||
|
tool_call["args"]
|
||||||
|
)
|
||||||
|
outputs.append(
|
||||||
|
ToolMessage(
|
||||||
|
content=json.dumps(tool_result),
|
||||||
|
name=tool_call["name"],
|
||||||
|
tool_call_id=tool_call["id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"messages": outputs}
|
||||||
12
src/tools.py
Normal file
12
src/tools.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# General imports
|
||||||
|
import os, json
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# LangGraph imports
|
||||||
|
from langchain_tavily import TavilySearch
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
tool = TavilySearch(max_results=10)
|
||||||
|
tools = [tool]
|
||||||
|
|
||||||
19
src/utils.py
19
src/utils.py
@@ -1 +1,20 @@
|
|||||||
|
from src.graph_builder import State, START, END
|
||||||
|
|
||||||
# Utility functions for LangGraph project
|
# Utility functions for LangGraph project
|
||||||
|
|
||||||
|
def route_tools( # function that checks for tool_calls in the chatbot's output
|
||||||
|
state: State,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use in the conditional_edge to route to the ToolNode if the last message
|
||||||
|
has tool calls. Otherwise, route to the end.
|
||||||
|
"""
|
||||||
|
if isinstance(state, list):
|
||||||
|
ai_message = state[-1]
|
||||||
|
elif messages := state.get("messages", []):
|
||||||
|
ai_message = messages[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
||||||
|
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
||||||
|
return "tools"
|
||||||
|
return END
|
||||||
|
|||||||
Reference in New Issue
Block a user