Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Fix transformer training #314

@marksgraham

Description

@marksgraham

I've noticed a bug in the way we're training transformers - currently, we are training them to predict token N, given tokens 0-N inclusive!

To fix this we need to shift the targets by one, so

loss = ce_loss(logits, latent)

becomes

loss = ce_loss(logits[:,:,:-1], quantizations_target[:,1:])

Alternatively we could change the return value of the inferer so it provides the shifted versions - any opinions @Warvito ?

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions