constraints.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from types import NoneType
  2. from django.core.exceptions import ValidationError
  3. from django.db import DEFAULT_DB_ALIAS, NotSupportedError
  4. from django.db.backends.ddl_references import Expressions, Statement, Table
  5. from django.db.models import BaseConstraint, Deferrable, F, Q
  6. from django.db.models.expressions import Exists, ExpressionList
  7. from django.db.models.indexes import IndexExpression
  8. from django.db.models.lookups import PostgresOperatorLookup
  9. from django.db.models.sql import Query
  10. __all__ = ["ExclusionConstraint"]
  11. class ExclusionConstraintExpression(IndexExpression):
  12. template = "%(expressions)s WITH %(operator)s"
  13. class ExclusionConstraint(BaseConstraint):
  14. template = (
  15. "CONSTRAINT %(name)s EXCLUDE USING %(index_type)s "
  16. "(%(expressions)s)%(include)s%(where)s%(deferrable)s"
  17. )
  18. def __init__(
  19. self,
  20. *,
  21. name,
  22. expressions,
  23. index_type=None,
  24. condition=None,
  25. deferrable=None,
  26. include=None,
  27. violation_error_code=None,
  28. violation_error_message=None,
  29. ):
  30. if index_type and index_type.lower() not in {"gist", "spgist"}:
  31. raise ValueError(
  32. "Exclusion constraints only support GiST or SP-GiST indexes."
  33. )
  34. if not expressions:
  35. raise ValueError(
  36. "At least one expression is required to define an exclusion "
  37. "constraint."
  38. )
  39. if not all(
  40. isinstance(expr, (list, tuple)) and len(expr) == 2 for expr in expressions
  41. ):
  42. raise ValueError("The expressions must be a list of 2-tuples.")
  43. if not isinstance(condition, (NoneType, Q)):
  44. raise ValueError("ExclusionConstraint.condition must be a Q instance.")
  45. if not isinstance(deferrable, (NoneType, Deferrable)):
  46. raise ValueError(
  47. "ExclusionConstraint.deferrable must be a Deferrable instance."
  48. )
  49. if not isinstance(include, (NoneType, list, tuple)):
  50. raise ValueError("ExclusionConstraint.include must be a list or tuple.")
  51. self.expressions = expressions
  52. self.index_type = index_type or "GIST"
  53. self.condition = condition
  54. self.deferrable = deferrable
  55. self.include = tuple(include) if include else ()
  56. super().__init__(
  57. name=name,
  58. violation_error_code=violation_error_code,
  59. violation_error_message=violation_error_message,
  60. )
  61. def _get_expressions(self, schema_editor, query):
  62. expressions = []
  63. for idx, (expression, operator) in enumerate(self.expressions):
  64. if isinstance(expression, str):
  65. expression = F(expression)
  66. expression = ExclusionConstraintExpression(expression, operator=operator)
  67. expression.set_wrapper_classes(schema_editor.connection)
  68. expressions.append(expression)
  69. return ExpressionList(*expressions).resolve_expression(query)
  70. def _check(self, model, connection):
  71. references = set()
  72. for expr, _ in self.expressions:
  73. if isinstance(expr, str):
  74. expr = F(expr)
  75. references.update(model._get_expr_references(expr))
  76. return self._check_references(model, references)
  77. def _get_condition_sql(self, compiler, schema_editor, query):
  78. if self.condition is None:
  79. return None
  80. where = query.build_where(self.condition)
  81. sql, params = where.as_sql(compiler, schema_editor.connection)
  82. return sql % tuple(schema_editor.quote_value(p) for p in params)
  83. def constraint_sql(self, model, schema_editor):
  84. query = Query(model, alias_cols=False)
  85. compiler = query.get_compiler(connection=schema_editor.connection)
  86. expressions = self._get_expressions(schema_editor, query)
  87. table = model._meta.db_table
  88. condition = self._get_condition_sql(compiler, schema_editor, query)
  89. include = [
  90. model._meta.get_field(field_name).column for field_name in self.include
  91. ]
  92. return Statement(
  93. self.template,
  94. table=Table(table, schema_editor.quote_name),
  95. name=schema_editor.quote_name(self.name),
  96. index_type=self.index_type,
  97. expressions=Expressions(
  98. table, expressions, compiler, schema_editor.quote_value
  99. ),
  100. where=" WHERE (%s)" % condition if condition else "",
  101. include=schema_editor._index_include_sql(model, include),
  102. deferrable=schema_editor._deferrable_constraint_sql(self.deferrable),
  103. )
  104. def create_sql(self, model, schema_editor):
  105. self.check_supported(schema_editor)
  106. return Statement(
  107. "ALTER TABLE %(table)s ADD %(constraint)s",
  108. table=Table(model._meta.db_table, schema_editor.quote_name),
  109. constraint=self.constraint_sql(model, schema_editor),
  110. )
  111. def remove_sql(self, model, schema_editor):
  112. return schema_editor._delete_constraint_sql(
  113. schema_editor.sql_delete_check,
  114. model,
  115. schema_editor.quote_name(self.name),
  116. )
  117. def check_supported(self, schema_editor):
  118. if (
  119. self.include
  120. and self.index_type.lower() == "spgist"
  121. and not schema_editor.connection.features.supports_covering_spgist_indexes
  122. ):
  123. raise NotSupportedError(
  124. "Covering exclusion constraints using an SP-GiST index "
  125. "require PostgreSQL 14+."
  126. )
  127. def deconstruct(self):
  128. path, args, kwargs = super().deconstruct()
  129. kwargs["expressions"] = self.expressions
  130. if self.condition is not None:
  131. kwargs["condition"] = self.condition
  132. if self.index_type.lower() != "gist":
  133. kwargs["index_type"] = self.index_type
  134. if self.deferrable:
  135. kwargs["deferrable"] = self.deferrable
  136. if self.include:
  137. kwargs["include"] = self.include
  138. return path, args, kwargs
  139. def __eq__(self, other):
  140. if isinstance(other, self.__class__):
  141. return (
  142. self.name == other.name
  143. and self.index_type == other.index_type
  144. and self.expressions == other.expressions
  145. and self.condition == other.condition
  146. and self.deferrable == other.deferrable
  147. and self.include == other.include
  148. and self.violation_error_code == other.violation_error_code
  149. and self.violation_error_message == other.violation_error_message
  150. )
  151. return super().__eq__(other)
  152. def __repr__(self):
  153. return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s%s>" % (
  154. self.__class__.__qualname__,
  155. repr(self.index_type),
  156. repr(self.expressions),
  157. repr(self.name),
  158. "" if self.condition is None else " condition=%s" % self.condition,
  159. "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
  160. "" if not self.include else " include=%s" % repr(self.include),
  161. (
  162. ""
  163. if self.violation_error_code is None
  164. else " violation_error_code=%r" % self.violation_error_code
  165. ),
  166. (
  167. ""
  168. if self.violation_error_message is None
  169. or self.violation_error_message == self.default_violation_error_message
  170. else " violation_error_message=%r" % self.violation_error_message
  171. ),
  172. )
  173. def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
  174. queryset = model._default_manager.using(using)
  175. replacement_map = instance._get_field_expression_map(
  176. meta=model._meta, exclude=exclude
  177. )
  178. replacements = {F(field): value for field, value in replacement_map.items()}
  179. lookups = []
  180. for idx, (expression, operator) in enumerate(self.expressions):
  181. if isinstance(expression, str):
  182. expression = F(expression)
  183. if exclude:
  184. if isinstance(expression, F):
  185. if expression.name in exclude:
  186. return
  187. else:
  188. for expr in expression.flatten():
  189. if isinstance(expr, F) and expr.name in exclude:
  190. return
  191. rhs_expression = expression.replace_expressions(replacements)
  192. if hasattr(expression, "get_expression_for_validation"):
  193. expression = expression.get_expression_for_validation()
  194. if hasattr(rhs_expression, "get_expression_for_validation"):
  195. rhs_expression = rhs_expression.get_expression_for_validation()
  196. lookup = PostgresOperatorLookup(lhs=expression, rhs=rhs_expression)
  197. lookup.postgres_operator = operator
  198. lookups.append(lookup)
  199. queryset = queryset.filter(*lookups)
  200. model_class_pk = instance._get_pk_val(model._meta)
  201. if not instance._state.adding and model_class_pk is not None:
  202. queryset = queryset.exclude(pk=model_class_pk)
  203. if not self.condition:
  204. if queryset.exists():
  205. raise ValidationError(
  206. self.get_violation_error_message(), code=self.violation_error_code
  207. )
  208. else:
  209. if (self.condition & Exists(queryset.filter(self.condition))).check(
  210. replacement_map, using=using
  211. ):
  212. raise ValidationError(
  213. self.get_violation_error_message(), code=self.violation_error_code
  214. )