| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324 |
- import re
- from collections import OrderedDict
- from datetime import datetime
- from typing import Union
- from redis.exceptions import ConnectionError
- from ..exceptions import ConnectionInterrupted
- from ..hash_ring import HashRing
- from ..util import CacheKey
- from .default import DEFAULT_TIMEOUT, DefaultClient
- class ShardClient(DefaultClient):
- _findhash = re.compile(r".*\{(.*)\}.*", re.I)
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- if not isinstance(self._server, (list, tuple)):
- self._server = [self._server]
- self._ring = HashRing(self._server)
- self._serverdict = self.connect()
- def get_client(self, *args, **kwargs):
- raise NotImplementedError
- def connect(self, index=0):
- connection_dict = {}
- for name in self._server:
- connection_dict[name] = self.connection_factory.connect(name)
- return connection_dict
- def get_server_name(self, _key):
- key = str(_key)
- g = self._findhash.match(key)
- if g is not None and len(g.groups()) > 0:
- key = g.groups()[0]
- name = self._ring.get_node(key)
- return name
- def get_server(self, key):
- name = self.get_server_name(key)
- return self._serverdict[name]
- def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().add(
- key=key, value=value, version=version, client=client, timeout=timeout
- )
- def get(self, key, default=None, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().get(key=key, default=default, version=version, client=client)
- def get_many(self, keys, version=None):
- if not keys:
- return {}
- recovered_data = OrderedDict()
- new_keys = [self.make_key(key, version=version) for key in keys]
- map_keys = dict(zip(new_keys, keys))
- for key in new_keys:
- client = self.get_server(key)
- value = self.get(key=key, version=version, client=client)
- if value is None:
- continue
- recovered_data[map_keys[key]] = value
- return recovered_data
- def set(
- self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None, nx=False
- ):
- """
- Persist a value to the cache, and set an optional expiration time.
- """
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().set(
- key=key, value=value, timeout=timeout, version=version, client=client, nx=nx
- )
- def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
- """
- Set a bunch of values in the cache at once from a dict of key/value
- pairs. This is much more efficient than calling set() multiple times.
- If timeout is given, that timeout will be used for the key; otherwise
- the default cache timeout will be used.
- """
- for key, value in data.items():
- self.set(key, value, timeout, version=version)
- def has_key(self, key, version=None, client=None):
- """
- Test if key exists.
- """
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- key = self.make_key(key, version=version)
- try:
- return client.exists(key) == 1
- except ConnectionError as e:
- raise ConnectionInterrupted(connection=client) from e
- def delete(self, key, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().delete(key=key, version=version, client=client)
- def ttl(self, key, version=None, client=None):
- """
- Executes TTL redis command and return the "time-to-live" of specified key.
- If key is a non volatile key, it returns None.
- """
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().ttl(key=key, version=version, client=client)
- def pttl(self, key, version=None, client=None):
- """
- Executes PTTL redis command and return the "time-to-live" of specified key
- in milliseconds. If key is a non volatile key, it returns None.
- """
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().pttl(key=key, version=version, client=client)
- def persist(self, key, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().persist(key=key, version=version, client=client)
- def expire(self, key, timeout, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().expire(key=key, timeout=timeout, version=version, client=client)
- def pexpire(self, key, timeout, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().pexpire(key=key, timeout=timeout, version=version, client=client)
- def pexpire_at(self, key, when: Union[datetime, int], version=None, client=None):
- """
- Set an expire flag on a ``key`` to ``when`` on a shard client.
- ``when`` which can be represented as an integer indicating unix
- time or a Python datetime object.
- """
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().pexpire_at(key=key, when=when, version=version, client=client)
- def expire_at(self, key, when: Union[datetime, int], version=None, client=None):
- """
- Set an expire flag on a ``key`` to ``when`` on a shard client.
- ``when`` which can be represented as an integer indicating unix
- time or a Python datetime object.
- """
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().expire_at(key=key, when=when, version=version, client=client)
- def lock(
- self,
- key,
- version=None,
- timeout=None,
- sleep=0.1,
- blocking_timeout=None,
- client=None,
- thread_local=True,
- ):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- key = self.make_key(key, version=version)
- return super().lock(
- key,
- timeout=timeout,
- sleep=sleep,
- client=client,
- blocking_timeout=blocking_timeout,
- thread_local=thread_local,
- )
- def delete_many(self, keys, version=None):
- """
- Remove multiple keys at once.
- """
- res = 0
- for key in [self.make_key(k, version=version) for k in keys]:
- client = self.get_server(key)
- res += self.delete(key, client=client)
- return res
- def incr_version(self, key, delta=1, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- if version is None:
- version = self._backend.version
- old_key = self.make_key(key, version)
- value = self.get(old_key, version=version, client=client)
- try:
- ttl = self.ttl(old_key, version=version, client=client)
- except ConnectionError as e:
- raise ConnectionInterrupted(connection=client) from e
- if value is None:
- raise ValueError("Key '%s' not found" % key)
- if isinstance(key, CacheKey):
- new_key = self.make_key(key.original_key(), version=version + delta)
- else:
- new_key = self.make_key(key, version=version + delta)
- self.set(new_key, value, timeout=ttl, client=self.get_server(new_key))
- self.delete(old_key, client=client)
- return version + delta
- def incr(self, key, delta=1, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().incr(key=key, delta=delta, version=version, client=client)
- def decr(self, key, delta=1, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().decr(key=key, delta=delta, version=version, client=client)
- def iter_keys(self, key, version=None):
- raise NotImplementedError("iter_keys not supported on sharded client")
- def keys(self, search, version=None):
- pattern = self.make_pattern(search, version=version)
- keys = []
- try:
- for server, connection in self._serverdict.items():
- keys.extend(connection.keys(pattern))
- except ConnectionError as e:
- # FIXME: technically all clients should be passed as `connection`.
- client = self.get_server(pattern)
- raise ConnectionInterrupted(connection=client) from e
- return [self.reverse_key(k.decode()) for k in keys]
- def delete_pattern(
- self, pattern, version=None, client=None, itersize=None, prefix=None
- ):
- """
- Remove all keys matching pattern.
- """
- pattern = self.make_pattern(pattern, version=version, prefix=prefix)
- kwargs = {"match": pattern}
- if itersize:
- kwargs["count"] = itersize
- keys = []
- for server, connection in self._serverdict.items():
- keys.extend(key for key in connection.scan_iter(**kwargs))
- res = 0
- if keys:
- for server, connection in self._serverdict.items():
- res += connection.delete(*keys)
- return res
- def do_close_clients(self):
- for client in self._serverdict.values():
- self.disconnect(client=client)
- def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None):
- if client is None:
- key = self.make_key(key, version=version)
- client = self.get_server(key)
- return super().touch(key=key, timeout=timeout, version=version, client=client)
- def clear(self, client=None):
- for connection in self._serverdict.values():
- connection.flushdb()
|