base.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import sys
  2. from abc import ABC
  3. from asyncio import IncompleteReadError, StreamReader, TimeoutError
  4. from typing import List, Optional, Union
  5. if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
  6. from asyncio import timeout as async_timeout
  7. else:
  8. from async_timeout import timeout as async_timeout
  9. from ..exceptions import (
  10. AuthenticationError,
  11. AuthenticationWrongNumberOfArgsError,
  12. BusyLoadingError,
  13. ConnectionError,
  14. ExecAbortError,
  15. ModuleError,
  16. NoPermissionError,
  17. NoScriptError,
  18. OutOfMemoryError,
  19. ReadOnlyError,
  20. RedisError,
  21. ResponseError,
  22. )
  23. from ..typing import EncodableT
  24. from .encoders import Encoder
  25. from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
  26. MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs."
  27. NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
  28. MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible."
  29. MODULE_EXPORTS_DATA_TYPES_ERROR = (
  30. "Error unloading module: the module "
  31. "exports one or more module-side data "
  32. "types, can't unload"
  33. )
  34. # user send an AUTH cmd to a server without authorization configured
  35. NO_AUTH_SET_ERROR = {
  36. # Redis >= 6.0
  37. "AUTH <password> called without any password "
  38. "configured for the default user. Are you sure "
  39. "your configuration is correct?": AuthenticationError,
  40. # Redis < 6.0
  41. "Client sent AUTH, but no password is set": AuthenticationError,
  42. }
  43. class BaseParser(ABC):
  44. EXCEPTION_CLASSES = {
  45. "ERR": {
  46. "max number of clients reached": ConnectionError,
  47. "invalid password": AuthenticationError,
  48. # some Redis server versions report invalid command syntax
  49. # in lowercase
  50. "wrong number of arguments "
  51. "for 'auth' command": AuthenticationWrongNumberOfArgsError,
  52. # some Redis server versions report invalid command syntax
  53. # in uppercase
  54. "wrong number of arguments "
  55. "for 'AUTH' command": AuthenticationWrongNumberOfArgsError,
  56. MODULE_LOAD_ERROR: ModuleError,
  57. MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
  58. NO_SUCH_MODULE_ERROR: ModuleError,
  59. MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
  60. **NO_AUTH_SET_ERROR,
  61. },
  62. "OOM": OutOfMemoryError,
  63. "WRONGPASS": AuthenticationError,
  64. "EXECABORT": ExecAbortError,
  65. "LOADING": BusyLoadingError,
  66. "NOSCRIPT": NoScriptError,
  67. "READONLY": ReadOnlyError,
  68. "NOAUTH": AuthenticationError,
  69. "NOPERM": NoPermissionError,
  70. }
  71. @classmethod
  72. def parse_error(cls, response):
  73. "Parse an error response"
  74. error_code = response.split(" ")[0]
  75. if error_code in cls.EXCEPTION_CLASSES:
  76. response = response[len(error_code) + 1 :]
  77. exception_class = cls.EXCEPTION_CLASSES[error_code]
  78. if isinstance(exception_class, dict):
  79. exception_class = exception_class.get(response, ResponseError)
  80. return exception_class(response)
  81. return ResponseError(response)
  82. def on_disconnect(self):
  83. raise NotImplementedError()
  84. def on_connect(self, connection):
  85. raise NotImplementedError()
  86. class _RESPBase(BaseParser):
  87. """Base class for sync-based resp parsing"""
  88. def __init__(self, socket_read_size):
  89. self.socket_read_size = socket_read_size
  90. self.encoder = None
  91. self._sock = None
  92. self._buffer = None
  93. def __del__(self):
  94. try:
  95. self.on_disconnect()
  96. except Exception:
  97. pass
  98. def on_connect(self, connection):
  99. "Called when the socket connects"
  100. self._sock = connection._sock
  101. self._buffer = SocketBuffer(
  102. self._sock, self.socket_read_size, connection.socket_timeout
  103. )
  104. self.encoder = connection.encoder
  105. def on_disconnect(self):
  106. "Called when the socket disconnects"
  107. self._sock = None
  108. if self._buffer is not None:
  109. self._buffer.close()
  110. self._buffer = None
  111. self.encoder = None
  112. def can_read(self, timeout):
  113. return self._buffer and self._buffer.can_read(timeout)
  114. class AsyncBaseParser(BaseParser):
  115. """Base parsing class for the python-backed async parser"""
  116. __slots__ = "_stream", "_read_size"
  117. def __init__(self, socket_read_size: int):
  118. self._stream: Optional[StreamReader] = None
  119. self._read_size = socket_read_size
  120. async def can_read_destructive(self) -> bool:
  121. raise NotImplementedError()
  122. async def read_response(
  123. self, disable_decoding: bool = False
  124. ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
  125. raise NotImplementedError()
  126. class _AsyncRESPBase(AsyncBaseParser):
  127. """Base class for async resp parsing"""
  128. __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
  129. def __init__(self, socket_read_size: int):
  130. super().__init__(socket_read_size)
  131. self.encoder: Optional[Encoder] = None
  132. self._buffer = b""
  133. self._chunks = []
  134. self._pos = 0
  135. def _clear(self):
  136. self._buffer = b""
  137. self._chunks.clear()
  138. def on_connect(self, connection):
  139. """Called when the stream connects"""
  140. self._stream = connection._reader
  141. if self._stream is None:
  142. raise RedisError("Buffer is closed.")
  143. self.encoder = connection.encoder
  144. self._clear()
  145. self._connected = True
  146. def on_disconnect(self):
  147. """Called when the stream disconnects"""
  148. self._connected = False
  149. async def can_read_destructive(self) -> bool:
  150. if not self._connected:
  151. raise RedisError("Buffer is closed.")
  152. if self._buffer:
  153. return True
  154. try:
  155. async with async_timeout(0):
  156. return self._stream.at_eof()
  157. except TimeoutError:
  158. return False
  159. async def _read(self, length: int) -> bytes:
  160. """
  161. Read `length` bytes of data. These are assumed to be followed
  162. by a '\r\n' terminator which is subsequently discarded.
  163. """
  164. want = length + 2
  165. end = self._pos + want
  166. if len(self._buffer) >= end:
  167. result = self._buffer[self._pos : end - 2]
  168. else:
  169. tail = self._buffer[self._pos :]
  170. try:
  171. data = await self._stream.readexactly(want - len(tail))
  172. except IncompleteReadError as error:
  173. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
  174. result = (tail + data)[:-2]
  175. self._chunks.append(data)
  176. self._pos += want
  177. return result
  178. async def _readline(self) -> bytes:
  179. """
  180. read an unknown number of bytes up to the next '\r\n'
  181. line separator, which is discarded.
  182. """
  183. found = self._buffer.find(b"\r\n", self._pos)
  184. if found >= 0:
  185. result = self._buffer[self._pos : found]
  186. else:
  187. tail = self._buffer[self._pos :]
  188. data = await self._stream.readline()
  189. if not data.endswith(b"\r\n"):
  190. raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
  191. result = (tail + data)[:-2]
  192. self._chunks.append(data)
  193. self._pos += len(result) + 2
  194. return result