middleware.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from __future__ import annotations
  2. import re
  3. from typing import Awaitable
  4. from typing import Callable
  5. from urllib.parse import SplitResult
  6. from urllib.parse import urlsplit
  7. from asgiref.sync import iscoroutinefunction
  8. from asgiref.sync import markcoroutinefunction
  9. from django.http import HttpRequest
  10. from django.http import HttpResponse
  11. from django.http.response import HttpResponseBase
  12. from django.utils.cache import patch_vary_headers
  13. from corsheaders.conf import conf
  14. from corsheaders.signals import check_request_enabled
  15. ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
  16. ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
  17. ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
  18. ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
  19. ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
  20. ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
  21. ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
  22. ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
  23. class CorsMiddleware:
  24. sync_capable = True
  25. async_capable = True
  26. def __init__(
  27. self,
  28. get_response: (
  29. Callable[[HttpRequest], HttpResponseBase]
  30. | Callable[[HttpRequest], Awaitable[HttpResponseBase]]
  31. ),
  32. ) -> None:
  33. self.get_response = get_response
  34. self.async_mode = iscoroutinefunction(self.get_response)
  35. if self.async_mode:
  36. # Mark the class as async-capable, but do the actual switch
  37. # inside __call__ to avoid swapping out dunder methods
  38. markcoroutinefunction(self)
  39. def __call__(
  40. self, request: HttpRequest
  41. ) -> HttpResponseBase | Awaitable[HttpResponseBase]:
  42. if self.async_mode:
  43. return self.__acall__(request)
  44. response: HttpResponseBase | None = self.check_preflight(request)
  45. if response is None:
  46. result = self.get_response(request)
  47. assert isinstance(result, HttpResponseBase)
  48. response = result
  49. self.add_response_headers(request, response)
  50. return response
  51. async def __acall__(self, request: HttpRequest) -> HttpResponseBase:
  52. response = self.check_preflight(request)
  53. if response is None:
  54. result = self.get_response(request)
  55. assert not isinstance(result, HttpResponseBase)
  56. response = await result
  57. self.add_response_headers(request, response)
  58. return response
  59. def check_preflight(self, request: HttpRequest) -> HttpResponseBase | None:
  60. """
  61. Generate a response for CORS preflight requests.
  62. """
  63. request._cors_enabled = self.is_enabled(request) # type: ignore [attr-defined]
  64. if (
  65. request._cors_enabled # type: ignore [attr-defined]
  66. and request.method == "OPTIONS"
  67. and "access-control-request-method" in request.headers
  68. ):
  69. return HttpResponse(headers={"content-length": "0"})
  70. return None
  71. def add_response_headers(
  72. self, request: HttpRequest, response: HttpResponseBase
  73. ) -> HttpResponseBase:
  74. """
  75. Add the respective CORS headers
  76. """
  77. enabled = getattr(request, "_cors_enabled", None)
  78. if enabled is None:
  79. enabled = self.is_enabled(request)
  80. if not enabled:
  81. return response
  82. patch_vary_headers(response, ("origin",))
  83. origin = request.headers.get("origin")
  84. if not origin:
  85. return response
  86. try:
  87. url = urlsplit(origin)
  88. except ValueError:
  89. return response
  90. if (
  91. not conf.CORS_ALLOW_ALL_ORIGINS
  92. and not self.origin_found_in_white_lists(origin, url)
  93. and not self.check_signal(request)
  94. ):
  95. return response
  96. if conf.CORS_ALLOW_ALL_ORIGINS and not conf.CORS_ALLOW_CREDENTIALS:
  97. response[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
  98. else:
  99. response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
  100. if conf.CORS_ALLOW_CREDENTIALS:
  101. response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
  102. if len(conf.CORS_EXPOSE_HEADERS):
  103. response[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
  104. conf.CORS_EXPOSE_HEADERS
  105. )
  106. if request.method == "OPTIONS":
  107. response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(conf.CORS_ALLOW_HEADERS)
  108. response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(conf.CORS_ALLOW_METHODS)
  109. if conf.CORS_PREFLIGHT_MAX_AGE:
  110. response[ACCESS_CONTROL_MAX_AGE] = str(conf.CORS_PREFLIGHT_MAX_AGE)
  111. if (
  112. conf.CORS_ALLOW_PRIVATE_NETWORK
  113. and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
  114. ):
  115. response[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
  116. return response
  117. def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool:
  118. return (
  119. (origin == "null" and origin in conf.CORS_ALLOWED_ORIGINS)
  120. or self._url_in_whitelist(url)
  121. or self.regex_domain_match(origin)
  122. )
  123. def regex_domain_match(self, origin: str) -> bool:
  124. return any(
  125. re.match(domain_pattern, origin)
  126. for domain_pattern in conf.CORS_ALLOWED_ORIGIN_REGEXES
  127. )
  128. def is_enabled(self, request: HttpRequest) -> bool:
  129. return bool(
  130. re.match(conf.CORS_URLS_REGEX, request.path_info)
  131. ) or self.check_signal(request)
  132. def check_signal(self, request: HttpRequest) -> bool:
  133. signal_responses = check_request_enabled.send(sender=None, request=request)
  134. return any(return_value for function, return_value in signal_responses)
  135. def _url_in_whitelist(self, url: SplitResult) -> bool:
  136. origins = [urlsplit(o) for o in conf.CORS_ALLOWED_ORIGINS]
  137. return any(
  138. origin.scheme == url.scheme and origin.netloc == url.netloc
  139. for origin in origins
  140. )