From f2ed61440cd314090043b2b75a4da68f54c1d3a1 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sun, 7 Dec 2025 19:01:06 +0000 Subject: [PATCH 1/3] multimodal batch config --- fast_llm/models/multimodal/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py index b2df026a8..a62de3c03 100644 --- a/fast_llm/models/multimodal/config.py +++ b/fast_llm/models/multimodal/config.py @@ -80,6 +80,7 @@ class PretrainedMultiModalModelConfig(PretrainedGPTModelConfig): @config_class(dynamic_type={RunnableConfig: "train_multimodal", TrainerConfig: "multimodal"}) class MultiModalTrainerConfig(PretrainedMultiModalModelConfig, GPTTrainerConfig): + batch: MultiModalBatchConfig = FieldUpdate() # TODO: Use dynamic model type? reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldUpdate() From a545deef1a4abac7d5a31e21c3b15d496a83d775 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sun, 7 Dec 2025 19:01:28 +0000 Subject: [PATCH 2/3] lm head test fix --- tests/layers/test_lm_head.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c431bb26d..d2a3f3d6d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -39,7 +39,7 @@ def _reverse_kl_loss( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() + loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum() return loss @@ -344,3 +344,7 @@ def test_lm_head( Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) + + +if __name__ == "__main__": + pytest.main([__file__]) From e6db7a0d91149337d6f45e621358833a7e86aaf3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sun, 7 Dec 2025 19:02:30 +0000 Subject: [PATCH 3/3] clean --- tests/layers/test_lm_head.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index d2a3f3d6d..6383e6aae 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -344,7 +344,3 @@ def test_lm_head( Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) - - -if __name__ == "__main__": - pytest.main([__file__])