-
Notifications
You must be signed in to change notification settings - Fork 26
Description
这是我加载了预训练参数concerto_large_outdoor.pth的训练结果,似乎还不错,但是我尝试仿照demo/2_sem_seg.py的方法处理sample2_outdoor.npz无法得到正确的结果,请问可以给我点意见建议吗?
[2025-11-28 00:33:59,486 INFO test.py line 295 17919] Test: 08_004070 [4071/4071]-122346 Batch 9.018 (11.609) Accuracy 0.7338 (0.7363) mIoU 0.4692 (0.6516) [2025-11-28 00:33:59,539 INFO test.py line 312 17919] Syncing ... [2025-11-28 00:33:59,544 INFO test.py line 340 17919] Val result: mIoU/mAcc/allAcc 0.6516/0.7363/0.9043 [2025-11-28 00:33:59,544 INFO test.py line 346 17919] Class_0 - car Result: iou/accuracy 0.9610/0.9833 [2025-11-28 00:33:59,544 INFO test.py line 346 17919] Class_1 - bicycle Result: iou/accuracy 0.4675/0.6021 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_2 - motorcycle Result: iou/accuracy 0.6744/0.7549 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_3 - truck Result: iou/accuracy 0.8655/0.9456 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_4 - other-vehicle Result: iou/accuracy 0.6478/0.7219 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_5 - person Result: iou/accuracy 0.7463/0.8718 [2025-11-28 00:33:59,545 INFO test.py line 346 17919] Class_6 - bicyclist Result: iou/accuracy 0.8761/0.9440 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_7 - motorcyclist Result: iou/accuracy 0.0000/0.0000 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_8 - road Result: iou/accuracy 0.9134/0.9559 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_9 - parking Result: iou/accuracy 0.4679/0.5319 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_10 - sidewalk Result: iou/accuracy 0.7478/0.9012 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_11 - other-ground Result: iou/accuracy 0.1269/0.1930 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_12 - building Result: iou/accuracy 0.8861/0.9633 [2025-11-28 00:33:59,546 INFO test.py line 346 17919] Class_13 - fence Result: iou/accuracy 0.5888/0.7681 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_14 - vegetation Result: iou/accuracy 0.8682/0.9265 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_15 - trunk Result: iou/accuracy 0.7115/0.7977 [2025-11-28 00:33:59,547 INFO test.py line 346 17919] Class_16 - terrain Result: iou/accuracy 0.7053/0.7847 [2025-11-28 00:33:59,548 INFO test.py line 346 17919] Class_17 - pole Result: iou/accuracy 0.6363/0.7702 [2025-11-28 00:33:59,548 INFO test.py line 346 17919] Class_18 - traffic-sign Result: iou/accuracy 0.4894/0.5739 [2025-11-28 00:33:59,548 INFO test.py line 354 17919] <<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<
这是测试代码
`
import numpy as np
import concerto
import torch
import torch.nn as nn
import open3d as o3d
import argparse
import os
try:
import flash_attn
except ImportError:
flash_attn = None
device = "cuda" if torch.cuda.is_available() else "cpu"
KITTI Meta data - 19 classes
KITTI_VALID_CLASS_IDS = (
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
)
KITTI_CLASS_LABELS = (
"car",
"bicycle",
"motorcycle",
"truck",
"other-vehicle",
"person",
"bicyclist",
"motorcyclist",
"road",
"parking",
"sidewalk",
"other-ground",
"building",
"fence",
"vegetation",
"trunk",
"terrain",
"pole",
"traffic-sign",
)
KITTI color map - 为每个类别分配颜色
KITTI_COLOR_MAP = {
0: (255.0, 0.0, 0.0), # car - 红色
1: (0.0, 255.0, 0.0), # bicycle - 绿色
2: (0.0, 0.0, 255.0), # motorcycle - 蓝色
3: (255.0, 255.0, 0.0), # truck - 黄色
4: (255.0, 0.0, 255.0), # other-vehicle - 品红
5: (0.0, 255.0, 255.0), # person - 青色
6: (255.0, 128.0, 0.0), # bicyclist - 橙色
7: (128.0, 0.0, 255.0), # motorcyclist - 紫色
8: (128.0, 128.0, 128.0), # road - 灰色
9: (255.0, 192.0, 203.0), # parking - 粉色
10: (0.0, 128.0, 128.0), # sidewalk - 深青
11: (255.0, 215.0, 0.0), # other-ground - 金色
12: (70.0, 130.0, 180.0), # building - 钢蓝
13: (165.0, 42.0, 42.0), # fence - 棕色
14: (50.0, 205.0, 50.0), # vegetation - 亮绿
15: (255.0, 99.0, 71.0), # trunk - 番茄红
16: (0.0, 100.0, 0.0), # terrain - 深绿
17: (211.0, 211.0, 211.0), # pole - 浅灰
18: (255.0, 255.0, 255.0) # traffic-sign - 白色
}
Get colors for valid classes (1-19)
CLASS_COLOR = [KITTI_COLOR_MAP[id] for id in KITTI_VALID_CLASS_IDS]
class SegHead(nn.Module):
def init(self, backbone_out_channels, num_classes):
super(SegHead, self).init()
self.seg_head = nn.Linear(backbone_out_channels, num_classes)
def forward(self, x):
return self.seg_head(x)
def visualize_results(coord, pred, class_colors):
"""可视化分割结果"""
# 创建点云对象
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coord)
# 为每个点分配颜色
colors = np.array([class_colors[p] if p < len(class_colors) else (0, 0, 0) for p in pred])
pcd.colors = o3d.utility.Vector3dVector(colors)
# 可视化
o3d.visualization.draw_geometries([pcd])
if name == "main":
parser = argparse.ArgumentParser()
parser.add_argument(
'--wo_color',
dest='wo_color',
action='store_true',
help="disable the color."
)
parser.add_argument(
'--wo_normal',
dest='wo_normal',
action='store_true',
help="disable the normal."
)
parser.add_argument(
'--model_path',
type=str,
default=None,
help="Path to the model checkpoint file"
)
parser.add_argument(
'--seg_head_path',
type=str,
default=None,
help="Path to the segmentation head checkpoint file"
)
parser.add_argument(
'--data_path',
type=str,
default=None,
help="Path to the input point cloud data file"
)
args = parser.parse_args()
# set random seed
concerto.utils.set_seed(46647087)
# 脚本目录
script_dir = os.path.dirname(os.path.abspath(__file__))
# 加载模型 - 使用本地模型而不是从仓库加载
model_path = args.model_path or os.path.join(script_dir, "../model/concerto_large_outdoor.pth")
print(f"Loading model from: {model_path}")
if flash_attn is not None:
model = concerto.load(model_path).to(device)
else:
custom_config = dict(
enc_patch_size=[1024 for _ in range(5)], # reduce patch size if necessary
enable_flash=False,
)
model = concerto.load(
model_path, custom_config=custom_config
).to(device)
# 加载分割头 - KITTI配置:1728输入通道,19个类别
seg_head_path = args.seg_head_path or os.path.join(script_dir, "../model/seg_head_kitti.pth")
print(f"Loading segmentation head from: {seg_head_path}")
try:
# 尝试直接加载checkpoint
ckpt = concerto.load(seg_head_path, ckpt_only=True,)
# 检查是否包含config和state_dict
if "config" not in ckpt:
# 如果没有config,使用KITTI的固定配置
ckpt["config"] = {
'backbone_out_channels': 1728,
'num_classes': 19
}
print("Using default KITTI configuration: backbone_out_channels=1728, num_classes=19")
except Exception as e:
print(f"Error loading checkpoint: {e}")
# 创建默认配置
ckpt = {
"config": {
'backbone_out_channels': 1728,
'num_classes': 19
},
"state_dict": {}
}
# 创建并加载分割头
seg_head = SegHead(**ckpt["config"]).to(device)
if "state_dict" in ckpt and ckpt["state_dict"]:
seg_head.load_state_dict(ckpt["state_dict"])
else:
print("Warning: No state_dict found in checkpoint, using randomly initialized weights")
# 加载默认数据转换管道
transform = concerto.transform.default()
data_path = os.path.join(script_dir, "../data/sample2_outdoor.npz")
# 加载数据
if args.data_path:
# 从指定路径加载数据
print(f"Loading data from: {args.data_path}")
if args.data_path.endswith('.npz'):
data = np.load(args.data_path)
point = {k: data[k] for k in data.files}
else:
# 尝试使用open3d加载点云文件
pcd = o3d.io.read_point_cloud(args.data_path)
point = {
"coord": np.asarray(pcd.points),
"color": np.asarray(pcd.colors) if pcd.has_colors() else np.zeros_like(np.asarray(pcd.points)),
"normal": np.asarray(pcd.normals) if pcd.has_normals() else np.zeros_like(np.asarray(pcd.points))
}
else:
# 使用示例数据
point = concerto.data.load("sample2_outdoor")
# 处理颜色和法线选项
if args.wo_color:
point["color"] = np.zeros_like(point["coord"])
if args.wo_normal:
point["normal"] = np.zeros_like(point["coord"])
# 保存原始坐标用于可视化
original_coord = point["coord"].copy()
# 应用数据转换
point = transform(point)
# 推理
model.eval()
seg_head.eval()
with torch.inference_mode():
# 将数据移至GPU
for key in point.keys():
if isinstance(point[key], torch.Tensor) and device == "cuda":
point[key] = point[key].cuda(non_blocking=True)
# 模型前向传播
point = model(point)
# 处理池化父节点(如果存在)
while "pooling_parent" in point.keys():
assert "pooling_inverse" in point.keys()
parent = point.pop("pooling_parent")
inverse = point.pop("pooling_inverse")
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
point = parent
# 获取特征并进行分割预测
feat = point.feat
seg_logits = seg_head(feat)
pred = seg_logits.argmax(dim=-1).data.cpu().numpy()
color = np.array(CLASS_COLOR)[pred]
print(f"Segmentation completed. Number of points: {len(pred)}")
print(f"Predicted classes: {np.unique(pred)}")
# 可视化结果
print("Visualizing results...")
# visualize_results(original_coord, pred, CLASS_COLOR)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(point.coord.cpu().detach().numpy())
pcd.colors = o3d.utility.Vector3dVector(color / 255.0)
o3d.visualization.draw_geometries([pcd])`
这是我提取分割头的方式
`# /home/simon/PTv3_ws/extract_seg_head.py
import torch
import os
设置路径
MODEL_PATH = "/root/workspace/workspace/Pointcept/exp/concerto/semseg-ptv3-large-v1m1-test-kitti-lin/model/model_best.pth"
OUTPUT_PATH = "/root/workspace/workspace/Concerto/model/seg_head_kitti.pth"
确保输出目录存在
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
print(f"Loading model from: {MODEL_PATH}")
try:
# 加载完整模型
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
# 提取模型状态字典
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint # 如果没有state_dict键,直接使用checkpoint
# 提取分割头参数
seg_head_state_dict = {}
for key, value in state_dict.items():
if key.startswith('seg_head.'):
# 保持原始键名,因为SegHead类的state_dict使用相同的命名
seg_head_state_dict[key] = value
print(f"Found seg_head parameter: {key}, shape: {value.shape}")
elif key.startswith('cls.'):
# 检查是否有cls开头的键(有些模型可能使用cls作为分割头名称)
# 将cls.转换为seg_head.以匹配SegHead类的期望
new_key = key.replace('cls.', 'seg_head.')
seg_head_state_dict[new_key] = value
print(f"Found cls parameter, converted to: {new_key}, shape: {value.shape}")
# 验证是否找到分割头参数
if not seg_head_state_dict:
print("Warning: No seg_head or cls parameters found in the model!")
print("Available keys:")
for key in list(state_dict.keys())[:20]: # 只显示前20个键作为示例
print(f" {key}")
if len(state_dict) > 20:
print(f" ... and {len(state_dict) - 20} more keys")
else:
print(f"Successfully extracted {len(seg_head_state_dict)} seg_head parameters")
# 创建符合要求格式的checkpoint文件,包含config和state_dict
output_checkpoint = {
'config': {
'backbone_out_channels': 1728, # KITTI配置
'num_classes': 19
},
'state_dict': seg_head_state_dict
}
# 保存分割头
torch.save(output_checkpoint, OUTPUT_PATH)
print(f"Segmentation head saved to: {OUTPUT_PATH}")
print(f"Checkpoint structure: {list(output_checkpoint.keys())}")
print(f"State dict keys: {list(seg_head_state_dict.keys())}")
except Exception as e:
print(f"Error processing model: {e}")
import traceback
traceback.print_exc()`