psycopg_any.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import ipaddress
  2. from functools import lru_cache
  3. try:
  4. from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql
  5. from psycopg.postgres import types
  6. from psycopg.types.datetime import TimestamptzLoader
  7. from psycopg.types.json import Jsonb
  8. from psycopg.types.range import Range, RangeDumper
  9. from psycopg.types.string import TextLoader
  10. Inet = ipaddress.ip_address
  11. DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range
  12. RANGE_TYPES = (Range,)
  13. TSRANGE_OID = types["tsrange"].oid
  14. TSTZRANGE_OID = types["tstzrange"].oid
  15. def mogrify(sql, params, connection):
  16. with connection.cursor() as cursor:
  17. return ClientCursor(cursor.connection).mogrify(sql, params)
  18. # Adapters.
  19. class BaseTzLoader(TimestamptzLoader):
  20. """
  21. Load a PostgreSQL timestamptz using the a specific timezone.
  22. The timezone can be None too, in which case it will be chopped.
  23. """
  24. timezone = None
  25. def load(self, data):
  26. res = super().load(data)
  27. return res.replace(tzinfo=self.timezone)
  28. def register_tzloader(tz, context):
  29. class SpecificTzLoader(BaseTzLoader):
  30. timezone = tz
  31. context.adapters.register_loader("timestamptz", SpecificTzLoader)
  32. class DjangoRangeDumper(RangeDumper):
  33. """A Range dumper customized for Django."""
  34. def upgrade(self, obj, format):
  35. # Dump ranges containing naive datetimes as tstzrange, because
  36. # Django doesn't use tz-aware ones.
  37. dumper = super().upgrade(obj, format)
  38. if dumper is not self and dumper.oid == TSRANGE_OID:
  39. dumper.oid = TSTZRANGE_OID
  40. return dumper
  41. @lru_cache
  42. def get_adapters_template(use_tz, timezone):
  43. # Create at adapters map extending the base one.
  44. ctx = adapt.AdaptersMap(adapters)
  45. # Register a no-op dumper to avoid a round trip from psycopg version 3
  46. # decode to json.dumps() to json.loads(), when using a custom decoder
  47. # in JSONField.
  48. ctx.register_loader("jsonb", TextLoader)
  49. # Don't convert automatically from PostgreSQL network types to Python
  50. # ipaddress.
  51. ctx.register_loader("inet", TextLoader)
  52. ctx.register_loader("cidr", TextLoader)
  53. ctx.register_dumper(Range, DjangoRangeDumper)
  54. # Register a timestamptz loader configured on self.timezone.
  55. # This, however, can be overridden by create_cursor.
  56. register_tzloader(timezone, ctx)
  57. return ctx
  58. is_psycopg3 = True
  59. except ImportError:
  60. from enum import IntEnum
  61. from psycopg2 import errors, extensions, sql # NOQA
  62. from psycopg2.extras import ( # NOQA
  63. DateRange,
  64. DateTimeRange,
  65. DateTimeTZRange,
  66. Inet,
  67. Json,
  68. NumericRange,
  69. Range,
  70. )
  71. RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
  72. class IsolationLevel(IntEnum):
  73. READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
  74. READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
  75. REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
  76. SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
  77. def _quote(value, connection=None):
  78. adapted = extensions.adapt(value)
  79. if hasattr(adapted, "encoding"):
  80. adapted.encoding = "utf8"
  81. # getquoted() returns a quoted bytestring of the adapted value.
  82. return adapted.getquoted().decode()
  83. sql.quote = _quote
  84. def mogrify(sql, params, connection):
  85. with connection.cursor() as cursor:
  86. return cursor.mogrify(sql, params).decode()
  87. is_psycopg3 = False
  88. class Jsonb(Json):
  89. def getquoted(self):
  90. quoted = super().getquoted()
  91. return quoted + b"::jsonb"