uchill/chatnext/backend/apps/chatbot/models/user_api_key.py

288 lines
9.0 KiB
Python

"""
User API Key Model - Store encrypted API keys for AI providers.
"""
from django.db import models
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from core.models import TimestampedModel
from cryptography.fernet import Fernet
from django.conf import settings as django_settings
import os
class UserAPIKey(TimestampedModel):
"""
Store encrypted user API keys for various AI providers.
Allows users to use their own API keys instead of platform credits.
Keys are encrypted at rest for security.
"""
# User association
user = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
related_name="api_keys",
help_text=_("User who owns this API key"),
)
# Provider information
provider = models.CharField(
max_length=50,
choices=[
("openai", "OpenAI"),
("anthropic", "Anthropic (Claude)"),
("google", "Google AI"),
("cohere", "Cohere"),
("huggingface", "HuggingFace"),
("azure", "Azure OpenAI"),
("custom", "Custom Provider"),
],
help_text=_("AI provider for this key"),
)
provider_display_name = models.CharField(
max_length=100,
blank=True,
null=True,
help_text=_("Custom display name for provider"),
)
# Encrypted key
encrypted_key = models.BinaryField(
help_text=_("Encrypted API key (stored securely)")
)
# Key metadata
key_name = models.CharField(
max_length=255, help_text=_("User-friendly name for this key")
)
key_prefix = models.CharField(
max_length=20,
blank=True,
null=True,
help_text=_("First few characters of key (for identification)"),
)
# Status
is_active = models.BooleanField(
default=True, help_text=_("Whether this key is active and usable")
)
is_default = models.BooleanField(
default=False, help_text=_("Whether this is the default key for this provider")
)
# Validation
is_validated = models.BooleanField(
default=False, help_text=_("Whether key has been validated with provider")
)
last_validated_at = models.DateTimeField(
null=True, blank=True, help_text=_("When key was last validated")
)
validation_error = models.TextField(
blank=True, null=True, help_text=_("Error message from last validation attempt")
)
# Usage tracking
usage_count = models.IntegerField(
default=0, help_text=_("Number of times this key has been used")
)
last_used_at = models.DateTimeField(
null=True, blank=True, help_text=_("When this key was last used")
)
total_tokens_used = models.BigIntegerField(
default=0, help_text=_("Total tokens used with this key")
)
# Rate limiting and quotas
daily_limit = models.IntegerField(
null=True,
blank=True,
help_text=_("Daily usage limit (in tokens, null = unlimited)"),
)
monthly_limit = models.IntegerField(
null=True,
blank=True,
help_text=_("Monthly usage limit (in tokens, null = unlimited)"),
)
# Additional configuration
custom_config = models.JSONField(
default=dict, blank=True, help_text=_("Provider-specific configuration")
)
class Meta:
verbose_name = _("User API Key")
verbose_name_plural = _("User API Keys")
ordering = ["-is_default", "-last_used_at"]
unique_together = ["user", "provider", "key_name"]
indexes = [
models.Index(fields=["user", "provider"], name="apikey_user_provider_idx"),
models.Index(fields=["user", "is_active"], name="apikey_user_active_idx"),
models.Index(fields=["is_default"], name="apikey_default_idx"),
]
def __str__(self):
return f"{self.user.email} - {self.provider} ({self.key_name})"
@staticmethod
def get_encryption_key():
"""Get or create encryption key."""
# In production, store this in environment variable or secrets manager
key = getattr(django_settings, "API_KEY_ENCRYPTION_KEY", None)
if not key:
# Generate a key (do this once and save it)
key = Fernet.generate_key()
# In production: save this key securely!
return key
def encrypt_api_key(self, api_key):
"""Encrypt API key before storage."""
encryption_key = self.get_encryption_key()
fernet = Fernet(encryption_key)
# Store first few chars as prefix
self.key_prefix = api_key[:8] if len(api_key) >= 8 else api_key[:4]
# Encrypt the full key
self.encrypted_key = fernet.encrypt(api_key.encode())
def decrypt_api_key(self):
"""Decrypt API key for use."""
encryption_key = self.get_encryption_key()
fernet = Fernet(encryption_key)
decrypted = fernet.decrypt(self.encrypted_key)
return decrypted.decode()
def validate_key(self):
"""
Validate API key with provider.
Returns:
dict: {'valid': bool, 'error': str or None}
"""
from django.utils import timezone
try:
api_key = self.decrypt_api_key()
# Validate based on provider
if self.provider == "openai":
from openai import OpenAI
client = OpenAI(api_key=api_key)
# Test with minimal request
client.models.list()
elif self.provider == "anthropic":
import anthropic
client = anthropic.Anthropic(api_key=api_key)
# Test request
client.models.list()
# Add more providers as needed
# Mark as validated
self.is_validated = True
self.last_validated_at = timezone.now()
self.validation_error = None
self.save(
update_fields=["is_validated", "last_validated_at", "validation_error"]
)
return {"valid": True, "error": None}
except Exception as e:
self.is_validated = False
self.validation_error = str(e)
self.save(update_fields=["is_validated", "validation_error"])
return {"valid": False, "error": str(e)}
def increment_usage(self, tokens_used=0):
"""Track key usage."""
from django.utils import timezone
self.usage_count += 1
self.total_tokens_used += tokens_used
self.last_used_at = timezone.now()
self.save(update_fields=["usage_count", "total_tokens_used", "last_used_at"])
def check_limits(self, tokens_to_use=0):
"""
Check if usage would exceed limits.
Returns:
dict: {'allowed': bool, 'reason': str}
"""
from django.utils import timezone
from datetime import timedelta
# Check daily limit
if self.daily_limit:
today_start = timezone.now().replace(
hour=0, minute=0, second=0, microsecond=0
)
from .token_usage import TokenUsage
today_usage = (
TokenUsage.objects.filter(
user=self.user,
created_at__gte=today_start,
metadata__api_key_id=self.id,
).aggregate(total=models.Sum("total_tokens"))["total"]
or 0
)
if today_usage + tokens_to_use > self.daily_limit:
return {
"allowed": False,
"reason": f"Daily limit exceeded ({self.daily_limit} tokens)",
}
# Check monthly limit
if self.monthly_limit:
month_start = timezone.now().replace(
day=1, hour=0, minute=0, second=0, microsecond=0
)
from .token_usage import TokenUsage
month_usage = (
TokenUsage.objects.filter(
user=self.user,
created_at__gte=month_start,
metadata__api_key_id=self.id,
).aggregate(total=models.Sum("total_tokens"))["total"]
or 0
)
if month_usage + tokens_to_use > self.monthly_limit:
return {
"allowed": False,
"reason": f"Monthly limit exceeded ({self.monthly_limit} tokens)",
}
return {"allowed": True, "reason": "Within limits"}
@classmethod
def get_default_key(cls, user, provider):
"""Get default key for user and provider."""
return cls.objects.filter(
user=user, provider=provider, is_active=True, is_default=True
).first()