Source code for django_aws_api_gateway_websockets.models

import json
import re
import secrets
from datetime import timedelta

import boto3
from botocore.exceptions import ClientError
from django.conf import settings
from django.core.validators import RegexValidator
from django.db import models, transaction
from django.utils import timezone

MAX_MESSAGE_SIZE = 1024 * 128  # 128KB (AWS limit is 128KB)


[docs] def get_region_name(): """Returns the AWS region name from settings.py. Uses AWS_GATEWAY_REGION_NAME first and falls back to AWS_REGION_NAME for backwards compatibility. """ if ( hasattr(settings, "AWS_GATEWAY_REGION_NAME") and settings.AWS_GATEWAY_REGION_NAME ): return settings.AWS_GATEWAY_REGION_NAME elif hasattr(settings, "AWS_REGION_NAME") and settings.AWS_REGION_NAME: return settings.AWS_REGION_NAME else: raise RuntimeError( "AWS_GATEWAY_REGION_NAME or AWS_REGION_NAME must be set within settings.py" )
[docs] def get_boto3_client(service: str = "apigatewayv2", **kwargs): """Returns the boto3 client to use. When running within AWS, if you are using an IAM Role with the service, E.G. on an EC2 instance, you need to set AWS_REGION_NAME within settings.py Otherwise you can either use a named profile using settings.AWS_IAM_PROFILE or you can set the credentials using both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY. :param str service: apigatewayv2 | apigatewaymanagementapi """ if hasattr(settings, "AWS_IAM_PROFILE") and settings.AWS_IAM_PROFILE: # Used a named profile where credentials are stored within .aws folder session = boto3.Session(profile_name=settings.AWS_IAM_PROFILE) client = session.client(service, **kwargs) elif ( hasattr(settings, "AWS_ACCESS_KEY_ID") and settings.AWS_ACCESS_KEY_ID and hasattr(settings, "AWS_SECRET_ACCESS_KEY") and settings.AWS_SECRET_ACCESS_KEY ): # Use specific access and secret keys region = get_region_name() client = boto3.client( service, aws_access_key_id=settings.AWS_ACCESS_KEY_ID, aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, region_name=region, **kwargs, ) else: # Use the IAM Role of the machine region = get_region_name() client = boto3.client(service, region_name=region, **kwargs) return client
[docs] class ApiGateway(models.Model): """Stored the API Gateway definitions""" class Meta: indexes = [ models.Index(fields=["api_id"]), models.Index(fields=["domain_name"]), models.Index(fields=["custom_domain_created"]), models.Index(fields=["api_created"]), models.Index(fields=["created_on"]), ] def __str__(self) -> str: return self.api_name
[docs] def clean(self): """Validate fields only when they have values""" from django.core.exceptions import ValidationError # Validate domain_name only if provided if self.domain_name: domain_pattern = re.compile( r"^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$" ) if not domain_pattern.match(self.domain_name): raise ValidationError({"domain_name": "Invalid domain name format"}) # Validate api_name only if provided if self.api_name: api_name_pattern = re.compile(r"^[ a-zA-Z0-9_-]+$") if not api_name_pattern.match(self.api_name): raise ValidationError( { "api_name": "API name can only contain alphanumeric characters, spaces, hyphens, and underscores" } )
[docs] def save(self, **kwargs): """Ensure the trailing slash is saved to the target endpoint""" self.full_clean() if not self.target_base_endpoint[-1:] == "/": self.target_base_endpoint = f"{self.target_base_endpoint}/" super().save(**kwargs)
api_name = models.CharField(max_length=255, unique=True) api_description = models.CharField(max_length=255, blank=True, default="") default_channel_name = models.CharField( max_length=128, blank=True, default="", help_text="Automatically sets the 'channel' on the WebSocketSession record for the connection", ) domain_name = models.CharField( max_length=255, blank=True, default="", help_text="The full domain you wish to use for the API endpoint. E.G ws.example.com", ) target_base_endpoint = models.URLField( null=True, blank=True, default=None, help_text=( "The URL on your website where the API Gateway routes will point, including the trailing /, but excluding " "the final route/slug portion of the URL." " E.G. If your default route will point to https://www.example.com/ws/default then enter " "https://www.example.com/ws/" ), ) certificate_arn = models.CharField( max_length=255, blank=True, default="", help_text="The ARN of the certificate to use from AWS Certificate Manager", ) hosted_zone_id = models.CharField( max_length=32, blank=True, default="", help_text="The Hosted Zone ID from AWs Route 53 for the domain you wish to use", ) api_key_selection_expression = models.CharField( max_length=255, default="$request.header.x-api-key" ) route_selection_expression = models.CharField( max_length=255, default="$request.body.action" ) route_key = models.CharField(max_length=255, default="$default") stage_name = models.CharField(max_length=63, default="production", blank=True) stage_description = models.CharField(max_length=63, default="", blank=True) deployment_id = models.CharField( max_length=32, default="", blank=True, editable=False ) tags = models.JSONField( blank=True, default=dict, help_text='In format {"tag-name": "tag-value"}' ) # Returned Values api_id = models.CharField( max_length=32, blank=True, default="", help_text="The ID of the Api Gateway returned by AWS", ) api_endpoint = models.CharField( max_length=255, blank=True, default="", help_text="The Api Gateway endpoint" ) api_gateway_domain_name = models.CharField( max_length=255, blank=True, default="", help_text="The value to point your CNAME record to", ) api_mapping_id = models.CharField( max_length=128, blank=True, default="", help_text="The ApiMappingId to use with api_mapping calls", ) api_created = models.BooleanField(default=False, editable=False) custom_domain_created = models.BooleanField(default=False, editable=False) created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True)
[docs] def create_gateway(self): """Creates the actual API gateway record""" if self.api_created: return client = get_boto3_client() self._create_api(client) try: self._create_routes(client) if self.pk: for additional_route in self.additional_routes.all(): if not additional_route.deployed: additional_route.create_route(client, deploy=False) self._create_stage_and_deploy(client) except ClientError as ce: raise ce finally: self.api_created = True self.save()
[docs] def create_custom_domain(self): """Uses boto3 to create the custom domain and associate it with the production stage of the loaded API Should be called after create_gateway() has been run :return: """ if not self.api_created: raise ValueError("The API needs to be created before calling this method") if not self.certificate_arn: raise ValueError("A Certificate ARN is required") client = get_boto3_client() domain_res = self._create_domain_name(client) try: self.api_gateway_domain_name = domain_res["DomainNameConfigurations"][0][ "ApiGatewayDomainName" ] self.api_mapping_id = self._create_api_mapping(client) except ClientError as ce: raise ce finally: self.custom_domain_created = True self.save()
def _create_api(self, client): """Creates the base API Gateway endpoint""" res = client.create_api( ApiKeySelectionExpression=self.api_key_selection_expression, Description=self.api_description, DisableSchemaValidation=True, Name=self.api_name, ProtocolType="WEBSOCKET", RouteKey=self.route_key, RouteSelectionExpression=self.route_selection_expression, Target=f"{self.target_base_endpoint}/default", ) self.api_id = res["ApiId"] self.api_endpoint = res["ApiEndpoint"] def _create_routes(self, client): """Creates the integrations and routes associating the route with integrations""" for route in ["$connect", "$disconnect", "$default"]: integration_res = client.create_integration( ApiId=self.api_id, ConnectionType="INTERNET", IntegrationMethod="POST", IntegrationType="HTTP_PROXY", IntegrationUri=f"{self.target_base_endpoint}{route.replace('$', '')}", PassthroughBehavior="WHEN_NO_MATCH", PayloadFormatVersion="1.0", RequestParameters={ "integration.request.header.connectionId": "context.connectionId" }, TimeoutInMillis=29000, ) extra_kwargs = {} if route == "$default": extra_kwargs["RouteResponseSelectionExpression"] = route client.create_route( ApiId=self.api_id, ApiKeyRequired=False, AuthorizationType="NONE", RouteKey=route, Target=f"integrations/{integration_res['IntegrationId']}", **extra_kwargs, ) def _create_stage_and_deploy(self, client): """Create the stage and deployment""" client.create_stage(ApiId=self.api_id, StageName=self.stage_name) self.deploy_api(client)
[docs] def deploy_api(self, client): res = client.create_deployment( ApiId=self.api_id, Description=self.stage_description, StageName=self.stage_name, ) self.deployment_id = res["DeploymentId"]
def _create_domain_name(self, client): """Creates the domain including the HostedZoneID if one is set""" conf = { "CertificateArn": self.certificate_arn, "DomainNameStatus": "AVAILABLE", "EndpointType": "REGIONAL", "SecurityPolicy": "TLS_1_2", } if self.hosted_zone_id: conf["HostedZoneId"] = self.hosted_zone_id return client.create_domain_name( DomainName=self.domain_name, DomainNameConfigurations=[conf] ) def _create_api_mapping(self, client) -> str: mapping_res = client.create_api_mapping( ApiId=self.api_id, DomainName=self.domain_name, Stage=self.stage_name ) return mapping_res["ApiMappingId"]
[docs] class ApiGatewayAdditionalRoute(models.Model): """Stores the additional route keys""" class Meta: # Once version 4 is the minimum supported version then swap to use this # constraints = [ # models.UniqueConstraint("api_gateway", 'name', name='unique_name_per_gateway') # ] unique_together = [["api_gateway", "route_key"]] def __str__(self) -> str: return self.name
[docs] def save(self, *args, **kwargs): """Save the record then deploy the new route if the parent has already been deployed""" super().save(*args, **kwargs) if self.api_gateway.deployment_id and not self.deployed: client = get_boto3_client() self.create_route(client)
api_gateway = models.ForeignKey( ApiGateway, on_delete=models.CASCADE, related_name="additional_routes" ) name = models.CharField(max_length=63, help_text="Descriptive name for the route") route_key = models.CharField(max_length=64, db_index=True) integration_url = models.URLField() deployed = models.BooleanField(default=False, editable=False) created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True)
[docs] def create_route(self, client, deploy=True): """Create the Integration and then the route""" integration_res = client.create_integration( ApiId=self.api_gateway.api_id, ConnectionType="INTERNET", IntegrationMethod="POST", IntegrationType="HTTP_PROXY", IntegrationUri=self.integration_url, PassthroughBehavior="WHEN_NO_MATCH", PayloadFormatVersion="1.0", RequestParameters={ "integration.request.header.connectionId": "context.connectionId" }, TimeoutInMillis=29000, ) client.create_route( ApiId=self.api_gateway.api_id, ApiKeyRequired=False, AuthorizationType="NONE", RouteKey=self.route_key, Target=f"integrations/{integration_res['IntegrationId']}", RouteResponseSelectionExpression="$default", ) if deploy: self.api_gateway.deploy_api(client) self.deployed = True self.save()
[docs] class WebSocketSessionQuerySet(models.QuerySet):
[docs] def send_message(self, data: dict): """Send the same message to all connected WebSocket sessions in the queryset. Example:: WebSocketSession.objects.filter( channel_name="Shared Channel Name", ).send_message( { "msg": "this is a server sent message", } ) """ client = None region = get_region_name() msg = json.dumps(data) res = [] for obj in self.filter(connected=True): if not client: client = get_boto3_client( "apigatewaymanagementapi", endpoint_url=( f"https://{obj.api_gateway.api_id}.execute-api." f"{region}.amazonaws.com/{obj.api_gateway.stage_name}" ), ) try: res.append( client.post_to_connection(Data=msg, ConnectionId=obj.connection_id) ) except ClientError as error: if error.response["Error"]["Code"] == "GoneException": obj.connected = False obj.save() else: raise error return res
[docs] class WebSocketSession(models.Model): class Meta: indexes = [ models.Index(fields=["connection_id", "connected"]), models.Index(fields=["channel_name", "connection_id"]), models.Index(fields=["channel_name", "connected"]), ] def __str__(self) -> str: return self.connection_id
[docs] def send_message(self, data: dict): """Sends a message containing the given data to connection""" region = get_region_name() message_data = json.dumps(data) # Security: Validate message size before sending message_size = len(message_data.encode("utf-8")) if message_size > MAX_MESSAGE_SIZE: raise ValueError( f"Message exceeds AWS limit: {message_size} bytes (max: {MAX_MESSAGE_SIZE} bytes)" ) client = get_boto3_client( "apigatewaymanagementapi", endpoint_url=( f"https://{self.api_gateway.api_id}.execute-api." f"{region}.amazonaws.com/{self.api_gateway.stage_name}" ), ) try: return client.post_to_connection( Data=message_data, ConnectionId=self.connection_id ) except ClientError as error: if error.response["Error"]["Code"] == "GoneException": self.connected = False self.save() else: raise error
objects = WebSocketSessionQuerySet.as_manager() connection_id = models.CharField(max_length=255, unique=True) channel_name = models.CharField( max_length=128, blank=True, default="", help_text="Used to group connections together", ) user = models.ForeignKey( settings.AUTH_USER_MODEL, blank=True, null=True, default=None, on_delete=models.CASCADE, related_name="websocket_sessions", ) connected = models.BooleanField( default=True, help_text="Indicates is the connection is current or not" ) api_gateway = models.ForeignKey( ApiGateway, null=True, blank=True, default=None, on_delete=models.SET_NULL, related_name="sessions", ) request_count = models.PositiveBigIntegerField(default=1) created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True)
[docs] class WebSocketToken(models.Model): """One-time use tokens for establishing WebSocket connections (CSRF protection)""" token = models.CharField( max_length=64, unique=True, validators=[RegexValidator(r"^[a-f0-9]{64}$", "Invalid token format")], ) user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="websocket_tokens", ) session_key = models.CharField(max_length=40, db_index=True) created_at = models.DateTimeField(auto_now_add=True, db_index=True) used = models.BooleanField(default=False) class Meta: indexes = [ models.Index(fields=["token", "used"], name="idx_websocket_tokens_used"), ]
[docs] @classmethod def generate_token(cls, user, session_key): """Generate a one-time use WebSocket token""" token_value = secrets.token_hex(32) # 64 character hex string return cls.objects.create(token=token_value, user=user, session_key=session_key)
[docs] @classmethod def validate_and_consume(cls, token_value, session_key, max_age_seconds=60): """Validate token and mark as used (single-use)""" try: cutoff_time = timezone.now() - timedelta(seconds=max_age_seconds) with transaction.atomic(): token_obj = cls.objects.select_for_update().get( token=token_value, session_key=session_key, used=False, created_at__gte=cutoff_time, ) # Mark as used immediately token_obj.used = True token_obj.save(update_fields=["used"]) return token_obj.user except cls.DoesNotExist: return None
[docs] @classmethod def cleanup_expired(cls, max_age_seconds=300): """Delete expired/used tokens (call via cron/celery)""" cutoff_time = timezone.now() - timedelta(seconds=max_age_seconds) return cls.objects.filter(created_at__lt=cutoff_time).delete()
[docs] @classmethod def check_rate_limit(cls, user, max_tokens_per_minute=10): """Check if user has exceeded token generation rate limit""" one_minute_ago = timezone.now() - timedelta(minutes=1) recent_tokens = cls.objects.filter( user=user, created_at__gte=one_minute_ago ).count() return recent_tokens < max_tokens_per_minute
[docs] class ConnectionRateLimit(models.Model): """Track connection attempts for rate limiting""" ip_address = models.GenericIPAddressField(db_index=True) user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, null=True, blank=True, related_name="connection_attempts", ) attempt_time = models.DateTimeField(auto_now_add=True, db_index=True) successful = models.BooleanField(default=False) class Meta: indexes = [ models.Index( fields=["ip_address", "attempt_time"], name="idx_ip_attempt_time" ), models.Index(fields=["user", "attempt_time"], name="idx_user_attempt_time"), ]
[docs] @classmethod def check_rate_limit(cls, ip_address, user=None, max_attempts=20, window_minutes=5): """ Check if IP or user has exceeded connection rate limit Returns: (is_allowed: bool, attempts_count: int) """ cutoff_time = timezone.now() - timedelta(minutes=window_minutes) # Check IP-based rate limit ip_attempts = cls.objects.filter( ip_address=ip_address, attempt_time__gte=cutoff_time ).count() if ip_attempts >= max_attempts: return False, ip_attempts # Check user-based rate limit if user provided if user and user.is_authenticated: user_attempts = cls.objects.filter( user=user, attempt_time__gte=cutoff_time ).count() if user_attempts >= max_attempts: return False, user_attempts return True, ip_attempts
[docs] @classmethod def record_attempt(cls, ip_address, user=None, successful=False): """Record a connection attempt""" return cls.objects.create( ip_address=ip_address, user=user if user and user.is_authenticated else None, successful=successful, )
[docs] @classmethod def cleanup_old_records(cls, days=7): """Delete old rate limit records (call via cron/celery)""" cutoff_time = timezone.now() - timedelta(days=days) return cls.objects.filter(attempt_time__lt=cutoff_time).delete()