Is your feature request related to a problem? Please describe.
Flax version of the training script under example did not implement different loss target for v_prediction.
Describe the solution you'd like
Implement v_prediction
Describe alternatives you've considered
I could probably implement it based on pytorch code, but I haven't had time to read and fully understand the paper, so might mess up. Just filing this here so people know flax version is different from pytorch version and will break when training on v_prediction models.