diff --git a/sense2vec/sense2vec.py b/sense2vec/sense2vec.py index bb157f5..1e1cf8f 100644 --- a/sense2vec/sense2vec.py +++ b/sense2vec/sense2vec.py @@ -3,6 +3,7 @@ from spacy.vectors import Vectors from spacy.strings import StringStore from spacy.util import SimpleFrozenDict +from thinc.api import NumpyOps import numpy import srsly @@ -247,7 +248,11 @@ def get_other_senses( result = [] key = key if isinstance(key, str) else self.strings[key] word, orig_sense = self.split_key(key) - versions = set([word, word.lower(), word.upper(), word.title()]) if ignore_case else [word] + versions = ( + set([word, word.lower(), word.upper(), word.title()]) + if ignore_case + else [word] + ) for text in versions: for sense in self.senses: new_key = self.make_key(text, sense) @@ -270,7 +275,11 @@ def get_best_sense( sense_options = senses or self.senses if not sense_options: return None - versions = set([word, word.lower(), word.upper(), word.title()]) if ignore_case else [word] + versions = ( + set([word, word.lower(), word.upper(), word.title()]) + if ignore_case + else [word] + ) freqs = [] for text in versions: for sense in sense_options: @@ -304,6 +313,9 @@ def from_bytes(self, bytes_data: bytes, exclude: Sequence[str] = tuple()): """ data = srsly.msgpack_loads(bytes_data) self.vectors = Vectors().from_bytes(data["vectors"]) + # Pin vectors to the CPU so that we don't end up comparing + # numpy and cupy arrays. + self.vectors.to_ops(NumpyOps()) self.freqs = dict(data.get("freqs", [])) self.cfg.update(data.get("cfg", {})) if "strings" not in exclude and "strings" in data: @@ -340,6 +352,9 @@ def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()): freqs_path = path / "freqs.json" cache_path = path / "cache" self.vectors = Vectors().from_disk(path) + # Pin vectors to the CPU so that we don't end up comparing + # numpy and cupy arrays. + self.vectors.to_ops(NumpyOps()) self.cfg.update(srsly.read_json(path / "cfg")) if freqs_path.exists(): self.freqs = dict(srsly.read_json(freqs_path)) diff --git a/sense2vec/tests/test_issue155.py b/sense2vec/tests/test_issue155.py new file mode 100644 index 0000000..546734d --- /dev/null +++ b/sense2vec/tests/test_issue155.py @@ -0,0 +1,13 @@ +from pathlib import Path +import pytest +from sense2vec.sense2vec import Sense2Vec +from thinc.api import use_ops +from thinc.util import has_cupy_gpu + + +@pytest.mark.skipif(not has_cupy_gpu, reason="requires Cupy/GPU") +def test_issue155(): + data_path = Path(__file__).parent / "data" + with use_ops("cupy"): + s2v = Sense2Vec().from_disk(data_path) + s2v.most_similar("beekeepers|NOUN")