text.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. from django.db import NotSupportedError
  2. from django.db.models.expressions import Func, Value
  3. from django.db.models.fields import CharField, IntegerField, TextField
  4. from django.db.models.functions import Cast, Coalesce
  5. from django.db.models.lookups import Transform
  6. class MySQLSHA2Mixin:
  7. def as_mysql(self, compiler, connection, **extra_context):
  8. return super().as_sql(
  9. compiler,
  10. connection,
  11. template="SHA2(%%(expressions)s, %s)" % self.function[3:],
  12. **extra_context,
  13. )
  14. class OracleHashMixin:
  15. def as_oracle(self, compiler, connection, **extra_context):
  16. return super().as_sql(
  17. compiler,
  18. connection,
  19. template=(
  20. "LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
  21. "%(expressions)s, 'AL32UTF8'), '%(function)s')))"
  22. ),
  23. **extra_context,
  24. )
  25. class PostgreSQLSHAMixin:
  26. def as_postgresql(self, compiler, connection, **extra_context):
  27. return super().as_sql(
  28. compiler,
  29. connection,
  30. template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
  31. function=self.function.lower(),
  32. **extra_context,
  33. )
  34. class Chr(Transform):
  35. function = "CHR"
  36. lookup_name = "chr"
  37. output_field = CharField()
  38. def as_mysql(self, compiler, connection, **extra_context):
  39. return super().as_sql(
  40. compiler,
  41. connection,
  42. function="CHAR",
  43. template="%(function)s(%(expressions)s USING utf16)",
  44. **extra_context,
  45. )
  46. def as_oracle(self, compiler, connection, **extra_context):
  47. return super().as_sql(
  48. compiler,
  49. connection,
  50. template="%(function)s(%(expressions)s USING NCHAR_CS)",
  51. **extra_context,
  52. )
  53. def as_sqlite(self, compiler, connection, **extra_context):
  54. return super().as_sql(compiler, connection, function="CHAR", **extra_context)
  55. class ConcatPair(Func):
  56. """
  57. Concatenate two arguments together. This is used by `Concat` because not
  58. all backend databases support more than two arguments.
  59. """
  60. function = "CONCAT"
  61. def pipes_concat_sql(self, compiler, connection, **extra_context):
  62. coalesced = self.coalesce()
  63. return super(ConcatPair, coalesced).as_sql(
  64. compiler,
  65. connection,
  66. template="%(expressions)s",
  67. arg_joiner=" || ",
  68. **extra_context,
  69. )
  70. as_sqlite = pipes_concat_sql
  71. def as_postgresql(self, compiler, connection, **extra_context):
  72. c = self.copy()
  73. c.set_source_expressions(
  74. [
  75. (
  76. expression
  77. if isinstance(expression.output_field, (CharField, TextField))
  78. else Cast(expression, TextField())
  79. )
  80. for expression in c.get_source_expressions()
  81. ]
  82. )
  83. return c.pipes_concat_sql(compiler, connection, **extra_context)
  84. def as_mysql(self, compiler, connection, **extra_context):
  85. # Use CONCAT_WS with an empty separator so that NULLs are ignored.
  86. return super().as_sql(
  87. compiler,
  88. connection,
  89. function="CONCAT_WS",
  90. template="%(function)s('', %(expressions)s)",
  91. **extra_context,
  92. )
  93. def coalesce(self):
  94. # null on either side results in null for expression, wrap with coalesce
  95. c = self.copy()
  96. c.set_source_expressions(
  97. [
  98. Coalesce(expression, Value(""))
  99. for expression in c.get_source_expressions()
  100. ]
  101. )
  102. return c
  103. class Concat(Func):
  104. """
  105. Concatenate text fields together. Backends that result in an entire
  106. null expression when any arguments are null will wrap each argument in
  107. coalesce functions to ensure a non-null result.
  108. """
  109. function = None
  110. template = "%(expressions)s"
  111. def __init__(self, *expressions, **extra):
  112. if len(expressions) < 2:
  113. raise ValueError("Concat must take at least two expressions")
  114. paired = self._paired(expressions, output_field=extra.get("output_field"))
  115. super().__init__(paired, **extra)
  116. def _paired(self, expressions, output_field):
  117. # wrap pairs of expressions in successive concat functions
  118. # exp = [a, b, c, d]
  119. # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
  120. if len(expressions) == 2:
  121. return ConcatPair(*expressions, output_field=output_field)
  122. return ConcatPair(
  123. expressions[0],
  124. self._paired(expressions[1:], output_field=output_field),
  125. output_field=output_field,
  126. )
  127. class Left(Func):
  128. function = "LEFT"
  129. arity = 2
  130. output_field = CharField()
  131. def __init__(self, expression, length, **extra):
  132. """
  133. expression: the name of a field, or an expression returning a string
  134. length: the number of characters to return from the start of the string
  135. """
  136. if not hasattr(length, "resolve_expression"):
  137. if length < 1:
  138. raise ValueError("'length' must be greater than 0.")
  139. super().__init__(expression, length, **extra)
  140. def get_substr(self):
  141. return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
  142. def as_oracle(self, compiler, connection, **extra_context):
  143. return self.get_substr().as_oracle(compiler, connection, **extra_context)
  144. def as_sqlite(self, compiler, connection, **extra_context):
  145. return self.get_substr().as_sqlite(compiler, connection, **extra_context)
  146. class Length(Transform):
  147. """Return the number of characters in the expression."""
  148. function = "LENGTH"
  149. lookup_name = "length"
  150. output_field = IntegerField()
  151. def as_mysql(self, compiler, connection, **extra_context):
  152. return super().as_sql(
  153. compiler, connection, function="CHAR_LENGTH", **extra_context
  154. )
  155. class Lower(Transform):
  156. function = "LOWER"
  157. lookup_name = "lower"
  158. class LPad(Func):
  159. function = "LPAD"
  160. output_field = CharField()
  161. def __init__(self, expression, length, fill_text=Value(" "), **extra):
  162. if (
  163. not hasattr(length, "resolve_expression")
  164. and length is not None
  165. and length < 0
  166. ):
  167. raise ValueError("'length' must be greater or equal to 0.")
  168. super().__init__(expression, length, fill_text, **extra)
  169. class LTrim(Transform):
  170. function = "LTRIM"
  171. lookup_name = "ltrim"
  172. class MD5(OracleHashMixin, Transform):
  173. function = "MD5"
  174. lookup_name = "md5"
  175. class Ord(Transform):
  176. function = "ASCII"
  177. lookup_name = "ord"
  178. output_field = IntegerField()
  179. def as_mysql(self, compiler, connection, **extra_context):
  180. return super().as_sql(compiler, connection, function="ORD", **extra_context)
  181. def as_sqlite(self, compiler, connection, **extra_context):
  182. return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
  183. class Repeat(Func):
  184. function = "REPEAT"
  185. output_field = CharField()
  186. def __init__(self, expression, number, **extra):
  187. if (
  188. not hasattr(number, "resolve_expression")
  189. and number is not None
  190. and number < 0
  191. ):
  192. raise ValueError("'number' must be greater or equal to 0.")
  193. super().__init__(expression, number, **extra)
  194. def as_oracle(self, compiler, connection, **extra_context):
  195. expression, number = self.source_expressions
  196. length = None if number is None else Length(expression) * number
  197. rpad = RPad(expression, length, expression)
  198. return rpad.as_sql(compiler, connection, **extra_context)
  199. class Replace(Func):
  200. function = "REPLACE"
  201. def __init__(self, expression, text, replacement=Value(""), **extra):
  202. super().__init__(expression, text, replacement, **extra)
  203. class Reverse(Transform):
  204. function = "REVERSE"
  205. lookup_name = "reverse"
  206. def as_oracle(self, compiler, connection, **extra_context):
  207. # REVERSE in Oracle is undocumented and doesn't support multi-byte
  208. # strings. Use a special subquery instead.
  209. suffix = connection.features.bare_select_suffix
  210. sql, params = super().as_sql(
  211. compiler,
  212. connection,
  213. template=(
  214. "(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
  215. f"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s{suffix} "
  216. "CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
  217. "GROUP BY %(expressions)s)"
  218. ),
  219. **extra_context,
  220. )
  221. return sql, params * 3
  222. class Right(Left):
  223. function = "RIGHT"
  224. def get_substr(self):
  225. return Substr(
  226. self.source_expressions[0],
  227. self.source_expressions[1] * Value(-1),
  228. self.source_expressions[1],
  229. )
  230. class RPad(LPad):
  231. function = "RPAD"
  232. class RTrim(Transform):
  233. function = "RTRIM"
  234. lookup_name = "rtrim"
  235. class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
  236. function = "SHA1"
  237. lookup_name = "sha1"
  238. class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
  239. function = "SHA224"
  240. lookup_name = "sha224"
  241. def as_oracle(self, compiler, connection, **extra_context):
  242. raise NotSupportedError("SHA224 is not supported on Oracle.")
  243. class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
  244. function = "SHA256"
  245. lookup_name = "sha256"
  246. class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
  247. function = "SHA384"
  248. lookup_name = "sha384"
  249. class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
  250. function = "SHA512"
  251. lookup_name = "sha512"
  252. class StrIndex(Func):
  253. """
  254. Return a positive integer corresponding to the 1-indexed position of the
  255. first occurrence of a substring inside another string, or 0 if the
  256. substring is not found.
  257. """
  258. function = "INSTR"
  259. arity = 2
  260. output_field = IntegerField()
  261. def as_postgresql(self, compiler, connection, **extra_context):
  262. return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
  263. class Substr(Func):
  264. function = "SUBSTRING"
  265. output_field = CharField()
  266. def __init__(self, expression, pos, length=None, **extra):
  267. """
  268. expression: the name of a field, or an expression returning a string
  269. pos: an integer > 0, or an expression returning an integer
  270. length: an optional number of characters to return
  271. """
  272. if not hasattr(pos, "resolve_expression"):
  273. if pos < 1:
  274. raise ValueError("'pos' must be greater than 0")
  275. expressions = [expression, pos]
  276. if length is not None:
  277. expressions.append(length)
  278. super().__init__(*expressions, **extra)
  279. def as_sqlite(self, compiler, connection, **extra_context):
  280. return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
  281. def as_oracle(self, compiler, connection, **extra_context):
  282. return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
  283. class Trim(Transform):
  284. function = "TRIM"
  285. lookup_name = "trim"
  286. class Upper(Transform):
  287. function = "UPPER"
  288. lookup_name = "upper"