| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085 |
- import copy
- import datetime
- import functools
- import inspect
- from collections import defaultdict
- from decimal import Decimal
- from enum import Enum
- from itertools import chain
- from types import NoneType
- from uuid import UUID
- from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
- from django.db import DatabaseError, NotSupportedError, connection
- from django.db.models import fields
- from django.db.models.constants import LOOKUP_SEP
- from django.db.models.query_utils import Q
- from django.utils.deconstruct import deconstructible
- from django.utils.functional import cached_property, classproperty
- from django.utils.hashable import make_hashable
- class SQLiteNumericMixin:
- """
- Some expressions with output_field=DecimalField() must be cast to
- numeric to be properly filtered.
- """
- def as_sqlite(self, compiler, connection, **extra_context):
- sql, params = self.as_sql(compiler, connection, **extra_context)
- try:
- if self.output_field.get_internal_type() == "DecimalField":
- sql = "(CAST(%s AS NUMERIC))" % sql
- except FieldError:
- pass
- return sql, params
- class Combinable:
- """
- Provide the ability to combine one or two objects with
- some connector. For example F('foo') + F('bar').
- """
- # Arithmetic connectors
- ADD = "+"
- SUB = "-"
- MUL = "*"
- DIV = "/"
- POW = "^"
- # The following is a quoted % operator - it is quoted because it can be
- # used in strings that also have parameter substitution.
- MOD = "%%"
- # Bitwise operators - note that these are generated by .bitand()
- # and .bitor(), the '&' and '|' are reserved for boolean operator
- # usage.
- BITAND = "&"
- BITOR = "|"
- BITLEFTSHIFT = "<<"
- BITRIGHTSHIFT = ">>"
- BITXOR = "#"
- def _combine(self, other, connector, reversed):
- if not hasattr(other, "resolve_expression"):
- # everything must be resolvable to an expression
- other = Value(other)
- if reversed:
- return CombinedExpression(other, connector, self)
- return CombinedExpression(self, connector, other)
- #############
- # OPERATORS #
- #############
- def __neg__(self):
- return self._combine(-1, self.MUL, False)
- def __add__(self, other):
- return self._combine(other, self.ADD, False)
- def __sub__(self, other):
- return self._combine(other, self.SUB, False)
- def __mul__(self, other):
- return self._combine(other, self.MUL, False)
- def __truediv__(self, other):
- return self._combine(other, self.DIV, False)
- def __mod__(self, other):
- return self._combine(other, self.MOD, False)
- def __pow__(self, other):
- return self._combine(other, self.POW, False)
- def __and__(self, other):
- if getattr(self, "conditional", False) and getattr(other, "conditional", False):
- return Q(self) & Q(other)
- raise NotImplementedError(
- "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
- )
- def bitand(self, other):
- return self._combine(other, self.BITAND, False)
- def bitleftshift(self, other):
- return self._combine(other, self.BITLEFTSHIFT, False)
- def bitrightshift(self, other):
- return self._combine(other, self.BITRIGHTSHIFT, False)
- def __xor__(self, other):
- if getattr(self, "conditional", False) and getattr(other, "conditional", False):
- return Q(self) ^ Q(other)
- raise NotImplementedError(
- "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
- )
- def bitxor(self, other):
- return self._combine(other, self.BITXOR, False)
- def __or__(self, other):
- if getattr(self, "conditional", False) and getattr(other, "conditional", False):
- return Q(self) | Q(other)
- raise NotImplementedError(
- "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
- )
- def bitor(self, other):
- return self._combine(other, self.BITOR, False)
- def __radd__(self, other):
- return self._combine(other, self.ADD, True)
- def __rsub__(self, other):
- return self._combine(other, self.SUB, True)
- def __rmul__(self, other):
- return self._combine(other, self.MUL, True)
- def __rtruediv__(self, other):
- return self._combine(other, self.DIV, True)
- def __rmod__(self, other):
- return self._combine(other, self.MOD, True)
- def __rpow__(self, other):
- return self._combine(other, self.POW, True)
- def __rand__(self, other):
- raise NotImplementedError(
- "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
- )
- def __ror__(self, other):
- raise NotImplementedError(
- "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
- )
- def __rxor__(self, other):
- raise NotImplementedError(
- "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
- )
- def __invert__(self):
- return NegatedExpression(self)
- class BaseExpression:
- """Base class for all query expressions."""
- empty_result_set_value = NotImplemented
- # aggregate specific fields
- is_summary = False
- _output_field_resolved_to_none = False
- # Can the expression be used in a WHERE clause?
- filterable = True
- # Can the expression be used as a source expression in Window?
- window_compatible = False
- # Can the expression be used as a database default value?
- allowed_default = False
- # Can the expression be used during a constraint validation?
- constraint_validation_compatible = True
- def __init__(self, output_field=None):
- if output_field is not None:
- self.output_field = output_field
- def __getstate__(self):
- state = self.__dict__.copy()
- state.pop("convert_value", None)
- return state
- def get_db_converters(self, connection):
- return (
- []
- if self.convert_value is self._convert_value_noop
- else [self.convert_value]
- ) + self.output_field.get_db_converters(connection)
- def get_source_expressions(self):
- return []
- def set_source_expressions(self, exprs):
- assert not exprs
- def _parse_expressions(self, *expressions):
- return [
- (
- arg
- if hasattr(arg, "resolve_expression")
- else (F(arg) if isinstance(arg, str) else Value(arg))
- )
- for arg in expressions
- ]
- def as_sql(self, compiler, connection):
- """
- Responsible for returning a (sql, [params]) tuple to be included
- in the current query.
- Different backends can provide their own implementation, by
- providing an `as_{vendor}` method and patching the Expression:
- ```
- def override_as_sql(self, compiler, connection):
- # custom logic
- return super().as_sql(compiler, connection)
- setattr(Expression, 'as_' + connection.vendor, override_as_sql)
- ```
- Arguments:
- * compiler: the query compiler responsible for generating the query.
- Must have a compile method, returning a (sql, [params]) tuple.
- Calling compiler(value) will return a quoted `value`.
- * connection: the database connection used for the current query.
- Return: (sql, params)
- Where `sql` is a string containing ordered sql parameters to be
- replaced with the elements of the list `params`.
- """
- raise NotImplementedError("Subclasses must implement as_sql()")
- @cached_property
- def contains_aggregate(self):
- return any(
- expr and expr.contains_aggregate for expr in self.get_source_expressions()
- )
- @cached_property
- def contains_over_clause(self):
- return any(
- expr and expr.contains_over_clause for expr in self.get_source_expressions()
- )
- @cached_property
- def contains_column_references(self):
- return any(
- expr and expr.contains_column_references
- for expr in self.get_source_expressions()
- )
- @cached_property
- def contains_subquery(self):
- return any(
- expr and (getattr(expr, "subquery", False) or expr.contains_subquery)
- for expr in self.get_source_expressions()
- )
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- """
- Provide the chance to do any preprocessing or validation before being
- added to the query.
- Arguments:
- * query: the backend query implementation
- * allow_joins: boolean allowing or denying use of joins
- in this query
- * reuse: a set of reusable joins for multijoins
- * summarize: a terminal aggregate clause
- * for_save: whether this expression about to be used in a save or update
- Return: an Expression to be added to the query.
- """
- c = self.copy()
- c.is_summary = summarize
- c.set_source_expressions(
- [
- (
- expr.resolve_expression(query, allow_joins, reuse, summarize)
- if expr
- else None
- )
- for expr in c.get_source_expressions()
- ]
- )
- return c
- @property
- def conditional(self):
- return isinstance(self.output_field, fields.BooleanField)
- @property
- def field(self):
- return self.output_field
- @cached_property
- def output_field(self):
- """Return the output type of this expressions."""
- output_field = self._resolve_output_field()
- if output_field is None:
- self._output_field_resolved_to_none = True
- raise FieldError("Cannot resolve expression type, unknown output_field")
- return output_field
- @cached_property
- def _output_field_or_none(self):
- """
- Return the output field of this expression, or None if
- _resolve_output_field() didn't return an output type.
- """
- try:
- return self.output_field
- except FieldError:
- if not self._output_field_resolved_to_none:
- raise
- def _resolve_output_field(self):
- """
- Attempt to infer the output type of the expression.
- As a guess, if the output fields of all source fields match then simply
- infer the same type here.
- If a source's output field resolves to None, exclude it from this check.
- If all sources are None, then an error is raised higher up the stack in
- the output_field property.
- """
- # This guess is mostly a bad idea, but there is quite a lot of code
- # (especially 3rd party Func subclasses) that depend on it, we'd need a
- # deprecation path to fix it.
- sources_iter = (
- source for source in self.get_source_fields() if source is not None
- )
- for output_field in sources_iter:
- for source in sources_iter:
- if not isinstance(output_field, source.__class__):
- raise FieldError(
- "Expression contains mixed types: %s, %s. You must "
- "set output_field."
- % (
- output_field.__class__.__name__,
- source.__class__.__name__,
- )
- )
- return output_field
- @staticmethod
- def _convert_value_noop(value, expression, connection):
- return value
- @cached_property
- def convert_value(self):
- """
- Expressions provide their own converters because users have the option
- of manually specifying the output_field which may be a different type
- from the one the database returns.
- """
- field = self.output_field
- internal_type = field.get_internal_type()
- if internal_type == "FloatField":
- return lambda value, expression, connection: (
- None if value is None else float(value)
- )
- elif internal_type.endswith("IntegerField"):
- return lambda value, expression, connection: (
- None if value is None else int(value)
- )
- elif internal_type == "DecimalField":
- return lambda value, expression, connection: (
- None if value is None else Decimal(value)
- )
- return self._convert_value_noop
- def get_lookup(self, lookup):
- return self.output_field.get_lookup(lookup)
- def get_transform(self, name):
- return self.output_field.get_transform(name)
- def relabeled_clone(self, change_map):
- clone = self.copy()
- clone.set_source_expressions(
- [
- e.relabeled_clone(change_map) if e is not None else None
- for e in self.get_source_expressions()
- ]
- )
- return clone
- def replace_expressions(self, replacements):
- if not replacements:
- return self
- if replacement := replacements.get(self):
- return replacement
- if not (source_expressions := self.get_source_expressions()):
- return self
- clone = self.copy()
- clone.set_source_expressions(
- [
- expr.replace_expressions(replacements) if expr else None
- for expr in source_expressions
- ]
- )
- return clone
- def get_refs(self):
- refs = set()
- for expr in self.get_source_expressions():
- if expr is None:
- continue
- refs |= expr.get_refs()
- return refs
- def copy(self):
- return copy.copy(self)
- def prefix_references(self, prefix):
- clone = self.copy()
- clone.set_source_expressions(
- [
- (
- F(f"{prefix}{expr.name}")
- if isinstance(expr, F)
- else expr.prefix_references(prefix)
- )
- for expr in self.get_source_expressions()
- ]
- )
- return clone
- def get_group_by_cols(self):
- if not self.contains_aggregate:
- return [self]
- cols = []
- for source in self.get_source_expressions():
- cols.extend(source.get_group_by_cols())
- return cols
- def get_source_fields(self):
- """Return the underlying field types used by this aggregate."""
- return [e._output_field_or_none for e in self.get_source_expressions()]
- def asc(self, **kwargs):
- return OrderBy(self, **kwargs)
- def desc(self, **kwargs):
- return OrderBy(self, descending=True, **kwargs)
- def reverse_ordering(self):
- return self
- def flatten(self):
- """
- Recursively yield this expression and all subexpressions, in
- depth-first order.
- """
- yield self
- for expr in self.get_source_expressions():
- if expr:
- if hasattr(expr, "flatten"):
- yield from expr.flatten()
- else:
- yield expr
- def select_format(self, compiler, sql, params):
- """
- Custom format for select clauses. For example, EXISTS expressions need
- to be wrapped in CASE WHEN on Oracle.
- """
- if hasattr(self.output_field, "select_format"):
- return self.output_field.select_format(compiler, sql, params)
- return sql, params
- def get_expression_for_validation(self):
- # Ignore expressions that cannot be used during a constraint validation.
- if not getattr(self, "constraint_validation_compatible", True):
- try:
- (expression,) = self.get_source_expressions()
- except ValueError as e:
- raise ValueError(
- "Expressions with constraint_validation_compatible set to False "
- "must have only one source expression."
- ) from e
- else:
- return expression
- return self
- @deconstructible
- class Expression(BaseExpression, Combinable):
- """An expression that can be combined with other expressions."""
- @classproperty
- @functools.lru_cache(maxsize=128)
- def _constructor_signature(cls):
- return inspect.signature(cls.__init__)
- @cached_property
- def identity(self):
- args, kwargs = self._constructor_args
- signature = self._constructor_signature.bind_partial(self, *args, **kwargs)
- signature.apply_defaults()
- arguments = iter(signature.arguments.items())
- next(arguments)
- identity = [self.__class__]
- for arg, value in arguments:
- if isinstance(value, fields.Field):
- if value.name and value.model:
- value = (value.model._meta.label, value.name)
- else:
- value = type(value)
- else:
- value = make_hashable(value)
- identity.append((arg, value))
- return tuple(identity)
- def __eq__(self, other):
- if not isinstance(other, Expression):
- return NotImplemented
- return other.identity == self.identity
- def __hash__(self):
- return hash(self.identity)
- # Type inference for CombinedExpression.output_field.
- # Missing items will result in FieldError, by design.
- #
- # The current approach for NULL is based on lowest common denominator behavior
- # i.e. if one of the supported databases is raising an error (rather than
- # return NULL) for `val <op> NULL`, then Django raises FieldError.
- _connector_combinations = [
- # Numeric operations - operands of same type.
- # PositiveIntegerField should take precedence over IntegerField (except
- # subtraction).
- {
- connector: [
- (
- fields.PositiveIntegerField,
- fields.PositiveIntegerField,
- fields.PositiveIntegerField,
- ),
- ]
- for connector in (
- Combinable.ADD,
- Combinable.MUL,
- Combinable.DIV,
- Combinable.MOD,
- Combinable.POW,
- )
- },
- # Other numeric operands.
- {
- connector: [
- (fields.IntegerField, fields.IntegerField, fields.IntegerField),
- (fields.FloatField, fields.FloatField, fields.FloatField),
- (fields.DecimalField, fields.DecimalField, fields.DecimalField),
- ]
- for connector in (
- Combinable.ADD,
- Combinable.SUB,
- Combinable.MUL,
- # Behavior for DIV with integer arguments follows Postgres/SQLite,
- # not MySQL/Oracle.
- Combinable.DIV,
- Combinable.MOD,
- Combinable.POW,
- )
- },
- # Numeric operations - operands of different type.
- {
- connector: [
- (fields.IntegerField, fields.DecimalField, fields.DecimalField),
- (fields.DecimalField, fields.IntegerField, fields.DecimalField),
- (fields.IntegerField, fields.FloatField, fields.FloatField),
- (fields.FloatField, fields.IntegerField, fields.FloatField),
- ]
- for connector in (
- Combinable.ADD,
- Combinable.SUB,
- Combinable.MUL,
- Combinable.DIV,
- Combinable.MOD,
- )
- },
- # Bitwise operators.
- {
- connector: [
- (fields.IntegerField, fields.IntegerField, fields.IntegerField),
- ]
- for connector in (
- Combinable.BITAND,
- Combinable.BITOR,
- Combinable.BITLEFTSHIFT,
- Combinable.BITRIGHTSHIFT,
- Combinable.BITXOR,
- )
- },
- # Numeric with NULL.
- {
- connector: list(
- chain.from_iterable(
- [(field_type, NoneType, field_type), (NoneType, field_type, field_type)]
- for field_type in (
- fields.IntegerField,
- fields.DecimalField,
- fields.FloatField,
- )
- )
- )
- for connector in (
- Combinable.ADD,
- Combinable.SUB,
- Combinable.MUL,
- Combinable.DIV,
- Combinable.MOD,
- Combinable.POW,
- )
- },
- # Date/DateTimeField/DurationField/TimeField.
- {
- Combinable.ADD: [
- # Date/DateTimeField.
- (fields.DateField, fields.DurationField, fields.DateTimeField),
- (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
- (fields.DurationField, fields.DateField, fields.DateTimeField),
- (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
- # DurationField.
- (fields.DurationField, fields.DurationField, fields.DurationField),
- # TimeField.
- (fields.TimeField, fields.DurationField, fields.TimeField),
- (fields.DurationField, fields.TimeField, fields.TimeField),
- ],
- },
- {
- Combinable.SUB: [
- # Date/DateTimeField.
- (fields.DateField, fields.DurationField, fields.DateTimeField),
- (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
- (fields.DateField, fields.DateField, fields.DurationField),
- (fields.DateField, fields.DateTimeField, fields.DurationField),
- (fields.DateTimeField, fields.DateField, fields.DurationField),
- (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
- # DurationField.
- (fields.DurationField, fields.DurationField, fields.DurationField),
- # TimeField.
- (fields.TimeField, fields.DurationField, fields.TimeField),
- (fields.TimeField, fields.TimeField, fields.DurationField),
- ],
- },
- ]
- _connector_combinators = defaultdict(list)
- def register_combinable_fields(lhs, connector, rhs, result):
- """
- Register combinable types:
- lhs <connector> rhs -> result
- e.g.
- register_combinable_fields(
- IntegerField, Combinable.ADD, FloatField, FloatField
- )
- """
- _connector_combinators[connector].append((lhs, rhs, result))
- for d in _connector_combinations:
- for connector, field_types in d.items():
- for lhs, rhs, result in field_types:
- register_combinable_fields(lhs, connector, rhs, result)
- @functools.lru_cache(maxsize=128)
- def _resolve_combined_type(connector, lhs_type, rhs_type):
- combinators = _connector_combinators.get(connector, ())
- for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
- if issubclass(lhs_type, combinator_lhs_type) and issubclass(
- rhs_type, combinator_rhs_type
- ):
- return combined_type
- class CombinedExpression(SQLiteNumericMixin, Expression):
- def __init__(self, lhs, connector, rhs, output_field=None):
- super().__init__(output_field=output_field)
- self.connector = connector
- self.lhs = lhs
- self.rhs = rhs
- def __repr__(self):
- return "<{}: {}>".format(self.__class__.__name__, self)
- def __str__(self):
- return "{} {} {}".format(self.lhs, self.connector, self.rhs)
- def get_source_expressions(self):
- return [self.lhs, self.rhs]
- def set_source_expressions(self, exprs):
- self.lhs, self.rhs = exprs
- def _resolve_output_field(self):
- # We avoid using super() here for reasons given in
- # Expression._resolve_output_field()
- combined_type = _resolve_combined_type(
- self.connector,
- type(self.lhs._output_field_or_none),
- type(self.rhs._output_field_or_none),
- )
- if combined_type is None:
- raise FieldError(
- f"Cannot infer type of {self.connector!r} expression involving these "
- f"types: {self.lhs.output_field.__class__.__name__}, "
- f"{self.rhs.output_field.__class__.__name__}. You must set "
- f"output_field."
- )
- return combined_type()
- def as_sql(self, compiler, connection):
- expressions = []
- expression_params = []
- sql, params = compiler.compile(self.lhs)
- expressions.append(sql)
- expression_params.extend(params)
- sql, params = compiler.compile(self.rhs)
- expressions.append(sql)
- expression_params.extend(params)
- # order of precedence
- expression_wrapper = "(%s)"
- sql = connection.ops.combine_expression(self.connector, expressions)
- return expression_wrapper % sql, expression_params
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- lhs = self.lhs.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- rhs = self.rhs.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- if not isinstance(self, (DurationExpression, TemporalSubtraction)):
- try:
- lhs_type = lhs.output_field.get_internal_type()
- except (AttributeError, FieldError):
- lhs_type = None
- try:
- rhs_type = rhs.output_field.get_internal_type()
- except (AttributeError, FieldError):
- rhs_type = None
- if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
- return DurationExpression(
- self.lhs, self.connector, self.rhs
- ).resolve_expression(
- query,
- allow_joins,
- reuse,
- summarize,
- for_save,
- )
- datetime_fields = {"DateField", "DateTimeField", "TimeField"}
- if (
- self.connector == self.SUB
- and lhs_type in datetime_fields
- and lhs_type == rhs_type
- ):
- return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
- query,
- allow_joins,
- reuse,
- summarize,
- for_save,
- )
- c = self.copy()
- c.is_summary = summarize
- c.lhs = lhs
- c.rhs = rhs
- return c
- @cached_property
- def allowed_default(self):
- return self.lhs.allowed_default and self.rhs.allowed_default
- class DurationExpression(CombinedExpression):
- def compile(self, side, compiler, connection):
- try:
- output = side.output_field
- except FieldError:
- pass
- else:
- if output.get_internal_type() == "DurationField":
- sql, params = compiler.compile(side)
- return connection.ops.format_for_duration_arithmetic(sql), params
- return compiler.compile(side)
- def as_sql(self, compiler, connection):
- if connection.features.has_native_duration_field:
- return super().as_sql(compiler, connection)
- connection.ops.check_expression_support(self)
- expressions = []
- expression_params = []
- sql, params = self.compile(self.lhs, compiler, connection)
- expressions.append(sql)
- expression_params.extend(params)
- sql, params = self.compile(self.rhs, compiler, connection)
- expressions.append(sql)
- expression_params.extend(params)
- # order of precedence
- expression_wrapper = "(%s)"
- sql = connection.ops.combine_duration_expression(self.connector, expressions)
- return expression_wrapper % sql, expression_params
- def as_sqlite(self, compiler, connection, **extra_context):
- sql, params = self.as_sql(compiler, connection, **extra_context)
- if self.connector in {Combinable.MUL, Combinable.DIV}:
- try:
- lhs_type = self.lhs.output_field.get_internal_type()
- rhs_type = self.rhs.output_field.get_internal_type()
- except (AttributeError, FieldError):
- pass
- else:
- allowed_fields = {
- "DecimalField",
- "DurationField",
- "FloatField",
- "IntegerField",
- }
- if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
- raise DatabaseError(
- f"Invalid arguments for operator {self.connector}."
- )
- return sql, params
- class TemporalSubtraction(CombinedExpression):
- output_field = fields.DurationField()
- def __init__(self, lhs, rhs):
- super().__init__(lhs, self.SUB, rhs)
- def as_sql(self, compiler, connection):
- connection.ops.check_expression_support(self)
- lhs = compiler.compile(self.lhs)
- rhs = compiler.compile(self.rhs)
- return connection.ops.subtract_temporals(
- self.lhs.output_field.get_internal_type(), lhs, rhs
- )
- @deconstructible(path="django.db.models.F")
- class F(Combinable):
- """An object capable of resolving references to existing query objects."""
- allowed_default = False
- def __init__(self, name):
- """
- Arguments:
- * name: the name of the field this expression references
- """
- self.name = name
- def __repr__(self):
- return "{}({})".format(self.__class__.__name__, self.name)
- def __getitem__(self, subscript):
- return Sliced(self, subscript)
- def __contains__(self, other):
- # Disable old-style iteration protocol inherited from implementing
- # __getitem__() to prevent this method from hanging.
- raise TypeError(f"argument of type '{self.__class__.__name__}' is not iterable")
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- return query.resolve_ref(self.name, allow_joins, reuse, summarize)
- def replace_expressions(self, replacements):
- return replacements.get(self, self)
- def asc(self, **kwargs):
- return OrderBy(self, **kwargs)
- def desc(self, **kwargs):
- return OrderBy(self, descending=True, **kwargs)
- def __eq__(self, other):
- return self.__class__ == other.__class__ and self.name == other.name
- def __hash__(self):
- return hash(self.name)
- def copy(self):
- return copy.copy(self)
- class ResolvedOuterRef(F):
- """
- An object that contains a reference to an outer query.
- In this case, the reference to the outer query has been resolved because
- the inner query has been used as a subquery.
- """
- contains_aggregate = False
- contains_over_clause = False
- def as_sql(self, *args, **kwargs):
- raise ValueError(
- "This queryset contains a reference to an outer query and may "
- "only be used in a subquery."
- )
- def resolve_expression(self, *args, **kwargs):
- col = super().resolve_expression(*args, **kwargs)
- if col.contains_over_clause:
- raise NotSupportedError(
- f"Referencing outer query window expression is not supported: "
- f"{self.name}."
- )
- # FIXME: Rename possibly_multivalued to multivalued and fix detection
- # for non-multivalued JOINs (e.g. foreign key fields). This should take
- # into account only many-to-many and one-to-many relationships.
- col.possibly_multivalued = LOOKUP_SEP in self.name
- return col
- def relabeled_clone(self, relabels):
- return self
- def get_group_by_cols(self):
- return []
- class OuterRef(F):
- contains_aggregate = False
- contains_over_clause = False
- def resolve_expression(self, *args, **kwargs):
- if isinstance(self.name, self.__class__):
- return self.name
- return ResolvedOuterRef(self.name)
- def relabeled_clone(self, relabels):
- return self
- class Sliced(F):
- """
- An object that contains a slice of an F expression.
- Object resolves the column on which the slicing is applied, and then
- applies the slicing if possible.
- """
- def __init__(self, obj, subscript):
- super().__init__(obj.name)
- self.obj = obj
- if isinstance(subscript, int):
- if subscript < 0:
- raise ValueError("Negative indexing is not supported.")
- self.start = subscript + 1
- self.length = 1
- elif isinstance(subscript, slice):
- if (subscript.start is not None and subscript.start < 0) or (
- subscript.stop is not None and subscript.stop < 0
- ):
- raise ValueError("Negative indexing is not supported.")
- if subscript.step is not None:
- raise ValueError("Step argument is not supported.")
- if subscript.stop and subscript.start and subscript.stop < subscript.start:
- raise ValueError("Slice stop must be greater than slice start.")
- self.start = 1 if subscript.start is None else subscript.start + 1
- if subscript.stop is None:
- self.length = None
- else:
- self.length = subscript.stop - (subscript.start or 0)
- else:
- raise TypeError("Argument to slice must be either int or slice instance.")
- def __repr__(self):
- start = self.start - 1
- stop = None if self.length is None else start + self.length
- subscript = slice(start, stop)
- return f"{self.__class__.__qualname__}({self.obj!r}, {subscript!r})"
- def resolve_expression(
- self,
- query=None,
- allow_joins=True,
- reuse=None,
- summarize=False,
- for_save=False,
- ):
- resolved = query.resolve_ref(self.name, allow_joins, reuse, summarize)
- if isinstance(self.obj, (OuterRef, self.__class__)):
- expr = self.obj.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- else:
- expr = resolved
- return resolved.output_field.slice_expression(expr, self.start, self.length)
- @deconstructible(path="django.db.models.Func")
- class Func(SQLiteNumericMixin, Expression):
- """An SQL function call."""
- function = None
- template = "%(function)s(%(expressions)s)"
- arg_joiner = ", "
- arity = None # The number of arguments the function accepts.
- def __init__(self, *expressions, output_field=None, **extra):
- if self.arity is not None and len(expressions) != self.arity:
- raise TypeError(
- "'%s' takes exactly %s %s (%s given)"
- % (
- self.__class__.__name__,
- self.arity,
- "argument" if self.arity == 1 else "arguments",
- len(expressions),
- )
- )
- super().__init__(output_field=output_field)
- self.source_expressions = self._parse_expressions(*expressions)
- self.extra = extra
- def __repr__(self):
- args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
- extra = {**self.extra, **self._get_repr_options()}
- if extra:
- extra = ", ".join(
- str(key) + "=" + str(val) for key, val in sorted(extra.items())
- )
- return "{}({}, {})".format(self.__class__.__name__, args, extra)
- return "{}({})".format(self.__class__.__name__, args)
- def _get_repr_options(self):
- """Return a dict of extra __init__() options to include in the repr."""
- return {}
- def get_source_expressions(self):
- return self.source_expressions
- def set_source_expressions(self, exprs):
- self.source_expressions = exprs
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- c = self.copy()
- c.is_summary = summarize
- for pos, arg in enumerate(c.source_expressions):
- c.source_expressions[pos] = arg.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- return c
- def as_sql(
- self,
- compiler,
- connection,
- function=None,
- template=None,
- arg_joiner=None,
- **extra_context,
- ):
- connection.ops.check_expression_support(self)
- sql_parts = []
- params = []
- for arg in self.source_expressions:
- try:
- arg_sql, arg_params = compiler.compile(arg)
- except EmptyResultSet:
- empty_result_set_value = getattr(
- arg, "empty_result_set_value", NotImplemented
- )
- if empty_result_set_value is NotImplemented:
- raise
- arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
- except FullResultSet:
- arg_sql, arg_params = compiler.compile(Value(True))
- sql_parts.append(arg_sql)
- params.extend(arg_params)
- data = {**self.extra, **extra_context}
- # Use the first supplied value in this order: the parameter to this
- # method, a value supplied in __init__()'s **extra (the value in
- # `data`), or the value defined on the class.
- if function is not None:
- data["function"] = function
- else:
- data.setdefault("function", self.function)
- template = template or data.get("template", self.template)
- arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
- data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
- return template % data, params
- def copy(self):
- copy = super().copy()
- copy.source_expressions = self.source_expressions[:]
- copy.extra = self.extra.copy()
- return copy
- @cached_property
- def allowed_default(self):
- return all(expression.allowed_default for expression in self.source_expressions)
- @deconstructible(path="django.db.models.Value")
- class Value(SQLiteNumericMixin, Expression):
- """Represent a wrapped value as a node within an expression."""
- # Provide a default value for `for_save` in order to allow unresolved
- # instances to be compiled until a decision is taken in #25425.
- for_save = False
- allowed_default = True
- def __init__(self, value, output_field=None):
- """
- Arguments:
- * value: the value this expression represents. The value will be
- added into the sql parameter list and properly quoted.
- * output_field: an instance of the model field type that this
- expression will return, such as IntegerField() or CharField().
- """
- super().__init__(output_field=output_field)
- self.value = value
- def __repr__(self):
- return f"{self.__class__.__name__}({self.value!r})"
- def as_sql(self, compiler, connection):
- connection.ops.check_expression_support(self)
- val = self.value
- output_field = self._output_field_or_none
- if output_field is not None:
- if self.for_save:
- val = output_field.get_db_prep_save(val, connection=connection)
- else:
- val = output_field.get_db_prep_value(val, connection=connection)
- if hasattr(output_field, "get_placeholder"):
- return output_field.get_placeholder(val, compiler, connection), [val]
- if val is None:
- # oracledb does not always convert None to the appropriate
- # NULL type (like in case expressions using numbers), so we
- # use a literal SQL NULL
- return "NULL", []
- return "%s", [val]
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
- c.for_save = for_save
- return c
- def get_group_by_cols(self):
- return []
- def _resolve_output_field(self):
- if isinstance(self.value, str):
- return fields.CharField()
- if isinstance(self.value, bool):
- return fields.BooleanField()
- if isinstance(self.value, int):
- return fields.IntegerField()
- if isinstance(self.value, float):
- return fields.FloatField()
- if isinstance(self.value, datetime.datetime):
- return fields.DateTimeField()
- if isinstance(self.value, datetime.date):
- return fields.DateField()
- if isinstance(self.value, datetime.time):
- return fields.TimeField()
- if isinstance(self.value, datetime.timedelta):
- return fields.DurationField()
- if isinstance(self.value, Decimal):
- return fields.DecimalField()
- if isinstance(self.value, bytes):
- return fields.BinaryField()
- if isinstance(self.value, UUID):
- return fields.UUIDField()
- @property
- def empty_result_set_value(self):
- return self.value
- class RawSQL(Expression):
- allowed_default = True
- def __init__(self, sql, params, output_field=None):
- if output_field is None:
- output_field = fields.Field()
- self.sql, self.params = sql, params
- super().__init__(output_field=output_field)
- def __repr__(self):
- return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
- def as_sql(self, compiler, connection):
- return "(%s)" % self.sql, self.params
- def get_group_by_cols(self):
- return [self]
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- # Resolve parents fields used in raw SQL.
- if query.model:
- for parent in query.model._meta.all_parents:
- for parent_field in parent._meta.local_fields:
- if parent_field.column.lower() in self.sql.lower():
- query.resolve_ref(
- parent_field.name, allow_joins, reuse, summarize
- )
- break
- return super().resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- class Star(Expression):
- def __repr__(self):
- return "'*'"
- def as_sql(self, compiler, connection):
- return "*", []
- class DatabaseDefault(Expression):
- """
- Expression to use DEFAULT keyword during insert otherwise the underlying expression.
- """
- def __init__(self, expression, output_field=None):
- super().__init__(output_field)
- self.expression = expression
- def get_source_expressions(self):
- return [self.expression]
- def set_source_expressions(self, exprs):
- (self.expression,) = exprs
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- resolved_expression = self.expression.resolve_expression(
- query=query,
- allow_joins=allow_joins,
- reuse=reuse,
- summarize=summarize,
- for_save=for_save,
- )
- # Defaults used outside an INSERT context should resolve to their
- # underlying expression.
- if not for_save:
- return resolved_expression
- return DatabaseDefault(
- resolved_expression, output_field=self._output_field_or_none
- )
- def as_sql(self, compiler, connection):
- if not connection.features.supports_default_keyword_in_insert:
- return compiler.compile(self.expression)
- return "DEFAULT", []
- class Col(Expression):
- contains_column_references = True
- possibly_multivalued = False
- def __init__(self, alias, target, output_field=None):
- if output_field is None:
- output_field = target
- super().__init__(output_field=output_field)
- self.alias, self.target = alias, target
- def __repr__(self):
- alias, target = self.alias, self.target
- identifiers = (alias, str(target)) if alias else (str(target),)
- return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
- def as_sql(self, compiler, connection):
- alias, column = self.alias, self.target.column
- identifiers = (alias, column) if alias else (column,)
- sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
- return sql, []
- def relabeled_clone(self, relabels):
- if self.alias is None:
- return self
- return self.__class__(
- relabels.get(self.alias, self.alias), self.target, self.output_field
- )
- def get_group_by_cols(self):
- return [self]
- def get_db_converters(self, connection):
- if self.target == self.output_field:
- return self.output_field.get_db_converters(connection)
- return self.output_field.get_db_converters(
- connection
- ) + self.target.get_db_converters(connection)
- class Ref(Expression):
- """
- Reference to column alias of the query. For example, Ref('sum_cost') in
- qs.annotate(sum_cost=Sum('cost')) query.
- """
- def __init__(self, refs, source):
- super().__init__()
- self.refs, self.source = refs, source
- def __repr__(self):
- return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
- def get_source_expressions(self):
- return [self.source]
- def set_source_expressions(self, exprs):
- (self.source,) = exprs
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- # The sub-expression `source` has already been resolved, as this is
- # just a reference to the name of `source`.
- return self
- def get_refs(self):
- return {self.refs}
- def relabeled_clone(self, relabels):
- clone = self.copy()
- clone.source = self.source.relabeled_clone(relabels)
- return clone
- def as_sql(self, compiler, connection):
- return connection.ops.quote_name(self.refs), []
- def get_group_by_cols(self):
- return [self]
- class ExpressionList(Func):
- """
- An expression containing multiple expressions. Can be used to provide a
- list of expressions as an argument to another expression, like a partition
- clause.
- """
- template = "%(expressions)s"
- def __str__(self):
- return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
- def as_sql(self, *args, **kwargs):
- if not self.source_expressions:
- return "", ()
- return super().as_sql(*args, **kwargs)
- def as_sqlite(self, compiler, connection, **extra_context):
- # Casting to numeric is unnecessary.
- return self.as_sql(compiler, connection, **extra_context)
- def get_group_by_cols(self):
- group_by_cols = []
- for expr in self.get_source_expressions():
- group_by_cols.extend(expr.get_group_by_cols())
- return group_by_cols
- class OrderByList(ExpressionList):
- allowed_default = False
- template = "ORDER BY %(expressions)s"
- def __init__(self, *expressions, **extra):
- expressions = (
- (
- OrderBy(F(expr[1:]), descending=True)
- if isinstance(expr, str) and expr[0] == "-"
- else expr
- )
- for expr in expressions
- )
- super().__init__(*expressions, **extra)
- @deconstructible(path="django.db.models.ExpressionWrapper")
- class ExpressionWrapper(SQLiteNumericMixin, Expression):
- """
- An expression that can wrap another expression so that it can provide
- extra context to the inner expression, such as the output_field.
- """
- def __init__(self, expression, output_field):
- super().__init__(output_field=output_field)
- self.expression = expression
- def set_source_expressions(self, exprs):
- self.expression = exprs[0]
- def get_source_expressions(self):
- return [self.expression]
- def get_group_by_cols(self):
- if isinstance(self.expression, Expression):
- expression = self.expression.copy()
- expression.output_field = self.output_field
- return expression.get_group_by_cols()
- # For non-expressions e.g. an SQL WHERE clause, the entire
- # `expression` must be included in the GROUP BY clause.
- return super().get_group_by_cols()
- def as_sql(self, compiler, connection):
- return compiler.compile(self.expression)
- def __repr__(self):
- return "{}({})".format(self.__class__.__name__, self.expression)
- @property
- def allowed_default(self):
- return self.expression.allowed_default
- class NegatedExpression(ExpressionWrapper):
- """The logical negation of a conditional expression."""
- def __init__(self, expression):
- super().__init__(expression, output_field=fields.BooleanField())
- def __invert__(self):
- return self.expression.copy()
- def as_sql(self, compiler, connection):
- try:
- sql, params = super().as_sql(compiler, connection)
- except EmptyResultSet:
- features = compiler.connection.features
- if not features.supports_boolean_expr_in_select_clause:
- return "1=1", ()
- return compiler.compile(Value(True))
- ops = compiler.connection.ops
- # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
- # to be compared to another expression unless they're wrapped in a CASE
- # WHEN.
- if not ops.conditional_expression_supported_in_where_clause(self.expression):
- return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
- return f"NOT {sql}", params
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- resolved = super().resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- if not getattr(resolved.expression, "conditional", False):
- raise TypeError("Cannot negate non-conditional expressions.")
- return resolved
- def select_format(self, compiler, sql, params):
- # Wrap boolean expressions with a CASE WHEN expression if a database
- # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
- # GROUP BY list.
- expression_supported_in_where_clause = (
- compiler.connection.ops.conditional_expression_supported_in_where_clause
- )
- if (
- not compiler.connection.features.supports_boolean_expr_in_select_clause
- # Avoid double wrapping.
- and expression_supported_in_where_clause(self.expression)
- ):
- sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
- return sql, params
- @deconstructible(path="django.db.models.When")
- class When(Expression):
- template = "WHEN %(condition)s THEN %(result)s"
- # This isn't a complete conditional expression, must be used in Case().
- conditional = False
- def __init__(self, condition=None, then=None, **lookups):
- if lookups:
- if condition is None:
- condition, lookups = Q(**lookups), None
- elif getattr(condition, "conditional", False):
- condition, lookups = Q(condition, **lookups), None
- if condition is None or not getattr(condition, "conditional", False) or lookups:
- raise TypeError(
- "When() supports a Q object, a boolean expression, or lookups "
- "as a condition."
- )
- if isinstance(condition, Q) and not condition:
- raise ValueError("An empty Q() can't be used as a When() condition.")
- super().__init__(output_field=None)
- self.condition = condition
- self.result = self._parse_expressions(then)[0]
- def __str__(self):
- return "WHEN %r THEN %r" % (self.condition, self.result)
- def __repr__(self):
- return "<%s: %s>" % (self.__class__.__name__, self)
- def get_source_expressions(self):
- return [self.condition, self.result]
- def set_source_expressions(self, exprs):
- self.condition, self.result = exprs
- def get_source_fields(self):
- # We're only interested in the fields of the result expressions.
- return [self.result._output_field_or_none]
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- c = self.copy()
- c.is_summary = summarize
- if hasattr(c.condition, "resolve_expression"):
- c.condition = c.condition.resolve_expression(
- query, allow_joins, reuse, summarize, False
- )
- c.result = c.result.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- return c
- def as_sql(self, compiler, connection, template=None, **extra_context):
- connection.ops.check_expression_support(self)
- template_params = extra_context
- sql_params = []
- condition_sql, condition_params = compiler.compile(self.condition)
- template_params["condition"] = condition_sql
- result_sql, result_params = compiler.compile(self.result)
- template_params["result"] = result_sql
- template = template or self.template
- return template % template_params, (
- *sql_params,
- *condition_params,
- *result_params,
- )
- def get_group_by_cols(self):
- # This is not a complete expression and cannot be used in GROUP BY.
- cols = []
- for source in self.get_source_expressions():
- cols.extend(source.get_group_by_cols())
- return cols
- @cached_property
- def allowed_default(self):
- return self.condition.allowed_default and self.result.allowed_default
- @deconstructible(path="django.db.models.Case")
- class Case(SQLiteNumericMixin, Expression):
- """
- An SQL searched CASE expression:
- CASE
- WHEN n > 0
- THEN 'positive'
- WHEN n < 0
- THEN 'negative'
- ELSE 'zero'
- END
- """
- template = "CASE %(cases)s ELSE %(default)s END"
- case_joiner = " "
- def __init__(self, *cases, default=None, output_field=None, **extra):
- if not all(isinstance(case, When) for case in cases):
- raise TypeError("Positional arguments must all be When objects.")
- super().__init__(output_field)
- self.cases = list(cases)
- self.default = self._parse_expressions(default)[0]
- self.extra = extra
- def __str__(self):
- return "CASE %s, ELSE %r" % (
- ", ".join(str(c) for c in self.cases),
- self.default,
- )
- def __repr__(self):
- return "<%s: %s>" % (self.__class__.__name__, self)
- def get_source_expressions(self):
- return self.cases + [self.default]
- def set_source_expressions(self, exprs):
- *self.cases, self.default = exprs
- def resolve_expression(
- self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
- ):
- c = self.copy()
- c.is_summary = summarize
- for pos, case in enumerate(c.cases):
- c.cases[pos] = case.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- c.default = c.default.resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
- return c
- def copy(self):
- c = super().copy()
- c.cases = c.cases[:]
- return c
- def as_sql(
- self, compiler, connection, template=None, case_joiner=None, **extra_context
- ):
- connection.ops.check_expression_support(self)
- if not self.cases:
- return compiler.compile(self.default)
- template_params = {**self.extra, **extra_context}
- case_parts = []
- sql_params = []
- default_sql, default_params = compiler.compile(self.default)
- for case in self.cases:
- try:
- case_sql, case_params = compiler.compile(case)
- except EmptyResultSet:
- continue
- except FullResultSet:
- default_sql, default_params = compiler.compile(case.result)
- break
- case_parts.append(case_sql)
- sql_params.extend(case_params)
- if not case_parts:
- return default_sql, default_params
- case_joiner = case_joiner or self.case_joiner
- template_params["cases"] = case_joiner.join(case_parts)
- template_params["default"] = default_sql
- sql_params.extend(default_params)
- template = template or template_params.get("template", self.template)
- sql = template % template_params
- if self._output_field_or_none is not None:
- sql = connection.ops.unification_cast_sql(self.output_field) % sql
- return sql, sql_params
- def get_group_by_cols(self):
- if not self.cases:
- return self.default.get_group_by_cols()
- return super().get_group_by_cols()
- @cached_property
- def allowed_default(self):
- return self.default.allowed_default and all(
- case_.allowed_default for case_ in self.cases
- )
- class Subquery(BaseExpression, Combinable):
- """
- An explicit subquery. It may contain OuterRef() references to the outer
- query which will be resolved when it is applied to that query.
- """
- template = "(%(subquery)s)"
- contains_aggregate = False
- empty_result_set_value = None
- subquery = True
- def __init__(self, queryset, output_field=None, **extra):
- # Allow the usage of both QuerySet and sql.Query objects.
- self.query = getattr(queryset, "query", queryset).clone()
- self.query.subquery = True
- self.extra = extra
- super().__init__(output_field)
- def get_source_expressions(self):
- return [self.query]
- def set_source_expressions(self, exprs):
- self.query = exprs[0]
- def _resolve_output_field(self):
- return self.query.output_field
- def copy(self):
- clone = super().copy()
- clone.query = clone.query.clone()
- return clone
- @property
- def external_aliases(self):
- return self.query.external_aliases
- def get_external_cols(self):
- return self.query.get_external_cols()
- def as_sql(self, compiler, connection, template=None, **extra_context):
- connection.ops.check_expression_support(self)
- template_params = {**self.extra, **extra_context}
- subquery_sql, sql_params = self.query.as_sql(compiler, connection)
- template_params["subquery"] = subquery_sql[1:-1]
- template = template or template_params.get("template", self.template)
- sql = template % template_params
- return sql, sql_params
- def get_group_by_cols(self):
- return self.query.get_group_by_cols(wrapper=self)
- class Exists(Subquery):
- template = "EXISTS(%(subquery)s)"
- output_field = fields.BooleanField()
- empty_result_set_value = False
- def __init__(self, queryset, **kwargs):
- super().__init__(queryset, **kwargs)
- self.query = self.query.exists()
- def select_format(self, compiler, sql, params):
- # Wrap EXISTS() with a CASE WHEN expression if a database backend
- # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
- # BY list.
- if not compiler.connection.features.supports_boolean_expr_in_select_clause:
- sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
- return sql, params
- def as_sql(self, compiler, *args, **kwargs):
- try:
- return super().as_sql(compiler, *args, **kwargs)
- except EmptyResultSet:
- features = compiler.connection.features
- if not features.supports_boolean_expr_in_select_clause:
- return "1=0", ()
- return compiler.compile(Value(False))
- @deconstructible(path="django.db.models.OrderBy")
- class OrderBy(Expression):
- template = "%(expression)s %(ordering)s"
- conditional = False
- constraint_validation_compatible = False
- def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
- if nulls_first and nulls_last:
- raise ValueError("nulls_first and nulls_last are mutually exclusive")
- if nulls_first is False or nulls_last is False:
- raise ValueError("nulls_first and nulls_last values must be True or None.")
- self.nulls_first = nulls_first
- self.nulls_last = nulls_last
- self.descending = descending
- if not hasattr(expression, "resolve_expression"):
- raise ValueError("expression must be an expression type")
- self.expression = expression
- def __repr__(self):
- return "{}({}, descending={})".format(
- self.__class__.__name__, self.expression, self.descending
- )
- def set_source_expressions(self, exprs):
- self.expression = exprs[0]
- def get_source_expressions(self):
- return [self.expression]
- def as_sql(self, compiler, connection, template=None, **extra_context):
- template = template or self.template
- if connection.features.supports_order_by_nulls_modifier:
- if self.nulls_last:
- template = "%s NULLS LAST" % template
- elif self.nulls_first:
- template = "%s NULLS FIRST" % template
- else:
- if self.nulls_last and not (
- self.descending and connection.features.order_by_nulls_first
- ):
- template = "%%(expression)s IS NULL, %s" % template
- elif self.nulls_first and not (
- not self.descending and connection.features.order_by_nulls_first
- ):
- template = "%%(expression)s IS NOT NULL, %s" % template
- connection.ops.check_expression_support(self)
- expression_sql, params = compiler.compile(self.expression)
- placeholders = {
- "expression": expression_sql,
- "ordering": "DESC" if self.descending else "ASC",
- **extra_context,
- }
- params *= template.count("%(expression)s")
- return (template % placeholders).rstrip(), params
- def as_oracle(self, compiler, connection):
- # Oracle < 23c doesn't allow ORDER BY EXISTS() or filters unless it's
- # wrapped in a CASE WHEN.
- if (
- not connection.features.supports_boolean_expr_in_select_clause
- and connection.ops.conditional_expression_supported_in_where_clause(
- self.expression
- )
- ):
- copy = self.copy()
- copy.expression = Case(
- When(self.expression, then=True),
- default=False,
- )
- return copy.as_sql(compiler, connection)
- return self.as_sql(compiler, connection)
- def get_group_by_cols(self):
- cols = []
- for source in self.get_source_expressions():
- cols.extend(source.get_group_by_cols())
- return cols
- def reverse_ordering(self):
- self.descending = not self.descending
- if self.nulls_first:
- self.nulls_last = True
- self.nulls_first = None
- elif self.nulls_last:
- self.nulls_first = True
- self.nulls_last = None
- return self
- def asc(self):
- self.descending = False
- def desc(self):
- self.descending = True
- class Window(SQLiteNumericMixin, Expression):
- template = "%(expression)s OVER (%(window)s)"
- # Although the main expression may either be an aggregate or an
- # expression with an aggregate function, the GROUP BY that will
- # be introduced in the query as a result is not desired.
- contains_aggregate = False
- contains_over_clause = True
- def __init__(
- self,
- expression,
- partition_by=None,
- order_by=None,
- frame=None,
- output_field=None,
- ):
- self.partition_by = partition_by
- self.order_by = order_by
- self.frame = frame
- if not getattr(expression, "window_compatible", False):
- raise ValueError(
- "Expression '%s' isn't compatible with OVER clauses."
- % expression.__class__.__name__
- )
- if self.partition_by is not None:
- if not isinstance(self.partition_by, (tuple, list)):
- self.partition_by = (self.partition_by,)
- self.partition_by = ExpressionList(*self.partition_by)
- if self.order_by is not None:
- if isinstance(self.order_by, (list, tuple)):
- self.order_by = OrderByList(*self.order_by)
- elif isinstance(self.order_by, (BaseExpression, str)):
- self.order_by = OrderByList(self.order_by)
- else:
- raise ValueError(
- "Window.order_by must be either a string reference to a "
- "field, an expression, or a list or tuple of them."
- )
- super().__init__(output_field=output_field)
- self.source_expression = self._parse_expressions(expression)[0]
- def _resolve_output_field(self):
- return self.source_expression.output_field
- def get_source_expressions(self):
- return [self.source_expression, self.partition_by, self.order_by, self.frame]
- def set_source_expressions(self, exprs):
- self.source_expression, self.partition_by, self.order_by, self.frame = exprs
- def as_sql(self, compiler, connection, template=None):
- connection.ops.check_expression_support(self)
- if not connection.features.supports_over_clause:
- raise NotSupportedError("This backend does not support window expressions.")
- expr_sql, params = compiler.compile(self.source_expression)
- window_sql, window_params = [], ()
- if self.partition_by is not None:
- sql_expr, sql_params = self.partition_by.as_sql(
- compiler=compiler,
- connection=connection,
- template="PARTITION BY %(expressions)s",
- )
- window_sql.append(sql_expr)
- window_params += tuple(sql_params)
- if self.order_by is not None:
- order_sql, order_params = compiler.compile(self.order_by)
- window_sql.append(order_sql)
- window_params += tuple(order_params)
- if self.frame:
- frame_sql, frame_params = compiler.compile(self.frame)
- window_sql.append(frame_sql)
- window_params += tuple(frame_params)
- template = template or self.template
- return (
- template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
- (*params, *window_params),
- )
- def as_sqlite(self, compiler, connection):
- if isinstance(self.output_field, fields.DecimalField):
- # Casting to numeric must be outside of the window expression.
- copy = self.copy()
- source_expressions = copy.get_source_expressions()
- source_expressions[0].output_field = fields.FloatField()
- copy.set_source_expressions(source_expressions)
- return super(Window, copy).as_sqlite(compiler, connection)
- return self.as_sql(compiler, connection)
- def __str__(self):
- return "{} OVER ({}{}{})".format(
- str(self.source_expression),
- "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
- str(self.order_by or ""),
- str(self.frame or ""),
- )
- def __repr__(self):
- return "<%s: %s>" % (self.__class__.__name__, self)
- def get_group_by_cols(self):
- group_by_cols = []
- if self.partition_by:
- group_by_cols.extend(self.partition_by.get_group_by_cols())
- if self.order_by is not None:
- group_by_cols.extend(self.order_by.get_group_by_cols())
- return group_by_cols
- class WindowFrameExclusion(Enum):
- CURRENT_ROW = "CURRENT ROW"
- GROUP = "GROUP"
- TIES = "TIES"
- NO_OTHERS = "NO OTHERS"
- def __repr__(self):
- return f"{self.__class__.__qualname__}.{self._name_}"
- class WindowFrame(Expression):
- """
- Model the frame clause in window expressions. There are two types of frame
- clauses which are subclasses, however, all processing and validation (by no
- means intended to be complete) is done here. Thus, providing an end for a
- frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
- row in the frame).
- """
- template = "%(frame_type)s BETWEEN %(start)s AND %(end)s%(exclude)s"
- def __init__(self, start=None, end=None, exclusion=None):
- self.start = Value(start)
- self.end = Value(end)
- if not isinstance(exclusion, (NoneType, WindowFrameExclusion)):
- raise TypeError(
- f"{self.__class__.__qualname__}.exclusion must be a "
- "WindowFrameExclusion instance."
- )
- self.exclusion = exclusion
- def set_source_expressions(self, exprs):
- self.start, self.end = exprs
- def get_source_expressions(self):
- return [self.start, self.end]
- def get_exclusion(self):
- if self.exclusion is None:
- return ""
- return f" EXCLUDE {self.exclusion.value}"
- def as_sql(self, compiler, connection):
- connection.ops.check_expression_support(self)
- start, end = self.window_frame_start_end(
- connection, self.start.value, self.end.value
- )
- if self.exclusion and not connection.features.supports_frame_exclusion:
- raise NotSupportedError(
- "This backend does not support window frame exclusions."
- )
- return (
- self.template
- % {
- "frame_type": self.frame_type,
- "start": start,
- "end": end,
- "exclude": self.get_exclusion(),
- },
- [],
- )
- def __repr__(self):
- return "<%s: %s>" % (self.__class__.__name__, self)
- def get_group_by_cols(self):
- return []
- def __str__(self):
- if self.start.value is not None and self.start.value < 0:
- start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
- elif self.start.value is not None and self.start.value == 0:
- start = connection.ops.CURRENT_ROW
- elif self.start.value is not None and self.start.value > 0:
- start = "%d %s" % (self.start.value, connection.ops.FOLLOWING)
- else:
- start = connection.ops.UNBOUNDED_PRECEDING
- if self.end.value is not None and self.end.value > 0:
- end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
- elif self.end.value is not None and self.end.value == 0:
- end = connection.ops.CURRENT_ROW
- elif self.end.value is not None and self.end.value < 0:
- end = "%d %s" % (abs(self.end.value), connection.ops.PRECEDING)
- else:
- end = connection.ops.UNBOUNDED_FOLLOWING
- return self.template % {
- "frame_type": self.frame_type,
- "start": start,
- "end": end,
- "exclude": self.get_exclusion(),
- }
- def window_frame_start_end(self, connection, start, end):
- raise NotImplementedError("Subclasses must implement window_frame_start_end().")
- class RowRange(WindowFrame):
- frame_type = "ROWS"
- def window_frame_start_end(self, connection, start, end):
- return connection.ops.window_frame_rows_start_end(start, end)
- class ValueRange(WindowFrame):
- frame_type = "RANGE"
- def window_frame_start_end(self, connection, start, end):
- return connection.ops.window_frame_range_start_end(start, end)
|