aggregates.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. """
  2. Classes to represent the definitions of aggregate functions.
  3. """
  4. from django.core.exceptions import FieldError, FullResultSet
  5. from django.db.models.expressions import Case, Func, Star, Value, When
  6. from django.db.models.fields import IntegerField
  7. from django.db.models.functions.comparison import Coalesce
  8. from django.db.models.functions.mixins import (
  9. FixDurationInputMixin,
  10. NumericOutputFieldMixin,
  11. )
  12. __all__ = [
  13. "Aggregate",
  14. "Avg",
  15. "Count",
  16. "Max",
  17. "Min",
  18. "StdDev",
  19. "Sum",
  20. "Variance",
  21. ]
  22. class Aggregate(Func):
  23. template = "%(function)s(%(distinct)s%(expressions)s)"
  24. contains_aggregate = True
  25. name = None
  26. filter_template = "%s FILTER (WHERE %%(filter)s)"
  27. window_compatible = True
  28. allow_distinct = False
  29. empty_result_set_value = None
  30. def __init__(
  31. self, *expressions, distinct=False, filter=None, default=None, **extra
  32. ):
  33. if distinct and not self.allow_distinct:
  34. raise TypeError("%s does not allow distinct." % self.__class__.__name__)
  35. if default is not None and self.empty_result_set_value is not None:
  36. raise TypeError(f"{self.__class__.__name__} does not allow default.")
  37. self.distinct = distinct
  38. self.filter = filter
  39. self.default = default
  40. super().__init__(*expressions, **extra)
  41. def get_source_fields(self):
  42. # Don't return the filter expression since it's not a source field.
  43. return [e._output_field_or_none for e in super().get_source_expressions()]
  44. def get_source_expressions(self):
  45. source_expressions = super().get_source_expressions()
  46. return source_expressions + [self.filter]
  47. def set_source_expressions(self, exprs):
  48. *exprs, self.filter = exprs
  49. return super().set_source_expressions(exprs)
  50. def resolve_expression(
  51. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  52. ):
  53. # Aggregates are not allowed in UPDATE queries, so ignore for_save
  54. c = super().resolve_expression(query, allow_joins, reuse, summarize)
  55. c.filter = (
  56. c.filter.resolve_expression(query, allow_joins, reuse, summarize)
  57. if c.filter
  58. else None
  59. )
  60. if summarize:
  61. # Summarized aggregates cannot refer to summarized aggregates.
  62. for ref in c.get_refs():
  63. if query.annotations[ref].is_summary:
  64. raise FieldError(
  65. f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate"
  66. )
  67. elif not self.is_summary:
  68. # Call Aggregate.get_source_expressions() to avoid
  69. # returning self.filter and including that in this loop.
  70. expressions = super(Aggregate, c).get_source_expressions()
  71. for index, expr in enumerate(expressions):
  72. if expr.contains_aggregate:
  73. before_resolved = self.get_source_expressions()[index]
  74. name = (
  75. before_resolved.name
  76. if hasattr(before_resolved, "name")
  77. else repr(before_resolved)
  78. )
  79. raise FieldError(
  80. "Cannot compute %s('%s'): '%s' is an aggregate"
  81. % (c.name, name, name)
  82. )
  83. if (default := c.default) is None:
  84. return c
  85. if hasattr(default, "resolve_expression"):
  86. default = default.resolve_expression(query, allow_joins, reuse, summarize)
  87. if default._output_field_or_none is None:
  88. default.output_field = c._output_field_or_none
  89. else:
  90. default = Value(default, c._output_field_or_none)
  91. c.default = None # Reset the default argument before wrapping.
  92. coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
  93. coalesce.is_summary = c.is_summary
  94. return coalesce
  95. @property
  96. def default_alias(self):
  97. expressions = [
  98. expr for expr in self.get_source_expressions() if expr is not None
  99. ]
  100. if len(expressions) == 1 and hasattr(expressions[0], "name"):
  101. return "%s__%s" % (expressions[0].name, self.name.lower())
  102. raise TypeError("Complex expressions require an alias")
  103. def get_group_by_cols(self):
  104. return []
  105. def as_sql(self, compiler, connection, **extra_context):
  106. extra_context["distinct"] = "DISTINCT " if self.distinct else ""
  107. if self.filter:
  108. if connection.features.supports_aggregate_filter_clause:
  109. try:
  110. filter_sql, filter_params = self.filter.as_sql(compiler, connection)
  111. except FullResultSet:
  112. pass
  113. else:
  114. template = self.filter_template % extra_context.get(
  115. "template", self.template
  116. )
  117. sql, params = super().as_sql(
  118. compiler,
  119. connection,
  120. template=template,
  121. filter=filter_sql,
  122. **extra_context,
  123. )
  124. return sql, (*params, *filter_params)
  125. else:
  126. copy = self.copy()
  127. copy.filter = None
  128. source_expressions = copy.get_source_expressions()
  129. condition = When(self.filter, then=source_expressions[0])
  130. copy.set_source_expressions([Case(condition)] + source_expressions[1:])
  131. return super(Aggregate, copy).as_sql(
  132. compiler, connection, **extra_context
  133. )
  134. return super().as_sql(compiler, connection, **extra_context)
  135. def _get_repr_options(self):
  136. options = super()._get_repr_options()
  137. if self.distinct:
  138. options["distinct"] = self.distinct
  139. if self.filter:
  140. options["filter"] = self.filter
  141. return options
  142. class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
  143. function = "AVG"
  144. name = "Avg"
  145. allow_distinct = True
  146. class Count(Aggregate):
  147. function = "COUNT"
  148. name = "Count"
  149. output_field = IntegerField()
  150. allow_distinct = True
  151. empty_result_set_value = 0
  152. def __init__(self, expression, filter=None, **extra):
  153. if expression == "*":
  154. expression = Star()
  155. if isinstance(expression, Star) and filter is not None:
  156. raise ValueError("Star cannot be used with filter. Please specify a field.")
  157. super().__init__(expression, filter=filter, **extra)
  158. class Max(Aggregate):
  159. function = "MAX"
  160. name = "Max"
  161. class Min(Aggregate):
  162. function = "MIN"
  163. name = "Min"
  164. class StdDev(NumericOutputFieldMixin, Aggregate):
  165. name = "StdDev"
  166. def __init__(self, expression, sample=False, **extra):
  167. self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
  168. super().__init__(expression, **extra)
  169. def _get_repr_options(self):
  170. return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
  171. class Sum(FixDurationInputMixin, Aggregate):
  172. function = "SUM"
  173. name = "Sum"
  174. allow_distinct = True
  175. class Variance(NumericOutputFieldMixin, Aggregate):
  176. name = "Variance"
  177. def __init__(self, expression, sample=False, **extra):
  178. self.function = "VAR_SAMP" if sample else "VAR_POP"
  179. super().__init__(expression, **extra)
  180. def _get_repr_options(self):
  181. return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}