import os
import re
import copy
import logging
import importlib
from scrubadub.detectors import url
from typing import Generator, Iterable, Optional, Sequence, List, Callable, cast, Type
try:
import spacy
import spacy.tokens
import spacy.cli
except ImportError as e:
if e.name == "spacy":
raise ImportError(
"Could not find module 'spacy'. If you want to use extras,"
" make sure you install scrubadub with 'pip install scrubadub[spacy]'"
)
from scrubadub.detectors.catalogue import register_detector
from scrubadub.detectors.base import Detector, RegexDetector
from scrubadub.filth import Filth, NameFilth, OrganizationFilth, LocationFilth, DateOfBirthFilth
from scrubadub.utils import CanonicalStringSet
[docs]@register_detector
class SpacyEntityDetector(Detector):
"""Use spaCy's named entity recognition to identify possible ``Filth``.
This detector is made to work with v3 of spaCy, since the NER model has been significantly improved in this
version.
This is particularly useful to remove names from text, but can also be used to remove any entity that is
recognised by spaCy. A full list of entities that spacy supports can be found here:
`<https://spacy.io/api/annotation#named-entities>`_.
Additional entities can be added like so:
>>> import scrubadub, scrubadub_spacy
>>> class MoneyFilth(scrubadub.filth.Filth):
... type = 'money'
>>> scrubadub_spacy.detectors.spacy.SpacyEntityDetector.filth_cls_map['MONEY'] = MoneyFilth
>>> detector = scrubadub_spacy.detectors.spacy.SpacyEntityDetector(named_entities=['MONEY'])
>>> scrubber = scrubadub.Scrubber(detector_list=[detector])
>>> scrubber.clean("You owe me 12 dollars man!")
'You owe me {{MONEY}} man!'
The dictonary ``scrubadub_spacy.detectors.spacy.SpacyEntityDetector.filth_cls_map`` is used to map between the spaCy
named entity label and the type of scrubadub ``Filth``, while the ``named_entities`` argument sets which named
entities are considered ``Filth`` by the ``SpacyEntityDetector``.
"""
filth_cls_map = {
'FAC': LocationFilth, # Buildings, airports, highways, bridges, etc.
'GPE': LocationFilth, # Countries, cities, states.
'LOC': LocationFilth, # Non-GPE locations, mountain ranges, bodies of water.
'PERSON': NameFilth, # People, including fictional.
'PER': NameFilth, # Bug in french model
'ORG': OrganizationFilth, # Companies, agencies, institutions, etc.
'DATE': DateOfBirthFilth, # Dates within the period 18 to 100 years ago.
}
name = 'spacy'
language_to_model = {
"zh": "zh_core_web_trf",
"nl": "nl_core_news_trf",
"en": "en_core_web_trf",
"fr": "fr_dep_news_trf",
"de": "de_dep_news_trf",
"es": "es_dep_news_trf",
}
disallowed_nouns = CanonicalStringSet(["skype"])
[docs] def __init__(self, named_entities: Optional[Iterable[str]] = None,
model: Optional[str] = None, **kwargs):
"""Initialise the ``Detector``.
:param named_entities: Limit the named entities to those in this list, defaults to ``{'PERSON', 'PER', 'ORG'}``
:type named_entities: Iterable[str], optional
:param model: The name of the spacy model to use, it must contain a 'ner' step in the model pipeline (most
do, but not all).
:type model: str, optional
:param name: Overrides the default name of the :class:``Detector``
:type name: str, optional
:param locale: The locale of the documents in the format: 2 letter lower-case language code followed by an
underscore and the two letter upper-case country code, eg "en_GB" or "de_CH".
:type locale: str, optional
"""
super(SpacyEntityDetector, self).__init__(**kwargs)
if named_entities is None:
named_entities = {'PERSON', 'PER', 'ORG'}
# Spacy NER are all upper cased
self.named_entities = {entity.upper() for entity in named_entities}
# Fixes a warning message from transformers that is pulled in via spacy
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.check_spacy_version()
if model is not None:
self.model = model
else:
if self.language in self.language_to_model:
self.model = self.language_to_model[self.language]
else:
self.model = "{}_core_news_lg".format(self.language)
self.preprocess_text = self.model.endswith('_trf')
if not self.check_spacy_model(self.model):
raise ValueError("Unable to find spacy model '{}'. Is your language supported? "
"Check the list of models available here: "
"https://github.com/explosion/spacy-models ".format(self.model))
self.nlp = spacy.load(self.model)
# If the model doesn't support named entity recognition
if 'ner' not in [step[0] for step in self.nlp.pipeline]:
raise ValueError(
"The spacy model '{}' doesn't support named entity recognition, "
"please choose another model.".format(self.model)
)
[docs] @staticmethod
def check_spacy_version() -> bool:
"""Ensure that the version od spaCy is v3."""
spacy_version = spacy.__version__ # spacy_info.get('spaCy version', spacy_info.get('spacy_version', None))
if spacy_version is None:
raise ImportError('Spacy v3 needs to be installed. Unable to detect spacy version.')
try:
spacy_major = int(spacy_version.split('.')[0])
except Exception:
raise ImportError('Spacy v3 needs to be installed. Spacy version {} is unknown.'.format(spacy_version))
if spacy_major != 3:
raise ImportError('Spacy v3 needs to be installed. Detected version {}.'.format(spacy_version))
return True
[docs] @staticmethod
def check_spacy_model(model) -> bool:
"""Ensure that the spaCy model is installed."""
spacy_info = spacy.info()
if isinstance(spacy_info, str):
raise ValueError('Unable to detect spacy models.')
models = list(spacy_info.get('pipelines', spacy_info.get('models', None)).keys())
if models is None:
raise ValueError('Unable to detect spacy models.')
if model not in models:
logger = logging.getLogger('scrubadub.detectors.spacy.SpacyEntityDetector')
logger.info("Downloading spacy model {}".format(model))
spacy.cli.download(model)
importlib.import_module(model)
# spacy.info() doesnt update after a spacy.cli.download, so theres no point checking it
models.append(model)
# Always returns true, if it fails to download, spacy sys.exit()s
return model in models
@staticmethod
def _preprocess_text(document_list: List[str]) -> List[str]:
whitespace_regex = re.compile(r'\s+')
for i_doc, text in enumerate(document_list):
document_list[i_doc] = re.sub(whitespace_regex, ' ', text)
document_list[i_doc] = re.sub(url.UrlDetector.regex, ' ', document_list[i_doc])
return document_list
def _run_spacy(
self, document_list: Sequence[str], document_names: Sequence[Optional[str]]
) -> List[spacy.tokens.doc.Doc]:
i = 0
spacy_docs = [] # type: List[spacy.tokens.doc.Doc]
import spacy_transformers.pipeline_component
transformer_stages = [stage for name, stage in self.nlp.pipeline if name == 'transformer']
if len(transformer_stages) > 0:
transformer_model = cast(spacy_transformers.pipeline_component.Transformer, transformer_stages[0])
if 'tokenizer' in transformer_model.model.attrs:
tokenizer = transformer_model.model.attrs['tokenizer']
tokenizer.deprecation_warnings['sequence-length-is-longer-than-the-specified-maximum'] = False
# self.nlp.pipe has an outstanding issue effecting its type signature
# https://github.com/explosion/spaCy/issues/8772
generator = self.nlp.pipe(document_list) # type: ignore
if len(transformer_stages) > 0:
transformer_model = cast(spacy_transformers.pipeline_component.Transformer, transformer_stages[0])
if 'tokenizer' in transformer_model.model.attrs:
tokenizer = transformer_model.model.attrs['tokenizer']
if tokenizer.deprecation_warnings['sequence-length-is-longer-than-the-specified-maximum']:
logger = logging.getLogger('scrubadub.detectors.spacy.SpacyEntityDetector')
logger.warning(
"The documents that triggered the sequence-length-is-longer-than-the-specified-maximum message:"
f"\n{document_list}"
)
while True:
try:
spacy_doc = next(generator)
except IndexError as e:
if e.args[0] == 'index out of range in self':
message = "Error processing documents due to spacy's transformer model. To use this model, try " \
"preprocessing the text by removing non-words and reducing spaces. Skipping file: {}"
logger = logging.getLogger('scrubadub.detectors.spacy.SpacyEntityDetector')
logger.warning(message.format(document_names[i]))
# self.nlp.pipe has an outstanding issue effecting its type signature
# https://github.com/explosion/spaCy/issues/8772
spacy_doc = list(self.nlp.pipe([' ']))[0] # type: ignore
else:
raise e
except StopIteration:
break
i += 1
spacy_docs.append(spacy_doc)
return spacy_docs
@staticmethod
def _get_entities(doc: spacy.tokens.doc.Doc) -> Iterable[spacy.tokens.span.Span]:
return doc.ents
def _yield_filth(
self, document_list: Sequence[str], document_names: Sequence[Optional[str]],
get_entity_function: Optional[Callable[[spacy.tokens.doc.Doc], Iterable[spacy.tokens.span.Span]]] = None,
) -> Generator[Filth, None, None]:
# If the model is a transformer model, we need to pre-process our data a little to avoid hitting the maximum
# width of the transformer. Lots of spaces causes lots of tokens to be made and passed to the transformer
# which causes an index go out of range error and so we remove this excess whitespace.
preprocessed_docs = list(copy.copy(document_list))
if self.preprocess_text:
preprocessed_docs = self._preprocess_text(preprocessed_docs)
spacy_docs = self._run_spacy(document_list=preprocessed_docs, document_names=document_names)
if get_entity_function is None:
get_entity_function = self._get_entities
for doc_name, doc, text in zip(document_names, spacy_docs, document_list):
# The pre-processing changes the character positions in the text (because we remove excessive whitespace),
# so this bit of code searches for the found entities in the original text.
if self.preprocess_text:
# Here we will keep a list of the filth that we have found already in this document and only search
# for entities that we've not already searched for in this document. If "Jane" is twice in a document
# and we loop over each "Jane" entity and search the whole document for "Jane", we would yield 4
# "Jane"s instead of just the two that are in the text.
yielded_filth = set()
for ent in get_entity_function(doc):
if ent.text in yielded_filth or ent.label_ not in self.named_entities:
continue
yielded_filth.add(ent.text)
filth_class = self.filth_cls_map.get(ent.label_, None)
if filth_class is None:
continue
# Use a modified version of the regex detector to find the entities in the original document
class PreProcessedSpacyEntityDetector(RegexDetector):
filth_cls = cast(Type[Filth], filth_class)
regex = re.compile(re.escape(ent.text).replace('\\ ', r'\s+'))
regex_detector = PreProcessedSpacyEntityDetector(name=self.name, locale=self.locale)
yield from regex_detector.iter_filth(text, document_name=doc_name)
else:
# If we didn't preprocess, just loop over the entities and yield Filth.
for ent in get_entity_function(doc):
if ent.label_ not in self.named_entities:
continue
filth_class = self.filth_cls_map.get(ent.label_, None)
if filth_class is None:
continue
filth = filth_class(
beg=ent.start_char,
end=ent.end_char,
text=ent.text,
document_name=(str(doc_name) if doc_name else None), # None if no doc_name provided
detector_name=self.name,
locale=self.locale,
)
yield filth
[docs] def iter_filth_documents(self, document_list: Sequence[str],
document_names: Sequence[Optional[str]]) -> Generator[Filth, None, None]:
"""Yields discovered filth in a list of documents.
:param document_list: A list of documents to clean.
:type document_list: List[str]
:param document_names: A list containing the name of each document.
:type document_names: List[str]
:return: An iterator to the discovered :class:`Filth`
:rtype: Iterator[:class:`Filth`]
"""
yield from self._yield_filth(document_list, document_names)
[docs] def iter_filth(self, text: str, document_name: Optional[str] = None) -> Generator[Filth, None, None]:
"""Yields discovered filth in the provided ``text``.
:param text: The dirty text to clean.
:type text: str
:param document_name: The name of the document to clean.
:type document_name: str, optional
:return: An iterator to the discovered :class:`Filth`
:rtype: Iterator[:class:`Filth`]
"""
yield from self.iter_filth_documents(document_list=[text], document_names=[document_name])
[docs] @classmethod
def supported_locale(cls, locale: str) -> bool:
"""Returns true if this ``Detector`` supports the given locale.
:param locale: The locale of the documents in the format: 2 letter lower-case language code followed by an
underscore and the two letter upper-case country code, eg "en_GB" or "de_CH".
:type locale: str
:return: ``True`` if the locale is supported, otherwise ``False``
:rtype: bool
"""
return True
__all__ = ['SpacyEntityDetector']