| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- from __future__ import annotations
- import re
- from typing import Awaitable
- from typing import Callable
- from urllib.parse import SplitResult
- from urllib.parse import urlsplit
- from asgiref.sync import iscoroutinefunction
- from asgiref.sync import markcoroutinefunction
- from django.http import HttpRequest
- from django.http import HttpResponse
- from django.http.response import HttpResponseBase
- from django.utils.cache import patch_vary_headers
- from corsheaders.conf import conf
- from corsheaders.signals import check_request_enabled
- ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
- ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
- ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
- ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
- ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
- ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
- ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
- ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
- class CorsMiddleware:
- sync_capable = True
- async_capable = True
- def __init__(
- self,
- get_response: (
- Callable[[HttpRequest], HttpResponseBase]
- | Callable[[HttpRequest], Awaitable[HttpResponseBase]]
- ),
- ) -> None:
- self.get_response = get_response
- self.async_mode = iscoroutinefunction(self.get_response)
- if self.async_mode:
- # Mark the class as async-capable, but do the actual switch
- # inside __call__ to avoid swapping out dunder methods
- markcoroutinefunction(self)
- def __call__(
- self, request: HttpRequest
- ) -> HttpResponseBase | Awaitable[HttpResponseBase]:
- if self.async_mode:
- return self.__acall__(request)
- response: HttpResponseBase | None = self.check_preflight(request)
- if response is None:
- result = self.get_response(request)
- assert isinstance(result, HttpResponseBase)
- response = result
- self.add_response_headers(request, response)
- return response
- async def __acall__(self, request: HttpRequest) -> HttpResponseBase:
- response = self.check_preflight(request)
- if response is None:
- result = self.get_response(request)
- assert not isinstance(result, HttpResponseBase)
- response = await result
- self.add_response_headers(request, response)
- return response
- def check_preflight(self, request: HttpRequest) -> HttpResponseBase | None:
- """
- Generate a response for CORS preflight requests.
- """
- request._cors_enabled = self.is_enabled(request) # type: ignore [attr-defined]
- if (
- request._cors_enabled # type: ignore [attr-defined]
- and request.method == "OPTIONS"
- and "access-control-request-method" in request.headers
- ):
- return HttpResponse(headers={"content-length": "0"})
- return None
- def add_response_headers(
- self, request: HttpRequest, response: HttpResponseBase
- ) -> HttpResponseBase:
- """
- Add the respective CORS headers
- """
- enabled = getattr(request, "_cors_enabled", None)
- if enabled is None:
- enabled = self.is_enabled(request)
- if not enabled:
- return response
- patch_vary_headers(response, ("origin",))
- origin = request.headers.get("origin")
- if not origin:
- return response
- try:
- url = urlsplit(origin)
- except ValueError:
- return response
- if (
- not conf.CORS_ALLOW_ALL_ORIGINS
- and not self.origin_found_in_white_lists(origin, url)
- and not self.check_signal(request)
- ):
- return response
- if conf.CORS_ALLOW_ALL_ORIGINS and not conf.CORS_ALLOW_CREDENTIALS:
- response[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
- else:
- response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
- if conf.CORS_ALLOW_CREDENTIALS:
- response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
- if len(conf.CORS_EXPOSE_HEADERS):
- response[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
- conf.CORS_EXPOSE_HEADERS
- )
- if request.method == "OPTIONS":
- response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(conf.CORS_ALLOW_HEADERS)
- response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(conf.CORS_ALLOW_METHODS)
- if conf.CORS_PREFLIGHT_MAX_AGE:
- response[ACCESS_CONTROL_MAX_AGE] = str(conf.CORS_PREFLIGHT_MAX_AGE)
- if (
- conf.CORS_ALLOW_PRIVATE_NETWORK
- and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
- ):
- response[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
- return response
- def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool:
- return (
- (origin == "null" and origin in conf.CORS_ALLOWED_ORIGINS)
- or self._url_in_whitelist(url)
- or self.regex_domain_match(origin)
- )
- def regex_domain_match(self, origin: str) -> bool:
- return any(
- re.match(domain_pattern, origin)
- for domain_pattern in conf.CORS_ALLOWED_ORIGIN_REGEXES
- )
- def is_enabled(self, request: HttpRequest) -> bool:
- return bool(
- re.match(conf.CORS_URLS_REGEX, request.path_info)
- ) or self.check_signal(request)
- def check_signal(self, request: HttpRequest) -> bool:
- signal_responses = check_request_enabled.send(sender=None, request=request)
- return any(return_value for function, return_value in signal_responses)
- def _url_in_whitelist(self, url: SplitResult) -> bool:
- origins = [urlsplit(o) for o in conf.CORS_ALLOWED_ORIGINS]
- return any(
- origin.scheme == url.scheme and origin.netloc == url.netloc
- for origin in origins
- )
|