(Day 257) LangChain's intro to LangGraph part 2

Ivan Ivanov · September 14, 2024

Hello :) Today is Day 257!

A quick summary of today:

Module 2: State Schema

Schema

When we define a LangGraph StateGraph, we use a state schema.

The state schema represents the structure and types of data that our graph will use.

All nodes are expected to communicate with that schema.

LangGraph offers flexibility in how you define your state schema, accommodating various Python types and validation approaches!

TypedDict

As mentioned in Module 1, we can use the TypedDict class from python’s typing module.

It allows you to specify keys and their corresponding value types.

But, note that these are type hints.

They can used by static type checkers (like mypy) or IDEs to catch potential type-related errors before the code is run.

But they are not enforced at runtime!

Dataclass

Python’s dataclasses provide another way to define structured data.

Dataclasses offer a concise syntax for creating classes that are primarily used to store data.

To access the keys of a dataclass, we just need to modify the subscripting used in node_1:

  • We use state.name for the dataclass state rather than state["name"] for the TypedDict above

You’ll notice something a bit odd: in each node, we still return a dictionary to perform the state updates.

This is possible because LangGraph stores each key of your state object separately.

The object returned by the node only needs to have keys (attributes) that match those in the state!

In this case, the dataclass has key name so we can update it by passing a dict from our node, just as we did when state was a TypedDict.

The best option is:

Pydantic

As mentioned, TypedDict and dataclasses provide type hints but they don’t enforce types at runtime.

This means you could potentially assign invalid values without raising an error!

For example, we can set mood to mad even though our type hint specifies mood: list[Literal["happy","sad"]].

Pydantic is a data validation and settings management library using Python type annotations.

It’s particularly well-suited for defining state schemas in LangGraph due to its validation capabilities.

Pydantic can perform validation to check whether data conforms to the specified types and constraints at runtime.

from pydantic import BaseModel, field_validator, ValidationError

class PydanticState(BaseModel):
    name: str
    mood: Literal["happy", "sad"]

    @field_validator('mood')
    @classmethod
    def validate_mood(cls, value):
        # Ensure the mood is either "happy" or "sad"
        if value not in ["happy", "sad"]:
            raise ValueError("Each mood must be either 'happy' or 'sad'")
        return value

try:
    state = PydanticState(name="John Doe", mood="mad")
except ValidationError as e:
    print("Validation Error:", e)

This returns:

Validation Error: 1 validation error for PydanticState
mood
  Input should be 'happy' or 'sad' [type=literal_error, input_value='mad', input_type=str]
    For further information visit https://errors.pydantic.dev/2.8/v/literal_error

We can use PydanticState in our graph seamlessly.

# Build graph
builder = StateGraph(PydanticState)
builder.add_node("node_1", node_1)
builder.add_node("node_2", node_2)
builder.add_node("node_3", node_3)

# Logic
builder.add_edge(START, "node_1")
builder.add_conditional_edges("node_1", decide_mood)
builder.add_edge("node_2", END)
builder.add_edge("node_3", END)

# Add
graph = builder.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))

image

Module 2: State Reducers

image

We might have a branching case where 2 nodes execute at the same time and they want to update the same state. What happens then by default? an error

class State(TypedDict):
    foo: int

def node_1(state):
    print("---Node 1---")
    return {"foo": state['foo'] + 1}

def node_2(state):
    print("---Node 2---")
    return {"foo": state['foo'] + 1}

def node_3(state):
    print("---Node 3---")
    return {"foo": state['foo'] + 1}

# Build graph
builder = StateGraph(State)
builder.add_node("node_1", node_1)
builder.add_node("node_2", node_2)
builder.add_node("node_3", node_3)

# Logic
builder.add_edge(START, "node_1")
builder.add_edge("node_1", "node_2")
builder.add_edge("node_1", "node_3")
builder.add_edge("node_2", END)
builder.add_edge("node_3", END)

# Add
graph = builder.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))
from langgraph.errors import InvalidUpdateError
try:
    graph.invoke({"foo" : 1})
except InvalidUpdateError as e:
    print(f"InvalidUpdateError occurred: {e}")
---Node 1---
---Node 2---
---Node 3---
InvalidUpdateError occurred: At key 'foo': Can receive only one value per step. Use an Annotated key to handle multiple values.

We see a problem!

Node 1 branches to nodes 2 and 3.

Nodes 2 and 3 run in parallel, which means they run in the same step of the graph.

They both attempt to overwrite the state within the same step.

This is ambiguous for the graph! Which state should it keep?

This is where Reducers come in

Reducers

Reducers give us a general way to address this problem.

They specify how to perform updates.

We can use the Annotated type to specify a reducer function.

For example, in this case let’s append the value returned from each node rather than overwriting them.

We just need a reducer that can perform this: operator.add is a function from Python’s built-in operator module.

When operator.add is applied to lists, it performs list concatenation.

from operator import add
from typing import Annotated

class State(TypedDict):
    foo: Annotated[list[int], add]

def node_1(state):
    print("---Node 1---")
    return {"foo": [state['foo'][0] + 1]}

# Build graph
builder = StateGraph(State)
builder.add_node("node_1", node_1)

# Logic
builder.add_edge(START, "node_1")
builder.add_edge("node_1", END)

# Add
graph = builder.compile()

Now if we manipulate the state:

graph.invoke({"foo" : [1]})
---Node 1---
{'foo': [1, 2]}

Now, our state key foo is a list.

This operator.add reducer function will append updates from each node to this list.

Messages

In module 1, the course showed how to use a built-in reducer, add_messages, to handle messages in state.

also showed that MessagesState is a useful shortcut if you want to work with messages.

  • MessagesState has a built-in messages key
  • It also has a built-in add_messages reducer for this key

These two are equivalent.

We’ll use the MessagesState class via from langgraph.graph import MessagesState for brevity.

from typing import Annotated
from langgraph.graph import MessagesState
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages

# Define a custom TypedDict that includes a list of messages with add_messages reducer
class CustomMessagesState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    added_key_1: str
    added_key_2: str
    # etc

# Use MessagesState, which includes the messages key with add_messages reducer
class ExtendedMessagesState(MessagesState):
    # Add any keys needed beyond messages, which is pre-built 
    added_key_1: str
    added_key_2: str
    # etc

Add msgs:

from langgraph.graph.message import add_messages
from langchain_core.messages import AIMessage, HumanMessage

# Initial state
initial_messages = [AIMessage(content="Hello! How can I assist you?", name="Model"),
                    HumanMessage(content="I'm looking for information on marine biology.", name="Lance")
                   ]

# New message to add
new_message = AIMessage(content="Sure, I can help with that. What specifically are you interested in?", name="Model")

# Test
add_messages(initial_messages , new_message)
[AIMessage(content='Hello! How can I assist you?', name='Model', id='79df9553-6f7e-4f79-9ec6-949d6857d89f'),
 HumanMessage(content="I'm looking for information on marine biology.", name='Lance', id='66ca4a3a-e362-4d96-a720-e4b2d2c206a0'),
 AIMessage(content='Sure, I can help with that. What specifically are you interested in?', name='Model', id='aa7d50d1-c4ee-4888-887e-aee79a4dae7f')]

Re-writing msgs

If we pass a message with the same ID as an existing one in our messages list, it will get overwritten!

# Initial state
initial_messages = [AIMessage(content="Hello! How can I assist you?", name="Model", id="1"),
                    HumanMessage(content="I'm looking for information on marine biology.", name="Lance", id="2")
                   ]

# New message to add
new_message = HumanMessage(content="I'm looking for information on whales, specifically", name="Lance", id="2")

# Test
add_messages(initial_messages , new_message)
[AIMessage(content='Hello! How can I assist you?', name='Model', id='1'),
 HumanMessage(content="I'm looking for information on whales, specifically", name='Lance', id='2')]

Removing msgs

from langchain_core.messages import RemoveMessage

# Message list
messages = [AIMessage("Hi.", name="Bot", id="1")]
messages.append(HumanMessage("Hi.", name="Lance", id="2"))
messages.append(AIMessage("So you said you were researching ocean mammals?", name="Bot", id="3"))
messages.append(HumanMessage("Yes, I know about whales. But what others should I learn about?", name="Lance", id="4"))

# Isolate messages to delete
delete_messages = [RemoveMessage(id=m.id) for m in messages[:-2]]
print(delete_messages)
[RemoveMessage(content='', id='1'), RemoveMessage(content='', id='2')]

Module 2: Multiple Schemas

Typically, all graph nodes communicate with a single schema.

Also, this single schema contains the graph’s input and output keys / channels.

But, there are cases where we may want a bit more control over this:

  • Internal nodes may pass information that is not required in the graph’s input / output.

  • We may also want to use different input / output schemas for the graph. The output might, for example, only contain a single relevant output key.

Private State

First, let’s cover the case of passing private state between nodes.

This is useful for anything needed as part of the intermediate working logic of the graph, but not relevant for the overall graph input or output.

We’ll define an OverallState and a PrivateState.

node_2 uses PrivateState as input, but writes out to OverallState.

from typing_extensions import TypedDict
from IPython.display import Image, display
from langgraph.graph import StateGraph, START, END

class OverallState(TypedDict):
    foo: int

class PrivateState(TypedDict):
    baz: int

def node_1(state: OverallState) -> PrivateState:
    print("---Node 1---")
    return {"baz": state['foo'] + 1}

def node_2(state: PrivateState) -> OverallState:
    print("---Node 2---")
    return {"foo": state['baz'] + 1}

# Build graph
builder = StateGraph(OverallState)
builder.add_node("node_1", node_1)
builder.add_node("node_2", node_2)

# Logic
builder.add_edge(START, "node_1")
builder.add_edge("node_1", "node_2")
builder.add_edge("node_2", END)

# Add
graph = builder.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))

image

graph.invoke({"foo" : 1})

Output:

---Node 1---
---Node 2---
{'foo': 3}

baz is only included in PrivateState.

node_2 uses PrivateState as input, but writes out to OverallState.

So, we can see that baz is excluded from the graph output because it is not in OverallState.

Input / Output Schema

By default, StateGraph takes in a single schema and all nodes are expected to communicate with that schema.

However, it is also possible to define explicit input and output schemas for a graph.

Often, in these cases, we define an “internal” schema that contains all keys relevant to graph operations.

But, we use specific input and output schemas to constrain the input and output.

First, let’s just run the graph with a single schema.

class OverallState(TypedDict):
    question: str
    answer: str
    notes: str

def thinking_node(state: OverallState):
    return {"answer": "bye", "notes": "... his is name is Lance"}

def answer_node(state: OverallState):
    return {"answer": "bye Lance"}

graph = StateGraph(OverallState)
graph.add_node("answer_node", answer_node)
graph.add_node("thinking_node", thinking_node)
graph.add_edge(START, "thinking_node")
graph.add_edge("thinking_node", "answer_node")
graph.add_edge("answer_node", END)

graph = graph.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))

image

graph.invoke({"question":"hi"})

Notice that the output of invoke contains all keys in OverallState:

{'question': 'hi', 'answer': 'bye Lance', 'notes': '... his is name is Lance'}

Now, let’s use a specific input and output schema.

class InputState(TypedDict):
    question: str

class OutputState(TypedDict):
    answer: str

class OverallState(TypedDict):
    question: str
    answer: str
    notes: str

def thinking_node(state: OverallState):
    return {"answer": "bye", "notes": "... his is name is Lance"}

def answer_node(state: OutputState):
    return {"answer": "bye Lance"}

graph = StateGraph(input=InputState, output=OutputState)
graph.add_node("answer_node", answer_node)
graph.add_node("thinking_node", thinking_node)
graph.add_edge(START, "thinking_node")
graph.add_edge("thinking_node", "answer_node")
graph.add_edge("answer_node", END)

graph = graph.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))

graph.invoke({"question":"hi"})
{'answer': 'bye Lance'}

We can see the output schema constrains the output to only the answer key.

Module 2: Filtering and trimming messages for a ChatBot

Reducer

A practical challenge when working with messages is managing long-running conversations.

Long-running conversations result in high token usage and latency if we are not careful, because we pass a growing list of messages to the model.

We have a few ways to address this.

First, recall the trick we saw using RemoveMessage and the add_messages reducer.

from langchain_core.messages import RemoveMessage

# Nodes
def filter_messages(state: MessagesState):
    # Delete all but the 2 most recent messages
    delete_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]] # Include only the last 2 msgs
    return {"messages": delete_messages}

def chat_model_node(state: MessagesState):    
    return {"messages": [llm.invoke(state["messages"])]}

# Build graph
builder = StateGraph(MessagesState)
builder.add_node("filter", filter_messages)
builder.add_node("chat_model", chat_model_node)
builder.add_edge(START, "filter")
builder.add_edge("filter", "chat_model")
builder.add_edge("chat_model", END)
graph = builder.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))

Filtering messages

If you don’t need or want to modify the graph state, you can just filter the messages you pass to the chat model.

For example, just pass in a filtered list: llm.invoke(messages[-1:]) to the model.

# Node
def chat_model_node(state: MessagesState):
    return {"messages": [llm.invoke(state["messages"][-1:])]}

# Build graph
builder = StateGraph(MessagesState)
builder.add_node("chat_model", chat_model_node)
builder.add_edge(START, "chat_model")
builder.add_edge("chat_model", END)
graph = builder.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))

Trim messages

Another approach is to trim messages, based upon a set number of tokens.

This restricts the message history to a specified number of tokens.

While filtering only returns a post-hoc subset of the messages between agents, trimming restricts the number of tokens that a chat model can use to respond.

See the trim_messages below.

from langchain_core.messages import trim_messages

# Node
def chat_model_node(state: MessagesState):
    messages = trim_messages(
            state["messages"],
            max_tokens=100,
            strategy="last",
            token_counter=ChatOpenAI(model="gpt-4o"),
            allow_partial=False,
        )
    return {"messages": [llm.invoke(messages)]}

# Build graph
builder = StateGraph(MessagesState)
builder.add_node("chat_model", chat_model_node)
builder.add_edge(START, "chat_model")
builder.add_edge("chat_model", END)
graph = builder.compile()

Module 2: ChatBot with msg summarisation and memory

With trimming we can lose important information from a message, so its not idea. Summarisation fixes this and still lowers the token cost

Init the model

from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-4o",temperature=0)

We’ll use MessagesState, as before.

In addition to the built-in messages key, we’ll now include a custom key (summary).

from langgraph.graph import MessagesState
class State(MessagesState):
    summary: str

We’ll define a node to call our LLM that incorporates a summary, if it exists, into the prompt.

from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage

# Define the logic to call the model
def call_model(state: State):
    
    # Get summary if it exists
    summary = state.get("summary", "")

    # If there is summary, then we add it
    if summary:
        
        # Add summary to system message
        system_message = f"Summary of conversation earlier: {summary}"

        # Append summary to any newer messages
        messages = [SystemMessage(content=system_message)] + state["messages"]
    
    else:
        messages = state["messages"]
    
    response = model.invoke(messages)
    return {"messages": response}

We’ll define a node to produce a summary.

Note, here we’ll use RemoveMessage to filter our state after we’ve produced the summary.

def summarize_conversation(state: State):
    
    # First, we get any existing summary
    summary = state.get("summary", "")

    # Create our summarization prompt 
    if summary:
        
        # A summary already exists
        summary_message = (
            f"This is summary of the conversation to date: {summary}\n\n"
            "Extend the summary by taking into account the new messages above:"
        )
        
    else:
        summary_message = "Create a summary of the conversation above:"

    # Add prompt to our history
    messages = state["messages"] + [HumanMessage(content=summary_message)]
    response = model.invoke(messages)
    
    # Delete all but the 2 most recent messages
    delete_messages = [RemoveMessage(id=m.id) for m in state["messages"][:-2]]
    return {"summary": response.content, "messages": delete_messages}

We’ll add a conditional edge to determine whether to produce a summary based on the conversation length.

from langgraph.graph import END
# Determine whether to end or summarize the conversation
def should_continue(state: State):
    
    """Return the next node to execute."""
    
    messages = state["messages"]
    
    # If there are more than six messages, then we summarize the conversation
    if len(messages) > 6:
        return "summarize_conversation"
    
    # Otherwise we can just end
    return END

Adding memory

Recall that state is transient to a single graph execution.

This limits our ability to have multi-turn conversations with interruptions.

As introduced at the end of Module 1, we can use persistence to address this!

LangGraph can use a checkpointer to automatically save the graph state after each step.

This built-in persistence layer gives us memory, allowing LangGraph to pick up from the last state update.

As we previously showed, one of the easiest to work with is MemorySaver, an in-memory key-value store for Graph state.

All we need to do is compile the graph with a checkpointer, and our graph has memory!

from IPython.display import Image, display
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START

# Define a new graph
workflow = StateGraph(State)
workflow.add_node("conversation", call_model)
workflow.add_node(summarize_conversation)

# Set the entrypoint as conversation
workflow.add_edge(START, "conversation")
workflow.add_conditional_edges("conversation", should_continue)
workflow.add_edge("summarize_conversation", END)

# Compile
memory = MemorySaver()
graph = workflow.compile(checkpointer=memory)
display(Image(graph.get_graph().draw_mermaid_png()))

image

Threads

The checkpointer saves the state at each step as a checkpoint.

These saved checkpoints can be grouped into a thread of conversation.

Think about Slack as an analog: different channels carry different conversations.

Threads are like Slack channels, capturing grouped collections of state (e.g., conversation).

Below, we use configurable to set a thread ID.

image

# Create a thread
config = {"configurable": {"thread_id": "1"}}

# Start conversation
input_message = HumanMessage(content="hi! I'm Lance")
output = graph.invoke({"messages": [input_message]}, config) 
for m in output['messages'][-1:]:
    m.pretty_print()

input_message = HumanMessage(content="what's my name?")
output = graph.invoke({"messages": [input_message]}, config) 
for m in output['messages'][-1:]:
    m.pretty_print()

input_message = HumanMessage(content="i like the 49ers!")
output = graph.invoke({"messages": [input_message]}, config) 
for m in output['messages'][-1:]:
    m.pretty_print()
================================== Ai Message ==================================

Hi Lance! How can I assist you today?
================================== Ai Message ==================================

You mentioned that your name is Lance. How can I help you today, Lance?
================================== Ai Message ==================================

That's awesome, Lance! The San Francisco 49ers have a rich history and a passionate fan base. Do you have a favorite player or a memorable game that stands out to you?

Now, we don’t yet have a summary of the state because we still have < = 6 messages.

This was set in should_continue.

    # If there are more than six messages, then we summarize the conversation
    if len(messages) > 6:
        return "summarize_conversation"

We can pick up the conversation because we have the thread.

Once we add a few more msgs (to exceed the threshold we set), we can see a summary is passed into our model

"Lance introduced himself and mentioned that he is a fan of the San Francisco 49ers. He specifically likes Nick Bosa and inquired if Bosa is the highest-paid defensive player. I confirmed that Nick Bosa signed a record-breaking contract extension in September 2023, making him the highest-paid defensive player at that time, and acknowledged Bosa's talent and Lance's enthusiasm for the player."

Module 2: ChatBot with summarisation and external memory

It is like above, but for memory we use sqlite3

import sqlite3
# In memory
conn = sqlite3.connect(":memory:", check_same_thread = False)

# Saved to local db
db_path = "state_db/example.db"
conn = sqlite3.connect(db_path, check_same_thread=False)

# Here is our checkpointer 
from langgraph.checkpoint.sqlite import SqliteSaver
memory = SqliteSaver(conn)

Actually after watching and taking notes for module 2, I decided to just watch modules 3 and 4 because I started to feel a bit lost in terms of practicality. I understand there’s a lot of LangGraph basics that we need to cover, but it felt a bit unclear as to how we will take this from notebook to practice. I suspect in the next few months, as these LLM app libraries are moving very fast, there will be a lot of abstractions coming out to make the whole thing a bit more clear. Nevertheless understanding these basics is important as once abstractions come in, it can be a bit harder to go back.

Exercises in ML

On another note, I have been keeping this paper in my to-dos. It involves only math related exercises related to ML. So today I finally went to print it and created a small booklet.

image

I haven’t started it yet but it is something for the next few days. I checked the exercises and it is just math related to linear algebra, optimisation, directed/undirected graphical models, expressive power of graphical models, factor graphs and message passing, inference for hidden markov chains, model-based learning, sampling and monte-carlo integration, variational inference.

Some of these topics are new to me, but I am excited to at least have a look.


That is all for today!

See you tomorrow :)