Skip to content

Add option to normalize loss per target#326

Merged
Muennighoff merged 11 commits intot0loadingfrom
lossseq
Nov 3, 2022
Merged

Add option to normalize loss per target#326
Muennighoff merged 11 commits intot0loadingfrom
lossseq

Conversation

@Muennighoff
Copy link
Collaborator

No description provided.

@Muennighoff Muennighoff changed the title TMP: Lossseq Add option to normalize loss per target Aug 15, 2022
@Muennighoff Muennighoff requested a review from thomasw21 August 15, 2022 15:08
@Muennighoff Muennighoff requested a review from thomasw21 August 17, 2022 10:32
)

if args.norm_target_loss:
loss_mask = loss_mask.view(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a fun hack you can do, view have the same storage space as the initial model. so you can probably write something like:

def fast_normalize(loss_mask: torch.Tensor):
    """
    Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3]
    """
    flatten_view = loss_mask.view(-1)
    _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True)
    counts = torch.gather(dim=0, index=inverse_indices, input=counts)
    flatten_view.div_(counts)
    return loss_mask

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could also clone before doing this operation so that you actually don't make fast_normalize a in-place operation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is

def fast_normalize(loss_mask: torch.Tensor):
    """
    Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3]
    """
    flatten_view = loss_mask.view(-1)
    _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True)
    counts = torch.gather(dim=0, index=inverse_indices, input=counts)
    flatten_view.div_(counts)
    return loss_mask

better than

def fast_normalize(loss_mask: torch.Tensor):
    """
    Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3]
    """
    _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True)
    counts = torch.gather(dim=0, index=inverse_indices, input=counts)
    return loss_mask / counts

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the latter work if loss_mask is not 1D?

@Muennighoff Muennighoff merged commit 1e77844 into t0loading Nov 3, 2022
@Muennighoff Muennighoff deleted the lossseq branch November 3, 2022 17:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants