expressions.py 71 KB


  1. import copy
  2. import datetime
  3. import functools
  4. import inspect
  5. from collections import defaultdict
  6. from decimal import Decimal
  7. from enum import Enum
  8. from itertools import chain
  9. from types import NoneType
  10. from uuid import UUID
  11. from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
  12. from django.db import DatabaseError, NotSupportedError, connection
  13. from django.db.models import fields
  14. from django.db.models.constants import LOOKUP_SEP
  15. from django.db.models.query_utils import Q
  16. from django.utils.deconstruct import deconstructible
  17. from django.utils.functional import cached_property, classproperty
  18. from django.utils.hashable import make_hashable
  19. class SQLiteNumericMixin:
  20. """
  21. Some expressions with output_field=DecimalField() must be cast to
  22. numeric to be properly filtered.
  23. """
  24. def as_sqlite(self, compiler, connection, **extra_context):
  25. sql, params = self.as_sql(compiler, connection, **extra_context)
  26. try:
  27. if self.output_field.get_internal_type() == "DecimalField":
  28. sql = "(CAST(%s AS NUMERIC))" % sql
  29. except FieldError:
  30. pass
  31. return sql, params
  32. class Combinable:
  33. """
  34. Provide the ability to combine one or two objects with
  35. some connector. For example F('foo') + F('bar').
  36. """
  37. # Arithmetic connectors
  38. ADD = "+"
  39. SUB = "-"
  40. MUL = "*"
  41. DIV = "/"
  42. POW = "^"
  43. # The following is a quoted % operator - it is quoted because it can be
  44. # used in strings that also have parameter substitution.
  45. MOD = "%%"
  46. # Bitwise operators - note that these are generated by .bitand()
  47. # and .bitor(), the '&' and '|' are reserved for boolean operator
  48. # usage.
  49. BITAND = "&"
  50. BITOR = "|"
  51. BITLEFTSHIFT = "<<"
  52. BITRIGHTSHIFT = ">>"
  53. BITXOR = "#"
  54. def _combine(self, other, connector, reversed):
  55. if not hasattr(other, "resolve_expression"):
  56. # everything must be resolvable to an expression
  57. other = Value(other)
  58. if reversed:
  59. return CombinedExpression(other, connector, self)
  60. return CombinedExpression(self, connector, other)
  61. #############
  62. # OPERATORS #
  63. #############
  64. def __neg__(self):
  65. return self._combine(-1, self.MUL, False)
  66. def __add__(self, other):
  67. return self._combine(other, self.ADD, False)
  68. def __sub__(self, other):
  69. return self._combine(other, self.SUB, False)
  70. def __mul__(self, other):
  71. return self._combine(other, self.MUL, False)
  72. def __truediv__(self, other):
  73. return self._combine(other, self.DIV, False)
  74. def __mod__(self, other):
  75. return self._combine(other, self.MOD, False)
  76. def __pow__(self, other):
  77. return self._combine(other, self.POW, False)
  78. def __and__(self, other):
  79. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  80. return Q(self) & Q(other)
  81. raise NotImplementedError(
  82. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  83. )
  84. def bitand(self, other):
  85. return self._combine(other, self.BITAND, False)
  86. def bitleftshift(self, other):
  87. return self._combine(other, self.BITLEFTSHIFT, False)
  88. def bitrightshift(self, other):
  89. return self._combine(other, self.BITRIGHTSHIFT, False)
  90. def __xor__(self, other):
  91. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  92. return Q(self) ^ Q(other)
  93. raise NotImplementedError(
  94. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  95. )
  96. def bitxor(self, other):
  97. return self._combine(other, self.BITXOR, False)
  98. def __or__(self, other):
  99. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  100. return Q(self) | Q(other)
  101. raise NotImplementedError(
  102. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  103. )
  104. def bitor(self, other):
  105. return self._combine(other, self.BITOR, False)
  106. def __radd__(self, other):
  107. return self._combine(other, self.ADD, True)
  108. def __rsub__(self, other):
  109. return self._combine(other, self.SUB, True)
  110. def __rmul__(self, other):
  111. return self._combine(other, self.MUL, True)
  112. def __rtruediv__(self, other):
  113. return self._combine(other, self.DIV, True)
  114. def __rmod__(self, other):
  115. return self._combine(other, self.MOD, True)
  116. def __rpow__(self, other):
  117. return self._combine(other, self.POW, True)
  118. def __rand__(self, other):
  119. raise NotImplementedError(
  120. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  121. )
  122. def __ror__(self, other):
  123. raise NotImplementedError(
  124. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  125. )
  126. def __rxor__(self, other):
  127. raise NotImplementedError(
  128. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  129. )
  130. def __invert__(self):
  131. return NegatedExpression(self)
  132. class BaseExpression:
  133. """Base class for all query expressions."""
  134. empty_result_set_value = NotImplemented
  135. # aggregate specific fields
  136. is_summary = False
  137. _output_field_resolved_to_none = False
  138. # Can the expression be used in a WHERE clause?
  139. filterable = True
  140. # Can the expression be used as a source expression in Window?
  141. window_compatible = False
  142. # Can the expression be used as a database default value?
  143. allowed_default = False
  144. # Can the expression be used during a constraint validation?
  145. constraint_validation_compatible = True
  146. def __init__(self, output_field=None):
  147. if output_field is not None:
  148. self.output_field = output_field
  149. def __getstate__(self):
  150. state = self.__dict__.copy()
  151. state.pop("convert_value", None)
  152. return state
  153. def get_db_converters(self, connection):
  154. return (
  155. []
  156. if self.convert_value is self._convert_value_noop
  157. else [self.convert_value]
  158. ) + self.output_field.get_db_converters(connection)
  159. def get_source_expressions(self):
  160. return []
  161. def set_source_expressions(self, exprs):
  162. assert not exprs
  163. def _parse_expressions(self, *expressions):
  164. return [
  165. (
  166. arg
  167. if hasattr(arg, "resolve_expression")
  168. else (F(arg) if isinstance(arg, str) else Value(arg))
  169. )
  170. for arg in expressions
  171. ]
  172. def as_sql(self, compiler, connection):
  173. """
  174. Responsible for returning a (sql, [params]) tuple to be included
  175. in the current query.
  176. Different backends can provide their own implementation, by
  177. providing an `as_{vendor}` method and patching the Expression:
  178. ```
  179. def override_as_sql(self, compiler, connection):
  180. # custom logic
  181. return super().as_sql(compiler, connection)
  182. setattr(Expression, 'as_' + connection.vendor, override_as_sql)
  183. ```
  184. Arguments:
  185. * compiler: the query compiler responsible for generating the query.
  186. Must have a compile method, returning a (sql, [params]) tuple.
  187. Calling compiler(value) will return a quoted `value`.
  188. * connection: the database connection used for the current query.
  189. Return: (sql, params)
  190. Where `sql` is a string containing ordered sql parameters to be
  191. replaced with the elements of the list `params`.
  192. """
  193. raise NotImplementedError("Subclasses must implement as_sql()")
  194. @cached_property
  195. def contains_aggregate(self):
  196. return any(
  197. expr and expr.contains_aggregate for expr in self.get_source_expressions()
  198. )
  199. @cached_property
  200. def contains_over_clause(self):
  201. return any(
  202. expr and expr.contains_over_clause for expr in self.get_source_expressions()
  203. )
  204. @cached_property
  205. def contains_column_references(self):
  206. return any(
  207. expr and expr.contains_column_references
  208. for expr in self.get_source_expressions()
  209. )
  210. @cached_property
  211. def contains_subquery(self):
  212. return any(
  213. expr and (getattr(expr, "subquery", False) or expr.contains_subquery)
  214. for expr in self.get_source_expressions()
  215. )
  216. def resolve_expression(
  217. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  218. ):
  219. """
  220. Provide the chance to do any preprocessing or validation before being
  221. added to the query.
  222. Arguments:
  223. * query: the backend query implementation
  224. * allow_joins: boolean allowing or denying use of joins
  225. in this query
  226. * reuse: a set of reusable joins for multijoins
  227. * summarize: a terminal aggregate clause
  228. * for_save: whether this expression about to be used in a save or update
  229. Return: an Expression to be added to the query.
  230. """
  231. c = self.copy()
  232. c.is_summary = summarize
  233. c.set_source_expressions(
  234. [
  235. (
  236. expr.resolve_expression(query, allow_joins, reuse, summarize)
  237. if expr
  238. else None
  239. )
  240. for expr in c.get_source_expressions()
  241. ]
  242. )
  243. return c
  244. @property
  245. def conditional(self):
  246. return isinstance(self.output_field, fields.BooleanField)
  247. @property
  248. def field(self):
  249. return self.output_field
  250. @cached_property
  251. def output_field(self):
  252. """Return the output type of this expressions."""
  253. output_field = self._resolve_output_field()
  254. if output_field is None:
  255. self._output_field_resolved_to_none = True
  256. raise FieldError("Cannot resolve expression type, unknown output_field")
  257. return output_field
  258. @cached_property
  259. def _output_field_or_none(self):
  260. """
  261. Return the output field of this expression, or None if
  262. _resolve_output_field() didn't return an output type.
  263. """
  264. try:
  265. return self.output_field
  266. except FieldError:
  267. if not self._output_field_resolved_to_none:
  268. raise
  269. def _resolve_output_field(self):
  270. """
  271. Attempt to infer the output type of the expression.
  272. As a guess, if the output fields of all source fields match then simply
  273. infer the same type here.
  274. If a source's output field resolves to None, exclude it from this check.
  275. If all sources are None, then an error is raised higher up the stack in
  276. the output_field property.
  277. """
  278. # This guess is mostly a bad idea, but there is quite a lot of code
  279. # (especially 3rd party Func subclasses) that depend on it, we'd need a
  280. # deprecation path to fix it.
  281. sources_iter = (
  282. source for source in self.get_source_fields() if source is not None
  283. )
  284. for output_field in sources_iter:
  285. for source in sources_iter:
  286. if not isinstance(output_field, source.__class__):
  287. raise FieldError(
  288. "Expression contains mixed types: %s, %s. You must "
  289. "set output_field."
  290. % (
  291. output_field.__class__.__name__,
  292. source.__class__.__name__,
  293. )
  294. )
  295. return output_field
  296. @staticmethod
  297. def _convert_value_noop(value, expression, connection):
  298. return value
  299. @cached_property
  300. def convert_value(self):
  301. """
  302. Expressions provide their own converters because users have the option
  303. of manually specifying the output_field which may be a different type
  304. from the one the database returns.
  305. """
  306. field = self.output_field
  307. internal_type = field.get_internal_type()
  308. if internal_type == "FloatField":
  309. return lambda value, expression, connection: (
  310. None if value is None else float(value)
  311. )
  312. elif internal_type.endswith("IntegerField"):
  313. return lambda value, expression, connection: (
  314. None if value is None else int(value)
  315. )
  316. elif internal_type == "DecimalField":
  317. return lambda value, expression, connection: (
  318. None if value is None else Decimal(value)
  319. )
  320. return self._convert_value_noop
  321. def get_lookup(self, lookup):
  322. return self.output_field.get_lookup(lookup)
  323. def get_transform(self, name):
  324. return self.output_field.get_transform(name)
  325. def relabeled_clone(self, change_map):
  326. clone = self.copy()
  327. clone.set_source_expressions(
  328. [
  329. e.relabeled_clone(change_map) if e is not None else None
  330. for e in self.get_source_expressions()
  331. ]
  332. )
  333. return clone
  334. def replace_expressions(self, replacements):
  335. if not replacements:
  336. return self
  337. if replacement := replacements.get(self):
  338. return replacement
  339. if not (source_expressions := self.get_source_expressions()):
  340. return self
  341. clone = self.copy()
  342. clone.set_source_expressions(
  343. [
  344. expr.replace_expressions(replacements) if expr else None
  345. for expr in source_expressions
  346. ]
  347. )
  348. return clone
  349. def get_refs(self):
  350. refs = set()
  351. for expr in self.get_source_expressions():
  352. if expr is None:
  353. continue
  354. refs |= expr.get_refs()
  355. return refs
  356. def copy(self):
  357. return copy.copy(self)
  358. def prefix_references(self, prefix):
  359. clone = self.copy()
  360. clone.set_source_expressions(
  361. [
  362. (
  363. F(f"{prefix}{expr.name}")
  364. if isinstance(expr, F)
  365. else expr.prefix_references(prefix)
  366. )
  367. for expr in self.get_source_expressions()
  368. ]
  369. )
  370. return clone
  371. def get_group_by_cols(self):
  372. if not self.contains_aggregate:
  373. return [self]
  374. cols = []
  375. for source in self.get_source_expressions():
  376. cols.extend(source.get_group_by_cols())
  377. return cols
  378. def get_source_fields(self):
  379. """Return the underlying field types used by this aggregate."""
  380. return [e._output_field_or_none for e in self.get_source_expressions()]
  381. def asc(self, **kwargs):
  382. return OrderBy(self, **kwargs)
  383. def desc(self, **kwargs):
  384. return OrderBy(self, descending=True, **kwargs)
  385. def reverse_ordering(self):
  386. return self
  387. def flatten(self):
  388. """
  389. Recursively yield this expression and all subexpressions, in
  390. depth-first order.
  391. """
  392. yield self
  393. for expr in self.get_source_expressions():
  394. if expr:
  395. if hasattr(expr, "flatten"):
  396. yield from expr.flatten()
  397. else:
  398. yield expr
  399. def select_format(self, compiler, sql, params):
  400. """
  401. Custom format for select clauses. For example, EXISTS expressions need
  402. to be wrapped in CASE WHEN on Oracle.
  403. """
  404. if hasattr(self.output_field, "select_format"):
  405. return self.output_field.select_format(compiler, sql, params)
  406. return sql, params
  407. def get_expression_for_validation(self):
  408. # Ignore expressions that cannot be used during a constraint validation.
  409. if not getattr(self, "constraint_validation_compatible", True):
  410. try:
  411. (expression,) = self.get_source_expressions()
  412. except ValueError as e:
  413. raise ValueError(
  414. "Expressions with constraint_validation_compatible set to False "
  415. "must have only one source expression."
  416. ) from e
  417. else:
  418. return expression
  419. return self
  420. @deconstructible
  421. class Expression(BaseExpression, Combinable):
  422. """An expression that can be combined with other expressions."""
  423. @classproperty
  424. @functools.lru_cache(maxsize=128)
  425. def _constructor_signature(cls):
  426. return inspect.signature(cls.__init__)
  427. @cached_property
  428. def identity(self):
  429. args, kwargs = self._constructor_args
  430. signature = self._constructor_signature.bind_partial(self, *args, **kwargs)
  431. signature.apply_defaults()
  432. arguments = iter(signature.arguments.items())
  433. next(arguments)
  434. identity = [self.__class__]
  435. for arg, value in arguments:
  436. if isinstance(value, fields.Field):
  437. if value.name and value.model:
  438. value = (value.model._meta.label, value.name)
  439. else:
  440. value = type(value)
  441. else:
  442. value = make_hashable(value)
  443. identity.append((arg, value))
  444. return tuple(identity)
  445. def __eq__(self, other):
  446. if not isinstance(other, Expression):
  447. return NotImplemented
  448. return other.identity == self.identity
  449. def __hash__(self):
  450. return hash(self.identity)
  451. # Type inference for CombinedExpression.output_field.
  452. # Missing items will result in FieldError, by design.
  453. #
  454. # The current approach for NULL is based on lowest common denominator behavior
  455. # i.e. if one of the supported databases is raising an error (rather than
  456. # return NULL) for `val <op> NULL`, then Django raises FieldError.
  457. _connector_combinations = [
  458. # Numeric operations - operands of same type.
  459. # PositiveIntegerField should take precedence over IntegerField (except
  460. # subtraction).
  461. {
  462. connector: [
  463. (
  464. fields.PositiveIntegerField,
  465. fields.PositiveIntegerField,
  466. fields.PositiveIntegerField,
  467. ),
  468. ]
  469. for connector in (
  470. Combinable.ADD,
  471. Combinable.MUL,
  472. Combinable.DIV,
  473. Combinable.MOD,
  474. Combinable.POW,
  475. )
  476. },
  477. # Other numeric operands.
  478. {
  479. connector: [
  480. (fields.IntegerField, fields.IntegerField, fields.IntegerField),
  481. (fields.FloatField, fields.FloatField, fields.FloatField),
  482. (fields.DecimalField, fields.DecimalField, fields.DecimalField),
  483. ]
  484. for connector in (
  485. Combinable.ADD,
  486. Combinable.SUB,
  487. Combinable.MUL,
  488. # Behavior for DIV with integer arguments follows Postgres/SQLite,
  489. # not MySQL/Oracle.
  490. Combinable.DIV,
  491. Combinable.MOD,
  492. Combinable.POW,
  493. )
  494. },
  495. # Numeric operations - operands of different type.
  496. {
  497. connector: [
  498. (fields.IntegerField, fields.DecimalField, fields.DecimalField),
  499. (fields.DecimalField, fields.IntegerField, fields.DecimalField),
  500. (fields.IntegerField, fields.FloatField, fields.FloatField),
  501. (fields.FloatField, fields.IntegerField, fields.FloatField),
  502. ]
  503. for connector in (
  504. Combinable.ADD,
  505. Combinable.SUB,
  506. Combinable.MUL,
  507. Combinable.DIV,
  508. Combinable.MOD,
  509. )
  510. },
  511. # Bitwise operators.
  512. {
  513. connector: [
  514. (fields.IntegerField, fields.IntegerField, fields.IntegerField),
  515. ]
  516. for connector in (
  517. Combinable.BITAND,
  518. Combinable.BITOR,
  519. Combinable.BITLEFTSHIFT,
  520. Combinable.BITRIGHTSHIFT,
  521. Combinable.BITXOR,
  522. )
  523. },
  524. # Numeric with NULL.
  525. {
  526. connector: list(
  527. chain.from_iterable(
  528. [(field_type, NoneType, field_type), (NoneType, field_type, field_type)]
  529. for field_type in (
  530. fields.IntegerField,
  531. fields.DecimalField,
  532. fields.FloatField,
  533. )
  534. )
  535. )
  536. for connector in (
  537. Combinable.ADD,
  538. Combinable.SUB,
  539. Combinable.MUL,
  540. Combinable.DIV,
  541. Combinable.MOD,
  542. Combinable.POW,
  543. )
  544. },
  545. # Date/DateTimeField/DurationField/TimeField.
  546. {
  547. Combinable.ADD: [
  548. # Date/DateTimeField.
  549. (fields.DateField, fields.DurationField, fields.DateTimeField),
  550. (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
  551. (fields.DurationField, fields.DateField, fields.DateTimeField),
  552. (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
  553. # DurationField.
  554. (fields.DurationField, fields.DurationField, fields.DurationField),
  555. # TimeField.
  556. (fields.TimeField, fields.DurationField, fields.TimeField),
  557. (fields.DurationField, fields.TimeField, fields.TimeField),
  558. ],
  559. },
  560. {
  561. Combinable.SUB: [
  562. # Date/DateTimeField.
  563. (fields.DateField, fields.DurationField, fields.DateTimeField),
  564. (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
  565. (fields.DateField, fields.DateField, fields.DurationField),
  566. (fields.DateField, fields.DateTimeField, fields.DurationField),
  567. (fields.DateTimeField, fields.DateField, fields.DurationField),
  568. (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
  569. # DurationField.
  570. (fields.DurationField, fields.DurationField, fields.DurationField),
  571. # TimeField.
  572. (fields.TimeField, fields.DurationField, fields.TimeField),
  573. (fields.TimeField, fields.TimeField, fields.DurationField),
  574. ],
  575. },
  576. ]
  577. _connector_combinators = defaultdict(list)
  578. def register_combinable_fields(lhs, connector, rhs, result):
  579. """
  580. Register combinable types:
  581. lhs <connector> rhs -> result
  582. e.g.
  583. register_combinable_fields(
  584. IntegerField, Combinable.ADD, FloatField, FloatField
  585. )
  586. """
  587. _connector_combinators[connector].append((lhs, rhs, result))
  588. for d in _connector_combinations:
  589. for connector, field_types in d.items():
  590. for lhs, rhs, result in field_types:
  591. register_combinable_fields(lhs, connector, rhs, result)
  592. @functools.lru_cache(maxsize=128)
  593. def _resolve_combined_type(connector, lhs_type, rhs_type):
  594. combinators = _connector_combinators.get(connector, ())
  595. for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
  596. if issubclass(lhs_type, combinator_lhs_type) and issubclass(
  597. rhs_type, combinator_rhs_type
  598. ):
  599. return combined_type
  600. class CombinedExpression(SQLiteNumericMixin, Expression):
  601. def __init__(self, lhs, connector, rhs, output_field=None):
  602. super().__init__(output_field=output_field)
  603. self.connector = connector
  604. self.lhs = lhs
  605. self.rhs = rhs
  606. def __repr__(self):
  607. return "<{}: {}>".format(self.__class__.__name__, self)
  608. def __str__(self):
  609. return "{} {} {}".format(self.lhs, self.connector, self.rhs)
  610. def get_source_expressions(self):
  611. return [self.lhs, self.rhs]
  612. def set_source_expressions(self, exprs):
  613. self.lhs, self.rhs = exprs
  614. def _resolve_output_field(self):
  615. # We avoid using super() here for reasons given in
  616. # Expression._resolve_output_field()
  617. combined_type = _resolve_combined_type(
  618. self.connector,
  619. type(self.lhs._output_field_or_none),
  620. type(self.rhs._output_field_or_none),
  621. )
  622. if combined_type is None:
  623. raise FieldError(
  624. f"Cannot infer type of {self.connector!r} expression involving these "
  625. f"types: {self.lhs.output_field.__class__.__name__}, "
  626. f"{self.rhs.output_field.__class__.__name__}. You must set "
  627. f"output_field."
  628. )
  629. return combined_type()
  630. def as_sql(self, compiler, connection):
  631. expressions = []
  632. expression_params = []
  633. sql, params = compiler.compile(self.lhs)
  634. expressions.append(sql)
  635. expression_params.extend(params)
  636. sql, params = compiler.compile(self.rhs)
  637. expressions.append(sql)
  638. expression_params.extend(params)
  639. # order of precedence
  640. expression_wrapper = "(%s)"
  641. sql = connection.ops.combine_expression(self.connector, expressions)
  642. return expression_wrapper % sql, expression_params
  643. def resolve_expression(
  644. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  645. ):
  646. lhs = self.lhs.resolve_expression(
  647. query, allow_joins, reuse, summarize, for_save
  648. )
  649. rhs = self.rhs.resolve_expression(
  650. query, allow_joins, reuse, summarize, for_save
  651. )
  652. if not isinstance(self, (DurationExpression, TemporalSubtraction)):
  653. try:
  654. lhs_type = lhs.output_field.get_internal_type()
  655. except (AttributeError, FieldError):
  656. lhs_type = None
  657. try:
  658. rhs_type = rhs.output_field.get_internal_type()
  659. except (AttributeError, FieldError):
  660. rhs_type = None
  661. if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
  662. return DurationExpression(
  663. self.lhs, self.connector, self.rhs
  664. ).resolve_expression(
  665. query,
  666. allow_joins,
  667. reuse,
  668. summarize,
  669. for_save,
  670. )
  671. datetime_fields = {"DateField", "DateTimeField", "TimeField"}
  672. if (
  673. self.connector == self.SUB
  674. and lhs_type in datetime_fields
  675. and lhs_type == rhs_type
  676. ):
  677. return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
  678. query,
  679. allow_joins,
  680. reuse,
  681. summarize,
  682. for_save,
  683. )
  684. c = self.copy()
  685. c.is_summary = summarize
  686. c.lhs = lhs
  687. c.rhs = rhs
  688. return c
  689. @cached_property
  690. def allowed_default(self):
  691. return self.lhs.allowed_default and self.rhs.allowed_default
  692. class DurationExpression(CombinedExpression):
  693. def compile(self, side, compiler, connection):
  694. try:
  695. output = side.output_field
  696. except FieldError:
  697. pass
  698. else:
  699. if output.get_internal_type() == "DurationField":
  700. sql, params = compiler.compile(side)
  701. return connection.ops.format_for_duration_arithmetic(sql), params
  702. return compiler.compile(side)
  703. def as_sql(self, compiler, connection):
  704. if connection.features.has_native_duration_field:
  705. return super().as_sql(compiler, connection)
  706. connection.ops.check_expression_support(self)
  707. expressions = []
  708. expression_params = []
  709. sql, params = self.compile(self.lhs, compiler, connection)
  710. expressions.append(sql)
  711. expression_params.extend(params)
  712. sql, params = self.compile(self.rhs, compiler, connection)
  713. expressions.append(sql)
  714. expression_params.extend(params)
  715. # order of precedence
  716. expression_wrapper = "(%s)"
  717. sql = connection.ops.combine_duration_expression(self.connector, expressions)
  718. return expression_wrapper % sql, expression_params
  719. def as_sqlite(self, compiler, connection, **extra_context):
  720. sql, params = self.as_sql(compiler, connection, **extra_context)
  721. if self.connector in {Combinable.MUL, Combinable.DIV}:
  722. try:
  723. lhs_type = self.lhs.output_field.get_internal_type()
  724. rhs_type = self.rhs.output_field.get_internal_type()
  725. except (AttributeError, FieldError):
  726. pass
  727. else:
  728. allowed_fields = {
  729. "DecimalField",
  730. "DurationField",
  731. "FloatField",
  732. "IntegerField",
  733. }
  734. if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
  735. raise DatabaseError(
  736. f"Invalid arguments for operator {self.connector}."
  737. )
  738. return sql, params
  739. class TemporalSubtraction(CombinedExpression):
  740. output_field = fields.DurationField()
  741. def __init__(self, lhs, rhs):
  742. super().__init__(lhs, self.SUB, rhs)
  743. def as_sql(self, compiler, connection):
  744. connection.ops.check_expression_support(self)
  745. lhs = compiler.compile(self.lhs)
  746. rhs = compiler.compile(self.rhs)
  747. return connection.ops.subtract_temporals(
  748. self.lhs.output_field.get_internal_type(), lhs, rhs
  749. )
  750. @deconstructible(path="django.db.models.F")
  751. class F(Combinable):
  752. """An object capable of resolving references to existing query objects."""
  753. allowed_default = False
  754. def __init__(self, name):
  755. """
  756. Arguments:
  757. * name: the name of the field this expression references
  758. """
  759. self.name = name
  760. def __repr__(self):
  761. return "{}({})".format(self.__class__.__name__, self.name)
  762. def __getitem__(self, subscript):
  763. return Sliced(self, subscript)
  764. def __contains__(self, other):
  765. # Disable old-style iteration protocol inherited from implementing
  766. # __getitem__() to prevent this method from hanging.
  767. raise TypeError(f"argument of type '{self.__class__.__name__}' is not iterable")
  768. def resolve_expression(
  769. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  770. ):
  771. return query.resolve_ref(self.name, allow_joins, reuse, summarize)
  772. def replace_expressions(self, replacements):
  773. return replacements.get(self, self)
  774. def asc(self, **kwargs):
  775. return OrderBy(self, **kwargs)
  776. def desc(self, **kwargs):
  777. return OrderBy(self, descending=True, **kwargs)
  778. def __eq__(self, other):
  779. return self.__class__ == other.__class__ and self.name == other.name
  780. def __hash__(self):
  781. return hash(self.name)
  782. def copy(self):
  783. return copy.copy(self)
  784. class ResolvedOuterRef(F):
  785. """
  786. An object that contains a reference to an outer query.
  787. In this case, the reference to the outer query has been resolved because
  788. the inner query has been used as a subquery.
  789. """
  790. contains_aggregate = False
  791. contains_over_clause = False
  792. def as_sql(self, *args, **kwargs):
  793. raise ValueError(
  794. "This queryset contains a reference to an outer query and may "
  795. "only be used in a subquery."
  796. )
  797. def resolve_expression(self, *args, **kwargs):
  798. col = super().resolve_expression(*args, **kwargs)
  799. if col.contains_over_clause:
  800. raise NotSupportedError(
  801. f"Referencing outer query window expression is not supported: "
  802. f"{self.name}."
  803. )
  804. # FIXME: Rename possibly_multivalued to multivalued and fix detection
  805. # for non-multivalued JOINs (e.g. foreign key fields). This should take
  806. # into account only many-to-many and one-to-many relationships.
  807. col.possibly_multivalued = LOOKUP_SEP in self.name
  808. return col
  809. def relabeled_clone(self, relabels):
  810. return self
  811. def get_group_by_cols(self):
  812. return []
  813. class OuterRef(F):
  814. contains_aggregate = False
  815. contains_over_clause = False
  816. def resolve_expression(self, *args, **kwargs):
  817. if isinstance(self.name, self.__class__):
  818. return self.name
  819. return ResolvedOuterRef(self.name)
  820. def relabeled_clone(self, relabels):
  821. return self
  822. class Sliced(F):
  823. """
  824. An object that contains a slice of an F expression.
  825. Object resolves the column on which the slicing is applied, and then
  826. applies the slicing if possible.
  827. """
  828. def __init__(self, obj, subscript):
  829. super().__init__(obj.name)
  830. self.obj = obj
  831. if isinstance(subscript, int):
  832. if subscript < 0:
  833. raise ValueError("Negative indexing is not supported.")
  834. self.start = subscript + 1
  835. self.length = 1
  836. elif isinstance(subscript, slice):
  837. if (subscript.start is not None and subscript.start < 0) or (
  838. subscript.stop is not None and subscript.stop < 0
  839. ):
  840. raise ValueError("Negative indexing is not supported.")
  841. if subscript.step is not None:
  842. raise ValueError("Step argument is not supported.")
  843. if subscript.stop and subscript.start and subscript.stop < subscript.start:
  844. raise ValueError("Slice stop must be greater than slice start.")
  845. self.start = 1 if subscript.start is None else subscript.start + 1
  846. if subscript.stop is None:
  847. self.length = None
  848. else:
  849. self.length = subscript.stop - (subscript.start or 0)
  850. else:
  851. raise TypeError("Argument to slice must be either int or slice instance.")
  852. def __repr__(self):
  853. start = self.start - 1
  854. stop = None if self.length is None else start + self.length
  855. subscript = slice(start, stop)
  856. return f"{self.__class__.__qualname__}({self.obj!r}, {subscript!r})"
  857. def resolve_expression(
  858. self,
  859. query=None,
  860. allow_joins=True,
  861. reuse=None,
  862. summarize=False,
  863. for_save=False,
  864. ):
  865. resolved = query.resolve_ref(self.name, allow_joins, reuse, summarize)
  866. if isinstance(self.obj, (OuterRef, self.__class__)):
  867. expr = self.obj.resolve_expression(
  868. query, allow_joins, reuse, summarize, for_save
  869. )
  870. else:
  871. expr = resolved
  872. return resolved.output_field.slice_expression(expr, self.start, self.length)
  873. @deconstructible(path="django.db.models.Func")
  874. class Func(SQLiteNumericMixin, Expression):
  875. """An SQL function call."""
  876. function = None
  877. template = "%(function)s(%(expressions)s)"
  878. arg_joiner = ", "
  879. arity = None # The number of arguments the function accepts.
  880. def __init__(self, *expressions, output_field=None, **extra):
  881. if self.arity is not None and len(expressions) != self.arity:
  882. raise TypeError(
  883. "'%s' takes exactly %s %s (%s given)"
  884. % (
  885. self.__class__.__name__,
  886. self.arity,
  887. "argument" if self.arity == 1 else "arguments",
  888. len(expressions),
  889. )
  890. )
  891. super().__init__(output_field=output_field)
  892. self.source_expressions = self._parse_expressions(*expressions)
  893. self.extra = extra
  894. def __repr__(self):
  895. args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
  896. extra = {**self.extra, **self._get_repr_options()}
  897. if extra:
  898. extra = ", ".join(
  899. str(key) + "=" + str(val) for key, val in sorted(extra.items())
  900. )
  901. return "{}({}, {})".format(self.__class__.__name__, args, extra)
  902. return "{}({})".format(self.__class__.__name__, args)
  903. def _get_repr_options(self):
  904. """Return a dict of extra __init__() options to include in the repr."""
  905. return {}
  906. def get_source_expressions(self):
  907. return self.source_expressions
  908. def set_source_expressions(self, exprs):
  909. self.source_expressions = exprs
  910. def resolve_expression(
  911. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  912. ):
  913. c = self.copy()
  914. c.is_summary = summarize
  915. for pos, arg in enumerate(c.source_expressions):
  916. c.source_expressions[pos] = arg.resolve_expression(
  917. query, allow_joins, reuse, summarize, for_save
  918. )
  919. return c
  920. def as_sql(
  921. self,
  922. compiler,
  923. connection,
  924. function=None,
  925. template=None,
  926. arg_joiner=None,
  927. **extra_context,
  928. ):
  929. connection.ops.check_expression_support(self)
  930. sql_parts = []
  931. params = []
  932. for arg in self.source_expressions:
  933. try:
  934. arg_sql, arg_params = compiler.compile(arg)
  935. except EmptyResultSet:
  936. empty_result_set_value = getattr(
  937. arg, "empty_result_set_value", NotImplemented
  938. )
  939. if empty_result_set_value is NotImplemented:
  940. raise
  941. arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
  942. except FullResultSet:
  943. arg_sql, arg_params = compiler.compile(Value(True))
  944. sql_parts.append(arg_sql)
  945. params.extend(arg_params)
  946. data = {**self.extra, **extra_context}
  947. # Use the first supplied value in this order: the parameter to this
  948. # method, a value supplied in __init__()'s **extra (the value in
  949. # `data`), or the value defined on the class.
  950. if function is not None:
  951. data["function"] = function
  952. else:
  953. data.setdefault("function", self.function)
  954. template = template or data.get("template", self.template)
  955. arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
  956. data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
  957. return template % data, params
  958. def copy(self):
  959. copy = super().copy()
  960. copy.source_expressions = self.source_expressions[:]
  961. copy.extra = self.extra.copy()
  962. return copy
  963. @cached_property
  964. def allowed_default(self):
  965. return all(expression.allowed_default for expression in self.source_expressions)
  966. @deconstructible(path="django.db.models.Value")
  967. class Value(SQLiteNumericMixin, Expression):
  968. """Represent a wrapped value as a node within an expression."""
  969. # Provide a default value for `for_save` in order to allow unresolved
  970. # instances to be compiled until a decision is taken in #25425.
  971. for_save = False
  972. allowed_default = True
  973. def __init__(self, value, output_field=None):
  974. """
  975. Arguments:
  976. * value: the value this expression represents. The value will be
  977. added into the sql parameter list and properly quoted.
  978. * output_field: an instance of the model field type that this
  979. expression will return, such as IntegerField() or CharField().
  980. """
  981. super().__init__(output_field=output_field)
  982. self.value = value
  983. def __repr__(self):
  984. return f"{self.__class__.__name__}({self.value!r})"
  985. def as_sql(self, compiler, connection):
  986. connection.ops.check_expression_support(self)
  987. val = self.value
  988. output_field = self._output_field_or_none
  989. if output_field is not None:
  990. if self.for_save:
  991. val = output_field.get_db_prep_save(val, connection=connection)
  992. else:
  993. val = output_field.get_db_prep_value(val, connection=connection)
  994. if hasattr(output_field, "get_placeholder"):
  995. return output_field.get_placeholder(val, compiler, connection), [val]
  996. if val is None:
  997. # oracledb does not always convert None to the appropriate
  998. # NULL type (like in case expressions using numbers), so we
  999. # use a literal SQL NULL
  1000. return "NULL", []
  1001. return "%s", [val]
  1002. def resolve_expression(
  1003. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1004. ):
  1005. c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  1006. c.for_save = for_save
  1007. return c
  1008. def get_group_by_cols(self):
  1009. return []
  1010. def _resolve_output_field(self):
  1011. if isinstance(self.value, str):
  1012. return fields.CharField()
  1013. if isinstance(self.value, bool):
  1014. return fields.BooleanField()
  1015. if isinstance(self.value, int):
  1016. return fields.IntegerField()
  1017. if isinstance(self.value, float):
  1018. return fields.FloatField()
  1019. if isinstance(self.value, datetime.datetime):
  1020. return fields.DateTimeField()
  1021. if isinstance(self.value, datetime.date):
  1022. return fields.DateField()
  1023. if isinstance(self.value, datetime.time):
  1024. return fields.TimeField()
  1025. if isinstance(self.value, datetime.timedelta):
  1026. return fields.DurationField()
  1027. if isinstance(self.value, Decimal):
  1028. return fields.DecimalField()
  1029. if isinstance(self.value, bytes):
  1030. return fields.BinaryField()
  1031. if isinstance(self.value, UUID):
  1032. return fields.UUIDField()
  1033. @property
  1034. def empty_result_set_value(self):
  1035. return self.value
  1036. class RawSQL(Expression):
  1037. allowed_default = True
  1038. def __init__(self, sql, params, output_field=None):
  1039. if output_field is None:
  1040. output_field = fields.Field()
  1041. self.sql, self.params = sql, params
  1042. super().__init__(output_field=output_field)
  1043. def __repr__(self):
  1044. return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
  1045. def as_sql(self, compiler, connection):
  1046. return "(%s)" % self.sql, self.params
  1047. def get_group_by_cols(self):
  1048. return [self]
  1049. def resolve_expression(
  1050. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1051. ):
  1052. # Resolve parents fields used in raw SQL.
  1053. if query.model:
  1054. for parent in query.model._meta.all_parents:
  1055. for parent_field in parent._meta.local_fields:
  1056. if parent_field.column.lower() in self.sql.lower():
  1057. query.resolve_ref(
  1058. parent_field.name, allow_joins, reuse, summarize
  1059. )
  1060. break
  1061. return super().resolve_expression(
  1062. query, allow_joins, reuse, summarize, for_save
  1063. )
  1064. class Star(Expression):
  1065. def __repr__(self):
  1066. return "'*'"
  1067. def as_sql(self, compiler, connection):
  1068. return "*", []
  1069. class DatabaseDefault(Expression):
  1070. """
  1071. Expression to use DEFAULT keyword during insert otherwise the underlying expression.
  1072. """
  1073. def __init__(self, expression, output_field=None):
  1074. super().__init__(output_field)
  1075. self.expression = expression
  1076. def get_source_expressions(self):
  1077. return [self.expression]
  1078. def set_source_expressions(self, exprs):
  1079. (self.expression,) = exprs
  1080. def resolve_expression(
  1081. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1082. ):
  1083. resolved_expression = self.expression.resolve_expression(
  1084. query=query,
  1085. allow_joins=allow_joins,
  1086. reuse=reuse,
  1087. summarize=summarize,
  1088. for_save=for_save,
  1089. )
  1090. # Defaults used outside an INSERT context should resolve to their
  1091. # underlying expression.
  1092. if not for_save:
  1093. return resolved_expression
  1094. return DatabaseDefault(
  1095. resolved_expression, output_field=self._output_field_or_none
  1096. )
  1097. def as_sql(self, compiler, connection):
  1098. if not connection.features.supports_default_keyword_in_insert:
  1099. return compiler.compile(self.expression)
  1100. return "DEFAULT", []
  1101. class Col(Expression):
  1102. contains_column_references = True
  1103. possibly_multivalued = False
  1104. def __init__(self, alias, target, output_field=None):
  1105. if output_field is None:
  1106. output_field = target
  1107. super().__init__(output_field=output_field)
  1108. self.alias, self.target = alias, target
  1109. def __repr__(self):
  1110. alias, target = self.alias, self.target
  1111. identifiers = (alias, str(target)) if alias else (str(target),)
  1112. return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
  1113. def as_sql(self, compiler, connection):
  1114. alias, column = self.alias, self.target.column
  1115. identifiers = (alias, column) if alias else (column,)
  1116. sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
  1117. return sql, []
  1118. def relabeled_clone(self, relabels):
  1119. if self.alias is None:
  1120. return self
  1121. return self.__class__(
  1122. relabels.get(self.alias, self.alias), self.target, self.output_field
  1123. )
  1124. def get_group_by_cols(self):
  1125. return [self]
  1126. def get_db_converters(self, connection):
  1127. if self.target == self.output_field:
  1128. return self.output_field.get_db_converters(connection)
  1129. return self.output_field.get_db_converters(
  1130. connection
  1131. ) + self.target.get_db_converters(connection)
  1132. class Ref(Expression):
  1133. """
  1134. Reference to column alias of the query. For example, Ref('sum_cost') in
  1135. qs.annotate(sum_cost=Sum('cost')) query.
  1136. """
  1137. def __init__(self, refs, source):
  1138. super().__init__()
  1139. self.refs, self.source = refs, source
  1140. def __repr__(self):
  1141. return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
  1142. def get_source_expressions(self):
  1143. return [self.source]
  1144. def set_source_expressions(self, exprs):
  1145. (self.source,) = exprs
  1146. def resolve_expression(
  1147. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1148. ):
  1149. # The sub-expression `source` has already been resolved, as this is
  1150. # just a reference to the name of `source`.
  1151. return self
  1152. def get_refs(self):
  1153. return {self.refs}
  1154. def relabeled_clone(self, relabels):
  1155. clone = self.copy()
  1156. clone.source = self.source.relabeled_clone(relabels)
  1157. return clone
  1158. def as_sql(self, compiler, connection):
  1159. return connection.ops.quote_name(self.refs), []
  1160. def get_group_by_cols(self):
  1161. return [self]
  1162. class ExpressionList(Func):
  1163. """
  1164. An expression containing multiple expressions. Can be used to provide a
  1165. list of expressions as an argument to another expression, like a partition
  1166. clause.
  1167. """
  1168. template = "%(expressions)s"
  1169. def __str__(self):
  1170. return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
  1171. def as_sql(self, *args, **kwargs):
  1172. if not self.source_expressions:
  1173. return "", ()
  1174. return super().as_sql(*args, **kwargs)
  1175. def as_sqlite(self, compiler, connection, **extra_context):
  1176. # Casting to numeric is unnecessary.
  1177. return self.as_sql(compiler, connection, **extra_context)
  1178. def get_group_by_cols(self):
  1179. group_by_cols = []
  1180. for expr in self.get_source_expressions():
  1181. group_by_cols.extend(expr.get_group_by_cols())
  1182. return group_by_cols
  1183. class OrderByList(ExpressionList):
  1184. allowed_default = False
  1185. template = "ORDER BY %(expressions)s"
  1186. def __init__(self, *expressions, **extra):
  1187. expressions = (
  1188. (
  1189. OrderBy(F(expr[1:]), descending=True)
  1190. if isinstance(expr, str) and expr[0] == "-"
  1191. else expr
  1192. )
  1193. for expr in expressions
  1194. )
  1195. super().__init__(*expressions, **extra)
  1196. @deconstructible(path="django.db.models.ExpressionWrapper")
  1197. class ExpressionWrapper(SQLiteNumericMixin, Expression):
  1198. """
  1199. An expression that can wrap another expression so that it can provide
  1200. extra context to the inner expression, such as the output_field.
  1201. """
  1202. def __init__(self, expression, output_field):
  1203. super().__init__(output_field=output_field)
  1204. self.expression = expression
  1205. def set_source_expressions(self, exprs):
  1206. self.expression = exprs[0]
  1207. def get_source_expressions(self):
  1208. return [self.expression]
  1209. def get_group_by_cols(self):
  1210. if isinstance(self.expression, Expression):
  1211. expression = self.expression.copy()
  1212. expression.output_field = self.output_field
  1213. return expression.get_group_by_cols()
  1214. # For non-expressions e.g. an SQL WHERE clause, the entire
  1215. # `expression` must be included in the GROUP BY clause.
  1216. return super().get_group_by_cols()
  1217. def as_sql(self, compiler, connection):
  1218. return compiler.compile(self.expression)
  1219. def __repr__(self):
  1220. return "{}({})".format(self.__class__.__name__, self.expression)
  1221. @property
  1222. def allowed_default(self):
  1223. return self.expression.allowed_default
  1224. class NegatedExpression(ExpressionWrapper):
  1225. """The logical negation of a conditional expression."""
  1226. def __init__(self, expression):
  1227. super().__init__(expression, output_field=fields.BooleanField())
  1228. def __invert__(self):
  1229. return self.expression.copy()
  1230. def as_sql(self, compiler, connection):
  1231. try:
  1232. sql, params = super().as_sql(compiler, connection)
  1233. except EmptyResultSet:
  1234. features = compiler.connection.features
  1235. if not features.supports_boolean_expr_in_select_clause:
  1236. return "1=1", ()
  1237. return compiler.compile(Value(True))
  1238. ops = compiler.connection.ops
  1239. # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
  1240. # to be compared to another expression unless they're wrapped in a CASE
  1241. # WHEN.
  1242. if not ops.conditional_expression_supported_in_where_clause(self.expression):
  1243. return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
  1244. return f"NOT {sql}", params
  1245. def resolve_expression(
  1246. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1247. ):
  1248. resolved = super().resolve_expression(
  1249. query, allow_joins, reuse, summarize, for_save
  1250. )
  1251. if not getattr(resolved.expression, "conditional", False):
  1252. raise TypeError("Cannot negate non-conditional expressions.")
  1253. return resolved
  1254. def select_format(self, compiler, sql, params):
  1255. # Wrap boolean expressions with a CASE WHEN expression if a database
  1256. # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
  1257. # GROUP BY list.
  1258. expression_supported_in_where_clause = (
  1259. compiler.connection.ops.conditional_expression_supported_in_where_clause
  1260. )
  1261. if (
  1262. not compiler.connection.features.supports_boolean_expr_in_select_clause
  1263. # Avoid double wrapping.
  1264. and expression_supported_in_where_clause(self.expression)
  1265. ):
  1266. sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
  1267. return sql, params
  1268. @deconstructible(path="django.db.models.When")
  1269. class When(Expression):
  1270. template = "WHEN %(condition)s THEN %(result)s"
  1271. # This isn't a complete conditional expression, must be used in Case().
  1272. conditional = False
  1273. def __init__(self, condition=None, then=None, **lookups):
  1274. if lookups:
  1275. if condition is None:
  1276. condition, lookups = Q(**lookups), None
  1277. elif getattr(condition, "conditional", False):
  1278. condition, lookups = Q(condition, **lookups), None
  1279. if condition is None or not getattr(condition, "conditional", False) or lookups:
  1280. raise TypeError(
  1281. "When() supports a Q object, a boolean expression, or lookups "
  1282. "as a condition."
  1283. )
  1284. if isinstance(condition, Q) and not condition:
  1285. raise ValueError("An empty Q() can't be used as a When() condition.")
  1286. super().__init__(output_field=None)
  1287. self.condition = condition
  1288. self.result = self._parse_expressions(then)[0]
  1289. def __str__(self):
  1290. return "WHEN %r THEN %r" % (self.condition, self.result)
  1291. def __repr__(self):
  1292. return "<%s: %s>" % (self.__class__.__name__, self)
  1293. def get_source_expressions(self):
  1294. return [self.condition, self.result]
  1295. def set_source_expressions(self, exprs):
  1296. self.condition, self.result = exprs
  1297. def get_source_fields(self):
  1298. # We're only interested in the fields of the result expressions.
  1299. return [self.result._output_field_or_none]
  1300. def resolve_expression(
  1301. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1302. ):
  1303. c = self.copy()
  1304. c.is_summary = summarize
  1305. if hasattr(c.condition, "resolve_expression"):
  1306. c.condition = c.condition.resolve_expression(
  1307. query, allow_joins, reuse, summarize, False
  1308. )
  1309. c.result = c.result.resolve_expression(
  1310. query, allow_joins, reuse, summarize, for_save
  1311. )
  1312. return c
  1313. def as_sql(self, compiler, connection, template=None, **extra_context):
  1314. connection.ops.check_expression_support(self)
  1315. template_params = extra_context
  1316. sql_params = []
  1317. condition_sql, condition_params = compiler.compile(self.condition)
  1318. template_params["condition"] = condition_sql
  1319. result_sql, result_params = compiler.compile(self.result)
  1320. template_params["result"] = result_sql
  1321. template = template or self.template
  1322. return template % template_params, (
  1323. *sql_params,
  1324. *condition_params,
  1325. *result_params,
  1326. )
  1327. def get_group_by_cols(self):
  1328. # This is not a complete expression and cannot be used in GROUP BY.
  1329. cols = []
  1330. for source in self.get_source_expressions():
  1331. cols.extend(source.get_group_by_cols())
  1332. return cols
  1333. @cached_property
  1334. def allowed_default(self):
  1335. return self.condition.allowed_default and self.result.allowed_default
  1336. @deconstructible(path="django.db.models.Case")
  1337. class Case(SQLiteNumericMixin, Expression):
  1338. """
  1339. An SQL searched CASE expression:
  1340. CASE
  1341. WHEN n > 0
  1342. THEN 'positive'
  1343. WHEN n < 0
  1344. THEN 'negative'
  1345. ELSE 'zero'
  1346. END
  1347. """
  1348. template = "CASE %(cases)s ELSE %(default)s END"
  1349. case_joiner = " "
  1350. def __init__(self, *cases, default=None, output_field=None, **extra):
  1351. if not all(isinstance(case, When) for case in cases):
  1352. raise TypeError("Positional arguments must all be When objects.")
  1353. super().__init__(output_field)
  1354. self.cases = list(cases)
  1355. self.default = self._parse_expressions(default)[0]
  1356. self.extra = extra
  1357. def __str__(self):
  1358. return "CASE %s, ELSE %r" % (
  1359. ", ".join(str(c) for c in self.cases),
  1360. self.default,
  1361. )
  1362. def __repr__(self):
  1363. return "<%s: %s>" % (self.__class__.__name__, self)
  1364. def get_source_expressions(self):
  1365. return self.cases + [self.default]
  1366. def set_source_expressions(self, exprs):
  1367. *self.cases, self.default = exprs
  1368. def resolve_expression(
  1369. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1370. ):
  1371. c = self.copy()
  1372. c.is_summary = summarize
  1373. for pos, case in enumerate(c.cases):
  1374. c.cases[pos] = case.resolve_expression(
  1375. query, allow_joins, reuse, summarize, for_save
  1376. )
  1377. c.default = c.default.resolve_expression(
  1378. query, allow_joins, reuse, summarize, for_save
  1379. )
  1380. return c
  1381. def copy(self):
  1382. c = super().copy()
  1383. c.cases = c.cases[:]
  1384. return c
  1385. def as_sql(
  1386. self, compiler, connection, template=None, case_joiner=None, **extra_context
  1387. ):
  1388. connection.ops.check_expression_support(self)
  1389. if not self.cases:
  1390. return compiler.compile(self.default)
  1391. template_params = {**self.extra, **extra_context}
  1392. case_parts = []
  1393. sql_params = []
  1394. default_sql, default_params = compiler.compile(self.default)
  1395. for case in self.cases:
  1396. try:
  1397. case_sql, case_params = compiler.compile(case)
  1398. except EmptyResultSet:
  1399. continue
  1400. except FullResultSet:
  1401. default_sql, default_params = compiler.compile(case.result)
  1402. break
  1403. case_parts.append(case_sql)
  1404. sql_params.extend(case_params)
  1405. if not case_parts:
  1406. return default_sql, default_params
  1407. case_joiner = case_joiner or self.case_joiner
  1408. template_params["cases"] = case_joiner.join(case_parts)
  1409. template_params["default"] = default_sql
  1410. sql_params.extend(default_params)
  1411. template = template or template_params.get("template", self.template)
  1412. sql = template % template_params
  1413. if self._output_field_or_none is not None:
  1414. sql = connection.ops.unification_cast_sql(self.output_field) % sql
  1415. return sql, sql_params
  1416. def get_group_by_cols(self):
  1417. if not self.cases:
  1418. return self.default.get_group_by_cols()
  1419. return super().get_group_by_cols()
  1420. @cached_property
  1421. def allowed_default(self):
  1422. return self.default.allowed_default and all(
  1423. case_.allowed_default for case_ in self.cases
  1424. )
  1425. class Subquery(BaseExpression, Combinable):
  1426. """
  1427. An explicit subquery. It may contain OuterRef() references to the outer
  1428. query which will be resolved when it is applied to that query.
  1429. """
  1430. template = "(%(subquery)s)"
  1431. contains_aggregate = False
  1432. empty_result_set_value = None
  1433. subquery = True
  1434. def __init__(self, queryset, output_field=None, **extra):
  1435. # Allow the usage of both QuerySet and sql.Query objects.
  1436. self.query = getattr(queryset, "query", queryset).clone()
  1437. self.query.subquery = True
  1438. self.extra = extra
  1439. super().__init__(output_field)
  1440. def get_source_expressions(self):
  1441. return [self.query]
  1442. def set_source_expressions(self, exprs):
  1443. self.query = exprs[0]
  1444. def _resolve_output_field(self):
  1445. return self.query.output_field
  1446. def copy(self):
  1447. clone = super().copy()
  1448. clone.query = clone.query.clone()
  1449. return clone
  1450. @property
  1451. def external_aliases(self):
  1452. return self.query.external_aliases
  1453. def get_external_cols(self):
  1454. return self.query.get_external_cols()
  1455. def as_sql(self, compiler, connection, template=None, **extra_context):
  1456. connection.ops.check_expression_support(self)
  1457. template_params = {**self.extra, **extra_context}
  1458. subquery_sql, sql_params = self.query.as_sql(compiler, connection)
  1459. template_params["subquery"] = subquery_sql[1:-1]
  1460. template = template or template_params.get("template", self.template)
  1461. sql = template % template_params
  1462. return sql, sql_params
  1463. def get_group_by_cols(self):
  1464. return self.query.get_group_by_cols(wrapper=self)
  1465. class Exists(Subquery):
  1466. template = "EXISTS(%(subquery)s)"
  1467. output_field = fields.BooleanField()
  1468. empty_result_set_value = False
  1469. def __init__(self, queryset, **kwargs):
  1470. super().__init__(queryset, **kwargs)
  1471. self.query = self.query.exists()
  1472. def select_format(self, compiler, sql, params):
  1473. # Wrap EXISTS() with a CASE WHEN expression if a database backend
  1474. # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
  1475. # BY list.
  1476. if not compiler.connection.features.supports_boolean_expr_in_select_clause:
  1477. sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
  1478. return sql, params
  1479. def as_sql(self, compiler, *args, **kwargs):
  1480. try:
  1481. return super().as_sql(compiler, *args, **kwargs)
  1482. except EmptyResultSet:
  1483. features = compiler.connection.features
  1484. if not features.supports_boolean_expr_in_select_clause:
  1485. return "1=0", ()
  1486. return compiler.compile(Value(False))
  1487. @deconstructible(path="django.db.models.OrderBy")
  1488. class OrderBy(Expression):
  1489. template = "%(expression)s %(ordering)s"
  1490. conditional = False
  1491. constraint_validation_compatible = False
  1492. def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
  1493. if nulls_first and nulls_last:
  1494. raise ValueError("nulls_first and nulls_last are mutually exclusive")
  1495. if nulls_first is False or nulls_last is False:
  1496. raise ValueError("nulls_first and nulls_last values must be True or None.")
  1497. self.nulls_first = nulls_first
  1498. self.nulls_last = nulls_last
  1499. self.descending = descending
  1500. if not hasattr(expression, "resolve_expression"):
  1501. raise ValueError("expression must be an expression type")
  1502. self.expression = expression
  1503. def __repr__(self):
  1504. return "{}({}, descending={})".format(
  1505. self.__class__.__name__, self.expression, self.descending
  1506. )
  1507. def set_source_expressions(self, exprs):
  1508. self.expression = exprs[0]
  1509. def get_source_expressions(self):
  1510. return [self.expression]
  1511. def as_sql(self, compiler, connection, template=None, **extra_context):
  1512. template = template or self.template
  1513. if connection.features.supports_order_by_nulls_modifier:
  1514. if self.nulls_last:
  1515. template = "%s NULLS LAST" % template
  1516. elif self.nulls_first:
  1517. template = "%s NULLS FIRST" % template
  1518. else:
  1519. if self.nulls_last and not (
  1520. self.descending and connection.features.order_by_nulls_first
  1521. ):
  1522. template = "%%(expression)s IS NULL, %s" % template
  1523. elif self.nulls_first and not (
  1524. not self.descending and connection.features.order_by_nulls_first
  1525. ):
  1526. template = "%%(expression)s IS NOT NULL, %s" % template
  1527. connection.ops.check_expression_support(self)
  1528. expression_sql, params = compiler.compile(self.expression)
  1529. placeholders = {
  1530. "expression": expression_sql,
  1531. "ordering": "DESC" if self.descending else "ASC",
  1532. **extra_context,
  1533. }
  1534. params *= template.count("%(expression)s")
  1535. return (template % placeholders).rstrip(), params
  1536. def as_oracle(self, compiler, connection):
  1537. # Oracle < 23c doesn't allow ORDER BY EXISTS() or filters unless it's
  1538. # wrapped in a CASE WHEN.
  1539. if (
  1540. not connection.features.supports_boolean_expr_in_select_clause
  1541. and connection.ops.conditional_expression_supported_in_where_clause(
  1542. self.expression
  1543. )
  1544. ):
  1545. copy = self.copy()
  1546. copy.expression = Case(
  1547. When(self.expression, then=True),
  1548. default=False,
  1549. )
  1550. return copy.as_sql(compiler, connection)
  1551. return self.as_sql(compiler, connection)
  1552. def get_group_by_cols(self):
  1553. cols = []
  1554. for source in self.get_source_expressions():
  1555. cols.extend(source.get_group_by_cols())
  1556. return cols
  1557. def reverse_ordering(self):
  1558. self.descending = not self.descending
  1559. if self.nulls_first:
  1560. self.nulls_last = True
  1561. self.nulls_first = None
  1562. elif self.nulls_last:
  1563. self.nulls_first = True
  1564. self.nulls_last = None
  1565. return self
  1566. def asc(self):
  1567. self.descending = False
  1568. def desc(self):
  1569. self.descending = True
  1570. class Window(SQLiteNumericMixin, Expression):
  1571. template = "%(expression)s OVER (%(window)s)"
  1572. # Although the main expression may either be an aggregate or an
  1573. # expression with an aggregate function, the GROUP BY that will
  1574. # be introduced in the query as a result is not desired.
  1575. contains_aggregate = False
  1576. contains_over_clause = True
  1577. def __init__(
  1578. self,
  1579. expression,
  1580. partition_by=None,
  1581. order_by=None,
  1582. frame=None,
  1583. output_field=None,
  1584. ):
  1585. self.partition_by = partition_by
  1586. self.order_by = order_by
  1587. self.frame = frame
  1588. if not getattr(expression, "window_compatible", False):
  1589. raise ValueError(
  1590. "Expression '%s' isn't compatible with OVER clauses."
  1591. % expression.__class__.__name__
  1592. )
  1593. if self.partition_by is not None:
  1594. if not isinstance(self.partition_by, (tuple, list)):
  1595. self.partition_by = (self.partition_by,)
  1596. self.partition_by = ExpressionList(*self.partition_by)
  1597. if self.order_by is not None:
  1598. if isinstance(self.order_by, (list, tuple)):
  1599. self.order_by = OrderByList(*self.order_by)
  1600. elif isinstance(self.order_by, (BaseExpression, str)):
  1601. self.order_by = OrderByList(self.order_by)
  1602. else:
  1603. raise ValueError(
  1604. "Window.order_by must be either a string reference to a "
  1605. "field, an expression, or a list or tuple of them."
  1606. )
  1607. super().__init__(output_field=output_field)
  1608. self.source_expression = self._parse_expressions(expression)[0]
  1609. def _resolve_output_field(self):
  1610. return self.source_expression.output_field
  1611. def get_source_expressions(self):
  1612. return [self.source_expression, self.partition_by, self.order_by, self.frame]
  1613. def set_source_expressions(self, exprs):
  1614. self.source_expression, self.partition_by, self.order_by, self.frame = exprs
  1615. def as_sql(self, compiler, connection, template=None):
  1616. connection.ops.check_expression_support(self)
  1617. if not connection.features.supports_over_clause:
  1618. raise NotSupportedError("This backend does not support window expressions.")
  1619. expr_sql, params = compiler.compile(self.source_expression)
  1620. window_sql, window_params = [], ()
  1621. if self.partition_by is not None:
  1622. sql_expr, sql_params = self.partition_by.as_sql(
  1623. compiler=compiler,
  1624. connection=connection,
  1625. template="PARTITION BY %(expressions)s",
  1626. )
  1627. window_sql.append(sql_expr)
  1628. window_params += tuple(sql_params)
  1629. if self.order_by is not None:
  1630. order_sql, order_params = compiler.compile(self.order_by)
  1631. window_sql.append(order_sql)
  1632. window_params += tuple(order_params)
  1633. if self.frame:
  1634. frame_sql, frame_params = compiler.compile(self.frame)
  1635. window_sql.append(frame_sql)
  1636. window_params += tuple(frame_params)
  1637. template = template or self.template
  1638. return (
  1639. template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
  1640. (*params, *window_params),
  1641. )
  1642. def as_sqlite(self, compiler, connection):
  1643. if isinstance(self.output_field, fields.DecimalField):
  1644. # Casting to numeric must be outside of the window expression.
  1645. copy = self.copy()
  1646. source_expressions = copy.get_source_expressions()
  1647. source_expressions[0].output_field = fields.FloatField()
  1648. copy.set_source_expressions(source_expressions)
  1649. return super(Window, copy).as_sqlite(compiler, connection)
  1650. return self.as_sql(compiler, connection)
  1651. def __str__(self):
  1652. return "{} OVER ({}{}{})".format(
  1653. str(self.source_expression),
  1654. "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
  1655. str(self.order_by or ""),
  1656. str(self.frame or ""),
  1657. )
  1658. def __repr__(self):
  1659. return "<%s: %s>" % (self.__class__.__name__, self)
  1660. def get_group_by_cols(self):
  1661. group_by_cols = []
  1662. if self.partition_by:
  1663. group_by_cols.extend(self.partition_by.get_group_by_cols())
  1664. if self.order_by is not None:
  1665. group_by_cols.extend(self.order_by.get_group_by_cols())
  1666. return group_by_cols
  1667. class WindowFrameExclusion(Enum):
  1668. CURRENT_ROW = "CURRENT ROW"
  1669. GROUP = "GROUP"
  1670. TIES = "TIES"
  1671. NO_OTHERS = "NO OTHERS"
  1672. def __repr__(self):
  1673. return f"{self.__class__.__qualname__}.{self._name_}"
  1674. class WindowFrame(Expression):
  1675. """
  1676. Model the frame clause in window expressions. There are two types of frame
  1677. clauses which are subclasses, however, all processing and validation (by no
  1678. means intended to be complete) is done here. Thus, providing an end for a
  1679. frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
  1680. row in the frame).
  1681. """
  1682. template = "%(frame_type)s BETWEEN %(start)s AND %(end)s%(exclude)s"
  1683. def __init__(self, start=None, end=None, exclusion=None):
  1684. self.start = Value(start)
  1685. self.end = Value(end)
  1686. if not isinstance(exclusion, (NoneType, WindowFrameExclusion)):
  1687. raise TypeError(
  1688. f"{self.__class__.__qualname__}.exclusion must be a "
  1689. "WindowFrameExclusion instance."
  1690. )
  1691. self.exclusion = exclusion
  1692. def set_source_expressions(self, exprs):
  1693. self.start, self.end = exprs
  1694. def get_source_expressions(self):
  1695. return [self.start, self.end]
  1696. def get_exclusion(self):
  1697. if self.exclusion is None:
  1698. return ""
  1699. return f" EXCLUDE {self.exclusion.value}"
  1700. def as_sql(self, compiler, connection):
  1701. connection.ops.check_expression_support(self)
  1702. start, end = self.window_frame_start_end(
  1703. connection, self.start.value, self.end.value
  1704. )
  1705. if self.exclusion and not connection.features.supports_frame_exclusion:
  1706. raise NotSupportedError(
  1707. "This backend does not support window frame exclusions."
  1708. )
  1709. return (
  1710. self.template
  1711. % {
  1712. "frame_type": self.frame_type,
  1713. "start": start,
  1714. "end": end,
  1715. "exclude": self.get_exclusion(),
  1716. },
  1717. [],
  1718. )
  1719. def __repr__(self):
  1720. return "<%s: %s>" % (self.__class__.__name__, self)
  1721. def get_group_by_cols(self):
  1722. return []
  1723. def __str__(self):
  1724. if self.start.value is not None and self.start.value < 0:
  1725. start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
  1726. elif self.start.value is not None and self.start.value == 0:
  1727. start = connection.ops.CURRENT_ROW
  1728. elif self.start.value is not None and self.start.value > 0:
  1729. start = "%d %s" % (self.start.value, connection.ops.FOLLOWING)
  1730. else:
  1731. start = connection.ops.UNBOUNDED_PRECEDING
  1732. if self.end.value is not None and self.end.value > 0:
  1733. end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
  1734. elif self.end.value is not None and self.end.value == 0:
  1735. end = connection.ops.CURRENT_ROW
  1736. elif self.end.value is not None and self.end.value < 0:
  1737. end = "%d %s" % (abs(self.end.value), connection.ops.PRECEDING)
  1738. else:
  1739. end = connection.ops.UNBOUNDED_FOLLOWING
  1740. return self.template % {
  1741. "frame_type": self.frame_type,
  1742. "start": start,
  1743. "end": end,
  1744. "exclude": self.get_exclusion(),
  1745. }
  1746. def window_frame_start_end(self, connection, start, end):
  1747. raise NotImplementedError("Subclasses must implement window_frame_start_end().")
  1748. class RowRange(WindowFrame):
  1749. frame_type = "ROWS"
  1750. def window_frame_start_end(self, connection, start, end):
  1751. return connection.ops.window_frame_rows_start_end(start, end)
  1752. class ValueRange(WindowFrame):
  1753. frame_type = "RANGE"
  1754. def window_frame_start_end(self, connection, start, end):
  1755. return connection.ops.window_frame_range_start_end(start, end)