423 lines
12 KiB
Python
423 lines
12 KiB
Python
"""
|
|
Message Service
|
|
|
|
Handles message operations with LangGraph PostgresCheckpointer integration.
|
|
Manages conversation history, retrieval, and state management.
|
|
|
|
Usage:
|
|
from apps.chatbot.services import MessageService
|
|
|
|
# Get conversation history
|
|
messages = MessageService.get_conversation_history(
|
|
thread_id=session.id
|
|
)
|
|
|
|
# Add user message
|
|
MessageService.add_message(
|
|
thread_id=session.id,
|
|
content="Hello!",
|
|
message_type="human"
|
|
)
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
from uuid import UUID
|
|
from django.conf import settings
|
|
|
|
from langchain_core.messages import (
|
|
HumanMessage,
|
|
AIMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
BaseMessage
|
|
)
|
|
from langgraph.checkpoint.postgres import PostgresSaver
|
|
|
|
from apps.chatbot.models import ChatSession
|
|
|
|
|
|
class MessageService:
|
|
"""Service for managing messages via LangGraph checkpointer."""
|
|
|
|
@staticmethod
|
|
def _get_checkpointer() -> PostgresSaver:
|
|
"""
|
|
Get PostgresSaver instance for checkpointing.
|
|
|
|
Returns:
|
|
PostgresSaver instance connected to PG_CHECKPOINT_URI
|
|
"""
|
|
checkpointer = PostgresSaver.from_conn_string(
|
|
settings.PG_CHECKPOINT_URI
|
|
)
|
|
# Ensure tables exist
|
|
checkpointer.setup()
|
|
return checkpointer
|
|
|
|
@staticmethod
|
|
def get_conversation_history(
|
|
thread_id: UUID,
|
|
limit: Optional[int] = None,
|
|
checkpoint_id: Optional[str] = None
|
|
) -> List[BaseMessage]:
|
|
"""
|
|
Get conversation history from LangGraph checkpointer.
|
|
|
|
Args:
|
|
thread_id: Chat session ID (also LangGraph thread_id)
|
|
limit: Limit number of messages (optional)
|
|
checkpoint_id: Specific checkpoint to retrieve (optional)
|
|
|
|
Returns:
|
|
List of LangChain message objects
|
|
|
|
Example:
|
|
messages = MessageService.get_conversation_history(
|
|
thread_id=session.id,
|
|
limit=50
|
|
)
|
|
"""
|
|
checkpointer = MessageService._get_checkpointer()
|
|
|
|
# Build config
|
|
config = {
|
|
"configurable": {
|
|
"thread_id": str(thread_id)
|
|
}
|
|
}
|
|
|
|
if checkpoint_id:
|
|
config["configurable"]["checkpoint_id"] = checkpoint_id
|
|
|
|
# Get state from checkpointer
|
|
try:
|
|
state = checkpointer.get_state(config)
|
|
messages = state.get('messages', [])
|
|
|
|
if limit:
|
|
messages = messages[-limit:]
|
|
|
|
return messages
|
|
except Exception as e:
|
|
# If thread doesn't exist yet, return empty list
|
|
return []
|
|
|
|
@staticmethod
|
|
def get_state_history(
|
|
thread_id: UUID,
|
|
limit: Optional[int] = 10
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get checkpoint history for a thread.
|
|
|
|
Args:
|
|
thread_id: Chat session ID
|
|
limit: Number of checkpoints to retrieve
|
|
|
|
Returns:
|
|
List of checkpoint state snapshots
|
|
|
|
Example:
|
|
history = MessageService.get_state_history(
|
|
thread_id=session.id,
|
|
limit=5
|
|
)
|
|
"""
|
|
checkpointer = MessageService._get_checkpointer()
|
|
|
|
config = {
|
|
"configurable": {
|
|
"thread_id": str(thread_id)
|
|
}
|
|
}
|
|
|
|
history = []
|
|
for state in checkpointer.list(config, limit=limit):
|
|
history.append({
|
|
'checkpoint_id': state.config['configurable'].get('checkpoint_id'),
|
|
'timestamp': state.metadata.get('created_at'),
|
|
'message_count': len(state.values.get('messages', [])),
|
|
'metadata': state.metadata
|
|
})
|
|
|
|
return history
|
|
|
|
@staticmethod
|
|
def add_message(
|
|
thread_id: UUID,
|
|
content: str,
|
|
message_type: str = "human",
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> BaseMessage:
|
|
"""
|
|
Add a message to the conversation.
|
|
|
|
Note: This is typically done automatically by the agent.
|
|
Use this for manual message addition.
|
|
|
|
Args:
|
|
thread_id: Chat session ID
|
|
content: Message content
|
|
message_type: Type of message (human, ai, system, tool)
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
Created message object
|
|
|
|
Example:
|
|
msg = MessageService.add_message(
|
|
thread_id=session.id,
|
|
content="Hello!",
|
|
message_type="human"
|
|
)
|
|
"""
|
|
# Create appropriate message type
|
|
message_classes = {
|
|
'human': HumanMessage,
|
|
'ai': AIMessage,
|
|
'system': SystemMessage,
|
|
'tool': ToolMessage
|
|
}
|
|
|
|
MessageClass = message_classes.get(message_type, HumanMessage)
|
|
|
|
message = MessageClass(
|
|
content=content,
|
|
additional_kwargs=metadata or {}
|
|
)
|
|
|
|
return message
|
|
|
|
@staticmethod
|
|
def get_message_at_checkpoint(
|
|
thread_id: UUID,
|
|
checkpoint_id: str
|
|
) -> List[BaseMessage]:
|
|
"""
|
|
Get messages at a specific checkpoint (time-travel).
|
|
|
|
Args:
|
|
thread_id: Chat session ID
|
|
checkpoint_id: Specific checkpoint ID
|
|
|
|
Returns:
|
|
List of messages at that checkpoint
|
|
|
|
Example:
|
|
messages = MessageService.get_message_at_checkpoint(
|
|
thread_id=session.id,
|
|
checkpoint_id="1ef663ba-28fe-6528-8002-5a559208592c"
|
|
)
|
|
"""
|
|
return MessageService.get_conversation_history(
|
|
thread_id=thread_id,
|
|
checkpoint_id=checkpoint_id
|
|
)
|
|
|
|
@staticmethod
|
|
def update_state(
|
|
thread_id: UUID,
|
|
values: Dict[str, Any],
|
|
as_node: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Update conversation state (advanced usage).
|
|
|
|
Args:
|
|
thread_id: Chat session ID
|
|
values: State values to update
|
|
as_node: Update as if from this node
|
|
|
|
Example:
|
|
MessageService.update_state(
|
|
thread_id=session.id,
|
|
values={"custom_key": "custom_value"},
|
|
as_node="agent"
|
|
)
|
|
"""
|
|
checkpointer = MessageService._get_checkpointer()
|
|
|
|
config = {
|
|
"configurable": {
|
|
"thread_id": str(thread_id)
|
|
}
|
|
}
|
|
|
|
checkpointer.update_state(
|
|
config=config,
|
|
values=values,
|
|
as_node=as_node
|
|
)
|
|
|
|
@staticmethod
|
|
def format_messages_for_display(
|
|
messages: List[BaseMessage]
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Format LangChain messages for API response.
|
|
|
|
Args:
|
|
messages: List of LangChain message objects
|
|
|
|
Returns:
|
|
List of message dictionaries for frontend
|
|
|
|
Example:
|
|
messages = MessageService.get_conversation_history(thread_id)
|
|
formatted = MessageService.format_messages_for_display(messages)
|
|
"""
|
|
formatted = []
|
|
|
|
for msg in messages:
|
|
formatted_msg = {
|
|
'role': msg.type,
|
|
'content': msg.content,
|
|
'id': getattr(msg, 'id', None),
|
|
'timestamp': msg.additional_kwargs.get('timestamp'),
|
|
'metadata': msg.additional_kwargs
|
|
}
|
|
|
|
# Add tool calls if present
|
|
if hasattr(msg, 'tool_calls'):
|
|
formatted_msg['tool_calls'] = msg.tool_calls
|
|
|
|
# Add tool call ID if present
|
|
if hasattr(msg, 'tool_call_id'):
|
|
formatted_msg['tool_call_id'] = msg.tool_call_id
|
|
|
|
formatted.append(formatted_msg)
|
|
|
|
return formatted
|
|
|
|
@staticmethod
|
|
def delete_thread_history(thread_id: UUID) -> None:
|
|
"""
|
|
Delete all checkpoints for a thread.
|
|
|
|
WARNING: This permanently deletes conversation history!
|
|
|
|
Args:
|
|
thread_id: Chat session ID to delete
|
|
|
|
Example:
|
|
MessageService.delete_thread_history(thread_id=session.id)
|
|
"""
|
|
checkpointer = MessageService._get_checkpointer()
|
|
|
|
config = {
|
|
"configurable": {
|
|
"thread_id": str(thread_id)
|
|
}
|
|
}
|
|
|
|
# Delete all checkpoints for this thread
|
|
checkpointer.delete_state(config)
|
|
|
|
@staticmethod
|
|
def get_latest_checkpoint_id(thread_id: UUID) -> Optional[str]:
|
|
"""
|
|
Get the latest checkpoint ID for a thread.
|
|
|
|
Args:
|
|
thread_id: Chat session ID
|
|
|
|
Returns:
|
|
Latest checkpoint ID or None
|
|
|
|
Example:
|
|
checkpoint_id = MessageService.get_latest_checkpoint_id(
|
|
thread_id=session.id
|
|
)
|
|
"""
|
|
checkpointer = MessageService._get_checkpointer()
|
|
|
|
config = {
|
|
"configurable": {
|
|
"thread_id": str(thread_id)
|
|
}
|
|
}
|
|
|
|
try:
|
|
state = checkpointer.get_state(config)
|
|
return state.config['configurable'].get('checkpoint_id')
|
|
except:
|
|
return None
|
|
|
|
@staticmethod
|
|
def count_messages(thread_id: UUID) -> int:
|
|
"""
|
|
Count messages in a thread.
|
|
|
|
Args:
|
|
thread_id: Chat session ID
|
|
|
|
Returns:
|
|
Number of messages
|
|
|
|
Example:
|
|
count = MessageService.count_messages(thread_id=session.id)
|
|
"""
|
|
messages = MessageService.get_conversation_history(thread_id)
|
|
return len(messages)
|
|
|
|
@staticmethod
|
|
def get_last_n_messages(
|
|
thread_id: UUID,
|
|
n: int = 10
|
|
) -> List[BaseMessage]:
|
|
"""
|
|
Get last N messages from conversation.
|
|
|
|
Args:
|
|
thread_id: Chat session ID
|
|
n: Number of recent messages
|
|
|
|
Returns:
|
|
List of last N messages
|
|
|
|
Example:
|
|
recent_msgs = MessageService.get_last_n_messages(
|
|
thread_id=session.id,
|
|
n=5
|
|
)
|
|
"""
|
|
return MessageService.get_conversation_history(
|
|
thread_id=thread_id,
|
|
limit=n
|
|
)
|
|
|
|
@staticmethod
|
|
def search_messages(
|
|
thread_id: UUID,
|
|
search_query: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search messages in a conversation.
|
|
|
|
Args:
|
|
thread_id: Chat session ID
|
|
search_query: Text to search for
|
|
|
|
Returns:
|
|
List of matching messages with context
|
|
|
|
Example:
|
|
results = MessageService.search_messages(
|
|
thread_id=session.id,
|
|
search_query="python"
|
|
)
|
|
"""
|
|
messages = MessageService.get_conversation_history(thread_id)
|
|
|
|
results = []
|
|
for i, msg in enumerate(messages):
|
|
if search_query.lower() in msg.content.lower():
|
|
results.append({
|
|
'index': i,
|
|
'role': msg.type,
|
|
'content': msg.content,
|
|
'match_preview': msg.content[:200]
|
|
})
|
|
|
|
return results
|