diff --git a/.env.template b/.env.template index 39bb49f..93a730f 100644 --- a/.env.template +++ b/.env.template @@ -1 +1,6 @@ +# LLM Parameters GOOGLE_API_KEY=your_google_api_key_here +GRPC_VERBOSITY=NONE # Options: DEBUG, INFO, ERROR, NONE - to remove errors + +# Tools parameters +TAVILY_API_KEY= # Search engine diff --git a/README.md b/README.md index 4e0e495..6a765bb 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,6 @@ Step 1 : follow the LangGraph tutorial "Build a custom workflow" Step 2 : use the base to create a different custom workflow To use, install all dependancies : requirements.txt -Add your Google Gemini API key in .env +Fullfil the .env (all parameters are mandatory) Run : python -m main Exit with "quit", "exit", "q" \ No newline at end of file diff --git a/main.py b/main.py index d3db1bf..1a9a72c 100644 --- a/main.py +++ b/main.py @@ -1,19 +1,30 @@ -# from src.graph_builder import build_graph -from src.graph_builder import graph_builder -from src.nodes import chatbot +# imports from langgraph from langgraph.graph import START, END +# imports from local files +from src.graph_builder import graph_builder +from src.nodes import chatbot, BasicToolNode +from src.utils import route_tools +from src.tools import tool -# The first argument is the unique node name -# The second argument is the function or object that will be called whenever -# the node is used. -if __name__ == "__main__": - graph_builder.add_node("chatbot", chatbot) - graph_builder.add_edge(START, "chatbot") - graph_builder.add_edge("chatbot", END) +def main(): + # The first argument is the unique node name + # The second argument is the function or object that will be called whenever + # the node is used. + graph_builder.add_node("chatbot", chatbot) # Create the chatbot node + graph_builder.add_edge("chatbot", END) # Add the edge from chatbot node to END + tool_node = BasicToolNode(tools=[tool]) # Intermediate variable creation to create a tool node with the tools + graph_builder.add_node("tools", tool_node) # Create a tool node with the tools + graph_builder.add_conditional_edges( # Create the conditionnal edges from chatbot to tools + "chatbot", + route_tools, + {"tools": "tools", END: END}, + ) + graph_builder.add_edge("tools", "chatbot") # Add the edge from tools node back to chatbot + graph_builder.add_edge(START, "chatbot") # Add the edge from START to chatbot node graph = graph_builder.compile() - + def stream_graph_updates(user_input: str): for event in graph.stream({"messages": [{"role": "user", "content": user_input}]}): for value in event.values(): @@ -32,4 +43,7 @@ if __name__ == "__main__": user_input = "What do you know about LangGraph?" print("User: " + user_input) stream_graph_updates(user_input) - break \ No newline at end of file + break + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/graph_builder.py b/src/graph_builder.py index b1ea8c5..29cd68e 100644 --- a/src/graph_builder.py +++ b/src/graph_builder.py @@ -17,16 +17,3 @@ class State(TypedDict): graph_builder = StateGraph(State) - - - - -# from langgraph import Graph -# from src.nodes import CustomNode - -# def build_graph(): -# graph = Graph() -# # Add nodes and edges here -# node = CustomNode() -# graph.add_node(node) -# return graph diff --git a/src/nodes.py b/src/nodes.py index fafbf09..0684294 100644 --- a/src/nodes.py +++ b/src/nodes.py @@ -1,21 +1,47 @@ -import os +# General imports +import os, json from dotenv import load_dotenv + +# LangGraph imports from langchain.chat_models import init_chat_model from langgraph.graph import StateGraph +from langchain_core.messages import ToolMessage + +# Local imports from src.graph_builder import State +from src.tools import tools load_dotenv() -#os.environ["GOOGLE_API_KEY"] = "..." - llm = init_chat_model("google_genai:gemini-2.0-flash") +# Modification: tell the LLM which tools it can call +llm_with_tools = llm.bind_tools(tools) + def chatbot(state: State): - return {"messages": [llm.invoke(state["messages"])]} + return {"messages": [llm_with_tools.invoke(state["messages"])]} +class BasicToolNode: + """A node that runs the tools requested in the last AIMessage.""" -# class CustomNode: -# def __init__(self): -# pass -# def run(self): -# print("Custom node running") + def __init__(self, tools: list) -> None: + self.tools_by_name = {tool.name: tool for tool in tools} + + def __call__(self, inputs: dict): + if messages := inputs.get("messages", []): + message = messages[-1] + else: + raise ValueError("No message found in input") + outputs = [] + for tool_call in message.tool_calls: + tool_result = self.tools_by_name[tool_call["name"]].invoke( + tool_call["args"] + ) + outputs.append( + ToolMessage( + content=json.dumps(tool_result), + name=tool_call["name"], + tool_call_id=tool_call["id"], + ) + ) + return {"messages": outputs} \ No newline at end of file diff --git a/src/tools.py b/src/tools.py new file mode 100644 index 0000000..4a27a97 --- /dev/null +++ b/src/tools.py @@ -0,0 +1,12 @@ +# General imports +import os, json +from dotenv import load_dotenv + +# LangGraph imports +from langchain_tavily import TavilySearch + +load_dotenv() + +tool = TavilySearch(max_results=10) +tools = [tool] + diff --git a/src/utils.py b/src/utils.py index 776bc84..5cb1444 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1 +1,20 @@ +from src.graph_builder import State, START, END + # Utility functions for LangGraph project + +def route_tools( # function that checks for tool_calls in the chatbot's output + state: State, +): + """ + Use in the conditional_edge to route to the ToolNode if the last message + has tool calls. Otherwise, route to the end. + """ + if isinstance(state, list): + ai_message = state[-1] + elif messages := state.get("messages", []): + ai_message = messages[-1] + else: + raise ValueError(f"No messages found in input state to tool_edge: {state}") + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + return "tools" + return END