2. Add tools, agent with tools created and functionnal
+ Small clean up on main.py
This commit is contained in:
@@ -1 +1,6 @@
|
||||
# LLM Parameters
|
||||
GOOGLE_API_KEY=your_google_api_key_here
|
||||
GRPC_VERBOSITY=NONE # Options: DEBUG, INFO, ERROR, NONE - to remove errors
|
||||
|
||||
# Tools parameters
|
||||
TAVILY_API_KEY= # Search engine
|
||||
|
||||
@@ -6,6 +6,6 @@ Step 1 : follow the LangGraph tutorial "Build a custom workflow"
|
||||
Step 2 : use the base to create a different custom workflow
|
||||
|
||||
To use, install all dependancies : requirements.txt
|
||||
Add your Google Gemini API key in .env
|
||||
Fullfil the .env (all parameters are mandatory)
|
||||
Run : python -m main
|
||||
Exit with "quit", "exit", "q"
|
||||
38
main.py
38
main.py
@@ -1,19 +1,30 @@
|
||||
# from src.graph_builder import build_graph
|
||||
from src.graph_builder import graph_builder
|
||||
from src.nodes import chatbot
|
||||
# imports from langgraph
|
||||
from langgraph.graph import START, END
|
||||
|
||||
# imports from local files
|
||||
from src.graph_builder import graph_builder
|
||||
from src.nodes import chatbot, BasicToolNode
|
||||
from src.utils import route_tools
|
||||
from src.tools import tool
|
||||
|
||||
# The first argument is the unique node name
|
||||
# The second argument is the function or object that will be called whenever
|
||||
# the node is used.
|
||||
|
||||
if __name__ == "__main__":
|
||||
graph_builder.add_node("chatbot", chatbot)
|
||||
graph_builder.add_edge(START, "chatbot")
|
||||
graph_builder.add_edge("chatbot", END)
|
||||
def main():
|
||||
# The first argument is the unique node name
|
||||
# The second argument is the function or object that will be called whenever
|
||||
# the node is used.
|
||||
graph_builder.add_node("chatbot", chatbot) # Create the chatbot node
|
||||
graph_builder.add_edge("chatbot", END) # Add the edge from chatbot node to END
|
||||
tool_node = BasicToolNode(tools=[tool]) # Intermediate variable creation to create a tool node with the tools
|
||||
graph_builder.add_node("tools", tool_node) # Create a tool node with the tools
|
||||
graph_builder.add_conditional_edges( # Create the conditionnal edges from chatbot to tools
|
||||
"chatbot",
|
||||
route_tools,
|
||||
{"tools": "tools", END: END},
|
||||
)
|
||||
graph_builder.add_edge("tools", "chatbot") # Add the edge from tools node back to chatbot
|
||||
graph_builder.add_edge(START, "chatbot") # Add the edge from START to chatbot node
|
||||
graph = graph_builder.compile()
|
||||
|
||||
|
||||
def stream_graph_updates(user_input: str):
|
||||
for event in graph.stream({"messages": [{"role": "user", "content": user_input}]}):
|
||||
for value in event.values():
|
||||
@@ -32,4 +43,7 @@ if __name__ == "__main__":
|
||||
user_input = "What do you know about LangGraph?"
|
||||
print("User: " + user_input)
|
||||
stream_graph_updates(user_input)
|
||||
break
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -17,16 +17,3 @@ class State(TypedDict):
|
||||
|
||||
graph_builder = StateGraph(State)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# from langgraph import Graph
|
||||
# from src.nodes import CustomNode
|
||||
|
||||
# def build_graph():
|
||||
# graph = Graph()
|
||||
# # Add nodes and edges here
|
||||
# node = CustomNode()
|
||||
# graph.add_node(node)
|
||||
# return graph
|
||||
|
||||
44
src/nodes.py
44
src/nodes.py
@@ -1,21 +1,47 @@
|
||||
import os
|
||||
# General imports
|
||||
import os, json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# LangGraph imports
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langgraph.graph import StateGraph
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
# Local imports
|
||||
from src.graph_builder import State
|
||||
from src.tools import tools
|
||||
|
||||
load_dotenv()
|
||||
|
||||
#os.environ["GOOGLE_API_KEY"] = "..."
|
||||
|
||||
llm = init_chat_model("google_genai:gemini-2.0-flash")
|
||||
|
||||
# Modification: tell the LLM which tools it can call
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
def chatbot(state: State):
|
||||
return {"messages": [llm.invoke(state["messages"])]}
|
||||
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
||||
|
||||
class BasicToolNode:
|
||||
"""A node that runs the tools requested in the last AIMessage."""
|
||||
|
||||
# class CustomNode:
|
||||
# def __init__(self):
|
||||
# pass
|
||||
# def run(self):
|
||||
# print("Custom node running")
|
||||
def __init__(self, tools: list) -> None:
|
||||
self.tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
def __call__(self, inputs: dict):
|
||||
if messages := inputs.get("messages", []):
|
||||
message = messages[-1]
|
||||
else:
|
||||
raise ValueError("No message found in input")
|
||||
outputs = []
|
||||
for tool_call in message.tool_calls:
|
||||
tool_result = self.tools_by_name[tool_call["name"]].invoke(
|
||||
tool_call["args"]
|
||||
)
|
||||
outputs.append(
|
||||
ToolMessage(
|
||||
content=json.dumps(tool_result),
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
return {"messages": outputs}
|
||||
12
src/tools.py
Normal file
12
src/tools.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# General imports
|
||||
import os, json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# LangGraph imports
|
||||
from langchain_tavily import TavilySearch
|
||||
|
||||
load_dotenv()
|
||||
|
||||
tool = TavilySearch(max_results=10)
|
||||
tools = [tool]
|
||||
|
||||
19
src/utils.py
19
src/utils.py
@@ -1 +1,20 @@
|
||||
from src.graph_builder import State, START, END
|
||||
|
||||
# Utility functions for LangGraph project
|
||||
|
||||
def route_tools( # function that checks for tool_calls in the chatbot's output
|
||||
state: State,
|
||||
):
|
||||
"""
|
||||
Use in the conditional_edge to route to the ToolNode if the last message
|
||||
has tool calls. Otherwise, route to the end.
|
||||
"""
|
||||
if isinstance(state, list):
|
||||
ai_message = state[-1]
|
||||
elif messages := state.get("messages", []):
|
||||
ai_message = messages[-1]
|
||||
else:
|
||||
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
||||
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
||||
return "tools"
|
||||
return END
|
||||
|
||||
Reference in New Issue
Block a user