diff --git a/.gitattributes b/.gitattributes index e808d54903..302cb2e191 100644 --- a/.gitattributes +++ b/.gitattributes @@ -9,7 +9,7 @@ *.ico binary *.pdf binary # Explicit LFS tracking for test files -tests/data/.lfs/*.tar.gz filter=lfs diff=lfs merge=lfs -text +/data/.lfs/*.tar.gz filter=lfs diff=lfs merge=lfs -text *.onnx filter=lfs diff=lfs merge=lfs -text binary *.mp4 filter=lfs diff=lfs merge=lfs -text binary *.mov filter=lfs diff=lfs merge=lfs -text binary diff --git a/.gitignore b/.gitignore index 48717f7e6a..f81f3c96d8 100644 --- a/.gitignore +++ b/.gitignore @@ -23,8 +23,8 @@ assets/agent/memory.txt .bash_history # Ignore all test data directories but allow compressed files -tests/data/* -!tests/data/.lfs/ +/data/* +!/data/.lfs/ # node env (used by devcontainers cli) node_modules diff --git a/tests/data/.lfs/ab_lidar_frames.tar.gz b/data/.lfs/ab_lidar_frames.tar.gz similarity index 100% rename from tests/data/.lfs/ab_lidar_frames.tar.gz rename to data/.lfs/ab_lidar_frames.tar.gz diff --git a/tests/data/.lfs/assets.tar.gz b/data/.lfs/assets.tar.gz similarity index 100% rename from tests/data/.lfs/assets.tar.gz rename to data/.lfs/assets.tar.gz diff --git a/tests/data/.lfs/cafe.jpg.tar.gz b/data/.lfs/cafe.jpg.tar.gz similarity index 100% rename from tests/data/.lfs/cafe.jpg.tar.gz rename to data/.lfs/cafe.jpg.tar.gz diff --git a/tests/data/.lfs/models_clip.tar.gz b/data/.lfs/models_clip.tar.gz similarity index 100% rename from tests/data/.lfs/models_clip.tar.gz rename to data/.lfs/models_clip.tar.gz diff --git a/tests/data/.lfs/models_fastsam.tar.gz b/data/.lfs/models_fastsam.tar.gz similarity index 100% rename from tests/data/.lfs/models_fastsam.tar.gz rename to data/.lfs/models_fastsam.tar.gz diff --git a/tests/data/.lfs/models_yolo.tar.gz b/data/.lfs/models_yolo.tar.gz similarity index 100% rename from tests/data/.lfs/models_yolo.tar.gz rename to data/.lfs/models_yolo.tar.gz diff --git a/tests/data/.lfs/office_lidar.tar.gz b/data/.lfs/office_lidar.tar.gz similarity index 100% rename from tests/data/.lfs/office_lidar.tar.gz rename to data/.lfs/office_lidar.tar.gz diff --git a/tests/data/.lfs/raw_odometry_rotate_walk.tar.gz b/data/.lfs/raw_odometry_rotate_walk.tar.gz similarity index 100% rename from tests/data/.lfs/raw_odometry_rotate_walk.tar.gz rename to data/.lfs/raw_odometry_rotate_walk.tar.gz diff --git a/dimos/agents/memory/image_embedding.py b/dimos/agents/memory/image_embedding.py index 265d15892f..1ad0e9132d 100644 --- a/dimos/agents/memory/image_embedding.py +++ b/dimos/agents/memory/image_embedding.py @@ -19,15 +19,17 @@ using pre-trained models like CLIP, ResNet, etc. """ +import base64 +import io import os -import numpy as np from typing import Union -from PIL import Image -import io + import cv2 -import base64 +import numpy as np +from PIL import Image + +from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import testData logger = setup_logger("dimos.agents.memory.image_embedding") @@ -60,12 +62,12 @@ def __init__(self, model_name: str = "clip", dimensions: int = 512): def _initialize_model(self): """Initialize the specified embedding model.""" try: - import torch - from transformers import CLIPProcessor, AutoFeatureExtractor, AutoModel import onnxruntime as ort + import torch + from transformers import AutoFeatureExtractor, AutoModel, CLIPProcessor if self.model_name == "clip": - model_id = testData("models_clip") / "model.onnx" + model_id = get_data("models_clip") / "model.onnx" processor_id = "openai/clip-vit-base-patch32" self.model = ort.InferenceSession(model_id) self.processor = CLIPProcessor.from_pretrained(processor_id) diff --git a/dimos/agents/memory/test_image_embedding.py b/dimos/agents/memory/test_image_embedding.py index 38877b1461..b55c3a7f27 100644 --- a/dimos/agents/memory/test_image_embedding.py +++ b/dimos/agents/memory/test_image_embedding.py @@ -18,12 +18,14 @@ import os import time + import numpy as np import pytest import reactivex as rx from reactivex import operators as ops -from dimos.stream.video_provider import VideoProvider + from dimos.agents.memory.image_embedding import ImageEmbeddingProvider +from dimos.stream.video_provider import VideoProvider class TestImageEmbedding: @@ -44,9 +46,9 @@ def test_clip_embedding_initialization(self): def test_clip_embedding_process_video(self): """Test CLIP embedding provider can process video frames and return embeddings.""" try: - from dimos.utils.testing import testData + from dimos.utils.data import get_data - video_path = testData("assets") / "trimmed_video_office.mov" + video_path = get_data("assets") / "trimmed_video_office.mov" embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) diff --git a/dimos/perception/detection2d/test_yolo_2d_det.py b/dimos/perception/detection2d/test_yolo_2d_det.py index 5316bfee90..4240625744 100644 --- a/dimos/perception/detection2d/test_yolo_2d_det.py +++ b/dimos/perception/detection2d/test_yolo_2d_det.py @@ -14,11 +14,13 @@ import os import time -import pytest + import cv2 import numpy as np +import pytest import reactivex as rx from reactivex import operators as ops + from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector from dimos.stream.video_provider import VideoProvider @@ -37,12 +39,12 @@ def test_yolo_detector_initialization(self): def test_yolo_detector_process_image(self): """Test YOLO detector can process video frames and return detection results.""" try: - # Import testData inside method to avoid pytest fixture confusion - from dimos.utils.testing import testData + # Import data inside method to avoid pytest fixture confusion + from dimos.utils.data import get_data detector = Yolo2DDetector() - video_path = testData("assets") / "trimmed_video_office.mov" + video_path = get_data("assets") / "trimmed_video_office.mov" # Create video provider and directly get a video stream observable assert os.path.exists(video_path), f"Test video not found: {video_path}" diff --git a/dimos/perception/detection2d/yolo_2d_det.py b/dimos/perception/detection2d/yolo_2d_det.py index 34e094b425..3e20a0fb6f 100644 --- a/dimos/perception/detection2d/yolo_2d_det.py +++ b/dimos/perception/detection2d/yolo_2d_det.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import cv2 +import onnxruntime from ultralytics import YOLO + from dimos.perception.detection2d.utils import ( extract_detection_results, - plot_results, filter_detections, + plot_results, ) -import os -import onnxruntime +from dimos.utils.data import get_data +from dimos.utils.gpu_utils import is_cuda_available from dimos.utils.logging_config import setup_logger from dimos.utils.path_utils import get_project_root -from dimos.utils.testing import testData -from dimos.utils.gpu_utils import is_cuda_available logger = setup_logger("dimos.perception.detection2d.yolo_2d_det") @@ -40,7 +42,7 @@ def __init__(self, model_path="models_yolo", model_name="yolo11n.onnx", device=" device (str): Device to run inference on ('cuda' or 'cpu') """ self.device = device - self.model = YOLO(testData(model_path) / model_name) + self.model = YOLO(get_data(model_path) / model_name) module_dir = os.path.dirname(__file__) self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index 568d09e1e4..f1d32d4daf 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -12,25 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 +import os import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor + +import cv2 +import onnxruntime from ultralytics import FastSAM + +from dimos.perception.common.detection2d_tracker import get_tracked_results, target2dTracker +from dimos.perception.segmentation.image_analyzer import ImageAnalyzer from dimos.perception.segmentation.utils import ( + crop_images_from_bboxes, extract_masks_bboxes_probs_names, filter_segmentation_results, plot_results, - crop_images_from_bboxes, ) +from dimos.utils.data import get_data from dimos.utils.gpu_utils import is_cuda_available -from dimos.perception.common.detection2d_tracker import target2dTracker, get_tracked_results -from dimos.perception.segmentation.image_analyzer import ImageAnalyzer -import os -from collections import deque -from concurrent.futures import ThreadPoolExecutor from dimos.utils.logging_config import setup_logger from dimos.utils.path_utils import get_project_root -import onnxruntime -from dimos.utils.testing import testData logger = setup_logger("dimos.perception.segmentation.sam_2d_seg") @@ -55,7 +57,7 @@ def __init__( logger.info("Using CPU for SAM 2d segmenter") self.device = "cpu" # Core components - self.model = FastSAM(testData(model_path) / model_name) + self.model = FastSAM(get_data(model_path) / model_name) self.use_tracker = use_tracker self.use_analyzer = use_analyzer self.use_rich_labeling = use_rich_labeling diff --git a/dimos/perception/segmentation/test_sam_2d_seg.py b/dimos/perception/segmentation/test_sam_2d_seg.py index fc7e488e51..dd60f4b109 100644 --- a/dimos/perception/segmentation/test_sam_2d_seg.py +++ b/dimos/perception/segmentation/test_sam_2d_seg.py @@ -14,15 +14,17 @@ import os import time -from dimos.stream import video_provider -import pytest + import cv2 import numpy as np +import pytest import reactivex as rx from reactivex import operators as ops -from dimos.stream.video_provider import VideoProvider + from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter from dimos.perception.segmentation.utils import extract_masks_bboxes_probs_names +from dimos.stream import video_provider +from dimos.stream.video_provider import VideoProvider class TestSam2DSegmenter: @@ -39,11 +41,11 @@ def test_sam_segmenter_initialization(self): def test_sam_segmenter_process_image(self): """Test FastSAM segmenter can process video frames and return segmentation masks.""" - # Import testData inside method to avoid pytest fixture confusion - from dimos.utils.testing import testData + # Import get data inside method to avoid pytest fixture confusion + from dimos.utils.data import get_data # Get test video path directly - video_path = testData("assets") / "trimmed_video_office.mov" + video_path = get_data("assets") / "trimmed_video_office.mov" try: # Initialize segmenter without analyzer for faster testing segmenter = Sam2DSegmenter(use_analyzer=False) diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index a6f4fcfa69..ba63917d9b 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -13,19 +13,19 @@ # limitations under the License. import os -import time +import shutil import tempfile -import pytest -import numpy as np +import time + import cv2 -import shutil +import numpy as np +import pytest import reactivex as rx +from reactivex import Observable from reactivex import operators as ops from reactivex.subject import Subject -from reactivex import Observable from dimos.perception.spatial_perception import SpatialMemory -from dimos.types.position import Position from dimos.stream.video_provider import VideoProvider from dimos.types.position import Position from dimos.types.vector import Vector @@ -101,9 +101,9 @@ def test_spatial_memory_processing(self, temp_dir): min_time_threshold=0.01, ) - from dimos.utils.testing import testData + from dimos.utils.data import get_data - video_path = testData("assets") / "trimmed_video_office.mov" + video_path = get_data("assets") / "trimmed_video_office.mov" assert os.path.exists(video_path), f"Test video not found: {video_path}" video_provider = VideoProvider(dev_name="test_video", video_source=video_path) video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) diff --git a/dimos/utils/data.py b/dimos/utils/data.py new file mode 100644 index 0000000000..3196b48a1c --- /dev/null +++ b/dimos/utils/data.py @@ -0,0 +1,163 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import pickle +import subprocess +import tarfile +from functools import cache +from pathlib import Path +from typing import Any, Callable, Generic, Iterator, Optional, Type, TypeVar, Union + +from reactivex import from_iterable, interval +from reactivex import operators as ops +from reactivex.observable import Observable + + +@cache +def _get_repo_root() -> Path: + try: + result = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], capture_output=True, check=True, text=True + ) + return Path(result.stdout.strip()) + except subprocess.CalledProcessError: + raise RuntimeError("Not in a Git repository") + + +@cache +def _get_data_dir() -> Path: + return _get_repo_root() / "data" + + +@cache +def _get_lfs_dir() -> Path: + return _get_data_dir() / ".lfs" + + +def _check_git_lfs_available() -> None: + try: + subprocess.run(["git", "lfs", "version"], capture_output=True, check=True, text=True) + except (subprocess.CalledProcessError, FileNotFoundError): + raise RuntimeError( + "Git LFS is not installed. Please install git-lfs to use test data utilities.\n" + "Installation instructions: https://git-lfs.github.io/" + ) + return True + + +def _is_lfs_pointer_file(file_path: Path) -> bool: + try: + # LFS pointer files are small (typically < 200 bytes) and start with specific text + if file_path.stat().st_size > 1024: # LFS pointers are much smaller + return False + + with open(file_path, "r", encoding="utf-8") as f: + first_line = f.readline().strip() + return first_line.startswith("version https://git-lfs.github.com/spec/") + + except (UnicodeDecodeError, OSError): + return False + + +def _lfs_pull(file_path: Path, repo_root: Path) -> None: + try: + relative_path = file_path.relative_to(repo_root) + + subprocess.run( + ["git", "lfs", "pull", "--include", str(relative_path)], + cwd=repo_root, + check=True, + capture_output=True, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to pull LFS file {file_path}: {e}") + + +def _decompress_archive(filename: Union[str, Path]) -> Path: + target_dir = _get_data_dir() + filename_path = Path(filename) + with tarfile.open(filename_path, "r:gz") as tar: + tar.extractall(target_dir) + return target_dir / filename_path.name.replace(".tar.gz", "") + + +def _pull_lfs_archive(filename: Union[str, Path]) -> Path: + # Check Git LFS availability first + _check_git_lfs_available() + + # Find repository root + repo_root = _get_repo_root() + + # Construct path to test data file + file_path = _get_lfs_dir() / (filename + ".tar.gz") + + # Check if file exists + if not file_path.exists(): + raise FileNotFoundError( + f"Test file '{filename}' not found at {file_path}. " + f"Make sure the file is committed to Git LFS in the tests/data directory." + ) + + # If it's an LFS pointer file, ensure LFS is set up and pull the file + if _is_lfs_pointer_file(file_path): + _lfs_pull(file_path, repo_root) + + # Verify the file was actually downloaded + if _is_lfs_pointer_file(file_path): + raise RuntimeError( + f"Failed to download LFS file '{filename}'. The file is still a pointer after attempting to pull." + ) + + return file_path + + +def get_data(filename: Union[str, Path]) -> Path: + """ + Get the path to a test data, downloading from LFS if needed. + + This function will: + 1. Check that Git LFS is available + 2. Locate the file in the tests/data directory + 3. Initialize Git LFS if needed + 4. Download the file from LFS if it's a pointer file + 5. Return the Path object to the actual file or dir + + Args: + filename: Name of the test file (e.g., "lidar_sample.bin") + + Returns: + Path: Path object to the test file + + Raises: + RuntimeError: If Git LFS is not available or LFS operations fail + FileNotFoundError: If the test file doesn't exist + + Usage: + # As string path + file_path = str(testFile("sample.bin")) + + # As context manager for file operations + with testFile("sample.bin").open('rb') as f: + data = f.read() + """ + data_dir = _get_data_dir() + file_path = data_dir / filename + + # already pulled and decompressed, return it directly + if file_path.exists(): + return file_path + + return _decompress_archive(_pull_lfs_archive(filename)) diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py new file mode 100644 index 0000000000..8e870762ca --- /dev/null +++ b/dimos/utils/test_data.py @@ -0,0 +1,126 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +import subprocess + +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils import data + + +def test_pull_file(): + repo_root = data._get_repo_root() + test_file_name = "cafe.jpg" + test_file_compressed = data._get_lfs_dir() / (test_file_name + ".tar.gz") + test_file_decompressed = data._get_data_dir() / test_file_name + + # delete decompressed test file if it exists + if test_file_decompressed.exists(): + test_file_decompressed.unlink() + + # delete lfs archive file if it exists + if test_file_compressed.exists(): + test_file_compressed.unlink() + + assert not test_file_compressed.exists() + assert not test_file_decompressed.exists() + + # pull the lfs file reference from git + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" + subprocess.run( + ["git", "checkout", "HEAD", "--", test_file_compressed], + cwd=repo_root, + env=env, + check=True, + capture_output=True, + ) + + # ensure we have a pointer file from git (small ASCII text file) + assert test_file_compressed.exists() + assert test_file_compressed.stat().st_size < 200 + + # trigger a data file pull + assert data.get_data(test_file_name) == test_file_decompressed + + # validate data is received + assert test_file_compressed.exists() + assert test_file_decompressed.exists() + + # validate hashes + with test_file_compressed.open("rb") as f: + assert test_file_compressed.stat().st_size > 200 + compressed_sha256 = hashlib.sha256(f.read()).hexdigest() + assert ( + compressed_sha256 == "b8cf30439b41033ccb04b09b9fc8388d18fb544d55b85c155dbf85700b9e7603" + ) + + with test_file_decompressed.open("rb") as f: + decompressed_sha256 = hashlib.sha256(f.read()).hexdigest() + assert ( + decompressed_sha256 + == "55d451dde49b05e3ad386fdd4ae9e9378884b8905bff1ca8aaea7d039ff42ddd" + ) + + +def test_pull_dir(): + repo_root = data._get_repo_root() + test_dir_name = "ab_lidar_frames" + test_dir_compressed = data._get_lfs_dir() / (test_dir_name + ".tar.gz") + test_dir_decompressed = data._get_data_dir() / test_dir_name + + # delete decompressed test directory if it exists + if test_dir_decompressed.exists(): + for item in test_dir_decompressed.iterdir(): + item.unlink() + test_dir_decompressed.rmdir() + + # delete lfs archive file if it exists + if test_dir_compressed.exists(): + test_dir_compressed.unlink() + + # pull the lfs file reference from git + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" + subprocess.run( + ["git", "checkout", "HEAD", "--", test_dir_compressed], + cwd=repo_root, + env=env, + check=True, + capture_output=True, + ) + + # ensure we have a pointer file from git (small ASCII text file) + assert test_dir_compressed.exists() + assert test_dir_compressed.stat().st_size < 200 + + # trigger a data file pull + assert data.get_data(test_dir_name) == test_dir_decompressed + assert test_dir_compressed.stat().st_size > 200 + + # validate data is received + assert test_dir_compressed.exists() + assert test_dir_decompressed.exists() + + for [file, expected_hash] in zip( + sorted(test_dir_decompressed.iterdir()), + [ + "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", + "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", + ], + ): + with file.open("rb") as f: + sha256 = hashlib.sha256(f.read()).hexdigest() + assert sha256 == expected_hash diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index 788acd0d67..092a269862 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -18,113 +18,6 @@ from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils import testing -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage - - -def test_pull_file(): - repo_root = testing._get_repo_root() - test_file_name = "cafe.jpg" - test_file_compressed = testing._get_lfs_dir() / (test_file_name + ".tar.gz") - test_file_decompressed = testing._get_data_dir() / test_file_name - - # delete decompressed test file if it exists - if test_file_decompressed.exists(): - test_file_decompressed.unlink() - - # delete lfs archive file if it exists - if test_file_compressed.exists(): - test_file_compressed.unlink() - - assert not test_file_compressed.exists() - assert not test_file_decompressed.exists() - - # pull the lfs file reference from git - env = os.environ.copy() - env["GIT_LFS_SKIP_SMUDGE"] = "1" - subprocess.run( - ["git", "checkout", "HEAD", "--", test_file_compressed], - cwd=repo_root, - env=env, - check=True, - capture_output=True, - ) - - # ensure we have a pointer file from git (small ASCII text file) - assert test_file_compressed.exists() - assert test_file_compressed.stat().st_size < 200 - - # trigger a data file pull - assert testing.testData(test_file_name) == test_file_decompressed - - # validate data is received - assert test_file_compressed.exists() - assert test_file_decompressed.exists() - - # validate hashes - with test_file_compressed.open("rb") as f: - assert test_file_compressed.stat().st_size > 200 - compressed_sha256 = hashlib.sha256(f.read()).hexdigest() - assert ( - compressed_sha256 == "b8cf30439b41033ccb04b09b9fc8388d18fb544d55b85c155dbf85700b9e7603" - ) - - with test_file_decompressed.open("rb") as f: - decompressed_sha256 = hashlib.sha256(f.read()).hexdigest() - assert ( - decompressed_sha256 - == "55d451dde49b05e3ad386fdd4ae9e9378884b8905bff1ca8aaea7d039ff42ddd" - ) - - -def test_pull_dir(): - repo_root = testing._get_repo_root() - test_dir_name = "ab_lidar_frames" - test_dir_compressed = testing._get_lfs_dir() / (test_dir_name + ".tar.gz") - test_dir_decompressed = testing._get_data_dir() / test_dir_name - - # delete decompressed test directory if it exists - if test_dir_decompressed.exists(): - for item in test_dir_decompressed.iterdir(): - item.unlink() - test_dir_decompressed.rmdir() - - # delete lfs archive file if it exists - if test_dir_compressed.exists(): - test_dir_compressed.unlink() - - # pull the lfs file reference from git - env = os.environ.copy() - env["GIT_LFS_SKIP_SMUDGE"] = "1" - subprocess.run( - ["git", "checkout", "HEAD", "--", test_dir_compressed], - cwd=repo_root, - env=env, - check=True, - capture_output=True, - ) - - # ensure we have a pointer file from git (small ASCII text file) - assert test_dir_compressed.exists() - assert test_dir_compressed.stat().st_size < 200 - - # trigger a data file pull - assert testing.testData(test_dir_name) == test_dir_decompressed - assert test_dir_compressed.stat().st_size > 200 - - # validate data is received - assert test_dir_compressed.exists() - assert test_dir_decompressed.exists() - - for [file, expected_hash] in zip( - sorted(test_dir_decompressed.iterdir()), - [ - "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", - "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", - ], - ): - with file.open("rb") as f: - sha256 = hashlib.sha256(f.read()).hexdigest() - assert sha256 == expected_hash def test_sensor_replay(): diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index b64fcad397..c9e92bd006 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -12,156 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import tarfile import glob import os import pickle +import subprocess +import tarfile from functools import cache from pathlib import Path -from typing import Union, Iterator, TypeVar, Generic, Optional, Any, Type, Callable +from typing import Any, Callable, Generic, Iterator, Optional, Type, TypeVar, Union +from reactivex import from_iterable, interval from reactivex import operators as ops -from reactivex import interval, from_iterable from reactivex.observable import Observable - -def _check_git_lfs_available() -> None: - try: - subprocess.run(["git", "lfs", "version"], capture_output=True, check=True, text=True) - except (subprocess.CalledProcessError, FileNotFoundError): - raise RuntimeError( - "Git LFS is not installed. Please install git-lfs to use test data utilities.\n" - "Installation instructions: https://git-lfs.github.io/" - ) - return True - - -@cache -def _get_repo_root() -> Path: - try: - result = subprocess.run( - ["git", "rev-parse", "--show-toplevel"], capture_output=True, check=True, text=True - ) - return Path(result.stdout.strip()) - except subprocess.CalledProcessError: - raise RuntimeError("Not in a Git repository") - - -@cache -def _get_data_dir() -> Path: - return _get_repo_root() / "tests" / "data" - - -@cache -def _get_lfs_dir() -> Path: - return _get_data_dir() / ".lfs" - - -def _is_lfs_pointer_file(file_path: Path) -> bool: - try: - # LFS pointer files are small (typically < 200 bytes) and start with specific text - if file_path.stat().st_size > 1024: # LFS pointers are much smaller - return False - - with open(file_path, "r", encoding="utf-8") as f: - first_line = f.readline().strip() - return first_line.startswith("version https://git-lfs.github.com/spec/") - - except (UnicodeDecodeError, OSError): - return False - - -def _lfs_pull(file_path: Path, repo_root: Path) -> None: - try: - relative_path = file_path.relative_to(repo_root) - - subprocess.run( - ["git", "lfs", "pull", "--include", str(relative_path)], - cwd=repo_root, - check=True, - capture_output=True, - ) - except subprocess.CalledProcessError as e: - raise RuntimeError(f"Failed to pull LFS file {file_path}: {e}") - - -def _pull_lfs_archive(filename: Union[str, Path]) -> Path: - # Check Git LFS availability first - _check_git_lfs_available() - - # Find repository root - repo_root = _get_repo_root() - - # Construct path to test data file - file_path = _get_lfs_dir() / (filename + ".tar.gz") - - # Check if file exists - if not file_path.exists(): - raise FileNotFoundError( - f"Test file '{filename}' not found at {file_path}. " - f"Make sure the file is committed to Git LFS in the tests/data directory." - ) - - # If it's an LFS pointer file, ensure LFS is set up and pull the file - if _is_lfs_pointer_file(file_path): - _lfs_pull(file_path, repo_root) - - # Verify the file was actually downloaded - if _is_lfs_pointer_file(file_path): - raise RuntimeError( - f"Failed to download LFS file '{filename}'. The file is still a pointer after attempting to pull." - ) - - return file_path - - -def _decompress_archive(filename: Union[str, Path]) -> Path: - target_dir = _get_data_dir() - filename_path = Path(filename) - with tarfile.open(filename_path, "r:gz") as tar: - tar.extractall(target_dir) - return target_dir / filename_path.name.replace(".tar.gz", "") - - -def testData(filename: Union[str, Path]) -> Path: - """ - Get the path to a test data, downloading from LFS if needed. - - This function will: - 1. Check that Git LFS is available - 2. Locate the file in the tests/data directory - 3. Initialize Git LFS if needed - 4. Download the file from LFS if it's a pointer file - 5. Return the Path object to the actual file or dir - - Args: - filename: Name of the test file (e.g., "lidar_sample.bin") - - Returns: - Path: Path object to the test file - - Raises: - RuntimeError: If Git LFS is not available or LFS operations fail - FileNotFoundError: If the test file doesn't exist - - Usage: - # As string path - file_path = str(testFile("sample.bin")) - - # As context manager for file operations - with testFile("sample.bin").open('rb') as f: - data = f.read() - """ - data_dir = _get_data_dir() - file_path = data_dir / filename - - # already pulled and decompressed, return it directly - if file_path.exists(): - return file_path - - return _decompress_archive(_pull_lfs_archive(filename)) - +from dimos.utils.data import _get_data_dir, get_data T = TypeVar("T") @@ -176,7 +40,7 @@ class SensorReplay(Generic[T]): """ def __init__(self, name: str, autocast: Optional[Callable[[Any], T]] = None): - self.root_dir = testData(name) + self.root_dir = get_data(name) self.autocast = autocast def load(self, *names: Union[int, str]) -> Union[T, Any, list[T], list[Any]]: