diff --git a/internnav/dataset/internvla_n1_lerobot_dataset.py b/internnav/dataset/internvla_n1_lerobot_dataset.py index 8cd39a67..52981731 100644 --- a/internnav/dataset/internvla_n1_lerobot_dataset.py +++ b/internnav/dataset/internvla_n1_lerobot_dataset.py @@ -17,6 +17,7 @@ from torchcodec.decoders import VideoDecoder from transformers.image_utils import to_numpy_array +from .vlln_lerobot_dataset import VLLNDataset from .rope2d import get_rope_index_2, get_rope_index_25 # Define placeholders for dataset paths @@ -150,6 +151,11 @@ def parse_sampling_rate(dataset_name): return 1.0 +def read_jsonl(path): + with open(path, "r") as f: + return [json.loads(line) for line in f] + + def data_list(dataset_names): config_list = [] for dataset_name in dataset_names: @@ -180,11 +186,6 @@ def rank0_print(*args): print(*args) -def read_jsonl(path): - with open(path, "r") as f: - return [json.loads(line) for line in f] - - def preprocess_qwen_2_visual( sources, tokenizer: transformers.PreTrainedTokenizer, @@ -1329,11 +1330,50 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: return batch +class CombinedDataset(Dataset): + """ + Combine multiple datasets into a single dataset interface. + + This class is used to merge different datasets for joint training. + It concatenates samples from all provided datasets and optionally shuffles + the global index mapping (without changing the underlying datasets). + """ + def __init__(self, datasets, shuffle=False): + super(CombinedDataset, self).__init__() + self.datasets = datasets + self.lengths = [len(dataset) for dataset in datasets] + self.cum_lengths = np.cumsum(self.lengths) + self.total_length = self.cum_lengths[-1] + self.shuffle_enabled = shuffle + self.indices = np.arange(self.total_length) + if self.shuffle_enabled: + self.shuffle() + + def shuffle(self): + np.random.shuffle(self.indices) + + def _map_index(self, idx): + return self.indices[idx] + + def __len__(self): + return self.cum_lengths[-1] + + def __getitem__(self, i): + real_idx = self._map_index(i) + for idx, cum_len in enumerate(self.cum_lengths): + if real_idx < cum_len: + return self.datasets[idx][real_idx - cum_len + self.lengths[idx]] + raise ValueError(f"Index {real_idx} out of bound") + def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: """Make dataset and collator for supervised fine-tuning.""" - train_dataset = NavPixelGoalDataset(tokenizer=tokenizer, data_args=data_args) - # train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args) + train_datasets = [] + if data_args.iion_dataset_use: + train_datasets.append(VLLNDataset(tokenizer=tokenizer, data_args=data_args)) + if data_args.vln_dataset_use: + train_datasets.append(NavPixelGoalDataset(tokenizer=tokenizer, data_args=data_args)) + train_dataset = CombinedDataset(train_datasets, shuffle=False) if data_args.data_flatten: data_collator = FlattenedDataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) diff --git a/internnav/dataset/vlln_lerobot_dataset.py b/internnav/dataset/vlln_lerobot_dataset.py new file mode 100644 index 00000000..fa670502 --- /dev/null +++ b/internnav/dataset/vlln_lerobot_dataset.py @@ -0,0 +1,674 @@ +import copy +import itertools +import json +import os +import random +import re +import time +from dataclasses import dataclass +from typing import Dict, List, Sequence, Tuple + +import numpy as np +import torch +import transformers +from decord import VideoReader +from PIL import Image +from torch.utils.data import Dataset +from torchcodec.decoders import VideoDecoder +from transformers.image_utils import to_numpy_array +from bisect import bisect_left +from .rope2d import get_rope_index_2, get_rope_index_25 + + +# Define placeholders for dataset paths +IION_split1 = { + "data_path": "traj_data/mp3d_split1", + "height": 125, + "pitch_1": 0, + "pitch_2": 30, +} + +IION_split2 = { + "data_path": "traj_data/mp3d_split2", + "height": 125, + "pitch_1": 0, + "pitch_2": 30, +} + +IION_split3 = { + "data_path": "traj_data/mp3d_split3", + "height": 125, + "pitch_1": 0, + "pitch_2": 30, +} + +data_dict = { + "iion_split1": IION_split1, + "iion_split2": IION_split2, + "iion_split3": IION_split3, +} + +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = 151655 +VIDEO_TOKEN_INDEX = 151656 +TRAJ_TOKEN_INDEX = 151667 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_VIDEO_TOKEN = "