Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.core.config import hydra_runner
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank
Expand Down Expand Up @@ -160,16 +160,26 @@ def main(cfg) -> None:
), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size"

if cfg.gpt_model_file:
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.gpt_model_file):
save_restore_connector.model_extracted_dir = cfg.gpt_model_file

pretrained_cfg = MegatronGPTModel.restore_from(
restore_path=cfg.gpt_model_file, trainer=trainer, return_config=True
restore_path=cfg.gpt_model_file,
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
OmegaConf.set_struct(pretrained_cfg, True)
with open_dict(pretrained_cfg):
pretrained_cfg.sequence_parallel = False
pretrained_cfg.activations_checkpoint_granularity = None
pretrained_cfg.activations_checkpoint_method = None
model = MegatronGPTModel.restore_from(
restore_path=cfg.gpt_model_file, trainer=trainer, override_config_path=pretrained_cfg
restore_path=cfg.gpt_model_file,
trainer=trainer,
override_config_path=pretrained_cfg,
save_restore_connector=save_restore_connector,
)
elif cfg.checkpoint_dir:
app_state = AppState()
Expand Down
39 changes: 31 additions & 8 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Utilities for generating text."""

from collections.abc import Iterable

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -454,6 +456,21 @@ def generate(
repetition_penalty=repetition_penalty,
min_tokens_to_generate=min_tokens_to_generate,
)
special_tokens = set()
if hasattr(tokenizer, 'pad_token') and tokenizer.pad_token is not None:
special_tokens.add(tokenizer.pad_token)
if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token is not None:
special_tokens.add(tokenizer.eos_token)
if hasattr(tokenizer, 'bos_token') and tokenizer.bos_token is not None:
special_tokens.add(tokenizer.bos_token)
if hasattr(tokenizer, 'cls_token') and tokenizer.cls_token is not None:
special_tokens.add(tokenizer.cls_token)
if hasattr(tokenizer, 'unk_token') and tokenizer.unk_token is not None:
special_tokens.add(tokenizer.unk_token)
if hasattr(tokenizer, 'sep_token') and tokenizer.sep_token is not None:
special_tokens.add(tokenizer.sep_token)
if hasattr(tokenizer, 'mask_token') and tokenizer.mask_token is not None:
special_tokens.add(tokenizer.mask_token)
if output is not None:
decode_tokens, output_logits, full_logits = output
resp_sentences = []
Expand All @@ -466,25 +483,31 @@ def generate(
if not isinstance(tokenizer, TabularTokenizer):
words = []
for token in decode_token:
# Skip any soft prompt pseudo tokens
if token not in tokenizer.tokenizer.decoder:
continue
word = tokenizer.tokenizer.decoder[token]
word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace'
)
if not isinstance(token, Iterable):
token = [token]
word = tokenizer.ids_to_tokens(token)
if isinstance(word, Iterable):
word = word[0]
if hasattr(tokenizer.tokenizer, 'byte_decoder'):
word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace'
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to replace

        # offsets calculation
        all_offsets = []
        for item in resp_sentences_seg:
            offsets = [0]
            for index, token in enumerate(item):
                if index != len(item) - 1:
                    offsets.append(len(token) + offsets[-1])
            all_offsets.append(offsets)

with

        # offsets calculation
        all_offsets = []
        for item in resp_sentences_seg:
            offsets = [0]
            for index, token in enumerate(item):
                if index != len(item) - 1:
                    if token in special_tokens:
                        offsets.append(offsets[-1])
                    else:
                        offsets.append(len(token) + offsets[-1])
            all_offsets.append(offsets)

at line 481

words.append(word)
resp_sentences_seg.append(words)
else:
words = tokenizer.text_to_tokens(sentence)
resp_sentences_seg.append(words)

# offsets calculation
all_offsets = []
for item in resp_sentences_seg:
offsets = [0]
for index, token in enumerate(item):
if index != len(item) - 1:
offsets.append(len(token) + offsets[-1])
if token in special_tokens:
offsets.append(offsets[-1])
else:
offsets.append(len(token) + offsets[-1])
all_offsets.append(offsets)

output = {}
Expand Down