import json
import re
import secrets
from datetime import timedelta
import boto3
from botocore.exceptions import ClientError
from django.conf import settings
from django.core.validators import RegexValidator
from django.db import models, transaction
from django.utils import timezone
MAX_MESSAGE_SIZE = 1024 * 128 # 128KB (AWS limit is 128KB)
[docs]
def get_region_name():
"""Returns the AWS region name from settings.py. Uses AWS_GATEWAY_REGION_NAME first and falls back to
AWS_REGION_NAME for backwards compatibility.
"""
if (
hasattr(settings, "AWS_GATEWAY_REGION_NAME")
and settings.AWS_GATEWAY_REGION_NAME
):
return settings.AWS_GATEWAY_REGION_NAME
elif hasattr(settings, "AWS_REGION_NAME") and settings.AWS_REGION_NAME:
return settings.AWS_REGION_NAME
else:
raise RuntimeError(
"AWS_GATEWAY_REGION_NAME or AWS_REGION_NAME must be set within settings.py"
)
[docs]
def get_boto3_client(service: str = "apigatewayv2", **kwargs):
"""Returns the boto3 client to use.
When running within AWS, if you are using an IAM Role with the service, E.G. on an EC2 instance, you need to
set AWS_REGION_NAME within settings.py
Otherwise you can either use a named profile using settings.AWS_IAM_PROFILE or you can set the credentials
using both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.
:param str service: apigatewayv2 | apigatewaymanagementapi
"""
if hasattr(settings, "AWS_IAM_PROFILE") and settings.AWS_IAM_PROFILE:
# Used a named profile where credentials are stored within .aws folder
session = boto3.Session(profile_name=settings.AWS_IAM_PROFILE)
client = session.client(service, **kwargs)
elif (
hasattr(settings, "AWS_ACCESS_KEY_ID")
and settings.AWS_ACCESS_KEY_ID
and hasattr(settings, "AWS_SECRET_ACCESS_KEY")
and settings.AWS_SECRET_ACCESS_KEY
):
# Use specific access and secret keys
region = get_region_name()
client = boto3.client(
service,
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
region_name=region,
**kwargs,
)
else:
# Use the IAM Role of the machine
region = get_region_name()
client = boto3.client(service, region_name=region, **kwargs)
return client
[docs]
class ApiGateway(models.Model):
"""Stored the API Gateway definitions"""
class Meta:
indexes = [
models.Index(fields=["api_id"]),
models.Index(fields=["domain_name"]),
models.Index(fields=["custom_domain_created"]),
models.Index(fields=["api_created"]),
models.Index(fields=["created_on"]),
]
def __str__(self) -> str:
return self.api_name
[docs]
def clean(self):
"""Validate fields only when they have values"""
from django.core.exceptions import ValidationError
# Validate domain_name only if provided
if self.domain_name:
domain_pattern = re.compile(
r"^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$"
)
if not domain_pattern.match(self.domain_name):
raise ValidationError({"domain_name": "Invalid domain name format"})
# Validate api_name only if provided
if self.api_name:
api_name_pattern = re.compile(r"^[ a-zA-Z0-9_-]+$")
if not api_name_pattern.match(self.api_name):
raise ValidationError(
{
"api_name": "API name can only contain alphanumeric characters, spaces, hyphens, and underscores"
}
)
[docs]
def save(self, **kwargs):
"""Ensure the trailing slash is saved to the target endpoint"""
self.full_clean()
if not self.target_base_endpoint[-1:] == "/":
self.target_base_endpoint = f"{self.target_base_endpoint}/"
super().save(**kwargs)
api_name = models.CharField(max_length=255, unique=True)
api_description = models.CharField(max_length=255, blank=True, default="")
default_channel_name = models.CharField(
max_length=128,
blank=True,
default="",
help_text="Automatically sets the 'channel' on the WebSocketSession record for the connection",
)
domain_name = models.CharField(
max_length=255,
blank=True,
default="",
help_text="The full domain you wish to use for the API endpoint. E.G ws.example.com",
)
target_base_endpoint = models.URLField(
null=True,
blank=True,
default=None,
help_text=(
"The URL on your website where the API Gateway routes will point, including the trailing /, but excluding "
"the final route/slug portion of the URL."
" E.G. If your default route will point to https://www.example.com/ws/default then enter "
"https://www.example.com/ws/"
),
)
certificate_arn = models.CharField(
max_length=255,
blank=True,
default="",
help_text="The ARN of the certificate to use from AWS Certificate Manager",
)
hosted_zone_id = models.CharField(
max_length=32,
blank=True,
default="",
help_text="The Hosted Zone ID from AWs Route 53 for the domain you wish to use",
)
api_key_selection_expression = models.CharField(
max_length=255, default="$request.header.x-api-key"
)
route_selection_expression = models.CharField(
max_length=255, default="$request.body.action"
)
route_key = models.CharField(max_length=255, default="$default")
stage_name = models.CharField(max_length=63, default="production", blank=True)
stage_description = models.CharField(max_length=63, default="", blank=True)
deployment_id = models.CharField(
max_length=32, default="", blank=True, editable=False
)
tags = models.JSONField(
blank=True, default=dict, help_text='In format {"tag-name": "tag-value"}'
)
# Returned Values
api_id = models.CharField(
max_length=32,
blank=True,
default="",
help_text="The ID of the Api Gateway returned by AWS",
)
api_endpoint = models.CharField(
max_length=255, blank=True, default="", help_text="The Api Gateway endpoint"
)
api_gateway_domain_name = models.CharField(
max_length=255,
blank=True,
default="",
help_text="The value to point your CNAME record to",
)
api_mapping_id = models.CharField(
max_length=128,
blank=True,
default="",
help_text="The ApiMappingId to use with api_mapping calls",
)
api_created = models.BooleanField(default=False, editable=False)
custom_domain_created = models.BooleanField(default=False, editable=False)
created_on = models.DateTimeField(auto_now_add=True)
updated_on = models.DateTimeField(auto_now=True)
[docs]
def create_gateway(self):
"""Creates the actual API gateway record"""
if self.api_created:
return
client = get_boto3_client()
self._create_api(client)
try:
self._create_routes(client)
if self.pk:
for additional_route in self.additional_routes.all():
if not additional_route.deployed:
additional_route.create_route(client, deploy=False)
self._create_stage_and_deploy(client)
except ClientError as ce:
raise ce
finally:
self.api_created = True
self.save()
[docs]
def create_custom_domain(self):
"""Uses boto3 to create the custom domain and associate it with the production stage of the loaded API
Should be called after create_gateway() has been run
:return:
"""
if not self.api_created:
raise ValueError("The API needs to be created before calling this method")
if not self.certificate_arn:
raise ValueError("A Certificate ARN is required")
client = get_boto3_client()
domain_res = self._create_domain_name(client)
try:
self.api_gateway_domain_name = domain_res["DomainNameConfigurations"][0][
"ApiGatewayDomainName"
]
self.api_mapping_id = self._create_api_mapping(client)
except ClientError as ce:
raise ce
finally:
self.custom_domain_created = True
self.save()
def _create_api(self, client):
"""Creates the base API Gateway endpoint"""
res = client.create_api(
ApiKeySelectionExpression=self.api_key_selection_expression,
Description=self.api_description,
DisableSchemaValidation=True,
Name=self.api_name,
ProtocolType="WEBSOCKET",
RouteKey=self.route_key,
RouteSelectionExpression=self.route_selection_expression,
Target=f"{self.target_base_endpoint}/default",
)
self.api_id = res["ApiId"]
self.api_endpoint = res["ApiEndpoint"]
def _create_routes(self, client):
"""Creates the integrations and routes associating the route with integrations"""
for route in ["$connect", "$disconnect", "$default"]:
integration_res = client.create_integration(
ApiId=self.api_id,
ConnectionType="INTERNET",
IntegrationMethod="POST",
IntegrationType="HTTP_PROXY",
IntegrationUri=f"{self.target_base_endpoint}{route.replace('$', '')}",
PassthroughBehavior="WHEN_NO_MATCH",
PayloadFormatVersion="1.0",
RequestParameters={
"integration.request.header.connectionId": "context.connectionId"
},
TimeoutInMillis=29000,
)
extra_kwargs = {}
if route == "$default":
extra_kwargs["RouteResponseSelectionExpression"] = route
client.create_route(
ApiId=self.api_id,
ApiKeyRequired=False,
AuthorizationType="NONE",
RouteKey=route,
Target=f"integrations/{integration_res['IntegrationId']}",
**extra_kwargs,
)
def _create_stage_and_deploy(self, client):
"""Create the stage and deployment"""
client.create_stage(ApiId=self.api_id, StageName=self.stage_name)
self.deploy_api(client)
[docs]
def deploy_api(self, client):
res = client.create_deployment(
ApiId=self.api_id,
Description=self.stage_description,
StageName=self.stage_name,
)
self.deployment_id = res["DeploymentId"]
def _create_domain_name(self, client):
"""Creates the domain including the HostedZoneID if one is set"""
conf = {
"CertificateArn": self.certificate_arn,
"DomainNameStatus": "AVAILABLE",
"EndpointType": "REGIONAL",
"SecurityPolicy": "TLS_1_2",
}
if self.hosted_zone_id:
conf["HostedZoneId"] = self.hosted_zone_id
return client.create_domain_name(
DomainName=self.domain_name, DomainNameConfigurations=[conf]
)
def _create_api_mapping(self, client) -> str:
mapping_res = client.create_api_mapping(
ApiId=self.api_id, DomainName=self.domain_name, Stage=self.stage_name
)
return mapping_res["ApiMappingId"]
[docs]
class ApiGatewayAdditionalRoute(models.Model):
"""Stores the additional route keys"""
class Meta:
# Once version 4 is the minimum supported version then swap to use this
# constraints = [
# models.UniqueConstraint("api_gateway", 'name', name='unique_name_per_gateway')
# ]
unique_together = [["api_gateway", "route_key"]]
def __str__(self) -> str:
return self.name
[docs]
def save(self, *args, **kwargs):
"""Save the record then deploy the new route if the parent has already been deployed"""
super().save(*args, **kwargs)
if self.api_gateway.deployment_id and not self.deployed:
client = get_boto3_client()
self.create_route(client)
api_gateway = models.ForeignKey(
ApiGateway, on_delete=models.CASCADE, related_name="additional_routes"
)
name = models.CharField(max_length=63, help_text="Descriptive name for the route")
route_key = models.CharField(max_length=64, db_index=True)
integration_url = models.URLField()
deployed = models.BooleanField(default=False, editable=False)
created_on = models.DateTimeField(auto_now_add=True)
updated_on = models.DateTimeField(auto_now=True)
[docs]
def create_route(self, client, deploy=True):
"""Create the Integration and then the route"""
integration_res = client.create_integration(
ApiId=self.api_gateway.api_id,
ConnectionType="INTERNET",
IntegrationMethod="POST",
IntegrationType="HTTP_PROXY",
IntegrationUri=self.integration_url,
PassthroughBehavior="WHEN_NO_MATCH",
PayloadFormatVersion="1.0",
RequestParameters={
"integration.request.header.connectionId": "context.connectionId"
},
TimeoutInMillis=29000,
)
client.create_route(
ApiId=self.api_gateway.api_id,
ApiKeyRequired=False,
AuthorizationType="NONE",
RouteKey=self.route_key,
Target=f"integrations/{integration_res['IntegrationId']}",
RouteResponseSelectionExpression="$default",
)
if deploy:
self.api_gateway.deploy_api(client)
self.deployed = True
self.save()
[docs]
class WebSocketSessionQuerySet(models.QuerySet):
[docs]
def send_message(self, data: dict):
"""Send the same message to all connected WebSocket sessions in the queryset.
Example::
WebSocketSession.objects.filter(
channel_name="Shared Channel Name",
).send_message(
{
"msg": "this is a server sent message",
}
)
"""
client = None
region = get_region_name()
msg = json.dumps(data)
res = []
for obj in self.filter(connected=True):
if not client:
client = get_boto3_client(
"apigatewaymanagementapi",
endpoint_url=(
f"https://{obj.api_gateway.api_id}.execute-api."
f"{region}.amazonaws.com/{obj.api_gateway.stage_name}"
),
)
try:
res.append(
client.post_to_connection(Data=msg, ConnectionId=obj.connection_id)
)
except ClientError as error:
if error.response["Error"]["Code"] == "GoneException":
obj.connected = False
obj.save()
else:
raise error
return res
[docs]
class WebSocketSession(models.Model):
class Meta:
indexes = [
models.Index(fields=["connection_id", "connected"]),
models.Index(fields=["channel_name", "connection_id"]),
models.Index(fields=["channel_name", "connected"]),
]
def __str__(self) -> str:
return self.connection_id
[docs]
def send_message(self, data: dict):
"""Sends a message containing the given data to connection"""
region = get_region_name()
message_data = json.dumps(data)
# Security: Validate message size before sending
message_size = len(message_data.encode("utf-8"))
if message_size > MAX_MESSAGE_SIZE:
raise ValueError(
f"Message exceeds AWS limit: {message_size} bytes (max: {MAX_MESSAGE_SIZE} bytes)"
)
client = get_boto3_client(
"apigatewaymanagementapi",
endpoint_url=(
f"https://{self.api_gateway.api_id}.execute-api."
f"{region}.amazonaws.com/{self.api_gateway.stage_name}"
),
)
try:
return client.post_to_connection(
Data=message_data, ConnectionId=self.connection_id
)
except ClientError as error:
if error.response["Error"]["Code"] == "GoneException":
self.connected = False
self.save()
else:
raise error
objects = WebSocketSessionQuerySet.as_manager()
connection_id = models.CharField(max_length=255, unique=True)
channel_name = models.CharField(
max_length=128,
blank=True,
default="",
help_text="Used to group connections together",
)
user = models.ForeignKey(
settings.AUTH_USER_MODEL,
blank=True,
null=True,
default=None,
on_delete=models.CASCADE,
related_name="websocket_sessions",
)
connected = models.BooleanField(
default=True, help_text="Indicates is the connection is current or not"
)
api_gateway = models.ForeignKey(
ApiGateway,
null=True,
blank=True,
default=None,
on_delete=models.SET_NULL,
related_name="sessions",
)
request_count = models.PositiveBigIntegerField(default=1)
created_on = models.DateTimeField(auto_now_add=True)
updated_on = models.DateTimeField(auto_now=True)
[docs]
class WebSocketToken(models.Model):
"""One-time use tokens for establishing WebSocket connections (CSRF protection)"""
token = models.CharField(
max_length=64,
unique=True,
validators=[RegexValidator(r"^[a-f0-9]{64}$", "Invalid token format")],
)
user = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
related_name="websocket_tokens",
)
session_key = models.CharField(max_length=40, db_index=True)
created_at = models.DateTimeField(auto_now_add=True, db_index=True)
used = models.BooleanField(default=False)
class Meta:
indexes = [
models.Index(fields=["token", "used"], name="idx_websocket_tokens_used"),
]
[docs]
@classmethod
def generate_token(cls, user, session_key):
"""Generate a one-time use WebSocket token"""
token_value = secrets.token_hex(32) # 64 character hex string
return cls.objects.create(token=token_value, user=user, session_key=session_key)
[docs]
@classmethod
def validate_and_consume(cls, token_value, session_key, max_age_seconds=60):
"""Validate token and mark as used (single-use)"""
try:
cutoff_time = timezone.now() - timedelta(seconds=max_age_seconds)
with transaction.atomic():
token_obj = cls.objects.select_for_update().get(
token=token_value,
session_key=session_key,
used=False,
created_at__gte=cutoff_time,
)
# Mark as used immediately
token_obj.used = True
token_obj.save(update_fields=["used"])
return token_obj.user
except cls.DoesNotExist:
return None
[docs]
@classmethod
def cleanup_expired(cls, max_age_seconds=300):
"""Delete expired/used tokens (call via cron/celery)"""
cutoff_time = timezone.now() - timedelta(seconds=max_age_seconds)
return cls.objects.filter(created_at__lt=cutoff_time).delete()
[docs]
@classmethod
def check_rate_limit(cls, user, max_tokens_per_minute=10):
"""Check if user has exceeded token generation rate limit"""
one_minute_ago = timezone.now() - timedelta(minutes=1)
recent_tokens = cls.objects.filter(
user=user, created_at__gte=one_minute_ago
).count()
return recent_tokens < max_tokens_per_minute
[docs]
class ConnectionRateLimit(models.Model):
"""Track connection attempts for rate limiting"""
ip_address = models.GenericIPAddressField(db_index=True)
user = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
null=True,
blank=True,
related_name="connection_attempts",
)
attempt_time = models.DateTimeField(auto_now_add=True, db_index=True)
successful = models.BooleanField(default=False)
class Meta:
indexes = [
models.Index(
fields=["ip_address", "attempt_time"], name="idx_ip_attempt_time"
),
models.Index(fields=["user", "attempt_time"], name="idx_user_attempt_time"),
]
[docs]
@classmethod
def check_rate_limit(cls, ip_address, user=None, max_attempts=20, window_minutes=5):
"""
Check if IP or user has exceeded connection rate limit
Returns: (is_allowed: bool, attempts_count: int)
"""
cutoff_time = timezone.now() - timedelta(minutes=window_minutes)
# Check IP-based rate limit
ip_attempts = cls.objects.filter(
ip_address=ip_address, attempt_time__gte=cutoff_time
).count()
if ip_attempts >= max_attempts:
return False, ip_attempts
# Check user-based rate limit if user provided
if user and user.is_authenticated:
user_attempts = cls.objects.filter(
user=user, attempt_time__gte=cutoff_time
).count()
if user_attempts >= max_attempts:
return False, user_attempts
return True, ip_attempts
[docs]
@classmethod
def record_attempt(cls, ip_address, user=None, successful=False):
"""Record a connection attempt"""
return cls.objects.create(
ip_address=ip_address,
user=user if user and user.is_authenticated else None,
successful=successful,
)
[docs]
@classmethod
def cleanup_old_records(cls, days=7):
"""Delete old rate limit records (call via cron/celery)"""
cutoff_time = timezone.now() - timedelta(days=days)
return cls.objects.filter(attempt_time__lt=cutoff_time).delete()