Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bin/lfs_check
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ new_data=()
# Enable nullglob to make globs expand to nothing when not matching
shopt -s nullglob

# Iterate through all directories in tests/data
for dir_path in tests/data/*; do
# Iterate through all directories in data/
for dir_path in data/*; do

# Extract directory name
dir_name=$(basename "$dir_path")
Expand All @@ -23,7 +23,7 @@ for dir_path in tests/data/*; do
[ "$dir_name" = ".lfs" ] && continue

# Define compressed file path
compressed_file="tests/data/.lfs/${dir_name}.tar.gz"
compressed_file="data/.lfs/${dir_name}.tar.gz"

# Check if compressed file already exists
if [ -f "$compressed_file" ]; then
Expand All @@ -34,9 +34,9 @@ for dir_path in tests/data/*; do
done

if [ ${#new_data[@]} -gt 0 ]; then
echo -e "${RED}✗${NC} New test data detected at /tests/data:"
echo -e "${RED}✗${NC} New test data detected at /data:"
echo -e " ${GREEN}${new_data[@]}${NC}"
echo -e "\nEither delete or run ${GREEN}./bin/lfs_push${NC}"
echo -e "(lfs_push will compress the files into /tests/data/.lfs/, upload to LFS, and add them to your commit)"
echo -e "(lfs_push will compress the files into /data/.lfs/, upload to LFS, and add them to your commit)"
exit 1
fi
16 changes: 8 additions & 8 deletions bin/lfs_push
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Compresses directories in tests/data/* into tests/data/.lfs/dirname.tar.gz
# Compresses directories in data/* into data/.lfs/dirname.tar.gz
# Pushes to LFS

set -e
Expand All @@ -15,17 +15,17 @@ NC='\033[0m' # No Color
ROOT=$(git rev-parse --show-toplevel)
cd $ROOT

# Check if tests/data exists
if [ ! -d "tests/data" ]; then
echo -e "${YELLOW}No tests/data directory found, skipping compression.${NC}"
# Check if data/ exists
if [ ! -d "data/" ]; then
echo -e "${YELLOW}No data directory found, skipping compression.${NC}"
exit 0
fi

# Track if any compression was performed
compressed_dirs=()

# Iterate through all directories in tests/data
for dir_path in tests/data/*; do
# Iterate through all directories in data/
for dir_path in data/*; do
# Skip if no directories found (glob didn't match)
[ ! "$dir_path" ] && continue

Expand All @@ -36,7 +36,7 @@ for dir_path in tests/data/*; do
[ "$dir_name" = ".lfs" ] && continue

# Define compressed file path
compressed_file="tests/data/.lfs/${dir_name}.tar.gz"
compressed_file="data/.lfs/${dir_name}.tar.gz"

# Check if compressed file already exists
if [ -f "$compressed_file" ]; then
Expand All @@ -58,7 +58,7 @@ for dir_path in tests/data/*; do
--exclude='Thumbs.db' \
--checkpoint=1000 \
--checkpoint-action=dot \
-C "tests/data" \
-C "data/" \
"$dir_name"

if [ $? -eq 0 ]; then
Expand Down
3 changes: 3 additions & 0 deletions data/.lfs/models_contact_graspnet.tar.gz
Git LFS file not shown
Empty file.
52 changes: 52 additions & 0 deletions dimos/models/manipulation/contact_graspnet_pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# ContactGraspNet PyTorch Module

This module provides a PyTorch implementation of ContactGraspNet for robotic grasping on dimOS.

## Setup Instructions

### 1. Install Required Dependencies

Install the manipulation extras from the main repository:

```bash
# From the root directory of the dimos repository
pip install -e ".[manipulation]"
```

This will install all the necessary dependencies for using the contact_graspnet_pytorch module, including:
- PyTorch
- Open3D
- Other manipulation-specific dependencies

### 2. Testing the Module

To test that the module is properly installed and functioning:

```bash
# From the root directory of the dimos repository
pytest -s dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py
```

The test will verify that:
- The model can be loaded
- Inference runs correctly
- Grasping outputs are generated as expected

### 3. Using in Your Code

Reference ```inference.py``` for usage example.

### Troubleshooting

If you encounter issues with imports or missing dependencies:

1. Verify that the manipulation extras are properly installed:
```python
import contact_graspnet_pytorch
print("Module loaded successfully!")
```

2. If LFS data files are missing, ensure Git LFS is installed and initialized:
```bash
git lfs pull
```
116 changes: 116 additions & 0 deletions dimos/models/manipulation/contact_graspnet_pytorch/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import glob
import os
import argparse

import torch
import numpy as np
from contact_graspnet_pytorch.contact_grasp_estimator import GraspEstimator
from contact_graspnet_pytorch import config_utils

from contact_graspnet_pytorch.visualization_utils_o3d import visualize_grasps, show_image
from contact_graspnet_pytorch.checkpoints import CheckpointIO
from contact_graspnet_pytorch.data import load_available_input_data
from dimos.utils.data import get_data

def inference(global_config,
ckpt_dir,
input_paths,
local_regions=True,
filter_grasps=True,
skip_border_objects=False,
z_range = [0.2,1.8],
forward_passes=1,
K=None,):
"""
Predict 6-DoF grasp distribution for given model and input data

:param global_config: config.yaml from checkpoint directory
:param checkpoint_dir: checkpoint directory
:param input_paths: .png/.npz/.npy file paths that contain depth/pointcloud and optionally intrinsics/segmentation/rgb
:param K: Camera Matrix with intrinsics to convert depth to point cloud
:param local_regions: Crop 3D local regions around given segments.
:param skip_border_objects: When extracting local_regions, ignore segments at depth map boundary.
:param filter_grasps: Filter and assign grasp contacts according to segmap.
:param segmap_id: only return grasps from specified segmap_id.
:param z_range: crop point cloud at a minimum/maximum z distance from camera to filter out outlier points. Default: [0.2, 1.8] m
:param forward_passes: Number of forward passes to run on each point cloud. Default: 1
"""
# Build the model
grasp_estimator = GraspEstimator(global_config)

# Load the weights
model_checkpoint_dir = get_data(ckpt_dir)
checkpoint_io = CheckpointIO(checkpoint_dir=model_checkpoint_dir, model=grasp_estimator.model)
try:
load_dict = checkpoint_io.load('model.pt')
except FileExistsError:
print('No model checkpoint found')
load_dict = {}


os.makedirs('results', exist_ok=True)

# Process example test scenes
for p in glob.glob(input_paths):
print('Loading ', p)

pc_segments = {}
segmap, rgb, depth, cam_K, pc_full, pc_colors = load_available_input_data(p, K=K)

if segmap is None and (local_regions or filter_grasps):
raise ValueError('Need segmentation map to extract local regions or filter grasps')

if pc_full is None:
print('Converting depth to point cloud(s)...')
pc_full, pc_segments, pc_colors = grasp_estimator.extract_point_clouds(depth, cam_K, segmap=segmap, rgb=rgb,
skip_border_objects=skip_border_objects,
z_range=z_range)

print(pc_full.shape)

print('Generating Grasps...')
pred_grasps_cam, scores, contact_pts, _ = grasp_estimator.predict_scene_grasps(pc_full,
pc_segments=pc_segments,
local_regions=local_regions,
filter_grasps=filter_grasps,
forward_passes=forward_passes)

# Save results
np.savez('results/predictions_{}'.format(os.path.basename(p.replace('png','npz').replace('npy','npz'))),
pc_full=pc_full, pred_grasps_cam=pred_grasps_cam, scores=scores, contact_pts=contact_pts, pc_colors=pc_colors)

# Visualize results
# show_image(rgb, segmap)
# visualize_grasps(pc_full, pred_grasps_cam, scores, plot_opencv_cam=True, pc_colors=pc_colors)

if not glob.glob(input_paths):
print('No files found: ', input_paths)

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', default='models_contact_graspnet', help='Log dir')
parser.add_argument('--np_path', default='test_data/7.npy', help='Input data: npz/npy file with keys either "depth" & camera matrix "K" or just point cloud "pc" in meters. Optionally, a 2D "segmap"')
parser.add_argument('--K', default=None, help='Flat Camera Matrix, pass as "[fx, 0, cx, 0, fy, cy, 0, 0 ,1]"')
parser.add_argument('--z_range', default=[0.2,1.8], help='Z value threshold to crop the input point cloud')
parser.add_argument('--local_regions', action='store_true', default=True, help='Crop 3D local regions around given segments.')
parser.add_argument('--filter_grasps', action='store_true', default=True, help='Filter grasp contacts according to segmap.')
parser.add_argument('--skip_border_objects', action='store_true', default=False, help='When extracting local_regions, ignore segments at depth map boundary.')
parser.add_argument('--forward_passes', type=int, default=1, help='Run multiple parallel forward passes to mesh_utils more potential contact points.')
parser.add_argument('--arg_configs', nargs="*", type=str, default=[], help='overwrite config parameters')
FLAGS = parser.parse_args()

global_config = config_utils.load_config(FLAGS.ckpt_dir, batch_size=FLAGS.forward_passes, arg_configs=FLAGS.arg_configs)

print(str(global_config))
print('pid: %s'%(str(os.getpid())))

inference(global_config,
FLAGS.ckpt_dir,
FLAGS.np_path,
local_regions=FLAGS.local_regions,
filter_grasps=FLAGS.filter_grasps,
skip_border_objects=FLAGS.skip_border_objects,
z_range=eval(str(FLAGS.z_range)),
forward_passes=FLAGS.forward_passes,
K=eval(str(FLAGS.K)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import sys
import glob
import pytest
import importlib.util
import numpy as np

def is_manipulation_installed():
"""Check if the manipulation extras are installed."""
try:
import contact_graspnet_pytorch
return True
except ImportError:
return False

@pytest.mark.skipif(not is_manipulation_installed(),
reason="This test requires 'pip install .[manipulation]' to be run")
def test_contact_graspnet_inference():
"""Test contact graspnet inference with local regions and filter grasps."""
# Skip test if manipulation dependencies not installed
if not is_manipulation_installed():
pytest.skip("contact_graspnet_pytorch not installed. Run 'pip install .[manipulation]' first.")
return

try:
from dimos.utils.data import get_data
from contact_graspnet_pytorch import config_utils
from dimos.models.manipulation.contact_graspnet_pytorch.inference import inference
except ImportError:
pytest.skip("Required modules could not be imported. Make sure you have run 'pip install .[manipulation]'.")
return

# Test data path - use the default test data path
test_data_path = os.path.join(get_data("models_contact_graspnet"), "test_data/0.npy")

# Check if test data exists
test_files = glob.glob(test_data_path)
if not test_files:
pytest.fail(f"No test data found at {test_data_path}")

# Load config with default values
ckpt_dir = 'models_contact_graspnet'
global_config = config_utils.load_config(ckpt_dir, batch_size=1)

# Run inference function with the same params as the command line
result_files_before = glob.glob('results/predictions_*.npz')

inference(
global_config=global_config,
ckpt_dir=ckpt_dir,
input_paths=test_data_path,
local_regions=True,
filter_grasps=True,
skip_border_objects=False,
z_range=[0.2, 1.8],
forward_passes=1,
K=None
)

# Verify results were created
result_files_after = glob.glob('results/predictions_*.npz')
assert len(result_files_after) >= len(result_files_before), "No result files were generated"

# Load at least one result file and verify it contains expected data
if result_files_after:
latest_result = sorted(result_files_after)[-1]
result_data = np.load(latest_result, allow_pickle=True)
expected_keys = ['pc_full', 'pred_grasps_cam', 'scores', 'contact_pts', 'pc_colors']
for key in expected_keys:
assert key in result_data.files, f"Expected key '{key}' not found in results"
Loading