| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383 |
- from typing import List, Union
- FIELDNAME = object()
- class Limit:
- def __init__(self, offset: int = 0, count: int = 0) -> None:
- self.offset = offset
- self.count = count
- def build_args(self):
- if self.count:
- return ["LIMIT", str(self.offset), str(self.count)]
- else:
- return []
- class Reducer:
- """
- Base reducer object for all reducers.
- See the `redisearch.reducers` module for the actual reducers.
- """
- NAME = None
- def __init__(self, *args: List[str]) -> None:
- self._args = args
- self._field = None
- self._alias = None
- def alias(self, alias: str) -> "Reducer":
- """
- Set the alias for this reducer.
- ### Parameters
- - **alias**: The value of the alias for this reducer. If this is the
- special value `aggregation.FIELDNAME` then this reducer will be
- aliased using the same name as the field upon which it operates.
- Note that using `FIELDNAME` is only possible on reducers which
- operate on a single field value.
- This method returns the `Reducer` object making it suitable for
- chaining.
- """
- if alias is FIELDNAME:
- if not self._field:
- raise ValueError("Cannot use FIELDNAME alias with no field")
- # Chop off initial '@'
- alias = self._field[1:]
- self._alias = alias
- return self
- @property
- def args(self) -> List[str]:
- return self._args
- class SortDirection:
- """
- This special class is used to indicate sort direction.
- """
- DIRSTRING = None
- def __init__(self, field: str) -> None:
- self.field = field
- class Asc(SortDirection):
- """
- Indicate that the given field should be sorted in ascending order
- """
- DIRSTRING = "ASC"
- class Desc(SortDirection):
- """
- Indicate that the given field should be sorted in descending order
- """
- DIRSTRING = "DESC"
- class AggregateRequest:
- """
- Aggregation request which can be passed to `Client.aggregate`.
- """
- def __init__(self, query: str = "*") -> None:
- """
- Create an aggregation request. This request may then be passed to
- `client.aggregate()`.
- In order for the request to be usable, it must contain at least one
- group.
- - **query** Query string for filtering records.
- All member methods (except `build_args()`)
- return the object itself, making them useful for chaining.
- """
- self._query = query
- self._aggregateplan = []
- self._loadfields = []
- self._loadall = False
- self._max = 0
- self._with_schema = False
- self._verbatim = False
- self._cursor = []
- self._dialect = None
- self._add_scores = False
- def load(self, *fields: List[str]) -> "AggregateRequest":
- """
- Indicate the fields to be returned in the response. These fields are
- returned in addition to any others implicitly specified.
- ### Parameters
- - **fields**: If fields not specified, all the fields will be loaded.
- Otherwise, fields should be given in the format of `@field`.
- """
- if fields:
- self._loadfields.extend(fields)
- else:
- self._loadall = True
- return self
- def group_by(
- self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
- ) -> "AggregateRequest":
- """
- Specify by which fields to group the aggregation.
- ### Parameters
- - **fields**: Fields to group by. This can either be a single string,
- or a list of strings. both cases, the field should be specified as
- `@field`.
- - **reducers**: One or more reducers. Reducers may be found in the
- `aggregation` module.
- """
- fields = [fields] if isinstance(fields, str) else fields
- reducers = [reducers] if isinstance(reducers, Reducer) else reducers
- ret = ["GROUPBY", str(len(fields)), *fields]
- for reducer in reducers:
- ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
- ret.extend(reducer.args)
- if reducer._alias is not None:
- ret += ["AS", reducer._alias]
- self._aggregateplan.extend(ret)
- return self
- def apply(self, **kwexpr) -> "AggregateRequest":
- """
- Specify one or more projection expressions to add to each result
- ### Parameters
- - **kwexpr**: One or more key-value pairs for a projection. The key is
- the alias for the projection, and the value is the projection
- expression itself, for example `apply(square_root="sqrt(@foo)")`
- """
- for alias, expr in kwexpr.items():
- ret = ["APPLY", expr]
- if alias is not None:
- ret += ["AS", alias]
- self._aggregateplan.extend(ret)
- return self
- def limit(self, offset: int, num: int) -> "AggregateRequest":
- """
- Sets the limit for the most recent group or query.
- If no group has been defined yet (via `group_by()`) then this sets
- the limit for the initial pool of results from the query. Otherwise,
- this limits the number of items operated on from the previous group.
- Setting a limit on the initial search results may be useful when
- attempting to execute an aggregation on a sample of a large data set.
- ### Parameters
- - **offset**: Result offset from which to begin paging
- - **num**: Number of results to return
- Example of sorting the initial results:
- ```
- AggregateRequest("@sale_amount:[10000, inf]")\
- .limit(0, 10)\
- .group_by("@state", r.count())
- ```
- Will only group by the states found in the first 10 results of the
- query `@sale_amount:[10000, inf]`. On the other hand,
- ```
- AggregateRequest("@sale_amount:[10000, inf]")\
- .limit(0, 1000)\
- .group_by("@state", r.count()\
- .limit(0, 10)
- ```
- Will group all the results matching the query, but only return the
- first 10 groups.
- If you only wish to return a *top-N* style query, consider using
- `sort_by()` instead.
- """
- _limit = Limit(offset, num)
- self._aggregateplan.extend(_limit.build_args())
- return self
- def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
- """
- Indicate how the results should be sorted. This can also be used for
- *top-N* style queries
- ### Parameters
- - **fields**: The fields by which to sort. This can be either a single
- field or a list of fields. If you wish to specify order, you can
- use the `Asc` or `Desc` wrapper classes.
- - **max**: Maximum number of results to return. This can be
- used instead of `LIMIT` and is also faster.
- Example of sorting by `foo` ascending and `bar` descending:
- ```
- sort_by(Asc("@foo"), Desc("@bar"))
- ```
- Return the top 10 customers:
- ```
- AggregateRequest()\
- .group_by("@customer", r.sum("@paid").alias(FIELDNAME))\
- .sort_by(Desc("@paid"), max=10)
- ```
- """
- if isinstance(fields, (str, SortDirection)):
- fields = [fields]
- fields_args = []
- for f in fields:
- if isinstance(f, SortDirection):
- fields_args += [f.field, f.DIRSTRING]
- else:
- fields_args += [f]
- ret = ["SORTBY", str(len(fields_args))]
- ret.extend(fields_args)
- max = kwargs.get("max", 0)
- if max > 0:
- ret += ["MAX", str(max)]
- self._aggregateplan.extend(ret)
- return self
- def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest":
- """
- Specify filter for post-query results using predicates relating to
- values in the result set.
- ### Parameters
- - **fields**: Fields to group by. This can either be a single string,
- or a list of strings.
- """
- if isinstance(expressions, str):
- expressions = [expressions]
- for expression in expressions:
- self._aggregateplan.extend(["FILTER", expression])
- return self
- def with_schema(self) -> "AggregateRequest":
- """
- If set, the `schema` property will contain a list of `[field, type]`
- entries in the result object.
- """
- self._with_schema = True
- return self
- def add_scores(self) -> "AggregateRequest":
- """
- If set, includes the score as an ordinary field of the row.
- """
- self._add_scores = True
- return self
- def verbatim(self) -> "AggregateRequest":
- self._verbatim = True
- return self
- def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest":
- args = ["WITHCURSOR"]
- if count:
- args += ["COUNT", str(count)]
- if max_idle:
- args += ["MAXIDLE", str(max_idle * 1000)]
- self._cursor = args
- return self
- def build_args(self) -> List[str]:
- # @foo:bar ...
- ret = [self._query]
- if self._with_schema:
- ret.append("WITHSCHEMA")
- if self._verbatim:
- ret.append("VERBATIM")
- if self._add_scores:
- ret.append("ADDSCORES")
- if self._cursor:
- ret += self._cursor
- if self._loadall:
- ret.append("LOAD")
- ret.append("*")
- elif self._loadfields:
- ret.append("LOAD")
- ret.append(str(len(self._loadfields)))
- ret.extend(self._loadfields)
- if self._dialect:
- ret.extend(["DIALECT", self._dialect])
- ret.extend(self._aggregateplan)
- return ret
- def dialect(self, dialect: int) -> "AggregateRequest":
- """
- Add a dialect field to the aggregate command.
- - **dialect** - dialect version to execute the query under
- """
- self._dialect = dialect
- return self
- class Cursor:
- def __init__(self, cid: int) -> None:
- self.cid = cid
- self.max_idle = 0
- self.count = 0
- def build_args(self):
- args = [str(self.cid)]
- if self.max_idle:
- args += ["MAXIDLE", str(self.max_idle)]
- if self.count:
- args += ["COUNT", str(self.count)]
- return args
- class AggregateResult:
- def __init__(self, rows, cursor: Cursor, schema) -> None:
- self.rows = rows
- self.cursor = cursor
- self.schema = schema
- def __repr__(self) -> (str, str):
- cid = self.cursor.cid if self.cursor else -1
- return (
- f"<{self.__class__.__name__} at 0x{id(self):x} "
- f"Rows={len(self.rows)}, Cursor={cid}>"
- )
|