Currently, we mask out examples whose sequence length exceeds the max. But this can have some unintended side effects, at least for SFT (e.g. if running with MBS 1, the entire microbatch will end up with 0 loss, which skews the average loss for the batch). We should think about whether there's a better way to handle this. One option is truncating the long examples rather than masking.
Currently, we mask out examples whose sequence length exceeds the max. But this can have some unintended side effects, at least for SFT (e.g. if running with MBS 1, the entire microbatch will end up with 0 loss, which skews the average loss for the batch). We should think about whether there's a better way to handle this. One option is truncating the long examples rather than masking.