sharded.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import re
  2. from collections import OrderedDict
  3. from datetime import datetime
  4. from typing import Union
  5. from redis.exceptions import ConnectionError
  6. from ..exceptions import ConnectionInterrupted
  7. from ..hash_ring import HashRing
  8. from ..util import CacheKey
  9. from .default import DEFAULT_TIMEOUT, DefaultClient
  10. class ShardClient(DefaultClient):
  11. _findhash = re.compile(r".*\{(.*)\}.*", re.I)
  12. def __init__(self, *args, **kwargs):
  13. super().__init__(*args, **kwargs)
  14. if not isinstance(self._server, (list, tuple)):
  15. self._server = [self._server]
  16. self._ring = HashRing(self._server)
  17. self._serverdict = self.connect()
  18. def get_client(self, *args, **kwargs):
  19. raise NotImplementedError
  20. def connect(self, index=0):
  21. connection_dict = {}
  22. for name in self._server:
  23. connection_dict[name] = self.connection_factory.connect(name)
  24. return connection_dict
  25. def get_server_name(self, _key):
  26. key = str(_key)
  27. g = self._findhash.match(key)
  28. if g is not None and len(g.groups()) > 0:
  29. key = g.groups()[0]
  30. name = self._ring.get_node(key)
  31. return name
  32. def get_server(self, key):
  33. name = self.get_server_name(key)
  34. return self._serverdict[name]
  35. def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None):
  36. if client is None:
  37. key = self.make_key(key, version=version)
  38. client = self.get_server(key)
  39. return super().add(
  40. key=key, value=value, version=version, client=client, timeout=timeout
  41. )
  42. def get(self, key, default=None, version=None, client=None):
  43. if client is None:
  44. key = self.make_key(key, version=version)
  45. client = self.get_server(key)
  46. return super().get(key=key, default=default, version=version, client=client)
  47. def get_many(self, keys, version=None):
  48. if not keys:
  49. return {}
  50. recovered_data = OrderedDict()
  51. new_keys = [self.make_key(key, version=version) for key in keys]
  52. map_keys = dict(zip(new_keys, keys))
  53. for key in new_keys:
  54. client = self.get_server(key)
  55. value = self.get(key=key, version=version, client=client)
  56. if value is None:
  57. continue
  58. recovered_data[map_keys[key]] = value
  59. return recovered_data
  60. def set(
  61. self, key, value, timeout=DEFAULT_TIMEOUT, version=None, client=None, nx=False
  62. ):
  63. """
  64. Persist a value to the cache, and set an optional expiration time.
  65. """
  66. if client is None:
  67. key = self.make_key(key, version=version)
  68. client = self.get_server(key)
  69. return super().set(
  70. key=key, value=value, timeout=timeout, version=version, client=client, nx=nx
  71. )
  72. def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
  73. """
  74. Set a bunch of values in the cache at once from a dict of key/value
  75. pairs. This is much more efficient than calling set() multiple times.
  76. If timeout is given, that timeout will be used for the key; otherwise
  77. the default cache timeout will be used.
  78. """
  79. for key, value in data.items():
  80. self.set(key, value, timeout, version=version)
  81. def has_key(self, key, version=None, client=None):
  82. """
  83. Test if key exists.
  84. """
  85. if client is None:
  86. key = self.make_key(key, version=version)
  87. client = self.get_server(key)
  88. key = self.make_key(key, version=version)
  89. try:
  90. return client.exists(key) == 1
  91. except ConnectionError as e:
  92. raise ConnectionInterrupted(connection=client) from e
  93. def delete(self, key, version=None, client=None):
  94. if client is None:
  95. key = self.make_key(key, version=version)
  96. client = self.get_server(key)
  97. return super().delete(key=key, version=version, client=client)
  98. def ttl(self, key, version=None, client=None):
  99. """
  100. Executes TTL redis command and return the "time-to-live" of specified key.
  101. If key is a non volatile key, it returns None.
  102. """
  103. if client is None:
  104. key = self.make_key(key, version=version)
  105. client = self.get_server(key)
  106. return super().ttl(key=key, version=version, client=client)
  107. def pttl(self, key, version=None, client=None):
  108. """
  109. Executes PTTL redis command and return the "time-to-live" of specified key
  110. in milliseconds. If key is a non volatile key, it returns None.
  111. """
  112. if client is None:
  113. key = self.make_key(key, version=version)
  114. client = self.get_server(key)
  115. return super().pttl(key=key, version=version, client=client)
  116. def persist(self, key, version=None, client=None):
  117. if client is None:
  118. key = self.make_key(key, version=version)
  119. client = self.get_server(key)
  120. return super().persist(key=key, version=version, client=client)
  121. def expire(self, key, timeout, version=None, client=None):
  122. if client is None:
  123. key = self.make_key(key, version=version)
  124. client = self.get_server(key)
  125. return super().expire(key=key, timeout=timeout, version=version, client=client)
  126. def pexpire(self, key, timeout, version=None, client=None):
  127. if client is None:
  128. key = self.make_key(key, version=version)
  129. client = self.get_server(key)
  130. return super().pexpire(key=key, timeout=timeout, version=version, client=client)
  131. def pexpire_at(self, key, when: Union[datetime, int], version=None, client=None):
  132. """
  133. Set an expire flag on a ``key`` to ``when`` on a shard client.
  134. ``when`` which can be represented as an integer indicating unix
  135. time or a Python datetime object.
  136. """
  137. if client is None:
  138. key = self.make_key(key, version=version)
  139. client = self.get_server(key)
  140. return super().pexpire_at(key=key, when=when, version=version, client=client)
  141. def expire_at(self, key, when: Union[datetime, int], version=None, client=None):
  142. """
  143. Set an expire flag on a ``key`` to ``when`` on a shard client.
  144. ``when`` which can be represented as an integer indicating unix
  145. time or a Python datetime object.
  146. """
  147. if client is None:
  148. key = self.make_key(key, version=version)
  149. client = self.get_server(key)
  150. return super().expire_at(key=key, when=when, version=version, client=client)
  151. def lock(
  152. self,
  153. key,
  154. version=None,
  155. timeout=None,
  156. sleep=0.1,
  157. blocking_timeout=None,
  158. client=None,
  159. thread_local=True,
  160. ):
  161. if client is None:
  162. key = self.make_key(key, version=version)
  163. client = self.get_server(key)
  164. key = self.make_key(key, version=version)
  165. return super().lock(
  166. key,
  167. timeout=timeout,
  168. sleep=sleep,
  169. client=client,
  170. blocking_timeout=blocking_timeout,
  171. thread_local=thread_local,
  172. )
  173. def delete_many(self, keys, version=None):
  174. """
  175. Remove multiple keys at once.
  176. """
  177. res = 0
  178. for key in [self.make_key(k, version=version) for k in keys]:
  179. client = self.get_server(key)
  180. res += self.delete(key, client=client)
  181. return res
  182. def incr_version(self, key, delta=1, version=None, client=None):
  183. if client is None:
  184. key = self.make_key(key, version=version)
  185. client = self.get_server(key)
  186. if version is None:
  187. version = self._backend.version
  188. old_key = self.make_key(key, version)
  189. value = self.get(old_key, version=version, client=client)
  190. try:
  191. ttl = self.ttl(old_key, version=version, client=client)
  192. except ConnectionError as e:
  193. raise ConnectionInterrupted(connection=client) from e
  194. if value is None:
  195. raise ValueError("Key '%s' not found" % key)
  196. if isinstance(key, CacheKey):
  197. new_key = self.make_key(key.original_key(), version=version + delta)
  198. else:
  199. new_key = self.make_key(key, version=version + delta)
  200. self.set(new_key, value, timeout=ttl, client=self.get_server(new_key))
  201. self.delete(old_key, client=client)
  202. return version + delta
  203. def incr(self, key, delta=1, version=None, client=None):
  204. if client is None:
  205. key = self.make_key(key, version=version)
  206. client = self.get_server(key)
  207. return super().incr(key=key, delta=delta, version=version, client=client)
  208. def decr(self, key, delta=1, version=None, client=None):
  209. if client is None:
  210. key = self.make_key(key, version=version)
  211. client = self.get_server(key)
  212. return super().decr(key=key, delta=delta, version=version, client=client)
  213. def iter_keys(self, key, version=None):
  214. raise NotImplementedError("iter_keys not supported on sharded client")
  215. def keys(self, search, version=None):
  216. pattern = self.make_pattern(search, version=version)
  217. keys = []
  218. try:
  219. for server, connection in self._serverdict.items():
  220. keys.extend(connection.keys(pattern))
  221. except ConnectionError as e:
  222. # FIXME: technically all clients should be passed as `connection`.
  223. client = self.get_server(pattern)
  224. raise ConnectionInterrupted(connection=client) from e
  225. return [self.reverse_key(k.decode()) for k in keys]
  226. def delete_pattern(
  227. self, pattern, version=None, client=None, itersize=None, prefix=None
  228. ):
  229. """
  230. Remove all keys matching pattern.
  231. """
  232. pattern = self.make_pattern(pattern, version=version, prefix=prefix)
  233. kwargs = {"match": pattern}
  234. if itersize:
  235. kwargs["count"] = itersize
  236. keys = []
  237. for server, connection in self._serverdict.items():
  238. keys.extend(key for key in connection.scan_iter(**kwargs))
  239. res = 0
  240. if keys:
  241. for server, connection in self._serverdict.items():
  242. res += connection.delete(*keys)
  243. return res
  244. def do_close_clients(self):
  245. for client in self._serverdict.values():
  246. self.disconnect(client=client)
  247. def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None):
  248. if client is None:
  249. key = self.make_key(key, version=version)
  250. client = self.get_server(key)
  251. return super().touch(key=key, timeout=timeout, version=version, client=client)
  252. def clear(self, client=None):
  253. for connection in self._serverdict.values():
  254. connection.flushdb()