base.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. """
  2. PostgreSQL database backend for Django.
  3. Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8
  4. """
  5. import asyncio
  6. import threading
  7. import warnings
  8. from contextlib import contextmanager
  9. from django.conf import settings
  10. from django.core.exceptions import ImproperlyConfigured
  11. from django.db import DatabaseError as WrappedDatabaseError
  12. from django.db import connections
  13. from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
  14. from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
  15. from django.utils.asyncio import async_unsafe
  16. from django.utils.functional import cached_property
  17. from django.utils.safestring import SafeString
  18. from django.utils.version import get_version_tuple
  19. try:
  20. try:
  21. import psycopg as Database
  22. except ImportError:
  23. import psycopg2 as Database
  24. except ImportError:
  25. raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
  26. def psycopg_version():
  27. version = Database.__version__.split(" ", 1)[0]
  28. return get_version_tuple(version)
  29. if psycopg_version() < (2, 8, 4):
  30. raise ImproperlyConfigured(
  31. f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
  32. )
  33. if (3,) <= psycopg_version() < (3, 1, 8):
  34. raise ImproperlyConfigured(
  35. f"psycopg version 3.1.8 or newer is required; you have {Database.__version__}"
  36. )
  37. from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
  38. if is_psycopg3:
  39. from psycopg import adapters, sql
  40. from psycopg.pq import Format
  41. from .psycopg_any import get_adapters_template, register_tzloader
  42. TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
  43. else:
  44. import psycopg2.extensions
  45. import psycopg2.extras
  46. psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
  47. psycopg2.extras.register_uuid()
  48. # Register support for inet[] manually so we don't have to handle the Inet()
  49. # object on load all the time.
  50. INETARRAY_OID = 1041
  51. INETARRAY = psycopg2.extensions.new_array_type(
  52. (INETARRAY_OID,),
  53. "INETARRAY",
  54. psycopg2.extensions.UNICODE,
  55. )
  56. psycopg2.extensions.register_type(INETARRAY)
  57. # Some of these import psycopg, so import them after checking if it's installed.
  58. from .client import DatabaseClient # NOQA isort:skip
  59. from .creation import DatabaseCreation # NOQA isort:skip
  60. from .features import DatabaseFeatures # NOQA isort:skip
  61. from .introspection import DatabaseIntrospection # NOQA isort:skip
  62. from .operations import DatabaseOperations # NOQA isort:skip
  63. from .schema import DatabaseSchemaEditor # NOQA isort:skip
  64. def _get_varchar_column(data):
  65. if data["max_length"] is None:
  66. return "varchar"
  67. return "varchar(%(max_length)s)" % data
  68. class DatabaseWrapper(BaseDatabaseWrapper):
  69. vendor = "postgresql"
  70. display_name = "PostgreSQL"
  71. # This dictionary maps Field objects to their associated PostgreSQL column
  72. # types, as strings. Column-type strings can contain format strings; they'll
  73. # be interpolated against the values of Field.__dict__ before being output.
  74. # If a column type is set to None, it won't be included in the output.
  75. data_types = {
  76. "AutoField": "integer",
  77. "BigAutoField": "bigint",
  78. "BinaryField": "bytea",
  79. "BooleanField": "boolean",
  80. "CharField": _get_varchar_column,
  81. "DateField": "date",
  82. "DateTimeField": "timestamp with time zone",
  83. "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
  84. "DurationField": "interval",
  85. "FileField": "varchar(%(max_length)s)",
  86. "FilePathField": "varchar(%(max_length)s)",
  87. "FloatField": "double precision",
  88. "IntegerField": "integer",
  89. "BigIntegerField": "bigint",
  90. "IPAddressField": "inet",
  91. "GenericIPAddressField": "inet",
  92. "JSONField": "jsonb",
  93. "OneToOneField": "integer",
  94. "PositiveBigIntegerField": "bigint",
  95. "PositiveIntegerField": "integer",
  96. "PositiveSmallIntegerField": "smallint",
  97. "SlugField": "varchar(%(max_length)s)",
  98. "SmallAutoField": "smallint",
  99. "SmallIntegerField": "smallint",
  100. "TextField": "text",
  101. "TimeField": "time",
  102. "UUIDField": "uuid",
  103. }
  104. data_type_check_constraints = {
  105. "PositiveBigIntegerField": '"%(column)s" >= 0',
  106. "PositiveIntegerField": '"%(column)s" >= 0',
  107. "PositiveSmallIntegerField": '"%(column)s" >= 0',
  108. }
  109. data_types_suffix = {
  110. "AutoField": "GENERATED BY DEFAULT AS IDENTITY",
  111. "BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
  112. "SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
  113. }
  114. operators = {
  115. "exact": "= %s",
  116. "iexact": "= UPPER(%s)",
  117. "contains": "LIKE %s",
  118. "icontains": "LIKE UPPER(%s)",
  119. "regex": "~ %s",
  120. "iregex": "~* %s",
  121. "gt": "> %s",
  122. "gte": ">= %s",
  123. "lt": "< %s",
  124. "lte": "<= %s",
  125. "startswith": "LIKE %s",
  126. "endswith": "LIKE %s",
  127. "istartswith": "LIKE UPPER(%s)",
  128. "iendswith": "LIKE UPPER(%s)",
  129. }
  130. # The patterns below are used to generate SQL pattern lookup clauses when
  131. # the right-hand side of the lookup isn't a raw string (it might be an expression
  132. # or the result of a bilateral transformation).
  133. # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
  134. # escaped on database side.
  135. #
  136. # Note: we use str.format() here for readability as '%' is used as a wildcard for
  137. # the LIKE operator.
  138. pattern_esc = (
  139. r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
  140. )
  141. pattern_ops = {
  142. "contains": "LIKE '%%' || {} || '%%'",
  143. "icontains": "LIKE '%%' || UPPER({}) || '%%'",
  144. "startswith": "LIKE {} || '%%'",
  145. "istartswith": "LIKE UPPER({}) || '%%'",
  146. "endswith": "LIKE '%%' || {}",
  147. "iendswith": "LIKE '%%' || UPPER({})",
  148. }
  149. Database = Database
  150. SchemaEditorClass = DatabaseSchemaEditor
  151. # Classes instantiated in __init__().
  152. client_class = DatabaseClient
  153. creation_class = DatabaseCreation
  154. features_class = DatabaseFeatures
  155. introspection_class = DatabaseIntrospection
  156. ops_class = DatabaseOperations
  157. # PostgreSQL backend-specific attributes.
  158. _named_cursor_idx = 0
  159. _connection_pools = {}
  160. @property
  161. def pool(self):
  162. pool_options = self.settings_dict["OPTIONS"].get("pool")
  163. if self.alias == NO_DB_ALIAS or not pool_options:
  164. return None
  165. if self.alias not in self._connection_pools:
  166. if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
  167. raise ImproperlyConfigured(
  168. "Pooling doesn't support persistent connections."
  169. )
  170. # Set the default options.
  171. if pool_options is True:
  172. pool_options = {}
  173. try:
  174. from psycopg_pool import ConnectionPool
  175. except ImportError as err:
  176. raise ImproperlyConfigured(
  177. "Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
  178. ) from err
  179. connect_kwargs = self.get_connection_params()
  180. # Ensure we run in autocommit, Django properly sets it later on.
  181. connect_kwargs["autocommit"] = True
  182. enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
  183. pool = ConnectionPool(
  184. kwargs=connect_kwargs,
  185. open=False, # Do not open the pool during startup.
  186. configure=self._configure_connection,
  187. check=ConnectionPool.check_connection if enable_checks else None,
  188. **pool_options,
  189. )
  190. # setdefault() ensures that multiple threads don't set this in
  191. # parallel. Since we do not open the pool during it's init above,
  192. # this means that at worst during startup multiple threads generate
  193. # pool objects and the first to set it wins.
  194. self._connection_pools.setdefault(self.alias, pool)
  195. return self._connection_pools[self.alias]
  196. def close_pool(self):
  197. if self.pool:
  198. self.pool.close()
  199. del self._connection_pools[self.alias]
  200. def get_database_version(self):
  201. """
  202. Return a tuple of the database's version.
  203. E.g. for pg_version 120004, return (12, 4).
  204. """
  205. return divmod(self.pg_version, 10000)
  206. def get_connection_params(self):
  207. settings_dict = self.settings_dict
  208. # None may be used to connect to the default 'postgres' db
  209. if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"):
  210. raise ImproperlyConfigured(
  211. "settings.DATABASES is improperly configured. "
  212. "Please supply the NAME or OPTIONS['service'] value."
  213. )
  214. if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
  215. raise ImproperlyConfigured(
  216. "The database name '%s' (%d characters) is longer than "
  217. "PostgreSQL's limit of %d characters. Supply a shorter NAME "
  218. "in settings.DATABASES."
  219. % (
  220. settings_dict["NAME"],
  221. len(settings_dict["NAME"]),
  222. self.ops.max_name_length(),
  223. )
  224. )
  225. if settings_dict["NAME"]:
  226. conn_params = {
  227. "dbname": settings_dict["NAME"],
  228. **settings_dict["OPTIONS"],
  229. }
  230. elif settings_dict["NAME"] is None:
  231. # Connect to the default 'postgres' db.
  232. settings_dict["OPTIONS"].pop("service", None)
  233. conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]}
  234. else:
  235. conn_params = {**settings_dict["OPTIONS"]}
  236. conn_params["client_encoding"] = "UTF8"
  237. conn_params.pop("assume_role", None)
  238. conn_params.pop("isolation_level", None)
  239. pool_options = conn_params.pop("pool", None)
  240. if pool_options and not is_psycopg3:
  241. raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
  242. server_side_binding = conn_params.pop("server_side_binding", None)
  243. conn_params.setdefault(
  244. "cursor_factory",
  245. (
  246. ServerBindingCursor
  247. if is_psycopg3 and server_side_binding is True
  248. else Cursor
  249. ),
  250. )
  251. if settings_dict["USER"]:
  252. conn_params["user"] = settings_dict["USER"]
  253. if settings_dict["PASSWORD"]:
  254. conn_params["password"] = settings_dict["PASSWORD"]
  255. if settings_dict["HOST"]:
  256. conn_params["host"] = settings_dict["HOST"]
  257. if settings_dict["PORT"]:
  258. conn_params["port"] = settings_dict["PORT"]
  259. if is_psycopg3:
  260. conn_params["context"] = get_adapters_template(
  261. settings.USE_TZ, self.timezone
  262. )
  263. # Disable prepared statements by default to keep connection poolers
  264. # working. Can be reenabled via OPTIONS in the settings dict.
  265. conn_params["prepare_threshold"] = conn_params.pop(
  266. "prepare_threshold", None
  267. )
  268. return conn_params
  269. @async_unsafe
  270. def get_new_connection(self, conn_params):
  271. # self.isolation_level must be set:
  272. # - after connecting to the database in order to obtain the database's
  273. # default when no value is explicitly specified in options.
  274. # - before calling _set_autocommit() because if autocommit is on, that
  275. # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
  276. options = self.settings_dict["OPTIONS"]
  277. set_isolation_level = False
  278. try:
  279. isolation_level_value = options["isolation_level"]
  280. except KeyError:
  281. self.isolation_level = IsolationLevel.READ_COMMITTED
  282. else:
  283. # Set the isolation level to the value from OPTIONS.
  284. try:
  285. self.isolation_level = IsolationLevel(isolation_level_value)
  286. set_isolation_level = True
  287. except ValueError:
  288. raise ImproperlyConfigured(
  289. f"Invalid transaction isolation level {isolation_level_value} "
  290. f"specified. Use one of the psycopg.IsolationLevel values."
  291. )
  292. if self.pool:
  293. # If nothing else has opened the pool, open it now.
  294. self.pool.open()
  295. connection = self.pool.getconn()
  296. else:
  297. connection = self.Database.connect(**conn_params)
  298. if set_isolation_level:
  299. connection.isolation_level = self.isolation_level
  300. if not is_psycopg3:
  301. # Register dummy loads() to avoid a round trip from psycopg2's
  302. # decode to json.dumps() to json.loads(), when using a custom
  303. # decoder in JSONField.
  304. psycopg2.extras.register_default_jsonb(
  305. conn_or_curs=connection, loads=lambda x: x
  306. )
  307. return connection
  308. def ensure_timezone(self):
  309. # Close the pool so new connections pick up the correct timezone.
  310. self.close_pool()
  311. if self.connection is None:
  312. return False
  313. return self._configure_timezone(self.connection)
  314. def _configure_timezone(self, connection):
  315. conn_timezone_name = connection.info.parameter_status("TimeZone")
  316. timezone_name = self.timezone_name
  317. if timezone_name and conn_timezone_name != timezone_name:
  318. with connection.cursor() as cursor:
  319. cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
  320. return True
  321. return False
  322. def _configure_role(self, connection):
  323. if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
  324. with connection.cursor() as cursor:
  325. sql = self.ops.compose_sql("SET ROLE %s", [new_role])
  326. cursor.execute(sql)
  327. return True
  328. return False
  329. def _configure_connection(self, connection):
  330. # This function is called from init_connection_state and from the
  331. # psycopg pool itself after a connection is opened.
  332. # Commit after setting the time zone.
  333. commit_tz = self._configure_timezone(connection)
  334. # Set the role on the connection. This is useful if the credential used
  335. # to login is not the same as the role that owns database resources. As
  336. # can be the case when using temporary or ephemeral credentials.
  337. commit_role = self._configure_role(connection)
  338. return commit_role or commit_tz
  339. def _close(self):
  340. if self.connection is not None:
  341. # `wrap_database_errors` only works for `putconn` as long as there
  342. # is no `reset` function set in the pool because it is deferred
  343. # into a thread and not directly executed.
  344. with self.wrap_database_errors:
  345. if self.pool:
  346. # Ensure the correct pool is returned. This is a workaround
  347. # for tests so a pool can be changed on setting changes
  348. # (e.g. USE_TZ, TIME_ZONE).
  349. self.connection._pool.putconn(self.connection)
  350. # Connection can no longer be used.
  351. self.connection = None
  352. else:
  353. return self.connection.close()
  354. def init_connection_state(self):
  355. super().init_connection_state()
  356. if self.connection is not None and not self.pool:
  357. commit = self._configure_connection(self.connection)
  358. if commit and not self.get_autocommit():
  359. self.connection.commit()
  360. @async_unsafe
  361. def create_cursor(self, name=None):
  362. if name:
  363. if is_psycopg3 and (
  364. self.settings_dict["OPTIONS"].get("server_side_binding") is not True
  365. ):
  366. # psycopg >= 3 forces the usage of server-side bindings for
  367. # named cursors so a specialized class that implements
  368. # server-side cursors while performing client-side bindings
  369. # must be used if `server_side_binding` is disabled (default).
  370. cursor = ServerSideCursor(
  371. self.connection,
  372. name=name,
  373. scrollable=False,
  374. withhold=self.connection.autocommit,
  375. )
  376. else:
  377. # In autocommit mode, the cursor will be used outside of a
  378. # transaction, hence use a holdable cursor.
  379. cursor = self.connection.cursor(
  380. name, scrollable=False, withhold=self.connection.autocommit
  381. )
  382. else:
  383. cursor = self.connection.cursor()
  384. if is_psycopg3:
  385. # Register the cursor timezone only if the connection disagrees, to
  386. # avoid copying the adapter map.
  387. tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
  388. if self.timezone != tzloader.timezone:
  389. register_tzloader(self.timezone, cursor)
  390. else:
  391. cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
  392. return cursor
  393. def tzinfo_factory(self, offset):
  394. return self.timezone
  395. @async_unsafe
  396. def chunked_cursor(self):
  397. self._named_cursor_idx += 1
  398. # Get the current async task
  399. # Note that right now this is behind @async_unsafe, so this is
  400. # unreachable, but in future we'll start loosening this restriction.
  401. # For now, it's here so that every use of "threading" is
  402. # also async-compatible.
  403. try:
  404. current_task = asyncio.current_task()
  405. except RuntimeError:
  406. current_task = None
  407. # Current task can be none even if the current_task call didn't error
  408. if current_task:
  409. task_ident = str(id(current_task))
  410. else:
  411. task_ident = "sync"
  412. # Use that and the thread ident to get a unique name
  413. return self._cursor(
  414. name="_django_curs_%d_%s_%d"
  415. % (
  416. # Avoid reusing name in other threads / tasks
  417. threading.current_thread().ident,
  418. task_ident,
  419. self._named_cursor_idx,
  420. )
  421. )
  422. def _set_autocommit(self, autocommit):
  423. with self.wrap_database_errors:
  424. self.connection.autocommit = autocommit
  425. def check_constraints(self, table_names=None):
  426. """
  427. Check constraints by setting them to immediate. Return them to deferred
  428. afterward.
  429. """
  430. with self.cursor() as cursor:
  431. cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
  432. cursor.execute("SET CONSTRAINTS ALL DEFERRED")
  433. def is_usable(self):
  434. if self.connection is None:
  435. return False
  436. try:
  437. # Use a psycopg cursor directly, bypassing Django's utilities.
  438. with self.connection.cursor() as cursor:
  439. cursor.execute("SELECT 1")
  440. except Database.Error:
  441. return False
  442. else:
  443. return True
  444. def close_if_health_check_failed(self):
  445. if self.pool:
  446. # The pool only returns healthy connections.
  447. return
  448. return super().close_if_health_check_failed()
  449. @contextmanager
  450. def _nodb_cursor(self):
  451. cursor = None
  452. try:
  453. with super()._nodb_cursor() as cursor:
  454. yield cursor
  455. except (Database.DatabaseError, WrappedDatabaseError):
  456. if cursor is not None:
  457. raise
  458. warnings.warn(
  459. "Normally Django will use a connection to the 'postgres' database "
  460. "to avoid running initialization queries against the production "
  461. "database when it's not needed (for example, when running tests). "
  462. "Django was unable to create a connection to the 'postgres' database "
  463. "and will use the first PostgreSQL database instead.",
  464. RuntimeWarning,
  465. )
  466. for connection in connections.all():
  467. if (
  468. connection.vendor == "postgresql"
  469. and connection.settings_dict["NAME"] != "postgres"
  470. ):
  471. conn = self.__class__(
  472. {
  473. **self.settings_dict,
  474. "NAME": connection.settings_dict["NAME"],
  475. },
  476. alias=self.alias,
  477. )
  478. try:
  479. with conn.cursor() as cursor:
  480. yield cursor
  481. finally:
  482. conn.close()
  483. break
  484. else:
  485. raise
  486. @cached_property
  487. def pg_version(self):
  488. with self.temporary_connection():
  489. return self.connection.info.server_version
  490. def make_debug_cursor(self, cursor):
  491. return CursorDebugWrapper(cursor, self)
  492. if is_psycopg3:
  493. class CursorMixin:
  494. """
  495. A subclass of psycopg cursor implementing callproc.
  496. """
  497. def callproc(self, name, args=None):
  498. if not isinstance(name, sql.Identifier):
  499. name = sql.Identifier(name)
  500. qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
  501. if args:
  502. for item in args:
  503. qparts.append(sql.Literal(item))
  504. qparts.append(sql.SQL(","))
  505. del qparts[-1]
  506. qparts.append(sql.SQL(")"))
  507. stmt = sql.Composed(qparts)
  508. self.execute(stmt)
  509. return args
  510. class ServerBindingCursor(CursorMixin, Database.Cursor):
  511. pass
  512. class Cursor(CursorMixin, Database.ClientCursor):
  513. pass
  514. class ServerSideCursor(
  515. CursorMixin, Database.client_cursor.ClientCursorMixin, Database.ServerCursor
  516. ):
  517. """
  518. psycopg >= 3 forces the usage of server-side bindings when using named
  519. cursors but the ORM doesn't yet support the systematic generation of
  520. prepareable SQL (#20516).
  521. ClientCursorMixin forces the usage of client-side bindings while
  522. ServerCursor implements the logic required to declare and scroll
  523. through named cursors.
  524. Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to
  525. specify how parameters should be bound instead, which ServerCursor
  526. would inherit, but that's not the case.
  527. """
  528. class CursorDebugWrapper(BaseCursorDebugWrapper):
  529. def copy(self, statement):
  530. with self.debug_sql(statement):
  531. return self.cursor.copy(statement)
  532. else:
  533. Cursor = psycopg2.extensions.cursor
  534. class CursorDebugWrapper(BaseCursorDebugWrapper):
  535. def copy_expert(self, sql, file, *args):
  536. with self.debug_sql(sql):
  537. return self.cursor.copy_expert(sql, file, *args)
  538. def copy_to(self, file, table, *args, **kwargs):
  539. with self.debug_sql(sql="COPY %s TO STDOUT" % table):
  540. return self.cursor.copy_to(file, table, *args, **kwargs)