resp2.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from typing import Any, Union
  2. from ..exceptions import ConnectionError, InvalidResponse, ResponseError
  3. from ..typing import EncodableT
  4. from .base import _AsyncRESPBase, _RESPBase
  5. from .socket import SERVER_CLOSED_CONNECTION_ERROR
  6. class _RESP2Parser(_RESPBase):
  7. """RESP2 protocol implementation"""
  8. def read_response(self, disable_decoding=False):
  9. pos = self._buffer.get_pos() if self._buffer else None
  10. try:
  11. result = self._read_response(disable_decoding=disable_decoding)
  12. except BaseException:
  13. if self._buffer:
  14. self._buffer.rewind(pos)
  15. raise
  16. else:
  17. self._buffer.purge()
  18. return result
  19. def _read_response(self, disable_decoding=False):
  20. raw = self._buffer.readline()
  21. if not raw:
  22. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  23. byte, response = raw[:1], raw[1:]
  24. # server returned an error
  25. if byte == b"-":
  26. response = response.decode("utf-8", errors="replace")
  27. error = self.parse_error(response)
  28. # if the error is a ConnectionError, raise immediately so the user
  29. # is notified
  30. if isinstance(error, ConnectionError):
  31. raise error
  32. # otherwise, we're dealing with a ResponseError that might belong
  33. # inside a pipeline response. the connection's read_response()
  34. # and/or the pipeline's execute() will raise this error if
  35. # necessary, so just return the exception instance here.
  36. return error
  37. # single value
  38. elif byte == b"+":
  39. pass
  40. # int value
  41. elif byte == b":":
  42. return int(response)
  43. # bulk response
  44. elif byte == b"$" and response == b"-1":
  45. return None
  46. elif byte == b"$":
  47. response = self._buffer.read(int(response))
  48. # multi-bulk response
  49. elif byte == b"*" and response == b"-1":
  50. return None
  51. elif byte == b"*":
  52. response = [
  53. self._read_response(disable_decoding=disable_decoding)
  54. for i in range(int(response))
  55. ]
  56. else:
  57. raise InvalidResponse(f"Protocol Error: {raw!r}")
  58. if disable_decoding is False:
  59. response = self.encoder.decode(response)
  60. return response
  61. class _AsyncRESP2Parser(_AsyncRESPBase):
  62. """Async class for the RESP2 protocol"""
  63. async def read_response(self, disable_decoding: bool = False):
  64. if not self._connected:
  65. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  66. if self._chunks:
  67. # augment parsing buffer with previously read data
  68. self._buffer += b"".join(self._chunks)
  69. self._chunks.clear()
  70. self._pos = 0
  71. response = await self._read_response(disable_decoding=disable_decoding)
  72. # Successfully parsing a response allows us to clear our parsing buffer
  73. self._clear()
  74. return response
  75. async def _read_response(
  76. self, disable_decoding: bool = False
  77. ) -> Union[EncodableT, ResponseError, None]:
  78. raw = await self._readline()
  79. response: Any
  80. byte, response = raw[:1], raw[1:]
  81. # server returned an error
  82. if byte == b"-":
  83. response = response.decode("utf-8", errors="replace")
  84. error = self.parse_error(response)
  85. # if the error is a ConnectionError, raise immediately so the user
  86. # is notified
  87. if isinstance(error, ConnectionError):
  88. self._clear() # Successful parse
  89. raise error
  90. # otherwise, we're dealing with a ResponseError that might belong
  91. # inside a pipeline response. the connection's read_response()
  92. # and/or the pipeline's execute() will raise this error if
  93. # necessary, so just return the exception instance here.
  94. return error
  95. # single value
  96. elif byte == b"+":
  97. pass
  98. # int value
  99. elif byte == b":":
  100. return int(response)
  101. # bulk response
  102. elif byte == b"$" and response == b"-1":
  103. return None
  104. elif byte == b"$":
  105. response = await self._read(int(response))
  106. # multi-bulk response
  107. elif byte == b"*" and response == b"-1":
  108. return None
  109. elif byte == b"*":
  110. response = [
  111. (await self._read_response(disable_decoding))
  112. for _ in range(int(response)) # noqa
  113. ]
  114. else:
  115. raise InvalidResponse(f"Protocol Error: {raw!r}")
  116. if disable_decoding is False:
  117. response = self.encoder.decode(response)
  118. return response