functions.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. from decimal import Decimal
  2. from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
  3. from django.contrib.gis.db.models.sql import AreaField, DistanceField
  4. from django.contrib.gis.geos import GEOSGeometry
  5. from django.core.exceptions import FieldError
  6. from django.db import NotSupportedError
  7. from django.db.models import (
  8. BinaryField,
  9. BooleanField,
  10. FloatField,
  11. Func,
  12. IntegerField,
  13. TextField,
  14. Transform,
  15. Value,
  16. )
  17. from django.db.models.functions import Cast
  18. from django.utils.functional import cached_property
  19. NUMERIC_TYPES = (int, float, Decimal)
  20. class GeoFuncMixin:
  21. function = None
  22. geom_param_pos = (0,)
  23. def __init__(self, *expressions, **extra):
  24. super().__init__(*expressions, **extra)
  25. # Ensure that value expressions are geometric.
  26. for pos in self.geom_param_pos:
  27. expr = self.source_expressions[pos]
  28. if not isinstance(expr, Value):
  29. continue
  30. try:
  31. output_field = expr.output_field
  32. except FieldError:
  33. output_field = None
  34. geom = expr.value
  35. if (
  36. not isinstance(geom, GEOSGeometry)
  37. or output_field
  38. and not isinstance(output_field, GeometryField)
  39. ):
  40. raise TypeError(
  41. "%s function requires a geometric argument in position %d."
  42. % (self.name, pos + 1)
  43. )
  44. if not geom.srid and not output_field:
  45. raise ValueError("SRID is required for all geometries.")
  46. if not output_field:
  47. self.source_expressions[pos] = Value(
  48. geom, output_field=GeometryField(srid=geom.srid)
  49. )
  50. @property
  51. def name(self):
  52. return self.__class__.__name__
  53. @cached_property
  54. def geo_field(self):
  55. return self.source_expressions[self.geom_param_pos[0]].field
  56. def as_sql(self, compiler, connection, function=None, **extra_context):
  57. if self.function is None and function is None:
  58. function = connection.ops.spatial_function_name(self.name)
  59. return super().as_sql(compiler, connection, function=function, **extra_context)
  60. def resolve_expression(self, *args, **kwargs):
  61. res = super().resolve_expression(*args, **kwargs)
  62. if not self.geom_param_pos:
  63. return res
  64. # Ensure that expressions are geometric.
  65. source_fields = res.get_source_fields()
  66. for pos in self.geom_param_pos:
  67. field = source_fields[pos]
  68. if not isinstance(field, GeometryField):
  69. raise TypeError(
  70. "%s function requires a GeometryField in position %s, got %s."
  71. % (
  72. self.name,
  73. pos + 1,
  74. type(field).__name__,
  75. )
  76. )
  77. base_srid = res.geo_field.srid
  78. for pos in self.geom_param_pos[1:]:
  79. expr = res.source_expressions[pos]
  80. expr_srid = expr.output_field.srid
  81. if expr_srid != base_srid:
  82. # Automatic SRID conversion so objects are comparable.
  83. res.source_expressions[pos] = Transform(
  84. expr, base_srid
  85. ).resolve_expression(*args, **kwargs)
  86. return res
  87. def _handle_param(self, value, param_name="", check_types=None):
  88. if not hasattr(value, "resolve_expression"):
  89. if check_types and not isinstance(value, check_types):
  90. raise TypeError(
  91. "The %s parameter has the wrong type: should be %s."
  92. % (param_name, check_types)
  93. )
  94. return value
  95. class GeoFunc(GeoFuncMixin, Func):
  96. pass
  97. class GeomOutputGeoFunc(GeoFunc):
  98. @cached_property
  99. def output_field(self):
  100. return GeometryField(srid=self.geo_field.srid)
  101. class SQLiteDecimalToFloatMixin:
  102. """
  103. By default, Decimal values are converted to str by the SQLite backend, which
  104. is not acceptable by the GIS functions expecting numeric values.
  105. """
  106. def as_sqlite(self, compiler, connection, **extra_context):
  107. copy = self.copy()
  108. copy.set_source_expressions(
  109. [
  110. (
  111. Value(float(expr.value))
  112. if hasattr(expr, "value") and isinstance(expr.value, Decimal)
  113. else expr
  114. )
  115. for expr in copy.get_source_expressions()
  116. ]
  117. )
  118. return copy.as_sql(compiler, connection, **extra_context)
  119. class OracleToleranceMixin:
  120. tolerance = 0.05
  121. def as_oracle(self, compiler, connection, **extra_context):
  122. tolerance = Value(
  123. self._handle_param(
  124. self.extra.get("tolerance", self.tolerance),
  125. "tolerance",
  126. NUMERIC_TYPES,
  127. )
  128. )
  129. clone = self.copy()
  130. clone.set_source_expressions([*self.get_source_expressions(), tolerance])
  131. return clone.as_sql(compiler, connection, **extra_context)
  132. class Area(OracleToleranceMixin, GeoFunc):
  133. arity = 1
  134. @cached_property
  135. def output_field(self):
  136. return AreaField(self.geo_field)
  137. def as_sql(self, compiler, connection, **extra_context):
  138. if not connection.features.supports_area_geodetic and self.geo_field.geodetic(
  139. connection
  140. ):
  141. raise NotSupportedError(
  142. "Area on geodetic coordinate systems not supported."
  143. )
  144. return super().as_sql(compiler, connection, **extra_context)
  145. def as_sqlite(self, compiler, connection, **extra_context):
  146. if self.geo_field.geodetic(connection):
  147. extra_context["template"] = "%(function)s(%(expressions)s, %(spheroid)d)"
  148. extra_context["spheroid"] = True
  149. return self.as_sql(compiler, connection, **extra_context)
  150. class Azimuth(GeoFunc):
  151. output_field = FloatField()
  152. arity = 2
  153. geom_param_pos = (0, 1)
  154. class AsGeoJSON(GeoFunc):
  155. output_field = TextField()
  156. def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
  157. expressions = [expression]
  158. if precision is not None:
  159. expressions.append(self._handle_param(precision, "precision", int))
  160. options = 0
  161. if crs and bbox:
  162. options = 3
  163. elif bbox:
  164. options = 1
  165. elif crs:
  166. options = 2
  167. expressions.append(options)
  168. super().__init__(*expressions, **extra)
  169. def as_oracle(self, compiler, connection, **extra_context):
  170. source_expressions = self.get_source_expressions()
  171. clone = self.copy()
  172. clone.set_source_expressions(source_expressions[:1])
  173. return super(AsGeoJSON, clone).as_sql(compiler, connection, **extra_context)
  174. class AsGML(GeoFunc):
  175. geom_param_pos = (1,)
  176. output_field = TextField()
  177. def __init__(self, expression, version=2, precision=8, **extra):
  178. expressions = [version, expression]
  179. if precision is not None:
  180. expressions.append(self._handle_param(precision, "precision", int))
  181. super().__init__(*expressions, **extra)
  182. def as_oracle(self, compiler, connection, **extra_context):
  183. source_expressions = self.get_source_expressions()
  184. version = source_expressions[0]
  185. clone = self.copy()
  186. clone.set_source_expressions([source_expressions[1]])
  187. extra_context["function"] = (
  188. "SDO_UTIL.TO_GML311GEOMETRY"
  189. if version.value == 3
  190. else "SDO_UTIL.TO_GMLGEOMETRY"
  191. )
  192. return super(AsGML, clone).as_sql(compiler, connection, **extra_context)
  193. class AsKML(GeoFunc):
  194. output_field = TextField()
  195. def __init__(self, expression, precision=8, **extra):
  196. expressions = [expression]
  197. if precision is not None:
  198. expressions.append(self._handle_param(precision, "precision", int))
  199. super().__init__(*expressions, **extra)
  200. class AsSVG(GeoFunc):
  201. output_field = TextField()
  202. def __init__(self, expression, relative=False, precision=8, **extra):
  203. relative = (
  204. relative if hasattr(relative, "resolve_expression") else int(relative)
  205. )
  206. expressions = [
  207. expression,
  208. relative,
  209. self._handle_param(precision, "precision", int),
  210. ]
  211. super().__init__(*expressions, **extra)
  212. class AsWKB(GeoFunc):
  213. output_field = BinaryField()
  214. arity = 1
  215. class AsWKT(GeoFunc):
  216. output_field = TextField()
  217. arity = 1
  218. class BoundingCircle(OracleToleranceMixin, GeomOutputGeoFunc):
  219. def __init__(self, expression, num_seg=48, **extra):
  220. super().__init__(expression, num_seg, **extra)
  221. def as_oracle(self, compiler, connection, **extra_context):
  222. clone = self.copy()
  223. clone.set_source_expressions([self.get_source_expressions()[0]])
  224. return super(BoundingCircle, clone).as_oracle(
  225. compiler, connection, **extra_context
  226. )
  227. def as_sqlite(self, compiler, connection, **extra_context):
  228. clone = self.copy()
  229. clone.set_source_expressions([self.get_source_expressions()[0]])
  230. return super(BoundingCircle, clone).as_sqlite(
  231. compiler, connection, **extra_context
  232. )
  233. class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
  234. arity = 1
  235. class ClosestPoint(GeomOutputGeoFunc):
  236. arity = 2
  237. geom_param_pos = (0, 1)
  238. class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
  239. arity = 2
  240. geom_param_pos = (0, 1)
  241. class DistanceResultMixin:
  242. @cached_property
  243. def output_field(self):
  244. return DistanceField(self.geo_field)
  245. def source_is_geography(self):
  246. return self.geo_field.geography and self.geo_field.srid == 4326
  247. class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  248. geom_param_pos = (0, 1)
  249. spheroid = None
  250. def __init__(self, expr1, expr2, spheroid=None, **extra):
  251. expressions = [expr1, expr2]
  252. if spheroid is not None:
  253. self.spheroid = self._handle_param(spheroid, "spheroid", bool)
  254. super().__init__(*expressions, **extra)
  255. def as_postgresql(self, compiler, connection, **extra_context):
  256. clone = self.copy()
  257. function = None
  258. expr2 = clone.source_expressions[1]
  259. geography = self.source_is_geography()
  260. if expr2.output_field.geography != geography:
  261. if isinstance(expr2, Value):
  262. expr2.output_field.geography = geography
  263. else:
  264. clone.source_expressions[1] = Cast(
  265. expr2,
  266. GeometryField(srid=expr2.output_field.srid, geography=geography),
  267. )
  268. if not geography and self.geo_field.geodetic(connection):
  269. # Geometry fields with geodetic (lon/lat) coordinates need special
  270. # distance functions.
  271. if self.spheroid:
  272. # DistanceSpheroid is more accurate and resource intensive than
  273. # DistanceSphere.
  274. function = connection.ops.spatial_function_name("DistanceSpheroid")
  275. # Replace boolean param by the real spheroid of the base field
  276. clone.source_expressions.append(
  277. Value(self.geo_field.spheroid(connection))
  278. )
  279. else:
  280. function = connection.ops.spatial_function_name("DistanceSphere")
  281. return super(Distance, clone).as_sql(
  282. compiler, connection, function=function, **extra_context
  283. )
  284. def as_sqlite(self, compiler, connection, **extra_context):
  285. if self.geo_field.geodetic(connection):
  286. # SpatiaLite returns NULL instead of zero on geodetic coordinates
  287. extra_context["template"] = (
  288. "COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)"
  289. )
  290. extra_context["spheroid"] = int(bool(self.spheroid))
  291. return super().as_sql(compiler, connection, **extra_context)
  292. class Envelope(GeomOutputGeoFunc):
  293. arity = 1
  294. class ForcePolygonCW(GeomOutputGeoFunc):
  295. arity = 1
  296. class FromWKB(GeoFunc):
  297. arity = 2
  298. geom_param_pos = ()
  299. def __init__(self, expression, srid=0, **extra):
  300. expressions = [
  301. expression,
  302. self._handle_param(srid, "srid", int),
  303. ]
  304. if "output_field" not in extra:
  305. extra["output_field"] = GeometryField(srid=srid)
  306. super().__init__(*expressions, **extra)
  307. def as_oracle(self, compiler, connection, **extra_context):
  308. # Oracle doesn't support the srid parameter.
  309. source_expressions = self.get_source_expressions()
  310. clone = self.copy()
  311. clone.set_source_expressions(source_expressions[:1])
  312. return super(FromWKB, clone).as_sql(compiler, connection, **extra_context)
  313. class FromWKT(FromWKB):
  314. pass
  315. class GeoHash(GeoFunc):
  316. output_field = TextField()
  317. def __init__(self, expression, precision=None, **extra):
  318. expressions = [expression]
  319. if precision is not None:
  320. expressions.append(self._handle_param(precision, "precision", int))
  321. super().__init__(*expressions, **extra)
  322. def as_mysql(self, compiler, connection, **extra_context):
  323. clone = self.copy()
  324. # If no precision is provided, set it to the maximum.
  325. if len(clone.source_expressions) < 2:
  326. clone.source_expressions.append(Value(100))
  327. return clone.as_sql(compiler, connection, **extra_context)
  328. class GeometryDistance(GeoFunc):
  329. output_field = FloatField()
  330. arity = 2
  331. function = ""
  332. arg_joiner = " <-> "
  333. geom_param_pos = (0, 1)
  334. class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
  335. arity = 2
  336. geom_param_pos = (0, 1)
  337. @BaseSpatialField.register_lookup
  338. class IsEmpty(GeoFuncMixin, Transform):
  339. lookup_name = "isempty"
  340. output_field = BooleanField()
  341. @BaseSpatialField.register_lookup
  342. class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform):
  343. lookup_name = "isvalid"
  344. output_field = BooleanField()
  345. def as_oracle(self, compiler, connection, **extra_context):
  346. sql, params = super().as_oracle(compiler, connection, **extra_context)
  347. return "CASE %s WHEN 'TRUE' THEN 1 ELSE 0 END" % sql, params
  348. class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  349. def __init__(self, expr1, spheroid=True, **extra):
  350. self.spheroid = spheroid
  351. super().__init__(expr1, **extra)
  352. def as_sql(self, compiler, connection, **extra_context):
  353. if (
  354. self.geo_field.geodetic(connection)
  355. and not connection.features.supports_length_geodetic
  356. ):
  357. raise NotSupportedError(
  358. "This backend doesn't support Length on geodetic fields"
  359. )
  360. return super().as_sql(compiler, connection, **extra_context)
  361. def as_postgresql(self, compiler, connection, **extra_context):
  362. clone = self.copy()
  363. function = None
  364. if self.source_is_geography():
  365. clone.source_expressions.append(Value(self.spheroid))
  366. elif self.geo_field.geodetic(connection):
  367. # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
  368. function = connection.ops.spatial_function_name("LengthSpheroid")
  369. clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
  370. else:
  371. dim = min(f.dim for f in self.get_source_fields() if f)
  372. if dim > 2:
  373. function = connection.ops.length3d
  374. return super(Length, clone).as_sql(
  375. compiler, connection, function=function, **extra_context
  376. )
  377. def as_sqlite(self, compiler, connection, **extra_context):
  378. function = None
  379. if self.geo_field.geodetic(connection):
  380. function = "GeodesicLength" if self.spheroid else "GreatCircleLength"
  381. return super().as_sql(compiler, connection, function=function, **extra_context)
  382. class LineLocatePoint(GeoFunc):
  383. output_field = FloatField()
  384. arity = 2
  385. geom_param_pos = (0, 1)
  386. class MakeValid(GeomOutputGeoFunc):
  387. pass
  388. class MemSize(GeoFunc):
  389. output_field = IntegerField()
  390. arity = 1
  391. class NumGeometries(GeoFunc):
  392. output_field = IntegerField()
  393. arity = 1
  394. class NumPoints(GeoFunc):
  395. output_field = IntegerField()
  396. arity = 1
  397. class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  398. arity = 1
  399. def as_postgresql(self, compiler, connection, **extra_context):
  400. function = None
  401. if self.geo_field.geodetic(connection) and not self.source_is_geography():
  402. raise NotSupportedError(
  403. "ST_Perimeter cannot use a non-projected non-geography field."
  404. )
  405. dim = min(f.dim for f in self.get_source_fields())
  406. if dim > 2:
  407. function = connection.ops.perimeter3d
  408. return super().as_sql(compiler, connection, function=function, **extra_context)
  409. def as_sqlite(self, compiler, connection, **extra_context):
  410. if self.geo_field.geodetic(connection):
  411. raise NotSupportedError("Perimeter cannot use a non-projected field.")
  412. return super().as_sql(compiler, connection, **extra_context)
  413. class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
  414. arity = 1
  415. class Reverse(GeoFunc):
  416. arity = 1
  417. class Scale(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
  418. def __init__(self, expression, x, y, z=0.0, **extra):
  419. expressions = [
  420. expression,
  421. self._handle_param(x, "x", NUMERIC_TYPES),
  422. self._handle_param(y, "y", NUMERIC_TYPES),
  423. ]
  424. if z != 0.0:
  425. expressions.append(self._handle_param(z, "z", NUMERIC_TYPES))
  426. super().__init__(*expressions, **extra)
  427. class SnapToGrid(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
  428. def __init__(self, expression, *args, **extra):
  429. nargs = len(args)
  430. expressions = [expression]
  431. if nargs in (1, 2):
  432. expressions.extend(
  433. [self._handle_param(arg, "", NUMERIC_TYPES) for arg in args]
  434. )
  435. elif nargs == 4:
  436. # Reverse origin and size param ordering
  437. expressions += [
  438. *(self._handle_param(arg, "", NUMERIC_TYPES) for arg in args[2:]),
  439. *(self._handle_param(arg, "", NUMERIC_TYPES) for arg in args[0:2]),
  440. ]
  441. else:
  442. raise ValueError("Must provide 1, 2, or 4 arguments to `SnapToGrid`.")
  443. super().__init__(*expressions, **extra)
  444. class SymDifference(OracleToleranceMixin, GeomOutputGeoFunc):
  445. arity = 2
  446. geom_param_pos = (0, 1)
  447. class Transform(GeomOutputGeoFunc):
  448. def __init__(self, expression, srid, **extra):
  449. expressions = [
  450. expression,
  451. self._handle_param(srid, "srid", int),
  452. ]
  453. if "output_field" not in extra:
  454. extra["output_field"] = GeometryField(srid=srid)
  455. super().__init__(*expressions, **extra)
  456. class Translate(Scale):
  457. def as_sqlite(self, compiler, connection, **extra_context):
  458. clone = self.copy()
  459. if len(self.source_expressions) < 4:
  460. # Always provide the z parameter for ST_Translate
  461. clone.source_expressions.append(Value(0))
  462. return super(Translate, clone).as_sqlite(compiler, connection, **extra_context)
  463. class Union(OracleToleranceMixin, GeomOutputGeoFunc):
  464. arity = 2
  465. geom_param_pos = (0, 1)