related_lookups.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from django.db.models.lookups import (
  2. Exact,
  3. GreaterThan,
  4. GreaterThanOrEqual,
  5. In,
  6. IsNull,
  7. LessThan,
  8. LessThanOrEqual,
  9. )
  10. class MultiColSource:
  11. contains_aggregate = False
  12. contains_over_clause = False
  13. def __init__(self, alias, targets, sources, field):
  14. self.targets, self.sources, self.field, self.alias = (
  15. targets,
  16. sources,
  17. field,
  18. alias,
  19. )
  20. self.output_field = self.field
  21. def __repr__(self):
  22. return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
  23. def relabeled_clone(self, relabels):
  24. return self.__class__(
  25. relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
  26. )
  27. def get_lookup(self, lookup):
  28. return self.output_field.get_lookup(lookup)
  29. def resolve_expression(self, *args, **kwargs):
  30. return self
  31. def get_normalized_value(value, lhs):
  32. from django.db.models import Model
  33. if isinstance(value, Model):
  34. if value.pk is None:
  35. raise ValueError("Model instances passed to related filters must be saved.")
  36. value_list = []
  37. sources = lhs.output_field.path_infos[-1].target_fields
  38. for source in sources:
  39. while not isinstance(value, source.model) and source.remote_field:
  40. source = source.remote_field.model._meta.get_field(
  41. source.remote_field.field_name
  42. )
  43. try:
  44. value_list.append(getattr(value, source.attname))
  45. except AttributeError:
  46. # A case like Restaurant.objects.filter(place=restaurant_instance),
  47. # where place is a OneToOneField and the primary key of Restaurant.
  48. return (value.pk,)
  49. return tuple(value_list)
  50. if not isinstance(value, tuple):
  51. return (value,)
  52. return value
  53. class RelatedIn(In):
  54. def get_prep_lookup(self):
  55. if not isinstance(self.lhs, MultiColSource):
  56. if self.rhs_is_direct_value():
  57. # If we get here, we are dealing with single-column relations.
  58. self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
  59. # We need to run the related field's get_prep_value(). Consider
  60. # case ForeignKey to IntegerField given value 'abc'. The
  61. # ForeignKey itself doesn't have validation for non-integers,
  62. # so we must run validation using the target field.
  63. if hasattr(self.lhs.output_field, "path_infos"):
  64. # Run the target field's get_prep_value. We can safely
  65. # assume there is only one as we don't get to the direct
  66. # value branch otherwise.
  67. target_field = self.lhs.output_field.path_infos[-1].target_fields[
  68. -1
  69. ]
  70. self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
  71. elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
  72. self.lhs.field.target_field, "primary_key", False
  73. ):
  74. if (
  75. getattr(self.lhs.output_field, "primary_key", False)
  76. and self.lhs.output_field.model == self.rhs.model
  77. ):
  78. # A case like
  79. # Restaurant.objects.filter(place__in=restaurant_qs), where
  80. # place is a OneToOneField and the primary key of
  81. # Restaurant.
  82. target_field = self.lhs.field.name
  83. else:
  84. target_field = self.lhs.field.target_field.name
  85. self.rhs.set_values([target_field])
  86. return super().get_prep_lookup()
  87. def as_sql(self, compiler, connection):
  88. if isinstance(self.lhs, MultiColSource):
  89. # For multicolumn lookups we need to build a multicolumn where clause.
  90. # This clause is either a SubqueryConstraint (for values that need
  91. # to be compiled to SQL) or an OR-combined list of
  92. # (col1 = val1 AND col2 = val2 AND ...) clauses.
  93. from django.db.models.sql.where import (
  94. AND,
  95. OR,
  96. SubqueryConstraint,
  97. WhereNode,
  98. )
  99. root_constraint = WhereNode(connector=OR)
  100. if self.rhs_is_direct_value():
  101. values = [get_normalized_value(value, self.lhs) for value in self.rhs]
  102. for value in values:
  103. value_constraint = WhereNode()
  104. for source, target, val in zip(
  105. self.lhs.sources, self.lhs.targets, value
  106. ):
  107. lookup_class = target.get_lookup("exact")
  108. lookup = lookup_class(
  109. target.get_col(self.lhs.alias, source), val
  110. )
  111. value_constraint.add(lookup, AND)
  112. root_constraint.add(value_constraint, OR)
  113. else:
  114. root_constraint.add(
  115. SubqueryConstraint(
  116. self.lhs.alias,
  117. [target.column for target in self.lhs.targets],
  118. [source.name for source in self.lhs.sources],
  119. self.rhs,
  120. ),
  121. AND,
  122. )
  123. return root_constraint.as_sql(compiler, connection)
  124. return super().as_sql(compiler, connection)
  125. class RelatedLookupMixin:
  126. def get_prep_lookup(self):
  127. if not isinstance(self.lhs, MultiColSource) and not hasattr(
  128. self.rhs, "resolve_expression"
  129. ):
  130. # If we get here, we are dealing with single-column relations.
  131. self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
  132. # We need to run the related field's get_prep_value(). Consider case
  133. # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
  134. # doesn't have validation for non-integers, so we must run validation
  135. # using the target field.
  136. if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
  137. # Get the target field. We can safely assume there is only one
  138. # as we don't get to the direct value branch otherwise.
  139. target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
  140. self.rhs = target_field.get_prep_value(self.rhs)
  141. return super().get_prep_lookup()
  142. def as_sql(self, compiler, connection):
  143. if isinstance(self.lhs, MultiColSource):
  144. assert self.rhs_is_direct_value()
  145. self.rhs = get_normalized_value(self.rhs, self.lhs)
  146. from django.db.models.sql.where import AND, WhereNode
  147. root_constraint = WhereNode()
  148. for target, source, val in zip(
  149. self.lhs.targets, self.lhs.sources, self.rhs
  150. ):
  151. lookup_class = target.get_lookup(self.lookup_name)
  152. root_constraint.add(
  153. lookup_class(target.get_col(self.lhs.alias, source), val), AND
  154. )
  155. return root_constraint.as_sql(compiler, connection)
  156. return super().as_sql(compiler, connection)
  157. class RelatedExact(RelatedLookupMixin, Exact):
  158. pass
  159. class RelatedLessThan(RelatedLookupMixin, LessThan):
  160. pass
  161. class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
  162. pass
  163. class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
  164. pass
  165. class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
  166. pass
  167. class RelatedIsNull(RelatedLookupMixin, IsNull):
  168. pass