Files
timmy-config/workspace/evennia-venv/lib/python3.12/site-packages/lunr/pipeline.py
2026-03-31 20:02:01 +00:00

162 lines
5.3 KiB
Python

from collections import defaultdict
import logging
from typing import Callable, Dict, List, Set
from lunr.exceptions import BaseLunrException
from lunr.token import Token
log = logging.getLogger(__name__)
class Pipeline:
"""lunr.Pipelines maintain a list of functions to be applied to all tokens
in documents entering the search index and queries ran agains the index.
"""
registered_functions: Dict[str, Callable] = {}
def __init__(self):
self._stack: List[Callable] = []
self._skip: Dict[Callable, Set[str]] = defaultdict(set)
def __len__(self):
return len(self._stack)
def __repr__(self):
return '<Pipeline stack="{}">'.format(",".join(fn.label for fn in self._stack))
# TODO: add iterator methods?
@classmethod
def register_function(cls, fn, label=None):
"""Register a function with the pipeline."""
label = label or fn.__name__
if label in cls.registered_functions:
log.warning("Overwriting existing registered function %s", label)
fn.label = label
cls.registered_functions[fn.label] = fn
@classmethod
def load(cls, serialised):
"""Loads a previously serialised pipeline."""
pipeline = cls()
for fn_name in serialised:
try:
fn = cls.registered_functions[fn_name]
except KeyError:
raise BaseLunrException(
"Cannot load unregistered function {}".format(fn_name)
)
else:
pipeline.add(fn)
return pipeline
def add(self, *args):
"""Adds new functions to the end of the pipeline.
Functions must accept three arguments:
- Token: A lunr.Token object which will be updated
- i: The index of the token in the set
- tokens: A list of tokens representing the set
"""
for fn in args:
self.warn_if_function_not_registered(fn)
self._stack.append(fn)
def warn_if_function_not_registered(self, fn):
try:
return fn.label in self.registered_functions
except AttributeError:
log.warning(
'Function "{}" is not registered with pipeline. '
"This may cause problems when serialising the index.".format(
getattr(fn, "label", fn)
)
)
def after(self, existing_fn, new_fn):
"""Adds a single function after a function that already exists in the
pipeline."""
self.warn_if_function_not_registered(new_fn)
try:
index = self._stack.index(existing_fn)
self._stack.insert(index + 1, new_fn)
except ValueError as e:
raise BaseLunrException("Cannot find existing_fn") from e
def before(self, existing_fn, new_fn):
"""Adds a single function before a function that already exists in the
pipeline.
"""
self.warn_if_function_not_registered(new_fn)
try:
index = self._stack.index(existing_fn)
self._stack.insert(index, new_fn)
except ValueError as e:
raise BaseLunrException("Cannot find existing_fn") from e
def remove(self, fn):
"""Removes a function from the pipeline."""
try:
self._stack.remove(fn)
except ValueError:
pass
def skip(self, fn: Callable, field_names: List[str]):
"""
Make the pipeline skip the function based on field name we're processing.
This relies on passing the field name to Pipeline.run().
"""
self._skip[fn].update(field_names)
def run(self, tokens, field_name=None):
"""
Runs the current list of functions that make up the pipeline against
the passed tokens.
:param tokens: The tokens to process.
:param field_name: The name of the field these tokens belongs to, can be ommited.
Used to skip some functions based on field names.
"""
for fn in self._stack:
# Skip the function based on field name.
if field_name and field_name in self._skip[fn]:
continue
results = []
for i, token in enumerate(tokens):
# JS ignores additional arguments to the functions but we
# force pipeline functions to declare (token, i, tokens)
# or *args
result = fn(token, i, tokens)
if not result:
continue
if isinstance(result, (list, tuple)): # simulate Array.concat
results.extend(result)
else:
results.append(result)
tokens = results
return tokens
def run_string(self, string, metadata=None):
"""Convenience method for passing a string through a pipeline and
getting strings out. This method takes care of wrapping the passed
string in a token and mapping the resulting tokens back to strings.
.. note:: This ignores the skipped functions since we can't
access field names from this context.
"""
token = Token(string, metadata)
return [str(tkn) for tkn in self.run([token])]
def reset(self):
self._stack = []
def serialize(self):
return [fn.label for fn in self._stack]