diff --git a/CMakeLists.txt b/CMakeLists.txt index 9448ceaae..aacd56dd2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -159,6 +159,7 @@ pybind11_add_module(${TARGET_NAME} src/pipeline/datatype/PointCloudConfigBindings.cpp src/pipeline/datatype/PointCloudDataBindings.cpp src/pipeline/datatype/ImageAlignConfigBindings.cpp + src/pipeline/datatype/ObjectTrackerConfigBindings.cpp ) if(WIN32) diff --git a/depthai-core b/depthai-core index 2b4d78030..603bd2261 160000 --- a/depthai-core +++ b/depthai-core @@ -1 +1 @@ -Subproject commit 2b4d780302f7ea7f6861c4d3b95ac912737b8158 +Subproject commit 603bd2261dbea34c61e0be7f482d93bc0b29a7d5 diff --git a/examples/ObjectTracker/object_tracker.py b/examples/ObjectTracker/object_tracker.py index 37e6f16a6..03a142a14 100755 --- a/examples/ObjectTracker/object_tracker.py +++ b/examples/ObjectTracker/object_tracker.py @@ -29,9 +29,11 @@ xlinkOut = pipeline.create(dai.node.XLinkOut) trackerOut = pipeline.create(dai.node.XLinkOut) +xinTrackerConfig = pipeline.create(dai.node.XLinkIn) xlinkOut.setStreamName("preview") trackerOut.setStreamName("tracklets") +xinTrackerConfig.setStreamName("trackerConfig") # Properties camRgb.setPreviewSize(300, 300) @@ -64,11 +66,19 @@ detectionNetwork.out.link(objectTracker.inputDetections) objectTracker.out.link(trackerOut.input) +# set tracking parameters +objectTracker.setOcclusionRatioThreshold(0.4) +objectTracker.setTrackletMaxLifespan(120) +objectTracker.setTrackletBirthThreshold(3) + +xinTrackerConfig.out.link(objectTracker.inputConfig) + # Connect to device and start pipeline with dai.Device(pipeline) as device: preview = device.getOutputQueue("preview", 4, False) tracklets = device.getOutputQueue("tracklets", 4, False) + trackerConfigQueue = device.getInputQueue("trackerConfig") startTime = time.monotonic() counter = 0 @@ -76,6 +86,7 @@ frame = None while(True): + latestTrackedIds = [] imgFrame = preview.get() track = tracklets.get() @@ -106,9 +117,26 @@ cv2.putText(frame, t.status.name, (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255) cv2.rectangle(frame, (x1, y1), (x2, y2), color, cv2.FONT_HERSHEY_SIMPLEX) + if t.status == dai.Tracklet.TrackingStatus.TRACKED: + latestTrackedIds.append(t.id) + cv2.putText(frame, "NN fps: {:.2f}".format(fps), (2, frame.shape[0] - 4), cv2.FONT_HERSHEY_TRIPLEX, 0.4, color) cv2.imshow("tracker", frame) - if cv2.waitKey(1) == ord('q'): + key = cv2.waitKey(1) + if key == ord('q'): break + elif key == ord('g'): + # send tracker config to device + config = dai.ObjectTrackerConfig() + + # take a random ID from the latest tracked IDs + if len(latestTrackedIds) > 0: + idToRemove = (np.random.choice(latestTrackedIds)) + print(f"Force removing ID: {idToRemove}") + config.forceRemoveID(idToRemove) + trackerConfigQueue.send(config) + else: + print("No tracked IDs available to force remove") + diff --git a/src/DatatypeBindings.cpp b/src/DatatypeBindings.cpp index 68d3dec47..aae56fbf3 100644 --- a/src/DatatypeBindings.cpp +++ b/src/DatatypeBindings.cpp @@ -28,6 +28,7 @@ void bind_tracklets(pybind11::module& m, void* pCallstack); void bind_pointcloudconfig(pybind11::module& m, void* pCallstack); void bind_pointclouddata(pybind11::module& m, void* pCallstack); void bind_imagealignconfig(pybind11::module& m, void* pCallstack); +void bind_objecttrackerconfig(pybind11::module& m, void* pCallstack); void DatatypeBindings::addToCallstack(std::deque& callstack) { // Bind common datatypebindings @@ -59,6 +60,7 @@ void DatatypeBindings::addToCallstack(std::deque& callstack) { callstack.push_front(bind_pointcloudconfig); callstack.push_front(bind_pointclouddata); callstack.push_front(bind_imagealignconfig); + callstack.push_front(bind_objecttrackerconfig); } void DatatypeBindings::bind(pybind11::module& m, void* pCallstack){ diff --git a/src/pipeline/datatype/ObjectTrackerConfigBindings.cpp b/src/pipeline/datatype/ObjectTrackerConfigBindings.cpp new file mode 100644 index 000000000..bc656c024 --- /dev/null +++ b/src/pipeline/datatype/ObjectTrackerConfigBindings.cpp @@ -0,0 +1,54 @@ +#include "DatatypeBindings.hpp" +#include "pipeline/CommonBindings.hpp" +#include +#include + +// depthai +#include "depthai/pipeline/datatype/ObjectTrackerConfig.hpp" + +//pybind +#include +#include + +// #include "spdlog/spdlog.h" + +void bind_objecttrackerconfig(pybind11::module& m, void* pCallstack){ + + using namespace dai; + + py::class_> rawConfig(m, "RawObjectTrackerConfig", DOC(dai, RawObjectTrackerConfig)); + py::class_> config(m, "ObjectTrackerConfig", DOC(dai, ObjectTrackerConfig)); + + /////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////// + // Call the rest of the type defines, then perform the actual bindings + Callstack* callstack = (Callstack*) pCallstack; + auto cb = callstack->top(); + callstack->pop(); + cb(m, pCallstack); + // Actual bindings + /////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////// + + // Metadata / raw + rawConfig + .def(py::init<>()) + .def_readwrite("trackletIdsToRemove", &RawObjectTrackerConfig::trackletIdsToRemove, DOC(dai, RawObjectTrackerConfig, trackletIdsToRemove)) + ; + + // Message + config + .def(py::init<>()) + .def(py::init>()) + + .def("set", &ObjectTrackerConfig::set, py::arg("config"), DOC(dai, ObjectTrackerConfig, set)) + .def("get", &ObjectTrackerConfig::get, DOC(dai, ObjectTrackerConfig, get)) + .def("forceRemoveID", &ObjectTrackerConfig::forceRemoveID, DOC(dai, ObjectTrackerConfig, forceRemoveID)) + .def("forceRemoveIDs", &ObjectTrackerConfig::forceRemoveIDs, DOC(dai, ObjectTrackerConfig, forceRemoveIDs)) + ; + + // add aliases + +} diff --git a/src/pipeline/node/ObjectTrackerBindings.cpp b/src/pipeline/node/ObjectTrackerBindings.cpp index b2b2b76a3..a06c6499f 100644 --- a/src/pipeline/node/ObjectTrackerBindings.cpp +++ b/src/pipeline/node/ObjectTrackerBindings.cpp @@ -48,6 +48,10 @@ void bind_objecttracker(pybind11::module& m, void* pCallstack){ .def_readwrite("detectionLabelsToTrack", &ObjectTrackerProperties::detectionLabelsToTrack, DOC(dai, ObjectTrackerProperties, detectionLabelsToTrack)) .def_readwrite("trackerType", &ObjectTrackerProperties::trackerType, DOC(dai, ObjectTrackerProperties, trackerType)) .def_readwrite("trackerIdAssignmentPolicy", &ObjectTrackerProperties::trackerIdAssignmentPolicy, DOC(dai, ObjectTrackerProperties, trackerIdAssignmentPolicy)) + .def_readwrite("trackingPerClass", &ObjectTrackerProperties::trackingPerClass, DOC(dai, ObjectTrackerProperties, trackingPerClass)) + .def_readwrite("occlusionRatioThreshold", &ObjectTrackerProperties::occlusionRatioThreshold, DOC(dai, ObjectTrackerProperties, occlusionRatioThreshold)) + .def_readwrite("trackletMaxLifespan", &ObjectTrackerProperties::trackletMaxLifespan, DOC(dai, ObjectTrackerProperties, trackletMaxLifespan)) + .def_readwrite("trackletBirthThreshold", &ObjectTrackerProperties::trackletBirthThreshold, DOC(dai, ObjectTrackerProperties, trackletBirthThreshold)) ; // Node @@ -55,6 +59,7 @@ void bind_objecttracker(pybind11::module& m, void* pCallstack){ .def_readonly("inputTrackerFrame", &ObjectTracker::inputTrackerFrame, DOC(dai, node, ObjectTracker, inputTrackerFrame)) .def_readonly("inputDetectionFrame", &ObjectTracker::inputDetectionFrame, DOC(dai, node, ObjectTracker, inputDetectionFrame)) .def_readonly("inputDetections", &ObjectTracker::inputDetections, DOC(dai, node, ObjectTracker, inputDetections)) + .def_readonly("inputConfig", &ObjectTracker::inputConfig, DOC(dai, node, ObjectTracker, inputConfig)) .def_readonly("out", &ObjectTracker::out, DOC(dai, node, ObjectTracker, out)) .def_readonly("passthroughTrackerFrame", &ObjectTracker::passthroughTrackerFrame, DOC(dai, node, ObjectTracker, passthroughTrackerFrame)) .def_readonly("passthroughDetectionFrame", &ObjectTracker::passthroughDetectionFrame, DOC(dai, node, ObjectTracker, passthroughDetectionFrame)) @@ -66,6 +71,9 @@ void bind_objecttracker(pybind11::module& m, void* pCallstack){ .def("setTrackerType", &ObjectTracker::setTrackerType, py::arg("type"), DOC(dai, node, ObjectTracker, setTrackerType)) .def("setTrackerIdAssignmentPolicy", &ObjectTracker::setTrackerIdAssignmentPolicy, py::arg("type"), DOC(dai, node, ObjectTracker, setTrackerIdAssignmentPolicy)) .def("setTrackingPerClass", &ObjectTracker::setTrackingPerClass, py::arg("trackingPerClass"), DOC(dai, node, ObjectTracker, setTrackingPerClass)) + .def("setOcclusionRatioThreshold", &ObjectTracker::setOcclusionRatioThreshold, py::arg("occlusionRatioThreshold"), DOC(dai, node, ObjectTracker, setOcclusionRatioThreshold)) + .def("setTrackletMaxLifespan", &ObjectTracker::setTrackletMaxLifespan, py::arg("lifespan"), DOC(dai, node, ObjectTracker, setTrackletMaxLifespan)) + .def("setTrackletBirthThreshold", &ObjectTracker::setTrackletBirthThreshold, py::arg("threshold"), DOC(dai, node, ObjectTracker, setTrackletBirthThreshold)) ; daiNodeModule.attr("ObjectTracker").attr("Properties") = objectTrackerProperties;