cookie.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import binascii
  2. import json
  3. from django.conf import settings
  4. from django.contrib.messages.storage.base import BaseStorage, Message
  5. from django.core import signing
  6. from django.http import SimpleCookie
  7. from django.utils.safestring import SafeData, mark_safe
  8. class MessageEncoder(json.JSONEncoder):
  9. """
  10. Compactly serialize instances of the ``Message`` class as JSON.
  11. """
  12. message_key = "__json_message"
  13. def default(self, obj):
  14. if isinstance(obj, Message):
  15. # Using 0/1 here instead of False/True to produce more compact json
  16. is_safedata = 1 if isinstance(obj.message, SafeData) else 0
  17. message = [self.message_key, is_safedata, obj.level, obj.message]
  18. if obj.extra_tags is not None:
  19. message.append(obj.extra_tags)
  20. return message
  21. return super().default(obj)
  22. class MessageDecoder(json.JSONDecoder):
  23. """
  24. Decode JSON that includes serialized ``Message`` instances.
  25. """
  26. def process_messages(self, obj):
  27. if isinstance(obj, list) and obj:
  28. if obj[0] == MessageEncoder.message_key:
  29. if obj[1]:
  30. obj[3] = mark_safe(obj[3])
  31. return Message(*obj[2:])
  32. return [self.process_messages(item) for item in obj]
  33. if isinstance(obj, dict):
  34. return {key: self.process_messages(value) for key, value in obj.items()}
  35. return obj
  36. def decode(self, s, **kwargs):
  37. decoded = super().decode(s, **kwargs)
  38. return self.process_messages(decoded)
  39. class MessagePartSerializer:
  40. def dumps(self, obj):
  41. return [
  42. json.dumps(
  43. o,
  44. separators=(",", ":"),
  45. cls=MessageEncoder,
  46. )
  47. for o in obj
  48. ]
  49. class MessagePartGatherSerializer:
  50. def dumps(self, obj):
  51. """
  52. The parameter is an already serialized list of Message objects. No need
  53. to serialize it again, only join the list together and encode it.
  54. """
  55. return ("[" + ",".join(obj) + "]").encode("latin-1")
  56. class MessageSerializer:
  57. def loads(self, data):
  58. return json.loads(data.decode("latin-1"), cls=MessageDecoder)
  59. class CookieStorage(BaseStorage):
  60. """
  61. Store messages in a cookie.
  62. """
  63. cookie_name = "messages"
  64. # uwsgi's default configuration enforces a maximum size of 4kb for all the
  65. # HTTP headers. In order to leave some room for other cookies and headers,
  66. # restrict the session cookie to 1/2 of 4kb. See #18781.
  67. max_cookie_size = 2048
  68. not_finished = "__messagesnotfinished__"
  69. not_finished_json = json.dumps("__messagesnotfinished__")
  70. key_salt = "django.contrib.messages"
  71. def __init__(self, *args, **kwargs):
  72. super().__init__(*args, **kwargs)
  73. self.signer = signing.get_cookie_signer(salt=self.key_salt)
  74. def _get(self, *args, **kwargs):
  75. """
  76. Retrieve a list of messages from the messages cookie. If the
  77. not_finished sentinel value is found at the end of the message list,
  78. remove it and return a result indicating that not all messages were
  79. retrieved by this storage.
  80. """
  81. data = self.request.COOKIES.get(self.cookie_name)
  82. messages = self._decode(data)
  83. all_retrieved = not (messages and messages[-1] == self.not_finished)
  84. if messages and not all_retrieved:
  85. # remove the sentinel value
  86. messages.pop()
  87. return messages, all_retrieved
  88. def _update_cookie(self, encoded_data, response):
  89. """
  90. Either set the cookie with the encoded data if there is any data to
  91. store, or delete the cookie.
  92. """
  93. if encoded_data:
  94. response.set_cookie(
  95. self.cookie_name,
  96. encoded_data,
  97. domain=settings.SESSION_COOKIE_DOMAIN,
  98. secure=settings.SESSION_COOKIE_SECURE or None,
  99. httponly=settings.SESSION_COOKIE_HTTPONLY or None,
  100. samesite=settings.SESSION_COOKIE_SAMESITE,
  101. )
  102. else:
  103. response.delete_cookie(
  104. self.cookie_name,
  105. domain=settings.SESSION_COOKIE_DOMAIN,
  106. samesite=settings.SESSION_COOKIE_SAMESITE,
  107. )
  108. def _store(self, messages, response, remove_oldest=True, *args, **kwargs):
  109. """
  110. Store the messages to a cookie and return a list of any messages which
  111. could not be stored.
  112. If the encoded data is larger than ``max_cookie_size``, remove
  113. messages until the data fits (these are the messages which are
  114. returned), and add the not_finished sentinel value to indicate as much.
  115. """
  116. unstored_messages = []
  117. serialized_messages = MessagePartSerializer().dumps(messages)
  118. encoded_data = self._encode_parts(serialized_messages)
  119. if self.max_cookie_size:
  120. # data is going to be stored eventually by SimpleCookie, which
  121. # adds its own overhead, which we must account for.
  122. cookie = SimpleCookie() # create outside the loop
  123. def is_too_large_for_cookie(data):
  124. return data and len(cookie.value_encode(data)[1]) > self.max_cookie_size
  125. def compute_msg(some_serialized_msg):
  126. return self._encode_parts(
  127. some_serialized_msg + [self.not_finished_json],
  128. encode_empty=True,
  129. )
  130. if is_too_large_for_cookie(encoded_data):
  131. if remove_oldest:
  132. idx = bisect_keep_right(
  133. serialized_messages,
  134. fn=lambda m: is_too_large_for_cookie(compute_msg(m)),
  135. )
  136. unstored_messages = messages[:idx]
  137. encoded_data = compute_msg(serialized_messages[idx:])
  138. else:
  139. idx = bisect_keep_left(
  140. serialized_messages,
  141. fn=lambda m: is_too_large_for_cookie(compute_msg(m)),
  142. )
  143. unstored_messages = messages[idx:]
  144. encoded_data = compute_msg(serialized_messages[:idx])
  145. self._update_cookie(encoded_data, response)
  146. return unstored_messages
  147. def _encode_parts(self, messages, encode_empty=False):
  148. """
  149. Return an encoded version of the serialized messages list which can be
  150. stored as plain text.
  151. Since the data will be retrieved from the client-side, the encoded data
  152. also contains a hash to ensure that the data was not tampered with.
  153. """
  154. if messages or encode_empty:
  155. return self.signer.sign_object(
  156. messages, serializer=MessagePartGatherSerializer, compress=True
  157. )
  158. def _encode(self, messages, encode_empty=False):
  159. """
  160. Return an encoded version of the messages list which can be stored as
  161. plain text.
  162. Proxies MessagePartSerializer.dumps and _encoded_parts.
  163. """
  164. serialized_messages = MessagePartSerializer().dumps(messages)
  165. return self._encode_parts(serialized_messages, encode_empty=encode_empty)
  166. def _decode(self, data):
  167. """
  168. Safely decode an encoded text stream back into a list of messages.
  169. If the encoded text stream contained an invalid hash or was in an
  170. invalid format, return None.
  171. """
  172. if not data:
  173. return None
  174. try:
  175. return self.signer.unsign_object(data, serializer=MessageSerializer)
  176. except (signing.BadSignature, binascii.Error, json.JSONDecodeError):
  177. pass
  178. # Mark the data as used (so it gets removed) since something was wrong
  179. # with the data.
  180. self.used = True
  181. return None
  182. def bisect_keep_left(a, fn):
  183. """
  184. Find the index of the first element from the start of the array that
  185. verifies the given condition.
  186. The function is applied from the start of the array to the pivot.
  187. """
  188. lo = 0
  189. hi = len(a)
  190. while lo < hi:
  191. mid = (lo + hi) // 2
  192. if fn(a[: mid + 1]):
  193. hi = mid
  194. else:
  195. lo = mid + 1
  196. return lo
  197. def bisect_keep_right(a, fn):
  198. """
  199. Find the index of the first element from the end of the array that verifies
  200. the given condition.
  201. The function is applied from the pivot to the end of array.
  202. """
  203. lo = 0
  204. hi = len(a)
  205. while lo < hi:
  206. mid = (lo + hi) // 2
  207. if fn(a[mid:]):
  208. lo = mid + 1
  209. else:
  210. hi = mid
  211. return lo