Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 86 additions & 68 deletions nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch.distributed.tensor import DTensor, distribute_tensor


@torch.no_grad()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary to remove the @torch.no_grad()?

def _compute_distributed_log_softmax(
vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup
) -> torch.Tensor:
Expand All @@ -36,25 +35,67 @@ def _compute_distributed_log_softmax(
log probabilities normalized across the full vocabulary dimension.
"""
logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True)
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=group,
)

if group is not None:
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=group,
)

# Subtract the maximum value.
vocab_parallel_logits = vocab_parallel_logits - logits_max

sum_exp_logits = vocab_parallel_logits.exp().sum(-1, keepdim=True).float()

torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=group,
)
if group is not None:
torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=group,
)

return vocab_parallel_logits - sum_exp_logits.log_().to(vocab_parallel_logits.dtype)

@torch.compile(fullgraph=True)
def distributed_logprob_forward(vocab_parallel_logits, target, vocab_start_index, vocab_end_index, tp_group):
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target - vocab_start_index
masked_target[target_mask] = 0

vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32)

log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=tp_group)
softmax_output = log_probs.exp()

log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1)
log_probs[target_mask] = 0.0

if tp_group is not None:
torch.distributed.all_reduce(
log_probs,
op=torch.distributed.ReduceOp.SUM,
group=tp_group,
)
return log_probs, softmax_output

@torch.compile(fullgraph=True)
def distributed_logprob_backward(grad_output, target, softmax, vocab_start_index, vocab_end_index):
V = softmax.size(-1)

target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target - vocab_start_index
masked_target[target_mask] = 0

is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot(
masked_target, num_classes=V
)

grad_input = is_chosen.float().sub_(softmax)
grad_input.mul_(grad_output.unsqueeze(-1))

return grad_input

class DistributedLogprob(torch.autograd.Function):
"""Custom autograd function for computing log probabilities in a distributed setting.
Expand All @@ -72,28 +113,16 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func
group: torch.distributed.ProcessGroup,
inference_only: bool = False,
) -> torch.Tensor:
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target - vocab_start_index
masked_target[target_mask] = 0

vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32)

log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group)
softmax_output = log_probs.exp()

log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1)
log_probs[target_mask] = 0.0

torch.distributed.all_reduce(
log_probs,
op=torch.distributed.ReduceOp.SUM,
group=group,

log_probs, softmax = distributed_logprob_forward(
vocab_parallel_logits, target, vocab_start_index, vocab_end_index, group
)

if not inference_only:
# only save for backward when we have inference only=False
ctx.save_for_backward(softmax_output, target_mask, masked_target)
ctx.save_for_backward(softmax, target)
ctx.vocab_start_index = vocab_start_index
ctx.vocab_end_index = vocab_end_index


return log_probs

Expand All @@ -103,40 +132,18 @@ def backward(
*grad_outputs: torch.Tensor,
) -> tuple[torch.Tensor, None, None, None, None, None, None]:
grad_output = grad_outputs[0]
softmax, target_mask, masked_target = ctx.saved_tensors
softmax, target = ctx.saved_tensors
vocab_start_index = ctx.vocab_start_index
vocab_end_index = ctx.vocab_end_index

if softmax.ndim == 3:
B, S, V = softmax.shape

# skip `torch.nn.functional.one_hot`
row = (
torch.arange(B, device=softmax.device)
.view(-1, 1)
.expand(-1, S)
.reshape(-1)
)
col = torch.arange(S, device=softmax.device).expand(B, -1).reshape(-1)
flat_idx = (row * S + col) * V

flat_chosen = flat_idx.masked_select(
~target_mask.reshape(-1)
) + masked_target.masked_select(~target_mask)

# `neg` is zero-copy
grad_input = softmax.neg()
grad_input = grad_input.mul_(grad_output.unsqueeze(-1))

grad_output_selected = grad_output.masked_select(~target_mask)
grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected)
else:
V = softmax.size(-1)
is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot(
masked_target, num_classes=V
)
grad_input = is_chosen.float().sub_(softmax)
grad_input.mul_(grad_output.unsqueeze(-1))
grad_input = distributed_logprob_backward(
grad_output,
target,
softmax,
vocab_start_index,
vocab_end_index,
)

# if you add an argument to the forward method, then you must add a corresponding None here
return grad_input, None, None, None, None, None, None


Expand All @@ -162,11 +169,6 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
) -> torch.Tensor:
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target - vocab_start_index
masked_target[target_mask] = 0

seq_size = int(vocab_parallel_logits.shape[1])
num_chunks = (seq_size + chunk_size - 1) // chunk_size
all_log_probs = []
Expand All @@ -175,6 +177,22 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func
chunk_start = chunk_idx * chunk_size
chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size)

vocab_parallel_logits_chunk = vocab_parallel_logits[:, chunk_start:chunk_end, :]
target_chunk = target[:, chunk_start:chunk_end]

# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask_chunk = (target_chunk < vocab_start_index) | (target_chunk >= vocab_end_index)
masked_target_chunk = target - vocab_start_index
masked_target_chunk[target_mask_chunk] = 0

distributed_logprob_forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not returning anything? and is the consequent line 196-217 necessary or should be removed?

vocab_parallel_logits_chunk,
target_chunk,
vocab_start_index,
vocab_end_index,
tp_group,
)

logits = vocab_parallel_logits[:, chunk_start:chunk_end, :]
logits = logits.to(dtype=torch.float32)

Expand Down Expand Up @@ -229,7 +247,7 @@ def backward(
logits = vocab_parallel_logits[:, chunk_start:chunk_end, :]
logits = logits.to(dtype=torch.float32)

softmax_output = _compute_distributed_log_softmax(
softmax_output = _distributed_logprob_forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this name is not found, is it distributed_logprob_forward?

logits,
group=tp_group,
)
Expand Down
Loading