Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .gitignore
Binary file not shown.
65 changes: 46 additions & 19 deletions src/yolo11n_vehicle_counter/scripts/yolo_vehicle_counter_carla.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def main(model_path=None, input_video_path=None, output_video_path=None):
limits = [350, 500, 1230, 500] # 计数线位置:起点(x1, y)到终点(x2, y) - 针对CARLA视频放低并进一步往右移红线
total_counts, crossed_ids = [], set() # 总计数和已计数车辆ID集合

# 分类计数器
class_counts = {'car': 0, 'motorbike': 0, 'bus': 0, 'truck': 0} # 各类别已计数ID集合
crossed_by_class = {cls: set() for cls in class_counts.keys()} # 各类别已计数ID集合


def draw_overlay(frame, pt1, pt2, alpha=0.25, color=(51, 68, 255), filled=True):
"""绘制半透明覆盖矩形
Expand All @@ -81,33 +85,41 @@ def draw_overlay(frame, pt1, pt2, alpha=0.25, color=(51, 68, 255), filled=True):
cv.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)


def count_vehicles(track_id, cx, cy, limits, crossed_ids):
def count_vehicles(track_id, cx, cy, limits, crossed_ids, class_name=None, crossed_by_class=None, class_counts=None):
"""统计穿过计数线的车辆

Args:
track_id: 车辆追踪ID
cx, cy: 车辆中心点坐标
limits: 计数线位置
crossed_ids: 已计数车辆ID集合
class_name: 车辆类别名称
crossed_by_class: 各类别已计数ID集合
class_counts: 各类别计数器

Returns:
bool: 是否计数成功
tuple: (是否计数成功, 车辆类别)
"""
if limits[0] < cx < limits[2] and limits[1] - 15 < cy < limits[1] + 15 and track_id not in crossed_ids:
crossed_ids.add(track_id)
return True
return False
if class_name and crossed_by_class and class_counts:
if track_id not in crossed_by_class[class_name]:
crossed_by_class[class_name].add(track_id)
class_counts[class_name] += 1
return True, class_name
return False, None


def draw_tracks_and_count(frame, detections, total_counts, limits):
"""绘制轨迹并统计车辆
def draw_tracks_and_count(frame, detections, limits):
"""绘制轨迹并统计车辆(按类别分类统计)

Args:
frame: 输入帧
detections: 检测结果
total_counts: 总计数列表
limits: 计数线位置
"""
nonlocal total_counts, crossed_ids, class_counts, crossed_by_class

# 按车辆类别和检测置信度过滤 - 针对CARLA视频优化置信度阈值
detections = detections[(np.isin(detections.class_id, selected_classes)) & (detections.confidence > 0.4)]

Expand All @@ -121,21 +133,30 @@ def draw_tracks_and_count(frame, detections, total_counts, limits):
trace_annotator.annotate(frame, detections=detections)

# 处理每个检测到的车辆
for track_id, center_point in zip(detections.tracker_id,
detections.get_anchors_coordinates(anchor=sv.Position.CENTER)):
for i, (track_id, center_point) in enumerate(zip(detections.tracker_id,
detections.get_anchors_coordinates(anchor=sv.Position.CENTER))):
cx, cy = map(int, center_point)
cls_id = detections.class_id[i]
cls_name = class_names[cls_id]

cv.circle(frame, (cx, cy), 4, (0, 255, 255), cv.FILLED) # 绘制车辆中心点

if count_vehicles(track_id, cx, cy, limits, crossed_ids):
total_counts.append(track_id)
sv.draw_line(frame, start=sv.Point(x=limits[0], y=limits[1]), end=sv.Point(x=limits[2], y=limits[3]),
color=sv.Color.ROBOFLOW, thickness=4)
# 调整覆盖区域位置,与新的计数线匹配 - 跟着红线一起往右移
draw_overlay(frame, (350, 450), (1230, 550), alpha=0.25, color=(10, 255, 50))
# 检查是否穿过计数线
if limits[0] < cx < limits[2] and limits[1] - 15 < cy < limits[1] + 15:
if track_id not in crossed_ids:
crossed_ids.add(track_id)
total_counts.append(track_id)

# 分类计数
if cls_name in class_counts:
class_counts[cls_name] += 1

sv.draw_line(frame, start=sv.Point(x=limits[0], y=limits[1]), end=sv.Point(x=limits[2], y=limits[3]),
color=sv.Color.ROBOFLOW, thickness=4)
draw_overlay(frame, (350, 450), (1230, 550), alpha=0.25, color=(10, 255, 50))

# 显示车辆计数 - 往右移避免遮挡
sv.draw_text(frame, f"COUNTS: {len(total_counts)}", sv.Point(x=150, y=80), sv.Color.ROBOFLOW, 1.25,
# 显示车辆总计数
sv.draw_text(frame, f"TOTAL: {len(total_counts)}", sv.Point(x=50, y=50), sv.Color.ROBOFLOW, 0.8,
2, background_color=sv.Color.WHITE)


Expand Down Expand Up @@ -183,7 +204,7 @@ def draw_tracks_and_count(frame, detections, total_counts, limits):
color=sv.Color.RED, thickness=4)
# 调整覆盖区域透明度 - 与红线位置匹配
draw_overlay(frame, (350, 450), (1230, 550), alpha=0.15)
draw_tracks_and_count(frame, detections, total_counts, limits)
draw_tracks_and_count(frame, detections, limits)

# 写入帧到输出视频
out.write(frame)
Expand All @@ -199,7 +220,13 @@ def draw_tracks_and_count(frame, detections, total_counts, limits):
out.release()
cv.destroyAllWindows()

print(f"处理完成!总计数: {len(total_counts)} 辆车")
print(f"处理完成!")
print(f"=" * 40)
print(f"总计数: {len(total_counts)} 辆车")
print("-" * 40)
for cls_name, count in class_counts.items():
print(f" {cls_name}: {count}")
print("=" * 40)


if __name__ == "__main__":
Expand Down
Loading