uchill/chatnext/backend/apps/chatbot/services/message_service.py

423 lines
12 KiB
Python

"""
Message Service
Handles message operations with LangGraph PostgresCheckpointer integration.
Manages conversation history, retrieval, and state management.
Usage:
from apps.chatbot.services import MessageService
# Get conversation history
messages = MessageService.get_conversation_history(
thread_id=session.id
)
# Add user message
MessageService.add_message(
thread_id=session.id,
content="Hello!",
message_type="human"
)
"""
from typing import List, Dict, Any, Optional
from uuid import UUID
from django.conf import settings
from langchain_core.messages import (
HumanMessage,
AIMessage,
SystemMessage,
ToolMessage,
BaseMessage
)
from langgraph.checkpoint.postgres import PostgresSaver
from apps.chatbot.models import ChatSession
class MessageService:
"""Service for managing messages via LangGraph checkpointer."""
@staticmethod
def _get_checkpointer() -> PostgresSaver:
"""
Get PostgresSaver instance for checkpointing.
Returns:
PostgresSaver instance connected to PG_CHECKPOINT_URI
"""
checkpointer = PostgresSaver.from_conn_string(
settings.PG_CHECKPOINT_URI
)
# Ensure tables exist
checkpointer.setup()
return checkpointer
@staticmethod
def get_conversation_history(
thread_id: UUID,
limit: Optional[int] = None,
checkpoint_id: Optional[str] = None
) -> List[BaseMessage]:
"""
Get conversation history from LangGraph checkpointer.
Args:
thread_id: Chat session ID (also LangGraph thread_id)
limit: Limit number of messages (optional)
checkpoint_id: Specific checkpoint to retrieve (optional)
Returns:
List of LangChain message objects
Example:
messages = MessageService.get_conversation_history(
thread_id=session.id,
limit=50
)
"""
checkpointer = MessageService._get_checkpointer()
# Build config
config = {
"configurable": {
"thread_id": str(thread_id)
}
}
if checkpoint_id:
config["configurable"]["checkpoint_id"] = checkpoint_id
# Get state from checkpointer
try:
state = checkpointer.get_state(config)
messages = state.get('messages', [])
if limit:
messages = messages[-limit:]
return messages
except Exception as e:
# If thread doesn't exist yet, return empty list
return []
@staticmethod
def get_state_history(
thread_id: UUID,
limit: Optional[int] = 10
) -> List[Dict[str, Any]]:
"""
Get checkpoint history for a thread.
Args:
thread_id: Chat session ID
limit: Number of checkpoints to retrieve
Returns:
List of checkpoint state snapshots
Example:
history = MessageService.get_state_history(
thread_id=session.id,
limit=5
)
"""
checkpointer = MessageService._get_checkpointer()
config = {
"configurable": {
"thread_id": str(thread_id)
}
}
history = []
for state in checkpointer.list(config, limit=limit):
history.append({
'checkpoint_id': state.config['configurable'].get('checkpoint_id'),
'timestamp': state.metadata.get('created_at'),
'message_count': len(state.values.get('messages', [])),
'metadata': state.metadata
})
return history
@staticmethod
def add_message(
thread_id: UUID,
content: str,
message_type: str = "human",
metadata: Optional[Dict[str, Any]] = None
) -> BaseMessage:
"""
Add a message to the conversation.
Note: This is typically done automatically by the agent.
Use this for manual message addition.
Args:
thread_id: Chat session ID
content: Message content
message_type: Type of message (human, ai, system, tool)
metadata: Additional metadata
Returns:
Created message object
Example:
msg = MessageService.add_message(
thread_id=session.id,
content="Hello!",
message_type="human"
)
"""
# Create appropriate message type
message_classes = {
'human': HumanMessage,
'ai': AIMessage,
'system': SystemMessage,
'tool': ToolMessage
}
MessageClass = message_classes.get(message_type, HumanMessage)
message = MessageClass(
content=content,
additional_kwargs=metadata or {}
)
return message
@staticmethod
def get_message_at_checkpoint(
thread_id: UUID,
checkpoint_id: str
) -> List[BaseMessage]:
"""
Get messages at a specific checkpoint (time-travel).
Args:
thread_id: Chat session ID
checkpoint_id: Specific checkpoint ID
Returns:
List of messages at that checkpoint
Example:
messages = MessageService.get_message_at_checkpoint(
thread_id=session.id,
checkpoint_id="1ef663ba-28fe-6528-8002-5a559208592c"
)
"""
return MessageService.get_conversation_history(
thread_id=thread_id,
checkpoint_id=checkpoint_id
)
@staticmethod
def update_state(
thread_id: UUID,
values: Dict[str, Any],
as_node: Optional[str] = None
) -> None:
"""
Update conversation state (advanced usage).
Args:
thread_id: Chat session ID
values: State values to update
as_node: Update as if from this node
Example:
MessageService.update_state(
thread_id=session.id,
values={"custom_key": "custom_value"},
as_node="agent"
)
"""
checkpointer = MessageService._get_checkpointer()
config = {
"configurable": {
"thread_id": str(thread_id)
}
}
checkpointer.update_state(
config=config,
values=values,
as_node=as_node
)
@staticmethod
def format_messages_for_display(
messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
"""
Format LangChain messages for API response.
Args:
messages: List of LangChain message objects
Returns:
List of message dictionaries for frontend
Example:
messages = MessageService.get_conversation_history(thread_id)
formatted = MessageService.format_messages_for_display(messages)
"""
formatted = []
for msg in messages:
formatted_msg = {
'role': msg.type,
'content': msg.content,
'id': getattr(msg, 'id', None),
'timestamp': msg.additional_kwargs.get('timestamp'),
'metadata': msg.additional_kwargs
}
# Add tool calls if present
if hasattr(msg, 'tool_calls'):
formatted_msg['tool_calls'] = msg.tool_calls
# Add tool call ID if present
if hasattr(msg, 'tool_call_id'):
formatted_msg['tool_call_id'] = msg.tool_call_id
formatted.append(formatted_msg)
return formatted
@staticmethod
def delete_thread_history(thread_id: UUID) -> None:
"""
Delete all checkpoints for a thread.
WARNING: This permanently deletes conversation history!
Args:
thread_id: Chat session ID to delete
Example:
MessageService.delete_thread_history(thread_id=session.id)
"""
checkpointer = MessageService._get_checkpointer()
config = {
"configurable": {
"thread_id": str(thread_id)
}
}
# Delete all checkpoints for this thread
checkpointer.delete_state(config)
@staticmethod
def get_latest_checkpoint_id(thread_id: UUID) -> Optional[str]:
"""
Get the latest checkpoint ID for a thread.
Args:
thread_id: Chat session ID
Returns:
Latest checkpoint ID or None
Example:
checkpoint_id = MessageService.get_latest_checkpoint_id(
thread_id=session.id
)
"""
checkpointer = MessageService._get_checkpointer()
config = {
"configurable": {
"thread_id": str(thread_id)
}
}
try:
state = checkpointer.get_state(config)
return state.config['configurable'].get('checkpoint_id')
except:
return None
@staticmethod
def count_messages(thread_id: UUID) -> int:
"""
Count messages in a thread.
Args:
thread_id: Chat session ID
Returns:
Number of messages
Example:
count = MessageService.count_messages(thread_id=session.id)
"""
messages = MessageService.get_conversation_history(thread_id)
return len(messages)
@staticmethod
def get_last_n_messages(
thread_id: UUID,
n: int = 10
) -> List[BaseMessage]:
"""
Get last N messages from conversation.
Args:
thread_id: Chat session ID
n: Number of recent messages
Returns:
List of last N messages
Example:
recent_msgs = MessageService.get_last_n_messages(
thread_id=session.id,
n=5
)
"""
return MessageService.get_conversation_history(
thread_id=thread_id,
limit=n
)
@staticmethod
def search_messages(
thread_id: UUID,
search_query: str
) -> List[Dict[str, Any]]:
"""
Search messages in a conversation.
Args:
thread_id: Chat session ID
search_query: Text to search for
Returns:
List of matching messages with context
Example:
results = MessageService.search_messages(
thread_id=session.id,
search_query="python"
)
"""
messages = MessageService.get_conversation_history(thread_id)
results = []
for i, msg in enumerate(messages):
if search_query.lower() in msg.content.lower():
results.append({
'index': i,
'role': msg.type,
'content': msg.content,
'match_preview': msg.content[:200]
})
return results