-
Notifications
You must be signed in to change notification settings - Fork 194
fix: fix colbert query postprocessing #557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,24 +48,24 @@ def _post_process_onnx_output( | |
| if not is_doc: | ||
| for embedding in output.model_output: | ||
| yield embedding | ||
|
|
||
| if output.input_ids is None or output.attention_mask is None: | ||
| raise ValueError( | ||
| "input_ids and attention_mask must be provided for document post-processing" | ||
| ) | ||
|
|
||
| for i, token_sequence in enumerate(output.input_ids): | ||
| for j, token_id in enumerate(token_sequence): # type: ignore | ||
| if token_id in self.skip_list or token_id == self.pad_token_id: | ||
| output.attention_mask[i, j] = 0 | ||
|
|
||
| output.model_output *= np.expand_dims(output.attention_mask, 2) | ||
| norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True) | ||
| norm_clamped = np.maximum(norm, 1e-12) | ||
| output.model_output /= norm_clamped | ||
|
|
||
| for embedding, attention_mask in zip(output.model_output, output.attention_mask): | ||
| yield embedding[attention_mask == 1] | ||
| else: | ||
| if output.input_ids is None or output.attention_mask is None: | ||
| raise ValueError( | ||
| "input_ids and attention_mask must be provided for document post-processing" | ||
| ) | ||
|
|
||
| for i, token_sequence in enumerate(output.input_ids): | ||
| for j, token_id in enumerate(token_sequence): # type: ignore | ||
| if token_id in self.skip_list or token_id == self.pad_token_id: | ||
| output.attention_mask[i, j] = 0 | ||
|
|
||
| output.model_output *= np.expand_dims(output.attention_mask, 2) | ||
| norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True) | ||
| norm_clamped = np.maximum(norm, 1e-12) | ||
| output.model_output /= norm_clamped | ||
|
|
||
| for embedding, attention_mask in zip(output.model_output, output.attention_mask): | ||
| yield embedding[attention_mask == 1] | ||
|
Comment on lines
+62
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Avoid dtype upcasts and reuse the boolean mask for both zeroing and yielding. Apply this diff: - output.model_output *= np.expand_dims(output.attention_mask, 2)
- norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
- norm_clamped = np.maximum(norm, 1e-12)
- output.model_output /= norm_clamped
-
- for embedding, attention_mask in zip(output.model_output, output.attention_mask):
- yield embedding[attention_mask == 1]
+ # Zero out masked tokens without changing dtype
+ mask_f = output.attention_mask[..., None].astype(output.model_output.dtype, copy=False)
+ output.model_output *= mask_f
+ # Normalize with dtype-appropriate epsilon
+ norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
+ eps = np.finfo(output.model_output.dtype).eps
+ output.model_output /= np.maximum(norm, eps)
+
+ # Yield only kept tokens
+ for embedding, keep_mask in zip(output.model_output, output.attention_mask):
+ yield embedding[keep_mask]
🤖 Prompt for AI Agents |
||
|
|
||
| def _preprocess_onnx_input( | ||
| self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Could also be fixed by adding an empty
returnhere (outside loop but inside if), that would make the code less nested.