Source code for django_aws_api_gateway_websockets.views

import ipaddress
import json
import re
import warnings
from html import escape
from typing import Union

from django.conf import settings
from django.db import transaction
from django.db.models import F
from django.http import HttpResponseBadRequest, HttpResponseForbidden, JsonResponse
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View

from django_aws_api_gateway_websockets.models import (
    ApiGateway,
    ConnectionRateLimit,
    WebSocketSession,
    WebSocketToken,
)


[docs] class WebSocketTokenView(View): """ CSRF-protected endpoint to generate one-time WebSocket tokens. Usage: 1. Client calls this endpoint with session cookie + CSRF token 2. Returns a one-time token valid for 60 seconds 3. Client uses token in WebSocket connection URL: ws://...?ws_token=<token> """
[docs] def post(self, request, *args, **kwargs): if not request.user.is_authenticated: return HttpResponseForbidden("Authentication required") if not request.session.session_key: # Create session if it doesn't exist request.session.create() # Security: Rate limit token generation if not WebSocketToken.check_rate_limit(request.user, max_tokens_per_minute=20): return HttpResponseForbidden("Rate limit exceeded") token = WebSocketToken.generate_token( user=request.user, session_key=request.session.session_key ) return JsonResponse({"token": token.token, "expires_in": 60}) # seconds
[docs] @method_decorator(csrf_exempt, name="dispatch") class WebSocketView(View): """The base WebSocket View for handling messages sent from the client via AWS API Gateway. Expects a URL slug parameter called route """ debug = False debug_log = None MAX_BODY_SIZE = 1024 * 128 # 128KB MAX_CHANNEL_NAME_LENGTH = 191 CHANNEL_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") ALLOWED_HANDLERS = set() ADDITIONAL_ALLOWED_HANDLERS = list() USE_WS_TOKEN = True RATE_LIMIT_ENABLED = True RATE_LIMIT_MAX_ATTEMPTS = 20 # Max connection attempts RATE_LIMIT_WINDOW_MINUTES = 1 # Within this time window # todo - The following key is deprecated and should be removed in a future release for handler_selection_key route_selection_key = "action" handler_selection_key = "handler" model = None body = {} aws_api_gateway_id = None # Set to None to allow all api_gateway = None websocket_session = None required_headers = [ "Host", "X-Forwarded-For", "X-Forwarded-Proto", "Content-Length", "Connectionid", "User-Agent", "X-Amzn-Apigateway-Api-Id", ] additional_required_headers = [ "X-Amzn-Trace-Id", "X-Forwarded-Port", "X-Real-Ip", ] required_connection_headers = [ "Cookie", "Origin", "Sec-Websocket-Extensions", "Sec-Websocket-Key", "Sec-Websocket-Version", ] expected_useragent_prefix = "AmazonAPIGateway_" permissions_required = ( [] ) # User must have ANY of these permissions to invoke the method. Default = Allow All all_permissions_required = ( [] ) # User must have ALL these permissions to invoke the method. Default = Allow All def __init__(self, **kwargs): self.api_gateway = kwargs.get("api_gateway", False) self.websocket_session = kwargs.get("websocket_session", False) self.debug = kwargs.get("debug", False) self.debug_log = [] if not self.ALLOWED_HANDLERS: self.ALLOWED_HANDLERS = set( ["default"] + self.ADDITIONAL_ALLOWED_HANDLERS + self._get_current_class_methods() ) super().__init__(**kwargs) def _debug(self, msg: str): if self.debug: # Security: Sanitize debug messages to prevent XSS if logs are ever displayed sanitized_msg = escape(str(msg)[:500]) # Limit length and escape HTML self.debug_log.append(sanitized_msg)
[docs] def setup(self, request, *args, **kwargs): """Converts the request.body string back into a dictionary and assign to the objects body property for ease""" super().setup(request, *args, **kwargs) self._debug("Within setup") if request.body and len(request.body) > self.MAX_BODY_SIZE: self.body = {} self._debug("Request body too large") return # Early exit try: self.body = json.loads(request.body) if request.body else {} except json.JSONDecodeError: self._debug("Body is invalid JSON") self.body = {} self._debug("Setup completed")
def _get_current_class_methods(self): """Get only the methods defined in the current class, ignoring inherited ones.""" # Get all methods from the current class current_class = self.__class__ # Get methods defined directly in this class (not inherited) methods = [] for name in dir(current_class): # Skip private/magic methods if desired if name.startswith("_"): continue attr = getattr(current_class, name) # Check if it's a method/function if callable(attr): # Check if it's defined in the current class, not inherited if name in current_class.__dict__: methods.append(name) return methods def _return_bad_request(self, msg: str): """Common method for logging and returning the HTTP400 response""" return HttpResponseBadRequest(msg)
[docs] def route_selection_key_missing( self, request, *args, **kwargs ) -> HttpResponseBadRequest: """Method for handling missing route_selection_key""" warnings.warn( "Deprecated: Use handler_selection_key_missing instead. Will be removed in version 3", DeprecationWarning, stacklevel=2, ) msg = f"route_select_key {self.route_selection_key} missing from request body." self._debug(msg) return self._return_bad_request(msg)
[docs] def handler_selection_key_missing( self, request, *args, **kwargs ) -> HttpResponseBadRequest: """Method for handling missing handler_selection_key""" msg = f"handler_selection_key {self.handler_selection_key} missing from request body." self._debug(msg) # Security: Generic error message to prevent information disclosure return self._return_bad_request("Invalid request format (1)")
[docs] def missing_headers(self, request, *args, **kwargs) -> HttpResponseBadRequest: """Method for handling missing headers""" if self.debug: msg = f"Expected {self.required_headers}, Received {request.headers.keys()}" self._debug(msg) return self._return_bad_request("Invalid request headers")
[docs] def permission_denied(self, request, *args, **kwargs) -> HttpResponseBadRequest: """Method for handling denied access""" return HttpResponseForbidden("Permission Denied")
[docs] def invalid_useragent(self, request, *args, **kwargs) -> HttpResponseBadRequest: """Method for handling unexpected useragents""" msg = ( f"Unexpected Useragent; Expected {self.expected_useragent_prefix}{self.aws_api_gateway_id}, " f"Received {request.headers['User-Agent']}" ) self._debug(msg) return self._return_bad_request("Unexpected Useragent")
def _expected_headers(self, request, *args, **kwargs) -> bool: """Ensure that all required headers exist within the request header""" request_headers = request.headers.keys() res = all(h in request_headers for h in self.required_headers) if res and self.additional_required_headers: res = all(h in request_headers for h in self.additional_required_headers) self._debug(f"_expected_headers() returned {res}") return res def _allowed_apigateway(self, request, *args, **kwargs) -> bool: """Ensure the AWS API Gateway is the expected one (if set against the class) otherwise use the DB check""" res = self._check_platform_registered_api_gateways(request) if self.aws_api_gateway_id: res = self.aws_api_gateway_id == request.headers["X-Amzn-Apigateway-Api-Id"] self._debug(f"_expected_apigateway_id() returned {res}") self._debug(f"_allowed_apigateway() returned {res}") return res def _check_platform_registered_api_gateways(self, request) -> bool: """Checks to ensure that the API Gateway calling the view is one that the user has registered""" self.api_gateway = ApiGateway.objects.filter( api_id=request.headers["X-Amzn-Apigateway-Api-Id"] ).first() self._debug( f"_check_platform_registered_api_gateways() returned {bool(self.api_gateway)}" ) return bool(self.api_gateway) def _expected_useragent(self, request, *args, **kwargs) -> bool: """Validated that the useragent is the expected one for all calls except the connect method For the connect method the useragent should be API Gateway itself and NOT the client's forwarded useragent """ if self.aws_api_gateway_id: res = ( not request.headers["User-Agent"] == f"{self.expected_useragent_prefix}{self.aws_api_gateway_id}" ) else: res = self.expected_useragent_prefix in request.headers["User-Agent"] self._debug(f"_expected_useragent() returned {res}") return res @staticmethod def _check_allowed_hosts(request) -> bool: """Check that the host is within the allowed hosts""" if ( settings.ALLOWED_HOSTS and request.headers["Host"] not in settings.ALLOWED_HOSTS ): return False return True @staticmethod def _check_host_is_in_origin(request) -> bool: """Check that the value of the Host header is within the Origin header. Origin will have the protocol as well""" if request.headers["Host"] not in request.headers["Origin"]: return False return True def _expected_connection_headers(self, request, *args, **kwargs) -> bool: """Run additional checks for the connection route for security""" request_headers = request.headers.keys() return all(h in request_headers for h in self.required_connection_headers) def _validate_connection_id(self, connection_id: str) -> bool: # AWS connection IDs are alphanumeric with specific allowed chars, max 128 chars # Security: Strict validation to prevent injection if not connection_id or len(connection_id) > 128: return False return bool(re.match(r"^[A-Za-z0-9_=-]+$", connection_id)) def _add_user_to_request(self, request): """Fetch the user from the model and append it back into the request variable""" connection_id = request.headers["Connectionid"] if self._validate_connection_id(request.headers["Connectionid"]): try: with transaction.atomic(): wss = WebSocketSession.objects.select_for_update().get( connection_id=connection_id ) if wss.user: request.user = wss.user wss.request_count = F("request_count") + 1 wss.save(update_fields=["request_count"]) except WebSocketSession.DoesNotExist: self._debug(f"Session not found: {request.headers['Connectionid']}") # Handle gracefully def _get_channel_name(self, request) -> str: """Returns the name of the channel to use The "channel" can be optionally be set as a querystring parameter during the connection. If it is not but the api_gateway that was selected has a default_channel_name set then that will be used instead. Otherwise an empty string is returned """ channel_name = request.GET.get("channel", "") # Security: Validate channel name length and format atomically if channel_name: if len( channel_name ) > self.MAX_CHANNEL_NAME_LENGTH or not self.CHANNEL_NAME_PATTERN.match( channel_name ): self._debug(f"Invalid channel name: {channel_name[:50]}") raise ValueError("Invalid channel name") if ( not channel_name and self.api_gateway and self.api_gateway.default_channel_name ): channel_name = self.api_gateway.default_channel_name return channel_name def _load_session(self, request): self.websocket_session = WebSocketSession.objects.get( connection_id=request.headers["Connectionid"] )
[docs] def dispatch(self, request, *args, **kwargs): """Determine the correct method to call. If the body contains a key identified by self.handler_selection_key and that value is a method on the view then that method will be called. ELSE IF the body contains a key identified by self.route_selection_key and that value is a method on the view then that method will be called. ELSE the default method will be called Checks for the expected headers. Tries to dispatch to the right method; if a method doesn't exist, defer to the default handler. If the Handler Selection Key is missing defer to the route selection error handler. If the request method isn't on the approved list then defer to the normal error handler . """ if self._expected_headers(request) and self._allowed_apigateway(request): if request.method.lower() in self.http_method_names: if "connect" == self.kwargs.get("route"): handler = self.connect elif "disconnect" == self.kwargs.get("route"): if not self._expected_useragent(request, *args, **kwargs): handler = self.invalid_useragent else: handler = self.disconnect self._add_user_to_request(request) elif ( self.handler_selection_key in self.body or self.route_selection_key in self.body ): self._load_session(request) handler = self.default handler_name = None # Safely get handler name and validate it's in ALLOWED_HANDLERS if self.body.get(self.handler_selection_key): handler_name = str( self.body[self.handler_selection_key] ) # Force string type # Security: Only allow handlers explicitly in ALLOWED_HANDLERS if ( handler_name in self.ALLOWED_HANDLERS and not handler_name.startswith("_") ): handler = getattr(self, handler_name, self.default) # Fallback to deprecated route_selection_key if handler not found if handler == self.default and self.body.get( self.route_selection_key ): handler_name = str( self.body[self.route_selection_key] ) # Force string type if ( handler_name in self.ALLOWED_HANDLERS and not handler_name.startswith("_") ): handler = getattr(self, handler_name, self.default) if not self._expected_useragent(request, *args, **kwargs): handler = self.invalid_useragent else: self._add_user_to_request(request) # Use Django Permissiosn to restrict the method if self.all_permissions_required: if not self.has_all_permission(request): handler = self.permission_denied elif self.permissions_required: if not self.has_any_permission(request): handler = self.permission_denied else: handler = self.handler_selection_key_missing else: handler = self.http_method_not_allowed else: handler = self.missing_headers res = handler(request, *args, **kwargs) return res if res else JsonResponse({})
[docs] def connect( self, request, *args, **kwargs ) -> Union[JsonResponse, HttpResponseBadRequest]: """Handle the connection route in a standard way that ensures the User to Connectionid mapping persists""" # Security: Rate limiting for connection attempts ip_address = self._get_client_ip(request) user = ( request.user if hasattr(request, "user") and request.user.is_authenticated else None ) if self.RATE_LIMIT_ENABLED: is_allowed, attempt_count = ConnectionRateLimit.check_rate_limit( ip_address=ip_address, user=user, max_attempts=self.RATE_LIMIT_MAX_ATTEMPTS, window_minutes=self.RATE_LIMIT_WINDOW_MINUTES, ) if not is_allowed: self._debug(f"Rate limit exceeded: {attempt_count} attempts") ConnectionRateLimit.record_attempt(ip_address, user, successful=False) return self._return_bad_request("Too many connection attempts") if not self._expected_connection_headers(request, *args, **kwargs): request_headers = request.headers.keys() missing_headers = [ h for h in self.required_connection_headers if h not in request_headers ] self._debug( f"Missing headers; Expected {self.required_connection_headers}, Received {request.headers}" ) msg = f"Missing {len(missing_headers)} headers" return self._return_bad_request(msg) if not self._check_allowed_hosts(request): msg = f"Host is not in AllowedHosts {settings.ALLOWED_HOSTS}" self._debug(msg) return self._return_bad_request(msg) if not self._check_host_is_in_origin(request): self._debug("Host is not in Origin") return self._return_bad_request("Host is not in Origin") if self.USE_WS_TOKEN: # Security: Validate one-time WebSocket token (CSRF protection) ws_token = request.GET.get("ws_token", "") if not request.session.session_key: self._debug("No session key present") return self._return_bad_request("Invalid session") validated_user = WebSocketToken.validate_and_consume( token_value=ws_token, session_key=request.session.session_key, max_age_seconds=60, ) if not validated_user: self._debug("WebSocket token validation failed") return self._return_bad_request("Invalid or expired token") # Override request.user with validated user from token request.user = validated_user res, msg = self._additional_connection_checks(request, *args, **kwargs) if not res: if self.RATE_LIMIT_ENABLED: ConnectionRateLimit.record_attempt(ip_address, user, successful=False) return self._return_bad_request(msg) WebSocketSession.objects.create( connection_id=request.headers["Connectionid"], channel_name=self._get_channel_name(request), user=request.user if request.user.is_authenticated else None, api_gateway=self.api_gateway, ) # Record successful connection if self.RATE_LIMIT_ENABLED: ConnectionRateLimit.record_attempt(ip_address, user, successful=True) return JsonResponse({})
def _get_client_ip(self, request) -> str: """Extract client IP address from request headers""" # Security: Get real IP from X-Forwarded-For (set by API Gateway) x_forwarded_for = request.headers.get("X-Forwarded-For", "") if x_forwarded_for: # Take the first IP in the chain (client IP) ip = x_forwarded_for.split(",")[0].strip() else: ip = request.META.get("REMOTE_ADDR", "0.0.0.0") # Validate IP format try: ipaddress.ip_address(ip) return ip except ValueError: self._debug(f"Invalid IP address: {ip[:50]}") return "0.0.0.0" def _additional_connection_checks(self, request, *args, **kwargs) -> (bool, str): """Could add in additional steps for certificates, APIGateway Authorizers etc""" return True, ""
[docs] def disconnect(self, request, *args, **kwargs): """Using connectionId update websocket table to show as disconnected""" wss = WebSocketSession.objects.get( connection_id=request.headers["Connectionid"] ) wss.connected = False wss.save()
[docs] def default(self, request, *args, **kwargs) -> JsonResponse: """Overload this method if you want to have a default message handler""" raise NotImplementedError("This logic needs to be defined within the subclass")
[docs] def has_any_permission(self, request) -> bool: """Test if the user has ANY of the required permissions""" res = True if self.permissions_required: has_perms = [ request.user.has_perms([permission]) for permission in self.permissions_required ] res = any(has_perms) return res
[docs] def has_all_permission(self, request) -> bool: """Test if the user has ALL: of the required permissions""" res = True if self.all_permissions_required: has_perms = [ request.user.has_perms([permission]) for permission in self.all_permissions_required ] res = all(has_perms) return res
[docs] def change_channel(self, request, *args, **kwargs): """Update the channel the user is connected to""" self.websocket_session.channel_name = self.body["channel"] self.websocket_session.save()