| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778 |
- import random
- import re
- import socket
- from collections import OrderedDict
- from datetime import datetime
- from typing import Any, Dict, Iterator, List, Optional, Union
- from django.conf import settings
- from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func
- from django.core.exceptions import ImproperlyConfigured
- from django.utils.module_loading import import_string
- from redis import Redis
- from redis.exceptions import ConnectionError, ResponseError, TimeoutError
- from .. import pool
- from ..exceptions import CompressorError, ConnectionInterrupted
- from ..util import CacheKey
- _main_exceptions = (TimeoutError, ResponseError, ConnectionError, socket.timeout)
- special_re = re.compile("([*?[])")
- def glob_escape(s: str) -> str:
- return special_re.sub(r"[\1]", s)
- class DefaultClient:
- def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None:
- self._backend = backend
- self._server = server
- self._params = params
- self.reverse_key = get_key_func(
- params.get("REVERSE_KEY_FUNCTION")
- or "django_redis.util.default_reverse_key"
- )
- if not self._server:
- raise ImproperlyConfigured("Missing connections string")
- if not isinstance(self._server, (list, tuple, set)):
- self._server = self._server.split(",")
- self._clients: List[Optional[Redis]] = [None] * len(self._server)
- self._options = params.get("OPTIONS", {})
- self._replica_read_only = self._options.get("REPLICA_READ_ONLY", True)
- serializer_path = self._options.get(
- "SERIALIZER", "django_redis.serializers.pickle.PickleSerializer"
- )
- serializer_cls = import_string(serializer_path)
- compressor_path = self._options.get(
- "COMPRESSOR", "django_redis.compressors.identity.IdentityCompressor"
- )
- compressor_cls = import_string(compressor_path)
- self._serializer = serializer_cls(options=self._options)
- self._compressor = compressor_cls(options=self._options)
- self.connection_factory = pool.get_connection_factory(options=self._options)
- def __contains__(self, key: Any) -> bool:
- return self.has_key(key)
- def get_next_client_index(
- self, write: bool = True, tried: Optional[List[int]] = None
- ) -> int:
- """
- Return a next index for read client. This function implements a default
- behavior for get a next read client for a replication setup.
- Overwrite this function if you want a specific
- behavior.
- """
- if tried is None:
- tried = list()
- if tried and len(tried) < len(self._server):
- not_tried = [i for i in range(0, len(self._server)) if i not in tried]
- return random.choice(not_tried)
- if write or len(self._server) == 1:
- return 0
- return random.randint(1, len(self._server) - 1)
- def get_client(
- self,
- write: bool = True,
- tried: Optional[List[int]] = None,
- show_index: bool = False,
- ):
- """
- Method used for obtain a raw redis client.
- This function is used by almost all cache backend
- operations for obtain a native redis client/connection
- instance.
- """
- index = self.get_next_client_index(write=write, tried=tried)
- if self._clients[index] is None:
- self._clients[index] = self.connect(index)
- if show_index:
- return self._clients[index], index
- else:
- return self._clients[index]
- def connect(self, index: int = 0) -> Redis:
- """
- Given a connection index, returns a new raw redis client/connection
- instance. Index is used for replication setups and indicates that
- connection string should be used. In normal setups, index is 0.
- """
- return self.connection_factory.connect(self._server[index])
- def disconnect(self, index=0, client=None):
- """delegates the connection factory to disconnect the client"""
- if not client:
- client = self._clients[index]
- return self.connection_factory.disconnect(client) if client else None
- def set(
- self,
- key: Any,
- value: Any,
- timeout: Optional[float] = DEFAULT_TIMEOUT,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- nx: bool = False,
- xx: bool = False,
- ) -> bool:
- """
- Persist a value to the cache, and set an optional expiration time.
- Also supports optional nx parameter. If set to True - will use redis
- setnx instead of set.
- """
- nkey = self.make_key(key, version=version)
- nvalue = self.encode(value)
- if timeout is DEFAULT_TIMEOUT:
- timeout = self._backend.default_timeout
- original_client = client
- tried: List[int] = []
- while True:
- try:
- if client is None:
- client, index = self.get_client(
- write=True, tried=tried, show_index=True
- )
- if timeout is not None:
- # Convert to milliseconds
- timeout = int(timeout * 1000)
- if timeout <= 0:
- if nx:
- # Using negative timeouts when nx is True should
- # not expire (in our case delete) the value if it exists.
- # Obviously expire not existent value is noop.
- return not self.has_key(key, version=version, client=client)
- else:
- # redis doesn't support negative timeouts in ex flags
- # so it seems that it's better to just delete the key
- # than to set it and than expire in a pipeline
- return bool(
- self.delete(key, client=client, version=version)
- )
- return bool(client.set(nkey, nvalue, nx=nx, px=timeout, xx=xx))
- except _main_exceptions as e:
- if (
- not original_client
- and not self._replica_read_only
- and len(tried) < len(self._server)
- ):
- tried.append(index)
- client = None
- continue
- raise ConnectionInterrupted(connection=client) from e
- def incr_version(
- self,
- key: Any,
- delta: int = 1,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ) -> int:
- """
- Adds delta to the cache version for the supplied key. Returns the
- new version.
- """
- if client is None:
- client = self.get_client(write=True)
- if version is None:
- version = self._backend.version
- old_key = self.make_key(key, version)
- value = self.get(old_key, version=version, client=client)
- try:
- ttl = self.ttl(old_key, version=version, client=client)
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- if value is None:
- raise ValueError("Key '%s' not found" % key)
- if isinstance(key, CacheKey):
- new_key = self.make_key(key.original_key(), version=version + delta)
- else:
- new_key = self.make_key(key, version=version + delta)
- self.set(new_key, value, timeout=ttl, client=client)
- self.delete(old_key, client=client)
- return version + delta
- def add(
- self,
- key: Any,
- value: Any,
- timeout: Any = DEFAULT_TIMEOUT,
- version: Optional[Any] = None,
- client: Optional[Redis] = None,
- ) -> bool:
- """
- Add a value to the cache, failing if the key already exists.
- Returns ``True`` if the object was added, ``False`` if not.
- """
- return self.set(key, value, timeout, version=version, client=client, nx=True)
- def get(
- self,
- key: Any,
- default=None,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ) -> Any:
- """
- Retrieve a value from the cache.
- Returns decoded value if key is found, the default if not.
- """
- if client is None:
- client = self.get_client(write=False)
- key = self.make_key(key, version=version)
- try:
- value = client.get(key)
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- if value is None:
- return default
- return self.decode(value)
- def persist(
- self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None
- ) -> bool:
- if client is None:
- client = self.get_client(write=True)
- key = self.make_key(key, version=version)
- return client.persist(key)
- def expire(
- self,
- key: Any,
- timeout,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ) -> bool:
- if client is None:
- client = self.get_client(write=True)
- key = self.make_key(key, version=version)
- return client.expire(key, timeout)
- def pexpire(self, key, timeout, version=None, client=None) -> bool:
- if client is None:
- client = self.get_client(write=True)
- key = self.make_key(key, version=version)
- # Temporary casting until https://github.com/redis/redis-py/issues/1664
- # is fixed.
- return bool(client.pexpire(key, timeout))
- def pexpire_at(
- self,
- key: Any,
- when: Union[datetime, int],
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ) -> bool:
- """
- Set an expire flag on a ``key`` to ``when``, which can be represented
- as an integer indicating unix time or a Python datetime object.
- """
- if client is None:
- client = self.get_client(write=True)
- key = self.make_key(key, version=version)
- return bool(client.pexpireat(key, when))
- def expire_at(
- self,
- key: Any,
- when: Union[datetime, int],
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ) -> bool:
- """
- Set an expire flag on a ``key`` to ``when``, which can be represented
- as an integer indicating unix time or a Python datetime object.
- """
- if client is None:
- client = self.get_client(write=True)
- key = self.make_key(key, version=version)
- return client.expireat(key, when)
- def lock(
- self,
- key,
- version: Optional[int] = None,
- timeout=None,
- sleep=0.1,
- blocking_timeout=None,
- client: Optional[Redis] = None,
- thread_local=True,
- ):
- if client is None:
- client = self.get_client(write=True)
- key = self.make_key(key, version=version)
- return client.lock(
- key,
- timeout=timeout,
- sleep=sleep,
- blocking_timeout=blocking_timeout,
- thread_local=thread_local,
- )
- def delete(
- self,
- key: Any,
- version: Optional[int] = None,
- prefix: Optional[str] = None,
- client: Optional[Redis] = None,
- ) -> int:
- """
- Remove a key from the cache.
- """
- if client is None:
- client = self.get_client(write=True)
- try:
- return client.delete(self.make_key(key, version=version, prefix=prefix))
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- def delete_pattern(
- self,
- pattern: str,
- version: Optional[int] = None,
- prefix: Optional[str] = None,
- client: Optional[Redis] = None,
- itersize: Optional[int] = None,
- ) -> int:
- """
- Remove all keys matching pattern.
- """
- if client is None:
- client = self.get_client(write=True)
- pattern = self.make_pattern(pattern, version=version, prefix=prefix)
- try:
- count = 0
- pipeline = client.pipeline()
- for key in client.scan_iter(match=pattern, count=itersize):
- pipeline.delete(key)
- count += 1
- pipeline.execute()
- return count
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- def delete_many(
- self, keys, version: Optional[int] = None, client: Optional[Redis] = None
- ):
- """
- Remove multiple keys at once.
- """
- if client is None:
- client = self.get_client(write=True)
- keys = [self.make_key(k, version=version) for k in keys]
- if not keys:
- return
- try:
- return client.delete(*keys)
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- def clear(self, client: Optional[Redis] = None) -> None:
- """
- Flush all cache keys.
- """
- if client is None:
- client = self.get_client(write=True)
- try:
- client.flushdb()
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- def decode(self, value: Union[bytes, int]) -> Any:
- """
- Decode the given value.
- """
- try:
- value = int(value)
- except (ValueError, TypeError):
- try:
- value = self._compressor.decompress(value)
- except CompressorError:
- # Handle little values, chosen to be not compressed
- pass
- value = self._serializer.loads(value)
- return value
- def encode(self, value: Any) -> Union[bytes, Any]:
- """
- Encode the given value.
- """
- if isinstance(value, bool) or not isinstance(value, int):
- value = self._serializer.dumps(value)
- value = self._compressor.compress(value)
- return value
- return value
- def get_many(
- self, keys, version: Optional[int] = None, client: Optional[Redis] = None
- ) -> OrderedDict:
- """
- Retrieve many keys.
- """
- if client is None:
- client = self.get_client(write=False)
- if not keys:
- return OrderedDict()
- recovered_data = OrderedDict()
- map_keys = OrderedDict((self.make_key(k, version=version), k) for k in keys)
- try:
- results = client.mget(*map_keys)
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- for key, value in zip(map_keys, results):
- if value is None:
- continue
- recovered_data[map_keys[key]] = self.decode(value)
- return recovered_data
- def set_many(
- self,
- data: Dict[Any, Any],
- timeout: Optional[float] = DEFAULT_TIMEOUT,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ) -> None:
- """
- Set a bunch of values in the cache at once from a dict of key/value
- pairs. This is much more efficient than calling set() multiple times.
- If timeout is given, that timeout will be used for the key; otherwise
- the default cache timeout will be used.
- """
- if client is None:
- client = self.get_client(write=True)
- try:
- pipeline = client.pipeline()
- for key, value in data.items():
- self.set(key, value, timeout, version=version, client=pipeline)
- pipeline.execute()
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- def _incr(
- self,
- key: Any,
- delta: int = 1,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ignore_key_check: bool = False,
- ) -> int:
- if client is None:
- client = self.get_client(write=True)
- key = self.make_key(key, version=version)
- try:
- try:
- # if key expired after exists check, then we get
- # key with wrong value and ttl -1.
- # use lua script for atomicity
- if not ignore_key_check:
- lua = """
- local exists = redis.call('EXISTS', KEYS[1])
- if (exists == 1) then
- return redis.call('INCRBY', KEYS[1], ARGV[1])
- else return false end
- """
- else:
- lua = """
- return redis.call('INCRBY', KEYS[1], ARGV[1])
- """
- value = client.eval(lua, 1, key, delta)
- if value is None:
- raise ValueError("Key '%s' not found" % key)
- except ResponseError:
- # if cached value or total value is greater than 64 bit signed
- # integer.
- # elif int is encoded. so redis sees the data as string.
- # In this situations redis will throw ResponseError
- # try to keep TTL of key
- timeout = self.ttl(key, version=version, client=client)
- # returns -2 if the key does not exist
- # means, that key have expired
- if timeout == -2:
- raise ValueError("Key '%s' not found" % key)
- value = self.get(key, version=version, client=client) + delta
- self.set(key, value, version=version, timeout=timeout, client=client)
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- return value
- def incr(
- self,
- key: Any,
- delta: int = 1,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ignore_key_check: bool = False,
- ) -> int:
- """
- Add delta to value in the cache. If the key does not exist, raise a
- ValueError exception. if ignore_key_check=True then the key will be
- created and set to the delta value by default.
- """
- return self._incr(
- key=key,
- delta=delta,
- version=version,
- client=client,
- ignore_key_check=ignore_key_check,
- )
- def decr(
- self,
- key: Any,
- delta: int = 1,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ) -> int:
- """
- Decreace delta to value in the cache. If the key does not exist, raise a
- ValueError exception.
- """
- return self._incr(key=key, delta=-delta, version=version, client=client)
- def ttl(
- self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None
- ) -> Optional[int]:
- """
- Executes TTL redis command and return the "time-to-live" of specified key.
- If key is a non volatile key, it returns None.
- """
- if client is None:
- client = self.get_client(write=False)
- key = self.make_key(key, version=version)
- if not client.exists(key):
- return 0
- t = client.ttl(key)
- if t >= 0:
- return t
- elif t == -1:
- return None
- elif t == -2:
- return 0
- else:
- # Should never reach here
- return None
- def pttl(self, key, version=None, client=None):
- """
- Executes PTTL redis command and return the "time-to-live" of specified key.
- If key is a non volatile key, it returns None.
- """
- if client is None:
- client = self.get_client(write=False)
- key = self.make_key(key, version=version)
- if not client.exists(key):
- return 0
- t = client.pttl(key)
- if t >= 0:
- return t
- elif t == -1:
- return None
- elif t == -2:
- return 0
- else:
- # Should never reach here
- return None
- def has_key(
- self, key: Any, version: Optional[int] = None, client: Optional[Redis] = None
- ) -> bool:
- """
- Test if key exists.
- """
- if client is None:
- client = self.get_client(write=False)
- key = self.make_key(key, version=version)
- try:
- return client.exists(key) == 1
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- def iter_keys(
- self,
- search: str,
- itersize: Optional[int] = None,
- client: Optional[Redis] = None,
- version: Optional[int] = None,
- ) -> Iterator[str]:
- """
- Same as keys, but uses redis >= 2.8 cursors
- for make memory efficient keys iteration.
- """
- if client is None:
- client = self.get_client(write=False)
- pattern = self.make_pattern(search, version=version)
- for item in client.scan_iter(match=pattern, count=itersize):
- yield self.reverse_key(item.decode())
- def keys(
- self, search: str, version: Optional[int] = None, client: Optional[Redis] = None
- ) -> List[Any]:
- """
- Execute KEYS command and return matched results.
- Warning: this can return huge number of results, in
- this case, it strongly recommended use iter_keys
- for it.
- """
- if client is None:
- client = self.get_client(write=False)
- pattern = self.make_pattern(search, version=version)
- try:
- return [self.reverse_key(k.decode()) for k in client.keys(pattern)]
- except _main_exceptions as e:
- raise ConnectionInterrupted(connection=client) from e
- def make_key(
- self, key: Any, version: Optional[Any] = None, prefix: Optional[str] = None
- ) -> CacheKey:
- if isinstance(key, CacheKey):
- return key
- if prefix is None:
- prefix = self._backend.key_prefix
- if version is None:
- version = self._backend.version
- return CacheKey(self._backend.key_func(key, prefix, version))
- def make_pattern(
- self, pattern: str, version: Optional[int] = None, prefix: Optional[str] = None
- ) -> CacheKey:
- if isinstance(pattern, CacheKey):
- return pattern
- if prefix is None:
- prefix = self._backend.key_prefix
- prefix = glob_escape(prefix)
- if version is None:
- version = self._backend.version
- version_str = glob_escape(str(version))
- return CacheKey(self._backend.key_func(pattern, prefix, version_str))
- def close(self, **kwargs):
- close_flag = self._options.get(
- "CLOSE_CONNECTION",
- getattr(settings, "DJANGO_REDIS_CLOSE_CONNECTION", False),
- )
- if close_flag:
- self.do_close_clients()
- def do_close_clients(self):
- """default implementation: Override in custom client"""
- num_clients = len(self._clients)
- for idx in range(num_clients):
- self.disconnect(index=idx)
- self._clients = [None] * num_clients
- def touch(
- self,
- key: Any,
- timeout: Optional[float] = DEFAULT_TIMEOUT,
- version: Optional[int] = None,
- client: Optional[Redis] = None,
- ) -> bool:
- """
- Sets a new expiration for a key.
- """
- if timeout is DEFAULT_TIMEOUT:
- timeout = self._backend.default_timeout
- if client is None:
- client = self.get_client(write=True)
- key = self.make_key(key, version=version)
- if timeout is None:
- return bool(client.persist(key))
- else:
- # Convert to milliseconds
- timeout = int(timeout * 1000)
- return bool(client.pexpire(key, timeout))
|