Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 78 additions & 55 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import re
from math import exp, prod
from typing import Iterable, List, Optional, Sequence, Tuple, Union, cast
from typing import Collection, Iterable, List, Optional, Sequence, Tuple, Union, cast

import torch # pyright: ignore[reportMissingImports]
from sacremoses import MosesPunctNormalizer
Expand All @@ -24,6 +24,7 @@
from transformers.tokenization_utils import BatchEncoding, TruncationStrategy

from ...annotations.range import Range
from ...corpora.aligned_word_pair import AlignedWordPair
from ...utils.typeshed import StrPath
from ..translation_engine import TranslationEngine
from ..translation_result import TranslationResult
Expand Down Expand Up @@ -163,10 +164,11 @@ def _try_translate_n_batch(
builder = TranslationResultBuilder(input_tokens)
for token, score in zip(output["translation_tokens"], output["token_scores"]):
builder.append_token(token, TranslationSources.NMT, exp(score))
src_indices = torch.argmax(output["token_attentions"], dim=1).tolist()
wa_matrix = WordAlignmentMatrix.from_word_pairs(
len(input_tokens), output_length, set(zip(src_indices, range(output_length)))
)
word_pairs: Optional[Collection[Union[AlignedWordPair, Tuple[int, int]]]] = None
if output.get("token_attentions") is not None:
src_indices = torch.argmax(output["token_attentions"], dim=1).tolist()
word_pairs = set(zip(src_indices, range(output_length)))
wa_matrix = WordAlignmentMatrix.from_word_pairs(len(input_tokens), output_length, word_pairs)
builder.mark_phrase(Range.create(0, len(input_tokens)), wa_matrix)
segment_results.append(builder.to_result(output["translation_text"]))
all_results.append(segment_results)
Expand Down Expand Up @@ -242,12 +244,12 @@ def _forward(self, model_inputs, **generate_kwargs):
config = self.model.config
generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length)
generate_kwargs["output_attentions"] = generate_kwargs.get("output_attentions", True)
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
output = self.model.generate(
**model_inputs,
**generate_kwargs,
output_scores=True,
output_attentions=True,
return_dict_in_generate=True,
)

Expand Down Expand Up @@ -285,36 +287,39 @@ def _forward(self, model_inputs, **generate_kwargs):
if self.model.config.decoder_start_token_id is not None:
scores = torch.cat((torch.zeros(scores.shape[0], scores.shape[1], 1, device=scores.device), scores), dim=2)

assert attentions is not None
num_heads = attentions[0][0].shape[1]
indices = torch.stack(
(
torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(in_b, n_sequences, -1),
torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)),
),
dim=3,
)
num_layers = len(attentions[0])
layer = (2 * num_layers) // 3
attentions = (
torch.stack([cast(Tuple[torch.FloatTensor, ...], a)[layer][:, :, -1, :] for a in attentions], dim=0)
.squeeze()
.reshape(len(attentions), in_b, num_beams, num_heads, -1)
.transpose(0, 1)
)
attentions = torch.mean(attentions, dim=3)
attentions = torch_gather_nd(attentions, indices, 1)
if self.model.config.decoder_start_token_id is not None:
attentions = torch.cat(
if generate_kwargs["output_attentions"] is True:
assert attentions is not None
num_heads = attentions[0][0].shape[1]
indices = torch.stack(
(
torch.zeros(
(attentions.shape[0], attentions.shape[1], 1, attentions.shape[3]),
device=attentions.device,
torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(
in_b, n_sequences, -1
),
attentions,
torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)),
),
dim=2,
dim=3,
)
num_layers = len(attentions[0])
layer = (2 * num_layers) // 3
attentions = (
torch.stack([cast(Tuple[torch.FloatTensor, ...], a)[layer][:, :, -1, :] for a in attentions], dim=0)
.squeeze()
.reshape(len(attentions), in_b, num_beams, num_heads, -1)
.transpose(0, 1)
)
attentions = torch.mean(attentions, dim=3)
attentions = torch_gather_nd(attentions, indices, 1)
if self.model.config.decoder_start_token_id is not None:
attentions = torch.cat(
(
torch.zeros(
(attentions.shape[0], attentions.shape[1], 1, attentions.shape[3]),
device=attentions.device,
),
attentions,
),
dim=2,
)

output_ids = output_ids.reshape(in_b, n_sequences, *output_ids.shape[1:])
return {
Expand All @@ -339,37 +344,55 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
input_tokens = model_outputs["input_tokens"][0]

records = []
output_ids: torch.Tensor
scores: torch.Tensor
attentions: torch.Tensor
for output_ids, scores, attentions in zip(
model_outputs["output_ids"][0],
model_outputs["scores"][0],
model_outputs["attentions"][0],
):

has_attentions = model_outputs.get("attentions") is not None and model_outputs["attentions"][0] is not None
if has_attentions:
zipped = zip(
model_outputs["output_ids"][0],
model_outputs["scores"][0],
model_outputs["attentions"][0],
)
else:
zipped = zip(
model_outputs["output_ids"][0],
model_outputs["scores"][0],
)

for item in zipped:
if has_attentions:
output_ids, scores, attentions = cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], item)
else:
output_ids, scores = cast(Tuple[torch.Tensor, torch.Tensor], item)
attentions = None

output_tokens: List[str] = []
output_indices: List[int] = []
for i, output_id in enumerate(output_ids):
id = cast(int, output_id.item())
if id not in all_special_ids:
output_tokens.append(self.tokenizer.convert_ids_to_tokens(id))
output_indices.append(i)

scores = scores[output_indices]
attentions = attentions[output_indices]
attentions = attentions[:, input_indices]
records.append(
{
"input_tokens": input_tokens,
"translation_tokens": output_tokens,
"token_scores": scores,
"token_attentions": attentions,
"translation_text": self.tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
),
}
)

record = {
"input_tokens": input_tokens,
"translation_tokens": output_tokens,
"token_scores": scores,
"translation_text": self.tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
),
}

if attentions is not None:
attentions = attentions[output_indices]
attentions = attentions[:, input_indices]
record["token_attentions"] = attentions

records.append(record)

return records


Expand Down
4 changes: 3 additions & 1 deletion machine/translation/word_alignment_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ def from_word_pairs(
cls,
row_count: int,
column_count: int,
set_values: Collection[Union[AlignedWordPair, Tuple[int, int]]] = set(),
set_values: Optional[Collection[Union[AlignedWordPair, Tuple[int, int]]]] = None,
) -> WordAlignmentMatrix:
if set_values is None:
set_values = set()
matrix = np.full((row_count, column_count), False)
for i, j in set_values:
matrix[i, j] = True
Expand Down
36 changes: 24 additions & 12 deletions tests/translation/huggingface/test_hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,51 @@

skip("skipping Hugging Face tests on MacOS", allow_module_level=True)

from pytest import approx, raises
from pytest import approx, mark, raises

from machine.translation.huggingface import HuggingFaceNmtEngine


def test_translate_n_batch_beam() -> None:
with HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="en", tgt_lang="es", num_beams=2, max_length=10) as engine:
@mark.parametrize("output_attentions", [True, False])
def test_translate_n_batch_beam(output_attentions: bool) -> None:
with HuggingFaceNmtEngine(
"stas/tiny-m2m_100",
src_lang="en",
tgt_lang="es",
num_beams=2,
max_length=10,
output_attentions=output_attentions,
) as engine:
results = engine.translate_n_batch(
n=2,
segments=["This is a test string", "Hello, world!"],
)
assert results[0][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir"
assert results[0][0].confidences[0] == approx(1.08e-05, 0.01)
assert str(results[0][0].alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7"
assert str(results[0][0].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
assert results[0][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir"
assert results[0][1].confidences[0] == approx(1.08e-05, 0.01)
assert str(results[0][1].alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7"
assert str(results[0][1].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
assert results[1][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir"
assert results[1][0].confidences[0] == approx(1.08e-05, 0.01)
assert str(results[1][0].alignment) == "0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6"
assert str(results[1][0].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "")
assert results[1][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir"
assert results[1][1].confidences[0] == approx(1.08e-05, 0.01)
assert str(results[1][1].alignment) == "0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6"
assert str(results[1][1].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "")


def test_translate_greedy() -> None:
with HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="en", tgt_lang="es", max_length=10) as engine:
@mark.parametrize("output_attentions", [True, False])
def test_translate_greedy(output_attentions: bool) -> None:
with HuggingFaceNmtEngine(
"stas/tiny-m2m_100", src_lang="en", tgt_lang="es", max_length=10, output_attentions=output_attentions
) as engine:
result = engine.translate("This is a test string")
assert result.translation == "skaberskaber Dollar Dollar Dollar ፤ gerekir gerekir"
assert result.confidences[0] == approx(1.08e-05, 0.01)
assert str(result.alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7"
assert str(result.alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")


def test_construct_invalid_lang() -> None:
@mark.parametrize("output_attentions", [True, False])
def test_construct_invalid_lang(output_attentions: bool) -> None:
with raises(ValueError):
HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es")
HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es", output_attentions=output_attentions)
Loading