search.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. from django.db.models import (
  2. CharField,
  3. Expression,
  4. Field,
  5. FloatField,
  6. Func,
  7. Lookup,
  8. TextField,
  9. Value,
  10. )
  11. from django.db.models.expressions import CombinedExpression, register_combinable_fields
  12. from django.db.models.functions import Cast, Coalesce
  13. class SearchVectorExact(Lookup):
  14. lookup_name = "exact"
  15. def process_rhs(self, qn, connection):
  16. if not isinstance(self.rhs, (SearchQuery, CombinedSearchQuery)):
  17. config = getattr(self.lhs, "config", None)
  18. self.rhs = SearchQuery(self.rhs, config=config)
  19. rhs, rhs_params = super().process_rhs(qn, connection)
  20. return rhs, rhs_params
  21. def as_sql(self, qn, connection):
  22. lhs, lhs_params = self.process_lhs(qn, connection)
  23. rhs, rhs_params = self.process_rhs(qn, connection)
  24. params = lhs_params + rhs_params
  25. return "%s @@ %s" % (lhs, rhs), params
  26. class SearchVectorField(Field):
  27. def db_type(self, connection):
  28. return "tsvector"
  29. class SearchQueryField(Field):
  30. def db_type(self, connection):
  31. return "tsquery"
  32. class _Float4Field(Field):
  33. def db_type(self, connection):
  34. return "float4"
  35. class SearchConfig(Expression):
  36. def __init__(self, config):
  37. super().__init__()
  38. if not hasattr(config, "resolve_expression"):
  39. config = Value(config)
  40. self.config = config
  41. @classmethod
  42. def from_parameter(cls, config):
  43. if config is None or isinstance(config, cls):
  44. return config
  45. return cls(config)
  46. def get_source_expressions(self):
  47. return [self.config]
  48. def set_source_expressions(self, exprs):
  49. (self.config,) = exprs
  50. def as_sql(self, compiler, connection):
  51. sql, params = compiler.compile(self.config)
  52. return "%s::regconfig" % sql, params
  53. class SearchVectorCombinable:
  54. ADD = "||"
  55. def _combine(self, other, connector, reversed):
  56. if not isinstance(other, SearchVectorCombinable):
  57. raise TypeError(
  58. "SearchVector can only be combined with other SearchVector "
  59. "instances, got %s." % type(other).__name__
  60. )
  61. if reversed:
  62. return CombinedSearchVector(other, connector, self, self.config)
  63. return CombinedSearchVector(self, connector, other, self.config)
  64. register_combinable_fields(
  65. SearchVectorField, SearchVectorCombinable.ADD, SearchVectorField, SearchVectorField
  66. )
  67. class SearchVector(SearchVectorCombinable, Func):
  68. function = "to_tsvector"
  69. arg_joiner = " || ' ' || "
  70. output_field = SearchVectorField()
  71. def __init__(self, *expressions, config=None, weight=None):
  72. super().__init__(*expressions)
  73. self.config = SearchConfig.from_parameter(config)
  74. if weight is not None and not hasattr(weight, "resolve_expression"):
  75. weight = Value(weight)
  76. self.weight = weight
  77. def resolve_expression(
  78. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  79. ):
  80. resolved = super().resolve_expression(
  81. query, allow_joins, reuse, summarize, for_save
  82. )
  83. if self.config:
  84. resolved.config = self.config.resolve_expression(
  85. query, allow_joins, reuse, summarize, for_save
  86. )
  87. return resolved
  88. def as_sql(self, compiler, connection, function=None, template=None):
  89. clone = self.copy()
  90. clone.set_source_expressions(
  91. [
  92. Coalesce(
  93. (
  94. expression
  95. if isinstance(expression.output_field, (CharField, TextField))
  96. else Cast(expression, TextField())
  97. ),
  98. Value(""),
  99. )
  100. for expression in clone.get_source_expressions()
  101. ]
  102. )
  103. config_sql = None
  104. config_params = []
  105. if template is None:
  106. if clone.config:
  107. config_sql, config_params = compiler.compile(clone.config)
  108. template = "%(function)s(%(config)s, %(expressions)s)"
  109. else:
  110. template = clone.template
  111. sql, params = super(SearchVector, clone).as_sql(
  112. compiler,
  113. connection,
  114. function=function,
  115. template=template,
  116. config=config_sql,
  117. )
  118. extra_params = []
  119. if clone.weight:
  120. weight_sql, extra_params = compiler.compile(clone.weight)
  121. sql = "setweight({}, {})".format(sql, weight_sql)
  122. return sql, config_params + params + extra_params
  123. class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
  124. def __init__(self, lhs, connector, rhs, config, output_field=None):
  125. self.config = config
  126. super().__init__(lhs, connector, rhs, output_field)
  127. class SearchQueryCombinable:
  128. BITAND = "&&"
  129. BITOR = "||"
  130. def _combine(self, other, connector, reversed):
  131. if not isinstance(other, SearchQueryCombinable):
  132. raise TypeError(
  133. "SearchQuery can only be combined with other SearchQuery "
  134. "instances, got %s." % type(other).__name__
  135. )
  136. if reversed:
  137. return CombinedSearchQuery(other, connector, self, self.config)
  138. return CombinedSearchQuery(self, connector, other, self.config)
  139. # On Combinable, these are not implemented to reduce confusion with Q. In
  140. # this case we are actually (ab)using them to do logical combination so
  141. # it's consistent with other usage in Django.
  142. def __or__(self, other):
  143. return self._combine(other, self.BITOR, False)
  144. def __ror__(self, other):
  145. return self._combine(other, self.BITOR, True)
  146. def __and__(self, other):
  147. return self._combine(other, self.BITAND, False)
  148. def __rand__(self, other):
  149. return self._combine(other, self.BITAND, True)
  150. class SearchQuery(SearchQueryCombinable, Func):
  151. output_field = SearchQueryField()
  152. SEARCH_TYPES = {
  153. "plain": "plainto_tsquery",
  154. "phrase": "phraseto_tsquery",
  155. "raw": "to_tsquery",
  156. "websearch": "websearch_to_tsquery",
  157. }
  158. def __init__(
  159. self,
  160. value,
  161. output_field=None,
  162. *,
  163. config=None,
  164. invert=False,
  165. search_type="plain",
  166. ):
  167. self.function = self.SEARCH_TYPES.get(search_type)
  168. if self.function is None:
  169. raise ValueError("Unknown search_type argument '%s'." % search_type)
  170. if not hasattr(value, "resolve_expression"):
  171. value = Value(value)
  172. expressions = (value,)
  173. self.config = SearchConfig.from_parameter(config)
  174. if self.config is not None:
  175. expressions = (self.config,) + expressions
  176. self.invert = invert
  177. super().__init__(*expressions, output_field=output_field)
  178. def as_sql(self, compiler, connection, function=None, template=None):
  179. sql, params = super().as_sql(compiler, connection, function, template)
  180. if self.invert:
  181. sql = "!!(%s)" % sql
  182. return sql, params
  183. def __invert__(self):
  184. clone = self.copy()
  185. clone.invert = not self.invert
  186. return clone
  187. def __str__(self):
  188. result = super().__str__()
  189. return ("~%s" % result) if self.invert else result
  190. class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
  191. def __init__(self, lhs, connector, rhs, config, output_field=None):
  192. self.config = config
  193. super().__init__(lhs, connector, rhs, output_field)
  194. def __str__(self):
  195. return "(%s)" % super().__str__()
  196. class SearchRank(Func):
  197. function = "ts_rank"
  198. output_field = FloatField()
  199. def __init__(
  200. self,
  201. vector,
  202. query,
  203. weights=None,
  204. normalization=None,
  205. cover_density=False,
  206. ):
  207. from .fields.array import ArrayField
  208. if not hasattr(vector, "resolve_expression"):
  209. vector = SearchVector(vector)
  210. if not hasattr(query, "resolve_expression"):
  211. query = SearchQuery(query)
  212. expressions = (vector, query)
  213. if weights is not None:
  214. if not hasattr(weights, "resolve_expression"):
  215. weights = Value(weights)
  216. weights = Cast(weights, ArrayField(_Float4Field()))
  217. expressions = (weights,) + expressions
  218. if normalization is not None:
  219. if not hasattr(normalization, "resolve_expression"):
  220. normalization = Value(normalization)
  221. expressions += (normalization,)
  222. if cover_density:
  223. self.function = "ts_rank_cd"
  224. super().__init__(*expressions)
  225. class SearchHeadline(Func):
  226. function = "ts_headline"
  227. template = "%(function)s(%(expressions)s%(options)s)"
  228. output_field = TextField()
  229. def __init__(
  230. self,
  231. expression,
  232. query,
  233. *,
  234. config=None,
  235. start_sel=None,
  236. stop_sel=None,
  237. max_words=None,
  238. min_words=None,
  239. short_word=None,
  240. highlight_all=None,
  241. max_fragments=None,
  242. fragment_delimiter=None,
  243. ):
  244. if not hasattr(query, "resolve_expression"):
  245. query = SearchQuery(query)
  246. options = {
  247. "StartSel": start_sel,
  248. "StopSel": stop_sel,
  249. "MaxWords": max_words,
  250. "MinWords": min_words,
  251. "ShortWord": short_word,
  252. "HighlightAll": highlight_all,
  253. "MaxFragments": max_fragments,
  254. "FragmentDelimiter": fragment_delimiter,
  255. }
  256. self.options = {
  257. option: value for option, value in options.items() if value is not None
  258. }
  259. expressions = (expression, query)
  260. if config is not None:
  261. config = SearchConfig.from_parameter(config)
  262. expressions = (config,) + expressions
  263. super().__init__(*expressions)
  264. def as_sql(self, compiler, connection, function=None, template=None):
  265. options_sql = ""
  266. options_params = []
  267. if self.options:
  268. options_params.append(
  269. ", ".join(
  270. connection.ops.compose_sql(f"{option}=%s", [value])
  271. for option, value in self.options.items()
  272. )
  273. )
  274. options_sql = ", %s"
  275. sql, params = super().as_sql(
  276. compiler,
  277. connection,
  278. function=function,
  279. template=template,
  280. options=options_sql,
  281. )
  282. return sql, params + options_params
  283. SearchVectorField.register_lookup(SearchVectorExact)
  284. class TrigramBase(Func):
  285. output_field = FloatField()
  286. def __init__(self, expression, string, **extra):
  287. if not hasattr(string, "resolve_expression"):
  288. string = Value(string)
  289. super().__init__(expression, string, **extra)
  290. class TrigramWordBase(Func):
  291. output_field = FloatField()
  292. def __init__(self, string, expression, **extra):
  293. if not hasattr(string, "resolve_expression"):
  294. string = Value(string)
  295. super().__init__(string, expression, **extra)
  296. class TrigramSimilarity(TrigramBase):
  297. function = "SIMILARITY"
  298. class TrigramDistance(TrigramBase):
  299. function = ""
  300. arg_joiner = " <-> "
  301. class TrigramWordDistance(TrigramWordBase):
  302. function = ""
  303. arg_joiner = " <<-> "
  304. class TrigramStrictWordDistance(TrigramWordBase):
  305. function = ""
  306. arg_joiner = " <<<-> "
  307. class TrigramWordSimilarity(TrigramWordBase):
  308. function = "WORD_SIMILARITY"
  309. class TrigramStrictWordSimilarity(TrigramWordBase):
  310. function = "STRICT_WORD_SIMILARITY"