Skip to content

Add Granite 4.1 Vision (granite4_vision)#45597

Open
artem-spector wants to merge 43 commits intohuggingface:mainfrom
artem-spector:add-gv41
Open

Add Granite 4.1 Vision (granite4_vision)#45597
artem-spector wants to merge 43 commits intohuggingface:mainfrom
artem-spector:add-gv41

Conversation

@artem-spector
Copy link
Copy Markdown

What does this PR do?

Adds built-in support for Granite 4.1 Vision (granite4_vision), IBM's multimodal vision-language model for enterprise document understanding.

Architecture highlights

  • Vision encoder: SigLIP2 (google/siglip2-so400m-patch16-384), tiled 384×384 patches
  • Window Q-Former projector: 4×4 patch windows compressed to 2×2 query tokens via cross-attention (downsample_rate="4/8")
  • DeepStack feature injection: 8 vision-to-LLM injection points across two mechanisms:
    • LayerDeepstack: features from 4 vision encoder depths injected at 4 LLM layers (reversed order — deepest vision → earliest LLM)
    • SpatialDeepstack: deepest features split into 4 spatial offset groups (TL/TR/BL/BR), injected at 4 later LLM layers
  • Language model: GraniteForCausalLM (3.5B) with a rank-256 LoRA adapter (same-repo, LM-only)

Files added

File Purpose
modular_granite4_vision.py Source of truth — inherits from LLaVA-Next, overrides novel components
configuration_granite4_vision.py Config (generated)
modeling_granite4_vision.py Model (generated)
processing_granite4_vision.py Unified processor (generated)
image_processing_granite4_vision.py Torchvision-based image processor
image_processing_pil_granite4_vision.py PIL/NumPy image processor
tests/models/granite4_vision/ Modeling, image processing, and processor tests
docs/source/en/model_doc/granite4_vision.md Model documentation

Auto-registration

  • Config: auto-generated via configuration_granite4_vision.py model_type
  • Modeling: MODEL_MAPPING_NAMES + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
  • Processing + image processing: registered in respective auto files

Tests

  • Unit tests pass locally (pytest tests/models/granite4_vision/ -x -q)
  • @slow integration tests load real checkpoint and assert outputs within tolerance
  • make style and make check-repo pass (3 remaining failures are pre-existing upstream issues: mlinter version mismatch and Sam3Lite incomplete model)

Before submitting

  • This PR is not a duplicate
  • I have read the contributor guidelines
  • The documentation reflects the changes
  • The tests pass

Related

@artem-spector
Copy link
Copy Markdown
Author

artem-spector commented Apr 23, 2026

I've traced the root cause of the check_repository_consistency and tests_torch failures to a specific upstream commit:

[Sam3LiteText] Remove unnecessary modules/configs (#45535) (7439ac0)

This commit removed Sam3LiteTextViTConfig and Sam3LiteTextVisionConfig from the modeling file but left them referenced in auto_mappings.py, causing:

  • AttributeError: module transformers has no attribute Sam3LiteTextViTConfig (357 test failures)
  • check_repo failure: Sam3LiteTextVisionConfig appears in CONFIG_MAPPING_NAMES but is not defined

This is reproducible on main independently of our PR.

Question for reviewers: Should we include a fix for this in our PR (removing the stale entries from auto_mappings.py), or would you prefer to handle it in a separate hotfix? Happy to do either.

@artem-spector
Copy link
Copy Markdown
Author

Opened a dedicated issue for the upstream regression: #45600

@zucchini-nlp
Copy link
Copy Markdown
Member

LMK when ready for review, and ig this PR supersedes #45350?

@artem-spector
Copy link
Copy Markdown
Author

artem-spector commented Apr 25, 2026

@zucchini-nlp - yes, this PR supersedes #45350. Its our team that is responsible for producing/release IBM vision models.
This PR is ready for review from my side.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@artem-spector great usage of modular!

Seems like the model uses granite llm as backbone with deepstack features. We will need to add an llm class in that case, since calling each backbone layer manually doesn't align well with our API. We can use modular to copy everything except for a single forward

As per adapters, can you explain how the weights are released? I am not really sure we have to manually add merge_adapters, prob I can suggest a cleaner way

Comment on lines +41 to +48
```bibtex
@misc{granite-vision-4.1-4b,
title={Granite Vision 4.1},
author={IBM Granite Vision Team},
year={2026},
url={https://huggingface.co/ibm-granite/granite-vision-4.1-4b}
}
```
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we dont need a bibtext entry and as long as there is a link to HF papers/arxiv, that is enought

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed.

Comment on lines +66 to +68
device=0,
torch_dtype=torch.bfloat16,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these two are by default "auto" so we dont need to manually set

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed.


processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, torch_dtype is "auto" by default and can be deleted

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed.

Comment on lines +165 to +177
## Notes

- The model includes LoRA adapters. Call `model.merge_lora_adapters()` after loading to merge them into base weights for faster inference.

- Set `padding_side="left"` during batched generation for more accurate results.

```py
processor.tokenizer.padding_side = "left"
```

- The model supports specialized task tags for document extraction: `<chart2csv>`, `<chart2summary>`, `<chart2code>`, `<tables_html>`, `<tables_otsl>`, `<tables_json>`. Pass these as the text prompt along with a document image.

- For key-value pair extraction, provide a JSON schema describing the fields to extract. The model returns structured JSON matching the schema.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move this block as Usage Tips section, before the usage example code snippets

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — moved to a "Usage Tips" section before the code examples.

@@ -0,0 +1,155 @@
import math
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following "one-model - one-file" philosophy, it is better put inside modular/modeling files

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — downsampling_granite4_vision.py deleted, all contents inlined into the modular.

Comment on lines 81 to 87
"openai-privacy-filter": "OpenAIPrivacyFilterConfig",
"lasr": "LasrCTCConfig",
"wav2vec2-with-lm": "Wav2Vec2Config",
"granite4-vision": "Granite4VisionConfig",
"hy-v3": "HYV3Config",
"slanet": "SLANetConfig",
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few bad rebases :)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed the stale entries introduced by bad rebases.

Comment thread src/transformers/conversion_mapping.py Outdated
Comment on lines +463 to +466
WeightRenaming(
source_patterns=r"(vision_tower\.)vision_model\.",
target_patterns=r"\1",
),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not needed anymore, we added PrefixWeights recently and fixed all llava models

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed the granite4_vision entry from conversion_mapping.py.

@@ -0,0 +1,253 @@
# Copyright 2025 IBM. All rights reserved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2026 :)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — Copyright 2026 IBM and The HuggingFace Team.

Comment on lines +50 to +57
class Granite4VisionModelTester(VLMModelTester):
base_model_class = Granite4VisionModel
config_class = Granite4VisionConfig
conditional_generation_class = Granite4VisionForConditionalGeneration
text_config_class = GraniteConfig
vision_config_class = CLIPVisionConfig

def __init__(self, parent, **kwargs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need only this tester, since processing is identical to llava-next. Thanks for using VLMTester 🤩

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed test_image_processing_granite4_vision.py entirely (processing is identical to LlavaNext, no re-definition needed).

Comment thread utils/check_repo.py Outdated
Comment on lines +551 to +557
"granite4_vision",
"falcon3",
"megatron_gpt2",
"code_llama",
"hy_v3",
"openai_privacy_filter",
"slanet",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also bad rebase

@artem-spector
Copy link
Copy Markdown
Author

@zucchini-nlp I'm ready for the second round :)

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @artem-spector ! Great work, glad to see a singlr modular file.

Left some comments on further cleaning-up, and a few questions. I just noticed that all image placeholders are filled with zeros, and not sure if that is intended. If we actually have no image features scattered, can we stop adding that many placeholders without quality degradation? Dummy and unnecessary token ids increase total length and looks like a waste of resources

model_id = "ibm-granite/granite-vision-4.1-4b"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(model_id).eval()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: [...], device_map="auto") without eval, shoudl be already in eval mode when loaded

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed .eval().

("glm_image", {"pil": "GlmImageImageProcessorPil", "torchvision": "GlmImageImageProcessor"}),
("glpn", {"pil": "GLPNImageProcessorPil", "torchvision": "GLPNImageProcessor"}),
("got_ocr2", {"pil": "GotOcr2ImageProcessorPil", "torchvision": "GotOcr2ImageProcessor"}),
("granite4_vision", {"pil": "Granite4VisionImageProcessorPil", "torchvision": "Granite4VisionImageProcessor"}),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete this, already mapped to 'LlavaNext'

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed.

@@ -0,0 +1,734 @@
# Copyright 2025 IBM. All rights reserved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ig you forgot to commit and push 😄



@dataclass
class Granite4VisionImageFeaturesOutput(ModelOutput):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think same as qwen3_vl.BaseModelOutputWithDeepstackFeatures

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — Granite4VisionImageFeaturesOutput now inherits BaseModelOutputWithPooling (same approach as qwen3_vl.BaseModelOutputWithDeepstackFeatures), with a deepstack_features field added.

Comment on lines +95 to +97
class Granite4VisionTextConfig(PreTrainedConfig):
model_type = "granite4_vision_text"
base_config_key = "text_config"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I am following, are we not supposed to inherit from GraniteConfig? Current class has no attributes

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — Granite4VisionTextConfig inherits GraniteConfig directly. It has no additional attributes because the text config is fully specified by GraniteConfig; the subclass exists only to set model_type = "granite4_vision_text" and base_config_key = "text_config".

past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
deepstack_features=outputs.deepstack_features,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only diff from llava-next is returning deepstack_features?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes — the main differences from LlavaNextForConditionalGeneration are: (1) the model uses Granite4VisionModel which has deepstack injection, (2) logits are scaled by text_config.logits_scaling, and (3) deepstack_features is threaded through the output. Everything else is inherited.

logits_to_keep=logits_to_keep,
**kwargs,
)
model_inputs = self._init_hybrid_cache(**model_inputs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be able to delete this line to inti cache. Correct cache should be init by generationMixin._prepare_cache_for_generation automatically

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed; GenerationMixin._prepare_cache_for_generation handles cache initialization.

Comment on lines +140 to +146
@unittest.skip("Granite4VisionImageFeaturesOutput has no hidden_states field")
def test_get_image_features_hidden_states(self):
pass

@unittest.skip("Granite4VisionImageFeaturesOutput has no attentions field")
def test_get_image_features_attentions(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two should be fixed when we add **kwargs and return a BaseModelOutputWithDeepstackFeatures

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — Granite4VisionImageFeaturesOutput now inherits BaseModelOutputWithPooling so all five tests pass. skip_test_image_features_output_shape = True remains because last_hidden_state isn't meaningful for this output type, but hidden_states, pooler_output, and field presence checks all pass.

Comment on lines +148 to +154
@unittest.skip("Base model forward returns ModelOutputWithPast, not CausalLMOutput with loss")
def test_training(self):
pass

@unittest.skip("QFormer submodules not initialized by init_weights from meta device")
def test_can_init_all_missing_weights(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these also should be fixable, I don't think this is a valid reason to skip

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for test_training — the skip was stale; Granite4VisionForConditionalGeneration computes a loss, and the framework skips Granite4VisionModel automatically via MODEL_MAPPING_NAMES.\n\ntest_can_init_all_missing_weights remains skipped: Blip2QFormerModel submodules aren't initialized from meta device by our _init_weightsBlip2QFormerPreTrainedModel._init_weights only handles Blip2ForConditionalGeneration instances, not the standalone Blip2QFormerModel. Happy to investigate further if you'd like.

Comment thread utils/check_config_attributes.py Outdated
Comment on lines +102 to +105
"Granite4VisionConfig": [
"multimodal_projector_bias",
"projector_hidden_act",
],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if actually not used, lets just delete them from config class. When inheriting config, you can define

multimodal_projector_bias = AtributeErro()

and it will not copy this field

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — multimodal_projector_bias and projector_hidden_act are shadowed with AttributeError() at class level on Granite4VisionConfig, so the framework sees them as intentionally inaccessible. The SPECIAL_CASES_TO_ALLOW entry was removed.

artemspector and others added 15 commits April 30, 2026 18:24
Full implementation of IBM Granite 4.1 Vision as a built-in HF model:
- Modular implementation (modular_granite4_vision.py)
- Generated files: config, modeling, image processing, processing
- Auto-registration: config, modeling, processing, image processing
- Tests: modeling (unit + @slow), image processor, processor
- Documentation (docs/source/en/model_doc/granite4_vision.md)
- WeightRenaming to handle SiglipVisionModel vision_model. nesting

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Upstream moved CONFIG_MAPPING_NAMES to auto_mappings.py. Add
granite4_vision entry there; resolve leftover conflict markers in
configuration_auto.py (granite4_vision is already in modeling_auto.py
and processing_auto.py).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…mappings duplicate

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove granite4_vision from MISSING_IMAGE_PROCESSOR_MAPPING_NAMES (auto-discovered via TorchvisionBackend/PilBackend)
- Add granite4-vision to HARDCODED_CONFIG_FOR_MODELS in auto_docstring.py
- Add granite4_vision to DOC_MODEL_NAMES_NOT_IN_AUTO in check_repo.py
- Fix import sort in models/__init__.py and test file
- Regenerate auto_mappings.py via check_auto.py --fix_and_overwrite
- Add dates to granite4_vision.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
- Fix processing_auto.py sort order (sort_auto_mappings)
- Add hy-v3, openai-privacy-filter, slanet to HARDCODED_CONFIG_FOR_MODELS
- Add hy_v3, openai_privacy_filter, slanet to DOC_MODEL_NAMES_NOT_IN_AUTO
  (new upstream models missing from these registries)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
sam3_vision_model and sam3_vit_model were incorrectly mapped to
Sam3LiteTextVisionConfig/Sam3LiteTextViTConfig instead of
Sam3VisionConfig/Sam3ViTConfig (and sam3_lite_text module instead of sam3).
These are unrelated to granite4_vision; restoring upstream/main values.

Signed-off-by: artemspector <artems@il.ibm.com>
…ebase regeneration

These three upstream model entries were accidentally removed from CONFIG_MAPPING_NAMES
in auto_mappings.py by a previous run of check_auto.py --fix_and_overwrite during
an incomplete rebase state. Restoring verbatim from upstream/main.

Signed-off-by: artemspector <artems@il.ibm.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…m auto_docstring and check_repo

These entries belong to other upstream PRs and were accidentally included during a previous rebase. Our PR only owns the granite4_vision entries.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
The hub checkpoint ships with pre-merged weights; PEFT-style merging doesn't
fit the HF API. Regenerated modeling file from modular via converter.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…layer loop

Instead of iterating self.language_model.layers from the VLM model's forward,
introduce Granite4VisionTextModel(GraniteModel) that owns the layer loop and
accepts deepstack_features (dict[layer_idx -> tensor]) and vision_mask.
Granite4VisionModel.forward() now calls self.language_model(...) cleanly.
Pattern follows Qwen3VL. Regenerated modeling file from modular.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…ted file

The modular converter generates a TextConfig subclass for the text model's
sub-layers. Define Granite4VisionTextConfig(GraniteConfig) explicitly in
modular so the converter resolves it correctly instead of creating an undefined
reference. Regenerated config and modeling files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
artemspector and others added 26 commits April 30, 2026 18:25
… import

Inheriting GraniteConfig caused the converter to drop the import in the
generated config file. Align with Qwen3VL pattern: TextConfig inherits
PreTrainedConfig directly. Also add PreTrainedConfig import to modular.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…odel

The converter respects source order; TextModel must come after PreTrainedModel.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…rom_pretrained

When loading with device_map, HF's _move_missing_keys_from_meta_to_device
replaces all non-persistent buffers with torch.empty_like() (garbage memory).
Add a _init_weights handler for Granite4VisionTextRotaryEmbedding that
recomputes inv_freq and original_inv_freq from config, so _initialize_missing_keys
restores correct values after the corruption. Also adds
Granite4VisionTextRotaryEmbedding as an explicit subclass in the modular file
so the isinstance check resolves correctly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ert to pure functions

- Delete downsampling_granite4_vision.py; move WindowQFormerDownsampler,
  interpolate_downsample, and spatial_offset_downsample into modular
- Replace stateless InterpolateDownsampler/SpatialOffsetDownsampler classes
  with plain functions (items 2 and 4 from reviewer feedback)
- Add config.qformer_config (Blip2QFormerConfig) as a proper sub-config field
  on Granite4VisionConfig following the Blip2Config pattern; remove inline
  Blip2QFormerConfig construction from WindowQFormerDownsampler.__init__ (item 3)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace the raw list-of-tuples return from get_image_features with a
proper @DataClass ModelOutput subclass (Granite4VisionImageFeaturesOutput),
following the Qwen3-VL BaseModelOutputWithDeepstackFeatures pattern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…Next

The image processors are identical to LlavaNextImageProcessor and
LlavaNextImageProcessorPil; no need to re-define them. Map
'granite4_vision' to the LlavaNext processors in image_processing_auto.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Item 8: move query/image_positions init to _init_weights (embed_std pattern)
- Item 9: rename _win/_unwin to _windowed_raster/_unwindowed_raster, replace
  single-letter vars with descriptive names
- Item 10: add deepstack_features field to Granite4VisionModelOutputWithPast and
  Granite4VisionCausalLMOutputWithPast instead of reusing image_hidden_states
- Item 11: use TransformersKwargs instead of FlashAttentionKwargs in
  Granite4VisionModel.forward; remove unused FlashAttentionKwargs import
- Item 12: raise ValueError instead of warning_once for patch shape mismatch;
  remove now-unused logger
- Item 13: drop use_image_newline_parameter (not used in released checkpoint)
- Item 14: read pad_token_id from config.text_config instead of top-level config

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Item 15: fix copyright to "2026 IBM and The HuggingFace Team"
- Item 16: remove bibtex entry from docs
- Item 17: remove torch_dtype/device_map from docs examples
- Item 18: move Notes to "Usage Tips" section before code examples
- Item 19: remove model_type from Granite4VisionProcessor
- Item 20: revert AttributeError() (converter incompatible); keep del self.
- Item 21: remove granite4_vision from conversion_mapping (PrefixWeights handles it)
- Item 22: remove granite4_vision from check_repo DOC_MODEL_NAMES_NOT_IN_AUTO and
  HARDCODED_CONFIG_FOR_MODELS in auto_docstring (bad rebase entries)
- Item 23: update test copyright, remove use_image_newline_parameter from tester,
  update skip reasons for get_image_features tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Item 20: drop get_image_token_mask override, use parent's get_placeholder_mask
- Item 29: delete test_image_processing_granite4_vision.py (identical to LlavaNext)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Pass output_attentions/output_hidden_states explicitly to language_model
  in Granite4VisionModel.forward (were swallowed as explicit params, not
  forwarded via **kwargs)
- Collect all_hidden_states and all_self_attns in Granite4VisionTextModel
  layer loop; add output_attentions/output_hidden_states params
- Fix qformer_config dict→object conversion to run before super().__post_init__()
  so _attn_implementation.setter doesn't hit a raw dict during sub_configs iteration
- Use Blip2QFormerConfig directly in sub_configs (instead of AutoConfig) so
  save/load round-trip resolves the type correctly; add missing import to
  generated configuration_granite4_vision.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…es AutoConfig

blip_2_qformer is registered in CONFIG_MAPPING so AutoConfig resolves it correctly.
Moving the Blip2QFormerConfig import inside __post_init__ avoids a cross-model
top-level import that the modular converter drops from the generated file.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ing public classes

- IGNORE_NON_TESTED + IGNORE_NON_AUTO_CONFIGURED: Granite4VisionTextModel is an
  internal subcomponent tested implicitly through Granite4VisionModel
- Doc: add autodoc entries for Granite4VisionTextConfig, Granite4VisionTextModel,
  Granite4VisionImageProcessor, Granite4VisionImageProcessorPil

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Needed for ruff F821 (undefined name) to pass under make style.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…isionModel.forward

Aligns with reviewer feedback: these args are not needed in the explicit
signature since they flow through kwargs: Unpack[TransformersKwargs].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- TRF010: add @strict to Granite4VisionTextConfig (direct PreTrainedConfig subclass)
- TRF002: set base_model_prefix = "model" on Granite4VisionTextModel (was "")
- TRF009: add trf-ignore comment on Blip2QFormerModel lazy import
  (cross-model import is intentional — QFormer is a shared building block)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Reorder imports in modular to satisfy ruff isort (stdlib → third-party → first-party)
- Sync processing_granite4_vision.py to match converter output
  (BatchFeature from feature_extraction_utils, no model_type on processor)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
downsampling_granite4_vision.py, image_processing_granite4_vision.py, and
image_processing_pil_granite4_vision.py are regenerated by the converter but
were previously intentionally deleted: image processors delegate to LlavaNext
(registered in image_processing_auto.py), and downsampling is inlined in
modular/modeling.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… this model

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ctor naming

- Granite4VisionTextConfig restored as proper GraniteConfig subclass
- Granite4VisionConfig.__post_init__: convert dict->config before super() so
  _attn_implementation.setter sees config objects; patch vision-size fields after super()
- Use CONFIG_MAPPING/AutoModel at module top-level (no lazy imports)
- Add _can_record_outputs to Granite4VisionTextModel for hidden_states/attentions
  capture via @capture_outputs decorator
- Add Granite4VisionTextAttention/TextDecoderLayer stubs in modular so converter
  generates the registry entries pointing to the correct layer classes
- WindowQFormerDownsampler renamed to Granite4VisionWindowQFormerDownsampler
- interpolate_downsample/spatial_offset_downsample take explicit size args (not config)
- Remove output_attentions/output_hidden_states from forward signatures (handled by
  @capture_outputs and **kwargs); use BaseModelOutputWithPast return type
- Remove prepare_inputs_for_generation (handled by parent)
- Remove _init_hybrid_cache (GraniteMoeHybrid leftover from 4.0)
- auto_mappings.py: use LlavaNextImageProcessor(Pil) instead of model-specific copies
- docs: remove .eval() from example

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…onfig attrs

- _windowed_raster/unwindowed_raster: x -> features, x_win -> windowed_features
- __init__: q, w -> query_side_str, window_side_str
- Granite4VisionConfig: shadow LlavaNextConfig's multimodal_projector_bias and
  projector_hidden_act with AttributeError() so check_config_attributes passes
  without SPECIAL_CASES_TO_ALLOW entry
- Remove Granite4VisionConfig from check_config_attributes.py SPECIAL_CASES_TO_ALLOW

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…d patching

Peek at vision_config.hidden_size (or its dict equivalent) before super() and
include hidden_size, num_attention_heads, encoder_hidden_size directly in the
CONFIG_MAPPING["blip_2_qformer"]() constructor call. This avoids mutating the
config object after super().__post_init__() runs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…ainedModel

Follows qwen3_vl pattern: _can_record_outputs and _deepstack_inject belong on
the shared PreTrainedModel base class, not on TextModel. TextAttention/
TextDecoderLayer stubs are defined before PreTrainedModel so the converter
generates locally-scoped classes for the _can_record_outputs registry.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…cCache

- Granite4VisionImageFeaturesOutput now inherits BaseModelOutputWithPooling so
  the common test framework can introspect last_hidden_state/pooler_output/
  hidden_states/attentions fields (removes 5 test skips)
- Add @capture_outputs to Granite4VisionTextModel.forward so output_hidden_states
  is collected via hooks and propagated through to the causal LM output (fixes
  test_assisted_decoding_matches_greedy_search)
- get_image_features: populate hidden_states from vision tower when
  output_hidden_states=True (via kwarg or config); removes test skip
- Remove stale test_training skip (ForConditionalGeneration computes loss;
  base model is skipped automatically by MODEL_MAPPING_NAMES check)
- Delete DynamicCache init block from Granite4VisionTextModel.forward;
  GenerationMixin._prepare_cache_for_generation handles this
- Import BaseModelOutputWithPooling, capture_outputs; drop DynamicCache import

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: artemspector <artems@il.ibm.com>
…ader

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
artemspector and others added 2 commits April 30, 2026 18:28
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, granite4_vision

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants