Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/overview/imitation-learning/humanoids_imitation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,22 @@ Then, from the **Isaac-GR00T** directory, install GR00T N1.5 and its dependencie
MAX_JOBS=4 uv pip install --no-build-isolation 'git+https://github.com/facebookresearch/pytorch3d.git@v0.7.9'
uv pip install diffusers decord zmq

.. note::

**If you cannot install or use flash-attn**, an optional patch is provided that switches the
bundled Eagle 2.5 VL model to PyTorch SDPA. Use this if ``flash-attn`` fails to build for your
environment, or if it installs but raises a runtime error such as
``RuntimeError: FlashAttention only supports Ampere GPUs or newer`` (for example on Blackwell
GPUs, which ``flash-attn==2.7.1.post4`` does not have prebuilt kernels for). After the patch,
finetune and rollout run on any CUDA arch supported by your PyTorch build, at the cost of
flash-attn's training speedup. Skip the ``flash-attn`` install line above, then apply the
patch from the **Isaac-GR00T** directory (the sibling layout above means the IsaacLab
checkout is at ``../IsaacLab``):

.. code:: bash

git apply ../IsaacLab/scripts/imitation_learning/locomanipulation_sdg/gr00t/no_flash_attn.patch

Convert dataset to LeRobot format
"""""""""""""""""""""""""""""""""

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
diff --git a/gr00t/model/backbone/eagle2_hg_model/config.json b/gr00t/model/backbone/eagle2_hg_model/config.json
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this git patch sensitive to the particular commit of GR00T? I.e. if someone changes to a different commit and the code changes slightly this will result in a git apply failure. Maybe worth mentioning in the docs the particular commit that this patch works against.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gr00t checkout is already pinned to a specific version in the documentation that matches the patch

index 3894adf..f1b95aa 100644
--- a/gr00t/model/backbone/eagle2_hg_model/config.json
+++ b/gr00t/model/backbone/eagle2_hg_model/config.json
@@ -1,5 +1,5 @@
{
- "_attn_implementation": "flash_attention_2",
+ "_attn_implementation": "sdpa",
"_commit_hash": null,
"architectures": [
"Eagle2_5_VLForConditionalGeneration"
diff --git a/gr00t/model/backbone/eagle2_hg_model/modeling_eagle2_5_vl.py b/gr00t/model/backbone/eagle2_hg_model/modeling_eagle2_5_vl.py
index a9649d5..d99b496 100755
--- a/gr00t/model/backbone/eagle2_hg_model/modeling_eagle2_5_vl.py
+++ b/gr00t/model/backbone/eagle2_hg_model/modeling_eagle2_5_vl.py
@@ -108,7 +108,7 @@ class Eagle2_5_VLForConditionalGeneration(Eagle2_5_VLPreTrainedModel, Generation
self.vision_model = vision_model
else:
if config.vision_config.model_type == "siglip_vision_model":
- config.vision_config._attn_implementation = "flash_attention_2"
+ config.vision_config._attn_implementation = "sdpa"
self.vision_model = SiglipVisionModel(config.vision_config)
elif config.vision_config.model_type == "radio":
self.vision_model = RADIOModel(config.vision_config)
@@ -124,9 +124,7 @@ class Eagle2_5_VLForConditionalGeneration(Eagle2_5_VLPreTrainedModel, Generation
raise NotImplementedError("Phi3 is not implemented.")
# self.language_model = Phi3ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
- assert (
- config.text_config._attn_implementation == "flash_attention_2"
- ), f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
+ config.text_config._attn_implementation = "sdpa"
self.language_model = Qwen2ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
self.language_model = Qwen3ForCausalLM(config.text_config)
diff --git a/gr00t/model/backbone/eagle2_hg_model/radio_model.py b/gr00t/model/backbone/eagle2_hg_model/radio_model.py
index 2df0415..eb9b741 100644
--- a/gr00t/model/backbone/eagle2_hg_model/radio_model.py
+++ b/gr00t/model/backbone/eagle2_hg_model/radio_model.py
@@ -44,12 +44,18 @@ from transformers.utils import ModelOutput

try: # v1
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
-except ImportError: # v2
- from flash_attn.flash_attn_interface import (
- flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func,
- )
+except ImportError:
+ try: # v2
+ from flash_attn.flash_attn_interface import (
+ flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func,
+ )
+ except ImportError: # flash-attn unavailable — RADIO vision won't work, SigLIP does
+ flash_attn_unpadded_qkvpacked_func = None

-from flash_attn.bert_padding import pad_input, unpad_input
+try:
+ from flash_attn.bert_padding import pad_input, unpad_input
+except ImportError:
+ pad_input = unpad_input = None


class FlashAttention(nn.Module):
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import torch
from policy import Policy

from isaaclab.envs.mdp.recorders.recorders_cfg import ActionStateRecorderManagerCfg
from isaaclab.utils.datasets import EpisodeData, HDF5DatasetFileHandler
from isaaclab.utils.math import convert_quat

Expand Down Expand Up @@ -375,6 +376,9 @@ def eval_policy(

env_cfg = parse_env_cfg(env_name, device=args_cli.device, num_envs=1)
env_cfg.sim.device = args_cli.device
# Drop the SDG output-data recorder term: it pulls env._locomanipulation_sdg_output_data,
# which is only populated by the data-generation state machine, not during policy rollout.
env_cfg.recorders = ActionStateRecorderManagerCfg()
env_cfg.recorders.dataset_export_dir_path = os.path.dirname(args_cli.output_file)
env_cfg.recorders.dataset_filename = os.path.basename(args_cli.output_file)

Expand Down
Loading