Some transformers (like x-transformers) take in a sequence of length (seq_len+1), then split it into input=x[:-1] and target=x[1:], and calculate the loss directly in forward(). This is efficient because the input and targets overlap. It means that forward() returns the loss, rather than the targets.
It would be nice if coord_check had an option that supported this usecase, where forward() returns the loss directly. Like adding loss_from_forward to the function signatures, and inserting this:
elif loss_from_forward:
if cuda:
batch = batch.cuda()
loss = model(batch)
at
Some transformers (like x-transformers) take in a sequence of length (seq_len+1), then split it into input=x[:-1] and target=x[1:], and calculate the loss directly in forward(). This is efficient because the input and targets overlap. It means that
forward()returns the loss, rather than the targets.It would be nice if coord_check had an option that supported this usecase, where forward() returns the loss directly. Like adding
loss_from_forwardto the function signatures, and inserting this:at
mup/mup/coord_check.py
Line 317 in 1981497