import json
from typing import Union
from django.conf import settings
from django.http import JsonResponse, HttpResponseBadRequest
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View
from django_aws_api_gateway_websockets.models import WebSocketSession
[docs]@method_decorator(csrf_exempt, name="dispatch")
class WebSocketView(View):
"""The base WebSocket View for handling messages sent from the client via AWS API Gateway"""
route_selection_key = "action"
model = None
body = {}
aws_api_gateway_id = None # Set to None to allow all
required_headers = [
"Host",
"X-Real-Ip",
"X-Forwarded-For",
"X-Forwarded-Proto",
"Connection",
"Content-Length",
"X-Forwarded-Port",
"X-Amzn-Trace-Id",
"Connectionid",
"User-Agent",
"X-Amzn-Apigateway-Api-Id",
]
required_connection_headers = [
"Cookie",
"Origin",
"Sec-Websocket-Extensions",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
]
expected_useragent_prefix = "AmazonAPIGateway_"
[docs] def setup(self, request, *args, **kwargs):
"""Converts the request.body string back into a dictionary and assign to the objets body property for ease"""
super().setup(request, *args, **kwargs)
self.body = json.loads(request.body) if request.body else {}
def _return_bad_request(self, msg):
"""Common method for logging and returning the HTTP400 response"""
return HttpResponseBadRequest(msg)
[docs] def route_selection_key_missing(
self, request, *args, **kwargs
) -> HttpResponseBadRequest:
"""Method for handling missing route_selection_key"""
msg = f"route_select_key {self.route_selection_key} missing from request body."
return self._return_bad_request(msg)
[docs] def invalid_useragent(self, request, *args, **kwargs) -> HttpResponseBadRequest:
"""Method for handling unexpected useragents"""
msg = (
f"Unexpected Useragent; Expected {self.expected_useragent_prefix}{self.aws_api_gateway_id}, "
f"Received {request.headers['User-Agent']}"
)
return self._return_bad_request(msg)
def _expected_headers(self, request, *args, **kwargs) -> bool:
"""Ensure that all required headers exist within the request header"""
request_headers = request.headers.keys()
return all(h in request_headers for h in self.required_headers)
def _expected_apigateway_id(self, request, *args, **kwargs) -> bool:
"""Ensure expected AWS Gateway ID if one is set, if expected value not set then allow all"""
return (
self.aws_api_gateway_id
and request.headers["X-Amzn-Apigateway-Api-Id"]
is not self.aws_api_gateway_id
)
def _expected_useragent(self, request, *args, **kwargs) -> bool:
"""Validated that the useragent is the expected one for all calls except the connect method"""
if self.aws_api_gateway_id:
return (
request.headers["User-Agent"]
is not f"{self.expected_useragent_prefix}{self.aws_api_gateway_id}"
)
return self.expected_useragent_prefix in request.headers["User-Agent"]
@staticmethod
def _check_allowed_hosts(request) -> bool:
"""Check that the host is within the allowed hosts"""
if (
settings.ALLOWED_HOSTS
and request.headers["Host"] not in settings.ALLOWED_HOSTS
):
return False
return True
@staticmethod
def _check_host_is_in_origin(request) -> bool:
"""Check that the value of the Host header is within the Origin header. Origin will have the protocol as well"""
if request.headers["Host"] not in request.headers["Origin"]:
return False
return True
def _expected_connection_headers(self, request, *args, **kwargs) -> bool:
"""Run additional checks for the connection route for security"""
request_headers = request.headers.keys()
return all(h in request_headers for h in self.required_connection_headers)
def _add_user_to_request(self, request):
"""Fetch the user from the model and append it back into the request variable"""
wss = WebSocketSession.objects.get(
connection_id=request.headers["Connectionid"]
)
request.user = wss.user
wss.request_count += 1
wss.save()
[docs] def dispatch(self, request, *args, **kwargs):
"""Determine the correct method to call. The method will map to the route_selection_key or default.
Checks for the expected headers.
Tries to dispatch to the right method; if a method doesn't exist defer to the default handler.
If the Route Selection Key is missing defer to the route selection error handler.
If the request method isn't on the approved list then defer to the normal error handler .
"""
if self._expected_headers(request):
if request.method.lower() in self.http_method_names:
if "connect" == self.kwargs["slug"]:
handler = self.connect
elif "disconnect" == self.kwargs["slug"]:
if not self._expected_useragent(request, *args, **kwargs):
handler = self.invalid_useragent
else:
handler = self.disconnect
self._add_user_to_request(request)
elif self.route_selection_key in self.body:
handler = getattr(
self, self.body[self.route_selection_key], self.default
)
if not self._expected_useragent(request, *args, **kwargs):
handler = self.invalid_useragent
else:
self._add_user_to_request(request)
else:
handler = self.route_selection_key_missing
else:
handler = self.http_method_not_allowed
else:
handler = self.missing_headers
return handler(request, *args, **kwargs)
[docs] def connect(
self, request, *args, **kwargs
) -> Union[JsonResponse, HttpResponseBadRequest]:
"""Handle the connection route in a standard way that ensures the User to Connectionid mapping persists"""
if not self._expected_connection_headers(request, *args, **kwargs):
msg = f"Missing headers; Expected {self.required_connection_headers}, Received {request.headers}"
return self._return_bad_request(msg)
if not self._check_allowed_hosts(request):
msg = f"Host {request.headers['Host']} not in AllowedHosts {settings.ALLOWED_HOSTS}"
return self._return_bad_request(msg)
if not self._check_host_is_in_origin(request):
msg = f"Host {request.headers['Host']} not in Origin {request.headers['Host']}"
return self._return_bad_request(msg)
res, msg = self._additional_connection_checks(request, *args, **kwargs)
if not res:
return self._return_bad_request(msg)
WebSocketSession.objects.create(
connection_id=request.headers["Connectionid"],
chennel=request.GET.get("channel", ""),
user=request.user,
)
return JsonResponse({})
def _additional_connection_checks(self, request, *args, **kwargs) -> (bool, str):
"""Could add in additional steps for certificates, APIGateway Authorizers etc"""
return True, ""
[docs] def disconnect(self, request, *args, **kwargs) -> JsonResponse:
"""Using connectionId update websocket table to show as disconnected"""
wss = WebSocketSession.objects.get(
connection_id=request.headers["Connectionid"]
)
wss.connected = False
wss.save()
return JsonResponse({})
[docs] def default(self, request, *args, **kwargs) -> JsonResponse:
"""OVerload this method if you want to have a default message handler"""
raise NotImplementedError("This logic needs to be defined within the subclass")