Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
afce64e
make nemo recognize sequence_parallel_size
xrennvidia Jun 6, 2023
3f98473
merge with main
xrennvidia Jun 6, 2023
e313000
add helper functions to set up SP running in TE
xrennvidia Jun 6, 2023
52ff102
Merge branch 'main' into xren/extend_sp
xrennvidia Jun 6, 2023
5580955
slice seq length for a specific rank
xrennvidia Jun 8, 2023
887c615
Merge branch 'main' into xren/extend_sp
xrennvidia Jun 8, 2023
ebd6323
fix data_parallel_size calculation
xrennvidia Jun 8, 2023
58cca3d
minor change
xrennvidia Jun 8, 2023
87f027a
add missing argument of self
xrennvidia Jun 8, 2023
9ebfcf7
pass sp_global_ranks to TE transformer layer
xrennvidia Jun 8, 2023
728fd43
fix nsys setting
xrennvidia Jun 9, 2023
66615e8
fix seq_len calculation
xrennvidia Jun 13, 2023
e1f5eb7
fix attn_mask split across seq-length dim
xrennvidia Jun 17, 2023
cf0c75c
code update of input split
xrennvidia Jun 17, 2023
b57e218
fix loss calculation
xrennvidia Jun 21, 2023
69f4ae8
fix loss_mask_sum calculation
xrennvidia Jun 21, 2023
a38dd9a
fix losss calculation
xrennvidia Jun 22, 2023
b31e31f
merge with main
xrennvidia Jun 22, 2023
8ac42f1
rename sequence_parallelism to context_parallelism
xrennvidia Jun 22, 2023
f7c9b5b
minor change
xrennvidia Jun 24, 2023
49b1052
fix loss_mask_sum calculation
xrennvidia Jun 24, 2023
ae889fc
merge with main
xrennvidia Aug 1, 2023
2c43687
make sure do not call megatron-core parallel_state while cp_size is 1
xrennvidia Aug 3, 2023
25bf369
Merge branch 'main' into xren/context_parallelism
xrennvidia Aug 3, 2023
61af551
slice position embedding for different CP rank
xrennvidia Aug 3, 2023
dc8a540
fix mising property decorator
xrennvidia Aug 3, 2023
46479c6
typo fix
xrennvidia Aug 3, 2023
b64b563
fix rpe_bias CP slicing
xrennvidia Aug 4, 2023
0362de6
Merge branch 'main' into xren/context_parallelism
xrennvidia Aug 6, 2023
e1654fb
code style fix
xrennvidia Aug 6, 2023
4f0a3be
fix loss_mask_sum calculation
xrennvidia Aug 8, 2023
c46b42e
Merge branch 'main' into xren/context_parallelism
xrennvidia Aug 8, 2023
b6db8f3
merge with main
xrennvidia Aug 21, 2023
4076d06
do not load attention mask if it's not needed
xrennvidia Aug 22, 2023
3353e13
cherry pick attention mask data loader skip
xrennvidia Aug 22, 2023
433f6a7
bug fix
xrennvidia Aug 23, 2023
c4592e8
Merge branch 'main' into xren/context_parallelism
xrennvidia Aug 25, 2023
5efaa76
fix ubuf size with CP > 1
xrennvidia Sep 5, 2023
006677d
address naming confusion of mixed dp and cp
xrennvidia Sep 14, 2023
d64b85d
merge with main
xrennvidia Sep 14, 2023
499f0d6
Merge branch 'main' into xren/context_parallelism
xrennvidia Sep 25, 2023
693b8b7
Merge branch 'main' into xren/context_parallelism
xrennvidia Sep 30, 2023
0f7d079
rewrite cp code by assuming with_context_parallel=False
xrennvidia Oct 3, 2023
4dcfdb6
Merge branch 'main' into xren/context_parallelism
xrennvidia Oct 3, 2023
3351953
pop context_parallel from dist opt kwargs
xrennvidia Oct 3, 2023
08f785b
make sure amax reduction group is aware of context parallelism
xrennvidia Oct 5, 2023
e277b3d
remove use_fp8 from initialize_model_parallel
xrennvidia Oct 5, 2023
a27155c
Merge branch 'main' into xren/context_parallelism
xrennvidia Oct 5, 2023
dc65d34
make implementaitons of setup_transformer_engine_tp_groups and setup_…
xrennvidia Oct 6, 2023
42a6b83
Merge branch 'main' into xren/context_parallelism
xrennvidia Oct 10, 2023
5013189
cp function renaming
xrennvidia Oct 11, 2023
52dd50b
make loss logging broadcast aware of cp
xrennvidia Oct 13, 2023
b61fa4e
Merge branch 'main' into xren/context_parallelism
xrennvidia Oct 13, 2023
52381e8
fix a typo
xrennvidia Oct 13, 2023
fb9cc3d
Merge branch 'xren/context_parallelism' of github.com:xrennvidia/NeMo…
xrennvidia Oct 13, 2023
1b92952
Merge branch 'main' into xren/context_parallelism
xrennvidia Oct 14, 2023
e394392
var name fix
xrennvidia Oct 14, 2023
f9bf0d8
import transformer layer specs from MCore
xrennvidia Oct 16, 2023
1f8815f
upgrade MCore version
xrennvidia Oct 17, 2023
f1bc1a7
Merge branch 'main' into xren/context_parallelism
xrennvidia Oct 17, 2023
a40b183
merge with main
xrennvidia Oct 17, 2023
d15ae17
add add context_parallel into the kwargs of dist opt
xrennvidia Oct 17, 2023
8ae9061
merge with main
xrennvidia Oct 19, 2023
6be25b9
Merge branch 'xren/context_parallelism' of github.com:NVIDIA/NeMo int…
xrennvidia Oct 19, 2023
4cbdb0e
Merge branch 'main' into xren/context_parallelism
xrennvidia Oct 25, 2023
55b7e13
remove redundant cp check
xrennvidia Oct 25, 2023
840103e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2023
03b2922
code style fix
xrennvidia Oct 25, 2023
14a589e
Merge branch 'xren/context_parallelism' of github.com:NVIDIA/NeMo int…
xrennvidia Oct 25, 2023
7c5b9c1
recover docker file
xrennvidia Oct 25, 2023
50d0385
Merge branch 'main' into xren/context_parallelism
xrennvidia Nov 1, 2023
45002d4
Merge branch 'main' into xren/context_parallelism
xrennvidia Nov 9, 2023
319e659
merge with main
xrennvidia Nov 17, 2023
071d234
Merge branch 'main' into xren/context_parallelism
xrennvidia Nov 18, 2023
baafb02
Merge branch 'main' into xren/context_parallelism
xrennvidia Nov 22, 2023
bf100fc
Merge branch 'main' into xren/context_parallelism
xrennvidia Nov 23, 2023
2da819e
fix seq_length of CP
xrennvidia Nov 27, 2023
cd7021a
merge with main
xrennvidia Dec 4, 2023
22eeaf9
recover seq-length which has been fixed in mcore
xrennvidia Dec 4, 2023
b56ce02
merge with main
xrennvidia Dec 16, 2023
3a29733
merge with main
xrennvidia Dec 18, 2023
5d25e67
function name fix
xrennvidia Dec 19, 2023
ead55a0
Merge branch 'main' into xren/context_parallelism
xrennvidia Dec 21, 2023
3a36003
merge with main
xrennvidia Jan 2, 2024
2d42b1c
Merge branch 'main' into xren/context_parallelism
xrennvidia Jan 3, 2024
f66a5aa
merge with main
xrennvidia Jan 4, 2024
5d464c9
Merge branch 'main' into xren/context_parallelism
xrennvidia Jan 5, 2024
2c9c95e
Merge branch 'main' into xren/context_parallelism
xrennvidia Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ def get_samples_mapping(
)
torch.distributed.barrier()
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True))
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size()
Expand Down
48 changes: 33 additions & 15 deletions nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def __init__(
self.indexed_dataset = indexed_dataset
self.drop_last = drop_last
self.seq_length = seq_length
self.get_attention_mask_from_fusion = cfg.get('get_attention_mask_from_fusion', True)

# Checks
assert np.min(documents) >= 0
Expand Down Expand Up @@ -433,13 +434,21 @@ def __getitem__(self, idx):
logging.debug('Got negative index. Masking loss from this sample')
loss_mask = torch.zeros_like(loss_mask)

return {
'tokens': tokens,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
}
if self.get_attention_mask_from_fusion:
return {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'position_ids': position_ids,
}
else:
return {
'tokens': tokens,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
}


class MockGPTDataset(Dataset):
Expand All @@ -457,6 +466,7 @@ def __init__(
self.vocab_size = tokenizer.vocab_size
self.length = num_samples
self.seed = seed
self.get_attention_mask_from_fusion = cfg.get('get_attention_mask_from_fusion', True)

self.attention_mask = torch.tril(torch.ones((self.seq_length, self.seq_length))).unsqueeze(0)
self.attention_mask = self.attention_mask < 0.5
Expand All @@ -476,13 +486,21 @@ def __getitem__(self, idx):
tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64))
labels = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64))

return {
'tokens': tokens,
'labels': labels,
'attention_mask': self.attention_mask,
'loss_mask': self.loss_mask,
'position_ids': self.position_ids,
}
if self.get_attention_mask_from_fusion:
return {
'tokens': tokens,
'labels': labels,
'loss_mask': self.loss_mask,
'position_ids': self.position_ids,
}
else:
return {
'tokens': tokens,
'labels': labels,
'attention_mask': self.attention_mask,
'loss_mask': self.loss_mask,
'position_ids': self.position_ids,
}


@torch.no_grad()
Expand Down Expand Up @@ -674,7 +692,7 @@ def _build_index_mappings(

torch.distributed.barrier()
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True))
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1),
virtual_pipeline_model_parallel_size=vp_size,
pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0),
context_parallel_size=cfg.get('context_parallel_size', 1),
micro_batch_size=cfg.get('micro_batch_size'),
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size', None),
Expand Down Expand Up @@ -231,6 +232,27 @@ def setup_transformer_engine_tp_groups(self):
tp_group = parallel_state.get_tensor_model_parallel_group()
child.set_tensor_parallel_group(tp_group)

def setup_transformer_engine_cp_groups(self):
""" This should be called after context parallel groups have been initialized
and only needs to be called when using Transformer Engine.
"""
cp_stream = torch.cuda.Stream()

for module in self.get_model_module_list():
"""Set context parallel running
Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(module.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(
parallel_state.get_context_parallel_group(),
parallel_state.get_context_parallel_global_ranks(),
cp_stream,
)

def _wrap_model_for_O2(self):
""" Wraps self.model in a float16 wrapper if the model is using megatron amp O2.
Args:
Expand Down Expand Up @@ -556,8 +578,10 @@ def allreduce_gradients(self):
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = torch._utils._flatten_dense_tensors(grads)
coalesced /= parallel_state.get_data_parallel_world_size()
torch.distributed.all_reduce(coalesced, group=parallel_state.get_data_parallel_group())
coalesced /= parallel_state.get_data_parallel_world_size(with_context_parallel=True)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_data_parallel_group(with_context_parallel=True)
)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

Expand Down Expand Up @@ -633,7 +657,6 @@ def setup_optimization(
):
optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy()
if self.with_distributed_adam:

# Allocate contiguous buffer to avoid extra copies
optim_kwargs['contiguous_grad_buffer'] = True

Expand Down
62 changes: 50 additions & 12 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# Convert the global-batch-based profile index to micro-batch index
if hasattr(self, '_nsys_profile_enabled'):
mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1)
data_parallel_world_size = trainer.world_size // mp_size
cp_size = cfg.get('context_parallel_size', 1)
data_parallel_world_size = trainer.world_size // (mp_size * cp_size)
grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size)
self._nsys_profile_start_step *= grad_accum_steps
self._nsys_profile_end_step *= grad_accum_steps
Expand Down Expand Up @@ -553,7 +554,9 @@ def initialize_ub_func(self):
)

input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('encoder_seq_length')
* self.cfg.get('micro_batch_size')
// self.cfg.get('context_parallel_size', 1),
self.cfg.get('hidden_size'),
]

Expand Down Expand Up @@ -834,6 +837,32 @@ def __next__(self):
# TODO @tmoon: Use once available in Megatron-LM
# return DataIteratorList(iters)

def get_batch_on_this_context_parallel_rank(self, batch):
cp_size = self.cfg.get('context_parallel_size', 1)
num_valid_tokens_in_ub = None
if 'loss_mask' in batch and batch['loss_mask'] is not None:
num_valid_tokens_in_ub = batch['loss_mask'].sum()

if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub

return batch

def get_forward_output_and_loss_func(self, validation_step=False):
def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None):

Expand All @@ -852,15 +881,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
required_keys.update(('tokens', 'position_ids'))
if parallel_state.is_pipeline_last_stage():
required_keys.update(('labels', 'loss_mask'))
if self.get_attention_mask_from_fusion:
if self.get_attention_mask_from_fusion and 'attention_mask' in required_keys:
required_keys.remove('attention_mask')
batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()}

batch = self.get_batch_on_this_context_parallel_rank(batch)

# Model forward pass
forward_args = {
'input_ids': batch['tokens'],
'position_ids': batch['position_ids'],
'attention_mask': batch['attention_mask'],
'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'],
'labels': batch['labels'],
'loss_mask': batch['loss_mask'],
}
Expand All @@ -885,9 +916,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_

def loss_func(output_tensor):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], output_tensor)
loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor)
cp_size = self.cfg.get('context_parallel_size', 1)
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_tokens_in_ub = batch['loss_mask'].sum()
num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub']
if loss_for_ub.isnan():
assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input'
loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub)
Expand All @@ -904,10 +936,10 @@ def loss_func(output_tensor):
torch.distributed.all_reduce(
loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group()
)
return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
else:
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
return loss_for_ub, {'avg': reduced_loss}
return loss_for_ub * cp_size, {'avg': reduced_loss}

return output_tensor, loss_func

Expand Down Expand Up @@ -1007,10 +1039,11 @@ def on_validation_epoch_end(self):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
cp_size = parallel_state.get_context_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * cp_size * tp_size)
last_pipeline_stage_offset = (tp_size * cp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
torch.distributed.broadcast(
averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
Expand All @@ -1029,11 +1062,14 @@ def on_test_epoch_end(self):
logging.info(f'test_loss: {averaged_loss[0]}')
self.test_step_outputs.clear() # free memory

def loss_func(self, loss_mask, output_tensor):
def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
# TODO: add nemo version here
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll
cp_size = self.cfg.get('context_parallel_size', 1)
if cp_size > 1:
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
return loss

def build_train_valid_test_datasets(self):
Expand Down Expand Up @@ -1185,6 +1221,7 @@ def setup(self, stage=None):

if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False):
self.setup_transformer_engine_tp_groups()
self.setup_transformer_engine_cp_groups()

def setup_training_data(self, cfg):
if hasattr(self, '_train_ds'):
Expand Down Expand Up @@ -1243,6 +1280,7 @@ def dummy():

if self.cfg.get('transformer_engine', False):
self.setup_transformer_engine_tp_groups()
self.setup_transformer_engine_cp_groups()

# set the default sampling params if it is None.
# default do greedy sampling
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/nlp/modules/common/megatron/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def build_model(
i = torch.cuda.current_device()
model = [
torch.nn.parallel.distributed.DistributedDataParallel(
model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(),
model_module,
device_ids=[i],
output_device=i,
process_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
for model_module in model
]
Expand Down
25 changes: 25 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.sequence_parallel = config.sequence_parallel
self.context_parallel = parallel_state.get_context_parallel_world_size() > 1
if kv_channels is None:

assert (
Expand Down Expand Up @@ -722,6 +723,19 @@ def set_input_tensor(self, input_tensor):

self.encoder.set_input_tensor(input_tensor[0])

def get_position_embedding_on_this_context_parallel_rank(self, position_embedding, seq_dim):
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=position_embedding.device)
position_embedding = position_embedding.view(
*position_embedding.shape[:seq_dim], 2 * cp_size, -1, *position_embedding.shape[(seq_dim + 1) :]
)
position_embedding = position_embedding.index_select(seq_dim, cp_idx)
position_embedding = position_embedding.view(
*position_embedding.shape[:seq_dim], -1, *position_embedding.shape[(seq_dim + 2) :]
)
return position_embedding

def forward(
self,
enc_input_ids,
Expand Down Expand Up @@ -775,10 +789,16 @@ def forward(
else:
enc_seq_length = encoder_input.size(0)

if self.context_parallel:
enc_seq_length = enc_seq_length * parallel_state.get_context_parallel_world_size()

rotary_pos_emb = None
encoder_self_attention_relative_position_bias = None
if self.position_embedding_type == 'rope':
rotary_pos_emb = self.rotary_pos_emb(enc_seq_length)

if self.context_parallel:
rotary_pos_emb = self.get_position_embedding_on_this_context_parallel_rank(rotary_pos_emb, 0)
elif (
self.position_embedding_type == 'alibi'
or self.position_embedding_type == 'sandwich'
Expand All @@ -790,6 +810,11 @@ def forward(
# causal attention bias: [1, head, 1, k]
# non-causal attention bias: [1, head, q, k]

if self.context_parallel and encoder_self_attention_relative_position_bias.shape[-2] > 1:
encoder_self_attention_relative_position_bias = self.get_position_embedding_on_this_context_parallel_rank(
encoder_self_attention_relative_position_bias, 2
)

# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(
Expand Down
Loading