from collections import defaultdict import logging from typing import Callable, Dict, List, Set from lunr.exceptions import BaseLunrException from lunr.token import Token log = logging.getLogger(__name__) class Pipeline: """lunr.Pipelines maintain a list of functions to be applied to all tokens in documents entering the search index and queries ran agains the index. """ registered_functions: Dict[str, Callable] = {} def __init__(self): self._stack: List[Callable] = [] self._skip: Dict[Callable, Set[str]] = defaultdict(set) def __len__(self): return len(self._stack) def __repr__(self): return ''.format(",".join(fn.label for fn in self._stack)) # TODO: add iterator methods? @classmethod def register_function(cls, fn, label=None): """Register a function with the pipeline.""" label = label or fn.__name__ if label in cls.registered_functions: log.warning("Overwriting existing registered function %s", label) fn.label = label cls.registered_functions[fn.label] = fn @classmethod def load(cls, serialised): """Loads a previously serialised pipeline.""" pipeline = cls() for fn_name in serialised: try: fn = cls.registered_functions[fn_name] except KeyError: raise BaseLunrException( "Cannot load unregistered function {}".format(fn_name) ) else: pipeline.add(fn) return pipeline def add(self, *args): """Adds new functions to the end of the pipeline. Functions must accept three arguments: - Token: A lunr.Token object which will be updated - i: The index of the token in the set - tokens: A list of tokens representing the set """ for fn in args: self.warn_if_function_not_registered(fn) self._stack.append(fn) def warn_if_function_not_registered(self, fn): try: return fn.label in self.registered_functions except AttributeError: log.warning( 'Function "{}" is not registered with pipeline. ' "This may cause problems when serialising the index.".format( getattr(fn, "label", fn) ) ) def after(self, existing_fn, new_fn): """Adds a single function after a function that already exists in the pipeline.""" self.warn_if_function_not_registered(new_fn) try: index = self._stack.index(existing_fn) self._stack.insert(index + 1, new_fn) except ValueError as e: raise BaseLunrException("Cannot find existing_fn") from e def before(self, existing_fn, new_fn): """Adds a single function before a function that already exists in the pipeline. """ self.warn_if_function_not_registered(new_fn) try: index = self._stack.index(existing_fn) self._stack.insert(index, new_fn) except ValueError as e: raise BaseLunrException("Cannot find existing_fn") from e def remove(self, fn): """Removes a function from the pipeline.""" try: self._stack.remove(fn) except ValueError: pass def skip(self, fn: Callable, field_names: List[str]): """ Make the pipeline skip the function based on field name we're processing. This relies on passing the field name to Pipeline.run(). """ self._skip[fn].update(field_names) def run(self, tokens, field_name=None): """ Runs the current list of functions that make up the pipeline against the passed tokens. :param tokens: The tokens to process. :param field_name: The name of the field these tokens belongs to, can be ommited. Used to skip some functions based on field names. """ for fn in self._stack: # Skip the function based on field name. if field_name and field_name in self._skip[fn]: continue results = [] for i, token in enumerate(tokens): # JS ignores additional arguments to the functions but we # force pipeline functions to declare (token, i, tokens) # or *args result = fn(token, i, tokens) if not result: continue if isinstance(result, (list, tuple)): # simulate Array.concat results.extend(result) else: results.append(result) tokens = results return tokens def run_string(self, string, metadata=None): """Convenience method for passing a string through a pipeline and getting strings out. This method takes care of wrapping the passed string in a token and mapping the resulting tokens back to strings. .. note:: This ignores the skipped functions since we can't access field names from this context. """ token = Token(string, metadata) return [str(tkn) for tkn in self.run([token])] def reset(self): self._stack = [] def serialize(self): return [fn.label for fn in self._stack]