resp3.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. from logging import getLogger
  2. from typing import Any, Union
  3. from ..exceptions import ConnectionError, InvalidResponse, ResponseError
  4. from ..typing import EncodableT
  5. from .base import _AsyncRESPBase, _RESPBase
  6. from .socket import SERVER_CLOSED_CONNECTION_ERROR
  7. class _RESP3Parser(_RESPBase):
  8. """RESP3 protocol implementation"""
  9. def __init__(self, socket_read_size):
  10. super().__init__(socket_read_size)
  11. self.push_handler_func = self.handle_push_response
  12. def handle_push_response(self, response):
  13. logger = getLogger("push_response")
  14. logger.info("Push response: " + str(response))
  15. return response
  16. def read_response(self, disable_decoding=False, push_request=False):
  17. pos = self._buffer.get_pos() if self._buffer else None
  18. try:
  19. result = self._read_response(
  20. disable_decoding=disable_decoding, push_request=push_request
  21. )
  22. except BaseException:
  23. if self._buffer:
  24. self._buffer.rewind(pos)
  25. raise
  26. else:
  27. self._buffer.purge()
  28. return result
  29. def _read_response(self, disable_decoding=False, push_request=False):
  30. raw = self._buffer.readline()
  31. if not raw:
  32. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  33. byte, response = raw[:1], raw[1:]
  34. # server returned an error
  35. if byte in (b"-", b"!"):
  36. if byte == b"!":
  37. response = self._buffer.read(int(response))
  38. response = response.decode("utf-8", errors="replace")
  39. error = self.parse_error(response)
  40. # if the error is a ConnectionError, raise immediately so the user
  41. # is notified
  42. if isinstance(error, ConnectionError):
  43. raise error
  44. # otherwise, we're dealing with a ResponseError that might belong
  45. # inside a pipeline response. the connection's read_response()
  46. # and/or the pipeline's execute() will raise this error if
  47. # necessary, so just return the exception instance here.
  48. return error
  49. # single value
  50. elif byte == b"+":
  51. pass
  52. # null value
  53. elif byte == b"_":
  54. return None
  55. # int and big int values
  56. elif byte in (b":", b"("):
  57. return int(response)
  58. # double value
  59. elif byte == b",":
  60. return float(response)
  61. # bool value
  62. elif byte == b"#":
  63. return response == b"t"
  64. # bulk response
  65. elif byte == b"$":
  66. response = self._buffer.read(int(response))
  67. # verbatim string response
  68. elif byte == b"=":
  69. response = self._buffer.read(int(response))[4:]
  70. # array response
  71. elif byte == b"*":
  72. response = [
  73. self._read_response(disable_decoding=disable_decoding)
  74. for _ in range(int(response))
  75. ]
  76. # set response
  77. elif byte == b"~":
  78. # redis can return unhashable types (like dict) in a set,
  79. # so we need to first convert to a list, and then try to convert it to a set
  80. response = [
  81. self._read_response(disable_decoding=disable_decoding)
  82. for _ in range(int(response))
  83. ]
  84. try:
  85. response = set(response)
  86. except TypeError:
  87. pass
  88. # map response
  89. elif byte == b"%":
  90. # We cannot use a dict-comprehension to parse stream.
  91. # Evaluation order of key:val expression in dict comprehension only
  92. # became defined to be left-right in version 3.8
  93. resp_dict = {}
  94. for _ in range(int(response)):
  95. key = self._read_response(disable_decoding=disable_decoding)
  96. resp_dict[key] = self._read_response(
  97. disable_decoding=disable_decoding, push_request=push_request
  98. )
  99. response = resp_dict
  100. # push response
  101. elif byte == b">":
  102. response = [
  103. self._read_response(
  104. disable_decoding=disable_decoding, push_request=push_request
  105. )
  106. for _ in range(int(response))
  107. ]
  108. res = self.push_handler_func(response)
  109. if not push_request:
  110. return self._read_response(
  111. disable_decoding=disable_decoding, push_request=push_request
  112. )
  113. else:
  114. return res
  115. else:
  116. raise InvalidResponse(f"Protocol Error: {raw!r}")
  117. if isinstance(response, bytes) and disable_decoding is False:
  118. response = self.encoder.decode(response)
  119. return response
  120. def set_push_handler(self, push_handler_func):
  121. self.push_handler_func = push_handler_func
  122. class _AsyncRESP3Parser(_AsyncRESPBase):
  123. def __init__(self, socket_read_size):
  124. super().__init__(socket_read_size)
  125. self.push_handler_func = self.handle_push_response
  126. def handle_push_response(self, response):
  127. logger = getLogger("push_response")
  128. logger.info("Push response: " + str(response))
  129. return response
  130. async def read_response(
  131. self, disable_decoding: bool = False, push_request: bool = False
  132. ):
  133. if self._chunks:
  134. # augment parsing buffer with previously read data
  135. self._buffer += b"".join(self._chunks)
  136. self._chunks.clear()
  137. self._pos = 0
  138. response = await self._read_response(
  139. disable_decoding=disable_decoding, push_request=push_request
  140. )
  141. # Successfully parsing a response allows us to clear our parsing buffer
  142. self._clear()
  143. return response
  144. async def _read_response(
  145. self, disable_decoding: bool = False, push_request: bool = False
  146. ) -> Union[EncodableT, ResponseError, None]:
  147. if not self._stream or not self.encoder:
  148. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  149. raw = await self._readline()
  150. response: Any
  151. byte, response = raw[:1], raw[1:]
  152. # if byte not in (b"-", b"+", b":", b"$", b"*"):
  153. # raise InvalidResponse(f"Protocol Error: {raw!r}")
  154. # server returned an error
  155. if byte in (b"-", b"!"):
  156. if byte == b"!":
  157. response = await self._read(int(response))
  158. response = response.decode("utf-8", errors="replace")
  159. error = self.parse_error(response)
  160. # if the error is a ConnectionError, raise immediately so the user
  161. # is notified
  162. if isinstance(error, ConnectionError):
  163. self._clear() # Successful parse
  164. raise error
  165. # otherwise, we're dealing with a ResponseError that might belong
  166. # inside a pipeline response. the connection's read_response()
  167. # and/or the pipeline's execute() will raise this error if
  168. # necessary, so just return the exception instance here.
  169. return error
  170. # single value
  171. elif byte == b"+":
  172. pass
  173. # null value
  174. elif byte == b"_":
  175. return None
  176. # int and big int values
  177. elif byte in (b":", b"("):
  178. return int(response)
  179. # double value
  180. elif byte == b",":
  181. return float(response)
  182. # bool value
  183. elif byte == b"#":
  184. return response == b"t"
  185. # bulk response
  186. elif byte == b"$":
  187. response = await self._read(int(response))
  188. # verbatim string response
  189. elif byte == b"=":
  190. response = (await self._read(int(response)))[4:]
  191. # array response
  192. elif byte == b"*":
  193. response = [
  194. (await self._read_response(disable_decoding=disable_decoding))
  195. for _ in range(int(response))
  196. ]
  197. # set response
  198. elif byte == b"~":
  199. # redis can return unhashable types (like dict) in a set,
  200. # so we need to first convert to a list, and then try to convert it to a set
  201. response = [
  202. (await self._read_response(disable_decoding=disable_decoding))
  203. for _ in range(int(response))
  204. ]
  205. try:
  206. response = set(response)
  207. except TypeError:
  208. pass
  209. # map response
  210. elif byte == b"%":
  211. # We cannot use a dict-comprehension to parse stream.
  212. # Evaluation order of key:val expression in dict comprehension only
  213. # became defined to be left-right in version 3.8
  214. resp_dict = {}
  215. for _ in range(int(response)):
  216. key = await self._read_response(disable_decoding=disable_decoding)
  217. resp_dict[key] = await self._read_response(
  218. disable_decoding=disable_decoding, push_request=push_request
  219. )
  220. response = resp_dict
  221. # push response
  222. elif byte == b">":
  223. response = [
  224. (
  225. await self._read_response(
  226. disable_decoding=disable_decoding, push_request=push_request
  227. )
  228. )
  229. for _ in range(int(response))
  230. ]
  231. res = self.push_handler_func(response)
  232. if not push_request:
  233. return await self._read_response(
  234. disable_decoding=disable_decoding, push_request=push_request
  235. )
  236. else:
  237. return res
  238. else:
  239. raise InvalidResponse(f"Protocol Error: {raw!r}")
  240. if isinstance(response, bytes) and disable_decoding is False:
  241. response = self.encoder.decode(response)
  242. return response
  243. def set_push_handler(self, push_handler_func):
  244. self.push_handler_func = push_handler_func