Skip to content

Outdoor segmentation question #5

@Cureser

Description

@Cureser

这是我加载了预训练参数concerto_large_outdoor.pth的训练结果,似乎还不错,但是我尝试仿照demo/2_sem_seg.py的方法处理sample2_outdoor.npz无法得到正确的结果,请问可以给我点意见建议吗?
[2025-11-28 00:33:59,486 INFO test.py line 295 17919] Test: 08_004070 [4071/4071]-122346 Batch 9.018 (11.609) Accuracy 0.7338 (0.7363) mIoU 0.4692 (0.6516) [2025-11-28 00:33:59,539 INFO test.py line 312 17919] Syncing ... [2025-11-28 00:33:59,544 INFO test.py line 340 17919] Val result: mIoU/mAcc/allAcc 0.6516/0.7363/0.9043 [2025-11-28 00:33:59,544 INFO test.py line 346 17919] Class_0 - car Result: iou/accuracy 0.9610/0.9833 [2025-11-28 00:33:59,544 INFO test.py line 346 17919] Class_1 - bicycle Result: iou/accuracy 0.4675/0.6021 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_2 - motorcycle Result: iou/accuracy 0.6744/0.7549 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_3 - truck Result: iou/accuracy 0.8655/0.9456 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_4 - other-vehicle Result: iou/accuracy 0.6478/0.7219 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_5 - person Result: iou/accuracy 0.7463/0.8718 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_6 - bicyclist Result: iou/accuracy 0.8761/0.9440 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_7 - motorcyclist Result: iou/accuracy 0.0000/0.0000 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_8 - road Result: iou/accuracy 0.9134/0.9559 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_9 - parking Result: iou/accuracy 0.4679/0.5319 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_10 - sidewalk Result: iou/accuracy 0.7478/0.9012 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_11 - other-ground Result: iou/accuracy 0.1269/0.1930 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_12 - building Result: iou/accuracy 0.8861/0.9633 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_13 - fence Result: iou/accuracy 0.5888/0.7681 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_14 - vegetation Result: iou/accuracy 0.8682/0.9265 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_15 - trunk Result: iou/accuracy 0.7115/0.7977 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_16 - terrain Result: iou/accuracy 0.7053/0.7847 [2025-11-28 00:33:59,548 INFO test.py line 346 17919] Class_17 - pole Result: iou/accuracy 0.6363/0.7702 [2025-11-28 00:33:59,548 INFO test.py line 346 17919] Class_18 - traffic-sign Result: iou/accuracy 0.4894/0.5739 [2025-11-28 00:33:59,548 INFO test.py line 354 17919] <<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<
这是测试代码
`
import numpy as np
import concerto
import torch
import torch.nn as nn
import open3d as o3d
import argparse
import os

try:
import flash_attn
except ImportError:
flash_attn = None
device = "cuda" if torch.cuda.is_available() else "cpu"

KITTI Meta data - 19 classes

KITTI_VALID_CLASS_IDS = (
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
)

KITTI_CLASS_LABELS = (
"car",
"bicycle",
"motorcycle",
"truck",
"other-vehicle",
"person",
"bicyclist",
"motorcyclist",
"road",
"parking",
"sidewalk",
"other-ground",
"building",
"fence",
"vegetation",
"trunk",
"terrain",
"pole",
"traffic-sign",
)

KITTI color map - 为每个类别分配颜色

KITTI_COLOR_MAP = {
0: (255.0, 0.0, 0.0), # car - 红色
1: (0.0, 255.0, 0.0), # bicycle - 绿色
2: (0.0, 0.0, 255.0), # motorcycle - 蓝色
3: (255.0, 255.0, 0.0), # truck - 黄色
4: (255.0, 0.0, 255.0), # other-vehicle - 品红
5: (0.0, 255.0, 255.0), # person - 青色
6: (255.0, 128.0, 0.0), # bicyclist - 橙色
7: (128.0, 0.0, 255.0), # motorcyclist - 紫色
8: (128.0, 128.0, 128.0), # road - 灰色
9: (255.0, 192.0, 203.0), # parking - 粉色
10: (0.0, 128.0, 128.0), # sidewalk - 深青
11: (255.0, 215.0, 0.0), # other-ground - 金色
12: (70.0, 130.0, 180.0), # building - 钢蓝
13: (165.0, 42.0, 42.0), # fence - 棕色
14: (50.0, 205.0, 50.0), # vegetation - 亮绿
15: (255.0, 99.0, 71.0), # trunk - 番茄红
16: (0.0, 100.0, 0.0), # terrain - 深绿
17: (211.0, 211.0, 211.0), # pole - 浅灰
18: (255.0, 255.0, 255.0) # traffic-sign - 白色
}

Get colors for valid classes (1-19)

CLASS_COLOR = [KITTI_COLOR_MAP[id] for id in KITTI_VALID_CLASS_IDS]

class SegHead(nn.Module):
def init(self, backbone_out_channels, num_classes):
super(SegHead, self).init()
self.seg_head = nn.Linear(backbone_out_channels, num_classes)

def forward(self, x):
    return self.seg_head(x)

def visualize_results(coord, pred, class_colors):
"""可视化分割结果"""
# 创建点云对象
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coord)

# 为每个点分配颜色
colors = np.array([class_colors[p] if p < len(class_colors) else (0, 0, 0) for p in pred])
pcd.colors = o3d.utility.Vector3dVector(colors)

# 可视化
o3d.visualization.draw_geometries([pcd])

if name == "main":
parser = argparse.ArgumentParser()
parser.add_argument(
'--wo_color',
dest='wo_color',
action='store_true',
help="disable the color."
)
parser.add_argument(
'--wo_normal',
dest='wo_normal',
action='store_true',
help="disable the normal."
)
parser.add_argument(
'--model_path',
type=str,
default=None,
help="Path to the model checkpoint file"
)
parser.add_argument(
'--seg_head_path',
type=str,
default=None,
help="Path to the segmentation head checkpoint file"
)
parser.add_argument(
'--data_path',
type=str,
default=None,
help="Path to the input point cloud data file"
)
args = parser.parse_args()

# set random seed
concerto.utils.set_seed(46647087)

# 脚本目录
script_dir = os.path.dirname(os.path.abspath(__file__))

# 加载模型 - 使用本地模型而不是从仓库加载
model_path = args.model_path or os.path.join(script_dir, "../model/concerto_large_outdoor.pth")
print(f"Loading model from: {model_path}")

if flash_attn is not None:
    model = concerto.load(model_path).to(device)
else:
    custom_config = dict(
        enc_patch_size=[1024 for _ in range(5)],  # reduce patch size if necessary
        enable_flash=False,
    )
    model = concerto.load(
        model_path, custom_config=custom_config
    ).to(device)

# 加载分割头 - KITTI配置:1728输入通道,19个类别
seg_head_path = args.seg_head_path or os.path.join(script_dir, "../model/seg_head_kitti.pth")
print(f"Loading segmentation head from: {seg_head_path}")

try:
    # 尝试直接加载checkpoint
    ckpt = concerto.load(seg_head_path, ckpt_only=True,)
    
    # 检查是否包含config和state_dict
    if "config" not in ckpt:
        # 如果没有config,使用KITTI的固定配置
        ckpt["config"] = {
            'backbone_out_channels': 1728,
            'num_classes': 19
        }
        print("Using default KITTI configuration: backbone_out_channels=1728, num_classes=19")
except Exception as e:
    print(f"Error loading checkpoint: {e}")
    # 创建默认配置
    ckpt = {
        "config": {
            'backbone_out_channels': 1728,
            'num_classes': 19
        },
        "state_dict": {}
    }

# 创建并加载分割头
seg_head = SegHead(**ckpt["config"]).to(device)
if "state_dict" in ckpt and ckpt["state_dict"]:
    seg_head.load_state_dict(ckpt["state_dict"])
else:
    print("Warning: No state_dict found in checkpoint, using randomly initialized weights")

# 加载默认数据转换管道
transform = concerto.transform.default()
data_path = os.path.join(script_dir, "../data/sample2_outdoor.npz")
# 加载数据
if args.data_path:
    # 从指定路径加载数据
    print(f"Loading data from: {args.data_path}")
    if args.data_path.endswith('.npz'):
        data = np.load(args.data_path)
        point = {k: data[k] for k in data.files}
    else:
        # 尝试使用open3d加载点云文件
        pcd = o3d.io.read_point_cloud(args.data_path)
        point = {
            "coord": np.asarray(pcd.points),
            "color": np.asarray(pcd.colors) if pcd.has_colors() else np.zeros_like(np.asarray(pcd.points)),
            "normal": np.asarray(pcd.normals) if pcd.has_normals() else np.zeros_like(np.asarray(pcd.points))
        }
else:
    # 使用示例数据
    point = concerto.data.load("sample2_outdoor")

# 处理颜色和法线选项
if args.wo_color:
    point["color"] = np.zeros_like(point["coord"])
if args.wo_normal:
    point["normal"] = np.zeros_like(point["coord"])

# 保存原始坐标用于可视化
original_coord = point["coord"].copy()

# 应用数据转换
point = transform(point)

# 推理
model.eval()
seg_head.eval()
with torch.inference_mode():
    # 将数据移至GPU
    for key in point.keys():
        if isinstance(point[key], torch.Tensor) and device == "cuda":
            point[key] = point[key].cuda(non_blocking=True)
    
    # 模型前向传播
    point = model(point)
    
    # 处理池化父节点(如果存在)
    while "pooling_parent" in point.keys():
        assert "pooling_inverse" in point.keys()
        parent = point.pop("pooling_parent")
        inverse = point.pop("pooling_inverse")
        parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
        point = parent
    
    # 获取特征并进行分割预测
    feat = point.feat
    seg_logits = seg_head(feat)
    pred = seg_logits.argmax(dim=-1).data.cpu().numpy()
    color = np.array(CLASS_COLOR)[pred]

    print(f"Segmentation completed. Number of points: {len(pred)}")
    print(f"Predicted classes: {np.unique(pred)}")

# 可视化结果
print("Visualizing results...")
# visualize_results(original_coord, pred, CLASS_COLOR)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(point.coord.cpu().detach().numpy())
pcd.colors = o3d.utility.Vector3dVector(color / 255.0)    
o3d.visualization.draw_geometries([pcd])`

这是我提取分割头的方式
`# /home/simon/PTv3_ws/extract_seg_head.py
import torch
import os

设置路径

MODEL_PATH = "/root/workspace/workspace/Pointcept/exp/concerto/semseg-ptv3-large-v1m1-test-kitti-lin/model/model_best.pth"
OUTPUT_PATH = "/root/workspace/workspace/Concerto/model/seg_head_kitti.pth"

确保输出目录存在

os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)

print(f"Loading model from: {MODEL_PATH}")
try:
# 加载完整模型
checkpoint = torch.load(MODEL_PATH, map_location='cpu')

# 提取模型状态字典
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  # 如果没有state_dict键,直接使用checkpoint

# 提取分割头参数
seg_head_state_dict = {}
for key, value in state_dict.items():
    if key.startswith('seg_head.'):
        # 保持原始键名,因为SegHead类的state_dict使用相同的命名
        seg_head_state_dict[key] = value
        print(f"Found seg_head parameter: {key}, shape: {value.shape}")
    elif key.startswith('cls.'):
        # 检查是否有cls开头的键(有些模型可能使用cls作为分割头名称)
        # 将cls.转换为seg_head.以匹配SegHead类的期望
        new_key = key.replace('cls.', 'seg_head.')
        seg_head_state_dict[new_key] = value
        print(f"Found cls parameter, converted to: {new_key}, shape: {value.shape}")

# 验证是否找到分割头参数
if not seg_head_state_dict:
    print("Warning: No seg_head or cls parameters found in the model!")
    print("Available keys:")
    for key in list(state_dict.keys())[:20]:  # 只显示前20个键作为示例
        print(f"  {key}")
    if len(state_dict) > 20:
        print(f"  ... and {len(state_dict) - 20} more keys")
else:
    print(f"Successfully extracted {len(seg_head_state_dict)} seg_head parameters")

# 创建符合要求格式的checkpoint文件,包含config和state_dict
output_checkpoint = {
    'config': {
        'backbone_out_channels': 1728,  # KITTI配置
        'num_classes': 19
    },
    'state_dict': seg_head_state_dict
}

# 保存分割头
torch.save(output_checkpoint, OUTPUT_PATH)
print(f"Segmentation head saved to: {OUTPUT_PATH}")
print(f"Checkpoint structure: {list(output_checkpoint.keys())}")
print(f"State dict keys: {list(seg_head_state_dict.keys())}")

except Exception as e:
print(f"Error processing model: {e}")
import traceback
traceback.print_exc()`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions