diff --git a/README.md b/README.md
index bd76959..b2a8711 100644
--- a/README.md
+++ b/README.md
@@ -48,9 +48,10 @@ Table of Contents
1. [English language](#english-language)
2. [Other languages](#other-languages)
2. [KeyphraseTfidfVectorizer](#keyphrasetfidfvectorizer)
- 3. [Custom POS-tagger](#custom-pos-tagger)
- 4. [PatternRank: Keyphrase extraction with KeyphraseVectorizers and KeyBERT](#patternrank-keyphrase-extraction-with-keyphrasevectorizers-and-keybert)
- 5. [Topic modeling with BERTopic and KeyphraseVectorizers](#topic-modeling-with-bertopic-and-keyphrasevectorizers)
+ 3. [Reuse a spaCy Language object](#reuse-a-spacy-language-object)
+ 4. [Custom POS-tagger](#custom-pos-tagger)
+ 5. [PatternRank: Keyphrase extraction with KeyphraseVectorizers and KeyBERT](#patternrank-keyphrase-extraction-with-keyphrasevectorizers-and-keybert)
+ 6. [Topic modeling with BERTopic and KeyphraseVectorizers](#topic-modeling-with-bertopic-and-keyphrasevectorizers)
4. [Citation information](#citation-information)
@@ -290,6 +291,48 @@ print(keyphrases)
'phrases' 'overlap' 'users' 'learning algorithm' 'document']
```
+
+
+### Reuse a spaCy Language object
+
+[Back to Table of Contents](#toc)
+
+KeyphraseVectorizers loads a `spacy.Language` object for every `KeyphraseVectorizer` object.
+When using multiple `KeyphraseVectorizer` objects, it is more efficient to load the `spacy.Language` object beforehand and pass it as the `spacy_pipeline` argument.
+
+```python
+import spacy
+from keyphrase_vectorizers import KeyphraseCountVectorizer, KeyphraseTfidfVectorizer
+
+docs = ["""Supervised learning is the machine learning task of learning a function that
+ maps an input to an output based on example input-output pairs. It infers a
+ function from labeled training data consisting of a set of training examples.
+ In supervised learning, each example is a pair consisting of an input object
+ (typically a vector) and a desired output value (also called the supervisory signal).
+ A supervised learning algorithm analyzes the training data and produces an inferred function,
+ which can be used for mapping new examples. An optimal scenario will allow for the
+ algorithm to correctly determine the class labels for unseen instances. This requires
+ the learning algorithm to generalize from the training data to unseen situations in a
+ 'reasonable' way (see inductive bias).""",
+
+ """Keywords are defined as phrases that capture the main topics discussed in a document.
+ As they offer a brief yet precise summary of document content, they can be utilized for various applications.
+ In an information retrieval environment, they serve as an indication of document relevance for users, as the list
+ of keywords can quickly help to determine whether a given document is relevant to their interest.
+ As keywords reflect a document's main topics, they can be utilized to classify documents into groups
+ by measuring the overlap between the keywords assigned to them. Keywords are also used proactively
+ in information retrieval."""]
+
+nlp = spacy.load("en_core_web_sm")
+
+vectorizer1 = KeyphraseCountVectorizer(spacy_pipeline=nlp)
+vectorizer2 = KeyphraseTfidfVectorizer(spacy_pipeline=nlp)
+
+# the following calls use the nlp object
+vectorizer1.fit(docs)
+vectorizer2.fit(docs)
+```
+
### Custom POS-tagger
diff --git a/keyphrase_vectorizers/keyphrase_count_vectorizer.py b/keyphrase_vectorizers/keyphrase_count_vectorizer.py
index 46c7942..0a355bc 100644
--- a/keyphrase_vectorizers/keyphrase_count_vectorizer.py
+++ b/keyphrase_vectorizers/keyphrase_count_vectorizer.py
@@ -12,6 +12,7 @@
import numpy as np
import psutil
+import spacy
from sklearn.base import BaseEstimator
from sklearn.exceptions import NotFittedError
from sklearn.feature_extraction.text import CountVectorizer
@@ -39,8 +40,8 @@ class KeyphraseCountVectorizer(_KeyphraseVectorizerMixin, BaseEstimator):
Parameters
----------
- spacy_pipeline : str, default='en_core_web_sm'
- The name of the `spaCy pipeline`_, used to tag the parts-of-speech in the text. Standard is the 'en' pipeline.
+ spacy_pipeline : Union[str, spacy.Language], default='en_core_web_sm'
+ A spacy.Language object or the name of the `spaCy pipeline`_, used to tag the parts-of-speech in the text. Standard is the 'en' pipeline.
pos_pattern : str, default='*+'
The `regex pattern`_ of `POS-tags`_ used to extract a sequence of POS-tagged tokens from the text.
@@ -86,7 +87,7 @@ class KeyphraseCountVectorizer(_KeyphraseVectorizerMixin, BaseEstimator):
Type of the matrix returned by fit_transform() or transform().
"""
- def __init__(self, spacy_pipeline: str = 'en_core_web_sm', pos_pattern: str = '*+',
+ def __init__(self, spacy_pipeline: Union[str, spacy.Language] = 'en_core_web_sm', pos_pattern: str = '*+',
stop_words: Union[str, List[str]] = 'english', lowercase: bool = True, workers: int = 1,
spacy_exclude: List[str] = None, custom_pos_tagger: callable = None,
max_df: int = None, min_df: int = None, binary: bool = False, dtype: np.dtype = np.int64):
diff --git a/keyphrase_vectorizers/keyphrase_tfidf_vectorizer.py b/keyphrase_vectorizers/keyphrase_tfidf_vectorizer.py
index b17efb7..30661cd 100644
--- a/keyphrase_vectorizers/keyphrase_tfidf_vectorizer.py
+++ b/keyphrase_vectorizers/keyphrase_tfidf_vectorizer.py
@@ -12,6 +12,7 @@
import numpy as np
import psutil
+import spacy
from sklearn.exceptions import NotFittedError
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.utils.validation import FLOAT_DTYPES
@@ -67,8 +68,8 @@ class KeyphraseTfidfVectorizer(KeyphraseCountVectorizer):
Parameters
----------
- spacy_pipeline : str, default='en_core_web_sm'
- The name of the `spaCy pipeline`_, used to tag the parts-of-speech in the text. Standard is the 'en' pipeline.
+ spacy_pipeline : Union[str, spacy.Language], default='en_core_web_sm'
+ A spacy.Language object or the name of the `spaCy pipeline`_, used to tag the parts-of-speech in the text. Standard is the 'en' pipeline.
pos_pattern : str, default='*+'
The `regex pattern`_ of `POS-tags`_ used to extract a sequence of POS-tagged tokens from the text.
@@ -131,7 +132,7 @@ class KeyphraseTfidfVectorizer(KeyphraseCountVectorizer):
"""
- def __init__(self, spacy_pipeline: str = 'en_core_web_sm', pos_pattern: str = '*+',
+ def __init__(self, spacy_pipeline: Union[str, spacy.Language] = 'en_core_web_sm', pos_pattern: str = '*+',
stop_words: Union[str, List[str]] = 'english',
lowercase: bool = True, workers: int = 1, spacy_exclude: List[str] = None,
custom_pos_tagger: callable = None, max_df: int = None, min_df: int = None,
diff --git a/keyphrase_vectorizers/keyphrase_vectorizer_mixin.py b/keyphrase_vectorizers/keyphrase_vectorizer_mixin.py
index 563f836..e62b1b2 100644
--- a/keyphrase_vectorizers/keyphrase_vectorizer_mixin.py
+++ b/keyphrase_vectorizers/keyphrase_vectorizer_mixin.py
@@ -180,7 +180,7 @@ def _split_long_document(self, text: str, max_text_length: int) -> List[str]:
max_text_length=max_text_length)
return splitted_document
- def _get_pos_keyphrases(self, document_list: List[str], stop_words: Union[str, List[str]], spacy_pipeline: str,
+ def _get_pos_keyphrases(self, document_list: List[str], stop_words: Union[str, List[str]], spacy_pipeline: Union[str, spacy.Language],
pos_pattern: str, spacy_exclude: List[str], custom_pos_tagger: callable,
lowercase: bool = True, workers: int = 1) -> List[str]:
"""
@@ -196,8 +196,8 @@ def _get_pos_keyphrases(self, document_list: List[str], stop_words: Union[str, L
Removes unwanted stopwords from keyphrases if 'stop_words' is not None.
If given a list of custom stopwords, removes them instead.
- spacy_pipeline : str
- The name of the `spaCy pipeline`_, used to tag the parts-of-speech in the text.
+ spacy_pipeline : Union[str, spacy.Language]
+ A spacy.Language object or the name of the `spaCy pipeline`_, used to tag the parts-of-speech in the text.
pos_pattern : str
The `regex pattern`_ of `POS-tags`_ used to extract a sequence of POS-tagged tokens from the text.
@@ -245,9 +245,9 @@ def _get_pos_keyphrases(self, document_list: List[str], stop_words: Union[str, L
)
# triggers a parameter validation
- if not isinstance(spacy_pipeline, str):
+ if not isinstance(spacy_pipeline, (str, spacy.Language)):
raise ValueError(
- "'spacy_pipeline' parameter needs to be a spaCy pipeline string. E.g. 'en_core_web_sm'"
+ "'spacy_pipeline' parameter needs to be a spacy.Language object or a spaCy pipeline string. E.g. 'en_core_web_sm'"
)
# triggers a parameter validation
@@ -304,25 +304,28 @@ def _get_pos_keyphrases(self, document_list: List[str], stop_words: Union[str, L
# add spaCy POS tags for documents
if not custom_pos_tagger:
- if not spacy_exclude:
- spacy_exclude = []
- try:
- nlp = spacy.load(spacy_pipeline,
- exclude=spacy_exclude)
- except OSError:
- # set logger
- logger = logging.getLogger('KeyphraseVectorizer')
- logger.setLevel(logging.WARNING)
- sh = logging.StreamHandler()
- sh.setFormatter(logging.Formatter(
- '%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
- logger.addHandler(sh)
- logger.setLevel(logging.DEBUG)
- logger.info(
- 'It looks like the selected spaCy pipeline is not downloaded yet. It is attempted to download the spaCy pipeline now.')
- spacy.cli.download(spacy_pipeline)
- nlp = spacy.load(spacy_pipeline,
- exclude=spacy_exclude)
+ if isinstance(spacy_pipeline, spacy.Language):
+ nlp = spacy_pipeline
+ else:
+ if not spacy_exclude:
+ spacy_exclude = []
+ try:
+ nlp = spacy.load(spacy_pipeline,
+ exclude=spacy_exclude)
+ except OSError:
+ # set logger
+ logger = logging.getLogger('KeyphraseVectorizer')
+ logger.setLevel(logging.WARNING)
+ sh = logging.StreamHandler()
+ sh.setFormatter(logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
+ logger.addHandler(sh)
+ logger.setLevel(logging.DEBUG)
+ logger.info(
+ 'It looks like the selected spaCy pipeline is not downloaded yet. It is attempted to download the spaCy pipeline now.')
+ spacy.cli.download(spacy_pipeline)
+ nlp = spacy.load(spacy_pipeline,
+ exclude=spacy_exclude)
if workers != 1:
os.environ["TOKENIZERS_PARALLELISM"] = "false"