array.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import json
  2. from django.contrib.postgres import lookups
  3. from django.contrib.postgres.forms import SimpleArrayField
  4. from django.contrib.postgres.validators import ArrayMaxLengthValidator
  5. from django.core import checks, exceptions
  6. from django.db.models import Field, Func, IntegerField, Transform, Value
  7. from django.db.models.fields.mixins import CheckFieldDefaultMixin
  8. from django.db.models.lookups import Exact, In
  9. from django.utils.translation import gettext_lazy as _
  10. from ..utils import prefix_validation_error
  11. from .utils import AttributeSetter
  12. __all__ = ["ArrayField"]
  13. class ArrayField(CheckFieldDefaultMixin, Field):
  14. empty_strings_allowed = False
  15. default_error_messages = {
  16. "item_invalid": _("Item %(nth)s in the array did not validate:"),
  17. "nested_array_mismatch": _("Nested arrays must have the same length."),
  18. }
  19. _default_hint = ("list", "[]")
  20. def __init__(self, base_field, size=None, **kwargs):
  21. self.base_field = base_field
  22. self.db_collation = getattr(self.base_field, "db_collation", None)
  23. self.size = size
  24. if self.size:
  25. self.default_validators = [
  26. *self.default_validators,
  27. ArrayMaxLengthValidator(self.size),
  28. ]
  29. # For performance, only add a from_db_value() method if the base field
  30. # implements it.
  31. if hasattr(self.base_field, "from_db_value"):
  32. self.from_db_value = self._from_db_value
  33. super().__init__(**kwargs)
  34. @property
  35. def model(self):
  36. try:
  37. return self.__dict__["model"]
  38. except KeyError:
  39. raise AttributeError(
  40. "'%s' object has no attribute 'model'" % self.__class__.__name__
  41. )
  42. @model.setter
  43. def model(self, model):
  44. self.__dict__["model"] = model
  45. self.base_field.model = model
  46. @classmethod
  47. def _choices_is_value(cls, value):
  48. return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
  49. def check(self, **kwargs):
  50. errors = super().check(**kwargs)
  51. if self.base_field.remote_field:
  52. errors.append(
  53. checks.Error(
  54. "Base field for array cannot be a related field.",
  55. obj=self,
  56. id="postgres.E002",
  57. )
  58. )
  59. else:
  60. # Remove the field name checks as they are not needed here.
  61. base_checks = self.base_field.check()
  62. if base_checks:
  63. error_messages = "\n ".join(
  64. "%s (%s)" % (base_check.msg, base_check.id)
  65. for base_check in base_checks
  66. if isinstance(base_check, checks.Error)
  67. )
  68. if error_messages:
  69. errors.append(
  70. checks.Error(
  71. "Base field for array has errors:\n %s" % error_messages,
  72. obj=self,
  73. id="postgres.E001",
  74. )
  75. )
  76. warning_messages = "\n ".join(
  77. "%s (%s)" % (base_check.msg, base_check.id)
  78. for base_check in base_checks
  79. if isinstance(base_check, checks.Warning)
  80. )
  81. if warning_messages:
  82. errors.append(
  83. checks.Warning(
  84. "Base field for array has warnings:\n %s"
  85. % warning_messages,
  86. obj=self,
  87. id="postgres.W004",
  88. )
  89. )
  90. return errors
  91. def set_attributes_from_name(self, name):
  92. super().set_attributes_from_name(name)
  93. self.base_field.set_attributes_from_name(name)
  94. @property
  95. def description(self):
  96. return "Array of %s" % self.base_field.description
  97. def db_type(self, connection):
  98. size = self.size or ""
  99. return "%s[%s]" % (self.base_field.db_type(connection), size)
  100. def cast_db_type(self, connection):
  101. size = self.size or ""
  102. return "%s[%s]" % (self.base_field.cast_db_type(connection), size)
  103. def db_parameters(self, connection):
  104. db_params = super().db_parameters(connection)
  105. db_params["collation"] = self.db_collation
  106. return db_params
  107. def get_placeholder(self, value, compiler, connection):
  108. return "%s::{}".format(self.db_type(connection))
  109. def get_db_prep_value(self, value, connection, prepared=False):
  110. if isinstance(value, (list, tuple)):
  111. return [
  112. self.base_field.get_db_prep_value(i, connection, prepared=False)
  113. for i in value
  114. ]
  115. return value
  116. def deconstruct(self):
  117. name, path, args, kwargs = super().deconstruct()
  118. if path == "django.contrib.postgres.fields.array.ArrayField":
  119. path = "django.contrib.postgres.fields.ArrayField"
  120. kwargs.update(
  121. {
  122. "base_field": self.base_field.clone(),
  123. "size": self.size,
  124. }
  125. )
  126. return name, path, args, kwargs
  127. def to_python(self, value):
  128. if isinstance(value, str):
  129. # Assume we're deserializing
  130. vals = json.loads(value)
  131. value = [self.base_field.to_python(val) for val in vals]
  132. return value
  133. def _from_db_value(self, value, expression, connection):
  134. if value is None:
  135. return value
  136. return [
  137. self.base_field.from_db_value(item, expression, connection)
  138. for item in value
  139. ]
  140. def value_to_string(self, obj):
  141. values = []
  142. vals = self.value_from_object(obj)
  143. base_field = self.base_field
  144. for val in vals:
  145. if val is None:
  146. values.append(None)
  147. else:
  148. obj = AttributeSetter(base_field.attname, val)
  149. values.append(base_field.value_to_string(obj))
  150. return json.dumps(values)
  151. def get_transform(self, name):
  152. transform = super().get_transform(name)
  153. if transform:
  154. return transform
  155. if "_" not in name:
  156. try:
  157. index = int(name)
  158. except ValueError:
  159. pass
  160. else:
  161. index += 1 # postgres uses 1-indexing
  162. return IndexTransformFactory(index, self.base_field)
  163. try:
  164. start, end = name.split("_")
  165. start = int(start) + 1
  166. end = int(end) # don't add one here because postgres slices are weird
  167. except ValueError:
  168. pass
  169. else:
  170. return SliceTransformFactory(start, end)
  171. def validate(self, value, model_instance):
  172. super().validate(value, model_instance)
  173. for index, part in enumerate(value):
  174. try:
  175. self.base_field.validate(part, model_instance)
  176. except exceptions.ValidationError as error:
  177. raise prefix_validation_error(
  178. error,
  179. prefix=self.error_messages["item_invalid"],
  180. code="item_invalid",
  181. params={"nth": index + 1},
  182. )
  183. if isinstance(self.base_field, ArrayField):
  184. if len({len(i) for i in value}) > 1:
  185. raise exceptions.ValidationError(
  186. self.error_messages["nested_array_mismatch"],
  187. code="nested_array_mismatch",
  188. )
  189. def run_validators(self, value):
  190. super().run_validators(value)
  191. for index, part in enumerate(value):
  192. try:
  193. self.base_field.run_validators(part)
  194. except exceptions.ValidationError as error:
  195. raise prefix_validation_error(
  196. error,
  197. prefix=self.error_messages["item_invalid"],
  198. code="item_invalid",
  199. params={"nth": index + 1},
  200. )
  201. def formfield(self, **kwargs):
  202. return super().formfield(
  203. **{
  204. "form_class": SimpleArrayField,
  205. "base_field": self.base_field.formfield(),
  206. "max_length": self.size,
  207. **kwargs,
  208. }
  209. )
  210. def slice_expression(self, expression, start, length):
  211. # If length is not provided, don't specify an end to slice to the end
  212. # of the array.
  213. end = None if length is None else start + length - 1
  214. return SliceTransform(start, end, expression)
  215. class ArrayRHSMixin:
  216. def __init__(self, lhs, rhs):
  217. # Don't wrap arrays that contains only None values, psycopg doesn't
  218. # allow this.
  219. if isinstance(rhs, (tuple, list)) and any(self._rhs_not_none_values(rhs)):
  220. expressions = []
  221. for value in rhs:
  222. if not hasattr(value, "resolve_expression"):
  223. field = lhs.output_field
  224. value = Value(field.base_field.get_prep_value(value))
  225. expressions.append(value)
  226. rhs = Func(
  227. *expressions,
  228. function="ARRAY",
  229. template="%(function)s[%(expressions)s]",
  230. )
  231. super().__init__(lhs, rhs)
  232. def process_rhs(self, compiler, connection):
  233. rhs, rhs_params = super().process_rhs(compiler, connection)
  234. cast_type = self.lhs.output_field.cast_db_type(connection)
  235. return "%s::%s" % (rhs, cast_type), rhs_params
  236. def _rhs_not_none_values(self, rhs):
  237. for x in rhs:
  238. if isinstance(x, (list, tuple)):
  239. yield from self._rhs_not_none_values(x)
  240. elif x is not None:
  241. yield True
  242. @ArrayField.register_lookup
  243. class ArrayContains(ArrayRHSMixin, lookups.DataContains):
  244. pass
  245. @ArrayField.register_lookup
  246. class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
  247. pass
  248. @ArrayField.register_lookup
  249. class ArrayExact(ArrayRHSMixin, Exact):
  250. pass
  251. @ArrayField.register_lookup
  252. class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
  253. pass
  254. @ArrayField.register_lookup
  255. class ArrayLenTransform(Transform):
  256. lookup_name = "len"
  257. output_field = IntegerField()
  258. def as_sql(self, compiler, connection):
  259. lhs, params = compiler.compile(self.lhs)
  260. # Distinguish NULL and empty arrays
  261. return (
  262. "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
  263. "coalesce(array_length(%(lhs)s, 1), 0) END"
  264. ) % {"lhs": lhs}, params * 2
  265. @ArrayField.register_lookup
  266. class ArrayInLookup(In):
  267. def get_prep_lookup(self):
  268. values = super().get_prep_lookup()
  269. if hasattr(values, "resolve_expression"):
  270. return values
  271. # In.process_rhs() expects values to be hashable, so convert lists
  272. # to tuples.
  273. prepared_values = []
  274. for value in values:
  275. if hasattr(value, "resolve_expression"):
  276. prepared_values.append(value)
  277. else:
  278. prepared_values.append(tuple(value))
  279. return prepared_values
  280. class IndexTransform(Transform):
  281. def __init__(self, index, base_field, *args, **kwargs):
  282. super().__init__(*args, **kwargs)
  283. self.index = index
  284. self.base_field = base_field
  285. def as_sql(self, compiler, connection):
  286. lhs, params = compiler.compile(self.lhs)
  287. if not lhs.endswith("]"):
  288. lhs = "(%s)" % lhs
  289. return "%s[%%s]" % lhs, (*params, self.index)
  290. @property
  291. def output_field(self):
  292. return self.base_field
  293. class IndexTransformFactory:
  294. def __init__(self, index, base_field):
  295. self.index = index
  296. self.base_field = base_field
  297. def __call__(self, *args, **kwargs):
  298. return IndexTransform(self.index, self.base_field, *args, **kwargs)
  299. class SliceTransform(Transform):
  300. def __init__(self, start, end, *args, **kwargs):
  301. super().__init__(*args, **kwargs)
  302. self.start = start
  303. self.end = end
  304. def as_sql(self, compiler, connection):
  305. lhs, params = compiler.compile(self.lhs)
  306. # self.start is set to 1 if slice start is not provided.
  307. if self.end is None:
  308. return f"({lhs})[%s:]", (*params, self.start)
  309. else:
  310. return f"({lhs})[%s:%s]", (*params, self.start, self.end)
  311. class SliceTransformFactory:
  312. def __init__(self, start, end):
  313. self.start = start
  314. self.end = end
  315. def __call__(self, *args, **kwargs):
  316. return SliceTransform(self.start, self.end, *args, **kwargs)