lookups.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. from django.contrib.gis.db.models.fields import BaseSpatialField
  2. from django.contrib.gis.measure import Distance
  3. from django.db import NotSupportedError
  4. from django.db.models import Expression, Lookup, Transform
  5. from django.db.models.sql.query import Query
  6. from django.utils.regex_helper import _lazy_re_compile
  7. class RasterBandTransform(Transform):
  8. def as_sql(self, compiler, connection):
  9. return compiler.compile(self.lhs)
  10. class GISLookup(Lookup):
  11. sql_template = None
  12. transform_func = None
  13. distance = False
  14. band_rhs = None
  15. band_lhs = None
  16. def __init__(self, lhs, rhs):
  17. rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs]
  18. super().__init__(lhs, rhs)
  19. self.template_params = {}
  20. self.process_rhs_params()
  21. def process_rhs_params(self):
  22. if self.rhs_params:
  23. # Check if a band index was passed in the query argument.
  24. if len(self.rhs_params) == (2 if self.lookup_name == "relate" else 1):
  25. self.process_band_indices()
  26. elif len(self.rhs_params) > 1:
  27. raise ValueError("Tuple too long for lookup %s." % self.lookup_name)
  28. elif isinstance(self.lhs, RasterBandTransform):
  29. self.process_band_indices(only_lhs=True)
  30. def process_band_indices(self, only_lhs=False):
  31. """
  32. Extract the lhs band index from the band transform class and the rhs
  33. band index from the input tuple.
  34. """
  35. # PostGIS band indices are 1-based, so the band index needs to be
  36. # increased to be consistent with the GDALRaster band indices.
  37. if only_lhs:
  38. self.band_rhs = 1
  39. self.band_lhs = self.lhs.band_index + 1
  40. return
  41. if isinstance(self.lhs, RasterBandTransform):
  42. self.band_lhs = self.lhs.band_index + 1
  43. else:
  44. self.band_lhs = 1
  45. self.band_rhs, *self.rhs_params = self.rhs_params
  46. def get_db_prep_lookup(self, value, connection):
  47. # get_db_prep_lookup is called by process_rhs from super class
  48. return ("%s", [connection.ops.Adapter(value)])
  49. def process_rhs(self, compiler, connection):
  50. if isinstance(self.rhs, Query):
  51. # If rhs is some Query, don't touch it.
  52. return super().process_rhs(compiler, connection)
  53. if isinstance(self.rhs, Expression):
  54. self.rhs = self.rhs.resolve_expression(compiler.query)
  55. rhs, rhs_params = super().process_rhs(compiler, connection)
  56. placeholder = connection.ops.get_geom_placeholder(
  57. self.lhs.output_field, self.rhs, compiler
  58. )
  59. return placeholder % rhs, rhs_params
  60. def get_rhs_op(self, connection, rhs):
  61. # Unlike BuiltinLookup, the GIS get_rhs_op() implementation should return
  62. # an object (SpatialOperator) with an as_sql() method to allow for more
  63. # complex computations (where the lhs part can be mixed in).
  64. return connection.ops.gis_operators[self.lookup_name]
  65. def as_sql(self, compiler, connection):
  66. lhs_sql, lhs_params = self.process_lhs(compiler, connection)
  67. rhs_sql, rhs_params = self.process_rhs(compiler, connection)
  68. sql_params = (*lhs_params, *rhs_params)
  69. template_params = {
  70. "lhs": lhs_sql,
  71. "rhs": rhs_sql,
  72. "value": "%s",
  73. **self.template_params,
  74. }
  75. rhs_op = self.get_rhs_op(connection, rhs_sql)
  76. return rhs_op.as_sql(connection, self, template_params, sql_params)
  77. # ------------------
  78. # Geometry operators
  79. # ------------------
  80. @BaseSpatialField.register_lookup
  81. class OverlapsLeftLookup(GISLookup):
  82. """
  83. The overlaps_left operator returns true if A's bounding box overlaps or is to the
  84. left of B's bounding box.
  85. """
  86. lookup_name = "overlaps_left"
  87. @BaseSpatialField.register_lookup
  88. class OverlapsRightLookup(GISLookup):
  89. """
  90. The 'overlaps_right' operator returns true if A's bounding box overlaps or is to the
  91. right of B's bounding box.
  92. """
  93. lookup_name = "overlaps_right"
  94. @BaseSpatialField.register_lookup
  95. class OverlapsBelowLookup(GISLookup):
  96. """
  97. The 'overlaps_below' operator returns true if A's bounding box overlaps or is below
  98. B's bounding box.
  99. """
  100. lookup_name = "overlaps_below"
  101. @BaseSpatialField.register_lookup
  102. class OverlapsAboveLookup(GISLookup):
  103. """
  104. The 'overlaps_above' operator returns true if A's bounding box overlaps or is above
  105. B's bounding box.
  106. """
  107. lookup_name = "overlaps_above"
  108. @BaseSpatialField.register_lookup
  109. class LeftLookup(GISLookup):
  110. """
  111. The 'left' operator returns true if A's bounding box is strictly to the left
  112. of B's bounding box.
  113. """
  114. lookup_name = "left"
  115. @BaseSpatialField.register_lookup
  116. class RightLookup(GISLookup):
  117. """
  118. The 'right' operator returns true if A's bounding box is strictly to the right
  119. of B's bounding box.
  120. """
  121. lookup_name = "right"
  122. @BaseSpatialField.register_lookup
  123. class StrictlyBelowLookup(GISLookup):
  124. """
  125. The 'strictly_below' operator returns true if A's bounding box is strictly below B's
  126. bounding box.
  127. """
  128. lookup_name = "strictly_below"
  129. @BaseSpatialField.register_lookup
  130. class StrictlyAboveLookup(GISLookup):
  131. """
  132. The 'strictly_above' operator returns true if A's bounding box is strictly above B's
  133. bounding box.
  134. """
  135. lookup_name = "strictly_above"
  136. @BaseSpatialField.register_lookup
  137. class SameAsLookup(GISLookup):
  138. """
  139. The "~=" operator is the "same as" operator. It tests actual geometric
  140. equality of two features. So if A and B are the same feature,
  141. vertex-by-vertex, the operator returns true.
  142. """
  143. lookup_name = "same_as"
  144. BaseSpatialField.register_lookup(SameAsLookup, "exact")
  145. @BaseSpatialField.register_lookup
  146. class BBContainsLookup(GISLookup):
  147. """
  148. The 'bbcontains' operator returns true if A's bounding box completely contains
  149. by B's bounding box.
  150. """
  151. lookup_name = "bbcontains"
  152. @BaseSpatialField.register_lookup
  153. class BBOverlapsLookup(GISLookup):
  154. """
  155. The 'bboverlaps' operator returns true if A's bounding box overlaps B's
  156. bounding box.
  157. """
  158. lookup_name = "bboverlaps"
  159. @BaseSpatialField.register_lookup
  160. class ContainedLookup(GISLookup):
  161. """
  162. The 'contained' operator returns true if A's bounding box is completely contained
  163. by B's bounding box.
  164. """
  165. lookup_name = "contained"
  166. # ------------------
  167. # Geometry functions
  168. # ------------------
  169. @BaseSpatialField.register_lookup
  170. class ContainsLookup(GISLookup):
  171. lookup_name = "contains"
  172. @BaseSpatialField.register_lookup
  173. class ContainsProperlyLookup(GISLookup):
  174. lookup_name = "contains_properly"
  175. @BaseSpatialField.register_lookup
  176. class CoveredByLookup(GISLookup):
  177. lookup_name = "coveredby"
  178. @BaseSpatialField.register_lookup
  179. class CoversLookup(GISLookup):
  180. lookup_name = "covers"
  181. @BaseSpatialField.register_lookup
  182. class CrossesLookup(GISLookup):
  183. lookup_name = "crosses"
  184. @BaseSpatialField.register_lookup
  185. class DisjointLookup(GISLookup):
  186. lookup_name = "disjoint"
  187. @BaseSpatialField.register_lookup
  188. class EqualsLookup(GISLookup):
  189. lookup_name = "equals"
  190. @BaseSpatialField.register_lookup
  191. class IntersectsLookup(GISLookup):
  192. lookup_name = "intersects"
  193. @BaseSpatialField.register_lookup
  194. class OverlapsLookup(GISLookup):
  195. lookup_name = "overlaps"
  196. @BaseSpatialField.register_lookup
  197. class RelateLookup(GISLookup):
  198. lookup_name = "relate"
  199. sql_template = "%(func)s(%(lhs)s, %(rhs)s, %%s)"
  200. pattern_regex = _lazy_re_compile(r"^[012TF*]{9}$")
  201. def process_rhs(self, compiler, connection):
  202. # Check the pattern argument
  203. pattern = self.rhs_params[0]
  204. backend_op = connection.ops.gis_operators[self.lookup_name]
  205. if hasattr(backend_op, "check_relate_argument"):
  206. backend_op.check_relate_argument(pattern)
  207. elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
  208. raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
  209. sql, params = super().process_rhs(compiler, connection)
  210. return sql, params + [pattern]
  211. @BaseSpatialField.register_lookup
  212. class TouchesLookup(GISLookup):
  213. lookup_name = "touches"
  214. @BaseSpatialField.register_lookup
  215. class WithinLookup(GISLookup):
  216. lookup_name = "within"
  217. class DistanceLookupBase(GISLookup):
  218. distance = True
  219. sql_template = "%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s"
  220. def process_rhs_params(self):
  221. if not 1 <= len(self.rhs_params) <= 3:
  222. raise ValueError(
  223. "2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name
  224. )
  225. elif len(self.rhs_params) == 3 and self.rhs_params[2] != "spheroid":
  226. raise ValueError(
  227. "For 4-element tuples the last argument must be the 'spheroid' "
  228. "directive."
  229. )
  230. # Check if the second parameter is a band index.
  231. if len(self.rhs_params) > 1 and self.rhs_params[1] != "spheroid":
  232. self.process_band_indices()
  233. def process_distance(self, compiler, connection):
  234. dist_param = self.rhs_params[0]
  235. return (
  236. compiler.compile(dist_param.resolve_expression(compiler.query))
  237. if hasattr(dist_param, "resolve_expression")
  238. else (
  239. "%s",
  240. connection.ops.get_distance(
  241. self.lhs.output_field, self.rhs_params, self.lookup_name
  242. ),
  243. )
  244. )
  245. @BaseSpatialField.register_lookup
  246. class DWithinLookup(DistanceLookupBase):
  247. lookup_name = "dwithin"
  248. sql_template = "%(func)s(%(lhs)s, %(rhs)s, %(value)s)"
  249. def process_distance(self, compiler, connection):
  250. dist_param = self.rhs_params[0]
  251. if (
  252. not connection.features.supports_dwithin_distance_expr
  253. and hasattr(dist_param, "resolve_expression")
  254. and not isinstance(dist_param, Distance)
  255. ):
  256. raise NotSupportedError(
  257. "This backend does not support expressions for specifying "
  258. "distance in the dwithin lookup."
  259. )
  260. return super().process_distance(compiler, connection)
  261. def process_rhs(self, compiler, connection):
  262. dist_sql, dist_params = self.process_distance(compiler, connection)
  263. self.template_params["value"] = dist_sql
  264. rhs_sql, params = super().process_rhs(compiler, connection)
  265. return rhs_sql, params + dist_params
  266. class DistanceLookupFromFunction(DistanceLookupBase):
  267. def as_sql(self, compiler, connection):
  268. spheroid = (
  269. len(self.rhs_params) == 2 and self.rhs_params[-1] == "spheroid"
  270. ) or None
  271. distance_expr = connection.ops.distance_expr_for_lookup(
  272. self.lhs, self.rhs, spheroid=spheroid
  273. )
  274. sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
  275. dist_sql, dist_params = self.process_distance(compiler, connection)
  276. return (
  277. "%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql},
  278. params + dist_params,
  279. )
  280. @BaseSpatialField.register_lookup
  281. class DistanceGTLookup(DistanceLookupFromFunction):
  282. lookup_name = "distance_gt"
  283. op = ">"
  284. @BaseSpatialField.register_lookup
  285. class DistanceGTELookup(DistanceLookupFromFunction):
  286. lookup_name = "distance_gte"
  287. op = ">="
  288. @BaseSpatialField.register_lookup
  289. class DistanceLTLookup(DistanceLookupFromFunction):
  290. lookup_name = "distance_lt"
  291. op = "<"
  292. @BaseSpatialField.register_lookup
  293. class DistanceLTELookup(DistanceLookupFromFunction):
  294. lookup_name = "distance_lte"
  295. op = "<="