aggregates.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from django.contrib.gis.db.models.fields import (
  2. ExtentField,
  3. GeometryCollectionField,
  4. GeometryField,
  5. LineStringField,
  6. )
  7. from django.db.models import Aggregate, Func, Value
  8. from django.utils.functional import cached_property
  9. __all__ = ["Collect", "Extent", "Extent3D", "MakeLine", "Union"]
  10. class GeoAggregate(Aggregate):
  11. function = None
  12. is_extent = False
  13. @cached_property
  14. def output_field(self):
  15. return self.output_field_class(self.source_expressions[0].output_field.srid)
  16. def as_sql(self, compiler, connection, function=None, **extra_context):
  17. # this will be called again in parent, but it's needed now - before
  18. # we get the spatial_aggregate_name
  19. connection.ops.check_expression_support(self)
  20. return super().as_sql(
  21. compiler,
  22. connection,
  23. function=function or connection.ops.spatial_aggregate_name(self.name),
  24. **extra_context,
  25. )
  26. def as_oracle(self, compiler, connection, **extra_context):
  27. if not self.is_extent:
  28. tolerance = self.extra.get("tolerance") or getattr(self, "tolerance", 0.05)
  29. clone = self.copy()
  30. source_expressions = self.get_source_expressions()
  31. source_expressions.pop() # Don't wrap filters with SDOAGGRTYPE().
  32. spatial_type_expr = Func(
  33. *source_expressions,
  34. Value(tolerance),
  35. function="SDOAGGRTYPE",
  36. output_field=self.output_field,
  37. )
  38. source_expressions = [spatial_type_expr, self.filter]
  39. clone.set_source_expressions(source_expressions)
  40. return clone.as_sql(compiler, connection, **extra_context)
  41. return self.as_sql(compiler, connection, **extra_context)
  42. def resolve_expression(
  43. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  44. ):
  45. c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  46. for field in c.get_source_fields():
  47. if not hasattr(field, "geom_type"):
  48. raise ValueError(
  49. "Geospatial aggregates only allowed on geometry fields."
  50. )
  51. return c
  52. class Collect(GeoAggregate):
  53. name = "Collect"
  54. output_field_class = GeometryCollectionField
  55. class Extent(GeoAggregate):
  56. name = "Extent"
  57. is_extent = "2D"
  58. def __init__(self, expression, **extra):
  59. super().__init__(expression, output_field=ExtentField(), **extra)
  60. def convert_value(self, value, expression, connection):
  61. return connection.ops.convert_extent(value)
  62. class Extent3D(GeoAggregate):
  63. name = "Extent3D"
  64. is_extent = "3D"
  65. def __init__(self, expression, **extra):
  66. super().__init__(expression, output_field=ExtentField(), **extra)
  67. def convert_value(self, value, expression, connection):
  68. return connection.ops.convert_extent3d(value)
  69. class MakeLine(GeoAggregate):
  70. name = "MakeLine"
  71. output_field_class = LineStringField
  72. class Union(GeoAggregate):
  73. name = "Union"
  74. output_field_class = GeometryField