math.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import math
  2. from django.db.models.expressions import Func, Value
  3. from django.db.models.fields import FloatField, IntegerField
  4. from django.db.models.functions import Cast
  5. from django.db.models.functions.mixins import (
  6. FixDecimalInputMixin,
  7. NumericOutputFieldMixin,
  8. )
  9. from django.db.models.lookups import Transform
  10. class Abs(Transform):
  11. function = "ABS"
  12. lookup_name = "abs"
  13. class ACos(NumericOutputFieldMixin, Transform):
  14. function = "ACOS"
  15. lookup_name = "acos"
  16. class ASin(NumericOutputFieldMixin, Transform):
  17. function = "ASIN"
  18. lookup_name = "asin"
  19. class ATan(NumericOutputFieldMixin, Transform):
  20. function = "ATAN"
  21. lookup_name = "atan"
  22. class ATan2(NumericOutputFieldMixin, Func):
  23. function = "ATAN2"
  24. arity = 2
  25. def as_sqlite(self, compiler, connection, **extra_context):
  26. if not getattr(
  27. connection.ops, "spatialite", False
  28. ) or connection.ops.spatial_version >= (5, 0, 0):
  29. return self.as_sql(compiler, connection)
  30. # This function is usually ATan2(y, x), returning the inverse tangent
  31. # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
  32. # Cast integers to float to avoid inconsistent/buggy behavior if the
  33. # arguments are mixed between integer and float or decimal.
  34. # https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
  35. clone = self.copy()
  36. clone.set_source_expressions(
  37. [
  38. (
  39. Cast(expression, FloatField())
  40. if isinstance(expression.output_field, IntegerField)
  41. else expression
  42. )
  43. for expression in self.get_source_expressions()[::-1]
  44. ]
  45. )
  46. return clone.as_sql(compiler, connection, **extra_context)
  47. class Ceil(Transform):
  48. function = "CEILING"
  49. lookup_name = "ceil"
  50. def as_oracle(self, compiler, connection, **extra_context):
  51. return super().as_sql(compiler, connection, function="CEIL", **extra_context)
  52. class Cos(NumericOutputFieldMixin, Transform):
  53. function = "COS"
  54. lookup_name = "cos"
  55. class Cot(NumericOutputFieldMixin, Transform):
  56. function = "COT"
  57. lookup_name = "cot"
  58. def as_oracle(self, compiler, connection, **extra_context):
  59. return super().as_sql(
  60. compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
  61. )
  62. class Degrees(NumericOutputFieldMixin, Transform):
  63. function = "DEGREES"
  64. lookup_name = "degrees"
  65. def as_oracle(self, compiler, connection, **extra_context):
  66. return super().as_sql(
  67. compiler,
  68. connection,
  69. template="((%%(expressions)s) * 180 / %s)" % math.pi,
  70. **extra_context,
  71. )
  72. class Exp(NumericOutputFieldMixin, Transform):
  73. function = "EXP"
  74. lookup_name = "exp"
  75. class Floor(Transform):
  76. function = "FLOOR"
  77. lookup_name = "floor"
  78. class Ln(NumericOutputFieldMixin, Transform):
  79. function = "LN"
  80. lookup_name = "ln"
  81. class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
  82. function = "LOG"
  83. arity = 2
  84. def as_sqlite(self, compiler, connection, **extra_context):
  85. if not getattr(connection.ops, "spatialite", False):
  86. return self.as_sql(compiler, connection)
  87. # This function is usually Log(b, x) returning the logarithm of x to
  88. # the base b, but on SpatiaLite it's Log(x, b).
  89. clone = self.copy()
  90. clone.set_source_expressions(self.get_source_expressions()[::-1])
  91. return clone.as_sql(compiler, connection, **extra_context)
  92. class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
  93. function = "MOD"
  94. arity = 2
  95. class Pi(NumericOutputFieldMixin, Func):
  96. function = "PI"
  97. arity = 0
  98. def as_oracle(self, compiler, connection, **extra_context):
  99. return super().as_sql(
  100. compiler, connection, template=str(math.pi), **extra_context
  101. )
  102. class Power(NumericOutputFieldMixin, Func):
  103. function = "POWER"
  104. arity = 2
  105. class Radians(NumericOutputFieldMixin, Transform):
  106. function = "RADIANS"
  107. lookup_name = "radians"
  108. def as_oracle(self, compiler, connection, **extra_context):
  109. return super().as_sql(
  110. compiler,
  111. connection,
  112. template="((%%(expressions)s) * %s / 180)" % math.pi,
  113. **extra_context,
  114. )
  115. class Random(NumericOutputFieldMixin, Func):
  116. function = "RANDOM"
  117. arity = 0
  118. def as_mysql(self, compiler, connection, **extra_context):
  119. return super().as_sql(compiler, connection, function="RAND", **extra_context)
  120. def as_oracle(self, compiler, connection, **extra_context):
  121. return super().as_sql(
  122. compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
  123. )
  124. def as_sqlite(self, compiler, connection, **extra_context):
  125. return super().as_sql(compiler, connection, function="RAND", **extra_context)
  126. def get_group_by_cols(self):
  127. return []
  128. class Round(FixDecimalInputMixin, Transform):
  129. function = "ROUND"
  130. lookup_name = "round"
  131. arity = None # Override Transform's arity=1 to enable passing precision.
  132. def __init__(self, expression, precision=0, **extra):
  133. super().__init__(expression, precision, **extra)
  134. def as_sqlite(self, compiler, connection, **extra_context):
  135. precision = self.get_source_expressions()[1]
  136. if isinstance(precision, Value) and precision.value < 0:
  137. raise ValueError("SQLite does not support negative precision.")
  138. return super().as_sqlite(compiler, connection, **extra_context)
  139. def _resolve_output_field(self):
  140. source = self.get_source_expressions()[0]
  141. return source.output_field
  142. class Sign(Transform):
  143. function = "SIGN"
  144. lookup_name = "sign"
  145. class Sin(NumericOutputFieldMixin, Transform):
  146. function = "SIN"
  147. lookup_name = "sin"
  148. class Sqrt(NumericOutputFieldMixin, Transform):
  149. function = "SQRT"
  150. lookup_name = "sqrt"
  151. class Tan(NumericOutputFieldMixin, Transform):
  152. function = "TAN"
  153. lookup_name = "tan"