Skip to content

Add Molmo2#43451

Open
SangbumChoi wants to merge 48 commits intohuggingface:mainfrom
SangbumChoi:molmo2
Open

Add Molmo2#43451
SangbumChoi wants to merge 48 commits intohuggingface:mainfrom
SangbumChoi:molmo2

Conversation

@SangbumChoi
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@merveenoyan merveenoyan requested a review from molbap February 27, 2026 05:55
Adds AllenAI Molmo2 multimodal VLM to transformers, supporting:
- Molmo2ForConditionalGeneration (image+video+text → text)
- Molmo2TextModel / Molmo2TextForCausalLM (text-only)
- Molmo2ImageProcessor and Molmo2VideoProcessor
- Molmo2Processor

Key implementation details:
- Uses is_first_iteration (v5 API) for prepare_inputs_for_generation
- Custom Molmo2Embedding with embedding + new_embedding parameters
- Vision backbone with pooling adapter and multi-layer ViT features
- Dynamic full cache support for generation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
SangbumChoi and others added 14 commits March 27, 2026 08:56
…odel_prefix

- Replace einops.rearrange with native numpy reshape+transpose+reshape
- Add @strict decorator to all 4 config classes (Molmo2VitConfig,
  Molmo2AdapterConfig, Molmo2TextConfig, Molmo2Config) to satisfy TRF010
- Set Molmo2Model.base_model_prefix = "model" (was empty, violating TRF002)
- Fix image_mean/image_std mutable shared list (copy constants on init)
- Fix test_image_processing: use image_processing_class instead of
  image_processor_list; skip CHW torch and 4-channel unsupported tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Re-sort _toctree.yml to place Molmo2 after mllama alphabetically
- Add None guard in test_video_processor_from_dict_with_kwargs to skip
  when fast_video_processing_class is not defined

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Molmo2TextModel is an internal sub-component used by Molmo2Model and
Molmo2ForConditionalGeneration and is tested implicitly through those.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
requests is not part of the standard library and caused ImportError in
minimal environments (e.g. HuggingFace Jobs). Use urllib.request instead.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Molmo2's processor has several behaviors that are incompatible with the
default ProcessorTesterMixin assumptions:
- Chat template enforces strict user/assistant alternation (no system role)
- Processor inserts BOS token, shifting sequence length by 1
- Image processor patchifies output, so rescale_factor passthrough fails
- Video processor requires FPS metadata not provided by base tests
- Hub processor_config.json contains auto_map not preserved in save/load

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add @auto_docstring(checkpoint="allenai/Molmo2-8B") decorator to
Molmo2TextConfig and Molmo2Config with custom_args for documenting
non-standard parameters. This fixes check_config_docstrings CI check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… date

Add parameter docstrings to Molmo2TextConfig and Molmo2Config __init__
methods so @strict-wrapped classes pass config docstring CI checks.
Update model doc date to 2026-03-28.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move top-level `import torch` and `import torchvision.transforms` behind
`is_torch_available()` / `is_torchvision_available()` guards in both
image and video processors to prevent ModuleNotFoundError when
torchvision is not installed.

Also skip test_kwargs_overrides_default_image_processor_kwargs since
Molmo2's patchifying image processor doesn't support rescale_factor
passthrough.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Convert all absolute imports (from transformers.xxx) to relative imports
(from ...xxx) in image_processing, video_processing, and processing
modules to match the convention used by all other in-library models.

Remove register_for_auto_class() calls which are only needed for custom
hub models and were causing dynamic_module_utils to incorrectly scan
local files for relative imports during save_pretrained.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…n_available

The processor's top-level imports from image_processing_molmo2 and
video_processing_molmo2 pull in PILImageResampling which requires PIL.
Guard these imports with is_vision_available() so `from transformers
import *` works when only torch is installed (no PIL/torchvision).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…L imports

Move Molmo2ImagesKwargs and Molmo2VideosKwargs definitions directly into
processing_molmo2.py instead of importing them from image/video processor
modules which require PIL. Also remove Molmo2ImageProcessor/VideoProcessor
type hints from __init__ to avoid NameError when vision is unavailable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@SangbumChoi
Copy link
Copy Markdown
Contributor Author

@molbap Hi I am still working on it since I have to make some example visualizer for this and (most of the code is generated by Claude code). However, you can start review this with brief level of code review! cc. @merveenoyan

Add integration tests for Molmo2-8B covering:
- Image generation with exact expected text verification
- Video QA (penguin identification)
- Video pointing (coordinate output)
- Multi-image comparison

All expected values derived from actual model inference on A10G.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@zucchini-nlp zucchini-nlp self-requested a review March 30, 2026 15:23
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @SangbumChoi

Great model to add to Transformers. After reviewing I see that using modular would be much better since a lot of part are copy-paste from different models. I left comments on each class about where it can be copied from. Apart from that, there are a few places where we need to clean up and align API with the rest of VLMs for consistency

If you have q, ping me on slack. I will unsubscribe myself from this PR to not get notif about each commit, so when you want another review ping me again by @

Comment thread docs/source/en/model_doc/molmo2.md Outdated
Comment thread docs/source/en/model_doc/molmo2.md Outdated
Comment thread docs/source/en/model_doc/molmo2.md Outdated
Comment thread src/transformers/models/molmo2/configuration_molmo2.py
Comment on lines +20 to +31
r"""
This is the configuration class to store the configuration of a [`Molmo2VisionTransformer`].
It is used to instantiate a `Molmo2VisionTransformer` according to the specified arguments,
defining the model architecture.

Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.

Example:
```python
>>> from transformers import Molmo2VitConfig, Molmo2VisionTransformer

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use @auto_doctring for configs

Comment on lines +26 to +27
if is_torch_available():
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dummy import

Comment on lines +51 to +55
# =====================================================================
# Molmo2 chat template enforces strict user/assistant alternation and
# does not support the "system" role used by the base test harness.
# =====================================================================
def test_apply_chat_template_decoded_video_0(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can override it when init a dummy processor in setup

Comment on lines +73 to +77
# =====================================================================
# Molmo2Processor.insert_bos() prepends a BOS token, so token count
# differs by 1 from raw tokenizer output. This is by design.
# =====================================================================
@unittest.skip("Molmo2 processor inserts BOS token, causing mismatch with raw tokenizer")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of skipping, lets override when needed.

Comment on lines +97 to +103
# =====================================================================
# Hub model has auto_map in processor_config.json which is not preserved
# through save/load cycle. Also use_single_crop_col_tokens default differs.
# =====================================================================
@unittest.skip("Molmo2 image processor patchifies output; rescale_factor passthrough not supported")
def test_image_processor_defaults_preserved_by_image_kwargs(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, instead of skipping, lets override when needed this and the rest as well

Comment thread tests/test_video_processing_common.py Outdated
Comment on lines +131 to +132
if self.fast_video_processing_class is None:
self.skipTest("No fast video processor class defined")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not, each tester should have a fast_video_processing_class as the only possible class

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @SangbumChoi

Great model to add to Transformers. After reviewing I see that using modular would be much better since a lot of part are copy-paste from different models. I left comments on each class about where it can be copied from. Apart from that, there are a few places where we need to clean up and align API with the rest of VLMs for consistency

If you have q, ping me on slack. I will unsubscribe myself from this PR to not get notif about each commit, so when you want another review ping me again by @

SangbumChoi and others added 14 commits April 6, 2026 08:17
- Remove unused _flash_attention_forward and flash_attn_supports_top_left_mask
  imports from modeling_molmo2.py (no longer needed after attention refactor)
- Move Molmo2AdapterConfig, Molmo2TextConfig, Molmo2VitConfig imports from
  lazy in-function imports to top-level in test_modeling_molmo2.py per review

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Fix TRF013 modeling structure violation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment thread docs/source/en/model_doc/molmo2.md Outdated
Comment thread docs/source/en/model_doc/molmo2.md Outdated
Comment thread docs/source/en/model_doc/molmo2.md Outdated
Comment thread src/transformers/models/molmo2/configuration_molmo2.py
Comment thread src/transformers/models/molmo2/configuration_molmo2.py Outdated
Comment thread src/transformers/models/molmo2/image_processing_molmo2.py
rope_scaling: dict[str, Any] | None = None
rope_scaling_layers: list[int] | None = None
use_qk_norm: bool = False
qk_norm_type: str = "olmo"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check what is olmo norm type

Comment on lines +219 to +222
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well convert_images only treat single image isn't it?

Comment thread src/transformers/models/molmo2/modeling_molmo2.py
Comment thread src/transformers/models/molmo2/modeling_molmo2.py
SangbumChoi and others added 4 commits April 11, 2026 10:23
Adopt auto_docstring on Molmo2Processor/__call__, simplify model_input_names
to inherit tokenizer + image_processor keys plus token_type_ids, and drop
deprecated frame_sample_mode/sampling_fps from Molmo2VideosKwargs and legacy
attribute declarations. Override prepare_processor_dict in the processor test
with a system-role-aware chat template, skip chat-template tests that assume
batch-dim pixel_values (Molmo2 concatenates crops), and relax test_model_input_names
to a subset check since video keys are absent in image-only runs. Drop the
test_generate_with_past_key_values skip since image features are cached in the
KV cache like other VLMs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
SangbumChoi and others added 3 commits April 18, 2026 11:54
…nd processor init

Fill in docstring entries for Molmo2ImagesKwargs, Molmo2VideosKwargs, and
Molmo2VideoProcessorKwargs TypedDicts, and document the five custom init
args of Molmo2Processor, so that make fix-repo / check_docstrings passes
without placeholder stubs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…inference, add tie_word_embeddings

- Add molmo2/molmo2_text to auto_mappings.py CONFIG_MAPPING_NAMES so
  AutoConfig.from_pretrained and check_repo.py doc-match checks work
- Add molmo2 to HARDCODED_CONFIG_FOR_MODELS in auto_docstring.py to
  silence repeated 'Config not found' errors during repo checks
- Add tie_word_embeddings: bool = False to Molmo2Config class and
  docstring to satisfy TRF015 modeling structure check
- Pass input_data_format=ChannelDimension.LAST explicitly to all
  normalize() calls in image/video processors; fixes ValueError
  'Unable to infer channel dimension format' when images have
  non-standard channel counts (e.g. RGBA) where infer_channel_
  dimension_format's default num_channels=(1,3) can't match

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, molmo2

…ensor inputs

resize_image() and build_overlapping_crops() assume HWC (channels-last)
layout. When callers pass CHW numpy arrays or torch tensors (e.g. frames
from torchvision / OpenCV→tensor pipelines at 960×540), the width was
misinterpreted as the channel count, causing:
  ValueError: mean must have 960 elements if it is an iterable, got 3

Fix: after to_numpy_array(), infer the channel dimension and transpose to
ChannelDimension.LAST before any spatial processing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @SangbumChoi !

The model arch is huge and more involved than vanilla VLMs, so I think we need a couple more iterations. My main comment is the usage of explicit device in text modules, and refactoring more image/video processors. I hope we can actually move the biiig build_image_inputs fn from model to processing, as it simply tries to split pixels per each input text. Requiring nested image inputs will help with that and we should be able to return already "built" pixel values from image processor

After that, it'd be great to add the basic helpers such as get_image_features, get_placeholder, etc. Some third party libraries started rely on it to get encoded images/videos

Comment on lines +48 to +49
torch_dtype=torch.bfloat16,
device_map="auto",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ultra nit: in v5 we load model in auto-dtype so we can skip passing torch_dtype

),
("mobilevitv2", {"torchvision": "MobileViTImageProcessor", "pil": "MobileViTImageProcessorPil"}),
("molmo2", {"torchvision": "Molmo2ImageProcessor"}),
("nougat", {"torchvision": "NougatImageProcessor", "pil": "NougatImageProcessorPil"}),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nougat? Also I think Molmo2 would have been in auto_mapping.py, does it not get added after running python utils/check_auto.py --fix_and_overwrite 🤔 ?

Comment on lines +42 to +51
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
"""Resize an image and rescale to [0, 1] float32."""
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh very interesting, fast processors have self.resize with same functionality, and I can't think of cases when we get a numpy image

I will comment below about possible reason, lmk if processor is starting with numpy even after fixing it

patch_size = 14
pooling_size = [2, 2]

def __init__(self, **kwargs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to type annotate kwargs with Unpack[Molmo2ImagesKwargs] to auto-docstring

Comment on lines +366 to +368
def preprocess(
self,
images: ImageInput,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is the reason i am seeing, we need to override a private self._preprocess which gets a ready list of tensor images in CHW format each. It also doesn't need resolving args with x if x else self.x

same for docs, not needed as long as you add annotation with Unpack and decorate class with auto_docstring

Comment on lines +190 to +191
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and do all ops in torch/torchvision since it is TorchBackend

Comment on lines +87 to +90
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same q here, i believe videos have a fixed shape as well

Comment on lines +269 to +270
if size.height is None or size.width is None:
raise ValueError("size must contain 'height' and 'width' keys.")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is in standardize_kwargs so not needed, kinda duplicate

Comment on lines +280 to +282
# Convert from torch (T, C, H, W) to numpy (T, H, W, C)
if isinstance(video, torch.Tensor):
video = video.permute(0, 2, 3, 1).numpy()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm interesting, why since all inputs are in channel first format for image and for video processors

batch_crops = []
batch_pooled_patches_idx = []

for video in videos:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great if we can use groupby_shapes here and in image processing, it greatly speeds up batch processing in some cases

is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)

# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & is_image_block
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant expression: is_image_block & is_image_block is equivalent to is_image_block


src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
tgt_idx = src_idx + 1 # shit right
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo : shift


-->
*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-05.*
*This model was released on 2020-05-16 and added to Hugging Face Transformers on 2025-12-05.*
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated change, should be removed from this PR


-->
*This model was released on {release_date} and added to Hugging Face Transformers on 2026-02-08.*
*This model was released on 2026-02-17 and added to Hugging Face Transformers on 2026-02-09.*
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated change, should be removed from this PR

size = {"height": 378, "width": 378}
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most of those flags are dead e.g. unused in the molmo2 preprocess cc @zucchini-nlp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants