| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- from logging import getLogger
- from typing import Any, Union
- from ..exceptions import ConnectionError, InvalidResponse, ResponseError
- from ..typing import EncodableT
- from .base import _AsyncRESPBase, _RESPBase
- from .socket import SERVER_CLOSED_CONNECTION_ERROR
- class _RESP3Parser(_RESPBase):
- """RESP3 protocol implementation"""
- def __init__(self, socket_read_size):
- super().__init__(socket_read_size)
- self.push_handler_func = self.handle_push_response
- def handle_push_response(self, response):
- logger = getLogger("push_response")
- logger.info("Push response: " + str(response))
- return response
- def read_response(self, disable_decoding=False, push_request=False):
- pos = self._buffer.get_pos() if self._buffer else None
- try:
- result = self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- except BaseException:
- if self._buffer:
- self._buffer.rewind(pos)
- raise
- else:
- self._buffer.purge()
- return result
- def _read_response(self, disable_decoding=False, push_request=False):
- raw = self._buffer.readline()
- if not raw:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- byte, response = raw[:1], raw[1:]
- # server returned an error
- if byte in (b"-", b"!"):
- if byte == b"!":
- response = self._buffer.read(int(response))
- response = response.decode("utf-8", errors="replace")
- error = self.parse_error(response)
- # if the error is a ConnectionError, raise immediately so the user
- # is notified
- if isinstance(error, ConnectionError):
- raise error
- # otherwise, we're dealing with a ResponseError that might belong
- # inside a pipeline response. the connection's read_response()
- # and/or the pipeline's execute() will raise this error if
- # necessary, so just return the exception instance here.
- return error
- # single value
- elif byte == b"+":
- pass
- # null value
- elif byte == b"_":
- return None
- # int and big int values
- elif byte in (b":", b"("):
- return int(response)
- # double value
- elif byte == b",":
- return float(response)
- # bool value
- elif byte == b"#":
- return response == b"t"
- # bulk response
- elif byte == b"$":
- response = self._buffer.read(int(response))
- # verbatim string response
- elif byte == b"=":
- response = self._buffer.read(int(response))[4:]
- # array response
- elif byte == b"*":
- response = [
- self._read_response(disable_decoding=disable_decoding)
- for _ in range(int(response))
- ]
- # set response
- elif byte == b"~":
- # redis can return unhashable types (like dict) in a set,
- # so we need to first convert to a list, and then try to convert it to a set
- response = [
- self._read_response(disable_decoding=disable_decoding)
- for _ in range(int(response))
- ]
- try:
- response = set(response)
- except TypeError:
- pass
- # map response
- elif byte == b"%":
- # We cannot use a dict-comprehension to parse stream.
- # Evaluation order of key:val expression in dict comprehension only
- # became defined to be left-right in version 3.8
- resp_dict = {}
- for _ in range(int(response)):
- key = self._read_response(disable_decoding=disable_decoding)
- resp_dict[key] = self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- response = resp_dict
- # push response
- elif byte == b">":
- response = [
- self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- for _ in range(int(response))
- ]
- res = self.push_handler_func(response)
- if not push_request:
- return self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- else:
- return res
- else:
- raise InvalidResponse(f"Protocol Error: {raw!r}")
- if isinstance(response, bytes) and disable_decoding is False:
- response = self.encoder.decode(response)
- return response
- def set_push_handler(self, push_handler_func):
- self.push_handler_func = push_handler_func
- class _AsyncRESP3Parser(_AsyncRESPBase):
- def __init__(self, socket_read_size):
- super().__init__(socket_read_size)
- self.push_handler_func = self.handle_push_response
- def handle_push_response(self, response):
- logger = getLogger("push_response")
- logger.info("Push response: " + str(response))
- return response
- async def read_response(
- self, disable_decoding: bool = False, push_request: bool = False
- ):
- if self._chunks:
- # augment parsing buffer with previously read data
- self._buffer += b"".join(self._chunks)
- self._chunks.clear()
- self._pos = 0
- response = await self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- # Successfully parsing a response allows us to clear our parsing buffer
- self._clear()
- return response
- async def _read_response(
- self, disable_decoding: bool = False, push_request: bool = False
- ) -> Union[EncodableT, ResponseError, None]:
- if not self._stream or not self.encoder:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- raw = await self._readline()
- response: Any
- byte, response = raw[:1], raw[1:]
- # if byte not in (b"-", b"+", b":", b"$", b"*"):
- # raise InvalidResponse(f"Protocol Error: {raw!r}")
- # server returned an error
- if byte in (b"-", b"!"):
- if byte == b"!":
- response = await self._read(int(response))
- response = response.decode("utf-8", errors="replace")
- error = self.parse_error(response)
- # if the error is a ConnectionError, raise immediately so the user
- # is notified
- if isinstance(error, ConnectionError):
- self._clear() # Successful parse
- raise error
- # otherwise, we're dealing with a ResponseError that might belong
- # inside a pipeline response. the connection's read_response()
- # and/or the pipeline's execute() will raise this error if
- # necessary, so just return the exception instance here.
- return error
- # single value
- elif byte == b"+":
- pass
- # null value
- elif byte == b"_":
- return None
- # int and big int values
- elif byte in (b":", b"("):
- return int(response)
- # double value
- elif byte == b",":
- return float(response)
- # bool value
- elif byte == b"#":
- return response == b"t"
- # bulk response
- elif byte == b"$":
- response = await self._read(int(response))
- # verbatim string response
- elif byte == b"=":
- response = (await self._read(int(response)))[4:]
- # array response
- elif byte == b"*":
- response = [
- (await self._read_response(disable_decoding=disable_decoding))
- for _ in range(int(response))
- ]
- # set response
- elif byte == b"~":
- # redis can return unhashable types (like dict) in a set,
- # so we need to first convert to a list, and then try to convert it to a set
- response = [
- (await self._read_response(disable_decoding=disable_decoding))
- for _ in range(int(response))
- ]
- try:
- response = set(response)
- except TypeError:
- pass
- # map response
- elif byte == b"%":
- # We cannot use a dict-comprehension to parse stream.
- # Evaluation order of key:val expression in dict comprehension only
- # became defined to be left-right in version 3.8
- resp_dict = {}
- for _ in range(int(response)):
- key = await self._read_response(disable_decoding=disable_decoding)
- resp_dict[key] = await self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- response = resp_dict
- # push response
- elif byte == b">":
- response = [
- (
- await self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- )
- for _ in range(int(response))
- ]
- res = self.push_handler_func(response)
- if not push_request:
- return await self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- else:
- return res
- else:
- raise InvalidResponse(f"Protocol Error: {raw!r}")
- if isinstance(response, bytes) and disable_decoding is False:
- response = self.encoder.decode(response)
- return response
- def set_push_handler(self, push_handler_func):
- self.push_handler_func = push_handler_func
|