json.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. import json
  2. from django import forms
  3. from django.core import checks, exceptions
  4. from django.db import NotSupportedError, connections, router
  5. from django.db.models import expressions, lookups
  6. from django.db.models.constants import LOOKUP_SEP
  7. from django.db.models.fields import TextField
  8. from django.db.models.lookups import (
  9. FieldGetDbPrepValueMixin,
  10. PostgresOperatorLookup,
  11. Transform,
  12. )
  13. from django.utils.translation import gettext_lazy as _
  14. from . import Field
  15. from .mixins import CheckFieldDefaultMixin
  16. __all__ = ["JSONField"]
  17. class JSONField(CheckFieldDefaultMixin, Field):
  18. empty_strings_allowed = False
  19. description = _("A JSON object")
  20. default_error_messages = {
  21. "invalid": _("Value must be valid JSON."),
  22. }
  23. _default_hint = ("dict", "{}")
  24. def __init__(
  25. self,
  26. verbose_name=None,
  27. name=None,
  28. encoder=None,
  29. decoder=None,
  30. **kwargs,
  31. ):
  32. if encoder and not callable(encoder):
  33. raise ValueError("The encoder parameter must be a callable object.")
  34. if decoder and not callable(decoder):
  35. raise ValueError("The decoder parameter must be a callable object.")
  36. self.encoder = encoder
  37. self.decoder = decoder
  38. super().__init__(verbose_name, name, **kwargs)
  39. def check(self, **kwargs):
  40. errors = super().check(**kwargs)
  41. databases = kwargs.get("databases") or []
  42. errors.extend(self._check_supported(databases))
  43. return errors
  44. def _check_supported(self, databases):
  45. errors = []
  46. for db in databases:
  47. if not router.allow_migrate_model(db, self.model):
  48. continue
  49. connection = connections[db]
  50. if (
  51. self.model._meta.required_db_vendor
  52. and self.model._meta.required_db_vendor != connection.vendor
  53. ):
  54. continue
  55. if not (
  56. "supports_json_field" in self.model._meta.required_db_features
  57. or connection.features.supports_json_field
  58. ):
  59. errors.append(
  60. checks.Error(
  61. "%s does not support JSONFields." % connection.display_name,
  62. obj=self.model,
  63. id="fields.E180",
  64. )
  65. )
  66. return errors
  67. def deconstruct(self):
  68. name, path, args, kwargs = super().deconstruct()
  69. if self.encoder is not None:
  70. kwargs["encoder"] = self.encoder
  71. if self.decoder is not None:
  72. kwargs["decoder"] = self.decoder
  73. return name, path, args, kwargs
  74. def from_db_value(self, value, expression, connection):
  75. if value is None:
  76. return value
  77. # Some backends (SQLite at least) extract non-string values in their
  78. # SQL datatypes.
  79. if isinstance(expression, KeyTransform) and not isinstance(value, str):
  80. return value
  81. try:
  82. return json.loads(value, cls=self.decoder)
  83. except json.JSONDecodeError:
  84. return value
  85. def get_internal_type(self):
  86. return "JSONField"
  87. def get_db_prep_value(self, value, connection, prepared=False):
  88. if not prepared:
  89. value = self.get_prep_value(value)
  90. if isinstance(value, expressions.Value) and isinstance(
  91. value.output_field, JSONField
  92. ):
  93. value = value.value
  94. elif hasattr(value, "as_sql"):
  95. return value
  96. return connection.ops.adapt_json_value(value, self.encoder)
  97. def get_db_prep_save(self, value, connection):
  98. if value is None:
  99. return value
  100. return self.get_db_prep_value(value, connection)
  101. def get_transform(self, name):
  102. transform = super().get_transform(name)
  103. if transform:
  104. return transform
  105. return KeyTransformFactory(name)
  106. def validate(self, value, model_instance):
  107. super().validate(value, model_instance)
  108. try:
  109. json.dumps(value, cls=self.encoder)
  110. except TypeError:
  111. raise exceptions.ValidationError(
  112. self.error_messages["invalid"],
  113. code="invalid",
  114. params={"value": value},
  115. )
  116. def value_to_string(self, obj):
  117. return self.value_from_object(obj)
  118. def formfield(self, **kwargs):
  119. return super().formfield(
  120. **{
  121. "form_class": forms.JSONField,
  122. "encoder": self.encoder,
  123. "decoder": self.decoder,
  124. **kwargs,
  125. }
  126. )
  127. def compile_json_path(key_transforms, include_root=True):
  128. path = ["$"] if include_root else []
  129. for key_transform in key_transforms:
  130. try:
  131. num = int(key_transform)
  132. except ValueError: # non-integer
  133. path.append(".")
  134. path.append(json.dumps(key_transform))
  135. else:
  136. path.append("[%s]" % num)
  137. return "".join(path)
  138. class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
  139. lookup_name = "contains"
  140. postgres_operator = "@>"
  141. def as_sql(self, compiler, connection):
  142. if not connection.features.supports_json_field_contains:
  143. raise NotSupportedError(
  144. "contains lookup is not supported on this database backend."
  145. )
  146. lhs, lhs_params = self.process_lhs(compiler, connection)
  147. rhs, rhs_params = self.process_rhs(compiler, connection)
  148. params = tuple(lhs_params) + tuple(rhs_params)
  149. return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
  150. class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
  151. lookup_name = "contained_by"
  152. postgres_operator = "<@"
  153. def as_sql(self, compiler, connection):
  154. if not connection.features.supports_json_field_contains:
  155. raise NotSupportedError(
  156. "contained_by lookup is not supported on this database backend."
  157. )
  158. lhs, lhs_params = self.process_lhs(compiler, connection)
  159. rhs, rhs_params = self.process_rhs(compiler, connection)
  160. params = tuple(rhs_params) + tuple(lhs_params)
  161. return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
  162. class HasKeyLookup(PostgresOperatorLookup):
  163. logical_operator = None
  164. def compile_json_path_final_key(self, key_transform):
  165. # Compile the final key without interpreting ints as array elements.
  166. return ".%s" % json.dumps(key_transform)
  167. def as_sql(self, compiler, connection, template=None):
  168. # Process JSON path from the left-hand side.
  169. if isinstance(self.lhs, KeyTransform):
  170. lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
  171. compiler, connection
  172. )
  173. lhs_json_path = compile_json_path(lhs_key_transforms)
  174. else:
  175. lhs, lhs_params = self.process_lhs(compiler, connection)
  176. lhs_json_path = "$"
  177. sql = template % lhs
  178. # Process JSON path from the right-hand side.
  179. rhs = self.rhs
  180. rhs_params = []
  181. if not isinstance(rhs, (list, tuple)):
  182. rhs = [rhs]
  183. for key in rhs:
  184. if isinstance(key, KeyTransform):
  185. *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
  186. else:
  187. rhs_key_transforms = [key]
  188. *rhs_key_transforms, final_key = rhs_key_transforms
  189. rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
  190. rhs_json_path += self.compile_json_path_final_key(final_key)
  191. rhs_params.append(lhs_json_path + rhs_json_path)
  192. # Add condition for each key.
  193. if self.logical_operator:
  194. sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
  195. return sql, tuple(lhs_params) + tuple(rhs_params)
  196. def as_mysql(self, compiler, connection):
  197. return self.as_sql(
  198. compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
  199. )
  200. def as_oracle(self, compiler, connection):
  201. sql, params = self.as_sql(
  202. compiler, connection, template="JSON_EXISTS(%s, '%%s')"
  203. )
  204. # Add paths directly into SQL because path expressions cannot be passed
  205. # as bind variables on Oracle.
  206. return sql % tuple(params), []
  207. def as_postgresql(self, compiler, connection):
  208. if isinstance(self.rhs, KeyTransform):
  209. *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
  210. for key in rhs_key_transforms[:-1]:
  211. self.lhs = KeyTransform(key, self.lhs)
  212. self.rhs = rhs_key_transforms[-1]
  213. return super().as_postgresql(compiler, connection)
  214. def as_sqlite(self, compiler, connection):
  215. return self.as_sql(
  216. compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
  217. )
  218. class HasKey(HasKeyLookup):
  219. lookup_name = "has_key"
  220. postgres_operator = "?"
  221. prepare_rhs = False
  222. class HasKeys(HasKeyLookup):
  223. lookup_name = "has_keys"
  224. postgres_operator = "?&"
  225. logical_operator = " AND "
  226. def get_prep_lookup(self):
  227. return [str(item) for item in self.rhs]
  228. class HasAnyKeys(HasKeys):
  229. lookup_name = "has_any_keys"
  230. postgres_operator = "?|"
  231. logical_operator = " OR "
  232. class HasKeyOrArrayIndex(HasKey):
  233. def compile_json_path_final_key(self, key_transform):
  234. return compile_json_path([key_transform], include_root=False)
  235. class CaseInsensitiveMixin:
  236. """
  237. Mixin to allow case-insensitive comparison of JSON values on MySQL.
  238. MySQL handles strings used in JSON context using the utf8mb4_bin collation.
  239. Because utf8mb4_bin is a binary collation, comparison of JSON values is
  240. case-sensitive.
  241. """
  242. def process_lhs(self, compiler, connection):
  243. lhs, lhs_params = super().process_lhs(compiler, connection)
  244. if connection.vendor == "mysql":
  245. return "LOWER(%s)" % lhs, lhs_params
  246. return lhs, lhs_params
  247. def process_rhs(self, compiler, connection):
  248. rhs, rhs_params = super().process_rhs(compiler, connection)
  249. if connection.vendor == "mysql":
  250. return "LOWER(%s)" % rhs, rhs_params
  251. return rhs, rhs_params
  252. class JSONExact(lookups.Exact):
  253. can_use_none_as_rhs = True
  254. def process_rhs(self, compiler, connection):
  255. rhs, rhs_params = super().process_rhs(compiler, connection)
  256. # Treat None lookup values as null.
  257. if rhs == "%s" and rhs_params == [None]:
  258. rhs_params = ["null"]
  259. if connection.vendor == "mysql":
  260. func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
  261. rhs %= tuple(func)
  262. return rhs, rhs_params
  263. def as_oracle(self, compiler, connection):
  264. lhs, lhs_params = self.process_lhs(compiler, connection)
  265. rhs, rhs_params = self.process_rhs(compiler, connection)
  266. if connection.features.supports_primitives_in_json_field:
  267. lhs = f"JSON({lhs})"
  268. rhs = f"JSON({rhs})"
  269. return f"JSON_EQUAL({lhs}, {rhs} ERROR ON ERROR)", (*lhs_params, *rhs_params)
  270. class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
  271. pass
  272. JSONField.register_lookup(DataContains)
  273. JSONField.register_lookup(ContainedBy)
  274. JSONField.register_lookup(HasKey)
  275. JSONField.register_lookup(HasKeys)
  276. JSONField.register_lookup(HasAnyKeys)
  277. JSONField.register_lookup(JSONExact)
  278. JSONField.register_lookup(JSONIContains)
  279. class KeyTransform(Transform):
  280. postgres_operator = "->"
  281. postgres_nested_operator = "#>"
  282. def __init__(self, key_name, *args, **kwargs):
  283. super().__init__(*args, **kwargs)
  284. self.key_name = str(key_name)
  285. def preprocess_lhs(self, compiler, connection):
  286. key_transforms = [self.key_name]
  287. previous = self.lhs
  288. while isinstance(previous, KeyTransform):
  289. key_transforms.insert(0, previous.key_name)
  290. previous = previous.lhs
  291. lhs, params = compiler.compile(previous)
  292. if connection.vendor == "oracle":
  293. # Escape string-formatting.
  294. key_transforms = [key.replace("%", "%%") for key in key_transforms]
  295. return lhs, params, key_transforms
  296. def as_mysql(self, compiler, connection):
  297. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  298. json_path = compile_json_path(key_transforms)
  299. return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
  300. def as_oracle(self, compiler, connection):
  301. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  302. json_path = compile_json_path(key_transforms)
  303. if connection.features.supports_primitives_in_json_field:
  304. sql = (
  305. "COALESCE(JSON_VALUE(%s, '%s'), JSON_QUERY(%s, '%s' DISALLOW SCALARS))"
  306. )
  307. else:
  308. sql = "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
  309. return sql % ((lhs, json_path) * 2), tuple(params) * 2
  310. def as_postgresql(self, compiler, connection):
  311. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  312. if len(key_transforms) > 1:
  313. sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
  314. return sql, tuple(params) + (key_transforms,)
  315. try:
  316. lookup = int(self.key_name)
  317. except ValueError:
  318. lookup = self.key_name
  319. return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
  320. def as_sqlite(self, compiler, connection):
  321. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  322. json_path = compile_json_path(key_transforms)
  323. datatype_values = ",".join(
  324. [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
  325. )
  326. return (
  327. "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
  328. "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
  329. ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
  330. class KeyTextTransform(KeyTransform):
  331. postgres_operator = "->>"
  332. postgres_nested_operator = "#>>"
  333. output_field = TextField()
  334. def as_mysql(self, compiler, connection):
  335. if connection.mysql_is_mariadb:
  336. # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
  337. sql, params = super().as_mysql(compiler, connection)
  338. return "JSON_UNQUOTE(%s)" % sql, params
  339. else:
  340. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  341. json_path = compile_json_path(key_transforms)
  342. return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
  343. @classmethod
  344. def from_lookup(cls, lookup):
  345. transform, *keys = lookup.split(LOOKUP_SEP)
  346. if not keys:
  347. raise ValueError("Lookup must contain key or index transforms.")
  348. for key in keys:
  349. transform = cls(key, transform)
  350. return transform
  351. KT = KeyTextTransform.from_lookup
  352. class KeyTransformTextLookupMixin:
  353. """
  354. Mixin for combining with a lookup expecting a text lhs from a JSONField
  355. key lookup. On PostgreSQL, make use of the ->> operator instead of casting
  356. key values to text and performing the lookup on the resulting
  357. representation.
  358. """
  359. def __init__(self, key_transform, *args, **kwargs):
  360. if not isinstance(key_transform, KeyTransform):
  361. raise TypeError(
  362. "Transform should be an instance of KeyTransform in order to "
  363. "use this lookup."
  364. )
  365. key_text_transform = KeyTextTransform(
  366. key_transform.key_name,
  367. *key_transform.source_expressions,
  368. **key_transform.extra,
  369. )
  370. super().__init__(key_text_transform, *args, **kwargs)
  371. class KeyTransformIsNull(lookups.IsNull):
  372. # key__isnull=False is the same as has_key='key'
  373. def as_oracle(self, compiler, connection):
  374. sql, params = HasKeyOrArrayIndex(
  375. self.lhs.lhs,
  376. self.lhs.key_name,
  377. ).as_oracle(compiler, connection)
  378. if not self.rhs:
  379. return sql, params
  380. # Column doesn't have a key or IS NULL.
  381. lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
  382. return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
  383. def as_sqlite(self, compiler, connection):
  384. template = "JSON_TYPE(%s, %%s) IS NULL"
  385. if not self.rhs:
  386. template = "JSON_TYPE(%s, %%s) IS NOT NULL"
  387. return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
  388. compiler,
  389. connection,
  390. template=template,
  391. )
  392. class KeyTransformIn(lookups.In):
  393. def resolve_expression_parameter(self, compiler, connection, sql, param):
  394. sql, params = super().resolve_expression_parameter(
  395. compiler,
  396. connection,
  397. sql,
  398. param,
  399. )
  400. if (
  401. not hasattr(param, "as_sql")
  402. and not connection.features.has_native_json_field
  403. ):
  404. if connection.vendor == "oracle":
  405. value = json.loads(param)
  406. sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
  407. if isinstance(value, (list, dict)):
  408. sql %= "JSON_QUERY"
  409. else:
  410. sql %= "JSON_VALUE"
  411. elif connection.vendor == "mysql" or (
  412. connection.vendor == "sqlite"
  413. and params[0] not in connection.ops.jsonfield_datatype_values
  414. ):
  415. sql = "JSON_EXTRACT(%s, '$')"
  416. if connection.vendor == "mysql" and connection.mysql_is_mariadb:
  417. sql = "JSON_UNQUOTE(%s)" % sql
  418. return sql, params
  419. class KeyTransformExact(JSONExact):
  420. def process_rhs(self, compiler, connection):
  421. if isinstance(self.rhs, KeyTransform):
  422. return super(lookups.Exact, self).process_rhs(compiler, connection)
  423. rhs, rhs_params = super().process_rhs(compiler, connection)
  424. if connection.vendor == "oracle":
  425. func = []
  426. sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
  427. for value in rhs_params:
  428. value = json.loads(value)
  429. if isinstance(value, (list, dict)):
  430. func.append(sql % "JSON_QUERY")
  431. else:
  432. func.append(sql % "JSON_VALUE")
  433. rhs %= tuple(func)
  434. elif connection.vendor == "sqlite":
  435. func = []
  436. for value in rhs_params:
  437. if value in connection.ops.jsonfield_datatype_values:
  438. func.append("%s")
  439. else:
  440. func.append("JSON_EXTRACT(%s, '$')")
  441. rhs %= tuple(func)
  442. return rhs, rhs_params
  443. def as_oracle(self, compiler, connection):
  444. rhs, rhs_params = super().process_rhs(compiler, connection)
  445. if rhs_params == ["null"]:
  446. # Field has key and it's NULL.
  447. has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
  448. has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
  449. is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
  450. is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
  451. return (
  452. "%s AND %s" % (has_key_sql, is_null_sql),
  453. tuple(has_key_params) + tuple(is_null_params),
  454. )
  455. return super().as_sql(compiler, connection)
  456. class KeyTransformIExact(
  457. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
  458. ):
  459. pass
  460. class KeyTransformIContains(
  461. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
  462. ):
  463. pass
  464. class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
  465. pass
  466. class KeyTransformIStartsWith(
  467. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
  468. ):
  469. pass
  470. class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
  471. pass
  472. class KeyTransformIEndsWith(
  473. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
  474. ):
  475. pass
  476. class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
  477. pass
  478. class KeyTransformIRegex(
  479. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
  480. ):
  481. pass
  482. class KeyTransformNumericLookupMixin:
  483. def process_rhs(self, compiler, connection):
  484. rhs, rhs_params = super().process_rhs(compiler, connection)
  485. if not connection.features.has_native_json_field:
  486. rhs_params = [json.loads(value) for value in rhs_params]
  487. return rhs, rhs_params
  488. class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
  489. pass
  490. class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
  491. pass
  492. class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
  493. pass
  494. class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
  495. pass
  496. KeyTransform.register_lookup(KeyTransformIn)
  497. KeyTransform.register_lookup(KeyTransformExact)
  498. KeyTransform.register_lookup(KeyTransformIExact)
  499. KeyTransform.register_lookup(KeyTransformIsNull)
  500. KeyTransform.register_lookup(KeyTransformIContains)
  501. KeyTransform.register_lookup(KeyTransformStartsWith)
  502. KeyTransform.register_lookup(KeyTransformIStartsWith)
  503. KeyTransform.register_lookup(KeyTransformEndsWith)
  504. KeyTransform.register_lookup(KeyTransformIEndsWith)
  505. KeyTransform.register_lookup(KeyTransformRegex)
  506. KeyTransform.register_lookup(KeyTransformIRegex)
  507. KeyTransform.register_lookup(KeyTransformLt)
  508. KeyTransform.register_lookup(KeyTransformLte)
  509. KeyTransform.register_lookup(KeyTransformGt)
  510. KeyTransform.register_lookup(KeyTransformGte)
  511. class KeyTransformFactory:
  512. def __init__(self, key_name):
  513. self.key_name = key_name
  514. def __call__(self, *args, **kwargs):
  515. return KeyTransform(self.key_name, *args, **kwargs)