_functions.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. """
  2. Implementations of SQL functions for SQLite.
  3. """
  4. import functools
  5. import random
  6. import statistics
  7. import zoneinfo
  8. from datetime import timedelta
  9. from hashlib import md5, sha1, sha224, sha256, sha384, sha512
  10. from math import (
  11. acos,
  12. asin,
  13. atan,
  14. atan2,
  15. ceil,
  16. cos,
  17. degrees,
  18. exp,
  19. floor,
  20. fmod,
  21. log,
  22. pi,
  23. radians,
  24. sin,
  25. sqrt,
  26. tan,
  27. )
  28. from re import search as re_search
  29. from django.db.backends.utils import (
  30. split_tzname_delta,
  31. typecast_time,
  32. typecast_timestamp,
  33. )
  34. from django.utils import timezone
  35. from django.utils.duration import duration_microseconds
  36. def register(connection):
  37. create_deterministic_function = functools.partial(
  38. connection.create_function,
  39. deterministic=True,
  40. )
  41. create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract)
  42. create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
  43. create_deterministic_function(
  44. "django_datetime_cast_date", 3, _sqlite_datetime_cast_date
  45. )
  46. create_deterministic_function(
  47. "django_datetime_cast_time", 3, _sqlite_datetime_cast_time
  48. )
  49. create_deterministic_function(
  50. "django_datetime_extract", 4, _sqlite_datetime_extract
  51. )
  52. create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc)
  53. create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
  54. create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
  55. create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
  56. create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff)
  57. create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
  58. create_deterministic_function("regexp", 2, _sqlite_regexp)
  59. create_deterministic_function("BITXOR", 2, _sqlite_bitxor)
  60. create_deterministic_function("COT", 1, _sqlite_cot)
  61. create_deterministic_function("LPAD", 3, _sqlite_lpad)
  62. create_deterministic_function("MD5", 1, _sqlite_md5)
  63. create_deterministic_function("REPEAT", 2, _sqlite_repeat)
  64. create_deterministic_function("REVERSE", 1, _sqlite_reverse)
  65. create_deterministic_function("RPAD", 3, _sqlite_rpad)
  66. create_deterministic_function("SHA1", 1, _sqlite_sha1)
  67. create_deterministic_function("SHA224", 1, _sqlite_sha224)
  68. create_deterministic_function("SHA256", 1, _sqlite_sha256)
  69. create_deterministic_function("SHA384", 1, _sqlite_sha384)
  70. create_deterministic_function("SHA512", 1, _sqlite_sha512)
  71. create_deterministic_function("SIGN", 1, _sqlite_sign)
  72. # Don't use the built-in RANDOM() function because it returns a value
  73. # in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
  74. connection.create_function("RAND", 0, random.random)
  75. connection.create_aggregate("STDDEV_POP", 1, StdDevPop)
  76. connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp)
  77. connection.create_aggregate("VAR_POP", 1, VarPop)
  78. connection.create_aggregate("VAR_SAMP", 1, VarSamp)
  79. # Some math functions are enabled by default in SQLite 3.35+.
  80. sql = "select sqlite_compileoption_used('ENABLE_MATH_FUNCTIONS')"
  81. if not connection.execute(sql).fetchone()[0]:
  82. create_deterministic_function("ACOS", 1, _sqlite_acos)
  83. create_deterministic_function("ASIN", 1, _sqlite_asin)
  84. create_deterministic_function("ATAN", 1, _sqlite_atan)
  85. create_deterministic_function("ATAN2", 2, _sqlite_atan2)
  86. create_deterministic_function("CEILING", 1, _sqlite_ceiling)
  87. create_deterministic_function("COS", 1, _sqlite_cos)
  88. create_deterministic_function("DEGREES", 1, _sqlite_degrees)
  89. create_deterministic_function("EXP", 1, _sqlite_exp)
  90. create_deterministic_function("FLOOR", 1, _sqlite_floor)
  91. create_deterministic_function("LN", 1, _sqlite_ln)
  92. create_deterministic_function("LOG", 2, _sqlite_log)
  93. create_deterministic_function("MOD", 2, _sqlite_mod)
  94. create_deterministic_function("PI", 0, _sqlite_pi)
  95. create_deterministic_function("POWER", 2, _sqlite_power)
  96. create_deterministic_function("RADIANS", 1, _sqlite_radians)
  97. create_deterministic_function("SIN", 1, _sqlite_sin)
  98. create_deterministic_function("SQRT", 1, _sqlite_sqrt)
  99. create_deterministic_function("TAN", 1, _sqlite_tan)
  100. def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
  101. if dt is None:
  102. return None
  103. try:
  104. dt = typecast_timestamp(dt)
  105. except (TypeError, ValueError):
  106. return None
  107. if conn_tzname:
  108. dt = dt.replace(tzinfo=zoneinfo.ZoneInfo(conn_tzname))
  109. if tzname is not None and tzname != conn_tzname:
  110. tzname, sign, offset = split_tzname_delta(tzname)
  111. if offset:
  112. hours, minutes = offset.split(":")
  113. offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
  114. dt += offset_delta if sign == "+" else -offset_delta
  115. # The tzname may originally be just the offset e.g. "+3:00",
  116. # which becomes an empty string after splitting the sign and offset.
  117. # In this case, use the conn_tzname as fallback.
  118. dt = timezone.localtime(dt, zoneinfo.ZoneInfo(tzname or conn_tzname))
  119. return dt
  120. def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
  121. dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
  122. if dt is None:
  123. return None
  124. if lookup_type == "year":
  125. return f"{dt.year:04d}-01-01"
  126. elif lookup_type == "quarter":
  127. month_in_quarter = dt.month - (dt.month - 1) % 3
  128. return f"{dt.year:04d}-{month_in_quarter:02d}-01"
  129. elif lookup_type == "month":
  130. return f"{dt.year:04d}-{dt.month:02d}-01"
  131. elif lookup_type == "week":
  132. dt -= timedelta(days=dt.weekday())
  133. return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
  134. elif lookup_type == "day":
  135. return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
  136. raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
  137. def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
  138. if dt is None:
  139. return None
  140. dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
  141. if dt_parsed is None:
  142. try:
  143. dt = typecast_time(dt)
  144. except (ValueError, TypeError):
  145. return None
  146. else:
  147. dt = dt_parsed
  148. if lookup_type == "hour":
  149. return f"{dt.hour:02d}:00:00"
  150. elif lookup_type == "minute":
  151. return f"{dt.hour:02d}:{dt.minute:02d}:00"
  152. elif lookup_type == "second":
  153. return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
  154. raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
  155. def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
  156. dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
  157. if dt is None:
  158. return None
  159. return dt.date().isoformat()
  160. def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
  161. dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
  162. if dt is None:
  163. return None
  164. return dt.time().isoformat()
  165. def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
  166. dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
  167. if dt is None:
  168. return None
  169. if lookup_type == "week_day":
  170. return (dt.isoweekday() % 7) + 1
  171. elif lookup_type == "iso_week_day":
  172. return dt.isoweekday()
  173. elif lookup_type == "week":
  174. return dt.isocalendar().week
  175. elif lookup_type == "quarter":
  176. return ceil(dt.month / 3)
  177. elif lookup_type == "iso_year":
  178. return dt.isocalendar().year
  179. else:
  180. return getattr(dt, lookup_type)
  181. def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
  182. dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
  183. if dt is None:
  184. return None
  185. if lookup_type == "year":
  186. return f"{dt.year:04d}-01-01 00:00:00"
  187. elif lookup_type == "quarter":
  188. month_in_quarter = dt.month - (dt.month - 1) % 3
  189. return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00"
  190. elif lookup_type == "month":
  191. return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00"
  192. elif lookup_type == "week":
  193. dt -= timedelta(days=dt.weekday())
  194. return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
  195. elif lookup_type == "day":
  196. return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
  197. elif lookup_type == "hour":
  198. return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00"
  199. elif lookup_type == "minute":
  200. return (
  201. f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
  202. f"{dt.hour:02d}:{dt.minute:02d}:00"
  203. )
  204. elif lookup_type == "second":
  205. return (
  206. f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
  207. f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
  208. )
  209. raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
  210. def _sqlite_time_extract(lookup_type, dt):
  211. if dt is None:
  212. return None
  213. try:
  214. dt = typecast_time(dt)
  215. except (ValueError, TypeError):
  216. return None
  217. return getattr(dt, lookup_type)
  218. def _sqlite_prepare_dtdelta_param(conn, param):
  219. if conn in ["+", "-"]:
  220. if isinstance(param, int):
  221. return timedelta(0, 0, param)
  222. else:
  223. return typecast_timestamp(param)
  224. return param
  225. def _sqlite_format_dtdelta(connector, lhs, rhs):
  226. """
  227. LHS and RHS can be either:
  228. - An integer number of microseconds
  229. - A string representing a datetime
  230. - A scalar value, e.g. float
  231. """
  232. if connector is None or lhs is None or rhs is None:
  233. return None
  234. connector = connector.strip()
  235. try:
  236. real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs)
  237. real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
  238. except (ValueError, TypeError):
  239. return None
  240. if connector == "+":
  241. # typecast_timestamp() returns a date or a datetime without timezone.
  242. # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
  243. out = str(real_lhs + real_rhs)
  244. elif connector == "-":
  245. out = str(real_lhs - real_rhs)
  246. elif connector == "*":
  247. out = real_lhs * real_rhs
  248. else:
  249. out = real_lhs / real_rhs
  250. return out
  251. def _sqlite_time_diff(lhs, rhs):
  252. if lhs is None or rhs is None:
  253. return None
  254. left = typecast_time(lhs)
  255. right = typecast_time(rhs)
  256. return (
  257. (left.hour * 60 * 60 * 1000000)
  258. + (left.minute * 60 * 1000000)
  259. + (left.second * 1000000)
  260. + (left.microsecond)
  261. - (right.hour * 60 * 60 * 1000000)
  262. - (right.minute * 60 * 1000000)
  263. - (right.second * 1000000)
  264. - (right.microsecond)
  265. )
  266. def _sqlite_timestamp_diff(lhs, rhs):
  267. if lhs is None or rhs is None:
  268. return None
  269. left = typecast_timestamp(lhs)
  270. right = typecast_timestamp(rhs)
  271. return duration_microseconds(left - right)
  272. def _sqlite_regexp(pattern, string):
  273. if pattern is None or string is None:
  274. return None
  275. if not isinstance(string, str):
  276. string = str(string)
  277. return bool(re_search(pattern, string))
  278. def _sqlite_acos(x):
  279. if x is None:
  280. return None
  281. return acos(x)
  282. def _sqlite_asin(x):
  283. if x is None:
  284. return None
  285. return asin(x)
  286. def _sqlite_atan(x):
  287. if x is None:
  288. return None
  289. return atan(x)
  290. def _sqlite_atan2(y, x):
  291. if y is None or x is None:
  292. return None
  293. return atan2(y, x)
  294. def _sqlite_bitxor(x, y):
  295. if x is None or y is None:
  296. return None
  297. return x ^ y
  298. def _sqlite_ceiling(x):
  299. if x is None:
  300. return None
  301. return ceil(x)
  302. def _sqlite_cos(x):
  303. if x is None:
  304. return None
  305. return cos(x)
  306. def _sqlite_cot(x):
  307. if x is None:
  308. return None
  309. return 1 / tan(x)
  310. def _sqlite_degrees(x):
  311. if x is None:
  312. return None
  313. return degrees(x)
  314. def _sqlite_exp(x):
  315. if x is None:
  316. return None
  317. return exp(x)
  318. def _sqlite_floor(x):
  319. if x is None:
  320. return None
  321. return floor(x)
  322. def _sqlite_ln(x):
  323. if x is None:
  324. return None
  325. return log(x)
  326. def _sqlite_log(base, x):
  327. if base is None or x is None:
  328. return None
  329. # Arguments reversed to match SQL standard.
  330. return log(x, base)
  331. def _sqlite_lpad(text, length, fill_text):
  332. if text is None or length is None or fill_text is None:
  333. return None
  334. delta = length - len(text)
  335. if delta <= 0:
  336. return text[:length]
  337. return (fill_text * length)[:delta] + text
  338. def _sqlite_md5(text):
  339. if text is None:
  340. return None
  341. return md5(text.encode()).hexdigest()
  342. def _sqlite_mod(x, y):
  343. if x is None or y is None:
  344. return None
  345. return fmod(x, y)
  346. def _sqlite_pi():
  347. return pi
  348. def _sqlite_power(x, y):
  349. if x is None or y is None:
  350. return None
  351. return x**y
  352. def _sqlite_radians(x):
  353. if x is None:
  354. return None
  355. return radians(x)
  356. def _sqlite_repeat(text, count):
  357. if text is None or count is None:
  358. return None
  359. return text * count
  360. def _sqlite_reverse(text):
  361. if text is None:
  362. return None
  363. return text[::-1]
  364. def _sqlite_rpad(text, length, fill_text):
  365. if text is None or length is None or fill_text is None:
  366. return None
  367. return (text + fill_text * length)[:length]
  368. def _sqlite_sha1(text):
  369. if text is None:
  370. return None
  371. return sha1(text.encode()).hexdigest()
  372. def _sqlite_sha224(text):
  373. if text is None:
  374. return None
  375. return sha224(text.encode()).hexdigest()
  376. def _sqlite_sha256(text):
  377. if text is None:
  378. return None
  379. return sha256(text.encode()).hexdigest()
  380. def _sqlite_sha384(text):
  381. if text is None:
  382. return None
  383. return sha384(text.encode()).hexdigest()
  384. def _sqlite_sha512(text):
  385. if text is None:
  386. return None
  387. return sha512(text.encode()).hexdigest()
  388. def _sqlite_sign(x):
  389. if x is None:
  390. return None
  391. return (x > 0) - (x < 0)
  392. def _sqlite_sin(x):
  393. if x is None:
  394. return None
  395. return sin(x)
  396. def _sqlite_sqrt(x):
  397. if x is None:
  398. return None
  399. return sqrt(x)
  400. def _sqlite_tan(x):
  401. if x is None:
  402. return None
  403. return tan(x)
  404. class ListAggregate(list):
  405. step = list.append
  406. class StdDevPop(ListAggregate):
  407. finalize = statistics.pstdev
  408. class StdDevSamp(ListAggregate):
  409. finalize = statistics.stdev
  410. class VarPop(ListAggregate):
  411. finalize = statistics.pvariance
  412. class VarSamp(ListAggregate):
  413. finalize = statistics.variance