From b8128dc43602e39cf5f379e90eea6ebc1ed7d9bf Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Thu, 28 Aug 2025 12:54:48 +0300 Subject: [PATCH 1/2] fix: fix colbert query postprocessing --- fastembed/late_interaction/colbert.py | 36 +++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index e3eaaf8a8..3b53de095 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -48,24 +48,24 @@ def _post_process_onnx_output( if not is_doc: for embedding in output.model_output: yield embedding - - if output.input_ids is None or output.attention_mask is None: - raise ValueError( - "input_ids and attention_mask must be provided for document post-processing" - ) - - for i, token_sequence in enumerate(output.input_ids): - for j, token_id in enumerate(token_sequence): # type: ignore - if token_id in self.skip_list or token_id == self.pad_token_id: - output.attention_mask[i, j] = 0 - - output.model_output *= np.expand_dims(output.attention_mask, 2) - norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True) - norm_clamped = np.maximum(norm, 1e-12) - output.model_output /= norm_clamped - - for embedding, attention_mask in zip(output.model_output, output.attention_mask): - yield embedding[attention_mask == 1] + else: + if output.input_ids is None or output.attention_mask is None: + raise ValueError( + "input_ids and attention_mask must be provided for document post-processing" + ) + + for i, token_sequence in enumerate(output.input_ids): + for j, token_id in enumerate(token_sequence): # type: ignore + if token_id in self.skip_list or token_id == self.pad_token_id: + output.attention_mask[i, j] = 0 + + output.model_output *= np.expand_dims(output.attention_mask, 2) + norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True) + norm_clamped = np.maximum(norm, 1e-12) + output.model_output /= norm_clamped + + for embedding, attention_mask in zip(output.model_output, output.attention_mask): + yield embedding[attention_mask == 1] def _preprocess_onnx_input( self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any From c274e9b3eb03aa6beae181d0b2c19318d93a0b5a Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Fri, 29 Aug 2025 13:19:52 +0300 Subject: [PATCH 2/2] fix: improve colbert single embedding tests --- tests/test_late_interaction_embeddings.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index f5b313304..9c7f20c64 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -175,10 +175,7 @@ def test_batch_inference_size_same_as_single_inference(model_name: str): is_ci = os.getenv("CI") model = LateInteractionTextEmbedding(model_name=model_name) - docs_to_embed = [ - "short document", - "A bit longer document, which should not affect the size" - ] + docs_to_embed = ["short document", "A bit longer document, which should not affect the size"] result = list(model.embed(docs_to_embed, batch_size=1)) result_2 = list(model.embed(docs_to_embed, batch_size=2)) assert len(result[0]) == len(result_2[0]) @@ -199,7 +196,9 @@ def test_single_embedding(model_name: str): print("evaluating", model_name) model = LateInteractionTextEmbedding(model_name=model_name) - result = next(iter(model.embed(docs_to_embed, batch_size=6))) + whole_result = list(model.embed(docs_to_embed, batch_size=6)) + assert len(whole_result) == 1 + result = whole_result[0] expected_result = CANONICAL_COLUMN_VALUES[model_name] token_num, abridged_dim = expected_result.shape assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) @@ -220,7 +219,9 @@ def test_single_embedding_query(model_name: str): print("evaluating", model_name) model = LateInteractionTextEmbedding(model_name=model_name) - result = next(iter(model.query_embed(queries_to_embed))) + whole_result = list(model.query_embed(queries_to_embed)) + assert len(whole_result) == 1 + result = whole_result[0] expected_result = CANONICAL_QUERY_VALUES[model_name] token_num, abridged_dim = expected_result.shape assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)