ranges.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from django import forms
  2. from django.core import exceptions
  3. from django.db.backends.postgresql.psycopg_any import (
  4. DateRange,
  5. DateTimeTZRange,
  6. NumericRange,
  7. )
  8. from django.forms.widgets import HiddenInput, MultiWidget
  9. from django.utils.translation import gettext_lazy as _
  10. __all__ = [
  11. "BaseRangeField",
  12. "IntegerRangeField",
  13. "DecimalRangeField",
  14. "DateTimeRangeField",
  15. "DateRangeField",
  16. "HiddenRangeWidget",
  17. "RangeWidget",
  18. ]
  19. class RangeWidget(MultiWidget):
  20. def __init__(self, base_widget, attrs=None):
  21. widgets = (base_widget, base_widget)
  22. super().__init__(widgets, attrs)
  23. def decompress(self, value):
  24. if value:
  25. return (value.lower, value.upper)
  26. return (None, None)
  27. class HiddenRangeWidget(RangeWidget):
  28. """A widget that splits input into two <input type="hidden"> inputs."""
  29. def __init__(self, attrs=None):
  30. super().__init__(HiddenInput, attrs)
  31. class BaseRangeField(forms.MultiValueField):
  32. default_error_messages = {
  33. "invalid": _("Enter two valid values."),
  34. "bound_ordering": _(
  35. "The start of the range must not exceed the end of the range."
  36. ),
  37. }
  38. hidden_widget = HiddenRangeWidget
  39. def __init__(self, **kwargs):
  40. if "widget" not in kwargs:
  41. kwargs["widget"] = RangeWidget(self.base_field.widget)
  42. if "fields" not in kwargs:
  43. kwargs["fields"] = [
  44. self.base_field(required=False),
  45. self.base_field(required=False),
  46. ]
  47. kwargs.setdefault("required", False)
  48. kwargs.setdefault("require_all_fields", False)
  49. self.range_kwargs = {}
  50. if default_bounds := kwargs.pop("default_bounds", None):
  51. self.range_kwargs = {"bounds": default_bounds}
  52. super().__init__(**kwargs)
  53. def prepare_value(self, value):
  54. lower_base, upper_base = self.fields
  55. if isinstance(value, self.range_type):
  56. return [
  57. lower_base.prepare_value(value.lower),
  58. upper_base.prepare_value(value.upper),
  59. ]
  60. if value is None:
  61. return [
  62. lower_base.prepare_value(None),
  63. upper_base.prepare_value(None),
  64. ]
  65. return value
  66. def compress(self, values):
  67. if not values:
  68. return None
  69. lower, upper = values
  70. if lower is not None and upper is not None and lower > upper:
  71. raise exceptions.ValidationError(
  72. self.error_messages["bound_ordering"],
  73. code="bound_ordering",
  74. )
  75. try:
  76. range_value = self.range_type(lower, upper, **self.range_kwargs)
  77. except TypeError:
  78. raise exceptions.ValidationError(
  79. self.error_messages["invalid"],
  80. code="invalid",
  81. )
  82. else:
  83. return range_value
  84. class IntegerRangeField(BaseRangeField):
  85. default_error_messages = {"invalid": _("Enter two whole numbers.")}
  86. base_field = forms.IntegerField
  87. range_type = NumericRange
  88. class DecimalRangeField(BaseRangeField):
  89. default_error_messages = {"invalid": _("Enter two numbers.")}
  90. base_field = forms.DecimalField
  91. range_type = NumericRange
  92. class DateTimeRangeField(BaseRangeField):
  93. default_error_messages = {"invalid": _("Enter two valid date/times.")}
  94. base_field = forms.DateTimeField
  95. range_type = DateTimeTZRange
  96. class DateRangeField(BaseRangeField):
  97. default_error_messages = {"invalid": _("Enter two valid dates.")}
  98. base_field = forms.DateField
  99. range_type = DateRange