pool.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from typing import Dict
  2. from urllib.parse import parse_qs, urlparse
  3. from django.conf import settings
  4. from django.core.exceptions import ImproperlyConfigured
  5. from django.utils.module_loading import import_string
  6. from redis import Redis
  7. from redis.connection import DefaultParser, to_bool
  8. from redis.sentinel import Sentinel
  9. class ConnectionFactory:
  10. # Store connection pool by cache backend options.
  11. #
  12. # _pools is a process-global, as otherwise _pools is cleared every time
  13. # ConnectionFactory is instantiated, as Django creates new cache client
  14. # (DefaultClient) instance for every request.
  15. _pools: Dict[str, Redis] = {}
  16. def __init__(self, options):
  17. pool_cls_path = options.get(
  18. "CONNECTION_POOL_CLASS", "redis.connection.ConnectionPool"
  19. )
  20. self.pool_cls = import_string(pool_cls_path)
  21. self.pool_cls_kwargs = options.get("CONNECTION_POOL_KWARGS", {})
  22. redis_client_cls_path = options.get("REDIS_CLIENT_CLASS", "redis.client.Redis")
  23. self.redis_client_cls = import_string(redis_client_cls_path)
  24. self.redis_client_cls_kwargs = options.get("REDIS_CLIENT_KWARGS", {})
  25. self.options = options
  26. def make_connection_params(self, url):
  27. """
  28. Given a main connection parameters, build a complete
  29. dict of connection parameters.
  30. """
  31. kwargs = {
  32. "url": url,
  33. "parser_class": self.get_parser_cls(),
  34. }
  35. password = self.options.get("PASSWORD", None)
  36. if password:
  37. kwargs["password"] = password
  38. socket_timeout = self.options.get("SOCKET_TIMEOUT", None)
  39. if socket_timeout:
  40. assert isinstance(
  41. socket_timeout, (int, float)
  42. ), "Socket timeout should be float or integer"
  43. kwargs["socket_timeout"] = socket_timeout
  44. socket_connect_timeout = self.options.get("SOCKET_CONNECT_TIMEOUT", None)
  45. if socket_connect_timeout:
  46. assert isinstance(
  47. socket_connect_timeout, (int, float)
  48. ), "Socket connect timeout should be float or integer"
  49. kwargs["socket_connect_timeout"] = socket_connect_timeout
  50. return kwargs
  51. def connect(self, url: str) -> Redis:
  52. """
  53. Given a basic connection parameters,
  54. return a new connection.
  55. """
  56. params = self.make_connection_params(url)
  57. connection = self.get_connection(params)
  58. return connection
  59. def disconnect(self, connection):
  60. """
  61. Given a not null client connection it disconnect from the Redis server.
  62. The default implementation uses a pool to hold connections.
  63. """
  64. connection.connection_pool.disconnect()
  65. def get_connection(self, params):
  66. """
  67. Given a now preformatted params, return a
  68. new connection.
  69. The default implementation uses a cached pools
  70. for create new connection.
  71. """
  72. pool = self.get_or_create_connection_pool(params)
  73. return self.redis_client_cls(
  74. connection_pool=pool, **self.redis_client_cls_kwargs
  75. )
  76. def get_parser_cls(self):
  77. cls = self.options.get("PARSER_CLASS", None)
  78. if cls is None:
  79. return DefaultParser
  80. return import_string(cls)
  81. def get_or_create_connection_pool(self, params):
  82. """
  83. Given a connection parameters and return a new
  84. or cached connection pool for them.
  85. Reimplement this method if you want distinct
  86. connection pool instance caching behavior.
  87. """
  88. key = params["url"]
  89. if key not in self._pools:
  90. self._pools[key] = self.get_connection_pool(params)
  91. return self._pools[key]
  92. def get_connection_pool(self, params):
  93. """
  94. Given a connection parameters, return a new
  95. connection pool for them.
  96. Overwrite this method if you want a custom
  97. behavior on creating connection pool.
  98. """
  99. cp_params = dict(params)
  100. cp_params.update(self.pool_cls_kwargs)
  101. pool = self.pool_cls.from_url(**cp_params)
  102. if pool.connection_kwargs.get("password", None) is None:
  103. pool.connection_kwargs["password"] = params.get("password", None)
  104. pool.reset()
  105. return pool
  106. class SentinelConnectionFactory(ConnectionFactory):
  107. def __init__(self, options):
  108. # allow overriding the default SentinelConnectionPool class
  109. options.setdefault(
  110. "CONNECTION_POOL_CLASS", "redis.sentinel.SentinelConnectionPool"
  111. )
  112. super().__init__(options)
  113. sentinels = options.get("SENTINELS")
  114. if not sentinels:
  115. raise ImproperlyConfigured(
  116. "SENTINELS must be provided as a list of (host, port)."
  117. )
  118. # provide the connection pool kwargs to the sentinel in case it
  119. # needs to use the socket options for the sentinels themselves
  120. connection_kwargs = self.make_connection_params(None)
  121. connection_kwargs.pop("url")
  122. connection_kwargs.update(self.pool_cls_kwargs)
  123. self._sentinel = Sentinel(
  124. sentinels,
  125. sentinel_kwargs=options.get("SENTINEL_KWARGS"),
  126. **connection_kwargs,
  127. )
  128. def get_connection_pool(self, params):
  129. """
  130. Given a connection parameters, return a new sentinel connection pool
  131. for them.
  132. """
  133. url = urlparse(params["url"])
  134. # explicitly set service_name and sentinel_manager for the
  135. # SentinelConnectionPool constructor since will be called by from_url
  136. cp_params = dict(params)
  137. cp_params.update(service_name=url.hostname, sentinel_manager=self._sentinel)
  138. pool = super().get_connection_pool(cp_params)
  139. # convert "is_master" to a boolean if set on the URL, otherwise if not
  140. # provided it defaults to True.
  141. is_master = parse_qs(url.query).get("is_master")
  142. if is_master:
  143. pool.is_master = to_bool(is_master[0])
  144. return pool
  145. def get_connection_factory(path=None, options=None):
  146. if path is None:
  147. path = getattr(
  148. settings,
  149. "DJANGO_REDIS_CONNECTION_FACTORY",
  150. "django_redis.pool.ConnectionFactory",
  151. )
  152. opt_conn_factory = options.get("CONNECTION_FACTORY")
  153. if opt_conn_factory:
  154. path = opt_conn_factory
  155. cls = import_string(path)
  156. return cls(options or {})