From 08db66c865fdcd2485d41ba063aef818d97fa4c0 Mon Sep 17 00:00:00 2001 From: zhushaohao Date: Tue, 16 Dec 2025 15:52:41 +0800 Subject: [PATCH 1/9] add VL-LN Bench training code --- internnav/dataset/vlln_lerobot_dataset.py | 769 ++++++++++++++++++ internnav/trainer/internvla_n1_argument.py | 1 + internnav/trainer/internvla_vlln_trainer.py | 239 ++++++ .../train/qwenvl_train/train_system2_vlln.sh | 81 ++ traj_data | 1 + 5 files changed, 1091 insertions(+) create mode 100644 internnav/dataset/vlln_lerobot_dataset.py create mode 100644 internnav/trainer/internvla_vlln_trainer.py create mode 100644 scripts/train/qwenvl_train/train_system2_vlln.sh create mode 120000 traj_data diff --git a/internnav/dataset/vlln_lerobot_dataset.py b/internnav/dataset/vlln_lerobot_dataset.py new file mode 100644 index 00000000..b0e78843 --- /dev/null +++ b/internnav/dataset/vlln_lerobot_dataset.py @@ -0,0 +1,769 @@ +import copy +import itertools +import json +import os +import random +import re +import time +from dataclasses import dataclass +from typing import Dict, List, Sequence, Tuple + +import numpy as np +import torch +import transformers +from decord import VideoReader +from PIL import Image +from torch.utils.data import Dataset +from torchcodec.decoders import VideoDecoder +from transformers.image_utils import to_numpy_array +from bisect import bisect_left + +from .rope2d import get_rope_index_2, get_rope_index_25 + +# Define placeholders for dataset paths +IION_split1 = { + "data_path": "traj_data/mp3d_split1", + "height": 125, + "pitch_1": 0, + "pitch_2": 30, +} + +IION_split2 = { + "data_path": "traj_data/mp3d_split2", + "height": 125, + "pitch_1": 0, + "pitch_2": 30, +} + +IION_split3 = { + "data_path": "traj_data/mp3d_split3", + "height": 125, + "pitch_1": 0, + "pitch_2": 30, +} + +data_dict = { + "iion_split1": IION_split1, + "iion_split2": IION_split2, + "iion_split3": IION_split3, +} + + +def parse_sampling_rate(dataset_name): + match = re.search(r"%(\d+)$", dataset_name) + if match: + return int(match.group(1)) / 100.0 + return 1.0 + + +def data_list(dataset_names): + config_list = [] + for dataset_name in dataset_names: + sampling_rate = parse_sampling_rate(dataset_name) + dataset_name = re.sub(r"%(\d+)$", "", dataset_name) + if dataset_name in data_dict.keys(): + config = data_dict[dataset_name].copy() + config["sampling_rate"] = sampling_rate + config_list.append(config) + else: + raise ValueError(f"do not find {dataset_name}") + return config_list + + +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = 151655 +VIDEO_TOKEN_INDEX = 151656 +TRAJ_TOKEN_INDEX = 151667 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_VIDEO_TOKEN = "