From cbdcb62bbe5e2c8d4e3debd9c7c6384f7d2056a7 Mon Sep 17 00:00:00 2001 From: Peter Chapman Date: Mon, 11 Aug 2025 15:15:42 +1200 Subject: [PATCH] Allow output_attentions to be set to False --- .../huggingface/hugging_face_nmt_engine.py | 133 ++++++++++-------- machine/translation/word_alignment_matrix.py | 4 +- .../test_hugging_face_nmt_engine.py | 36 +++-- 3 files changed, 105 insertions(+), 68 deletions(-) diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 04086afd..562c5dbc 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -4,7 +4,7 @@ import logging import re from math import exp, prod -from typing import Iterable, List, Optional, Sequence, Tuple, Union, cast +from typing import Collection, Iterable, List, Optional, Sequence, Tuple, Union, cast import torch # pyright: ignore[reportMissingImports] from sacremoses import MosesPunctNormalizer @@ -24,6 +24,7 @@ from transformers.tokenization_utils import BatchEncoding, TruncationStrategy from ...annotations.range import Range +from ...corpora.aligned_word_pair import AlignedWordPair from ...utils.typeshed import StrPath from ..translation_engine import TranslationEngine from ..translation_result import TranslationResult @@ -163,10 +164,11 @@ def _try_translate_n_batch( builder = TranslationResultBuilder(input_tokens) for token, score in zip(output["translation_tokens"], output["token_scores"]): builder.append_token(token, TranslationSources.NMT, exp(score)) - src_indices = torch.argmax(output["token_attentions"], dim=1).tolist() - wa_matrix = WordAlignmentMatrix.from_word_pairs( - len(input_tokens), output_length, set(zip(src_indices, range(output_length))) - ) + word_pairs: Optional[Collection[Union[AlignedWordPair, Tuple[int, int]]]] = None + if output.get("token_attentions") is not None: + src_indices = torch.argmax(output["token_attentions"], dim=1).tolist() + word_pairs = set(zip(src_indices, range(output_length))) + wa_matrix = WordAlignmentMatrix.from_word_pairs(len(input_tokens), output_length, word_pairs) builder.mark_phrase(Range.create(0, len(input_tokens)), wa_matrix) segment_results.append(builder.to_result(output["translation_text"])) all_results.append(segment_results) @@ -242,12 +244,12 @@ def _forward(self, model_inputs, **generate_kwargs): config = self.model.config generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length) generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length) + generate_kwargs["output_attentions"] = generate_kwargs.get("output_attentions", True) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) output = self.model.generate( **model_inputs, **generate_kwargs, output_scores=True, - output_attentions=True, return_dict_in_generate=True, ) @@ -285,36 +287,39 @@ def _forward(self, model_inputs, **generate_kwargs): if self.model.config.decoder_start_token_id is not None: scores = torch.cat((torch.zeros(scores.shape[0], scores.shape[1], 1, device=scores.device), scores), dim=2) - assert attentions is not None - num_heads = attentions[0][0].shape[1] - indices = torch.stack( - ( - torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(in_b, n_sequences, -1), - torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)), - ), - dim=3, - ) - num_layers = len(attentions[0]) - layer = (2 * num_layers) // 3 - attentions = ( - torch.stack([cast(Tuple[torch.FloatTensor, ...], a)[layer][:, :, -1, :] for a in attentions], dim=0) - .squeeze() - .reshape(len(attentions), in_b, num_beams, num_heads, -1) - .transpose(0, 1) - ) - attentions = torch.mean(attentions, dim=3) - attentions = torch_gather_nd(attentions, indices, 1) - if self.model.config.decoder_start_token_id is not None: - attentions = torch.cat( + if generate_kwargs["output_attentions"] is True: + assert attentions is not None + num_heads = attentions[0][0].shape[1] + indices = torch.stack( ( - torch.zeros( - (attentions.shape[0], attentions.shape[1], 1, attentions.shape[3]), - device=attentions.device, + torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand( + in_b, n_sequences, -1 ), - attentions, + torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)), ), - dim=2, + dim=3, ) + num_layers = len(attentions[0]) + layer = (2 * num_layers) // 3 + attentions = ( + torch.stack([cast(Tuple[torch.FloatTensor, ...], a)[layer][:, :, -1, :] for a in attentions], dim=0) + .squeeze() + .reshape(len(attentions), in_b, num_beams, num_heads, -1) + .transpose(0, 1) + ) + attentions = torch.mean(attentions, dim=3) + attentions = torch_gather_nd(attentions, indices, 1) + if self.model.config.decoder_start_token_id is not None: + attentions = torch.cat( + ( + torch.zeros( + (attentions.shape[0], attentions.shape[1], 1, attentions.shape[3]), + device=attentions.device, + ), + attentions, + ), + dim=2, + ) output_ids = output_ids.reshape(in_b, n_sequences, *output_ids.shape[1:]) return { @@ -339,14 +344,27 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False): input_tokens = model_outputs["input_tokens"][0] records = [] - output_ids: torch.Tensor - scores: torch.Tensor - attentions: torch.Tensor - for output_ids, scores, attentions in zip( - model_outputs["output_ids"][0], - model_outputs["scores"][0], - model_outputs["attentions"][0], - ): + + has_attentions = model_outputs.get("attentions") is not None and model_outputs["attentions"][0] is not None + if has_attentions: + zipped = zip( + model_outputs["output_ids"][0], + model_outputs["scores"][0], + model_outputs["attentions"][0], + ) + else: + zipped = zip( + model_outputs["output_ids"][0], + model_outputs["scores"][0], + ) + + for item in zipped: + if has_attentions: + output_ids, scores, attentions = cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], item) + else: + output_ids, scores = cast(Tuple[torch.Tensor, torch.Tensor], item) + attentions = None + output_tokens: List[str] = [] output_indices: List[int] = [] for i, output_id in enumerate(output_ids): @@ -354,22 +372,27 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False): if id not in all_special_ids: output_tokens.append(self.tokenizer.convert_ids_to_tokens(id)) output_indices.append(i) + scores = scores[output_indices] - attentions = attentions[output_indices] - attentions = attentions[:, input_indices] - records.append( - { - "input_tokens": input_tokens, - "translation_tokens": output_tokens, - "token_scores": scores, - "token_attentions": attentions, - "translation_text": self.tokenizer.decode( - output_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - ), - } - ) + + record = { + "input_tokens": input_tokens, + "translation_tokens": output_tokens, + "token_scores": scores, + "translation_text": self.tokenizer.decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + ), + } + + if attentions is not None: + attentions = attentions[output_indices] + attentions = attentions[:, input_indices] + record["token_attentions"] = attentions + + records.append(record) + return records diff --git a/machine/translation/word_alignment_matrix.py b/machine/translation/word_alignment_matrix.py index d65e856a..f3323aa2 100644 --- a/machine/translation/word_alignment_matrix.py +++ b/machine/translation/word_alignment_matrix.py @@ -23,8 +23,10 @@ def from_word_pairs( cls, row_count: int, column_count: int, - set_values: Collection[Union[AlignedWordPair, Tuple[int, int]]] = set(), + set_values: Optional[Collection[Union[AlignedWordPair, Tuple[int, int]]]] = None, ) -> WordAlignmentMatrix: + if set_values is None: + set_values = set() matrix = np.full((row_count, column_count), False) for i, j in set_values: matrix[i, j] = True diff --git a/tests/translation/huggingface/test_hugging_face_nmt_engine.py b/tests/translation/huggingface/test_hugging_face_nmt_engine.py index abeab018..cdaf674c 100644 --- a/tests/translation/huggingface/test_hugging_face_nmt_engine.py +++ b/tests/translation/huggingface/test_hugging_face_nmt_engine.py @@ -5,39 +5,51 @@ skip("skipping Hugging Face tests on MacOS", allow_module_level=True) -from pytest import approx, raises +from pytest import approx, mark, raises from machine.translation.huggingface import HuggingFaceNmtEngine -def test_translate_n_batch_beam() -> None: - with HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="en", tgt_lang="es", num_beams=2, max_length=10) as engine: +@mark.parametrize("output_attentions", [True, False]) +def test_translate_n_batch_beam(output_attentions: bool) -> None: + with HuggingFaceNmtEngine( + "stas/tiny-m2m_100", + src_lang="en", + tgt_lang="es", + num_beams=2, + max_length=10, + output_attentions=output_attentions, + ) as engine: results = engine.translate_n_batch( n=2, segments=["This is a test string", "Hello, world!"], ) assert results[0][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir" assert results[0][0].confidences[0] == approx(1.08e-05, 0.01) - assert str(results[0][0].alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" + assert str(results[0][0].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "") assert results[0][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir" assert results[0][1].confidences[0] == approx(1.08e-05, 0.01) - assert str(results[0][1].alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" + assert str(results[0][1].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "") assert results[1][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir" assert results[1][0].confidences[0] == approx(1.08e-05, 0.01) - assert str(results[1][0].alignment) == "0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" + assert str(results[1][0].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "") assert results[1][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir" assert results[1][1].confidences[0] == approx(1.08e-05, 0.01) - assert str(results[1][1].alignment) == "0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" + assert str(results[1][1].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "") -def test_translate_greedy() -> None: - with HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="en", tgt_lang="es", max_length=10) as engine: +@mark.parametrize("output_attentions", [True, False]) +def test_translate_greedy(output_attentions: bool) -> None: + with HuggingFaceNmtEngine( + "stas/tiny-m2m_100", src_lang="en", tgt_lang="es", max_length=10, output_attentions=output_attentions + ) as engine: result = engine.translate("This is a test string") assert result.translation == "skaberskaber Dollar Dollar Dollar ፤ gerekir gerekir" assert result.confidences[0] == approx(1.08e-05, 0.01) - assert str(result.alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" + assert str(result.alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "") -def test_construct_invalid_lang() -> None: +@mark.parametrize("output_attentions", [True, False]) +def test_construct_invalid_lang(output_attentions: bool) -> None: with raises(ValueError): - HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es") + HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es", output_attentions=output_attentions)