generated.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from django.core import checks
  2. from django.db import connections, router
  3. from django.db.models.sql import Query
  4. from django.utils.functional import cached_property
  5. from . import NOT_PROVIDED, Field
  6. __all__ = ["GeneratedField"]
  7. class GeneratedField(Field):
  8. generated = True
  9. db_returning = True
  10. _query = None
  11. output_field = None
  12. def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
  13. if kwargs.setdefault("editable", False):
  14. raise ValueError("GeneratedField cannot be editable.")
  15. if not kwargs.setdefault("blank", True):
  16. raise ValueError("GeneratedField must be blank.")
  17. if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
  18. raise ValueError("GeneratedField cannot have a default.")
  19. if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
  20. raise ValueError("GeneratedField cannot have a database default.")
  21. if db_persist not in (True, False):
  22. raise ValueError("GeneratedField.db_persist must be True or False.")
  23. self.expression = expression
  24. self.output_field = output_field
  25. self.db_persist = db_persist
  26. super().__init__(**kwargs)
  27. @cached_property
  28. def cached_col(self):
  29. from django.db.models.expressions import Col
  30. return Col(self.model._meta.db_table, self, self.output_field)
  31. def get_col(self, alias, output_field=None):
  32. if alias != self.model._meta.db_table and output_field in (None, self):
  33. output_field = self.output_field
  34. return super().get_col(alias, output_field)
  35. def contribute_to_class(self, *args, **kwargs):
  36. super().contribute_to_class(*args, **kwargs)
  37. self._query = Query(model=self.model, alias_cols=False)
  38. # Register lookups from the output_field class.
  39. for lookup_name, lookup in self.output_field.get_class_lookups().items():
  40. self.register_lookup(lookup, lookup_name=lookup_name)
  41. def generated_sql(self, connection):
  42. compiler = connection.ops.compiler("SQLCompiler")(
  43. self._query, connection=connection, using=None
  44. )
  45. resolved_expression = self.expression.resolve_expression(
  46. self._query, allow_joins=False
  47. )
  48. sql, params = compiler.compile(resolved_expression)
  49. if (
  50. getattr(self.expression, "conditional", False)
  51. and not connection.features.supports_boolean_expr_in_select_clause
  52. ):
  53. sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
  54. return sql, params
  55. def check(self, **kwargs):
  56. databases = kwargs.get("databases") or []
  57. errors = [
  58. *super().check(**kwargs),
  59. *self._check_supported(databases),
  60. *self._check_persistence(databases),
  61. ]
  62. output_field_clone = self.output_field.clone()
  63. output_field_clone.model = self.model
  64. output_field_checks = output_field_clone.check(databases=databases)
  65. if output_field_checks:
  66. separator = "\n "
  67. error_messages = separator.join(
  68. f"{output_check.msg} ({output_check.id})"
  69. for output_check in output_field_checks
  70. if isinstance(output_check, checks.Error)
  71. )
  72. if error_messages:
  73. errors.append(
  74. checks.Error(
  75. "GeneratedField.output_field has errors:"
  76. f"{separator}{error_messages}",
  77. obj=self,
  78. id="fields.E223",
  79. )
  80. )
  81. warning_messages = separator.join(
  82. f"{output_check.msg} ({output_check.id})"
  83. for output_check in output_field_checks
  84. if isinstance(output_check, checks.Warning)
  85. )
  86. if warning_messages:
  87. errors.append(
  88. checks.Warning(
  89. "GeneratedField.output_field has warnings:"
  90. f"{separator}{warning_messages}",
  91. obj=self,
  92. id="fields.W224",
  93. )
  94. )
  95. return errors
  96. def _check_supported(self, databases):
  97. errors = []
  98. for db in databases:
  99. if not router.allow_migrate_model(db, self.model):
  100. continue
  101. connection = connections[db]
  102. if (
  103. self.model._meta.required_db_vendor
  104. and self.model._meta.required_db_vendor != connection.vendor
  105. ):
  106. continue
  107. if not (
  108. connection.features.supports_virtual_generated_columns
  109. or "supports_stored_generated_columns"
  110. in self.model._meta.required_db_features
  111. ) and not (
  112. connection.features.supports_stored_generated_columns
  113. or "supports_virtual_generated_columns"
  114. in self.model._meta.required_db_features
  115. ):
  116. errors.append(
  117. checks.Error(
  118. f"{connection.display_name} does not support GeneratedFields.",
  119. obj=self,
  120. id="fields.E220",
  121. )
  122. )
  123. return errors
  124. def _check_persistence(self, databases):
  125. errors = []
  126. for db in databases:
  127. if not router.allow_migrate_model(db, self.model):
  128. continue
  129. connection = connections[db]
  130. if (
  131. self.model._meta.required_db_vendor
  132. and self.model._meta.required_db_vendor != connection.vendor
  133. ):
  134. continue
  135. if not self.db_persist and not (
  136. connection.features.supports_virtual_generated_columns
  137. or "supports_virtual_generated_columns"
  138. in self.model._meta.required_db_features
  139. ):
  140. errors.append(
  141. checks.Error(
  142. f"{connection.display_name} does not support non-persisted "
  143. "GeneratedFields.",
  144. obj=self,
  145. id="fields.E221",
  146. hint="Set db_persist=True on the field.",
  147. )
  148. )
  149. if self.db_persist and not (
  150. connection.features.supports_stored_generated_columns
  151. or "supports_stored_generated_columns"
  152. in self.model._meta.required_db_features
  153. ):
  154. errors.append(
  155. checks.Error(
  156. f"{connection.display_name} does not support persisted "
  157. "GeneratedFields.",
  158. obj=self,
  159. id="fields.E222",
  160. hint="Set db_persist=False on the field.",
  161. )
  162. )
  163. return errors
  164. def deconstruct(self):
  165. name, path, args, kwargs = super().deconstruct()
  166. del kwargs["blank"]
  167. del kwargs["editable"]
  168. kwargs["db_persist"] = self.db_persist
  169. kwargs["expression"] = self.expression
  170. kwargs["output_field"] = self.output_field
  171. return name, path, args, kwargs
  172. def get_internal_type(self):
  173. return self.output_field.get_internal_type()
  174. def db_parameters(self, connection):
  175. return self.output_field.db_parameters(connection)
  176. def db_type_parameters(self, connection):
  177. return self.output_field.db_type_parameters(connection)