Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,24 @@ def _post_process_onnx_output(
if not is_doc:
for embedding in output.model_output:
yield embedding

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Could also be fixed by adding an empty return here (outside loop but inside if), that would make the code less nested.

if output.input_ids is None or output.attention_mask is None:
raise ValueError(
"input_ids and attention_mask must be provided for document post-processing"
)

for i, token_sequence in enumerate(output.input_ids):
for j, token_id in enumerate(token_sequence): # type: ignore
if token_id in self.skip_list or token_id == self.pad_token_id:
output.attention_mask[i, j] = 0

output.model_output *= np.expand_dims(output.attention_mask, 2)
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
norm_clamped = np.maximum(norm, 1e-12)
output.model_output /= norm_clamped

for embedding, attention_mask in zip(output.model_output, output.attention_mask):
yield embedding[attention_mask == 1]
else:
if output.input_ids is None or output.attention_mask is None:
raise ValueError(
"input_ids and attention_mask must be provided for document post-processing"
)

for i, token_sequence in enumerate(output.input_ids):
for j, token_id in enumerate(token_sequence): # type: ignore
if token_id in self.skip_list or token_id == self.pad_token_id:
output.attention_mask[i, j] = 0

output.model_output *= np.expand_dims(output.attention_mask, 2)
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
norm_clamped = np.maximum(norm, 1e-12)
output.model_output /= norm_clamped

for embedding, attention_mask in zip(output.model_output, output.attention_mask):
yield embedding[attention_mask == 1]
Comment on lines +62 to +68
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid dtype upcasts and reuse the boolean mask for both zeroing and yielding.
Multiplying float embeddings by an int64 mask risks dtype promotions and extra work. Use a boolean mask cast to the model_output dtype, and use dtype-aware epsilon.

Apply this diff:

-            output.model_output *= np.expand_dims(output.attention_mask, 2)
-            norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
-            norm_clamped = np.maximum(norm, 1e-12)
-            output.model_output /= norm_clamped
-
-            for embedding, attention_mask in zip(output.model_output, output.attention_mask):
-                yield embedding[attention_mask == 1]
+            # Zero out masked tokens without changing dtype
+            mask_f = output.attention_mask[..., None].astype(output.model_output.dtype, copy=False)
+            output.model_output *= mask_f
+            # Normalize with dtype-appropriate epsilon
+            norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
+            eps = np.finfo(output.model_output.dtype).eps
+            output.model_output /= np.maximum(norm, eps)
+
+            # Yield only kept tokens
+            for embedding, keep_mask in zip(output.model_output, output.attention_mask):
+                yield embedding[keep_mask]

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In fastembed/late_interaction/colbert.py around lines 62 to 68, avoid
integer->float dtype upcasts and redundant mask comparisons: convert
attention_mask once to a boolean array, cast that boolean mask to the
model_output dtype for in-place masking, and use a dtype-matching epsilon when
clamping norms; then reuse the boolean mask for yielding rows instead of
comparing to 1. Concretely: create mask_bool =
output.attention_mask.astype(bool), mask_cast =
mask_bool.astype(output.model_output.dtype), multiply model_output by mask_cast
in-place, compute norm and norm_clamped using eps = np.array(1e-12,
dtype=output.model_output.dtype), divide by norm_clamped, and in the final loop
use for embedding, mask in zip(output.model_output, mask_bool): yield
embedding[mask].


def _preprocess_onnx_input(
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
Expand Down
13 changes: 7 additions & 6 deletions tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ def test_batch_inference_size_same_as_single_inference(model_name: str):
is_ci = os.getenv("CI")

model = LateInteractionTextEmbedding(model_name=model_name)
docs_to_embed = [
"short document",
"A bit longer document, which should not affect the size"
]
docs_to_embed = ["short document", "A bit longer document, which should not affect the size"]
result = list(model.embed(docs_to_embed, batch_size=1))
result_2 = list(model.embed(docs_to_embed, batch_size=2))
assert len(result[0]) == len(result_2[0])
Expand All @@ -199,7 +196,9 @@ def test_single_embedding(model_name: str):

print("evaluating", model_name)
model = LateInteractionTextEmbedding(model_name=model_name)
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
whole_result = list(model.embed(docs_to_embed, batch_size=6))
assert len(whole_result) == 1
result = whole_result[0]
expected_result = CANONICAL_COLUMN_VALUES[model_name]
token_num, abridged_dim = expected_result.shape
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)
Expand All @@ -220,7 +219,9 @@ def test_single_embedding_query(model_name: str):

print("evaluating", model_name)
model = LateInteractionTextEmbedding(model_name=model_name)
result = next(iter(model.query_embed(queries_to_embed)))
whole_result = list(model.query_embed(queries_to_embed))
assert len(whole_result) == 1
result = whole_result[0]
expected_result = CANONICAL_QUERY_VALUES[model_name]
token_num, abridged_dim = expected_result.shape
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)
Expand Down
Loading