| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- from django import forms
- from django.core import exceptions
- from django.db.backends.postgresql.psycopg_any import (
- DateRange,
- DateTimeTZRange,
- NumericRange,
- )
- from django.forms.widgets import HiddenInput, MultiWidget
- from django.utils.translation import gettext_lazy as _
- __all__ = [
- "BaseRangeField",
- "IntegerRangeField",
- "DecimalRangeField",
- "DateTimeRangeField",
- "DateRangeField",
- "HiddenRangeWidget",
- "RangeWidget",
- ]
- class RangeWidget(MultiWidget):
- def __init__(self, base_widget, attrs=None):
- widgets = (base_widget, base_widget)
- super().__init__(widgets, attrs)
- def decompress(self, value):
- if value:
- return (value.lower, value.upper)
- return (None, None)
- class HiddenRangeWidget(RangeWidget):
- """A widget that splits input into two <input type="hidden"> inputs."""
- def __init__(self, attrs=None):
- super().__init__(HiddenInput, attrs)
- class BaseRangeField(forms.MultiValueField):
- default_error_messages = {
- "invalid": _("Enter two valid values."),
- "bound_ordering": _(
- "The start of the range must not exceed the end of the range."
- ),
- }
- hidden_widget = HiddenRangeWidget
- def __init__(self, **kwargs):
- if "widget" not in kwargs:
- kwargs["widget"] = RangeWidget(self.base_field.widget)
- if "fields" not in kwargs:
- kwargs["fields"] = [
- self.base_field(required=False),
- self.base_field(required=False),
- ]
- kwargs.setdefault("required", False)
- kwargs.setdefault("require_all_fields", False)
- self.range_kwargs = {}
- if default_bounds := kwargs.pop("default_bounds", None):
- self.range_kwargs = {"bounds": default_bounds}
- super().__init__(**kwargs)
- def prepare_value(self, value):
- lower_base, upper_base = self.fields
- if isinstance(value, self.range_type):
- return [
- lower_base.prepare_value(value.lower),
- upper_base.prepare_value(value.upper),
- ]
- if value is None:
- return [
- lower_base.prepare_value(None),
- upper_base.prepare_value(None),
- ]
- return value
- def compress(self, values):
- if not values:
- return None
- lower, upper = values
- if lower is not None and upper is not None and lower > upper:
- raise exceptions.ValidationError(
- self.error_messages["bound_ordering"],
- code="bound_ordering",
- )
- try:
- range_value = self.range_type(lower, upper, **self.range_kwargs)
- except TypeError:
- raise exceptions.ValidationError(
- self.error_messages["invalid"],
- code="invalid",
- )
- else:
- return range_value
- class IntegerRangeField(BaseRangeField):
- default_error_messages = {"invalid": _("Enter two whole numbers.")}
- base_field = forms.IntegerField
- range_type = NumericRange
- class DecimalRangeField(BaseRangeField):
- default_error_messages = {"invalid": _("Enter two numbers.")}
- base_field = forms.DecimalField
- range_type = NumericRange
- class DateTimeRangeField(BaseRangeField):
- default_error_messages = {"invalid": _("Enter two valid date/times.")}
- base_field = forms.DateTimeField
- range_type = DateTimeTZRange
- class DateRangeField(BaseRangeField):
- default_error_messages = {"invalid": _("Enter two valid dates.")}
- base_field = forms.DateField
- range_type = DateRange
|