diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py index b2df026a8..a62de3c03 100644 --- a/fast_llm/models/multimodal/config.py +++ b/fast_llm/models/multimodal/config.py @@ -80,6 +80,7 @@ class PretrainedMultiModalModelConfig(PretrainedGPTModelConfig): @config_class(dynamic_type={RunnableConfig: "train_multimodal", TrainerConfig: "multimodal"}) class MultiModalTrainerConfig(PretrainedMultiModalModelConfig, GPTTrainerConfig): + batch: MultiModalBatchConfig = FieldUpdate() # TODO: Use dynamic model type? reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldUpdate() diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c431bb26d..6383e6aae 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -39,7 +39,7 @@ def _reverse_kl_loss( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() + loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum() return loss