| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377 |
- from typing import List, Optional, Union
- class Query:
- """
- Query is used to build complex queries that have more parameters than just
- the query string. The query string is set in the constructor, and other
- options have setter functions.
- The setter functions return the query object, so they can be chained,
- i.e. `Query("foo").verbatim().filter(...)` etc.
- """
- def __init__(self, query_string: str) -> None:
- """
- Create a new query object.
- The query string is set in the constructor, and other options have
- setter functions.
- """
- self._query_string: str = query_string
- self._offset: int = 0
- self._num: int = 10
- self._no_content: bool = False
- self._no_stopwords: bool = False
- self._fields: Optional[List[str]] = None
- self._verbatim: bool = False
- self._with_payloads: bool = False
- self._with_scores: bool = False
- self._scorer: Optional[str] = None
- self._filters: List = list()
- self._ids: Optional[List[str]] = None
- self._slop: int = -1
- self._timeout: Optional[float] = None
- self._in_order: bool = False
- self._sortby: Optional[SortbyField] = None
- self._return_fields: List = []
- self._return_fields_decode_as: dict = {}
- self._summarize_fields: List = []
- self._highlight_fields: List = []
- self._language: Optional[str] = None
- self._expander: Optional[str] = None
- self._dialect: Optional[int] = None
- def query_string(self) -> str:
- """Return the query string of this query only."""
- return self._query_string
- def limit_ids(self, *ids) -> "Query":
- """Limit the results to a specific set of pre-known document
- ids of any length."""
- self._ids = ids
- return self
- def return_fields(self, *fields) -> "Query":
- """Add fields to return fields."""
- for field in fields:
- self.return_field(field)
- return self
- def return_field(
- self,
- field: str,
- as_field: Optional[str] = None,
- decode_field: Optional[bool] = True,
- encoding: Optional[str] = "utf8",
- ) -> "Query":
- """
- Add a field to the list of fields to return.
- - **field**: The field to include in query results
- - **as_field**: The alias for the field
- - **decode_field**: Whether to decode the field from bytes to string
- - **encoding**: The encoding to use when decoding the field
- """
- self._return_fields.append(field)
- self._return_fields_decode_as[field] = encoding if decode_field else None
- if as_field is not None:
- self._return_fields += ("AS", as_field)
- return self
- def _mk_field_list(self, fields: List[str]) -> List:
- if not fields:
- return []
- return [fields] if isinstance(fields, str) else list(fields)
- def summarize(
- self,
- fields: Optional[List] = None,
- context_len: Optional[int] = None,
- num_frags: Optional[int] = None,
- sep: Optional[str] = None,
- ) -> "Query":
- """
- Return an abridged format of the field, containing only the segments of
- the field which contain the matching term(s).
- If `fields` is specified, then only the mentioned fields are
- summarized; otherwise all results are summarized.
- Server side defaults are used for each option (except `fields`)
- if not specified
- - **fields** List of fields to summarize. All fields are summarized
- if not specified
- - **context_len** Amount of context to include with each fragment
- - **num_frags** Number of fragments per document
- - **sep** Separator string to separate fragments
- """
- args = ["SUMMARIZE"]
- fields = self._mk_field_list(fields)
- if fields:
- args += ["FIELDS", str(len(fields))] + fields
- if context_len is not None:
- args += ["LEN", str(context_len)]
- if num_frags is not None:
- args += ["FRAGS", str(num_frags)]
- if sep is not None:
- args += ["SEPARATOR", sep]
- self._summarize_fields = args
- return self
- def highlight(
- self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
- ) -> None:
- """
- Apply specified markup to matched term(s) within the returned field(s).
- - **fields** If specified then only those mentioned fields are
- highlighted, otherwise all fields are highlighted
- - **tags** A list of two strings to surround the match.
- """
- args = ["HIGHLIGHT"]
- fields = self._mk_field_list(fields)
- if fields:
- args += ["FIELDS", str(len(fields))] + fields
- if tags:
- args += ["TAGS"] + list(tags)
- self._highlight_fields = args
- return self
- def language(self, language: str) -> "Query":
- """
- Analyze the query as being in the specified language.
- :param language: The language (e.g. `chinese` or `english`)
- """
- self._language = language
- return self
- def slop(self, slop: int) -> "Query":
- """Allow a maximum of N intervening non matched terms between
- phrase terms (0 means exact phrase).
- """
- self._slop = slop
- return self
- def timeout(self, timeout: float) -> "Query":
- """overrides the timeout parameter of the module"""
- self._timeout = timeout
- return self
- def in_order(self) -> "Query":
- """
- Match only documents where the query terms appear in
- the same order in the document.
- i.e. for the query "hello world", we do not match "world hello"
- """
- self._in_order = True
- return self
- def scorer(self, scorer: str) -> "Query":
- """
- Use a different scoring function to evaluate document relevance.
- Default is `TFIDF`.
- :param scorer: The scoring function to use
- (e.g. `TFIDF.DOCNORM` or `BM25`)
- """
- self._scorer = scorer
- return self
- def get_args(self) -> List[str]:
- """Format the redis arguments for this query and return them."""
- args = [self._query_string]
- args += self._get_args_tags()
- args += self._summarize_fields + self._highlight_fields
- args += ["LIMIT", self._offset, self._num]
- return args
- def _get_args_tags(self) -> List[str]:
- args = []
- if self._no_content:
- args.append("NOCONTENT")
- if self._fields:
- args.append("INFIELDS")
- args.append(len(self._fields))
- args += self._fields
- if self._verbatim:
- args.append("VERBATIM")
- if self._no_stopwords:
- args.append("NOSTOPWORDS")
- if self._filters:
- for flt in self._filters:
- if not isinstance(flt, Filter):
- raise AttributeError("Did not receive a Filter object.")
- args += flt.args
- if self._with_payloads:
- args.append("WITHPAYLOADS")
- if self._scorer:
- args += ["SCORER", self._scorer]
- if self._with_scores:
- args.append("WITHSCORES")
- if self._ids:
- args.append("INKEYS")
- args.append(len(self._ids))
- args += self._ids
- if self._slop >= 0:
- args += ["SLOP", self._slop]
- if self._timeout is not None:
- args += ["TIMEOUT", self._timeout]
- if self._in_order:
- args.append("INORDER")
- if self._return_fields:
- args.append("RETURN")
- args.append(len(self._return_fields))
- args += self._return_fields
- if self._sortby:
- if not isinstance(self._sortby, SortbyField):
- raise AttributeError("Did not receive a SortByField.")
- args.append("SORTBY")
- args += self._sortby.args
- if self._language:
- args += ["LANGUAGE", self._language]
- if self._expander:
- args += ["EXPANDER", self._expander]
- if self._dialect:
- args += ["DIALECT", self._dialect]
- return args
- def paging(self, offset: int, num: int) -> "Query":
- """
- Set the paging for the query (defaults to 0..10).
- - **offset**: Paging offset for the results. Defaults to 0
- - **num**: How many results do we want
- """
- self._offset = offset
- self._num = num
- return self
- def verbatim(self) -> "Query":
- """Set the query to be verbatim, i.e. use no query expansion
- or stemming.
- """
- self._verbatim = True
- return self
- def no_content(self) -> "Query":
- """Set the query to only return ids and not the document content."""
- self._no_content = True
- return self
- def no_stopwords(self) -> "Query":
- """
- Prevent the query from being filtered for stopwords.
- Only useful in very big queries that you are certain contain
- no stopwords.
- """
- self._no_stopwords = True
- return self
- def with_payloads(self) -> "Query":
- """Ask the engine to return document payloads."""
- self._with_payloads = True
- return self
- def with_scores(self) -> "Query":
- """Ask the engine to return document search scores."""
- self._with_scores = True
- return self
- def limit_fields(self, *fields: List[str]) -> "Query":
- """
- Limit the search to specific TEXT fields only.
- - **fields**: A list of strings, case sensitive field names
- from the defined schema.
- """
- self._fields = fields
- return self
- def add_filter(self, flt: "Filter") -> "Query":
- """
- Add a numeric or geo filter to the query.
- **Currently only one of each filter is supported by the engine**
- - **flt**: A NumericFilter or GeoFilter object, used on a
- corresponding field
- """
- self._filters.append(flt)
- return self
- def sort_by(self, field: str, asc: bool = True) -> "Query":
- """
- Add a sortby field to the query.
- - **field** - the name of the field to sort by
- - **asc** - when `True`, sorting will be done in asceding order
- """
- self._sortby = SortbyField(field, asc)
- return self
- def expander(self, expander: str) -> "Query":
- """
- Add a expander field to the query.
- - **expander** - the name of the expander
- """
- self._expander = expander
- return self
- def dialect(self, dialect: int) -> "Query":
- """
- Add a dialect field to the query.
- - **dialect** - dialect version to execute the query under
- """
- self._dialect = dialect
- return self
- class Filter:
- def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
- self.args = [keyword, field] + list(args)
- class NumericFilter(Filter):
- INF = "+inf"
- NEG_INF = "-inf"
- def __init__(
- self,
- field: str,
- minval: Union[int, str],
- maxval: Union[int, str],
- minExclusive: bool = False,
- maxExclusive: bool = False,
- ) -> None:
- args = [
- minval if not minExclusive else f"({minval}",
- maxval if not maxExclusive else f"({maxval}",
- ]
- Filter.__init__(self, "FILTER", field, *args)
- class GeoFilter(Filter):
- METERS = "m"
- KILOMETERS = "km"
- FEET = "ft"
- MILES = "mi"
- def __init__(
- self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS
- ) -> None:
- Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit)
- class SortbyField:
- def __init__(self, field: str, asc=True) -> None:
- self.args = [field, "ASC" if asc else "DESC"]
|