From 6cd48bfa765fd034fe1aa0f1fdb05f9a8ab06670 Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 27 May 2025 22:02:58 +0300 Subject: [PATCH 01/42] json cleanup, tests reformat --- .gitignore | 4 +- dimos/utils/test_testing.py | 9 +- dimos/utils/testing.py | 9 +- dimos/web/dimos_interface/tsconfig.json | 22 +- dimos/web/dimos_interface/tsconfig.node.json | 4 +- dimos/web/websocket_vis/deno.json | 42 ++-- tests/agent_manip_flow_fastapi_test.py | 31 ++- tests/agent_manip_flow_flask_test.py | 39 ++-- tests/agent_memory_test.py | 2 +- tests/genesissim/stream_camera.py | 31 +-- tests/isaacsim/stream_camera.py | 12 +- tests/run.py | 57 +++-- tests/run_go2_ros.py | 52 ++--- tests/simple_agent_test.py | 18 +- tests/test_agent.py | 4 + tests/test_agent_alibaba.py | 28 +-- tests/test_agent_ctransformers_gguf.py | 4 +- tests/test_agent_huggingface_local.py | 10 +- tests/test_agent_huggingface_local_jetson.py | 8 +- tests/test_agent_huggingface_remote.py | 24 +- tests/test_audio_agent.py | 1 - tests/test_claude_agent_query.py | 7 +- tests/test_claude_agent_skills_query.py | 25 ++- tests/test_command_pose_unitree.py | 13 +- tests/test_header.py | 12 +- tests/test_huggingface_llm_agent.py | 14 +- tests/test_move_vel_unitree.py | 9 +- tests/test_navigate_to_object_robot.py | 79 ++++--- tests/test_navigation_skills.py | 64 +++--- ...bject_detection_agent_data_query_stream.py | 78 ++++--- tests/test_object_detection_stream.py | 115 +++++----- tests/test_object_tracking_webcam.py | 110 ++++++---- tests/test_object_tracking_with_qwen.py | 86 +++++--- tests/test_observe_stream_skill.py | 71 +++--- tests/test_person_following_robot.py | 53 ++--- tests/test_person_following_webcam.py | 117 +++++----- tests/test_planning_agent_web_interface.py | 54 +++-- tests/test_planning_robot_agent.py | 32 +-- tests/test_qwen_image_query.py | 18 +- tests/test_robot.py | 36 ++- tests/test_rtsp_video_provider.py | 53 ++--- tests/test_semantic_seg_robot.py | 53 +++-- tests/test_semantic_seg_robot_agent.py | 68 +++--- tests/test_semantic_seg_webcam.py | 57 ++--- tests/test_skills.py | 61 +++--- tests/test_skills_rest.py | 15 +- tests/test_spatial_memory.py | 180 ++++++++------- tests/test_spatial_memory_query.py | 207 ++++++++++-------- tests/test_standalone_chromadb.py | 12 +- tests/test_standalone_fastapi.py | 28 ++- tests/test_standalone_hugging_face.py | 16 +- tests/test_standalone_openai_json.py | 26 ++- tests/test_standalone_openai_json_struct.py | 14 +- ...test_standalone_openai_json_struct_func.py | 69 +++--- ...lone_openai_json_struct_func_playground.py | 24 +- tests/test_standalone_project_out.py | 138 ++++++------ tests/test_standalone_rxpy_01.py | 19 +- tests/test_unitree_agent.py | 38 ++-- tests/test_unitree_agent_queries_fastapi.py | 27 ++- tests/test_webrtc_queue.py | 72 +++--- tests/test_websocketvis.py | 28 +-- 61 files changed, 1408 insertions(+), 1201 deletions(-) diff --git a/.gitignore b/.gitignore index 47d636fc28..2439af14ed 100644 --- a/.gitignore +++ b/.gitignore @@ -22,5 +22,7 @@ assets/agent/memory.txt tests/data/* !tests/data/.lfs/ -# node modules (for dev tooling) +# node env (used by devcontainers cli) node_modules +package.json +package-lock.json diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index e134492952..8952782168 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -43,11 +43,16 @@ def test_pull_file(): # validate hashes with test_file_compressed.open("rb") as f: compressed_sha256 = hashlib.sha256(f.read()).hexdigest() - assert compressed_sha256 == "cdfd708d66e6dd5072ed7636fc10fb97754f8d14e3acd6c3553663e27fc96065" + assert ( + compressed_sha256 == "cdfd708d66e6dd5072ed7636fc10fb97754f8d14e3acd6c3553663e27fc96065" + ) with test_file_decompressed.open("rb") as f: decompressed_sha256 = hashlib.sha256(f.read()).hexdigest() - assert decompressed_sha256 == "55d451dde49b05e3ad386fdd4ae9e9378884b8905bff1ca8aaea7d039ff42ddd" + assert ( + decompressed_sha256 + == "55d451dde49b05e3ad386fdd4ae9e9378884b8905bff1ca8aaea7d039ff42ddd" + ) def test_pull_dir(): diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index a37e1804b9..53b9849718 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -19,7 +19,9 @@ def _check_git_lfs_available() -> None: @cache def _get_repo_root() -> Path: try: - result = subprocess.run(["git", "rev-parse", "--show-toplevel"], capture_output=True, check=True, text=True) + result = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], capture_output=True, check=True, text=True + ) return Path(result.stdout.strip()) except subprocess.CalledProcessError: raise RuntimeError("Not in a Git repository") @@ -54,7 +56,10 @@ def _lfs_pull(file_path: Path, repo_root: Path) -> None: relative_path = file_path.relative_to(repo_root) subprocess.run( - ["git", "lfs", "pull", "--include", str(relative_path)], cwd=repo_root, check=True, capture_output=True + ["git", "lfs", "pull", "--include", str(relative_path)], + cwd=repo_root, + check=True, + capture_output=True, ) except subprocess.CalledProcessError as e: raise RuntimeError(f"Failed to pull LFS file {file_path}: {e}") diff --git a/dimos/web/dimos_interface/tsconfig.json b/dimos/web/dimos_interface/tsconfig.json index 772ce46b79..4bf29f39d2 100644 --- a/dimos/web/dimos_interface/tsconfig.json +++ b/dimos/web/dimos_interface/tsconfig.json @@ -5,17 +5,21 @@ "useDefineForClassFields": true, "module": "ESNext", "resolveJsonModule": true, - /** - * Typecheck JS in `.svelte` and `.js` files by default. - * Disable checkJs if you'd like to use dynamic types in JS. - * Note that setting allowJs false does not prevent the use - * of JS in `.svelte` files. - */ "allowJs": true, "checkJs": true, "isolatedModules": true, - "types": ["node"] + "types": [ + "node" + ] }, - "include": ["src/**/*.ts", "src/**/*.js", "src/**/*.svelte"], - "references": [{ "path": "./tsconfig.node.json" }] + "include": [ + "src/**/*.ts", + "src/**/*.js", + "src/**/*.svelte" + ], + "references": [ + { + "path": "./tsconfig.node.json" + } + ] } diff --git a/dimos/web/dimos_interface/tsconfig.node.json b/dimos/web/dimos_interface/tsconfig.node.json index 494bfe0835..ad883d0eb4 100644 --- a/dimos/web/dimos_interface/tsconfig.node.json +++ b/dimos/web/dimos_interface/tsconfig.node.json @@ -5,5 +5,7 @@ "module": "ESNext", "moduleResolution": "bundler" }, - "include": ["vite.config.ts"] + "include": [ + "vite.config.ts" + ] } diff --git a/dimos/web/websocket_vis/deno.json b/dimos/web/websocket_vis/deno.json index 401e578474..453c7a4a80 100644 --- a/dimos/web/websocket_vis/deno.json +++ b/dimos/web/websocket_vis/deno.json @@ -1,22 +1,26 @@ { - "nodeModulesDir": "auto", - "tasks": { - "build": "deno run -A build.ts", - "watch": "deno run -A build.ts --watch" - }, - "lint": { - "rules": { - "exclude": ["require-await", "ban-ts-comment"] - } - }, - "fmt": { - "indentWidth": 4, - "useTabs": false, - "semiColons": false - }, - - "compilerOptions": { - "lib": ["dom", "deno.ns"] + "nodeModulesDir": "auto", + "tasks": { + "build": "deno run -A build.ts", + "watch": "deno run -A build.ts --watch" + }, + "lint": { + "rules": { + "exclude": [ + "require-await", + "ban-ts-comment" + ] } + }, + "fmt": { + "indentWidth": 4, + "useTabs": false, + "semiColons": false + }, + "compilerOptions": { + "lib": [ + "dom", + "deno.ns" + ] + } } - diff --git a/tests/agent_manip_flow_fastapi_test.py b/tests/agent_manip_flow_fastapi_test.py index d9bc9ef2f9..d802dd5663 100644 --- a/tests/agent_manip_flow_fastapi_test.py +++ b/tests/agent_manip_flow_fastapi_test.py @@ -19,7 +19,7 @@ from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler, ImmediateScheduler # Local application imports -from dimos.agents.agent import OpenAIAgent +from dimos.agents.agent import OpenAIAgent from dimos.stream.frame_processor import FrameProcessor from dimos.stream.video_operators import VideoOperators as vops from dimos.stream.video_provider import VideoProvider @@ -28,6 +28,7 @@ # Load environment variables load_dotenv() + def main(): """ Initializes and runs the video processing pipeline with web server output. @@ -42,7 +43,9 @@ def main(): """ disposables = CompositeDisposable() - processor = FrameProcessor(output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True) + processor = FrameProcessor( + output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True + ) optimal_thread_count = multiprocessing.cpu_count() # Gets number of CPU cores thread_pool_scheduler = ThreadPoolScheduler(optimal_thread_count) @@ -53,14 +56,16 @@ def main(): f"{os.getcwd()}/assets/trimmed_video_480p.mov", f"{os.getcwd()}/assets/video-f30-480p.mp4", "rtsp://192.168.50.207:8080/h264.sdp", - "rtsp://10.0.0.106:8080/h264.sdp" + "rtsp://10.0.0.106:8080/h264.sdp", ] VIDEO_SOURCE_INDEX = 3 VIDEO_SOURCE_INDEX_2 = 2 my_video_provider = VideoProvider("Video File", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX]) - my_video_provider_2 = VideoProvider("Video File 2", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX_2]) + my_video_provider_2 = VideoProvider( + "Video File 2", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX_2] + ) video_stream_obs = my_video_provider.capture_video_as_observable(fps=120).pipe( ops.subscribe_on(thread_pool_scheduler), @@ -86,37 +91,39 @@ def main(): vops.with_jpeg_export(processor, suffix="edge"), ) - optical_flow_relevancy_stream_obs = processor.process_stream_optical_flow_with_relevancy(video_stream_obs) + optical_flow_relevancy_stream_obs = processor.process_stream_optical_flow_with_relevancy( + video_stream_obs + ) optical_flow_stream_obs = optical_flow_relevancy_stream_obs.pipe( ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), vops.with_optical_flow_filtering(threshold=2.0), ops.do_action(lambda _: print(f"Optical Flow Passed Threshold.")), - vops.with_jpeg_export(processor, suffix="optical") + vops.with_jpeg_export(processor, suffix="optical"), ) # - # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== + # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== # # Agent 1 # my_agent = OpenAIAgent( - # "Agent 1", + # "Agent 1", # query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.") # my_agent.subscribe_to_image_processing(slowed_video_stream_obs) # disposables.add(my_agent.disposables) # # Agent 2 # my_agent_two = OpenAIAgent( - # "Agent 2", + # "Agent 2", # query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.") # my_agent_two.subscribe_to_image_processing(optical_flow_stream_obs) # disposables.add(my_agent_two.disposables) # - # ====== Create and start the FastAPI server ====== + # ====== Create and start the FastAPI server ====== # - + # Will be visible at http://[host]:[port]/video_feed/[key] streams = { "video_one": video_stream_obs, @@ -127,6 +134,6 @@ def main(): fast_api_server = FastAPIServer(port=5555, **streams) fast_api_server.run() + if __name__ == "__main__": main() - diff --git a/tests/agent_manip_flow_flask_test.py b/tests/agent_manip_flow_flask_test.py index e7cf0cfb61..aecf4049a5 100644 --- a/tests/agent_manip_flow_flask_test.py +++ b/tests/agent_manip_flow_flask_test.py @@ -20,7 +20,7 @@ from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler, ImmediateScheduler # Local application imports -from dimos.agents.agent import PromptBuilder, OpenAIAgent +from dimos.agents.agent import PromptBuilder, OpenAIAgent from dimos.stream.frame_processor import FrameProcessor from dimos.stream.video_operators import VideoOperators as vops from dimos.stream.video_provider import VideoProvider @@ -31,6 +31,7 @@ app = Flask(__name__) + def main(): """ Initializes and runs the video processing pipeline with web server output. @@ -45,7 +46,9 @@ def main(): """ disposables = CompositeDisposable() - processor = FrameProcessor(output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True) + processor = FrameProcessor( + output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True + ) optimal_thread_count = multiprocessing.cpu_count() # Gets number of CPU cores thread_pool_scheduler = ThreadPoolScheduler(optimal_thread_count) @@ -58,7 +61,7 @@ def main(): f"{os.getcwd()}/assets/video.mov", "rtsp://192.168.50.207:8080/h264.sdp", "rtsp://10.0.0.106:8080/h264.sdp", - f"{os.getcwd()}/assets/people_1080p_24fps.mp4" + f"{os.getcwd()}/assets/people_1080p_24fps.mp4", ] VIDEO_SOURCE_INDEX = 4 @@ -85,28 +88,28 @@ def main(): # ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), # vops.with_optical_flow_filtering(threshold=2.0), # ops.do_action(lambda _: print(f"Optical Flow Passed Threshold.")), - #vops.with_jpeg_export(processor, suffix="optical") + # vops.with_jpeg_export(processor, suffix="optical") ) # - # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== + # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== # # Observable that emits every 2 seconds secondly_emission = interval(2, scheduler=thread_pool_scheduler).pipe( - ops.map(lambda x: f"Second {x+1}"), + ops.map(lambda x: f"Second {x + 1}"), # ops.take(30) ) # Agent 1 my_agent = OpenAIAgent( - "Agent 1", + "Agent 1", query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.", - json_mode=False + json_mode=False, ) - - # Create an agent for each subset of questions that it would be theroized to handle. - # Set std. template/blueprints, and devs will add to that likely. + + # Create an agent for each subset of questions that it would be theroized to handle. + # Set std. template/blueprints, and devs will add to that likely. ai_1_obs = video_stream_obs.pipe( # vops.with_fps_sampling(fps=30), @@ -118,20 +121,20 @@ def main(): ai_1_obs.connect() ai_1_repeat_obs = ai_1_obs.pipe(ops.repeat()) - + my_agent.subscribe_to_image_processing(ai_1_obs) disposables.add(my_agent.disposables) # Agent 2 my_agent_two = OpenAIAgent( - "Agent 2", + "Agent 2", query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.", max_input_tokens_per_request=1000, max_output_tokens_per_request=300, json_mode=False, model_name="gpt-4o-2024-08-06", ) - + ai_2_obs = optical_flow_stream_obs.pipe( # vops.with_fps_sampling(fps=30), # ops.throttle_first(1), @@ -143,7 +146,6 @@ def main(): ai_2_repeat_obs = ai_2_obs.pipe(ops.repeat()) - # Combine emissions using zip ai_1_secondly_repeating_obs = zip(secondly_emission, ai_1_repeat_obs).pipe( # ops.do_action(lambda s: print(f"AI 1 - Emission Count: {s[0]}")), @@ -156,12 +158,11 @@ def main(): ops.map(lambda r: r[1]), ) - my_agent_two.subscribe_to_image_processing(ai_2_obs) disposables.add(my_agent_two.disposables) # - # ====== Create and start the Flask server ====== + # ====== Create and start the Flask server ====== # # Will be visible at http://[host]:[port]/video_feed/[key] @@ -172,9 +173,9 @@ def main(): OpenAIAgent_1=ai_1_secondly_repeating_obs, OpenAIAgent_2=ai_2_secondly_repeating_obs, ) - + flask_server.run(threaded=True) + if __name__ == "__main__": main() - diff --git a/tests/agent_memory_test.py b/tests/agent_memory_test.py index 5ebe40e5d7..e77cbc2821 100644 --- a/tests/agent_memory_test.py +++ b/tests/agent_memory_test.py @@ -44,4 +44,4 @@ results = agent_memory.query("Colors", n_results=19, similarity_threshold=0.45) print(results) -print("Done querying agent memory (n_results=19, similarity_threshold=0.45).") \ No newline at end of file +print("Done querying agent memory (n_results=19, similarity_threshold=0.45).") diff --git a/tests/genesissim/stream_camera.py b/tests/genesissim/stream_camera.py index c7a439232f..4a820851cf 100644 --- a/tests/genesissim/stream_camera.py +++ b/tests/genesissim/stream_camera.py @@ -1,27 +1,19 @@ import os from dimos.simulation.genesis import GenesisSimulator, GenesisStream + def main(): # Add multiple entities at once entities = [ - { - 'type': 'primitive', - 'params': {'shape': 'plane'} - }, - { - 'type': 'mjcf', - 'path': 'xml/franka_emika_panda/panda.xml' - } + {"type": "primitive", "params": {"shape": "plane"}}, + {"type": "mjcf", "path": "xml/franka_emika_panda/panda.xml"}, ] # Initialize simulator - sim = GenesisSimulator( - headless=True, - entities=entities - ) + sim = GenesisSimulator(headless=True, entities=entities) # You can also add entity individually - sim.add_entity('primitive', shape='box', size=[0.5, 0.5, 0.5], pos=[0, 1, 0.5]) - + sim.add_entity("primitive", shape="box", size=[0.5, 0.5, 0.5], pos=[0, 1, 0.5]) + # Create stream with custom settings stream = GenesisStream( simulator=sim, @@ -29,11 +21,11 @@ def main(): height=960, fps=60, camera_path="/camera", # Genesis uses simpler camera paths - annotator_type='rgb', # Can be 'rgb' or 'normals' - transport='tcp', - rtsp_url="rtsp://mediamtx:8554/stream" + annotator_type="rgb", # Can be 'rgb' or 'normals' + transport="tcp", + rtsp_url="rtsp://mediamtx:8554/stream", ) - + # Start streaming try: stream.stream() @@ -45,5 +37,6 @@ def main(): finally: sim.close() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/isaacsim/stream_camera.py b/tests/isaacsim/stream_camera.py index 31952dc671..a6eecd93d9 100644 --- a/tests/isaacsim/stream_camera.py +++ b/tests/isaacsim/stream_camera.py @@ -2,10 +2,11 @@ from dimos.simulation.isaac import IsaacSimulator from dimos.simulation.isaac import IsaacStream + def main(): # Initialize simulator sim = IsaacSimulator(headless=True) - + # Create stream with custom settings stream = IsaacStream( simulator=sim, @@ -13,14 +14,15 @@ def main(): height=1080, fps=60, camera_path="/World/alfred_parent_prim/alfred_base_descr/chest_cam_rgb_camera_frame/chest_cam", - annotator_type='rgb', - transport='tcp', + annotator_type="rgb", + transport="tcp", rtsp_url="rtsp://mediamtx:8554/stream", - usd_path=f"{os.getcwd()}/assets/TestSim3.usda" + usd_path=f"{os.getcwd()}/assets/TestSim3.usda", ) - + # Start streaming stream.stream() + if __name__ == "__main__": main() diff --git a/tests/run.py b/tests/run.py index d62c3a1103..4c4bfc036e 100644 --- a/tests/run.py +++ b/tests/run.py @@ -43,21 +43,31 @@ # Allow command line arguments to control spatial memory parameters import argparse + def parse_arguments(): - parser = argparse.ArgumentParser(description='Run the robot with optional spatial memory parameters') - parser.add_argument('--new-memory', action='store_true', help='Create a new spatial memory from scratch') - parser.add_argument('--spatial-memory-dir', type=str, help='Directory for storing spatial memory data') + parser = argparse.ArgumentParser( + description="Run the robot with optional spatial memory parameters" + ) + parser.add_argument( + "--new-memory", action="store_true", help="Create a new spatial memory from scratch" + ) + parser.add_argument( + "--spatial-memory-dir", type=str, help="Directory for storing spatial memory data" + ) return parser.parse_args() + args = parse_arguments() # Initialize robot with spatial memory parameters -robot = UnitreeGo2(ip=os.getenv('ROBOT_IP'), - skills=MyUnitreeSkills(), - mock_connection=False, - spatial_memory_dir=args.spatial_memory_dir, # Will use default if None - new_memory=args.new_memory, # Create a new memory if specified - mode = "ai") +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + skills=MyUnitreeSkills(), + mock_connection=False, + spatial_memory_dir=args.spatial_memory_dir, # Will use default if None + new_memory=args.new_memory, # Create a new memory if specified + mode="ai", +) # Create a subject for agent responses agent_response_subject = rx.subject.Subject() @@ -79,7 +89,7 @@ def parse_arguments(): class_filter=class_filter, transform_to_map=robot.ros_control.transform_pose, detector=detector, - video_stream=video_stream + video_stream=video_stream, ) # Create visualization stream for web interface @@ -94,12 +104,13 @@ def parse_arguments(): ops.filter(lambda x: x is not None) ) + # Create a direct mapping that combines detection data with locations def combine_with_locations(object_detections): # Get locations from spatial memory try: locations = robot.get_spatial_memory().get_robot_locations() - + # Format the locations section locations_text = "\n\nSaved Robot Locations:\n" if locations: @@ -108,22 +119,22 @@ def combine_with_locations(object_detections): locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" else: locations_text += "None\n" - + # Simply concatenate the strings return object_detections + locations_text except Exception as e: print(f"Error adding locations: {e}") return object_detections + # Create the combined stream with a simple pipe operation -enhanced_data_stream = formatted_detection_stream.pipe( - ops.map(combine_with_locations), - ops.share() -) +enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) -streams = {"unitree_video": robot.get_ros_video_stream(), - "local_planner_viz": local_planner_viz_stream, - "object_detection": viz_stream} +streams = { + "unitree_video": robot.get_ros_video_stream(), + "local_planner_viz": local_planner_viz_stream, + "object_detection": viz_stream, +} text_streams = { "agent_responses": agent_response_stream, } @@ -145,7 +156,7 @@ def combine_with_locations(object_detections): skills=robot.get_skills(), system_query="What do you see", model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=0 + thinking_budget_tokens=0, ) # tts_node = tts() @@ -168,11 +179,9 @@ def combine_with_locations(object_detections): robot_skills.create_instance("Speak", tts_node=tts_node) # Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe( - lambda x: agent_response_subject.on_next(x) -) +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) print("ObserveStream and Kill skills registered and ready for use") print("Created memory.txt file") -web_interface.run() \ No newline at end of file +web_interface.run() diff --git a/tests/run_go2_ros.py b/tests/run_go2_ros.py index 459fa54b0c..2fcc74dcb1 100644 --- a/tests/run_go2_ros.py +++ b/tests/run_go2_ros.py @@ -6,10 +6,11 @@ from dimos.robot.unitree.unitree_go2 import UnitreeGo2, WebRTCConnectionMethod from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl + def get_env_var(var_name, default=None, required=False): """Get environment variable with validation.""" value = os.getenv(var_name, default) - if value == '': + if value == "": value = default if required and not value: raise ValueError(f"{var_name} environment variable is required") @@ -21,8 +22,7 @@ def get_env_var(var_name, default=None, required=False): robot_ip = get_env_var("ROBOT_IP") connection_method = get_env_var("CONNECTION_METHOD", "LocalSTA") serial_number = get_env_var("SERIAL_NUMBER", None) - output_dir = get_env_var("ROS_OUTPUT_DIR", - os.path.join(os.getcwd(), "assets/output/ros")) + output_dir = get_env_var("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) @@ -45,16 +45,17 @@ def get_env_var(var_name, default=None, required=False): else: ros_control = None - robot = UnitreeGo2(ip=robot_ip, - connection_method=connection_method, - serial_number=serial_number, - output_dir=output_dir, - ros_control=ros_control, - use_ros=use_ros, - use_webrtc=use_webrtc) + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + serial_number=serial_number, + output_dir=output_dir, + ros_control=ros_control, + use_ros=use_ros, + use_webrtc=use_webrtc, + ) time.sleep(5) try: - # Start perception print("\nStarting perception system...") @@ -72,14 +73,15 @@ def handle_frame(frame): try: # Save frame to output directory if desired for debugging frame streaming # MAKE SURE TO CHANGE OUTPUT DIR depending on if running in ROS or local - #frame_path = os.path.join(output_dir, f"frame_{frame_count:04d}.jpg") - #success = cv2.imwrite(frame_path, frame) - #print(f"Frame #{frame_count} {'saved successfully' if success else 'failed to save'} to {frame_path}") + # frame_path = os.path.join(output_dir, f"frame_{frame_count:04d}.jpg") + # success = cv2.imwrite(frame_path, frame) + # print(f"Frame #{frame_count} {'saved successfully' if success else 'failed to save'} to {frame_path}") pass except Exception as e: print(f"Error in handle_frame: {e}") import traceback + print(traceback.format_exc()) def handle_error(error): @@ -94,7 +96,8 @@ def handle_completion(): subscription = processed_stream.subscribe( on_next=handle_frame, on_error=lambda e: print(f"Subscription error: {e}"), - on_completed=lambda: print("Subscription completed")) + on_completed=lambda: print("Subscription completed"), + ) print("Subscription created successfully") except Exception as e: print(f"Error creating subscription: {e}") @@ -109,38 +112,37 @@ def handle_completion(): print("\n๐Ÿค– QUEUEING WEBRTC COMMANDS BACK-TO-BACK FOR TESTING UnitreeGo2๐Ÿค–\n") # Dance 1 - robot.webrtc_req(api_id=1033) + robot.webrtc_req(api_id=1033) print("Queued: WiggleHips (1033)") robot.reverse(distance=0.2, speed=0.5) print("Queued: Reverse 0.5m at 0.5m/s") # Wiggle Hips - robot.webrtc_req(api_id=1033) + robot.webrtc_req(api_id=1033) print("Queued: WiggleHips (1033)") robot.move(distance=0.2, speed=0.5) print("Queued: Move forward 1.0m at 0.5m/s") - robot.webrtc_req(api_id=1017) + robot.webrtc_req(api_id=1017) print("Queued: Stretch (1017)") robot.move(distance=0.2, speed=0.5) print("Queued: Move forward 1.0m at 0.5m/s") - robot.webrtc_req(api_id=1017) + robot.webrtc_req(api_id=1017) print("Queued: Stretch (1017)") robot.reverse(distance=0.2, speed=0.5) print("Queued: Reverse 0.5m at 0.5m/s") - robot.webrtc_req(api_id=1017) - print("Queued: Stretch (1017)")\ - + robot.webrtc_req(api_id=1017) + print("Queued: Stretch (1017)") robot.spin(degrees=-90.0, speed=45.0) print("Queued: Spin right 90 degrees at 45 degrees/s") - robot.spin(degrees=90.0, speed=45.0) + robot.spin(degrees=90.0, speed=45.0) print("Queued: Spin left 90 degrees at 45 degrees/s") # To prevent termination @@ -149,14 +151,14 @@ def handle_completion(): except KeyboardInterrupt: print("\nStopping perception...") - if 'subscription' in locals(): + if "subscription" in locals(): subscription.dispose() except Exception as e: print(f"Error in main loop: {e}") finally: # Cleanup print("Cleaning up resources...") - if 'subscription' in locals(): + if "subscription" in locals(): subscription.dispose() del robot print("Cleanup complete.") diff --git a/tests/simple_agent_test.py b/tests/simple_agent_test.py index 0b7f64bac9..a5506f820a 100644 --- a/tests/simple_agent_test.py +++ b/tests/simple_agent_test.py @@ -7,19 +7,19 @@ import os # Initialize robot -robot = UnitreeGo2(ip=os.getenv('ROBOT_IP'), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills()) +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() +) # Initialize agent agent = OpenAIAgent( - dev_name="UnitreeExecutionAgent", - input_video_stream=robot.get_ros_video_stream(), - skills=robot.get_skills(), - system_query="Wiggle when you see a person! Jump when you see a person waving!" - ) + dev_name="UnitreeExecutionAgent", + input_video_stream=robot.get_ros_video_stream(), + skills=robot.get_skills(), + system_query="Wiggle when you see a person! Jump when you see a person waving!", +) try: input("Press ESC to exit...") except KeyboardInterrupt: - print("\nExiting...") \ No newline at end of file + print("\nExiting...") diff --git a/tests/test_agent.py b/tests/test_agent.py index d05b18bb34..1319533d78 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6,6 +6,7 @@ from dotenv import load_dotenv + # Sanity check for dotenv def test_dotenv(): print("test_dotenv:") @@ -13,9 +14,11 @@ def test_dotenv(): openai_api_key = os.getenv("OPENAI_API_KEY") print("\t\tOPENAI_API_KEY: ", openai_api_key) + # Sanity check for openai connection def test_openai_connection(): from openai import OpenAI + client = OpenAI() print("test_openai_connection:") response = client.chat.completions.create( @@ -38,5 +41,6 @@ def test_openai_connection(): ) print("\t\tOpenAI Response: ", response.choices[0]) + test_dotenv() test_openai_connection() diff --git a/tests/test_agent_alibaba.py b/tests/test_agent_alibaba.py index 6d39efe85a..9519387b7b 100644 --- a/tests/test_agent_alibaba.py +++ b/tests/test_agent_alibaba.py @@ -32,9 +32,9 @@ # Specify the OpenAI client for Alibaba qwen_client = OpenAI( - base_url='https://dashscope-intl.aliyuncs.com/compatible-mode/v1', - api_key=os.getenv('ALIBABA_API_KEY'), - ) + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key=os.getenv("ALIBABA_API_KEY"), +) # Initialize Unitree skills myUnitreeSkills = MyUnitreeSkills() @@ -42,18 +42,18 @@ # Initialize agent agent = OpenAIAgent( - dev_name="AlibabaExecutionAgent", - openai_client=qwen_client, - model_name="qwen2.5-vl-72b-instruct", - tokenizer=HuggingFaceTokenizer(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), - max_output_tokens_per_request=8192, - input_video_stream=video_stream, - # system_query="Tell me the number in the video. Find me the center of the number spotted, and print the coordinates to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", - system_query="Tell me about any objects seen. Print the coordinates for center of the objects seen to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", - skills=myUnitreeSkills, - ) + dev_name="AlibabaExecutionAgent", + openai_client=qwen_client, + model_name="qwen2.5-vl-72b-instruct", + tokenizer=HuggingFaceTokenizer(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), + max_output_tokens_per_request=8192, + input_video_stream=video_stream, + # system_query="Tell me the number in the video. Find me the center of the number spotted, and print the coordinates to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", + system_query="Tell me about any objects seen. Print the coordinates for center of the objects seen to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", + skills=myUnitreeSkills, +) try: input("Press ESC to exit...") except KeyboardInterrupt: - print("\nExiting...") \ No newline at end of file + print("\nExiting...") diff --git a/tests/test_agent_ctransformers_gguf.py b/tests/test_agent_ctransformers_gguf.py index a48fa3c68e..6cd3405239 100644 --- a/tests/test_agent_ctransformers_gguf.py +++ b/tests/test_agent_ctransformers_gguf.py @@ -35,10 +35,10 @@ agent.run_observable_query(test_query).subscribe( on_next=lambda response: print(f"One-off query response: {response}"), on_error=lambda error: print(f"Error: {error}"), - on_completed=lambda: print("Query completed") + on_completed=lambda: print("Query completed"), ) try: input("Press ESC to exit...") except KeyboardInterrupt: - print("\nExiting...") \ No newline at end of file + print("\nExiting...") diff --git a/tests/test_agent_huggingface_local.py b/tests/test_agent_huggingface_local.py index d94b6a0470..4c4536a197 100644 --- a/tests/test_agent_huggingface_local.py +++ b/tests/test_agent_huggingface_local.py @@ -42,7 +42,7 @@ # Initialize agent agent = HuggingFaceLocalAgent( dev_name="HuggingFaceLLMAgent", - model_name= "Qwen/Qwen2.5-3B", + model_name="Qwen/Qwen2.5-3B", agent_type="HF-LLM", system_query=system_query, input_query_stream=query_provider.data_stream, @@ -59,14 +59,14 @@ # This will cause listening agents to consume the queries and respond # to them via skill execution and provide 1-shot responses. query_provider.start_query_stream( - query_template= - "{query}; User: travel forward by 10 meters", + query_template="{query}; User: travel forward by 10 meters", frequency=10, start_count=1, end_count=10000, - step=1) + step=1, +) try: input("Press ESC to exit...") except KeyboardInterrupt: - print("\nExiting...") \ No newline at end of file + print("\nExiting...") diff --git a/tests/test_agent_huggingface_local_jetson.py b/tests/test_agent_huggingface_local_jetson.py index eb260dcc90..6d29b3903f 100644 --- a/tests/test_agent_huggingface_local_jetson.py +++ b/tests/test_agent_huggingface_local_jetson.py @@ -43,7 +43,7 @@ agent = HuggingFaceLocalAgent( dev_name="HuggingFaceLLMAgent", model_name="Qwen/Qwen2.5-0.5B", - #model_name="HuggingFaceTB/SmolLM2-135M", + # model_name="HuggingFaceTB/SmolLM2-135M", agent_type="HF-LLM", system_query=system_query, input_query_stream=query_provider.data_stream, @@ -60,12 +60,12 @@ # This will cause listening agents to consume the queries and respond # to them via skill execution and provide 1-shot responses. query_provider.start_query_stream( - query_template= - "{query}; User: Hello how are you!", + query_template="{query}; User: Hello how are you!", frequency=30, start_count=1, end_count=10000, - step=1) + step=1, +) try: input("Press ESC to exit...") diff --git a/tests/test_agent_huggingface_remote.py b/tests/test_agent_huggingface_remote.py index dd533b2e78..7129523bf0 100644 --- a/tests/test_agent_huggingface_remote.py +++ b/tests/test_agent_huggingface_remote.py @@ -39,26 +39,26 @@ # Initialize agent agent = HuggingFaceRemoteAgent( - dev_name="HuggingFaceRemoteAgent", - model_name="meta-llama/Meta-Llama-3-8B-Instruct", - tokenizer=HuggingFaceTokenizer(model_name="meta-llama/Meta-Llama-3-8B-Instruct"), - max_output_tokens_per_request=8192, - input_query_stream=query_provider.data_stream, - # input_video_stream=video_stream, - system_query="You are a helpful assistant that can answer questions and help with tasks.", - ) + dev_name="HuggingFaceRemoteAgent", + model_name="meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer=HuggingFaceTokenizer(model_name="meta-llama/Meta-Llama-3-8B-Instruct"), + max_output_tokens_per_request=8192, + input_query_stream=query_provider.data_stream, + # input_video_stream=video_stream, + system_query="You are a helpful assistant that can answer questions and help with tasks.", +) # Start the query stream. # Queries will be pushed every 1 second, in a count from 100 to 5000. query_provider.start_query_stream( - query_template= - "{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response.", + query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response.", frequency=5, start_count=1, end_count=10000, - step=1) + step=1, +) try: input("Press ESC to exit...") except KeyboardInterrupt: - print("\nExiting...") \ No newline at end of file + print("\nExiting...") diff --git a/tests/test_audio_agent.py b/tests/test_audio_agent.py index debdee11a6..61c30031fb 100644 --- a/tests/test_audio_agent.py +++ b/tests/test_audio_agent.py @@ -5,7 +5,6 @@ def main(): - stt_node = stt() agent = OpenAIAgent( diff --git a/tests/test_claude_agent_query.py b/tests/test_claude_agent_query.py index 96319d760b..aabd85bc12 100644 --- a/tests/test_claude_agent_query.py +++ b/tests/test_claude_agent_query.py @@ -21,12 +21,9 @@ load_dotenv() # Create a ClaudeAgent instance -agent = ClaudeAgent( - dev_name="test_agent", - query="What is the capital of France?" -) +agent = ClaudeAgent(dev_name="test_agent", query="What is the capital of France?") # Use the stream_query method to get a response response = agent.run_observable_query("What is the capital of France?").run() -print(f"Response from Claude Agent: {response}") \ No newline at end of file +print(f"Response from Claude Agent: {response}") diff --git a/tests/test_claude_agent_skills_query.py b/tests/test_claude_agent_skills_query.py index 2f1e34ec73..1aaeb795f1 100644 --- a/tests/test_claude_agent_skills_query.py +++ b/tests/test_claude_agent_skills_query.py @@ -37,18 +37,22 @@ # Load API key from environment load_dotenv() -robot = UnitreeGo2(ip=os.getenv('ROBOT_IP'), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills(), - mock_connection=False) +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + mock_connection=False, +) # Create a subject for agent responses agent_response_subject = rx.subject.Subject() agent_response_stream = agent_response_subject.pipe(ops.share()) local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) -streams = {"unitree_video": robot.get_ros_video_stream(), - "local_planner_viz": local_planner_viz_stream} +streams = { + "unitree_video": robot.get_ros_video_stream(), + "local_planner_viz": local_planner_viz_stream, +} text_streams = { "agent_responses": agent_response_stream, } @@ -73,7 +77,7 @@ Example: If the user asks to move forward 1 meter, call the Move function with distance=1""", model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=2000 + thinking_budget_tokens=2000, ) tts_node = tts() @@ -100,9 +104,7 @@ robot_skills.create_instance("Speak", tts_node=tts_node) # Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe( - lambda x: agent_response_subject.on_next(x) -) +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) print("ObserveStream and Kill skills registered and ready for use") print("Created memory.txt file") @@ -120,11 +122,14 @@ def msg_handler(msgtype, data): except Exception as e: print(f"Error setting goal: {e}") return + + def threaded_msg_handler(msgtype, data): thread = threading.Thread(target=msg_handler, args=(msgtype, data)) thread.daemon = True thread.start() + websocket_vis.msg_handler = threaded_msg_handler web_interface.run() diff --git a/tests/test_command_pose_unitree.py b/tests/test_command_pose_unitree.py index 2311593a28..0537f5c446 100644 --- a/tests/test_command_pose_unitree.py +++ b/tests/test_command_pose_unitree.py @@ -1,5 +1,6 @@ import os import sys + # Add the parent directory to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -11,9 +12,10 @@ import math # Initialize robot -robot = UnitreeGo2(ip=os.getenv('ROBOT_IP'), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills()) +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() +) + # Helper function to send pose commands continuously for a duration def send_pose_for_duration(roll, pitch, yaw, duration, hz=10): @@ -21,7 +23,8 @@ def send_pose_for_duration(roll, pitch, yaw, duration, hz=10): start_time = time.time() while time.time() - start_time < duration: robot.pose_command(roll=roll, pitch=pitch, yaw=yaw) - time.sleep(1.0/hz) # Sleep to achieve the desired frequency + time.sleep(1.0 / hz) # Sleep to achieve the desired frequency + # Test pose commands @@ -62,4 +65,4 @@ def send_pose_for_duration(roll, pitch, yaw, duration, hz=10): while True: time.sleep(1) except KeyboardInterrupt: - print("Test terminated by user") \ No newline at end of file + print("Test terminated by user") diff --git a/tests/test_header.py b/tests/test_header.py index 8ae057208d..48ea6dd509 100644 --- a/tests/test_header.py +++ b/tests/test_header.py @@ -26,31 +26,33 @@ # Add the parent directory of 'tests' to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + def get_caller_info(): """Identify the filename of the caller in the stack. - + Examines the call stack to find the first non-internal file that called this module. Skips the current file and Python internal files. - + Returns: str: The basename of the caller's filename, or "unknown" if not found. """ current_file = os.path.abspath(__file__) - + # Look through the call stack to find the first file that's not this one for frame in inspect.stack()[1:]: filename = os.path.abspath(frame.filename) # Skip this file and Python internals if filename != current_file and "= self.print_interval: self.last_print_time = current_time - + if not objects: print("\n[No objects detected]") return - - print("\n" + "="*50) + + print("\n" + "=" * 50) print(f"Detected {len(objects)} objects at {time.strftime('%H:%M:%S')}:") - print("="*50) - + print("=" * 50) + for i, obj in enumerate(objects): pos = obj["position"] rot = obj["rotation"] size = obj["size"] - - print(f"{i+1}. {obj['label']} (ID: {obj['object_id']}, Conf: {obj['confidence']:.2f})") + + print( + f"{i + 1}. {obj['label']} (ID: {obj['object_id']}, Conf: {obj['confidence']:.2f})" + ) print(f" Position: x={pos['x']:.2f}, y={pos['y']:.2f}, z={pos['z']:.2f} m") print(f" Rotation: yaw={rot['yaw']:.2f} rad") print(f" Size: width={size['width']:.2f}, height={size['height']:.2f} m") @@ -70,7 +82,7 @@ def print_results(self, objects: List[Dict[str, Any]]): def main(): # Get command line arguments args = parse_args() - + # Set up the result printer for console output result_printer = ResultPrinter(print_interval=1.0) @@ -78,20 +90,20 @@ def main(): min_confidence = 0.6 class_filter = None # No class filtering web_port = 5555 - + # Initialize detector detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) - + # Initialize based on mode if args.mode == "robot": print("Initializing in robot mode...") - + # Get robot IP from environment - robot_ip = os.getenv('ROBOT_IP') + robot_ip = os.getenv("ROBOT_IP") if not robot_ip: print("Error: ROBOT_IP environment variable not set.") sys.exit(1) - + # Initialize robot robot = UnitreeGo2( ip=robot_ip, @@ -108,99 +120,92 @@ def main(): class_filter=class_filter, transform_to_map=robot.ros_control.transform_pose, detector=detector, - video_stream=video_stream + video_stream=video_stream, ) - - else: # webcam mode print("Initializing in webcam mode...") - + # Define camera intrinsics for the webcam # These are approximate values for a typical 640x480 webcam width, height = 640, 480 focal_length_mm = 3.67 # mm (typical webcam) - sensor_width_mm = 4.8 # mm (1/4" sensor) - + sensor_width_mm = 4.8 # mm (1/4" sensor) + # Calculate focal length in pixels focal_length_x_px = width * focal_length_mm / sensor_width_mm focal_length_y_px = height * focal_length_mm / sensor_width_mm - + # Principal point (center of image) cx, cy = width / 2, height / 2 - + # Camera intrinsics in [fx, fy, cx, cy] format camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] - + # Initialize video provider and ObjectDetectionStream video_provider = VideoProvider("test_camera", video_source=0) # Default camera # Create video stream - video_stream = backpressure(video_provider.capture_video_as_observable(realtime=True, fps=30)) + video_stream = backpressure( + video_provider.capture_video_as_observable(realtime=True, fps=30) + ) object_detector = ObjectDetectionStream( camera_intrinsics=camera_intrinsics, min_confidence=min_confidence, class_filter=class_filter, detector=detector, - video_stream=video_stream + video_stream=video_stream, ) - - - + # Set placeholder robot for cleanup robot = None - + # Create visualization stream for web interface viz_stream = object_detector.get_stream().pipe( ops.share(), ops.map(lambda x: x["viz_frame"] if x is not None else None), ops.filter(lambda x: x is not None), ) - + # Create stop event for clean shutdown stop_event = threading.Event() - + # Define subscription callback to print results def on_next(result): if stop_event.is_set(): return - + # Print detected objects to console if "objects" in result: result_printer.print_results(result["objects"]) - + def on_error(error): print(f"Error in detection stream: {error}") stop_event.set() - + def on_completed(): print("Detection stream completed") stop_event.set() - + try: # Subscribe to the detection stream subscription = object_detector.get_stream().subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed + on_next=on_next, on_error=on_error, on_completed=on_completed ) - + # Set up web interface print("Initializing web interface...") - web_interface = RobotWebInterface( - port=web_port, - object_detection=viz_stream - ) - + web_interface = RobotWebInterface(port=web_port, object_detection=viz_stream) + # Print configuration information print("\nObjectDetectionStream Test Running:") print(f"Mode: {args.mode}") print(f"Web Interface: http://localhost:{web_port}") print("\nPress Ctrl+C to stop the test\n") - + # Start web server (blocking call) web_interface.run() - + except KeyboardInterrupt: print("\nTest interrupted by user") except Exception as e: @@ -209,16 +214,16 @@ def on_completed(): # Clean up resources print("Cleaning up resources...") stop_event.set() - + if subscription: subscription.dispose() - + if args.mode == "robot" and robot: robot.cleanup() elif args.mode == "webcam": - if 'video_provider' in locals(): + if "video_provider" in locals(): video_provider.dispose_all() - + print("Test completed") diff --git a/tests/test_object_tracking_webcam.py b/tests/test_object_tracking_webcam.py index e72912ed93..ed72d76e34 100644 --- a/tests/test_object_tracking_webcam.py +++ b/tests/test_object_tracking_webcam.py @@ -16,20 +16,21 @@ tracker_initialized = False object_size = 0.30 # Hardcoded object size in meters (adjust based on your tracking target) + def mouse_callback(event, x, y, flags, param): global selecting_bbox, bbox_points, current_bbox, tracker_initialized, tracker_stream - + if event == cv2.EVENT_LBUTTONDOWN: # Start bbox selection selecting_bbox = True bbox_points = [(x, y)] current_bbox = None tracker_initialized = False - + elif event == cv2.EVENT_MOUSEMOVE and selecting_bbox: # Update current selection for visualization current_bbox = [bbox_points[0][0], bbox_points[0][1], x, y] - + elif event == cv2.EVENT_LBUTTONUP: # End bbox selection selecting_bbox = False @@ -40,49 +41,49 @@ def mouse_callback(event, x, y, flags, param): # Ensure x1,y1 is top-left and x2,y2 is bottom-right current_bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] # Add the bbox to the tracking queue - if param.get('bbox_queue') and not tracker_initialized: - param['bbox_queue'].put((current_bbox, object_size)) + if param.get("bbox_queue") and not tracker_initialized: + param["bbox_queue"].put((current_bbox, object_size)) tracker_initialized = True def main(): global tracker_initialized - + # Create queues for thread communication frame_queue = queue.Queue(maxsize=5) bbox_queue = queue.Queue() stop_event = threading.Event() - + # Logitech C920e camera parameters at 480p # Convert physical parameters to pixel-based intrinsics width, height = 640, 480 focal_length_mm = 3.67 # mm - sensor_width_mm = 4.8 # mm (1/4" sensor) + sensor_width_mm = 4.8 # mm (1/4" sensor) sensor_height_mm = 3.6 # mm - + # Calculate focal length in pixels focal_length_x_px = width * focal_length_mm / sensor_width_mm focal_length_y_px = height * focal_length_mm / sensor_height_mm - + # Principal point (assuming center of image) cx = width / 2 cy = height / 2 - + # Final camera intrinsics in [fx, fy, cx, cy] format camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] - + # Initialize video provider and object tracking stream video_provider = VideoProvider("test_camera", video_source=0) tracker_stream = ObjectTrackingStream( camera_intrinsics=camera_intrinsics, camera_pitch=0.0, # Adjust if your camera is tilted - camera_height=0.5 # Height of camera from ground in meters (adjust as needed) + camera_height=0.5, # Height of camera from ground in meters (adjust as needed) ) - + # Create video stream video_stream = video_provider.capture_video_as_observable(realtime=True, fps=30) tracking_stream = tracker_stream.create_stream(video_stream) - + # Define callbacks for the tracking stream def on_next(result): if stop_event.is_set(): @@ -90,55 +91,74 @@ def on_next(result): # Get the visualization frame viz_frame = result["viz_frame"] - + # If we're selecting a bbox, draw the current selection if selecting_bbox and current_bbox is not None: x1, y1, x2, y2 = current_bbox cv2.rectangle(viz_frame, (x1, y1), (x2, y2), (0, 255, 255), 2) - + # Add instructions - cv2.putText(viz_frame, "Click and drag to select object", (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - cv2.putText(viz_frame, f"Object size: {object_size:.2f}m", (10, 60), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - + cv2.putText( + viz_frame, + "Click and drag to select object", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + viz_frame, + f"Object size: {object_size:.2f}m", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + # Show tracking status status = "Tracking" if tracker_initialized else "Not tracking" - cv2.putText(viz_frame, f"Status: {status}", (10, 90), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0) if tracker_initialized else (0, 0, 255), 2) - + cv2.putText( + viz_frame, + f"Status: {status}", + (10, 90), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (0, 255, 0) if tracker_initialized else (0, 0, 255), + 2, + ) + # Put frame in queue for main thread to display try: frame_queue.put_nowait(viz_frame) except queue.Full: # Skip frame if queue is full pass - + def on_error(error): print(f"Error: {error}") stop_event.set() - + def on_completed(): print("Stream completed") stop_event.set() - + # Start the subscription subscription = None - + try: # Subscribe to start processing in background thread subscription = tracking_stream.subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed + on_next=on_next, on_error=on_error, on_completed=on_completed ) - + print("Object tracking started. Click and drag to select an object. Press 'q' to exit.") - + # Create window and set mouse callback cv2.namedWindow("Object Tracker") - cv2.setMouseCallback("Object Tracker", mouse_callback, {'bbox_queue': bbox_queue}) - + cv2.setMouseCallback("Object Tracker", mouse_callback, {"bbox_queue": bbox_queue}) + # Main thread loop for displaying frames and handling bbox selection while not stop_event.is_set(): # Check if there's a new bbox to track @@ -149,35 +169,35 @@ def on_completed(): tracker_stream.track(new_bbox, size=size) except queue.Empty: pass - + try: # Get frame with timeout viz_frame = frame_queue.get(timeout=1.0) - + # Display the frame cv2.imshow("Object Tracker", viz_frame) # Check for exit key - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): print("Exit key pressed") break - + except queue.Empty: # No frame available, check if we should continue - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): print("Exit key pressed") break continue - + except KeyboardInterrupt: print("\nKeyboard interrupt received. Stopping...") finally: # Signal threads to stop stop_event.set() - + # Clean up resources if subscription: subscription.dispose() - + video_provider.dispose_all() tracker_stream.cleanup() cv2.destroyAllWindows() @@ -185,4 +205,4 @@ def on_completed(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_object_tracking_with_qwen.py b/tests/test_object_tracking_with_qwen.py index 547ee7e139..c2cb004f90 100644 --- a/tests/test_object_tracking_with_qwen.py +++ b/tests/test_object_tracking_with_qwen.py @@ -29,7 +29,7 @@ # Logitech C920e camera parameters at 480p width, height = 640, 480 focal_length_mm = 3.67 # mm -sensor_width_mm = 4.8 # mm (1/4" sensor) +sensor_width_mm = 4.8 # mm (1/4" sensor) sensor_height_mm = 3.6 # mm # Calculate focal length in pixels @@ -43,9 +43,7 @@ # Initialize video provider and object tracking stream video_provider = VideoProvider("webcam", video_source=0) tracker_stream = ObjectTrackingStream( - camera_intrinsics=camera_intrinsics, - camera_pitch=0.0, - camera_height=0.5 + camera_intrinsics=camera_intrinsics, camera_pitch=0.0, camera_height=0.5 ) # Create video streams @@ -53,8 +51,11 @@ tracking_stream = tracker_stream.create_stream(video_stream) # Check if display is available -if 'DISPLAY' not in os.environ: - raise RuntimeError("No display available. Please set DISPLAY environment variable or run in headless mode.") +if "DISPLAY" not in os.environ: + raise RuntimeError( + "No display available. Please set DISPLAY environment variable or run in headless mode." + ) + # Define callbacks for the tracking stream def on_next(result): @@ -64,38 +65,55 @@ def on_next(result): # Get the visualization frame viz_frame = result["viz_frame"] - + # Add information to the visualization - cv2.putText(viz_frame, f"Tracking {tracking_object_name}", (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - cv2.putText(viz_frame, f"Object size: {object_size:.2f}m", (10, 60), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - + cv2.putText( + viz_frame, + f"Tracking {tracking_object_name}", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + viz_frame, + f"Object size: {object_size:.2f}m", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + # Show tracking status status = "Tracking" if tracker_initialized else "Waiting for detection" color = (0, 255, 0) if tracker_initialized else (0, 0, 255) - cv2.putText(viz_frame, f"Status: {status}", (10, 90), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) - + cv2.putText(viz_frame, f"Status: {status}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) + # If detection is in progress, show a message if detection_in_progress: - cv2.putText(viz_frame, "Querying Qwen...", (10, 120), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) - + cv2.putText( + viz_frame, "Querying Qwen...", (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2 + ) + # Put frame in queue for main thread to display try: frame_queue.put_nowait(viz_frame) except queue.Full: pass + def on_error(error): print(f"Error: {error}") stop_event.set() + def on_completed(): print("Stream completed") stop_event.set() + # Start the subscription subscription = None @@ -105,14 +123,12 @@ def on_completed(): detection_in_progress = False # Subscribe to start processing in background thread subscription = tracking_stream.subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed + on_next=on_next, on_error=on_error, on_completed=on_completed ) - + print("Object tracking with Qwen started. Press 'q' to exit.") print("Waiting for initial object detection...") - + # Main thread loop for displaying frames and updating tracking while not stop_event.is_set(): # Check if we need to update tracking @@ -129,14 +145,14 @@ def detection_task(): try: result = get_bbox_from_qwen(video_stream, object_name=object_name) print(f"Got result from Qwen: {result}") - + if result: bbox, size = result print(f"Detected object at {bbox} with size {size}") tracker_stream.track(bbox, size=size) tracker_initialized = True return - + print("No object detected by Qwen") tracker_initialized = False tracker_stream.stop_track() @@ -150,37 +166,37 @@ def detection_task(): # Run detection task in a separate thread threading.Thread(target=detection_task, daemon=True).start() - + try: # Get frame with timeout viz_frame = frame_queue.get(timeout=0.1) - + # Display the frame cv2.imshow("Object Tracking with Qwen", viz_frame) - + # Check for exit key - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): print("Exit key pressed") break - + except queue.Empty: # No frame available, check if we should continue - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): print("Exit key pressed") break continue - + except KeyboardInterrupt: print("\nKeyboard interrupt received. Stopping...") finally: # Signal threads to stop stop_event.set() - + # Clean up resources if subscription: subscription.dispose() - + video_provider.dispose_all() tracker_stream.cleanup() cv2.destroyAllWindows() - print("Cleanup complete") \ No newline at end of file + print("Cleanup complete") diff --git a/tests/test_observe_stream_skill.py b/tests/test_observe_stream_skill.py index d0b87d4ff7..10532c8d2e 100644 --- a/tests/test_observe_stream_skill.py +++ b/tests/test_observe_stream_skill.py @@ -23,32 +23,28 @@ from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.logging_config import setup_logger import tests.test_header + logger = setup_logger("tests.test_observe_stream_skill") load_dotenv() + def main(): # Initialize the robot with mock connection for testing robot = UnitreeGo2( - ip=os.getenv('ROBOT_IP', '192.168.123.161'), - skills=MyUnitreeSkills(), - mock_connection=True + ip=os.getenv("ROBOT_IP", "192.168.123.161"), skills=MyUnitreeSkills(), mock_connection=True ) - + agent_response_subject = rx.subject.Subject() agent_response_stream = agent_response_subject.pipe(ops.share()) - + streams = {"unitree_video": robot.get_ros_video_stream()} text_streams = { "agent_responses": agent_response_stream, } - - web_interface = RobotWebInterface( - port=5555, - text_streams=text_streams, - **streams - ) - + + web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + agent = ClaudeAgent( dev_name="test_agent", input_query_stream=web_interface.query_stream, @@ -57,62 +53,65 @@ def main(): When you see an image, describe what you see and alert if you notice any people or important changes. Be concise but thorough in your observations.""", model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=10000 - ) - - agent.get_response_observable().subscribe( - lambda x: agent_response_subject.on_next(x) + thinking_budget_tokens=10000, ) - + + agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + robot_skills = robot.get_skills() - + robot_skills.add(ObserveStream) robot_skills.add(KillSkill) - + robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) robot_skills.create_instance("KillSkill", skill_library=robot_skills) - + web_interface_thread = threading.Thread(target=web_interface.run) web_interface_thread.daemon = True web_interface_thread.start() - + logger.info("Starting monitor skill...") - + memory_file = os.path.join(agent.output_dir, "memory.txt") with open(memory_file, "a") as f: - f.write("SKILL CALL: ObserveStream(timestep=10.0, query_text='What do you see in this image? Alert me if you see any people.', max_duration=120.0)") - - result = robot_skills.call("ObserveStream", - timestep=10.0, # 20 seconds between monitoring queries - query_text="What do you see in this image? Alert me if you see any people.", - max_duration=120.0) # Run for 120 seconds + f.write( + "SKILL CALL: ObserveStream(timestep=10.0, query_text='What do you see in this image? Alert me if you see any people.', max_duration=120.0)" + ) + + result = robot_skills.call( + "ObserveStream", + timestep=10.0, # 20 seconds between monitoring queries + query_text="What do you see in this image? Alert me if you see any people.", + max_duration=120.0, + ) # Run for 120 seconds logger.info(f"Monitor skill result: {result}") - + logger.info(f"Running skills: {robot_skills.get_running_skills().keys()}") - + try: logger.info("Observer running. Will stop after 35 seconds...") time.sleep(20.0) logger.info(f"Running skills before kill: {robot_skills.get_running_skills().keys()}") logger.info("Killing the observer skill...") - + memory_file = os.path.join(agent.output_dir, "memory.txt") with open(memory_file, "a") as f: f.write("\n\nSKILL CALL: KillSkill(skill_name='observer')\n\n") - + kill_result = robot_skills.call("KillSkill", skill_name="observer") logger.info(f"Kill skill result: {kill_result}") - + logger.info(f"Running skills after kill: {robot_skills.get_running_skills().keys()}") - + # Keep test running until user interrupts while True: time.sleep(1.0) except KeyboardInterrupt: logger.info("Test interrupted by user") - + logger.info("Test completed") + if __name__ == "__main__": main() diff --git a/tests/test_person_following_robot.py b/tests/test_person_following_robot.py index 22a51e0670..c082cb1b57 100644 --- a/tests/test_person_following_robot.py +++ b/tests/test_person_following_robot.py @@ -16,12 +16,12 @@ def main(): # Hardcoded parameters timeout = 60.0 # Maximum time to follow a person (seconds) distance = 0.5 # Desired distance to maintain from target (meters) - + print("Initializing Unitree Go2 robot...") - + # Initialize the robot with ROS control and skills robot = UnitreeGo2( - ip=os.getenv('ROBOT_IP'), + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills(), ) @@ -33,58 +33,59 @@ def main(): RxOps.filter(lambda x: x is not None), ) video_stream = robot.get_ros_video_stream() - + try: # Set up web interface logger.info("Initializing web interface") - streams = { - "unitree_video": video_stream, - "person_tracking": viz_stream - } - - web_interface = RobotWebInterface( - port=5555, - **streams - ) - + streams = {"unitree_video": video_stream, "person_tracking": viz_stream} + + web_interface = RobotWebInterface(port=5555, **streams) + # Wait for camera and tracking to initialize print("Waiting for camera and tracking to initialize...") time.sleep(5) # Get initial point from Qwen max_retries = 5 - delay = 3 - + delay = 3 + for attempt in range(max_retries): try: - qwen_point = eval(query_single_frame_observable( - video_stream, - "Look at this frame and point to the person shirt. Return ONLY their center coordinates as a tuple (x,y)." - ).pipe(RxOps.take(1)).run()) # Get first response and convert string tuple to actual tuple + qwen_point = eval( + query_single_frame_observable( + video_stream, + "Look at this frame and point to the person shirt. Return ONLY their center coordinates as a tuple (x,y).", + ) + .pipe(RxOps.take(1)) + .run() + ) # Get first response and convert string tuple to actual tuple logger.info(f"Found person at coordinates {qwen_point}") break # If successful, break out of retry loop except Exception as e: if attempt < max_retries - 1: - logger.error(f"Person not found. Attempt {attempt + 1}/{max_retries} failed. Retrying in {delay}s... Error: {e}") + logger.error( + f"Person not found. Attempt {attempt + 1}/{max_retries} failed. Retrying in {delay}s... Error: {e}" + ) time.sleep(delay) else: logger.error(f"Person not found after {max_retries} attempts. Last error: {e}") return - + # Start following human in a separate thread import threading + follow_thread = threading.Thread( target=lambda: robot.follow_human(timeout=timeout, distance=distance, point=qwen_point), - daemon=True + daemon=True, ) follow_thread.start() - + print(f"Following human at point {qwen_point} for {timeout} seconds...") print("Web interface available at http://localhost:5555") - + # Start web server (blocking call) web_interface.run() - + except KeyboardInterrupt: print("\nInterrupted by user") except Exception as e: diff --git a/tests/test_person_following_webcam.py b/tests/test_person_following_webcam.py index 09ca7f9319..11d2739504 100644 --- a/tests/test_person_following_webcam.py +++ b/tests/test_person_following_webcam.py @@ -17,52 +17,50 @@ def main(): frame_queue = queue.Queue(maxsize=5) result_queue = queue.Queue(maxsize=5) # For tracking results stop_event = threading.Event() - + # Logitech C920e camera parameters at 480p # Convert physical parameters to intrinsics [fx, fy, cx, cy] resolution = (640, 480) # 480p resolution focal_length_mm = 3.67 # mm sensor_size_mm = (4.8, 3.6) # mm (1/4" sensor) - + # Calculate focal length in pixels fx = (resolution[0] * focal_length_mm) / sensor_size_mm[0] fy = (resolution[1] * focal_length_mm) / sensor_size_mm[1] - + # Principal point (typically at image center) cx = resolution[0] / 2 cy = resolution[1] / 2 - + # Camera intrinsics in [fx, fy, cx, cy] format camera_intrinsics = [fx, fy, cx, cy] - + # Camera mounted parameters camera_pitch = np.deg2rad(-5) # negative for downward pitch camera_height = 1.4 # meters - + # Initialize video provider and person tracking stream video_provider = VideoProvider("test_camera", video_source=0) person_tracker = PersonTrackingStream( - camera_intrinsics=camera_intrinsics, - camera_pitch=camera_pitch, - camera_height=camera_height + camera_intrinsics=camera_intrinsics, camera_pitch=camera_pitch, camera_height=camera_height ) - + # Create streams video_stream = video_provider.capture_video_as_observable(realtime=False, fps=20) person_tracking_stream = person_tracker.create_stream(video_stream) - + # Create visual servoing object visual_servoing = VisualServoing( tracking_stream=person_tracking_stream, max_linear_speed=0.5, max_angular_speed=0.75, - desired_distance=2.5 + desired_distance=2.5, ) - + # Track if we have selected a person to follow selected_point = None tracking_active = False - + # Define callbacks for the tracking stream def on_next(result): if stop_event.is_set(): @@ -71,63 +69,61 @@ def on_next(result): # Get the visualization frame which already includes person detections # with bounding boxes, tracking IDs, and distance/angle information viz_frame = result["viz_frame"] - + # Store the result for the main thread to use with visual servoing try: result_queue.put_nowait(result) except queue.Full: # Skip if queue is full pass - + # Put frame in queue for main thread to display (non-blocking) try: frame_queue.put_nowait(viz_frame) except queue.Full: # Skip frame if queue is full pass - + def on_error(error): print(f"Error: {error}") stop_event.set() - + def on_completed(): print("Stream completed") stop_event.set() - + # Mouse callback for selecting a person to track def mouse_callback(event, x, y, flags, param): nonlocal selected_point, tracking_active - + if event == cv2.EVENT_LBUTTONDOWN: # Store the clicked point selected_point = (x, y) tracking_active = False # Will be set to True if start_tracking succeeds print(f"Selected point: {selected_point}") - + # Start the subscription subscription = None - + try: # Subscribe to start processing in background thread subscription = person_tracking_stream.subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed + on_next=on_next, on_error=on_error, on_completed=on_completed ) - + print("Person tracking visualization started.") print("Click on a person to start visual servoing. Press 'q' to exit.") - + # Set up mouse callback cv2.namedWindow("Person Tracking") cv2.setMouseCallback("Person Tracking", mouse_callback) - + # Main thread loop for displaying frames while not stop_event.is_set(): try: # Get frame with timeout (allows checking stop_event periodically) frame = frame_queue.get(timeout=1.0) - + # Call the visual servoing if we have a selected point if selected_point is not None: # If not actively tracking, try to start tracking @@ -136,56 +132,79 @@ def mouse_callback(event, x, y, flags, param): if not tracking_active: print("Failed to start tracking") selected_point = None - + # If tracking is active, update tracking if tracking_active: servoing_result = visual_servoing.updateTracking() - + # Display visual servoing output on the frame linear_vel = servoing_result.get("linear_vel", 0.0) angular_vel = servoing_result.get("angular_vel", 0.0) running = visual_servoing.running - - status_color = (0, 255, 0) if running else (0, 0, 255) # Green if running, red if not - + + status_color = ( + (0, 255, 0) if running else (0, 0, 255) + ) # Green if running, red if not + # Add velocity text to frame - cv2.putText(frame, f"Linear: {linear_vel:.2f} m/s", (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2) - cv2.putText(frame, f"Angular: {angular_vel:.2f} rad/s", (10, 60), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2) - cv2.putText(frame, f"Tracking: {'ON' if running else 'OFF'}", (10, 90), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2) - + cv2.putText( + frame, + f"Linear: {linear_vel:.2f} m/s", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + cv2.putText( + frame, + f"Angular: {angular_vel:.2f} rad/s", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + cv2.putText( + frame, + f"Tracking: {'ON' if running else 'OFF'}", + (10, 90), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + # If tracking is lost, reset selected_point and tracking_active if not running: selected_point = None tracking_active = False - + # Display the frame in main thread cv2.imshow("Person Tracking", frame) - + # Check for exit key - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): print("Exit key pressed") break - + except queue.Empty: # No frame available, check if we should continue - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): print("Exit key pressed") break continue - + except KeyboardInterrupt: print("\nKeyboard interrupt received. Stopping...") finally: # Signal threads to stop stop_event.set() - + # Clean up resources if subscription: subscription.dispose() - + visual_servoing.cleanup() video_provider.dispose_all() person_tracker.cleanup() diff --git a/tests/test_planning_agent_web_interface.py b/tests/test_planning_agent_web_interface.py index 46f58c9e4e..68bc711075 100644 --- a/tests/test_planning_agent_web_interface.py +++ b/tests/test_planning_agent_web_interface.py @@ -27,18 +27,19 @@ from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.utils.logging_config import logger + # from dimos.web.fastapi_server import FastAPIServer from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.threadpool import make_single_thread_scheduler + def main(): # Get environment variables robot_ip = os.getenv("ROBOT_IP") if not robot_ip: raise ValueError("ROBOT_IP environment variable is required") - connection_method = os.getenv("CONN_TYPE") or 'webrtc' - output_dir = os.getenv("ROS_OUTPUT_DIR", - os.path.join(os.getcwd(), "assets/output/ros")) + connection_method = os.getenv("CONN_TYPE") or "webrtc" + output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) # Initialize components as None for proper cleanup robot = None @@ -49,11 +50,13 @@ def main(): try: # Initialize robot logger.info("Initializing Unitree Robot") - robot = UnitreeGo2(ip=robot_ip, - connection_method=connection_method, - output_dir=output_dir, - mock_connection=False, - skills=MyUnitreeSkills()) + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir, + mock_connection=False, + skills=MyUnitreeSkills(), + ) # Set up video stream logger.info("Starting video stream") video_stream = robot.get_ros_video_stream() @@ -65,10 +68,10 @@ def main(): logger.info("Creating response streams") planner_response_subject = rx.subject.Subject() planner_response_stream = planner_response_subject.pipe(ops.share()) - + executor_response_subject = rx.subject.Subject() executor_response_stream = executor_response_subject.pipe(ops.share()) - + # Web interface mode with FastAPI server logger.info("Initializing FastAPI server") streams = {"unitree_video": video_stream} @@ -76,36 +79,33 @@ def main(): "planner_responses": planner_response_stream, "executor_responses": executor_response_stream, } - - web_interface = RobotWebInterface( - port=5555, text_streams=text_streams, **streams) + + web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) logger.info("Starting planning agent with web interface") planner = PlanningAgent( dev_name="TaskPlanner", model_name="gpt-4o", input_query_stream=web_interface.query_stream, - skills=robot.get_skills() + skills=robot.get_skills(), ) - + # Get planner's response observable logger.info("Setting up agent response streams") planner_responses = planner.get_response_observable() - + # Connect planner to its subject - planner_responses.subscribe( - lambda x: planner_response_subject.on_next(x) - ) + planner_responses.subscribe(lambda x: planner_response_subject.on_next(x)) planner_responses.subscribe( on_next=lambda x: logger.info(f"Planner response: {x}"), on_error=lambda e: logger.error(f"Planner error: {e}"), - on_completed=lambda: logger.info("Planner completed") + on_completed=lambda: logger.info("Planner completed"), ) - + # Initialize execution agent with robot skills logger.info("Starting execution agent") - system_query=dedent( + system_query = dedent( """ You are a robot execution agent that can execute tasks on a virtual robot. The sole text you will be given is the task to execute. @@ -119,7 +119,7 @@ def main(): output_dir=output_dir, skills=robot.get_skills(), system_query=system_query, - pool_scheduler=make_single_thread_scheduler() + pool_scheduler=make_single_thread_scheduler(), ) # Get executor's response observable @@ -129,13 +129,11 @@ def main(): executor_responses.subscribe( on_next=lambda x: logger.info(f"Executor response: {x}"), on_error=lambda e: logger.error(f"Executor error: {e}"), - on_completed=lambda: logger.info("Executor completed") + on_completed=lambda: logger.info("Executor completed"), ) - + # Connect executor to its subject - executor_responses.subscribe( - lambda x: executor_response_subject.on_next(x) - ) + executor_responses.subscribe(lambda x: executor_response_subject.on_next(x)) # Start web server (blocking call) logger.info("Starting FastAPI server") diff --git a/tests/test_planning_robot_agent.py b/tests/test_planning_robot_agent.py index 3b7b95cea2..b2d27ba3a5 100644 --- a/tests/test_planning_robot_agent.py +++ b/tests/test_planning_robot_agent.py @@ -29,14 +29,14 @@ from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.threadpool import make_single_thread_scheduler + def main(): # Get environment variables robot_ip = os.getenv("ROBOT_IP") if not robot_ip: raise ValueError("ROBOT_IP environment variable is required") - connection_method = os.getenv("CONN_TYPE") or 'webrtc' - output_dir = os.getenv("ROS_OUTPUT_DIR", - os.path.join(os.getcwd(), "assets/output/ros")) + connection_method = os.getenv("CONN_TYPE") or "webrtc" + output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) use_terminal = os.getenv("USE_TERMINAL", "").lower() == "true" use_terminal = True @@ -49,10 +49,12 @@ def main(): try: # Initialize robot logger.info("Initializing Unitree Robot") - robot = UnitreeGo2(ip=robot_ip, - connection_method=connection_method, - output_dir=output_dir, - mock_connection=True) + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir, + mock_connection=True, + ) # Set up video stream logger.info("Starting video stream") @@ -69,7 +71,7 @@ def main(): dev_name="TaskPlanner", model_name="gpt-4o", use_terminal=True, - skills=skills_instance + skills=skills_instance, ) else: # Web interface mode @@ -82,16 +84,16 @@ def main(): dev_name="TaskPlanner", model_name="gpt-4o", input_query_stream=web_interface.query_stream, - skills=skills_instance + skills=skills_instance, ) - + # Get planner's response observable logger.info("Setting up agent response streams") planner_responses = planner.get_response_observable() - + # Initialize execution agent with robot skills logger.info("Starting execution agent") - system_query=dedent( + system_query = dedent( """ You are a robot execution agent that can execute tasks on a virtual robot. You are given a task to execute and a list of skills that @@ -105,7 +107,7 @@ def main(): output_dir=output_dir, skills=skills_instance, system_query=system_query, - pool_scheduler=make_single_thread_scheduler() + pool_scheduler=make_single_thread_scheduler(), ) # Get executor's response observable @@ -115,7 +117,7 @@ def main(): executor_responses.subscribe( on_next=lambda x: logger.info(f"Executor response: {x}"), on_error=lambda e: logger.error(f"Executor error: {e}"), - on_completed=lambda: logger.info("Executor completed") + on_completed=lambda: logger.info("Executor completed"), ) if use_terminal: @@ -158,4 +160,4 @@ def main(): if __name__ == "__main__": sys.exit(main()) -# Example Task: Move the robot forward by 1 meter, then turn 90 degrees clockwise, then move backward by 1 meter, then turn a random angle counterclockwise, then repeat this sequence 5 times. \ No newline at end of file +# Example Task: Move the robot forward by 1 meter, then turn 90 degrees clockwise, then move backward by 1 meter, then turn a random angle counterclockwise, then repeat this sequence 5 times. diff --git a/tests/test_qwen_image_query.py b/tests/test_qwen_image_query.py index 6032430f2d..feaa8c0096 100644 --- a/tests/test_qwen_image_query.py +++ b/tests/test_qwen_image_query.py @@ -4,30 +4,32 @@ from PIL import Image from dimos.models.qwen.video_query import query_single_frame + def test_qwen_image_query(): """Test querying Qwen with a single image.""" # Skip if no API key - if not os.getenv('ALIBABA_API_KEY'): + if not os.getenv("ALIBABA_API_KEY"): print("ALIBABA_API_KEY not set") return - + # Load test image image_path = os.path.join(os.getcwd(), "assets", "test_spatial_memory", "frame_038.jpg") image = Image.open(image_path) - + # Test basic object detection query response = query_single_frame( image=image, - query="What objects do you see in this image? Return as a comma-separated list." + query="What objects do you see in this image? Return as a comma-separated list.", ) print(response) - + # Test coordinate query response = query_single_frame( image=image, - query="Return the center coordinates of any person in the image as a tuple (x,y)" + query="Return the center coordinates of any person in the image as a tuple (x,y)", ) print(response) - + + if __name__ == "__main__": - test_qwen_image_query() \ No newline at end of file + test_qwen_image_query() diff --git a/tests/test_robot.py b/tests/test_robot.py index b452100c85..011850b04e 100644 --- a/tests/test_robot.py +++ b/tests/test_robot.py @@ -9,19 +9,20 @@ from reactivex import operators as RxOps import tests.test_header + def main(): print("Initializing Unitree Go2 robot with local planner visualization...") - + # Initialize the robot with ROS control and skills robot = UnitreeGo2( - ip=os.getenv('ROBOT_IP'), + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills(), ) # Get the camera stream video_stream = robot.get_ros_video_stream() - + # The local planner visualization stream is created during robot initialization local_planner_stream = robot.local_planner_viz_stream @@ -30,21 +31,15 @@ def main(): RxOps.map(lambda x: x if x is not None else None), RxOps.filter(lambda x: x is not None), ) - + goal_following_thread = None try: # Set up web interface with both streams - streams = { - "camera": video_stream, - "local_planner": local_planner_stream - } - + streams = {"camera": video_stream, "local_planner": local_planner_stream} + # Create and start the web interface - web_interface = RobotWebInterface( - port=5555, - **streams - ) - + web_interface = RobotWebInterface(port=5555, **streams) + # Wait for initialization print("Waiting for camera and systems to initialize...") time.sleep(2) @@ -53,23 +48,18 @@ def main(): print("Starting navigation to local goal (2m ahead) in a separate thread...") goal_following_thread = threading.Thread( target=navigate_to_goal_local, - kwargs={ - 'robot': robot, - 'goal_xy_robot': (3.0, 0.0), - 'distance': 0.0, - 'timeout': 300 - }, - daemon=True + kwargs={"robot": robot, "goal_xy_robot": (3.0, 0.0), "distance": 0.0, "timeout": 300}, + daemon=True, ) goal_following_thread.start() print("Robot streams running") print("Web interface available at http://localhost:5555") print("Press Ctrl+C to exit") - + # Start web server (blocking call) web_interface.run() - + except KeyboardInterrupt: print("\nInterrupted by user") except Exception as e: diff --git a/tests/test_rtsp_video_provider.py b/tests/test_rtsp_video_provider.py index 0afa7b95cf..e3824740a6 100644 --- a/tests/test_rtsp_video_provider.py +++ b/tests/test_rtsp_video_provider.py @@ -36,17 +36,20 @@ # Load environment variables from .env file from dotenv import load_dotenv + load_dotenv() # RTSP URL must be provided as a command-line argument or environment variable RTSP_URL = os.environ.get("TEST_RTSP_URL", "") if len(sys.argv) > 1: - RTSP_URL = sys.argv[1] # Allow overriding with command-line argument + RTSP_URL = sys.argv[1] # Allow overriding with command-line argument elif RTSP_URL == "": - print("Please provide an RTSP URL for testing.") - print("You can set the TEST_RTSP_URL environment variable or pass it as a command-line argument.") - print("Example: python -m dimos.stream.rtsp_video_provider rtsp://...") - sys.exit(1) + print("Please provide an RTSP URL for testing.") + print( + "You can set the TEST_RTSP_URL environment variable or pass it as a command-line argument." + ) + print("Example: python -m dimos.stream.rtsp_video_provider rtsp://...") + sys.exit(1) logger.info(f"Attempting to connect to provided RTSP URL.") provider = RtspVideoProvider(dev_name="TestRtspCam", rtsp_url=RTSP_URL) @@ -56,34 +59,40 @@ logger.info("Subscribing to observable...") frame_counter = 0 -start_time = time.monotonic() # Re-initialize start_time -last_log_time = start_time # Keep this for interval timing +start_time = time.monotonic() # Re-initialize start_time +last_log_time = start_time # Keep this for interval timing # Create a subject for ffmpeg responses ffmpeg_response_subject = rx.subject.Subject() ffmpeg_response_stream = ffmpeg_response_subject.pipe(ops.observe_on(get_scheduler()), ops.share()) + def process_frame(frame: np.ndarray): """Callback function executed for each received frame.""" - global frame_counter, last_log_time, start_time # Add start_time to global + global frame_counter, last_log_time, start_time # Add start_time to global frame_counter += 1 current_time = time.monotonic() # Log stats periodically (e.g., every 5 seconds) if current_time - last_log_time >= 5.0: - total_elapsed_time = current_time - start_time # Calculate total elapsed time + total_elapsed_time = current_time - start_time # Calculate total elapsed time avg_fps = frame_counter / total_elapsed_time if total_elapsed_time > 0 else 0 logger.info(f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}") - ffmpeg_response_subject.on_next(f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}") - last_log_time = current_time # Update log time for the next interval + ffmpeg_response_subject.on_next( + f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}" + ) + last_log_time = current_time # Update log time for the next interval + def handle_error(error: Exception): """Callback function executed if the observable stream errors.""" - logger.error(f"Stream error: {error}", exc_info=True) # Log with traceback + logger.error(f"Stream error: {error}", exc_info=True) # Log with traceback + def handle_completion(): """Callback function executed when the observable stream completes.""" logger.info("Stream completed.") + # Subscribe to the observable stream processor = FrameProcessor() subscription = video_stream_observable.pipe( @@ -91,22 +100,16 @@ def handle_completion(): ops.observe_on(get_scheduler()), ops.share(), vops.with_jpeg_export(processor, suffix="reolink_", save_limit=30, loop=True), -).subscribe( - on_next=process_frame, - on_error=handle_error, - on_completed=handle_completion -) - -streams = { - "reolink_video": video_stream_observable -} +).subscribe(on_next=process_frame, on_error=handle_error, on_completed=handle_completion) + +streams = {"reolink_video": video_stream_observable} text_streams = { "ffmpeg_responses": ffmpeg_response_stream, } web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) -web_interface.run() # This may block the main thread +web_interface.run() # This may block the main thread # TODO: Redo disposal / keep-alive loop @@ -135,9 +138,9 @@ def handle_completion(): print("Cleanup finished.") # Final check (optional, for debugging) -time.sleep(1) # Give background threads a moment +time.sleep(1) # Give background threads a moment final_process = provider._ffmpeg_process if final_process and final_process.poll() is None: - print(f"WARNING: ffmpeg process (PID: {final_process.pid}) may still be running after cleanup!") + print(f"WARNING: ffmpeg process (PID: {final_process.pid}) may still be running after cleanup!") else: - print("ffmpeg process appears terminated.") \ No newline at end of file + print("ffmpeg process appears terminated.") diff --git a/tests/test_semantic_seg_robot.py b/tests/test_semantic_seg_robot.py index 88f40b6755..eccc5dc84e 100644 --- a/tests/test_semantic_seg_robot.py +++ b/tests/test_semantic_seg_robot.py @@ -23,25 +23,29 @@ def main(): # Create a queue for thread communication (limit to prevent memory issues) frame_queue = queue.Queue(maxsize=5) stop_event = threading.Event() - + # Unitree Go2 camera parameters at 1080p camera_params = { - 'resolution': (1920, 1080), # 1080p resolution - 'focal_length': 3.2, # mm - 'sensor_size': (4.8, 3.6) # mm (1/4" sensor) + "resolution": (1920, 1080), # 1080p resolution + "focal_length": 3.2, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) } - + # Initialize video provider and segmentation stream - #video_provider = VideoProvider("test_camera", video_source=0) - robot = UnitreeGo2(ip=os.getenv('ROBOT_IP'), - ros_control=UnitreeROSControl(),) - - seg_stream = SemanticSegmentationStream(enable_mono_depth=False, camera_params=camera_params, gt_depth_scale=512.0) - + # video_provider = VideoProvider("test_camera", video_source=0) + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + ) + + seg_stream = SemanticSegmentationStream( + enable_mono_depth=False, camera_params=camera_params, gt_depth_scale=512.0 + ) + # Create streams video_stream = robot.get_ros_video_stream(fps=5) segmentation_stream = seg_stream.create_stream(video_stream) - + # Define callbacks for the segmentation stream def on_next(segmentation): if stop_event.is_set(): @@ -53,7 +57,7 @@ def on_next(segmentation): height, width = vis_frame.shape[:2] depth_height, depth_width = depth_viz.shape[:2] - # Resize depth visualization to match segmentation height + # Resize depth visualization to match segmentation height # (maintaining aspect ratio if needed) depth_resized = cv2.resize(depth_viz, (int(depth_width * height / depth_height), height)) @@ -63,7 +67,9 @@ def on_next(segmentation): # Add labels font = cv2.FONT_HERSHEY_SIMPLEX cv2.putText(combined_viz, "Semantic Segmentation", (10, 30), font, 0.8, (255, 255, 255), 2) - cv2.putText(combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2) + cv2.putText( + combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2 + ) # Put frame in queue for main thread to display (non-blocking) try: @@ -71,18 +77,18 @@ def on_next(segmentation): except queue.Full: # Skip frame if queue is full pass - + def on_error(error): print(f"Error: {error}") stop_event.set() - + def on_completed(): print("Stream completed") stop_event.set() - + # Start the subscription subscription = None - + try: # Subscribe to start processing in background thread print_emission_args = { @@ -91,7 +97,6 @@ def on_completed(): "counts": {}, } - frame_processor = FrameProcessor(delete_on_init=True) subscription = segmentation_stream.pipe( MyOps.print_emission(id="A", **print_emission_args), @@ -101,7 +106,7 @@ def on_completed(): MyOps.print_emission(id="C", **print_emission_args), RxOps.filter(lambda x: x is not None), MyOps.print_emission(id="D", **print_emission_args), - # MyVideoOps.with_jpeg_export(frame_processor=frame_processor, suffix="_frame_"), + # MyVideoOps.with_jpeg_export(frame_processor=frame_processor, suffix="_frame_"), MyOps.print_emission(id="E", **print_emission_args), ) @@ -112,21 +117,21 @@ def on_completed(): } fast_api_server = RobotWebInterface(port=5555, **streams) fast_api_server.run() - + except KeyboardInterrupt: print("\nKeyboard interrupt received. Stopping...") finally: # Signal threads to stop stop_event.set() - + # Clean up resources if subscription: subscription.dispose() - + seg_stream.cleanup() cv2.destroyAllWindows() print("Cleanup complete") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_semantic_seg_robot_agent.py b/tests/test_semantic_seg_robot_agent.py index 527d1286a3..8aa102e16a 100644 --- a/tests/test_semantic_seg_robot_agent.py +++ b/tests/test_semantic_seg_robot_agent.py @@ -15,28 +15,29 @@ from dimos.agents.agent import OpenAIAgent from dimos.utils.threadpool import get_scheduler + def main(): # Unitree Go2 camera parameters at 1080p camera_params = { - 'resolution': (1920, 1080), # 1080p resolution - 'focal_length': 3.2, # mm - 'sensor_size': (4.8, 3.6) # mm (1/4" sensor) + "resolution": (1920, 1080), # 1080p resolution + "focal_length": 3.2, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) } - - robot = UnitreeGo2(ip=os.getenv('ROBOT_IP'), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills()) - - seg_stream = SemanticSegmentationStream(enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0) - + + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() + ) + + seg_stream = SemanticSegmentationStream( + enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0 + ) + # Create streams video_stream = robot.get_ros_video_stream(fps=5) segmentation_stream = seg_stream.create_stream( - video_stream.pipe( - MyVideoOps.with_fps_sampling(fps=.5) - ) + video_stream.pipe(MyVideoOps.with_fps_sampling(fps=0.5)) ) - # Throttling to slowdown SegmentationAgent calls + # Throttling to slowdown SegmentationAgent calls # TODO: add Agent parameter to handle this called api_call_interval frame_processor = FrameProcessor(delete_on_init=True) @@ -57,25 +58,33 @@ def main(): RxOps.share(), RxOps.map(lambda x: x.metadata["objects"] if x is not None else None), RxOps.filter(lambda x: x is not None), - RxOps.map(lambda objects: "\n".join( - f"Object {obj['object_id']}: {obj['label']} (confidence: {obj['prob']:.2f})" + - (f", depth: {obj['depth']:.2f}m" if 'depth' in obj else "") - for obj in objects - ) if objects else "No objects detected."), + RxOps.map( + lambda objects: "\n".join( + f"Object {obj['object_id']}: {obj['label']} (confidence: {obj['prob']:.2f})" + + (f", depth: {obj['depth']:.2f}m" if "depth" in obj else "") + for obj in objects + ) + if objects + else "No objects detected." + ), ) text_query_stream = Subject() - + # Combine text query with latest object data when a new text query arrives enriched_query_stream = text_query_stream.pipe( RxOps.with_latest_from(object_stream), - RxOps.map(lambda combined: { - "query": combined[0], - "objects": combined[1] if len(combined) > 1 else "No object data available" - }), + RxOps.map( + lambda combined: { + "query": combined[0], + "objects": combined[1] if len(combined) > 1 else "No object data available", + } + ), RxOps.map(lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}"), - RxOps.do_action(lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m") or - [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]]), + RxOps.do_action( + lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m") + or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]] + ), ) segmentation_agent = OpenAIAgent( @@ -85,7 +94,7 @@ def main(): input_query_stream=enriched_query_stream, process_all_inputs=False, pool_scheduler=get_scheduler(), - skills=robot.get_skills() + skills=robot.get_skills(), ) agent_response_stream = segmentation_agent.get_response_observable() @@ -104,7 +113,7 @@ def main(): try: fast_api_server = RobotWebInterface(port=5555, text_streams=text_streams, **streams) - fast_api_server.query_stream.subscribe(lambda x: text_query_stream.on_next(x)) + fast_api_server.query_stream.subscribe(lambda x: text_query_stream.on_next(x)) fast_api_server.run() except KeyboardInterrupt: print("\nKeyboard interrupt received. Stopping...") @@ -113,5 +122,6 @@ def main(): cv2.destroyAllWindows() print("Cleanup complete") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_semantic_seg_webcam.py b/tests/test_semantic_seg_webcam.py index 7e387ae597..444e4cd629 100644 --- a/tests/test_semantic_seg_webcam.py +++ b/tests/test_semantic_seg_webcam.py @@ -11,26 +11,29 @@ from dimos.stream.video_provider import VideoProvider from dimos.perception.semantic_seg import SemanticSegmentationStream + def main(): # Create a queue for thread communication (limit to prevent memory issues) frame_queue = queue.Queue(maxsize=5) stop_event = threading.Event() - + # Logitech C920e camera parameters at 480p camera_params = { - 'resolution': (640, 480), # 480p resolution - 'focal_length': 3.67, # mm - 'sensor_size': (4.8, 3.6) # mm (1/4" sensor) + "resolution": (640, 480), # 480p resolution + "focal_length": 3.67, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) } - + # Initialize video provider and segmentation stream video_provider = VideoProvider("test_camera", video_source=0) - seg_stream = SemanticSegmentationStream(enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0) - + seg_stream = SemanticSegmentationStream( + enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0 + ) + # Create streams video_stream = video_provider.capture_video_as_observable(realtime=False, fps=5) segmentation_stream = seg_stream.create_stream(video_stream) - + # Define callbacks for the segmentation stream def on_next(segmentation): if stop_event.is_set(): @@ -43,7 +46,7 @@ def on_next(segmentation): height, width = vis_frame.shape[:2] depth_height, depth_width = depth_viz.shape[:2] - # Resize depth visualization to match segmentation height + # Resize depth visualization to match segmentation height # (maintaining aspect ratio if needed) depth_resized = cv2.resize(depth_viz, (int(depth_width * height / depth_height), height)) @@ -53,7 +56,9 @@ def on_next(segmentation): # Add labels font = cv2.FONT_HERSHEY_SIMPLEX cv2.putText(combined_viz, "Semantic Segmentation", (10, 30), font, 0.8, (255, 255, 255), 2) - cv2.putText(combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2) + cv2.putText( + combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2 + ) # Put frame in queue for main thread to display (non-blocking) try: @@ -61,58 +66,56 @@ def on_next(segmentation): except queue.Full: # Skip frame if queue is full pass - + def on_error(error): print(f"Error: {error}") stop_event.set() - + def on_completed(): print("Stream completed") stop_event.set() - + # Start the subscription subscription = None - + try: # Subscribe to start processing in background thread subscription = segmentation_stream.subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed + on_next=on_next, on_error=on_error, on_completed=on_completed ) - + print("Semantic segmentation visualization started. Press 'q' to exit.") - + # Main thread loop for displaying frames while not stop_event.is_set(): try: # Get frame with timeout (allows checking stop_event periodically) combined_viz = frame_queue.get(timeout=1.0) - + # Display the frame in main thread cv2.imshow("Semantic Segmentation", combined_viz) # Check for exit key - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): print("Exit key pressed") break - + except queue.Empty: # No frame available, check if we should continue - if cv2.waitKey(1) & 0xFF == ord('q'): + if cv2.waitKey(1) & 0xFF == ord("q"): print("Exit key pressed") break continue - + except KeyboardInterrupt: print("\nKeyboard interrupt received. Stopping...") finally: # Signal threads to stop stop_event.set() - + # Clean up resources if subscription: subscription.dispose() - + video_provider.dispose_all() seg_stream.cleanup() cv2.destroyAllWindows() @@ -120,4 +123,4 @@ def on_completed(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_skills.py b/tests/test_skills.py index ae77facc38..0d4b7f2ff8 100644 --- a/tests/test_skills.py +++ b/tests/test_skills.py @@ -31,10 +31,10 @@ class TestSkill(AbstractSkill): _called: bool = False - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._called = False - + def __call__(self): self._called = True return "TestSkill executed successfully" @@ -48,31 +48,31 @@ def setUp(self): self.robot = MockRobot() self.skill_library = MyUnitreeSkills(robot=self.robot) self.skill_library.initialize_skills() - + def test_skill_iteration(self): """Test that skills can be properly iterated in the skill library.""" skills_count = 0 for skill in self.skill_library: skills_count += 1 - self.assertTrue(hasattr(skill, '__name__')) + self.assertTrue(hasattr(skill, "__name__")) self.assertTrue(issubclass(skill, AbstractSkill)) - + self.assertGreater(skills_count, 0, "Skill library should contain at least one skill") - + def test_skill_registration(self): """Test that skills can be properly registered in the skill library.""" # Clear existing skills for isolated test self.skill_library = MyUnitreeSkills(robot=self.robot) original_count = len(list(self.skill_library)) - + # Add a custom test skill test_skill = TestSkill self.skill_library.add(test_skill) - + # Verify the skill was added new_count = len(list(self.skill_library)) self.assertEqual(new_count, original_count + 1) - + # Check if the skill can be found by name found = False for skill in self.skill_library: @@ -80,7 +80,7 @@ def test_skill_registration(self): found = True break self.assertTrue(found, "Added skill should be found in skill library") - + def test_skill_direct_execution(self): """Test that a skill can be executed directly.""" test_skill = TestSkill() @@ -88,19 +88,19 @@ def test_skill_direct_execution(self): result = test_skill() self.assertTrue(test_skill._called) self.assertEqual(result, "TestSkill executed successfully") - + def test_skill_library_execution(self): """Test that a skill can be executed through the skill library.""" # Add our test skill to the library test_skill = TestSkill self.skill_library.add(test_skill) - + # Create an instance to confirm it was executed - with mock.patch.object(TestSkill, '__call__', return_value="Success") as mock_call: + with mock.patch.object(TestSkill, "__call__", return_value="Success") as mock_call: result = self.skill_library.call("TestSkill") mock_call.assert_called_once() self.assertEqual(result, "Success") - + def test_skill_not_found(self): """Test that calling a non-existent skill raises an appropriate error.""" with self.assertRaises(ValueError): @@ -115,34 +115,34 @@ def setUp(self): self.robot = MockRobot() self.skill_library = MyUnitreeSkills(robot=self.robot) self.skill_library.initialize_skills() - + # Add a test skill self.skill_library.add(TestSkill) - + # Create the agent self.agent = OpenAIAgent( dev_name="SkillTestAgent", system_query="You are a skill testing agent. When prompted to perform an action, use the appropriate skill.", - skills=self.skill_library + skills=self.skill_library, ) - - @mock.patch('dimos.agents.agent.OpenAIAgent.run_observable_query') + + @mock.patch("dimos.agents.agent.OpenAIAgent.run_observable_query") def test_agent_skill_identification(self, mock_query): """Test that the agent can identify skills based on natural language.""" # Mock the agent response mock_response = mock.MagicMock() mock_response.run.return_value = "I found the TestSkill and executed it." mock_query.return_value = mock_response - + # Run the test response = self.agent.run_observable_query("Please run the test skill").run() - + # Assertions mock_query.assert_called_once_with("Please run the test skill") self.assertEqual(response, "I found the TestSkill and executed it.") - - @mock.patch.object(TestSkill, '__call__') - @mock.patch('dimos.agents.agent.OpenAIAgent.run_observable_query') + + @mock.patch.object(TestSkill, "__call__") + @mock.patch("dimos.agents.agent.OpenAIAgent.run_observable_query") def test_agent_skill_execution(self, mock_query, mock_skill_call): """Test that the agent can execute skills properly.""" # Mock the agent and skill call @@ -150,30 +150,31 @@ def test_agent_skill_execution(self, mock_query, mock_skill_call): mock_response = mock.MagicMock() mock_response.run.return_value = "Executed TestSkill successfully." mock_query.return_value = mock_response - + # Run the test response = self.agent.run_observable_query("Execute the TestSkill skill").run() - + # We can't directly verify the skill was called since our mocking setup # doesn't capture the internal skill execution of the agent, but we can # verify the agent was properly called mock_query.assert_called_once_with("Execute the TestSkill skill") self.assertEqual(response, "Executed TestSkill successfully.") - + def test_agent_multi_skill_registration(self): """Test that multiple skills can be registered with an agent.""" + # Create a new skill class AnotherTestSkill(AbstractSkill): def __call__(self): return "Another test skill executed" - + # Register the new skill initial_count = len(list(self.skill_library)) self.skill_library.add(AnotherTestSkill) - + # Verify two distinct skills now exist self.assertEqual(len(list(self.skill_library)), initial_count + 1) - + # Verify both skills are found by name skill_names = [skill.__name__ for skill in self.skill_library] self.assertIn("TestSkill", skill_names) diff --git a/tests/test_skills_rest.py b/tests/test_skills_rest.py index 36ca9dc366..70a15fcfd5 100644 --- a/tests/test_skills_rest.py +++ b/tests/test_skills_rest.py @@ -54,21 +54,20 @@ IMPORTANT: Only return the response directly asked of the user. E.G. if the user asks for the time, only return the time. If the user asks for the weather, only return the weather. - """), + """ + ), model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=2000 + thinking_budget_tokens=2000, ) # Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe( - lambda x: agent_response_subject.on_next(x) -) +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) # Start the web interface web_interface.run() # Run this query in the web interface: -# -# Make a web request to nist to get the current time. +# +# Make a web request to nist to get the current time. # You should use http://worldclockapi.com/api/json/utc/now -# \ No newline at end of file +# diff --git a/tests/test_spatial_memory.py b/tests/test_spatial_memory.py index c520674b9f..b400749cb4 100644 --- a/tests/test_spatial_memory.py +++ b/tests/test_spatial_memory.py @@ -32,70 +32,67 @@ from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.perception.spatial_perception import SpatialMemory + def extract_position(transform): """Extract position coordinates from a transform message""" if transform is None: return (0, 0, 0) - + pos = transform.transform.translation return (pos.x, pos.y, pos.z) + def setup_persistent_chroma_db(db_path="chromadb_data"): """ Set up a persistent ChromaDB database at the specified path. - + Args: db_path: Path to store the ChromaDB database - + Returns: The ChromaDB client instance """ # Create a persistent ChromaDB client full_db_path = os.path.join("/home/stash/dimensional/dimos/assets/test_spatial_memory", db_path) print(f"Setting up persistent ChromaDB at: {full_db_path}") - + # Ensure the directory exists os.makedirs(full_db_path, exist_ok=True) - + return chromadb.PersistentClient(path=full_db_path) + def main(): print("Starting spatial memory test...") - + # Initialize ROS control and robot - ros_control = UnitreeROSControl( - node_name="spatial_memory_test", - mock_connection=False - ) - - robot = UnitreeGo2( - ros_control=ros_control, - ip=os.getenv('ROBOT_IP') - ) - + ros_control = UnitreeROSControl(node_name="spatial_memory_test", mock_connection=False) + + robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP")) + # Create counters for tracking frame_count = 0 transform_count = 0 stored_count = 0 - + print("Setting up video stream...") - video_stream = robot.get_ros_video_stream() - + video_stream = robot.get_ros_video_stream() + # Create transform stream at 1 Hz print("Setting up transform stream...") transform_stream = ros_control.get_transform_stream( child_frame="map", parent_frame="base_link", - rate_hz=1.0 # 1 transform per second + rate_hz=1.0, # 1 transform per second ) - + # Setup output directory for visual memory visual_memory_dir = "/home/stash/dimensional/dimos/assets/test_spatial_memory" os.makedirs(visual_memory_dir, exist_ok=True) - + # Setup persistent storage path for visual memory visual_memory_path = os.path.join(visual_memory_dir, "visual_memory.pkl") - + # Try to load existing visual memory if it exists if os.path.exists(visual_memory_path): try: @@ -108,10 +105,10 @@ def main(): else: print("No existing visual memory found. Starting with empty visual memory.") visual_memory = VisualMemory(output_dir=visual_memory_dir) - + # Setup a persistent database for ChromaDB db_client = setup_persistent_chroma_db() - + # Create spatial perception instance with persistent storage print("Creating SpatialMemory with persistent vector database...") spatial_memory = SpatialMemory( @@ -119,167 +116,182 @@ def main(): min_distance_threshold=1, # Store frames every 1 meter min_time_threshold=1, # Store frames at least every 1 second chroma_client=db_client, # Use the persistent client - visual_memory=visual_memory # Use the visual memory we loaded or created + visual_memory=visual_memory, # Use the visual memory we loaded or created ) - + # Combine streams using combine_latest # This will pair up items properly without buffering combined_stream = reactivex.combine_latest(video_stream, transform_stream).pipe( - ops.map(lambda pair: { - "frame": pair[0], # First element is the frame - "position": extract_position(pair[1]) # Second element is the transform - }) + ops.map( + lambda pair: { + "frame": pair[0], # First element is the frame + "position": extract_position(pair[1]), # Second element is the transform + } + ) ) - + # Process with spatial memory result_stream = spatial_memory.process_stream(combined_stream) - + # Simple callback to track stored frames and save them to the assets directory def on_stored_frame(result): nonlocal stored_count # Only count actually stored frames (not debug frames) - if not result.get('stored', True) == False: + if not result.get("stored", True) == False: stored_count += 1 - pos = result['position'] + pos = result["position"] print(f"\nStored frame #{stored_count} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})") - + # Save the frame to the assets directory - if 'frame' in result: + if "frame" in result: frame_filename = f"/home/stash/dimensional/dimos/assets/test_spatial_memory/frame_{stored_count:03d}.jpg" - cv2.imwrite(frame_filename, result['frame']) + cv2.imwrite(frame_filename, result["frame"]) print(f"Saved frame to {frame_filename}") - + # Subscribe to results print("Subscribing to spatial perception results...") result_subscription = result_stream.subscribe(on_stored_frame) - + print("\nRunning until interrupted...") try: while True: - time.sleep(1.0) + time.sleep(1.0) print(f"Running: {stored_count} frames stored so far", end="\r") except KeyboardInterrupt: print("\nTest interrupted by user") finally: # Clean up resources print("\nCleaning up...") - if 'result_subscription' in locals(): + if "result_subscription" in locals(): result_subscription.dispose() - + # Visualize spatial memory with multiple object queries visualize_spatial_memory_with_objects( - spatial_memory, - objects=["kitchen", "conference room", "vacuum", "office", "bathroom", "boxes", "telephone booth"], - output_filename="spatial_memory_map.png" + spatial_memory, + objects=[ + "kitchen", + "conference room", + "vacuum", + "office", + "bathroom", + "boxes", + "telephone booth", + ], + output_filename="spatial_memory_map.png", ) - + # Save visual memory to disk for later use saved_path = spatial_memory.vector_db.visual_memory.save("visual_memory.pkl") print(f"Saved {spatial_memory.vector_db.visual_memory.count()} images to disk at {saved_path}") - -def visualize_spatial_memory_with_objects(spatial_memory, objects, output_filename="spatial_memory_map.png"): + + +def visualize_spatial_memory_with_objects( + spatial_memory, objects, output_filename="spatial_memory_map.png" +): """ Visualize a spatial memory map with multiple labeled objects. - + Args: spatial_memory: SpatialMemory instance objects: List of object names to query and visualize (e.g. ["kitchen", "office"]) output_filename: Filename to save the visualization """ # Define colors for different objects - will cycle through these - colors = ['red', 'green', 'orange', 'purple', 'brown', 'cyan', 'magenta', 'yellow'] - + colors = ["red", "green", "orange", "purple", "brown", "cyan", "magenta", "yellow"] + # Get all stored locations for background locations = spatial_memory.vector_db.get_all_locations() if not locations: print("No locations stored in spatial memory.") return - + # Extract coordinates from all stored locations if len(locations[0]) >= 3: x_coords = [loc[0] for loc in locations] y_coords = [loc[1] for loc in locations] else: x_coords, y_coords = zip(*locations) - + # Create figure plt.figure(figsize=(12, 10)) - + # Plot all points in blue - plt.scatter(x_coords, y_coords, c='blue', s=50, alpha=0.5, label='All Frames') - + plt.scatter(x_coords, y_coords, c="blue", s=50, alpha=0.5, label="All Frames") + # Container for all object coordinates object_coords = {} - + # Query for each object and store the result for i, obj in enumerate(objects): color = colors[i % len(colors)] # Cycle through colors print(f"\nProcessing {obj} query for visualization...") - + # Get best match for this object results = spatial_memory.query_by_text(obj, limit=1) if not results: print(f"No results found for '{obj}'") continue - + # Get the first (best) result result = results[0] - metadata = result['metadata'] - + metadata = result["metadata"] + # Extract coordinates from the first metadata item if isinstance(metadata, list) and metadata: metadata = metadata[0] - - if isinstance(metadata, dict) and 'x' in metadata and 'y' in metadata: - x = metadata.get('x', 0) - y = metadata.get('y', 0) - + + if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: + x = metadata.get("x", 0) + y = metadata.get("y", 0) + # Store coordinates for this object object_coords[obj] = (x, y) - + # Plot this object's position plt.scatter([x], [y], c=color, s=100, alpha=0.8, label=obj.title()) - + # Add annotation - obj_abbrev = obj[0].upper() if len(obj) > 0 else 'X' - plt.annotate(f"{obj_abbrev}", (x, y), textcoords="offset points", - xytext=(0,10), ha='center') - + obj_abbrev = obj[0].upper() if len(obj) > 0 else "X" + plt.annotate( + f"{obj_abbrev}", (x, y), textcoords="offset points", xytext=(0, 10), ha="center" + ) + # Save the image to a file using the object name - if 'image' in result and result['image'] is not None: + if "image" in result and result["image"] is not None: # Clean the object name to make it suitable for a filename - clean_name = obj.replace(' ', '_').lower() + clean_name = obj.replace(" ", "_").lower() output_img_filename = f"{clean_name}_result.jpg" cv2.imwrite(output_img_filename, result["image"]) print(f"Saved {obj} image to {output_img_filename}") - + # Finalize the plot plt.title("Spatial Memory Map with Query Results") plt.xlabel("X Position (m)") plt.ylabel("Y Position (m)") plt.grid(True) - plt.axis('equal') + plt.axis("equal") plt.legend() - + # Add origin circle - plt.gca().add_patch(Circle((0, 0), 1.0, fill=False, color='blue', linestyle='--')) - + plt.gca().add_patch(Circle((0, 0), 1.0, fill=False, color="blue", linestyle="--")) + # Save the visualization plt.savefig(output_filename, dpi=300) print(f"Saved enhanced map visualization to {output_filename}") - + return object_coords - + # Final cleanup print("Performing final cleanup...") spatial_memory.cleanup() - + try: robot.cleanup() except Exception as e: print(f"Error during robot cleanup: {e}") - + print("Test completed successfully") + if __name__ == "__main__": main() diff --git a/tests/test_spatial_memory_query.py b/tests/test_spatial_memory_query.py index 2919575428..a0e77e9444 100644 --- a/tests/test_spatial_memory_query.py +++ b/tests/test_spatial_memory_query.py @@ -19,6 +19,7 @@ python test_spatial_memory_query.py --query "kitchen table" --limit 5 --threshold 0.7 --save-all python test_spatial_memory_query.py --query "robot" --limit 3 --save-one """ + import os import sys import argparse @@ -32,45 +33,59 @@ from dimos.perception.spatial_perception import SpatialMemory from dimos.agents.memory.visual_memory import VisualMemory + def setup_persistent_chroma_db(db_path): """Set up a persistent ChromaDB client at the specified path.""" print(f"Setting up persistent ChromaDB at: {db_path}") os.makedirs(db_path, exist_ok=True) return chromadb.PersistentClient(path=db_path) + def parse_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser(description="Query spatial memory database.") - parser.add_argument("--query", type=str, default=None, - help="Text query to search for (e.g., 'kitchen table')") - parser.add_argument("--limit", type=int, default=3, - help="Maximum number of results to return") - parser.add_argument("--threshold", type=float, default=None, - help="Similarity threshold (0.0-1.0). Only return results above this threshold.") - parser.add_argument("--save-all", action="store_true", - help="Save all result images") - parser.add_argument("--save-one", action="store_true", - help="Save only the best matching image") - parser.add_argument("--visualize", action="store_true", - help="Create a visualization of all stored memory locations") - parser.add_argument("--db-path", type=str, - default="/home/stash/dimensional/dimos/assets/test_spatial_memory/chromadb_data", - help="Path to ChromaDB database") - parser.add_argument("--visual-memory-path", type=str, - default="/home/stash/dimensional/dimos/assets/test_spatial_memory/visual_memory.pkl", - help="Path to visual memory file") + parser.add_argument( + "--query", type=str, default=None, help="Text query to search for (e.g., 'kitchen table')" + ) + parser.add_argument("--limit", type=int, default=3, help="Maximum number of results to return") + parser.add_argument( + "--threshold", + type=float, + default=None, + help="Similarity threshold (0.0-1.0). Only return results above this threshold.", + ) + parser.add_argument("--save-all", action="store_true", help="Save all result images") + parser.add_argument("--save-one", action="store_true", help="Save only the best matching image") + parser.add_argument( + "--visualize", + action="store_true", + help="Create a visualization of all stored memory locations", + ) + parser.add_argument( + "--db-path", + type=str, + default="/home/stash/dimensional/dimos/assets/test_spatial_memory/chromadb_data", + help="Path to ChromaDB database", + ) + parser.add_argument( + "--visual-memory-path", + type=str, + default="/home/stash/dimensional/dimos/assets/test_spatial_memory/visual_memory.pkl", + help="Path to visual memory file", + ) return parser.parse_args() + def main(): args = parse_args() print("Loading existing spatial memory database for querying...") - + # Setup the persistent ChromaDB client db_client = setup_persistent_chroma_db(args.db_path) - + # Setup output directory for any saved results output_dir = os.path.dirname(args.visual_memory_path) - + # Load the visual memory print(f"Loading visual memory from {args.visual_memory_path}...") if os.path.exists(args.visual_memory_path): @@ -79,40 +94,41 @@ def main(): else: visual_memory = VisualMemory(output_dir=output_dir) print("No existing visual memory found. Query results won't include images.") - + # Create SpatialMemory with the existing database and visual memory spatial_memory = SpatialMemory( - collection_name="test_spatial_memory", - chroma_client=db_client, - visual_memory=visual_memory + collection_name="test_spatial_memory", chroma_client=db_client, visual_memory=visual_memory ) - + # Create a visualization if requested if args.visualize: print("\nCreating visualization of spatial memory...") common_objects = [ - "kitchen", "conference room", "vacuum", "office", - "bathroom", "boxes", "telephone booth" + "kitchen", + "conference room", + "vacuum", + "office", + "bathroom", + "boxes", + "telephone booth", ] visualize_spatial_memory_with_objects( - spatial_memory, - objects=common_objects, - output_filename="spatial_memory_map.png" + spatial_memory, objects=common_objects, output_filename="spatial_memory_map.png" ) - + # Handle query if provided if args.query: query = args.query limit = args.limit print(f"\nQuerying for: '{query}' (limit: {limit})...") - + # Run the query results = spatial_memory.query_by_text(query, limit=limit) - + if not results: print(f"No results found for query: '{query}'") return - + # Filter by threshold if specified if args.threshold is not None: print(f"Filtering results with similarity threshold: {args.threshold}") @@ -120,153 +136,162 @@ def main(): for result in results: # Distance is inverse of similarity (0 is perfect match) # Convert to similarity score (1.0 is perfect match) - similarity = 1.0 - (result.get('distance', 0) if result.get('distance') is not None else 0) + similarity = 1.0 - ( + result.get("distance", 0) if result.get("distance") is not None else 0 + ) if similarity >= args.threshold: filtered_results.append((result, similarity)) - + # Sort by similarity (highest first) filtered_results.sort(key=lambda x: x[1], reverse=True) - + if not filtered_results: print(f"No results met the similarity threshold of {args.threshold}") return - + print(f"Found {len(filtered_results)} results above threshold") results_with_scores = filtered_results else: # Add similarity scores for all results results_with_scores = [] for result in results: - similarity = 1.0 - (result.get('distance', 0) if result.get('distance') is not None else 0) + similarity = 1.0 - ( + result.get("distance", 0) if result.get("distance") is not None else 0 + ) results_with_scores.append((result, similarity)) - + # Process and display results timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + for i, (result, similarity) in enumerate(results_with_scores): - metadata = result.get('metadata', {}) + metadata = result.get("metadata", {}) if isinstance(metadata, list) and metadata: metadata = metadata[0] - + # Display result information - print(f"\nResult {i+1} for '{query}':") + print(f"\nResult {i + 1} for '{query}':") print(f"Similarity: {similarity:.4f} (distance: {1.0 - similarity:.4f})") - + # Extract and display position information if isinstance(metadata, dict): - x = metadata.get('x', 0) - y = metadata.get('y', 0) - z = metadata.get('z', 0) + x = metadata.get("x", 0) + y = metadata.get("y", 0) + z = metadata.get("z", 0) print(f"Position: ({x:.2f}, {y:.2f}, {z:.2f})") - if 'timestamp' in metadata: + if "timestamp" in metadata: print(f"Timestamp: {metadata['timestamp']}") - if 'frame_id' in metadata: + if "frame_id" in metadata: print(f"Frame ID: {metadata['frame_id']}") - + # Save image if requested and available - if 'image' in result and result['image'] is not None: + if "image" in result and result["image"] is not None: # Only save first image, or all images based on flags if args.save_one and i > 0: continue if not (args.save_all or args.save_one): continue - + # Create a descriptive filename - clean_query = query.replace(' ', '_').replace('/', '_').lower() - output_filename = f"{clean_query}_result_{i+1}_{timestamp}.jpg" - + clean_query = query.replace(" ", "_").replace("/", "_").lower() + output_filename = f"{clean_query}_result_{i + 1}_{timestamp}.jpg" + # Save the image cv2.imwrite(output_filename, result["image"]) print(f"Saved image to {output_filename}") - elif 'image' in result and result['image'] is None: + elif "image" in result and result["image"] is None: print("Image data not available for this result") else: - print("No query specified. Use --query \"text to search for\" to run a query.") + print('No query specified. Use --query "text to search for" to run a query.') print("Use --help to see all available options.") - + print("\nQuery completed successfully!") -def visualize_spatial_memory_with_objects(spatial_memory, objects, output_filename="spatial_memory_map.png"): + +def visualize_spatial_memory_with_objects( + spatial_memory, objects, output_filename="spatial_memory_map.png" +): """Visualize spatial memory with labeled objects.""" # Define colors for different objects - colors = ['red', 'green', 'orange', 'purple', 'brown', 'cyan', 'magenta', 'yellow'] - + colors = ["red", "green", "orange", "purple", "brown", "cyan", "magenta", "yellow"] + # Get all stored locations for background locations = spatial_memory.vector_db.get_all_locations() if not locations: print("No locations stored in spatial memory.") return - + # Extract coordinates if len(locations[0]) >= 3: x_coords = [loc[0] for loc in locations] y_coords = [loc[1] for loc in locations] else: x_coords, y_coords = zip(*locations) - + # Create figure plt.figure(figsize=(12, 10)) - plt.scatter(x_coords, y_coords, c='blue', s=50, alpha=0.5, label='All Frames') - + plt.scatter(x_coords, y_coords, c="blue", s=50, alpha=0.5, label="All Frames") + # Container for object coordinates object_coords = {} - + # Query for each object for i, obj in enumerate(objects): color = colors[i % len(colors)] print(f"Processing {obj} query for visualization...") - + # Get best match results = spatial_memory.query_by_text(obj, limit=1) if not results: print(f"No results found for '{obj}'") continue - + # Process result result = results[0] - metadata = result['metadata'] - + metadata = result["metadata"] + if isinstance(metadata, list) and metadata: metadata = metadata[0] - - if isinstance(metadata, dict) and 'x' in metadata and 'y' in metadata: - x = metadata.get('x', 0) - y = metadata.get('y', 0) - + + if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: + x = metadata.get("x", 0) + y = metadata.get("y", 0) + # Store coordinates object_coords[obj] = (x, y) - + # Plot position plt.scatter([x], [y], c=color, s=100, alpha=0.8, label=obj.title()) - + # Add annotation - obj_abbrev = obj[0].upper() if len(obj) > 0 else 'X' - plt.annotate(f"{obj_abbrev}", (x, y), textcoords="offset points", - xytext=(0,10), ha='center') - + obj_abbrev = obj[0].upper() if len(obj) > 0 else "X" + plt.annotate( + f"{obj_abbrev}", (x, y), textcoords="offset points", xytext=(0, 10), ha="center" + ) + # Save image if available - if 'image' in result and result['image'] is not None: - clean_name = obj.replace(' ', '_').lower() + if "image" in result and result["image"] is not None: + clean_name = obj.replace(" ", "_").lower() output_img_filename = f"{clean_name}_result.jpg" cv2.imwrite(output_img_filename, result["image"]) print(f"Saved {obj} image to {output_img_filename}") - + # Finalize plot plt.title("Spatial Memory Map with Query Results") plt.xlabel("X Position (m)") plt.ylabel("Y Position (m)") plt.grid(True) - plt.axis('equal') + plt.axis("equal") plt.legend() - + # Add origin marker - plt.gca().add_patch(plt.Circle((0, 0), 1.0, fill=False, color='blue', linestyle='--')) - + plt.gca().add_patch(plt.Circle((0, 0), 1.0, fill=False, color="blue", linestyle="--")) + # Save visualization plt.savefig(output_filename, dpi=300) print(f"Saved visualization to {output_filename}") - + return object_coords + if __name__ == "__main__": main() diff --git a/tests/test_standalone_chromadb.py b/tests/test_standalone_chromadb.py index da9ef9e691..067303b572 100644 --- a/tests/test_standalone_chromadb.py +++ b/tests/test_standalone_chromadb.py @@ -24,6 +24,7 @@ embedding_function=embeddings, ) + def add_vector(vector_id, vector_data): """Add a vector to the ChromaDB collection.""" if not db_connection: @@ -34,6 +35,7 @@ def add_vector(vector_id, vector_data): metadatas=[{"name": vector_id}], ) + add_vector("id0", "Food") add_vector("id1", "Cat") add_vector("id2", "Mouse") @@ -50,22 +52,22 @@ def add_vector(vector_id, vector_data): def get_vector(vector_id): """Retrieve a vector from the ChromaDB by its identifier.""" - result = db_connection.get(include=['embeddings'], ids=[vector_id]) + result = db_connection.get(include=["embeddings"], ids=[vector_id]) return result + print(get_vector("id1")) # print(get_vector("id3")) # print(get_vector("id0")) # print(get_vector("id2")) + def query(query_texts, n_results=2): """Query the collection with a specific text and return up to n results.""" if not db_connection: raise Exception("Collection not initialized. Call connect() first.") - return db_connection.similarity_search( - query=query_texts, - k=n_results - ) + return db_connection.similarity_search(query=query_texts, k=n_results) + results = query("Colors") print(results) diff --git a/tests/test_standalone_fastapi.py b/tests/test_standalone_fastapi.py index 86775930a1..8cfec64ae0 100644 --- a/tests/test_standalone_fastapi.py +++ b/tests/test_standalone_fastapi.py @@ -2,6 +2,7 @@ import os import logging + logging.basicConfig(level=logging.DEBUG) from fastapi import FastAPI, Response @@ -11,25 +12,29 @@ app = FastAPI() -# Note: Chrome does not allow for loading more than 6 simultaneous -# video streams. Use Safari or another browser for utilizing +# Note: Chrome does not allow for loading more than 6 simultaneous +# video streams. Use Safari or another browser for utilizing # multiple simultaneous streams. Possibly build out functionality # that will stop live streams. + @app.get("/") async def root(): pid = os.getpid() # Get the current process ID return {"message": f"Video Streaming Server, PID: {pid}"} + def video_stream_generator(): pid = os.getpid() print(f"Stream initiated by worker with PID: {pid}") # Log the PID when the generator is called # Use the correct path for your video source - cap = cv2.VideoCapture(f"{os.getcwd()}/assets/trimmed_video_480p.mov") # Change 0 to a filepath for video files + cap = cv2.VideoCapture( + f"{os.getcwd()}/assets/trimmed_video_480p.mov" + ) # Change 0 to a filepath for video files if not cap.isOpened(): - yield (b'--frame\r\nContent-Type: text/plain\r\n\r\n' + b'Could not open video source\r\n') + yield (b"--frame\r\nContent-Type: text/plain\r\n\r\n" + b"Could not open video source\r\n") return try: @@ -38,20 +43,25 @@ def video_stream_generator(): # If frame is read correctly ret is True if not ret: print(f"Reached the end of the video, restarting... PID: {pid}") - cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Set the position of the next video frame to 0 (the beginning) + cap.set( + cv2.CAP_PROP_POS_FRAMES, 0 + ) # Set the position of the next video frame to 0 (the beginning) continue - _, buffer = cv2.imencode('.jpg', frame) - yield (b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n') + _, buffer = cv2.imencode(".jpg", frame) + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + buffer.tobytes() + b"\r\n") finally: cap.release() + @app.get("/video") async def video_endpoint(): logging.debug("Attempting to open video stream.") - response = StreamingResponse(video_stream_generator(), media_type='multipart/x-mixed-replace; boundary=frame') + response = StreamingResponse( + video_stream_generator(), media_type="multipart/x-mixed-replace; boundary=frame" + ) logging.debug("Streaming response set up.") return response + if __name__ == "__main__": uvicorn.run("__main__:app", host="0.0.0.0", port=5555, workers=20) diff --git a/tests/test_standalone_hugging_face.py b/tests/test_standalone_hugging_face.py index baa81f0d2d..d0b2e68e61 100644 --- a/tests/test_standalone_hugging_face.py +++ b/tests/test_standalone_hugging_face.py @@ -129,25 +129,19 @@ from huggingface_hub import InferenceClient # Use environment variable for API key -api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') +api_key = os.getenv("HUGGINGFACE_ACCESS_TOKEN") client = InferenceClient( provider="hf-inference", api_key=api_key, ) -messages = [ - { - "role": "user", - "content": "How many r's are in the word \"strawberry\"" - } -] +messages = [{"role": "user", "content": 'How many r\'s are in the word "strawberry"'}] completion = client.chat.completions.create( - model="Qwen/QwQ-32B", - messages=messages, - max_tokens=150, + model="Qwen/QwQ-32B", + messages=messages, + max_tokens=150, ) print(completion.choices[0].message) - diff --git a/tests/test_standalone_openai_json.py b/tests/test_standalone_openai_json.py index dd6e4b9bbf..c4531808fb 100644 --- a/tests/test_standalone_openai_json.py +++ b/tests/test_standalone_openai_json.py @@ -4,6 +4,7 @@ # ----- import dotenv + dotenv.load_dotenv() import json @@ -13,18 +14,19 @@ MODEL = "gpt-4o-2024-08-06" -math_tutor_prompt = ''' +math_tutor_prompt = """ You are a helpful math tutor. You will be provided with a math problem, and your goal will be to output a step by step solution, along with a final answer. For each step, just provide the output as an equation use the explanation field to detail the reasoning. -''' +""" -bad_prompt = ''' +bad_prompt = """ Follow the instructions. -''' +""" client = OpenAI() + class MathReasoning(BaseModel): class Step(BaseModel): explanation: str @@ -33,6 +35,7 @@ class Step(BaseModel): steps: list[Step] final_answer: str + def get_math_solution(question: str): completion = client.beta.chat.completions.parse( model=MODEL, @@ -44,6 +47,7 @@ def get_math_solution(question: str): ) return completion.choices[0].message + # Web Server import http.server import socketserver @@ -51,6 +55,7 @@ def get_math_solution(question: str): PORT = 5555 + class CustomHandler(http.server.SimpleHTTPRequestHandler): def do_GET(self): # Parse query parameters from the URL @@ -58,27 +63,32 @@ def do_GET(self): query_params = urllib.parse.parse_qs(parsed_path.query) # Check for a specific query parameter, e.g., 'problem' - problem = query_params.get('problem', [''])[0] # Default to an empty string if 'problem' isn't provided + problem = query_params.get("problem", [""])[ + 0 + ] # Default to an empty string if 'problem' isn't provided if problem: print(f"Problem: {problem}") solution = get_math_solution(problem) - + if solution.refusal: print(f"Refusal: {solution.refusal}") print(f"Solution: {solution}") self.send_response(200) else: - solution = json.dumps({"error": "Please provide a math problem using the 'problem' query parameter."}) + solution = json.dumps( + {"error": "Please provide a math problem using the 'problem' query parameter."} + ) self.send_response(400) - self.send_header('Content-type', 'application/json; charset=utf-8') + self.send_header("Content-type", "application/json; charset=utf-8") self.end_headers() # Write the message content self.wfile.write(str(solution).encode()) + with socketserver.TCPServer(("", PORT), CustomHandler) as httpd: print(f"Serving at port {PORT}") httpd.serve_forever() diff --git a/tests/test_standalone_openai_json_struct.py b/tests/test_standalone_openai_json_struct.py index 27d11fc964..88301728c0 100644 --- a/tests/test_standalone_openai_json_struct.py +++ b/tests/test_standalone_openai_json_struct.py @@ -6,6 +6,7 @@ from typing import List, Union, Dict import dotenv + dotenv.load_dotenv() from textwrap import dedent @@ -14,18 +15,19 @@ MODEL = "gpt-4o-2024-08-06" -math_tutor_prompt = ''' +math_tutor_prompt = """ You are a helpful math tutor. You will be provided with a math problem, and your goal will be to output a step by step solution, along with a final answer. For each step, just provide the output as an equation use the explanation field to detail the reasoning. -''' +""" -general_prompt = ''' +general_prompt = """ Follow the instructions. Output a step by step solution, along with a final answer. Use the explanation field to detail the reasoning. -''' +""" client = OpenAI() + class MathReasoning(BaseModel): class Step(BaseModel): explanation: str @@ -34,6 +36,7 @@ class Step(BaseModel): steps: list[Step] final_answer: str + def get_math_solution(question: str): prompt = general_prompt completion = client.beta.chat.completions.parse( @@ -46,6 +49,7 @@ def get_math_solution(question: str): ) return completion.choices[0].message + # Define Problem problem = "What is the derivative of 3x^2" print(f"Problem: {problem}") @@ -63,7 +67,7 @@ def get_math_solution(question: str): if not parsed_solution: print(f"Unable to Parse Solution") exit() - + # Print solution from class definitions print(f"Parsed: {parsed_solution}") diff --git a/tests/test_standalone_openai_json_struct_func.py b/tests/test_standalone_openai_json_struct_func.py index 5d5e67cd23..fe6bb21844 100644 --- a/tests/test_standalone_openai_json_struct_func.py +++ b/tests/test_standalone_openai_json_struct_func.py @@ -6,6 +6,7 @@ from typing import List, Union, Dict import dotenv + dotenv.load_dotenv() import json @@ -16,18 +17,19 @@ MODEL = "gpt-4o-2024-08-06" -math_tutor_prompt = ''' +math_tutor_prompt = """ You are a helpful math tutor. You will be provided with a math problem, and your goal will be to output a step by step solution, along with a final answer. For each step, just provide the output as an equation use the explanation field to detail the reasoning. -''' +""" -general_prompt = ''' +general_prompt = """ Follow the instructions. Output a step by step solution, along with a final answer. Use the explanation field to detail the reasoning. -''' +""" client = OpenAI() + class MathReasoning(BaseModel): class Step(BaseModel): explanation: str @@ -36,26 +38,28 @@ class Step(BaseModel): steps: list[Step] final_answer: str + # region Function Calling class GetWeather(BaseModel): - latitude: str = Field( - ..., - description="latitude e.g. Bogotรก, Colombia" - ) - longitude: str = Field( - ..., - description="longitude e.g. Bogotรก, Colombia" - ) + latitude: str = Field(..., description="latitude e.g. Bogotรก, Colombia") + longitude: str = Field(..., description="longitude e.g. Bogotรก, Colombia") + def get_weather(latitude, longitude): - response = requests.get(f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit") + response = requests.get( + f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit" + ) data = response.json() - return data['current']['temperature_2m'] + return data["current"]["temperature_2m"] + def get_tools(): return [pydantic_function_tool(GetWeather)] + + tools = get_tools() + def call_function(name, args): if name == "get_weather": print(f"Running function: {name}") @@ -67,7 +71,8 @@ def call_function(name, args): return get_weather(**args) else: return f"Local function not found: {name}" - + + def callback(message, messages, response_message, tool_calls): if message is None or message.tool_calls is None: print("No message or tools were called.") @@ -83,14 +88,11 @@ def callback(message, messages, response_message, tool_calls): result = call_function(name, args) print(f"Function Call Results: {result}") - - messages.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": str(result), - "name": name - }) - + + messages.append( + {"role": "tool", "tool_call_id": tool_call.id, "content": str(result), "name": name} + ) + # Complete the second call, after the functions have completed. if has_called_tools: print("Sending Second Query.") @@ -106,19 +108,18 @@ def callback(message, messages, response_message, tool_calls): print("No Need for Second Query.") return None + # endregion Function Calling + def get_math_solution(question: str): prompt = general_prompt messages = [ - {"role": "system", "content": dedent(prompt)}, - {"role": "user", "content": question}, - ] + {"role": "system", "content": dedent(prompt)}, + {"role": "user", "content": question}, + ] response = client.beta.chat.completions.parse( - model=MODEL, - messages=messages, - response_format=MathReasoning, - tools=tools + model=MODEL, messages=messages, response_format=MathReasoning, tools=tools ) response_message = response.choices[0].message @@ -128,11 +129,9 @@ def get_math_solution(question: str): return new_response or response.choices[0].message + # Define Problem -problems = [ - "What is the derivative of 3x^2", - "What's the weather like in San Fran today?" -] +problems = ["What is the derivative of 3x^2", "What's the weather like in San Fran today?"] problem = problems[0] for problem in problems: @@ -153,7 +152,7 @@ def get_math_solution(question: str): print(f"Unable to Parse Solution") print(f"Solution: {solution}") break - + # Print solution from class definitions print(f"Parsed: {parsed_solution}") diff --git a/tests/test_standalone_openai_json_struct_func_playground.py b/tests/test_standalone_openai_json_struct_func_playground.py index 62f905a4a0..62374da0be 100644 --- a/tests/test_standalone_openai_json_struct_func_playground.py +++ b/tests/test_standalone_openai_json_struct_func_playground.py @@ -105,32 +105,35 @@ import requests from dotenv import load_dotenv + load_dotenv() from openai import OpenAI client = OpenAI() + def get_current_weather(latitude, longitude): """Get the current weather in a given latitude and longitude using the 7Timer API""" base = "http://www.7timer.info/bin/api.pl" request_url = f"{base}?lon={longitude}&lat={latitude}&product=civillight&output=json" response = requests.get(request_url) - + # Parse response to extract the main weather data weather_data = response.json() - current_data = weather_data.get('dataseries', [{}])[0] - + current_data = weather_data.get("dataseries", [{}])[0] + result = { "latitude": latitude, "longitude": longitude, - "temp": current_data.get('temp2m', {'max': 'Unknown', 'min': 'Unknown'}), - "humidity": "Unknown" + "temp": current_data.get("temp2m", {"max": "Unknown", "min": "Unknown"}), + "humidity": "Unknown", } - + # Convert the dictionary to JSON string to match the given structure return json.dumps(result) + def run_conversation(content): messages = [{"role": "user", "content": content}] tools = [ @@ -192,15 +195,14 @@ def run_conversation(content): ) second_response = client.chat.completions.create( - model="gpt-3.5-turbo-0125", - messages=messages, - stream=True + model="gpt-3.5-turbo-0125", messages=messages, stream=True ) return second_response + if __name__ == "__main__": question = "What's the weather like in Paris and San Francisco?" response = run_conversation(question) for chunk in response: - print(chunk.choices[0].delta.content or "", end='', flush=True) -# Milestone 2 \ No newline at end of file + print(chunk.choices[0].delta.content or "", end="", flush=True) +# Milestone 2 diff --git a/tests/test_standalone_project_out.py b/tests/test_standalone_project_out.py index 3ae8a15515..15c5e4f480 100644 --- a/tests/test_standalone_project_out.py +++ b/tests/test_standalone_project_out.py @@ -9,13 +9,14 @@ import types import sys + def extract_function_info(filename): with open(filename, "r") as f: source = f.read() tree = ast.parse(source, filename=filename) - + function_info = [] - + # Use a dictionary to track functions module_globals = {} @@ -25,79 +26,92 @@ def extract_function_info(filename): for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): docstring = ast.get_docstring(node) or "" - + # Attempt to get the callable object from the globals try: - if node.name in module_globals: - func_obj = module_globals[node.name] - signature = inspect.signature(func_obj) - function_info.append({ - "name": node.name, - "signature": str(signature), - "docstring": docstring - }) - else: - function_info.append({ - "name": node.name, - "signature": "Could not get signature", - "docstring": docstring - }) + if node.name in module_globals: + func_obj = module_globals[node.name] + signature = inspect.signature(func_obj) + function_info.append( + {"name": node.name, "signature": str(signature), "docstring": docstring} + ) + else: + function_info.append( + { + "name": node.name, + "signature": "Could not get signature", + "docstring": docstring, + } + ) except TypeError as e: - print(f"Could not get function signature for {node.name} in {filename}: {e}", file=sys.stderr) - function_info.append({ - "name": node.name, - "signature": "Could not get signature", - "docstring": docstring - }) + print( + f"Could not get function signature for {node.name} in {filename}: {e}", + file=sys.stderr, + ) + function_info.append( + { + "name": node.name, + "signature": "Could not get signature", + "docstring": docstring, + } + ) class_info = [] for node in ast.walk(tree): - if isinstance(node, ast.ClassDef): + if isinstance(node, ast.ClassDef): docstring = ast.get_docstring(node) or "" methods = [] for method in node.body: if isinstance(method, (ast.FunctionDef, ast.AsyncFunctionDef)): method_docstring = ast.get_docstring(method) or "" try: - if node.name in module_globals: - class_obj = module_globals[node.name] - method_obj = getattr(class_obj, method.name) - signature = inspect.signature(method_obj) - methods.append({ - "name": method.name, - "signature": str(signature), - "docstring": method_docstring - }) - else: - methods.append({ - "name": method.name, - "signature": "Could not get signature", - "docstring": method_docstring - }) + if node.name in module_globals: + class_obj = module_globals[node.name] + method_obj = getattr(class_obj, method.name) + signature = inspect.signature(method_obj) + methods.append( + { + "name": method.name, + "signature": str(signature), + "docstring": method_docstring, + } + ) + else: + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) except AttributeError as e: - print(f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", file=sys.stderr) - methods.append({ - "name": method.name, - "signature": "Could not get signature", - "docstring": method_docstring - }) + print( + f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", + file=sys.stderr, + ) + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) except TypeError as e: - print(f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", file=sys.stderr) - methods.append({ - "name": method.name, - "signature": "Could not get signature", - "docstring": method_docstring - }) - class_info.append({ - "name": node.name, - "docstring": docstring, - "methods": methods - }) - - return { - "function_info": function_info, - "class_info": class_info - } + print( + f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", + file=sys.stderr, + ) + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) + class_info.append({"name": node.name, "docstring": docstring, "methods": methods}) + + return {"function_info": function_info, "class_info": class_info} + # Usage: file_path = "./dimos/agents/memory/base.py" @@ -110,4 +124,4 @@ def extract_function_info(filename): file_path = "./dimos/agents/agent.py" extracted_info = extract_function_info(file_path) -print(extracted_info) \ No newline at end of file +print(extracted_info) diff --git a/tests/test_standalone_rxpy_01.py b/tests/test_standalone_rxpy_01.py index d68b6fef82..1f65f3a468 100644 --- a/tests/test_standalone_rxpy_01.py +++ b/tests/test_standalone_rxpy_01.py @@ -40,7 +40,7 @@ def emission_process(value): # Create an observable that emits every second secondly_emission = reactivex.interval(1.0, scheduler=pool_scheduler).pipe( - ops.map(lambda x: f"Value {x} emitted after {x+1} second(s)"), + ops.map(lambda x: f"Value {x} emitted after {x + 1} second(s)"), ops.do_action(emission_process), ops.take(30), # Limit the emission to 30 times ) @@ -50,7 +50,7 @@ def emission_process(value): on_next=lambda x: print(x), on_error=lambda e: print(e), on_completed=lambda: print("Emission completed."), - scheduler=pool_scheduler + scheduler=pool_scheduler, ) elif which_test == 2: @@ -92,19 +92,16 @@ def emission_process(value): # Observable that emits every second secondly_emission = reactivex.interval(1.0, scheduler=pool_scheduler).pipe( - ops.map(lambda x: f"Second {x+1}"), - ops.take(30) + ops.map(lambda x: f"Second {x + 1}"), ops.take(30) ) # Observable that emits values immediately and repeatedly - immediate_emission = reactivex.from_(['a', 'b', 'c', 'd', 'e']).pipe( - ops.repeat() - ) + immediate_emission = reactivex.from_(["a", "b", "c", "d", "e"]).pipe(ops.repeat()) # Combine emissions using zip combined_emissions = reactivex.zip(secondly_emission, immediate_emission).pipe( ops.map(lambda combined: f"{combined[0]} - Value: {combined[1]}"), - ops.do_action(lambda s: print(f"Combined emission: {s}")) + ops.do_action(lambda s: print(f"Combined emission: {s}")), ) # Subscribe to the combined emissions @@ -113,10 +110,10 @@ def emission_process(value): on_error=lambda e: print(f"Error: {e}"), on_completed=lambda: { print("Combined emission completed."), - completed_event.set() # Set the event to signal completion + completed_event.set(), # Set the event to signal completion }, - scheduler=pool_scheduler + scheduler=pool_scheduler, ) # Wait for the observable to complete - completed_event.wait() \ No newline at end of file + completed_event.wait() diff --git a/tests/test_unitree_agent.py b/tests/test_unitree_agent.py index 835bb34ff8..fa0aa6b8e3 100644 --- a/tests/test_unitree_agent.py +++ b/tests/test_unitree_agent.py @@ -17,7 +17,6 @@ class UnitreeAgentDemo: - def __init__(self): self.robot_ip = None self.connection_method = None @@ -39,7 +38,8 @@ def get_env_var(var_name, default=None, required=False): self.connection_method = get_env_var("CONN_TYPE") self.serial_number = get_env_var("SERIAL_NUMBER") self.output_dir = get_env_var( - "ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) + "ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros") + ) def _initialize_robot(self, with_video_stream=True): print( @@ -83,12 +83,12 @@ def run_with_queries(self): # This will cause listening agents to consume the queries and respond # to them via skill execution and provide 1-shot responses. query_provider.start_query_stream( - query_template= - "{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", frequency=0.01, start_count=1, end_count=10000, - step=1) + step=1, + ) def run_with_test_video(self): # Initialize robot @@ -96,9 +96,9 @@ def run_with_test_video(self): # Initialize test video stream from dimos.stream.video_provider import VideoProvider + self.video_stream = VideoProvider( - dev_name="UnitreeGo2", - video_source=f"{os.getcwd()}/assets/framecount.mp4" + dev_name="UnitreeGo2", video_source=f"{os.getcwd()}/assets/framecount.mp4" ).capture_video_as_observable(realtime=False, fps=1) # Get Skills @@ -111,8 +111,7 @@ def run_with_test_video(self): agent_type="Perception", input_video_stream=self.video_stream, output_dir=self.output_dir, - query= - "Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", image_detail="high", skills=skills_instance, # frame_processor=frame_processor, @@ -150,9 +149,8 @@ def run_with_ros_video(self): agent_type="Perception", input_video_stream=self.video_stream, output_dir=self.output_dir, - query= - "Based on the image, execute the command seen in the image AND ONLY THE COMMAND IN THE IMAGE. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", - #WORKING MOVEMENT DEMO VVV + query="Based on the image, execute the command seen in the image AND ONLY THE COMMAND IN THE IMAGE. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + # WORKING MOVEMENT DEMO VVV # query="Move() 5 meters foward. Then spin 360 degrees to the right, and then Reverse() 5 meters, and then Move forward 3 meters", image_detail="high", skills=skills_instance, @@ -168,9 +166,9 @@ def run_with_multiple_query_and_test_video_agents(self): # Initialize test video stream from dimos.stream.video_provider import VideoProvider + self.video_stream = VideoProvider( - dev_name="UnitreeGo2", - video_source=f"{os.getcwd()}/assets/framecount.mp4" + dev_name="UnitreeGo2", video_source=f"{os.getcwd()}/assets/framecount.mp4" ).capture_video_as_observable(realtime=False, fps=1) # Create the skills available to the agent. @@ -203,8 +201,7 @@ def run_with_multiple_query_and_test_video_agents(self): agent_type="Perception", input_video_stream=self.video_stream, output_dir=self.output_dir, - query= - "Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", image_detail="high", skills=skills_instance, # frame_processor=frame_processor, @@ -216,8 +213,7 @@ def run_with_multiple_query_and_test_video_agents(self): agent_type="Perception", input_video_stream=self.video_stream, output_dir=self.output_dir, - query= - "Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", image_detail="high", skills=skills_instance, # frame_processor=frame_processor, @@ -228,12 +224,12 @@ def run_with_multiple_query_and_test_video_agents(self): # This will cause listening agents to consume the queries and respond # to them via skill execution and provide 1-shot responses. query_provider.start_query_stream( - query_template= - "{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", frequency=0.01, start_count=1, end_count=10000000, - step=1) + step=1, + ) def run_with_queries_and_fast_api(self): # Initialize robot diff --git a/tests/test_unitree_agent_queries_fastapi.py b/tests/test_unitree_agent_queries_fastapi.py index 73feea6a9a..74517e2ec5 100644 --- a/tests/test_unitree_agent_queries_fastapi.py +++ b/tests/test_unitree_agent_queries_fastapi.py @@ -29,17 +29,18 @@ def main(): robot_ip = os.getenv("ROBOT_IP") if not robot_ip: raise ValueError("ROBOT_IP environment variable is required") - connection_method = os.getenv("CONN_TYPE") or 'webrtc' - output_dir = os.getenv("ROS_OUTPUT_DIR", - os.path.join(os.getcwd(), "assets/output/ros")) + connection_method = os.getenv("CONN_TYPE") or "webrtc" + output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) try: # Initialize robot logger.info("Initializing Unitree Robot") - robot = UnitreeGo2(ip=robot_ip, - connection_method=connection_method, - output_dir=output_dir, - skills=MyUnitreeSkills()) + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir, + skills=MyUnitreeSkills(), + ) # Set up video stream logger.info("Starting video stream") @@ -48,15 +49,15 @@ def main(): # Create FastAPI server with video stream and text streams logger.info("Initializing FastAPI server") streams = {"unitree_video": video_stream} - + # Create a subject for agent responses agent_response_subject = rx.subject.Subject() agent_response_stream = agent_response_subject.pipe(ops.share()) - + text_streams = { "agent_responses": agent_response_stream, } - + web_interface = FastAPIServer(port=5555, text_streams=text_streams, **streams) logger.info("Starting action primitive execution agent") @@ -66,11 +67,9 @@ def main(): output_dir=output_dir, skills=robot.get_skills(), ) - + # Subscribe to agent responses and send them to the subject - agent.get_response_observable().subscribe( - lambda x: agent_response_subject.on_next(x) - ) + agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) # Start server (blocking call) logger.info("Starting FastAPI server") diff --git a/tests/test_webrtc_queue.py b/tests/test_webrtc_queue.py index 8a01a9da3a..305298c13a 100644 --- a/tests/test_webrtc_queue.py +++ b/tests/test_webrtc_queue.py @@ -7,127 +7,124 @@ import os from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl + def main(): """Test WebRTC request queue with a sequence of 20 back-to-back commands""" - + print("Initializing UnitreeGo2...") - + # Get configuration from environment variables robot_ip = os.getenv("ROBOT_IP") - connection_method = getattr(WebRTCConnectionMethod, - os.getenv("CONNECTION_METHOD", "LocalSTA")) - + connection_method = getattr(WebRTCConnectionMethod, os.getenv("CONNECTION_METHOD", "LocalSTA")) + # Initialize ROS control - ros_control = UnitreeROSControl( - node_name="unitree_go2_test", - use_raw=True - ) - + ros_control = UnitreeROSControl(node_name="unitree_go2_test", use_raw=True) + # Initialize robot robot = UnitreeGo2( ip=robot_ip, connection_method=connection_method, ros_control=ros_control, use_ros=True, - use_webrtc=False # Using queue instead of direct WebRTC + use_webrtc=False, # Using queue instead of direct WebRTC ) - + # Wait for initialization print("Waiting for robot to initialize...") time.sleep(5) - + # First put the robot in a good starting state print("Running recovery stand...") robot.webrtc_req(api_id=1006) # RecoveryStand - + # Queue 20 WebRTC requests back-to-back print("\n๐Ÿค– QUEUEING 20 COMMANDS BACK-TO-BACK ๐Ÿค–\n") - + # Dance 1 robot.webrtc_req(api_id=1022) # Dance1 print("Queued: Dance1 (1022)") - + # Wiggle Hips robot.webrtc_req(api_id=1033) # WiggleHips print("Queued: WiggleHips (1033)") - + # Stretch robot.webrtc_req(api_id=1017) # Stretch print("Queued: Stretch (1017)") - + # Hello robot.webrtc_req(api_id=1016) # Hello print("Queued: Hello (1016)") - + # Dance 2 robot.webrtc_req(api_id=1023) # Dance2 print("Queued: Dance2 (1023)") - + # Wallow robot.webrtc_req(api_id=1021) # Wallow print("Queued: Wallow (1021)") - + # Scrape robot.webrtc_req(api_id=1029) # Scrape print("Queued: Scrape (1029)") - + # Finger Heart robot.webrtc_req(api_id=1036) # FingerHeart print("Queued: FingerHeart (1036)") - + # Recovery Stand (base position) robot.webrtc_req(api_id=1006) # RecoveryStand print("Queued: RecoveryStand (1006)") - + # Hello again robot.webrtc_req(api_id=1016) # Hello print("Queued: Hello (1016)") - + # Wiggle Hips again robot.webrtc_req(api_id=1033) # WiggleHips print("Queued: WiggleHips (1033)") - + # Front Pounce robot.webrtc_req(api_id=1032) # FrontPounce print("Queued: FrontPounce (1032)") - + # Dance 1 again robot.webrtc_req(api_id=1022) # Dance1 print("Queued: Dance1 (1022)") - + # Stretch again robot.webrtc_req(api_id=1017) # Stretch print("Queued: Stretch (1017)") - + # Front Jump robot.webrtc_req(api_id=1031) # FrontJump print("Queued: FrontJump (1031)") - + # Finger Heart again robot.webrtc_req(api_id=1036) # FingerHeart print("Queued: FingerHeart (1036)") - + # Scrape again robot.webrtc_req(api_id=1029) # Scrape print("Queued: Scrape (1029)") - + # Hello one more time robot.webrtc_req(api_id=1016) # Hello print("Queued: Hello (1016)") - + # Dance 2 again robot.webrtc_req(api_id=1023) # Dance2 print("Queued: Dance2 (1023)") - + # Finish with recovery stand robot.webrtc_req(api_id=1006) # RecoveryStand print("Queued: RecoveryStand (1006)") - + print("\nAll 20 commands queued successfully! Watch the robot perform them in sequence.") print("The WebRTC queue manager will process them one by one when the robot is ready.") print("Press Ctrl+C to stop the program when you've seen enough.\n") - + try: # Keep the program running so the queue can be processed while True: @@ -140,5 +137,6 @@ def main(): robot.cleanup() print("Test completed.") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_websocketvis.py b/tests/test_websocketvis.py index dd39e59f68..d3e798fa20 100644 --- a/tests/test_websocketvis.py +++ b/tests/test_websocketvis.py @@ -31,28 +31,26 @@ def parse_args(): def setup_web_interface(robot, port=5555): """Set up web interface with robot video and local planner visualization""" print(f"Setting up web interface on port {port}") - + # Get video stream from robot video_stream = robot.video_stream_ros.pipe( ops.share(), ops.map(lambda frame: frame), ops.filter(lambda frame: frame is not None), ) - + # Get local planner visualization stream local_planner_stream = robot.local_planner_viz_stream.pipe( ops.share(), ops.map(lambda frame: frame), ops.filter(lambda frame: frame is not None), ) - + # Create web interface with streams web_interface = RobotWebInterface( - port=port, - robot_video=video_stream, - local_planner=local_planner_stream + port=port, robot_video=video_stream, local_planner=local_planner_stream ) - + return web_interface @@ -61,7 +59,7 @@ def main(): websocket_vis = WebsocketVis() websocket_vis.start() - + web_interface = None if args.live: @@ -69,13 +67,17 @@ def main(): robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP")) planner = robot.global_planner - websocket_vis.connect(vector_stream("robot", lambda: robot.ros_control.transform_euler_pos("base_link"))) - websocket_vis.connect(robot.ros_control.topic("map", Costmap).pipe(ops.map(lambda x: ["costmap", x]))) - + websocket_vis.connect( + vector_stream("robot", lambda: robot.ros_control.transform_euler_pos("base_link")) + ) + websocket_vis.connect( + robot.ros_control.topic("map", Costmap).pipe(ops.map(lambda x: ["costmap", x])) + ) + # Also set up the web interface with both streams - if hasattr(robot, 'video_stream_ros') and hasattr(robot, 'local_planner_viz_stream'): + if hasattr(robot, "video_stream_ros") and hasattr(robot, "local_planner_viz_stream"): web_interface = setup_web_interface(robot, port=args.port) - + # Start web interface in a separate thread viz_thread = threading.Thread(target=web_interface.run, daemon=True) viz_thread.start() From ba8eab1d41554c0a1d2bb5259e14b8a7bb4c4d9c Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 10:53:16 +0300 Subject: [PATCH 02/42] check gitconfig before test --- .github/workflows/pytest.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 0fd43aac66..093b922e37 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,6 +21,8 @@ jobs: steps: - uses: actions/checkout@v4 + - run: | + cat /root/.gitconfig - name: Run tests run: | /entrypoint.sh bash -c "pytest" From 1cd591558b9b426b3106c0a7ac4f99b295a22d3f Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 11:07:38 +0300 Subject: [PATCH 03/42] set-safe-directory on the github action level --- .github/workflows/pytest.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 093b922e37..07b4dde555 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,8 +21,11 @@ jobs: steps: - uses: actions/checkout@v4 + set-safe-directory: true + - run: | cat /root/.gitconfig + - name: Run tests run: | /entrypoint.sh bash -c "pytest" From 932d720630103e6265e26dd342e609b6a1391445 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 11:09:42 +0300 Subject: [PATCH 04/42] workflow fix --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 07b4dde555..bd73032d8b 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,9 +21,9 @@ jobs: steps: - uses: actions/checkout@v4 - set-safe-directory: true - run: | + git config --global --add safe.directory '*' cat /root/.gitconfig - name: Run tests From d3c2ac7921f1f9c9a4016b617382a9adc89ca838 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 11:35:21 +0300 Subject: [PATCH 05/42] improved tests, verifying workflow config --- .github/workflows/pytest.yml | 8 +++++--- dimos/utils/test_testing.py | 24 +++++++++++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index bd73032d8b..f327614fb1 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,10 +21,12 @@ jobs: steps: - uses: actions/checkout@v4 + with: + set-safe-directory: tru - - run: | - git config --global --add safe.directory '*' - cat /root/.gitconfig +# - run: | +# git config --global --add safe.directory '*' +# cat /root/.gitconfig - name: Run tests run: | diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index 8952782168..2f9be8baad 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -12,12 +12,15 @@ def test_pull_file(): # delete decompressed test file if it exists if test_file_decompressed.exists(): - test_file_compressed.unlink() + test_file_decompressed.unlink() # delete lfs archive file if it exists if test_file_compressed.exists(): test_file_compressed.unlink() + assert not test_file_compressed.exists() + assert not test_file_decompressed.exists() + # pull the lfs file reference from git env = os.environ.copy() env["GIT_LFS_SKIP_SMUDGE"] = "1" @@ -31,7 +34,7 @@ def test_pull_file(): # ensure we have a pointer file from git (small ASCII text file) assert test_file_compressed.exists() - test_file_compressed.stat().st_size < 200 + assert test_file_compressed.stat().st_size < 200 # trigger a data file pull assert testing.testData(test_file_name) == test_file_decompressed @@ -42,9 +45,10 @@ def test_pull_file(): # validate hashes with test_file_compressed.open("rb") as f: + assert test_file_compressed.stat().st_size > 200 compressed_sha256 = hashlib.sha256(f.read()).hexdigest() assert ( - compressed_sha256 == "cdfd708d66e6dd5072ed7636fc10fb97754f8d14e3acd6c3553663e27fc96065" + compressed_sha256 == "b8cf30439b41033ccb04b09b9fc8388d18fb544d55b85c155dbf85700b9e7603" ) with test_file_decompressed.open("rb") as f: @@ -84,13 +88,23 @@ def test_pull_dir(): # ensure we have a pointer file from git (small ASCII text file) assert test_dir_compressed.exists() - test_dir_compressed.stat().st_size < 200 + assert test_dir_compressed.stat().st_size < 200 # trigger a data file pull assert testing.testData(test_dir_name) == test_dir_decompressed + assert test_dir_compressed.stat().st_size > 200 # validate data is received assert test_dir_compressed.exists() assert test_dir_decompressed.exists() - assert len(list(test_dir_decompressed.iterdir())) == 2 + for [file, expected_hash] in zip( + list(test_dir_decompressed.iterdir()), + [ + "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", + "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", + ], + ): + with file.open("rb") as f: + sha256 = hashlib.sha256(f.read()).hexdigest() + assert sha256 == expected_hash From 6fa09284c493770ecff73f71a5d03db427a8559b Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 11:44:48 +0300 Subject: [PATCH 06/42] workflow config fix --- .github/workflows/pytest.yml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index f327614fb1..5f532d8e6e 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,12 +21,9 @@ jobs: steps: - uses: actions/checkout@v4 - with: - set-safe-directory: tru -# - run: | -# git config --global --add safe.directory '*' -# cat /root/.gitconfig + - run: | + git config --global --add safe.directory '*' - name: Run tests run: | From f7e9096c39cab85678e951ce09a2470c61305b8e Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 11:51:43 +0300 Subject: [PATCH 07/42] brute forcing config --- .github/workflows/pytest.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 5f532d8e6e..2bbb19f3c9 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,9 +21,13 @@ jobs: steps: - uses: actions/checkout@v4 + with: + set-safe-directory: true - run: | + cat /root/.gitconfig git config --global --add safe.directory '*' + cat /root/.gitconfig - name: Run tests run: | From 8fbec9a69ea74819a3ee970e457f72bacb547500 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:06:37 +0300 Subject: [PATCH 08/42] another fix attempt --- .github/workflows/pytest.yml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 2bbb19f3c9..ae1be9b751 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,14 +21,8 @@ jobs: steps: - uses: actions/checkout@v4 - with: - set-safe-directory: true - - - run: | - cat /root/.gitconfig - git config --global --add safe.directory '*' - cat /root/.gitconfig - name: Run tests run: | + git config --global --add safe.directory '*' /entrypoint.sh bash -c "pytest" From e0d4e51672cb81d3f31da003b8ec9bcdfe6414df Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:22:22 +0300 Subject: [PATCH 09/42] dif hashes need to be sorted --- dimos/utils/test_testing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index 2f9be8baad..c1c0c10072 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -99,10 +99,10 @@ def test_pull_dir(): assert test_dir_decompressed.exists() for [file, expected_hash] in zip( - list(test_dir_decompressed.iterdir()), + sorted(test_dir_decompressed.iterdir()), [ - "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", + "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", ], ): with file.open("rb") as f: From b41817977e7406f8970e3c61a4bb35e52d989c16 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:46:08 +0300 Subject: [PATCH 10/42] testing in-ci pre-commit run --- .github/workflows/code-cleanup.yml | 18 ++++++++++++++++++ .github/workflows/docker.yml | 9 ++++++++- .github/workflows/{pytest.yml => tests.yml} | 0 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/code-cleanup.yml rename .github/workflows/{pytest.yml => tests.yml} (100%) diff --git a/.github/workflows/code-cleanup.yml b/.github/workflows/code-cleanup.yml new file mode 100644 index 0000000000..c964709e25 --- /dev/null +++ b/.github/workflows/code-cleanup.yml @@ -0,0 +1,18 @@ +name: code-cleanup +on: push + +# permissions: +# contents: read +# packages: write +# pull-requests: read + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.1 + - uses: stefanzweifel/git-auto-commit-action@v5 + with: + commit_message: "automated code cleanup" diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index d2861ad731..1b6b482e82 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,5 +1,12 @@ name: docker-tree -on: push +on: +# - push + - workflow_call: + inputs: + branch-tag: + required: true + type: string + default: "latest" permissions: contents: read diff --git a/.github/workflows/pytest.yml b/.github/workflows/tests.yml similarity index 100% rename from .github/workflows/pytest.yml rename to .github/workflows/tests.yml From 01071cec80409ce4375a48d2fe86ab10c2445f86 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:47:26 +0300 Subject: [PATCH 11/42] added pre-commit-config --- .pre-commit-config.yaml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..7a4549ab11 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +default_stages: [pre-commit] +exclude: (dimos/models/.*)|(deprecated) +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.11 + hooks: + #- id: ruff-check + # args: [--fix] + - id: ruff-format + stages: [pre-commit] + + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-case-conflict + - id: trailing-whitespace + language: python + types: [text] + stages: [pre-push] + - id: check-json + - id: check-toml + - id: check-yaml + - id: pretty-format-json + args: [ --autofix, --no-sort-keys ] From 1d516292944d4bc839b35a2b3ab7c31f59aff2e4 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:48:45 +0300 Subject: [PATCH 12/42] forcing CI code cleanup --- .github/workflows/code-cleanup.yml | 2 +- dimos/utils/test_testing.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/code-cleanup.yml b/.github/workflows/code-cleanup.yml index c964709e25..8516de03c4 100644 --- a/.github/workflows/code-cleanup.yml +++ b/.github/workflows/code-cleanup.yml @@ -15,4 +15,4 @@ jobs: - uses: pre-commit/action@v3.0.1 - uses: stefanzweifel/git-auto-commit-action@v5 with: - commit_message: "automated code cleanup" + commit_message: "CI code cleanup" diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index c1c0c10072..9aea9ade74 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -100,10 +100,7 @@ def test_pull_dir(): for [file, expected_hash] in zip( sorted(test_dir_decompressed.iterdir()), - [ - "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", - "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", - ], + [ "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74" ] ): with file.open("rb") as f: sha256 = hashlib.sha256(f.read()).hexdigest() From b8f6d92f958c1a0323095d253e083971bf06e5fb Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:53:37 +0300 Subject: [PATCH 13/42] double pre-commit --- .github/workflows/code-cleanup.yml | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-cleanup.yml b/.github/workflows/code-cleanup.yml index 8516de03c4..7f2580a88e 100644 --- a/.github/workflows/code-cleanup.yml +++ b/.github/workflows/code-cleanup.yml @@ -12,7 +12,18 @@ jobs: steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 - - uses: stefanzweifel/git-auto-commit-action@v5 + - name: Run pre-commit + id: pre-commit-first + uses: pre-commit/action@v3.0.1 + continue-on-error: true + + - name: Re-run pre-commit if failed initially + id: pre-commit-retry + if: steps.pre-commit-first.outcome == 'failure' + uses: pre-commit/action@v3.0.1 + continue-on-error: false + + - name: Commit code changes + uses: stefanzweifel/git-auto-commit-action@v5 with: commit_message: "CI code cleanup" From c24d11b596b48cabd7af4dcd28acd209397f0fbc Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:55:28 +0300 Subject: [PATCH 14/42] permissions added for auto-commits --- .github/workflows/code-cleanup.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/code-cleanup.yml b/.github/workflows/code-cleanup.yml index 7f2580a88e..5710a69570 100644 --- a/.github/workflows/code-cleanup.yml +++ b/.github/workflows/code-cleanup.yml @@ -1,10 +1,8 @@ name: code-cleanup on: push -# permissions: -# contents: read -# packages: write -# pull-requests: read +permissions: + contents: write jobs: pre-commit: From f35ed8484a5d3fc49ae69311353c1e1e10fa51a0 Mon Sep 17 00:00:00 2001 From: leshy <681516+leshy@users.noreply.github.com> Date: Wed, 28 May 2025 09:56:09 +0000 Subject: [PATCH 15/42] CI code cleanup --- dimos/utils/test_testing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index 9aea9ade74..c1c0c10072 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -100,7 +100,10 @@ def test_pull_dir(): for [file, expected_hash] in zip( sorted(test_dir_decompressed.iterdir()), - [ "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74" ] + [ + "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", + "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", + ], ): with file.open("rb") as f: sha256 = hashlib.sha256(f.read()).hexdigest() From 1efa924047007674a097915e3340e565e8ffe2d9 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:57:16 +0300 Subject: [PATCH 16/42] attempting to trigger docker builds after pre-commit --- .github/workflows/code-cleanup.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/code-cleanup.yml b/.github/workflows/code-cleanup.yml index 5710a69570..a1403c1d79 100644 --- a/.github/workflows/code-cleanup.yml +++ b/.github/workflows/code-cleanup.yml @@ -25,3 +25,7 @@ jobs: uses: stefanzweifel/git-auto-commit-action@v5 with: commit_message: "CI code cleanup" + + run-docker-builds: + needs: [pre-commit] + uses: ./.github/workflows/docker.yml From 437a7c628c05a9f4e00325084514d51f01566736 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:58:52 +0300 Subject: [PATCH 17/42] workflow typo --- .github/workflows/docker.yml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 1b6b482e82..c7c3538ab8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,12 +1,5 @@ name: docker-tree -on: -# - push - - workflow_call: - inputs: - branch-tag: - required: true - type: string - default: "latest" +on: workflow_call permissions: contents: read From f8c1baef810e50ffcb9d141b975c9d9deb1f2dc7 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 12:59:42 +0300 Subject: [PATCH 18/42] workflow pytest reference fix --- .github/workflows/docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c7c3538ab8..67a4f4987f 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -114,6 +114,6 @@ jobs: (needs.build-dev.result == 'skipped' && needs.check-changes.outputs.tests == 'true')) }} - uses: ./.github/workflows/pytest.yml + uses: ./.github/workflows/tests.yml with: branch-tag: ${{ needs.build-dev.result != 'success' && 'dev' || needs.check-changes.outputs.branch-tag }} From b850fba715f69235b4507c807e04b4aa01c3d4b2 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 13:00:22 +0300 Subject: [PATCH 19/42] code cleanup needs permissions to call docker build --- .github/workflows/code-cleanup.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/code-cleanup.yml b/.github/workflows/code-cleanup.yml index a1403c1d79..ec46f3e4cd 100644 --- a/.github/workflows/code-cleanup.yml +++ b/.github/workflows/code-cleanup.yml @@ -3,6 +3,8 @@ on: push permissions: contents: write + packages: write + pull-requests: read jobs: pre-commit: From 759fe4f9ae6724d757d86ec09a7805fbf2c35633 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 13:14:39 +0300 Subject: [PATCH 20/42] pre-commit hooks in dev container --- .devcontainer/devcontainer.json | 2 +- docker/dev/Dockerfile | 1 - docker/dev/dev-requirements.txt | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index b54b64bd31..860fcd87f2 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -12,7 +12,7 @@ "containerEnv": { "PYTHONPATH": "${localEnv:PYTHONPATH}:/workspaces/dimos" }, - "postCreateCommand": "git config --global --add safe.directory /workspaces/dimos", + "postCreateCommand": "git config --global --add safe.directory /workspaces/dimos && cd /workspaces/dimos && pre-commit install", "settings": { "notebook.formatOnSave.enabled": true, "notebook.codeActionsOnSave": { diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile index f5a3810bcf..a210fd4e15 100644 --- a/docker/dev/Dockerfile +++ b/docker/dev/Dockerfile @@ -14,7 +14,6 @@ RUN apt-get install -y \ htop \ python-is-python3 \ iputils-ping \ - pre-commit \ wget # Configure git to trust any directory (resolves dubious ownership issues in containers) diff --git a/docker/dev/dev-requirements.txt b/docker/dev/dev-requirements.txt index 965af4cc9d..9633816cf2 100644 --- a/docker/dev/dev-requirements.txt +++ b/docker/dev/dev-requirements.txt @@ -1,2 +1,3 @@ ruff==0.11.10 mypy==1.15.0 +pre_commit==4.2.0 From 95854ae8d1859db511382b580ae91d0cfebc7409 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 13:46:13 +0300 Subject: [PATCH 21/42] Auto-compress test data: test_file.txt --- .pre-commit-config.yaml | 12 ++++++++++-- tests/data/.lfs/test_file.txt.tar.gz | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a4549ab11..b3b885eb78 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,8 +8,6 @@ repos: # args: [--fix] - id: ruff-format stages: [pre-commit] - - - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: @@ -23,3 +21,13 @@ repos: - id: check-yaml - id: pretty-format-json args: [ --autofix, --no-sort-keys ] + + - repo: local + hooks: + - id: lfs_push + name: lfs_push + always_run: true + pass_filenames: false + verbose: true + entry: bin/lfs_push_all + language: script diff --git a/tests/data/.lfs/test_file.txt.tar.gz b/tests/data/.lfs/test_file.txt.tar.gz index 3f28e33601..774fd6b694 100644 --- a/tests/data/.lfs/test_file.txt.tar.gz +++ b/tests/data/.lfs/test_file.txt.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6849b15423b67bcb3a4b6e636612e8fe1d4deb2478f27d50c2ae44f318e66cd8 +oid sha256:41091b53cf0bd0bda58709c356666bab1852452ec680bdfac647df99438408ad size 137 From ef8fe73f83323851d3b6b82f198f174513242eb7 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 13:48:20 +0300 Subject: [PATCH 22/42] Auto-compress test data: test_file.txt --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3b885eb78..d4dcae7784 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: - repo: local hooks: - id: lfs_push - name: lfs_push + name: lfs data compress and push always_run: true pass_filenames: false verbose: true From 9224f745d29c9a7eda21b8a8d164916a1ec53425 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 13:56:07 +0300 Subject: [PATCH 23/42] lfs hook fixes --- .pre-commit-config.yaml | 2 +- bin/lfs_push | 105 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100755 bin/lfs_push diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4dcae7784..e11a291598 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,5 +29,5 @@ repos: always_run: true pass_filenames: false verbose: true - entry: bin/lfs_push_all + entry: bin/lfs_push language: script diff --git a/bin/lfs_push b/bin/lfs_push new file mode 100755 index 0000000000..d6a11050ca --- /dev/null +++ b/bin/lfs_push @@ -0,0 +1,105 @@ +#!/bin/bash + +# Pre-push hook to compress testing data directories +# Compresses directories in tests/data/* into tests/data/.lfs/dirname.tar.gz + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +#echo -e "${GREEN}Running test data compression check...${NC}" + +ROOT=$(git rev-parse --show-toplevel) + + +cd $ROOT +# Ensure tests/data/.lfs directory exists +mkdir -p tests/data/.lfs + +# Check if tests/data exists +if [ ! -d "tests/data" ]; then + echo -e "${YELLOW}No tests/data directory found, skipping compression.${NC}" + exit 0 +fi + +# Track if any compression was performed +compressed_dirs=() + +# Iterate through all directories in tests/data +for dir_path in tests/data/*; do + # Skip if no directories found (glob didn't match) + [ ! "$dir_path" ] && continue + + # Extract directory name + dir_name=$(basename "$dir_path") + + # Skip .lfs directory if it exists + [ "$dir_name" = ".lfs" ] && continue + + # Define compressed file path + compressed_file="tests/data/.lfs/${dir_name}.tar.gz" + + # Check if compressed file already exists + if [ -f "$compressed_file" ]; then + continue + fi + + echo -e " ${YELLOW}Compressing${NC} $dir_path -> $compressed_file" + + # Show directory size before compression + dir_size=$(du -sh "$dir_path" | cut -f1) + echo -e " Data size: ${YELLOW}$dir_size${NC}" + + # Create compressed archive with progress bar + # Use tar with gzip compression, excluding hidden files and common temp files + tar -czf "$compressed_file" \ + --exclude='*.tmp' \ + --exclude='*.temp' \ + --exclude='.DS_Store' \ + --exclude='Thumbs.db' \ + --checkpoint=1000 \ + --checkpoint-action=dot \ + -C "tests/data" \ + "$dir_name" + + if [ $? -eq 0 ]; then + # Show compressed file size + compressed_size=$(du -sh "$compressed_file" | cut -f1) + echo -e " ${GREEN}โœ“${NC} Successfully compressed $dir_name (${GREEN}$dir_size${NC} โ†’ ${GREEN}$compressed_size${NC})" + compressed_dirs+=("$dir_name") + + # Add the compressed file to git LFS tracking + + git add "$compressed_file" + + echo -e "${GREEN}โœ“${NC} git-add $compressed_file" + + else + echo -e " ${RED}โœ—${NC} Failed to compress $dir_name" + exit 1 + fi +done + +if [ ${#compressed_dirs[@]} -gt 0 ]; then + # Create commit message with compressed directory names + if [ ${#compressed_dirs[@]} -eq 1 ]; then + commit_msg="Auto-compress test data: ${compressed_dirs[0]}" + else + # Join array elements with commas + dirs_list=$(IFS=', '; echo "${compressed_dirs[*]}") + commit_msg="Auto-compress test data: ${dirs_list}" + fi + + #git commit -m "$commit_msg" + echo -e "${GREEN}โœ“${NC} Compressed files committed automatically, uploading..." + git lfs push origin $(git branch --show-current) + echo -e "${GREEN}โœ“${NC} Uploaded to LFS" + echo -e "${GREEN}โœ“${NC} New commit has been added" +else + echo -e "${GREEN}โœ“${NC} No test data to compress" +fi + From 97319cf92d445463c30c631e7c57ce5793c7b469 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 13:58:32 +0300 Subject: [PATCH 24/42] lfs hook fixes 2 --- bin/lfs_push | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bin/lfs_push b/bin/lfs_push index d6a11050ca..85a35275c0 100755 --- a/bin/lfs_push +++ b/bin/lfs_push @@ -76,7 +76,7 @@ for dir_path in tests/data/*; do git add "$compressed_file" - echo -e "${GREEN}โœ“${NC} git-add $compressed_file" + echo -e " ${GREEN}โœ“${NC} git-add $compressed_file" else echo -e " ${RED}โœ—${NC} Failed to compress $dir_name" @@ -95,10 +95,9 @@ if [ ${#compressed_dirs[@]} -gt 0 ]; then fi #git commit -m "$commit_msg" - echo -e "${GREEN}โœ“${NC} Compressed files committed automatically, uploading..." + echo -e "${GREEN}โœ“${NC} Compressed file references added. Uploading..." git lfs push origin $(git branch --show-current) echo -e "${GREEN}โœ“${NC} Uploaded to LFS" - echo -e "${GREEN}โœ“${NC} New commit has been added" else echo -e "${GREEN}โœ“${NC} No test data to compress" fi From 20866bdaef1d4cfbcf78cf94409e796e3c79ef26 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 13:58:59 +0300 Subject: [PATCH 25/42] triggering rebuild 2 --- dimos/utils/test_testing.py | 2 ++ tests/data/.lfs/test_file2.txt.tar.gz | 3 +++ 2 files changed, 5 insertions(+) create mode 100644 tests/data/.lfs/test_file2.txt.tar.gz diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index c1c0c10072..d729866d3c 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -3,6 +3,8 @@ import subprocess from dimos.utils import testing +# trigger rebuild1 + def test_pull_file(): repo_root = testing._get_repo_root() diff --git a/tests/data/.lfs/test_file2.txt.tar.gz b/tests/data/.lfs/test_file2.txt.tar.gz new file mode 100644 index 0000000000..56c614ebd1 --- /dev/null +++ b/tests/data/.lfs/test_file2.txt.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57604ebe6c01c7bcd180f5d87eaff64a275f1e1b1955b44b3d500c23c93d3b7e +size 137 From 5d6a960ec0383616791c68ce24797ecfdf24387d Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 14:00:22 +0300 Subject: [PATCH 26/42] final cleanup of the lfs script --- bin/lfs_push | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bin/lfs_push b/bin/lfs_push index 85a35275c0..e4c70bd633 100755 --- a/bin/lfs_push +++ b/bin/lfs_push @@ -1,7 +1,6 @@ #!/bin/bash - -# Pre-push hook to compress testing data directories # Compresses directories in tests/data/* into tests/data/.lfs/dirname.tar.gz +# Pushes to LFS set -e @@ -14,9 +13,8 @@ NC='\033[0m' # No Color #echo -e "${GREEN}Running test data compression check...${NC}" ROOT=$(git rev-parse --show-toplevel) - - cd $ROOT + # Ensure tests/data/.lfs directory exists mkdir -p tests/data/.lfs @@ -73,7 +71,6 @@ for dir_path in tests/data/*; do compressed_dirs+=("$dir_name") # Add the compressed file to git LFS tracking - git add "$compressed_file" echo -e " ${GREEN}โœ“${NC} git-add $compressed_file" From 1b4f155a79af57608a2c4fb201393c89c3bbbb40 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 14:00:57 +0300 Subject: [PATCH 27/42] removed temp test files --- tests/data/.lfs/test_file.txt.tar.gz | 3 --- tests/data/.lfs/test_file2.txt.tar.gz | 3 --- 2 files changed, 6 deletions(-) delete mode 100644 tests/data/.lfs/test_file.txt.tar.gz delete mode 100644 tests/data/.lfs/test_file2.txt.tar.gz diff --git a/tests/data/.lfs/test_file.txt.tar.gz b/tests/data/.lfs/test_file.txt.tar.gz deleted file mode 100644 index 774fd6b694..0000000000 --- a/tests/data/.lfs/test_file.txt.tar.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:41091b53cf0bd0bda58709c356666bab1852452ec680bdfac647df99438408ad -size 137 diff --git a/tests/data/.lfs/test_file2.txt.tar.gz b/tests/data/.lfs/test_file2.txt.tar.gz deleted file mode 100644 index 56c614ebd1..0000000000 --- a/tests/data/.lfs/test_file2.txt.tar.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:57604ebe6c01c7bcd180f5d87eaff64a275f1e1b1955b44b3d500c23c93d3b7e -size 137 From 226eaaf8adaa78d31be78e5fdc27f69958eff7c3 Mon Sep 17 00:00:00 2001 From: leshy <681516+leshy@users.noreply.github.com> Date: Wed, 28 May 2025 11:02:42 +0000 Subject: [PATCH 28/42] CI code cleanup --- tests/data/.lfs/*.tar.gz | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tests/data/.lfs/*.tar.gz diff --git a/tests/data/.lfs/*.tar.gz b/tests/data/.lfs/*.tar.gz new file mode 100644 index 0000000000..015ffb632b --- /dev/null +++ b/tests/data/.lfs/*.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85cea451eec057fa7e734548ca3ba6d779ed5836a3f9de14b8394575ef0d7d8e +size 45 From 64824ac801f70ac6f0f217e9c4e424b19764f220 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 14:20:43 +0300 Subject: [PATCH 29/42] cleanup --- .pre-commit-config.yaml | 3 ++- dimos/utils/test_testing.py | 2 -- tests/data/.lfs/*.tar.gz | 3 --- 3 files changed, 2 insertions(+), 6 deletions(-) delete mode 100644 tests/data/.lfs/*.tar.gz diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e11a291598..c735d3382e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,12 +20,13 @@ repos: - id: check-toml - id: check-yaml - id: pretty-format-json + name: format json args: [ --autofix, --no-sort-keys ] - repo: local hooks: - id: lfs_push - name: lfs data compress and push + name: new LFS files compress and push always_run: true pass_filenames: false verbose: true diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index d729866d3c..c1c0c10072 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -3,8 +3,6 @@ import subprocess from dimos.utils import testing -# trigger rebuild1 - def test_pull_file(): repo_root = testing._get_repo_root() diff --git a/tests/data/.lfs/*.tar.gz b/tests/data/.lfs/*.tar.gz deleted file mode 100644 index 015ffb632b..0000000000 --- a/tests/data/.lfs/*.tar.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:85cea451eec057fa7e734548ca3ba6d779ed5836a3f9de14b8394575ef0d7d8e -size 45 From adfaed3abb0a79526c9a313c72436c0a11fbbf84 Mon Sep 17 00:00:00 2001 From: leshy <681516+leshy@users.noreply.github.com> Date: Wed, 28 May 2025 11:21:31 +0000 Subject: [PATCH 30/42] CI code cleanup --- tests/data/.lfs/*.tar.gz | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tests/data/.lfs/*.tar.gz diff --git a/tests/data/.lfs/*.tar.gz b/tests/data/.lfs/*.tar.gz new file mode 100644 index 0000000000..015ffb632b --- /dev/null +++ b/tests/data/.lfs/*.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85cea451eec057fa7e734548ca3ba6d779ed5836a3f9de14b8394575ef0d7d8e +size 45 From ec8e3ea2dbbd86b0597adf3e759c5ae0add4569c Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 14:34:07 +0300 Subject: [PATCH 31/42] pre-commit doesn't push to LFS, it just checks --- .pre-commit-config.yaml | 7 +++---- bin/lfs_check | 40 ++++++++++++++++++++++++++++++++++++++++ bin/lfs_push | 3 --- tests/data/.lfs/*.tar.gz | 3 --- 4 files changed, 43 insertions(+), 10 deletions(-) create mode 100755 bin/lfs_check delete mode 100644 tests/data/.lfs/*.tar.gz diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c735d3382e..46161bb5e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,10 +25,9 @@ repos: - repo: local hooks: - - id: lfs_push - name: new LFS files compress and push + - id: lfs_check + name: LFS data always_run: true pass_filenames: false - verbose: true - entry: bin/lfs_push + entry: bin/lfs_check language: script diff --git a/bin/lfs_check b/bin/lfs_check new file mode 100755 index 0000000000..c69d9e5838 --- /dev/null +++ b/bin/lfs_check @@ -0,0 +1,40 @@ +#!/bin/bash + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +ROOT=$(git rev-parse --show-toplevel) +cd $ROOT + +new_data=() + +# Iterate through all directories in tests/data +for dir_path in tests/data/*; do + # Skip if no directories found (glob didn't match) + [ ! "$dir_path" ] && continue + + # Extract directory name + dir_name=$(basename "$dir_path") + + # Skip .lfs directory if it exists + [ "$dir_name" = ".lfs" ] && continue + + # Define compressed file path + compressed_file="tests/data/.lfs/${dir_name}.tar.gz" + + # Check if compressed file already exists + if [ -f "$compressed_file" ]; then + continue + fi + + new_data+=("$dir_name") +done + +if [ ${#new_data[@]} -gt 0 ]; then + echo -e "${RED}โœ—${NC} New test data detected:" + echo -e " ${GREEN}${new_data[@]}${NC}" + echo -e "\neither delete or run ./bin/lfs_push to add the data to LFS and your commit" + exit 1 +fi diff --git a/bin/lfs_push b/bin/lfs_push index e4c70bd633..47f50b1354 100755 --- a/bin/lfs_push +++ b/bin/lfs_push @@ -15,9 +15,6 @@ NC='\033[0m' # No Color ROOT=$(git rev-parse --show-toplevel) cd $ROOT -# Ensure tests/data/.lfs directory exists -mkdir -p tests/data/.lfs - # Check if tests/data exists if [ ! -d "tests/data" ]; then echo -e "${YELLOW}No tests/data directory found, skipping compression.${NC}" diff --git a/tests/data/.lfs/*.tar.gz b/tests/data/.lfs/*.tar.gz deleted file mode 100644 index 015ffb632b..0000000000 --- a/tests/data/.lfs/*.tar.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:85cea451eec057fa7e734548ca3ba6d779ed5836a3f9de14b8394575ef0d7d8e -size 45 From 4a7aed7aa2b3dbf6186213afd03fa7e28d419c57 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 14:36:47 +0300 Subject: [PATCH 32/42] null glob fix --- bin/lfs_check | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bin/lfs_check b/bin/lfs_check index c69d9e5838..f0cbdcc695 100755 --- a/bin/lfs_check +++ b/bin/lfs_check @@ -10,10 +10,11 @@ cd $ROOT new_data=() +# Enable nullglob to make globs expand to nothing when not matching +shopt -s nullglob + # Iterate through all directories in tests/data for dir_path in tests/data/*; do - # Skip if no directories found (glob didn't match) - [ ! "$dir_path" ] && continue # Extract directory name dir_name=$(basename "$dir_path") From bf887497a324568e2dc0db7feb2b038fb4653fd3 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 14:38:59 +0300 Subject: [PATCH 33/42] slightly nicer lfs check output --- bin/lfs_check | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/lfs_check b/bin/lfs_check index f0cbdcc695..ce0dada6c1 100755 --- a/bin/lfs_check +++ b/bin/lfs_check @@ -36,6 +36,6 @@ done if [ ${#new_data[@]} -gt 0 ]; then echo -e "${RED}โœ—${NC} New test data detected:" echo -e " ${GREEN}${new_data[@]}${NC}" - echo -e "\neither delete or run ./bin/lfs_push to add the data to LFS and your commit" + echo -e "\neither delete or run ${GREEN}./bin/lfs_push${NC} to add the data to LFS and your commit" exit 1 fi From 2d8049cf04de4956479f3ffa4d673780900fe3e7 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 14:44:09 +0300 Subject: [PATCH 34/42] small workflow naming fixes --- .github/workflows/{code-cleanup.yml => cleanup.yml} | 4 ++-- .github/workflows/docker.yml | 2 +- .github/workflows/tests.yml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename .github/workflows/{code-cleanup.yml => cleanup.yml} (94%) diff --git a/.github/workflows/code-cleanup.yml b/.github/workflows/cleanup.yml similarity index 94% rename from .github/workflows/code-cleanup.yml rename to .github/workflows/cleanup.yml index ec46f3e4cd..ee1ec788e2 100644 --- a/.github/workflows/code-cleanup.yml +++ b/.github/workflows/cleanup.yml @@ -1,4 +1,4 @@ -name: code-cleanup +name: cleanup on: push permissions: @@ -28,6 +28,6 @@ jobs: with: commit_message: "CI code cleanup" - run-docker-builds: + docker: needs: [pre-commit] uses: ./.github/workflows/docker.yml diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 67a4f4987f..0ace5cbd81 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,4 +1,4 @@ -name: docker-tree +name: docker on: workflow_call permissions: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ae1be9b751..7efc7bad01 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: testing +name: tests on: workflow_call: From 2de8581aba158b1dd4ec6f49b46521b24e71e49e Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 14:52:35 +0300 Subject: [PATCH 35/42] better lfs_check output --- bin/lfs_check | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bin/lfs_check b/bin/lfs_check index ce0dada6c1..09d09bd219 100755 --- a/bin/lfs_check +++ b/bin/lfs_check @@ -34,8 +34,9 @@ for dir_path in tests/data/*; do done if [ ${#new_data[@]} -gt 0 ]; then - echo -e "${RED}โœ—${NC} New test data detected:" + echo -e "${RED}โœ—${NC} New test data detected at /tests/data:" echo -e " ${GREEN}${new_data[@]}${NC}" - echo -e "\neither delete or run ${GREEN}./bin/lfs_push${NC} to add the data to LFS and your commit" + echo -e "\nEither delete or run ${GREEN}./bin/lfs_push${NC}" + echo -e "(lfs_push will compress the files into /tests/data/.lfs/, upload to LFS, and add them to your commit)" exit 1 fi From 022a5bf90836f15fcd3b6f5c8e69920cdc30f6aa Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 15:27:39 +0300 Subject: [PATCH 36/42] renaming actions for better UI view --- .github/workflows/docker.yml | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 0ace5cbd81..a5500530e3 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -56,7 +56,7 @@ jobs: echo "branch tag determined: ${branch_tag}" echo branch_tag="${branch_tag}" >> "$GITHUB_OUTPUT" - build-ros: + ros: needs: [check-changes] if: needs.check-changes.outputs.ros == 'true' uses: ./.github/workflows/_docker-build-template.yml @@ -66,21 +66,21 @@ jobs: # just a debugger inspect-needs: - needs: [check-changes, build-ros] + needs: [check-changes, ros] runs-on: dimos-runner-ubuntu-2204 if: always() steps: - run: | echo '${{ toJSON(needs) }}' - build-python: - needs: [check-changes, build-ros] + python: + needs: [check-changes, ros] if: | ${{ always() && !cancelled() && needs.check-changes.result == 'success' && - ((needs.build-ros.result == 'success') || - (needs.build-ros.result == 'skipped' && + ((needs.ros.result == 'success') || + (needs.ros.result == 'skipped' && needs.check-changes.outputs.python == 'true')) }} uses: ./.github/workflows/_docker-build-template.yml @@ -89,14 +89,14 @@ jobs: target: base-ros-python freespace: true - build-dev: - needs: [check-changes, build-python] + dev: + needs: [check-changes, python] if: | ${{ always() && !cancelled() && needs.check-changes.result == 'success' && - ((needs.build-python.result == 'success') || - (needs.build-python.result == 'skipped' && + ((needs.python.result == 'success') || + (needs.python.result == 'skipped' && needs.check-changes.outputs.dev == 'true')) }} uses: ./.github/workflows/_docker-build-template.yml @@ -105,15 +105,15 @@ jobs: target: dev run-tests: - needs: [check-changes, build-dev] + needs: [check-changes, dev] if: | ${{ always() && !cancelled() && needs.check-changes.result == 'success' && - ((needs.build-dev.result == 'success') || - (needs.build-dev.result == 'skipped' && + ((needs.dev.result == 'success') || + (needs.dev.result == 'skipped' && needs.check-changes.outputs.tests == 'true')) }} uses: ./.github/workflows/tests.yml with: - branch-tag: ${{ needs.build-dev.result != 'success' && 'dev' || needs.check-changes.outputs.branch-tag }} + branch-tag: ${{ needs.dev.result != 'success' && 'dev' || needs.check-changes.outputs.branch-tag }} From 0592c67391d35e7212de3a4a0436b6eafeac66e9 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 15:29:49 +0300 Subject: [PATCH 37/42] even shorter naming --- .github/workflows/cleanup.yml | 2 +- .github/workflows/docker.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cleanup.yml b/.github/workflows/cleanup.yml index ee1ec788e2..101daefebb 100644 --- a/.github/workflows/cleanup.yml +++ b/.github/workflows/cleanup.yml @@ -28,6 +28,6 @@ jobs: with: commit_message: "CI code cleanup" - docker: + dkr: needs: [pre-commit] uses: ./.github/workflows/docker.yml diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index a5500530e3..b2f8eeae8e 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,4 +1,4 @@ -name: docker +name: dkr on: workflow_call permissions: From adeb3f732fd02ddd870d03a2a0b4e307cd87f5a2 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 15:31:44 +0300 Subject: [PATCH 38/42] checking explicit action naming --- .github/workflows/cleanup.yml | 2 +- .github/workflows/docker.yml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cleanup.yml b/.github/workflows/cleanup.yml index 101daefebb..ee1ec788e2 100644 --- a/.github/workflows/cleanup.yml +++ b/.github/workflows/cleanup.yml @@ -28,6 +28,6 @@ jobs: with: commit_message: "CI code cleanup" - dkr: + docker: needs: [pre-commit] uses: ./.github/workflows/docker.yml diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b2f8eeae8e..4217f0f974 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,4 +1,4 @@ -name: dkr +name: docker on: workflow_call permissions: @@ -57,6 +57,7 @@ jobs: echo branch_tag="${branch_tag}" >> "$GITHUB_OUTPUT" ros: + name: docker-ros needs: [check-changes] if: needs.check-changes.outputs.ros == 'true' uses: ./.github/workflows/_docker-build-template.yml From 1bcb9bd4f91091522f9d0a5d49314f22cb3d900a Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 15:36:02 +0300 Subject: [PATCH 39/42] decoupling workflows --- .github/workflows/cleanup.yml | 6 ------ .github/workflows/docker.yml | 21 +++++---------------- .github/workflows/tests.yml | 10 ++++------ 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/.github/workflows/cleanup.yml b/.github/workflows/cleanup.yml index ee1ec788e2..2eff2ffd7e 100644 --- a/.github/workflows/cleanup.yml +++ b/.github/workflows/cleanup.yml @@ -3,8 +3,6 @@ on: push permissions: contents: write - packages: write - pull-requests: read jobs: pre-commit: @@ -27,7 +25,3 @@ jobs: uses: stefanzweifel/git-auto-commit-action@v5 with: commit_message: "CI code cleanup" - - docker: - needs: [pre-commit] - uses: ./.github/workflows/docker.yml diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 4217f0f974..089b2ed36a 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,5 +1,9 @@ name: docker -on: workflow_call +on: + workflow_run: + workflows: ["cleanup"] + types: + - completed permissions: contents: read @@ -57,7 +61,6 @@ jobs: echo branch_tag="${branch_tag}" >> "$GITHUB_OUTPUT" ros: - name: docker-ros needs: [check-changes] if: needs.check-changes.outputs.ros == 'true' uses: ./.github/workflows/_docker-build-template.yml @@ -104,17 +107,3 @@ jobs: with: branch-tag: ${{ needs.check-changes.outputs.branch-tag }} target: dev - - run-tests: - needs: [check-changes, dev] - if: | - ${{ - always() && !cancelled() && - needs.check-changes.result == 'success' && - ((needs.dev.result == 'success') || - (needs.dev.result == 'skipped' && - needs.check-changes.outputs.tests == 'true')) - }} - uses: ./.github/workflows/tests.yml - with: - branch-tag: ${{ needs.dev.result != 'success' && 'dev' || needs.check-changes.outputs.branch-tag }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7efc7bad01..325c88903e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,12 +1,10 @@ name: tests on: - workflow_call: - inputs: - branch-tag: - required: true - type: string - default: "latest" + workflow_run: + workflows: ["docker"] + types: + - completed permissions: contents: read From 0c23a15e534033feafa666335d9ed276280dc4ed Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 15:39:01 +0300 Subject: [PATCH 40/42] re-coupling workflows --- .github/workflows/cleanup.yml | 6 ++++++ .github/workflows/docker.yml | 20 +++++++++++++++----- .github/workflows/tests.yml | 10 ++++++---- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/.github/workflows/cleanup.yml b/.github/workflows/cleanup.yml index 2eff2ffd7e..ee1ec788e2 100644 --- a/.github/workflows/cleanup.yml +++ b/.github/workflows/cleanup.yml @@ -3,6 +3,8 @@ on: push permissions: contents: write + packages: write + pull-requests: read jobs: pre-commit: @@ -25,3 +27,7 @@ jobs: uses: stefanzweifel/git-auto-commit-action@v5 with: commit_message: "CI code cleanup" + + docker: + needs: [pre-commit] + uses: ./.github/workflows/docker.yml diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 089b2ed36a..a5500530e3 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,9 +1,5 @@ name: docker -on: - workflow_run: - workflows: ["cleanup"] - types: - - completed +on: workflow_call permissions: contents: read @@ -107,3 +103,17 @@ jobs: with: branch-tag: ${{ needs.check-changes.outputs.branch-tag }} target: dev + + run-tests: + needs: [check-changes, dev] + if: | + ${{ + always() && !cancelled() && + needs.check-changes.result == 'success' && + ((needs.dev.result == 'success') || + (needs.dev.result == 'skipped' && + needs.check-changes.outputs.tests == 'true')) + }} + uses: ./.github/workflows/tests.yml + with: + branch-tag: ${{ needs.dev.result != 'success' && 'dev' || needs.check-changes.outputs.branch-tag }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 325c88903e..7efc7bad01 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,10 +1,12 @@ name: tests on: - workflow_run: - workflows: ["docker"] - types: - - completed + workflow_call: + inputs: + branch-tag: + required: true + type: string + default: "latest" permissions: contents: read From 15aa2d98a3cae95dd4a24510e856478073411b80 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 28 May 2025 18:41:11 +0300 Subject: [PATCH 41/42] sensor replay test --- dimos/robot/unitree_webrtc/type/costmap.py | 44 ++++-- dimos/robot/unitree_webrtc/type/test_map.py | 23 ++++ dimos/robot/unitree_webrtc/type/vector.py | 140 -------------------- dimos/types/vector.py | 2 - dimos/utils/test_testing.py | 19 +++ dimos/utils/testing.py | 84 +++++++++++- 6 files changed, 156 insertions(+), 156 deletions(-) diff --git a/dimos/robot/unitree_webrtc/type/costmap.py b/dimos/robot/unitree_webrtc/type/costmap.py index 814184479e..ba8dba55e3 100644 --- a/dimos/robot/unitree_webrtc/type/costmap.py +++ b/dimos/robot/unitree_webrtc/type/costmap.py @@ -231,6 +231,34 @@ def smudge( origin=self.origin, ) + @property + def total_cells(self) -> int: + return self.width * self.height + + @property + def occupied_cells(self) -> int: + return np.sum(self.grid >= 0.1) + + @property + def unknown_cells(self) -> int: + return np.sum(self.grid == -1) + + @property + def free_cells(self) -> int: + return self.total_cells - self.occupied_cells - self.unknown_cells + + @property + def free_percent(self) -> float: + return (self.free_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 + + @property + def occupied_percent(self) -> float: + return (self.occupied_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 + + @property + def unknown_percent(self) -> float: + return (self.unknown_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 + def __str__(self) -> str: """ Create a string representation of the Costmap. @@ -238,16 +266,6 @@ def __str__(self) -> str: Returns: A formatted string with key costmap information """ - # Calculate occupancy statistics - total_cells = self.width * self.height - occupied_cells = np.sum(self.grid >= 0.1) - unknown_cells = np.sum(self.grid == -1) - free_cells = total_cells - occupied_cells - unknown_cells - - # Calculate percentages - occupied_percent = (occupied_cells / total_cells) * 100 - unknown_percent = (unknown_cells / total_cells) * 100 - free_percent = (free_cells / total_cells) * 100 cell_info = [ "โ–ฆ Costmap", @@ -255,9 +273,9 @@ def __str__(self) -> str: f"({self.width * self.resolution:.1f}x{self.height * self.resolution:.1f}m @", f"{1 / self.resolution:.0f}cm res)", f"Origin: ({x(self.origin):.2f}, {y(self.origin):.2f})", - f"โ–ฃ {occupied_percent:.1f}%", - f"โ–ก {free_percent:.1f}%", - f"โ—Œ {unknown_percent:.1f}%", + f"โ–ฃ {self.occupied_percent:.1f}%", + f"โ–ก {self.free_percent:.1f}%", + f"โ—Œ {self.unknown_percent:.1f}%", ] return " ".join(cell_info) diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py index 25deccda00..180d473eb7 100644 --- a/dimos/robot/unitree_webrtc/type/test_map.py +++ b/dimos/robot/unitree_webrtc/type/test_map.py @@ -1,8 +1,10 @@ import pytest from dimos.robot.unitree_webrtc.testing.mock import Mock from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream, show3d +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils.reactive import backpressure from dimos.robot.unitree_webrtc.type.map import splice_sphere, Map +from dimos.utils.testing import SensorReplay @pytest.mark.vis @@ -36,3 +38,24 @@ def test_robot_vis(): clearframe=True, title="gloal dynamic map test", ) + + +def test_robot_mapping(): + lidar_stream = SensorReplay("office_lidar", autocast=lambda x: LidarMessage.from_msg(x)) + map = Map(voxel_size=0.5) + map.consume(lidar_stream.stream(rate_hz=100.0)).subscribe(lambda x: ...) + + costmap = map.costmap + + shape = costmap.grid.shape + assert shape[0] > 150 + assert shape[1] > 150 + + assert costmap.unknown_percent > 80 + assert costmap.unknown_percent < 90 + + assert costmap.free_percent > 5 + assert costmap.free_percent < 10 + + assert costmap.occupied_percent > 8 + assert costmap.occupied_percent < 15 diff --git a/dimos/robot/unitree_webrtc/type/vector.py b/dimos/robot/unitree_webrtc/type/vector.py index e5fb446884..0343a71118 100644 --- a/dimos/robot/unitree_webrtc/type/vector.py +++ b/dimos/robot/unitree_webrtc/type/vector.py @@ -116,7 +116,6 @@ def __add__(self: T, other: Union["Vector", Iterable[float]]) -> T: def __sub__(self: T, other: Union["Vector", Iterable[float]]) -> T: if isinstance(other, Vector): - print(self, other) return self.__class__(self._data - other._data) return self.__class__(self._data - np.array(other, dtype=float)) @@ -433,142 +432,3 @@ def z(value: VectorLike) -> float: else: arr = to_numpy(value) return float(arr[2]) if len(arr) > 2 else 0.0 - - -if __name__ == "__main__": - # Test vectors in various directions - test_vectors = [ - Vector(1, 0), # Right - Vector(1, 1), # Up-Right - Vector(0, 1), # Up - Vector(-1, 1), # Up-Left - Vector(-1, 0), # Left - Vector(-1, -1), # Down-Left - Vector(0, -1), # Down - Vector(1, -1), # Down-Right - Vector(0.5, 0.5), # Up-Right (shorter) - Vector(-3, 4), # Up-Left (longer) - ] - - for v in test_vectors: - print(str(v)) - - # Test the vector compatibility functions - print("Testing vectortypes.py conversion functions\n") - - # Create test vectors in different formats - vector_obj = Vector(1.0, 2.0, 3.0) - numpy_arr = np.array([4.0, 5.0, 6.0]) - tuple_vec = (7.0, 8.0, 9.0) - list_vec = [10.0, 11.0, 12.0] - - print("Original values:") - print(f"Vector: {vector_obj}") - print(f"NumPy: {numpy_arr}") - print(f"Tuple: {tuple_vec}") - print(f"List: {list_vec}") - print() - - # Test to_numpy - print("to_numpy() conversions:") - print(f"Vector โ†’ NumPy: {to_numpy(vector_obj)}") - print(f"NumPy โ†’ NumPy: {to_numpy(numpy_arr)}") - print(f"Tuple โ†’ NumPy: {to_numpy(tuple_vec)}") - print(f"List โ†’ NumPy: {to_numpy(list_vec)}") - print() - - # Test to_vector - print("to_vector() conversions:") - print(f"Vector โ†’ Vector: {to_vector(vector_obj)}") - print(f"NumPy โ†’ Vector: {to_vector(numpy_arr)}") - print(f"Tuple โ†’ Vector: {to_vector(tuple_vec)}") - print(f"List โ†’ Vector: {to_vector(list_vec)}") - print() - - # Test to_tuple - print("to_tuple() conversions:") - print(f"Vector โ†’ Tuple: {to_tuple(vector_obj)}") - print(f"NumPy โ†’ Tuple: {to_tuple(numpy_arr)}") - print(f"Tuple โ†’ Tuple: {to_tuple(tuple_vec)}") - print(f"List โ†’ Tuple: {to_tuple(list_vec)}") - print() - - # Test to_list - print("to_list() conversions:") - print(f"Vector โ†’ List: {to_list(vector_obj)}") - print(f"NumPy โ†’ List: {to_list(numpy_arr)}") - print(f"Tuple โ†’ List: {to_list(tuple_vec)}") - print(f"List โ†’ List: {to_list(list_vec)}") - print() - - # Test component extraction - print("Component extraction:") - print("x() function:") - print(f"x(Vector): {x(vector_obj)}") - print(f"x(NumPy): {x(numpy_arr)}") - print(f"x(Tuple): {x(tuple_vec)}") - print(f"x(List): {x(list_vec)}") - print() - - print("y() function:") - print(f"y(Vector): {y(vector_obj)}") - print(f"y(NumPy): {y(numpy_arr)}") - print(f"y(Tuple): {y(tuple_vec)}") - print(f"y(List): {y(list_vec)}") - print() - - print("z() function:") - print(f"z(Vector): {z(vector_obj)}") - print(f"z(NumPy): {z(numpy_arr)}") - print(f"z(Tuple): {z(tuple_vec)}") - print(f"z(List): {z(list_vec)}") - print() - - # Test dimension checking - print("Dimension checking:") - vec2d = Vector(1.0, 2.0) - vec3d = Vector(1.0, 2.0, 3.0) - arr2d = np.array([1.0, 2.0]) - arr3d = np.array([1.0, 2.0, 3.0]) - - print(f"is_2d(Vector(1,2)): {is_2d(vec2d)}") - print(f"is_2d(Vector(1,2,3)): {is_2d(vec3d)}") - print(f"is_2d(np.array([1,2])): {is_2d(arr2d)}") - print(f"is_2d(np.array([1,2,3])): {is_2d(arr3d)}") - print(f"is_2d((1,2)): {is_2d((1.0, 2.0))}") - print(f"is_2d((1,2,3)): {is_2d((1.0, 2.0, 3.0))}") - print() - - print(f"is_3d(Vector(1,2)): {is_3d(vec2d)}") - print(f"is_3d(Vector(1,2,3)): {is_3d(vec3d)}") - print(f"is_3d(np.array([1,2])): {is_3d(arr2d)}") - print(f"is_3d(np.array([1,2,3])): {is_3d(arr3d)}") - print(f"is_3d((1,2)): {is_3d((1.0, 2.0))}") - print(f"is_3d((1,2,3)): {is_3d((1.0, 2.0, 3.0))}") - print() - - # Test the Protocol interface - print("Testing VectorLike Protocol:") - print(f"isinstance(Vector(1,2), VectorLike): {isinstance(vec2d, VectorLike)}") - print(f"isinstance(np.array([1,2]), VectorLike): {isinstance(arr2d, VectorLike)}") - print(f"isinstance((1,2), VectorLike): {isinstance((1.0, 2.0), VectorLike)}") - print(f"isinstance([1,2], VectorLike): {isinstance([1.0, 2.0], VectorLike)}") - print() - - # Test mixed operations using different vector types - # These functions aren't defined in vectortypes, but demonstrate the concept - def distance(a: VectorLike, b: VectorLike) -> float: - a_np = to_numpy(a) - b_np = to_numpy(b) - diff = a_np - b_np - return float(np.sqrt(np.sum(diff * diff))) - - def midpoint(a: VectorLike, b: VectorLike) -> NDArray[np.float64]: - a_np = to_numpy(a) - b_np = to_numpy(b) - return (a_np + b_np) / 2 - - print("Mixed operations between different vector types:") - print(f"distance(Vector(1,2,3), [4,5,6]): {distance(vec3d, [4.0, 5.0, 6.0])}") - print(f"distance(np.array([1,2,3]), (4,5,6)): {distance(arr3d, (4.0, 5.0, 6.0))}") - print(f"midpoint(Vector(1,2,3), np.array([4,5,6])): {midpoint(vec3d, numpy_arr)}") diff --git a/dimos/types/vector.py b/dimos/types/vector.py index eb43c04945..d5fcd08165 100644 --- a/dimos/types/vector.py +++ b/dimos/types/vector.py @@ -91,8 +91,6 @@ def __str__(self) -> str: def getArrow(): repr = ["โ†", "โ†–", "โ†‘", "โ†—", "โ†’", "โ†˜", "โ†“", "โ†™"] - print("SELF X", self.x) - print("SELF Y", self.y) if self.x == 0 and self.y == 0: return "ยท" diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index c1c0c10072..796c190c8a 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -2,6 +2,7 @@ import os import subprocess from dimos.utils import testing +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage def test_pull_file(): @@ -108,3 +109,21 @@ def test_pull_dir(): with file.open("rb") as f: sha256 = hashlib.sha256(f.read()).hexdigest() assert sha256 == expected_hash + + +def test_sensor_replay(): + counter = 0 + for message in testing.SensorReplay(name="office_lidar").iterate(): + counter += 1 + assert isinstance(message, dict) + assert counter == 500 + + +def test_sensor_replay_cast(): + counter = 0 + for message in testing.SensorReplay( + name="office_lidar", autocast=lambda x: LidarMessage.from_msg(x) + ).iterate(): + counter += 1 + assert isinstance(message, LidarMessage) + assert counter == 500 diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index 53b9849718..bd9a37d00a 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -1,8 +1,15 @@ import subprocess import tarfile +import glob +import os +import pickle from functools import cache from pathlib import Path -from typing import Union +from typing import Union, Iterator, TypeVar, Generic, Optional, Any, Type, Callable + +from reactivex import operators as ops +from reactivex import interval, from_iterable +from reactivex.observable import Observable def _check_git_lfs_available() -> None: @@ -140,3 +147,78 @@ def testData(filename: Union[str, Path]) -> Path: return file_path return _decompress_archive(_pull_lfs_archive(filename)) + + +T = TypeVar("T") + + +class SensorReplay(Generic[T]): + """Generic sensor data replay utility. + + Args: + name: The name of the test dataset + autocast: Optional function that takes unpickled data and returns a processed result. + For example: lambda data: LidarMessage.from_msg(data) + """ + + def __init__(self, name: str, autocast: Optional[Callable[[Any], T]] = None): + self.root_dir = testData(name) + self.autocast = autocast + self.cnt = 0 + + def load(self, *names: Union[int, str]) -> Union[T, Any, list[T], list[Any]]: + if len(names) == 1: + return self.load_one(names[0]) + return list(map(lambda name: self.load_one(name), names)) + + def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: + if isinstance(name, int): + full_path = self.root_dir / f"/{name:03d}.pickle" + elif isinstance(name, Path): + full_path = self.root_dir / f"/{name}.pickle" + else: + full_path = name + + with open(full_path, "rb") as f: + data = pickle.load(f) + if self.autocast: + return self.autocast(data) + return data + + def iterate(self) -> Iterator[Union[T, Any]]: + pattern = os.path.join(self.root_dir, "*") + for file_path in sorted(glob.glob(pattern)): + yield self.load_one(file_path) + + def stream(self, rate_hz: float = 10.0) -> Observable[Union[T, Any]]: + sleep_time = 1.0 / rate_hz + + return from_iterable(self.iterate()).pipe( + ops.zip(interval(sleep_time)), + ops.map(lambda x: x[0] if isinstance(x, tuple) else x), + ) + + def save_stream(self, observable: Observable[Union[T, Any]]) -> Observable[int]: + return observable.pipe(ops.map(lambda frame: self.save_one(frame))) + + def save(self, *frames) -> int: + [self.save_one(frame) for frame in frames] + return self.cnt + + def save_one(self, frame) -> int: + file_name = f"/{self.cnt:03d}.pickle" + full_path = self.root_dir + file_name + + self.cnt += 1 + + if os.path.isfile(full_path): + raise Exception(f"file {full_path} exists") + + # Convert to raw message if frame has a raw_msg attribute + if hasattr(frame, "raw_msg"): + frame = frame.raw_msg + + with open(full_path, "wb") as f: + pickle.dump(frame, f) + + return self.cnt From 792582c647d61408bddf31f2e3bae2d4df242d38 Mon Sep 17 00:00:00 2001 From: lesh Date: Thu, 29 May 2025 16:34:37 +0300 Subject: [PATCH 42/42] mapping test --- dimos/robot/unitree_webrtc/type/test_map.py | 26 +++++++++++---------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py index 180d473eb7..88c394236c 100644 --- a/dimos/robot/unitree_webrtc/type/test_map.py +++ b/dimos/robot/unitree_webrtc/type/test_map.py @@ -1,9 +1,10 @@ import pytest + +from dimos.robot.unitree_webrtc.testing.helpers import show3d, show3d_stream from dimos.robot.unitree_webrtc.testing.mock import Mock -from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream, show3d from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map, splice_sphere from dimos.utils.reactive import backpressure -from dimos.robot.unitree_webrtc.type.map import splice_sphere, Map from dimos.utils.testing import SensorReplay @@ -43,19 +44,20 @@ def test_robot_vis(): def test_robot_mapping(): lidar_stream = SensorReplay("office_lidar", autocast=lambda x: LidarMessage.from_msg(x)) map = Map(voxel_size=0.5) - map.consume(lidar_stream.stream(rate_hz=100.0)).subscribe(lambda x: ...) + map.consume(lidar_stream.stream(rate_hz=100.0)).run() costmap = map.costmap - shape = costmap.grid.shape - assert shape[0] > 150 - assert shape[1] > 150 + assert costmap.grid.shape == (404, 276) - assert costmap.unknown_percent > 80 - assert costmap.unknown_percent < 90 + assert 70 <= costmap.unknown_percent <= 80, ( + f"Unknown percent {costmap.unknown_percent} is not within the range 70-80" + ) - assert costmap.free_percent > 5 - assert costmap.free_percent < 10 + assert 5 < costmap.free_percent < 10, ( + f"Free percent {costmap.free_percent} is not within the range 5-10" + ) - assert costmap.occupied_percent > 8 - assert costmap.occupied_percent < 15 + assert 8 < costmap.occupied_percent < 15, ( + f"Occupied percent {costmap.occupied_percent} is not within the range 8-15" + )