288 lines
9.0 KiB
Python
288 lines
9.0 KiB
Python
"""
|
|
User API Key Model - Store encrypted API keys for AI providers.
|
|
"""
|
|
|
|
from django.db import models
|
|
from django.conf import settings
|
|
from django.utils.translation import gettext_lazy as _
|
|
from core.models import TimestampedModel
|
|
from cryptography.fernet import Fernet
|
|
from django.conf import settings as django_settings
|
|
import os
|
|
|
|
|
|
class UserAPIKey(TimestampedModel):
|
|
"""
|
|
Store encrypted user API keys for various AI providers.
|
|
|
|
Allows users to use their own API keys instead of platform credits.
|
|
Keys are encrypted at rest for security.
|
|
"""
|
|
|
|
# User association
|
|
user = models.ForeignKey(
|
|
settings.AUTH_USER_MODEL,
|
|
on_delete=models.CASCADE,
|
|
related_name="api_keys",
|
|
help_text=_("User who owns this API key"),
|
|
)
|
|
|
|
# Provider information
|
|
provider = models.CharField(
|
|
max_length=50,
|
|
choices=[
|
|
("openai", "OpenAI"),
|
|
("anthropic", "Anthropic (Claude)"),
|
|
("google", "Google AI"),
|
|
("cohere", "Cohere"),
|
|
("huggingface", "HuggingFace"),
|
|
("azure", "Azure OpenAI"),
|
|
("custom", "Custom Provider"),
|
|
],
|
|
help_text=_("AI provider for this key"),
|
|
)
|
|
|
|
provider_display_name = models.CharField(
|
|
max_length=100,
|
|
blank=True,
|
|
null=True,
|
|
help_text=_("Custom display name for provider"),
|
|
)
|
|
|
|
# Encrypted key
|
|
encrypted_key = models.BinaryField(
|
|
help_text=_("Encrypted API key (stored securely)")
|
|
)
|
|
|
|
# Key metadata
|
|
key_name = models.CharField(
|
|
max_length=255, help_text=_("User-friendly name for this key")
|
|
)
|
|
|
|
key_prefix = models.CharField(
|
|
max_length=20,
|
|
blank=True,
|
|
null=True,
|
|
help_text=_("First few characters of key (for identification)"),
|
|
)
|
|
|
|
# Status
|
|
is_active = models.BooleanField(
|
|
default=True, help_text=_("Whether this key is active and usable")
|
|
)
|
|
|
|
is_default = models.BooleanField(
|
|
default=False, help_text=_("Whether this is the default key for this provider")
|
|
)
|
|
|
|
# Validation
|
|
is_validated = models.BooleanField(
|
|
default=False, help_text=_("Whether key has been validated with provider")
|
|
)
|
|
|
|
last_validated_at = models.DateTimeField(
|
|
null=True, blank=True, help_text=_("When key was last validated")
|
|
)
|
|
|
|
validation_error = models.TextField(
|
|
blank=True, null=True, help_text=_("Error message from last validation attempt")
|
|
)
|
|
|
|
# Usage tracking
|
|
usage_count = models.IntegerField(
|
|
default=0, help_text=_("Number of times this key has been used")
|
|
)
|
|
|
|
last_used_at = models.DateTimeField(
|
|
null=True, blank=True, help_text=_("When this key was last used")
|
|
)
|
|
|
|
total_tokens_used = models.BigIntegerField(
|
|
default=0, help_text=_("Total tokens used with this key")
|
|
)
|
|
|
|
# Rate limiting and quotas
|
|
daily_limit = models.IntegerField(
|
|
null=True,
|
|
blank=True,
|
|
help_text=_("Daily usage limit (in tokens, null = unlimited)"),
|
|
)
|
|
|
|
monthly_limit = models.IntegerField(
|
|
null=True,
|
|
blank=True,
|
|
help_text=_("Monthly usage limit (in tokens, null = unlimited)"),
|
|
)
|
|
|
|
# Additional configuration
|
|
custom_config = models.JSONField(
|
|
default=dict, blank=True, help_text=_("Provider-specific configuration")
|
|
)
|
|
|
|
class Meta:
|
|
verbose_name = _("User API Key")
|
|
verbose_name_plural = _("User API Keys")
|
|
ordering = ["-is_default", "-last_used_at"]
|
|
unique_together = ["user", "provider", "key_name"]
|
|
indexes = [
|
|
models.Index(fields=["user", "provider"], name="apikey_user_provider_idx"),
|
|
models.Index(fields=["user", "is_active"], name="apikey_user_active_idx"),
|
|
models.Index(fields=["is_default"], name="apikey_default_idx"),
|
|
]
|
|
|
|
def __str__(self):
|
|
return f"{self.user.email} - {self.provider} ({self.key_name})"
|
|
|
|
@staticmethod
|
|
def get_encryption_key():
|
|
"""Get or create encryption key."""
|
|
# In production, store this in environment variable or secrets manager
|
|
key = getattr(django_settings, "API_KEY_ENCRYPTION_KEY", None)
|
|
|
|
if not key:
|
|
# Generate a key (do this once and save it)
|
|
key = Fernet.generate_key()
|
|
# In production: save this key securely!
|
|
|
|
return key
|
|
|
|
def encrypt_api_key(self, api_key):
|
|
"""Encrypt API key before storage."""
|
|
encryption_key = self.get_encryption_key()
|
|
fernet = Fernet(encryption_key)
|
|
|
|
# Store first few chars as prefix
|
|
self.key_prefix = api_key[:8] if len(api_key) >= 8 else api_key[:4]
|
|
|
|
# Encrypt the full key
|
|
self.encrypted_key = fernet.encrypt(api_key.encode())
|
|
|
|
def decrypt_api_key(self):
|
|
"""Decrypt API key for use."""
|
|
encryption_key = self.get_encryption_key()
|
|
fernet = Fernet(encryption_key)
|
|
|
|
decrypted = fernet.decrypt(self.encrypted_key)
|
|
return decrypted.decode()
|
|
|
|
def validate_key(self):
|
|
"""
|
|
Validate API key with provider.
|
|
|
|
Returns:
|
|
dict: {'valid': bool, 'error': str or None}
|
|
"""
|
|
from django.utils import timezone
|
|
|
|
try:
|
|
api_key = self.decrypt_api_key()
|
|
|
|
# Validate based on provider
|
|
if self.provider == "openai":
|
|
from openai import OpenAI
|
|
|
|
client = OpenAI(api_key=api_key)
|
|
# Test with minimal request
|
|
client.models.list()
|
|
|
|
elif self.provider == "anthropic":
|
|
import anthropic
|
|
|
|
client = anthropic.Anthropic(api_key=api_key)
|
|
# Test request
|
|
client.models.list()
|
|
|
|
# Add more providers as needed
|
|
|
|
# Mark as validated
|
|
self.is_validated = True
|
|
self.last_validated_at = timezone.now()
|
|
self.validation_error = None
|
|
self.save(
|
|
update_fields=["is_validated", "last_validated_at", "validation_error"]
|
|
)
|
|
|
|
return {"valid": True, "error": None}
|
|
|
|
except Exception as e:
|
|
self.is_validated = False
|
|
self.validation_error = str(e)
|
|
self.save(update_fields=["is_validated", "validation_error"])
|
|
|
|
return {"valid": False, "error": str(e)}
|
|
|
|
def increment_usage(self, tokens_used=0):
|
|
"""Track key usage."""
|
|
from django.utils import timezone
|
|
|
|
self.usage_count += 1
|
|
self.total_tokens_used += tokens_used
|
|
self.last_used_at = timezone.now()
|
|
|
|
self.save(update_fields=["usage_count", "total_tokens_used", "last_used_at"])
|
|
|
|
def check_limits(self, tokens_to_use=0):
|
|
"""
|
|
Check if usage would exceed limits.
|
|
|
|
Returns:
|
|
dict: {'allowed': bool, 'reason': str}
|
|
"""
|
|
from django.utils import timezone
|
|
from datetime import timedelta
|
|
|
|
# Check daily limit
|
|
if self.daily_limit:
|
|
today_start = timezone.now().replace(
|
|
hour=0, minute=0, second=0, microsecond=0
|
|
)
|
|
|
|
from .token_usage import TokenUsage
|
|
|
|
today_usage = (
|
|
TokenUsage.objects.filter(
|
|
user=self.user,
|
|
created_at__gte=today_start,
|
|
metadata__api_key_id=self.id,
|
|
).aggregate(total=models.Sum("total_tokens"))["total"]
|
|
or 0
|
|
)
|
|
|
|
if today_usage + tokens_to_use > self.daily_limit:
|
|
return {
|
|
"allowed": False,
|
|
"reason": f"Daily limit exceeded ({self.daily_limit} tokens)",
|
|
}
|
|
|
|
# Check monthly limit
|
|
if self.monthly_limit:
|
|
month_start = timezone.now().replace(
|
|
day=1, hour=0, minute=0, second=0, microsecond=0
|
|
)
|
|
|
|
from .token_usage import TokenUsage
|
|
|
|
month_usage = (
|
|
TokenUsage.objects.filter(
|
|
user=self.user,
|
|
created_at__gte=month_start,
|
|
metadata__api_key_id=self.id,
|
|
).aggregate(total=models.Sum("total_tokens"))["total"]
|
|
or 0
|
|
)
|
|
|
|
if month_usage + tokens_to_use > self.monthly_limit:
|
|
return {
|
|
"allowed": False,
|
|
"reason": f"Monthly limit exceeded ({self.monthly_limit} tokens)",
|
|
}
|
|
|
|
return {"allowed": True, "reason": "Within limits"}
|
|
|
|
@classmethod
|
|
def get_default_key(cls, user, provider):
|
|
"""Get default key for user and provider."""
|
|
return cls.objects.filter(
|
|
user=user, provider=provider, is_active=True, is_default=True
|
|
).first()
|