diff --git a/deps/cloudxr/openxr_extensions/XR_NV_opaque_data_channel.h b/deps/cloudxr/openxr_extensions/XR_NV_opaque_data_channel.h new file mode 100644 index 000000000..21e940dfd --- /dev/null +++ b/deps/cloudxr/openxr_extensions/XR_NV_opaque_data_channel.h @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 OR MIT +/*! + * @file + * @brief Header for XR_NV_opaque_data_channel extension. + */ +#ifndef XR_NV_OPAQUE_DATA_CHANNEL_H +#define XR_NV_OPAQUE_DATA_CHANNEL_H 1 + +#include "openxr_extension_helpers.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define XR_NV_opaque_data_channel 1 +#define XR_NV_opaque_data_channel_SPEC_VERSION 1 +#define XR_NV_OPAQUE_DATA_CHANNEL_EXTENSION_NAME "XR_NV_opaque_data_channel" + +XR_DEFINE_HANDLE(XrOpaqueDataChannelNV) + +XR_STRUCT_ENUM(XR_TYPE_OPAQUE_DATA_CHANNEL_CREATE_INFO_NV, 1000526001); +XR_STRUCT_ENUM(XR_TYPE_OPAQUE_DATA_CHANNEL_STATE_NV, 1000526002); + +XR_RESULT_ENUM(XR_ERROR_CHANNEL_ALREADY_CREATED_NV, -1000526000); +XR_RESULT_ENUM(XR_ERROR_CHANNEL_NOT_CONNECTED_NV, -1000526001); + +typedef enum XrOpaqueDataChannelStatusNV { + XR_OPAQUE_DATA_CHANNEL_STATUS_CONNECTING_NV = 0, + XR_OPAQUE_DATA_CHANNEL_STATUS_CONNECTED_NV = 1, + XR_OPAQUE_DATA_CHANNEL_STATUS_SHUTTING_NV = 2, + XR_OPAQUE_DATA_CHANNEL_STATUS_DISCONNECTED_NV = 3, + XR_OPAQUE_DATA_CHANNEL_STATUS_MAX_ENUM = 0x7FFFFFFF, +} XrOpaqueDataChannelStatusNV; + +typedef struct XrOpaqueDataChannelCreateInfoNV { + XrStructureType type; + const void* next; + XrSystemId systemId; + XrUuidEXT uuid; +} XrOpaqueDataChannelCreateInfoNV; + +typedef struct XrOpaqueDataChannelStateNV { + XrStructureType type; + void* next; + XrOpaqueDataChannelStatusNV state; +} XrOpaqueDataChannelStateNV; + +typedef XrResult(XRAPI_PTR* PFN_xrCreateOpaqueDataChannelNV)(XrInstance instance, + const XrOpaqueDataChannelCreateInfoNV* createInfo, + XrOpaqueDataChannelNV* opaqueDataChannel); +typedef XrResult(XRAPI_PTR* PFN_xrDestroyOpaqueDataChannelNV)(XrOpaqueDataChannelNV opaqueDataChannel); +typedef XrResult(XRAPI_PTR* PFN_xrGetOpaqueDataChannelStateNV)(XrOpaqueDataChannelNV opaqueDataChannel, + XrOpaqueDataChannelStateNV* state); +typedef XrResult(XRAPI_PTR* PFN_xrSendOpaqueDataChannelNV)(XrOpaqueDataChannelNV opaqueDataChannel, + uint32_t opaqueDataInputCount, + const uint8_t* opaqueDatas); +typedef XrResult(XRAPI_PTR* PFN_xrReceiveOpaqueDataChannelNV)(XrOpaqueDataChannelNV opaqueDataChannel, + uint32_t opaqueDataCapacityInput, + uint32_t* opaqueDataCountOutput, + uint8_t* opaqueDatas); +typedef XrResult(XRAPI_PTR* PFN_xrShutdownOpaqueDataChannelNV)(XrOpaqueDataChannelNV opaqueDataChannel); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/examples/teleop_session_manager/python/message_channel_example.py b/examples/teleop_session_manager/python/message_channel_example.py new file mode 100755 index 000000000..892b860f8 --- /dev/null +++ b/examples/teleop_session_manager/python/message_channel_example.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Message channel example using TeleopSession + retargeting source/sink nodes. + +Behavior: +- Prints any incoming messages each frame. +- Once channel status is CONNECTED, sends one message every second. +""" + +import argparse +import sys +import time +import uuid + +from isaacteleop.retargeting_engine.deviceio_source_nodes import ( + MessageChannelConnectionStatus, + message_channel_config, +) +from isaacteleop.retargeting_engine.interface import TensorGroup +from isaacteleop.schema import MessageChannelMessages, MessageChannelMessagesTrackedT +from isaacteleop.teleop_session_manager import TeleopSession, TeleopSessionConfig + + +def _positive_int(value: str) -> int: + """Argparse type= callable that rejects non-positive integers.""" + try: + n = int(value) + except ValueError: + raise argparse.ArgumentTypeError(f"expected a positive integer, got {value!r}") + if n <= 0: + raise argparse.ArgumentTypeError(f"must be a positive integer, got {n}") + return n + + +def _parse_uuid_bytes(uuid_text: str) -> bytes: + """Parse canonical UUID text to 16-byte payload (argparse type= callable).""" + try: + return uuid.UUID(uuid_text).bytes + except ValueError: + raise argparse.ArgumentTypeError( + f"--channel-uuid: invalid UUID {uuid_text!r} (expected canonical form, " + "e.g. 550e8400-e29b-41d4-a716-446655440000)" + ) + + +def _enqueue_outbound_message(sink, payload: bytes) -> None: + """Push one outbound message through MessageChannelSink.""" + tg = TensorGroup(sink.input_spec()["messages_tracked"]) + tg[0] = MessageChannelMessagesTrackedT([MessageChannelMessages(payload)]) + sink.compute({"messages_tracked": tg}, {}) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Message channel TeleopSession example" + ) + parser.add_argument( + "--channel-uuid", + type=_parse_uuid_bytes, + required=True, + help="Message channel UUID (canonical form, e.g. 550e8400-e29b-41d4-a716-446655440000)", + ) + parser.add_argument( + "--channel-name", + type=str, + default="example_message_channel", + help="Optional channel display name", + ) + parser.add_argument( + "--outbound-queue-capacity", + type=_positive_int, + default=256, + help="Bounded outbound queue length", + ) + args = parser.parse_args() + + source, sink = message_channel_config( + name="message_channel", + channel_uuid=args.channel_uuid, + channel_name=args.channel_name, + outbound_queue_capacity=args.outbound_queue_capacity, + ) + + config = TeleopSessionConfig( + app_name="MessageChannelExample", + pipeline=source, + ) + + print("=" * 80) + print("Message Channel TeleopSession Example") + print("=" * 80) + print(f"Channel UUID: {args.channel_uuid}") + print(f"Channel Name: {args.channel_name}") + print("Press Ctrl+C to exit.") + print() + + send_counter = 0 + last_send_time = 0.0 + + with TeleopSession(config) as session: + while True: + result = session.step() + status = result["status"][0] + messages_tracked = result["messages_tracked"][0] + messages = ( + messages_tracked.data if messages_tracked.data is not None else [] + ) + + for msg in messages: + payload = bytes(msg.payload) + try: + decoded = payload.decode("utf-8") + print(f"[rx] {decoded}") + except UnicodeDecodeError: + print(f"[rx] 0x{payload.hex()}") + + now = time.monotonic() + if ( + status == MessageChannelConnectionStatus.CONNECTED + and now - last_send_time >= 1.0 + ): + payload_text = f"hello #{send_counter} @ {time.time():.3f}" + _enqueue_outbound_message(sink, payload_text.encode("utf-8")) + print(f"[tx] {payload_text}") + last_send_time = now + send_counter += 1 + + time.sleep(0.01) + + return 0 + + +if __name__ == "__main__": + try: + sys.exit(main()) + except KeyboardInterrupt: + print("\nExiting.") + sys.exit(0) diff --git a/src/core/deviceio_base/cpp/inc/deviceio_base/message_channel_tracker_base.hpp b/src/core/deviceio_base/cpp/inc/deviceio_base/message_channel_tracker_base.hpp new file mode 100644 index 000000000..d0ad376d0 --- /dev/null +++ b/src/core/deviceio_base/cpp/inc/deviceio_base/message_channel_tracker_base.hpp @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tracker.hpp" + +#include +#include + +namespace core +{ + +struct MessageChannelMessagesT; +struct MessageChannelMessagesTrackedT; + +enum class MessageChannelStatus : int32_t +{ + CONNECTING = 0, + CONNECTED = 1, + SHUTTING = 2, + DISCONNECTED = 3, + UNKNOWN = -1, +}; + +class IMessageChannelTrackerImpl : public ITrackerImpl +{ +public: + virtual MessageChannelStatus get_status() const = 0; + virtual const MessageChannelMessagesTrackedT& get_messages() const = 0; + virtual void send_message(const std::vector& payload) const = 0; +}; + +} // namespace core diff --git a/src/core/deviceio_trackers/AGENTS.md b/src/core/deviceio_trackers/AGENTS.md index 9b9816352..64649df18 100644 --- a/src/core/deviceio_trackers/AGENTS.md +++ b/src/core/deviceio_trackers/AGENTS.md @@ -10,6 +10,7 @@ SPDX-License-Identifier: Apache-2.0 ## No OpenXR dependency - **`deviceio_trackers`** must **not** link **`OpenXR::headers`**, **`oxr::oxr_utils`**, or vendor extension targets, and must **not** `#include` OpenXR headers. Public API stays schema + **`deviceio_base`** only. +- This includes **`tracker_bindings.cpp`**: do not add `#include ` or any `XR_NV_*` extension headers here, even when the bound tracker wraps an OpenXR concept. The UUID is `std::array` at the `deviceio_trackers` boundary—no OpenXR types leak through. ## Related docs diff --git a/src/core/deviceio_trackers/cpp/CMakeLists.txt b/src/core/deviceio_trackers/cpp/CMakeLists.txt index 0ea801774..48b460d71 100644 --- a/src/core/deviceio_trackers/cpp/CMakeLists.txt +++ b/src/core/deviceio_trackers/cpp/CMakeLists.txt @@ -8,12 +8,14 @@ add_library(deviceio_trackers STATIC hand_tracker.cpp head_tracker.cpp controller_tracker.cpp + message_channel_tracker.cpp generic_3axis_pedal_tracker.cpp frame_metadata_tracker_oak.cpp full_body_tracker_pico.cpp inc/deviceio_trackers/head_tracker.hpp inc/deviceio_trackers/hand_tracker.hpp inc/deviceio_trackers/controller_tracker.hpp + inc/deviceio_trackers/message_channel_tracker.hpp inc/deviceio_trackers/full_body_tracker_pico.hpp inc/deviceio_trackers/generic_3axis_pedal_tracker.hpp inc/deviceio_trackers/frame_metadata_tracker_oak.hpp diff --git a/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/message_channel_tracker.hpp b/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/message_channel_tracker.hpp new file mode 100644 index 000000000..74a152af2 --- /dev/null +++ b/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/message_channel_tracker.hpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace core +{ + +class MessageChannelTracker : public ITracker +{ +public: + static constexpr size_t DEFAULT_MAX_MESSAGE_SIZE = 64 * 1024; + static constexpr size_t CHANNEL_UUID_SIZE = 16; + + explicit MessageChannelTracker(const std::array& channel_uuid, + const std::string& channel_name = "", + size_t max_message_size = DEFAULT_MAX_MESSAGE_SIZE); + + std::string_view get_name() const override + { + return TRACKER_NAME; + } + + MessageChannelStatus get_status(const ITrackerSession& session) const; + const MessageChannelMessagesTrackedT& get_messages(const ITrackerSession& session) const; + void send_message(const ITrackerSession& session, const std::vector& payload) const; + + const std::array& channel_uuid() const + { + return channel_uuid_; + } + + const std::string& channel_name() const + { + return channel_name_; + } + + size_t max_message_size() const + { + return max_message_size_; + } + +private: + static constexpr const char* TRACKER_NAME = "MessageChannelTracker"; + + std::array channel_uuid_{}; + std::string channel_name_; + size_t max_message_size_{ DEFAULT_MAX_MESSAGE_SIZE }; +}; + +} // namespace core diff --git a/src/core/deviceio_trackers/cpp/message_channel_tracker.cpp b/src/core/deviceio_trackers/cpp/message_channel_tracker.cpp new file mode 100644 index 000000000..3774c06b2 --- /dev/null +++ b/src/core/deviceio_trackers/cpp/message_channel_tracker.cpp @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "inc/deviceio_trackers/message_channel_tracker.hpp" + +#include + +namespace core +{ + +MessageChannelTracker::MessageChannelTracker(const std::array& channel_uuid, + const std::string& channel_name, + size_t max_message_size) + : channel_uuid_(channel_uuid), channel_name_(channel_name), max_message_size_(max_message_size) +{ + if (max_message_size_ == 0) + { + throw std::invalid_argument("MessageChannelTracker: max_message_size must be > 0"); + } +} + +MessageChannelStatus MessageChannelTracker::get_status(const ITrackerSession& session) const +{ + return static_cast(session.get_tracker_impl(*this)).get_status(); +} + +const MessageChannelMessagesTrackedT& MessageChannelTracker::get_messages(const ITrackerSession& session) const +{ + return static_cast(session.get_tracker_impl(*this)).get_messages(); +} + +void MessageChannelTracker::send_message(const ITrackerSession& session, const std::vector& payload) const +{ + static_cast(session.get_tracker_impl(*this)).send_message(payload); +} + +} // namespace core diff --git a/src/core/deviceio_trackers/python/deviceio_trackers_init.py b/src/core/deviceio_trackers/python/deviceio_trackers_init.py index e1bda3ec3..f867e8f54 100644 --- a/src/core/deviceio_trackers/python/deviceio_trackers_init.py +++ b/src/core/deviceio_trackers/python/deviceio_trackers_init.py @@ -8,6 +8,8 @@ HandTracker, HeadTracker, ControllerTracker, + MessageChannelStatus, + MessageChannelTracker, FrameMetadataTrackerOak, Generic3AxisPedalTracker, FullBodyTrackerPico, @@ -21,6 +23,8 @@ __all__ = [ "ControllerTracker", + "MessageChannelStatus", + "MessageChannelTracker", "FrameMetadataTrackerOak", "FullBodyTrackerPico", "Generic3AxisPedalTracker", diff --git a/src/core/deviceio_trackers/python/tracker_bindings.cpp b/src/core/deviceio_trackers/python/tracker_bindings.cpp index 494e72233..7b339d57f 100644 --- a/src/core/deviceio_trackers/python/tracker_bindings.cpp +++ b/src/core/deviceio_trackers/python/tracker_bindings.cpp @@ -7,9 +7,15 @@ #include #include #include +#include #include #include #include +#include + +#include +#include +#include namespace py = pybind11; @@ -58,6 +64,47 @@ PYBIND11_MODULE(_deviceio_trackers, m) { return self.get_right_controller(session); }, py::arg("session"), "Get the right controller tracked state (data is None if inactive)"); + py::enum_(m, "MessageChannelStatus") + .value("CONNECTING", core::MessageChannelStatus::CONNECTING) + .value("CONNECTED", core::MessageChannelStatus::CONNECTED) + .value("SHUTTING", core::MessageChannelStatus::SHUTTING) + .value("DISCONNECTED", core::MessageChannelStatus::DISCONNECTED) + .value("UNKNOWN", core::MessageChannelStatus::UNKNOWN); + + py::class_>( + m, "MessageChannelTracker") + .def(py::init( + [](py::bytes channel_uuid, const std::string& channel_name, size_t max_message_size) + { + std::string uuid_str = channel_uuid; + if (uuid_str.size() != core::MessageChannelTracker::CHANNEL_UUID_SIZE) + { + throw std::invalid_argument("MessageChannelTracker: channel_uuid must be exactly 16 bytes"); + } + std::array uuid{}; + std::memcpy(uuid.data(), uuid_str.data(), uuid.size()); + return std::make_shared(uuid, channel_name, max_message_size); + }), + py::arg("channel_uuid"), py::arg("channel_name") = "", + py::arg("max_message_size") = core::MessageChannelTracker::DEFAULT_MAX_MESSAGE_SIZE, + "Construct a MessageChannelTracker for XR_NV_opaque_data_channel") + .def( + "get_messages", + [](const core::MessageChannelTracker& self, + const core::ITrackerSession& session) -> core::MessageChannelMessagesTrackedT + { return self.get_messages(session); }, + py::arg("session"), "Get all messages drained during the last update (possibly empty)") + .def( + "get_status", + [](const core::MessageChannelTracker& self, const core::ITrackerSession& session) -> core::MessageChannelStatus + { return self.get_status(session); }, + py::arg("session"), "Get current channel connection state") + .def( + "send_message", + [](const core::MessageChannelTracker& self, const core::ITrackerSession& session, + const core::MessageChannelMessagesT& message) { self.send_message(session, message.payload); }, + py::arg("session"), py::arg("message"), "Send a MessageChannelMessages payload over the message channel"); + py::class_>( m, "FrameMetadataTrackerOak") .def(py::init&, size_t>(), py::arg("collection_prefix"), diff --git a/src/core/live_trackers/AGENTS.md b/src/core/live_trackers/AGENTS.md index 41558202c..8d882310d 100644 --- a/src/core/live_trackers/AGENTS.md +++ b/src/core/live_trackers/AGENTS.md @@ -28,6 +28,19 @@ SPDX-License-Identifier: Apache-2.0 - **`live_trackers`** should **`PUBLIC` link `oxr::oxr_utils`** (OpenXR headers come through that INTERFACE target) because headers/sources use OpenXR / oxr types. +## New tracker MCAP checklist + +When adding MCAP support to a new tracker impl, all of the following are required together—missing any one causes a build failure or wrong timestamps: + +1. Add `XrTimeConverter time_converter_` and `int64_t last_update_time_ = 0` members to the impl header. +2. Initialize `time_converter_(handles)` in the constructor initializer list. +3. Declare `update(int64_t monotonic_time_ns) override` (not `XrTime`)—they are the same C++ type (`int64_t`) but semantically different; the base interface uses monotonic ns. +4. At the top of `update()`: store `last_update_time_ = monotonic_time_ns` and compute `const XrTime xr_time = time_converter_.convert_monotonic_ns_to_xrtime(monotonic_time_ns)`. +5. Use `DeviceDataTimestamp(last_update_time_, last_update_time_, xr_time)` — not `(time, time, time)`. +6. Add `MessageChannelRecordingTraits` (or equivalent) to `recording_traits.hpp`. +7. **Always build** (`cmake --build -- -j$(nproc)`) before treating work as done. Pre-commit alone does not catch compile errors or clang-format violations enforced at build time. +8. Read `AGENTS.md` before starting. Not after CI breaks. + ## Related docs - Session update loop: [`../deviceio_session/AGENTS.md`](../deviceio_session/AGENTS.md) diff --git a/src/core/live_trackers/cpp/CMakeLists.txt b/src/core/live_trackers/cpp/CMakeLists.txt index a2e429215..23d105b7d 100644 --- a/src/core/live_trackers/cpp/CMakeLists.txt +++ b/src/core/live_trackers/cpp/CMakeLists.txt @@ -9,6 +9,7 @@ add_library(live_trackers STATIC live_head_tracker_impl.cpp live_hand_tracker_impl.cpp live_controller_tracker_impl.cpp + live_message_channel_tracker_impl.cpp live_full_body_tracker_pico_impl.cpp live_generic_3axis_pedal_tracker_impl.cpp live_frame_metadata_tracker_oak_impl.cpp @@ -18,6 +19,7 @@ add_library(live_trackers STATIC live_head_tracker_impl.hpp live_hand_tracker_impl.hpp live_controller_tracker_impl.hpp + live_message_channel_tracker_impl.hpp live_full_body_tracker_pico_impl.hpp live_generic_3axis_pedal_tracker_impl.hpp live_frame_metadata_tracker_oak_impl.hpp diff --git a/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp b/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp index 50b17fdc5..7d6b5c4f9 100644 --- a/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp +++ b/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp @@ -24,6 +24,8 @@ class ControllerTracker; class IControllerTrackerImpl; class FrameMetadataTrackerOak; class IFrameMetadataTrackerOakImpl; +class MessageChannelTracker; +class IMessageChannelTrackerImpl; class FullBodyTrackerPico; class IFullBodyTrackerPicoImpl; class Generic3AxisPedalTracker; @@ -56,6 +58,7 @@ class LiveDeviceIOFactory std::unique_ptr create_head_tracker_impl(const HeadTracker* tracker); std::unique_ptr create_hand_tracker_impl(const HandTracker* tracker); std::unique_ptr create_controller_tracker_impl(const ControllerTracker* tracker); + std::unique_ptr create_message_channel_tracker_impl(const MessageChannelTracker* tracker); std::unique_ptr create_full_body_tracker_pico_impl(const FullBodyTrackerPico* tracker); std::unique_ptr create_generic_3axis_pedal_tracker_impl( const Generic3AxisPedalTracker* tracker); diff --git a/src/core/live_trackers/cpp/live_deviceio_factory.cpp b/src/core/live_trackers/cpp/live_deviceio_factory.cpp index 1160f7996..2c304480c 100644 --- a/src/core/live_trackers/cpp/live_deviceio_factory.cpp +++ b/src/core/live_trackers/cpp/live_deviceio_factory.cpp @@ -9,6 +9,7 @@ #include "live_generic_3axis_pedal_tracker_impl.hpp" #include "live_hand_tracker_impl.hpp" #include "live_head_tracker_impl.hpp" +#include "live_message_channel_tracker_impl.hpp" #include #include @@ -16,6 +17,7 @@ #include #include #include +#include #include #include @@ -59,6 +61,12 @@ std::unique_ptr try_create_controller_impl(LiveDeviceIOFactory& fa return typed ? factory.create_controller_tracker_impl(typed) : nullptr; } +std::unique_ptr try_create_message_channel_impl(LiveDeviceIOFactory& factory, const ITracker& tracker) +{ + auto* typed = dynamic_cast(&tracker); + return typed ? factory.create_message_channel_tracker_impl(typed) : nullptr; +} + std::unique_ptr try_create_full_body_pico_impl(LiveDeviceIOFactory& factory, const ITracker& tracker) { auto* typed = dynamic_cast(&tracker); @@ -91,6 +99,7 @@ inline const TrackerDispatchEntry k_tracker_dispatch[] = { { &try_add_extensions, &try_create_head_impl }, { &try_add_extensions, &try_create_hand_impl }, { &try_add_extensions, &try_create_controller_impl }, + { &try_add_extensions, &try_create_message_channel_impl }, { &try_add_extensions, &try_create_full_body_pico_impl }, { &try_add_extensions, &try_create_generic_pedal_impl }, { &try_add_extensions, &try_create_oak_impl }, @@ -202,6 +211,17 @@ std::unique_ptr LiveDeviceIOFactory::create_controller_t return std::make_unique(handles_, std::move(channels)); } +std::unique_ptr LiveDeviceIOFactory::create_message_channel_tracker_impl( + const MessageChannelTracker* tracker) +{ + std::unique_ptr channels; + if (should_record(tracker)) + { + channels = LiveMessageChannelTrackerImpl::create_mcap_channels(*writer_, get_name(tracker)); + } + return std::make_unique(handles_, tracker, std::move(channels)); +} + std::unique_ptr LiveDeviceIOFactory::create_full_body_tracker_pico_impl( const FullBodyTrackerPico* tracker) { diff --git a/src/core/live_trackers/cpp/live_message_channel_tracker_impl.cpp b/src/core/live_trackers/cpp/live_message_channel_tracker_impl.cpp new file mode 100644 index 000000000..cb4815d72 --- /dev/null +++ b/src/core/live_trackers/cpp/live_message_channel_tracker_impl.cpp @@ -0,0 +1,327 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "live_message_channel_tracker_impl.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace core +{ + +std::unique_ptr LiveMessageChannelTrackerImpl::create_mcap_channels(mcap::McapWriter& writer, + std::string_view base_name) +{ + return std::make_unique( + writer, base_name, MessageChannelRecordingTraits::schema_name, + std::vector( + MessageChannelRecordingTraits::channels.begin(), MessageChannelRecordingTraits::channels.end())); +} + +LiveMessageChannelTrackerImpl::LiveMessageChannelTrackerImpl(const OpenXRSessionHandles& handles, + const MessageChannelTracker* tracker, + std::unique_ptr mcap_channels) + : handles_(handles), tracker_(tracker), time_converter_(handles), mcap_channels_(std::move(mcap_channels)) +{ + if (handles_.instance == XR_NULL_HANDLE) + throw std::invalid_argument("LiveMessageChannelTrackerImpl: handles_.instance is XR_NULL_HANDLE"); + if (handles_.session == XR_NULL_HANDLE) + throw std::invalid_argument("LiveMessageChannelTrackerImpl: handles_.session is XR_NULL_HANDLE"); + if (!handles_.xrGetInstanceProcAddr) + throw std::invalid_argument("LiveMessageChannelTrackerImpl: handles_.xrGetInstanceProcAddr is null"); + if (!tracker_) + throw std::invalid_argument("LiveMessageChannelTrackerImpl: tracker_ is null"); + + receive_buffer_.resize(tracker_->max_message_size(), 0); + + initialize_functions(); + + system_id_ = resolve_system_id(); + channel_uuid_ = make_uuid(tracker_->channel_uuid()); + create_channel(); +} + +LiveMessageChannelTrackerImpl::~LiveMessageChannelTrackerImpl() +{ + destroy_channel(); +} + +void LiveMessageChannelTrackerImpl::update(int64_t monotonic_time_ns) +{ + last_update_time_ = monotonic_time_ns; + const XrTime xr_time = time_converter_.convert_monotonic_ns_to_xrtime(monotonic_time_ns); + + messages_.data.clear(); + + const MessageChannelStatus status = query_status(); + if (status == MessageChannelStatus::DISCONNECTED) + { + // Runtime/client disconnected: rebuild the channel object so it can reconnect. + try_reopen_channel(); + return; + } + + if (status != MessageChannelStatus::CONNECTED) + { + return; + } + + while (true) + { + // First call: query pending byte count without reading. + uint32_t count_out = 0; + XrResult query_result = receive_fn_(channel_, 0, &count_out, nullptr); + if (query_result != XR_SUCCESS) + { + if (query_result == XR_ERROR_CHANNEL_NOT_CONNECTED_NV) + { + return; + } + throw std::runtime_error("LiveMessageChannelTrackerImpl: xrReceiveOpaqueDataChannelNV (query) failed, result=" + + std::to_string(query_result)); + } + + if (count_out == 0) + { + break; + } + + if (count_out > tracker_->max_message_size()) + { + // Drain the oversized message to unblock the queue, but discard the data. + // Read in bounded chunks using the pre-allocated receive_buffer_ to avoid + // allocating a buffer sized by the untrusted remote-supplied count_out. + std::cerr << "[LiveMessageChannelTrackerImpl] Dropping oversized message (" << count_out << " bytes, max " + << tracker_->max_message_size() << ")" << std::endl; + uint32_t remaining = count_out; + while (remaining > 0) + { + uint32_t chunk = std::min(remaining, static_cast(receive_buffer_.size())); + uint32_t drained = 0; + XrResult drain_result = receive_fn_(channel_, chunk, &drained, receive_buffer_.data()); + if (drain_result != XR_SUCCESS || drained == 0) + { + break; + } + remaining -= drained; + } + if (remaining > 0) + { + // Drain failed; stop processing to avoid re-querying the same oversized message. + break; + } + continue; + } + + // Second call: read the message into the pre-allocated buffer. + uint32_t read_count = 0; + XrResult recv_result = receive_fn_(channel_, count_out, &read_count, receive_buffer_.data()); + if (recv_result != XR_SUCCESS) + { + if (recv_result == XR_ERROR_CHANNEL_NOT_CONNECTED_NV) + { + return; + } + throw std::runtime_error("LiveMessageChannelTrackerImpl: xrReceiveOpaqueDataChannelNV (read) failed, result=" + + std::to_string(recv_result)); + } + + auto message = std::make_shared(); + message->payload.assign(receive_buffer_.begin(), receive_buffer_.begin() + read_count); + messages_.data.push_back(message); + } + + if (mcap_channels_) + { + DeviceDataTimestamp timestamp(last_update_time_, last_update_time_, xr_time); + for (const auto& msg : messages_.data) + { + mcap_channels_->write(0, timestamp, msg); + } + } +} + +MessageChannelStatus LiveMessageChannelTrackerImpl::get_status() const +{ + return query_status(); +} + +const MessageChannelMessagesTrackedT& LiveMessageChannelTrackerImpl::get_messages() const +{ + return messages_; +} + +void LiveMessageChannelTrackerImpl::send_message(const std::vector& payload) const +{ + if (channel_ == XR_NULL_HANDLE) + { + throw std::runtime_error("LiveMessageChannelTrackerImpl::send_message: channel is not open"); + } + + if (payload.size() > tracker_->max_message_size()) + { + throw std::runtime_error("LiveMessageChannelTrackerImpl::send_message: payload size " + + std::to_string(payload.size()) + " exceeds max_message_size " + + std::to_string(tracker_->max_message_size())); + } + + XrOpaqueDataChannelStateNV channel_state{ XR_TYPE_OPAQUE_DATA_CHANNEL_STATE_NV }; + XrResult state_result = get_state_fn_(channel_, &channel_state); + if (state_result != XR_SUCCESS) + { + throw std::runtime_error("LiveMessageChannelTrackerImpl: xrGetOpaqueDataChannelStateNV failed, result=" + + std::to_string(state_result)); + } + if (channel_state.state != XR_OPAQUE_DATA_CHANNEL_STATUS_CONNECTED_NV) + { + throw std::runtime_error("LiveMessageChannelTrackerImpl::send_message: channel is not connected"); + } + + const uint8_t* payload_ptr = payload.empty() ? nullptr : payload.data(); + XrResult send_result = send_fn_(channel_, static_cast(payload.size()), payload_ptr); + if (send_result != XR_SUCCESS) + { + throw std::runtime_error("LiveMessageChannelTrackerImpl: xrSendOpaqueDataChannelNV failed, result=" + + std::to_string(send_result)); + } +} + +void LiveMessageChannelTrackerImpl::initialize_functions() +{ + loadExtensionFunction(handles_.instance, handles_.xrGetInstanceProcAddr, "xrCreateOpaqueDataChannelNV", + reinterpret_cast(&create_channel_fn_)); + loadExtensionFunction(handles_.instance, handles_.xrGetInstanceProcAddr, "xrDestroyOpaqueDataChannelNV", + reinterpret_cast(&destroy_channel_fn_)); + loadExtensionFunction(handles_.instance, handles_.xrGetInstanceProcAddr, "xrGetOpaqueDataChannelStateNV", + reinterpret_cast(&get_state_fn_)); + loadExtensionFunction(handles_.instance, handles_.xrGetInstanceProcAddr, "xrSendOpaqueDataChannelNV", + reinterpret_cast(&send_fn_)); + loadExtensionFunction(handles_.instance, handles_.xrGetInstanceProcAddr, "xrReceiveOpaqueDataChannelNV", + reinterpret_cast(&receive_fn_)); + loadExtensionFunction(handles_.instance, handles_.xrGetInstanceProcAddr, "xrShutdownOpaqueDataChannelNV", + reinterpret_cast(&shutdown_fn_)); + loadExtensionFunction(handles_.instance, handles_.xrGetInstanceProcAddr, "xrGetSystem", + reinterpret_cast(&get_system_fn_)); +} + +XrSystemId LiveMessageChannelTrackerImpl::resolve_system_id() const +{ + XrSystemGetInfo get_info{ XR_TYPE_SYSTEM_GET_INFO }; + get_info.formFactor = XR_FORM_FACTOR_HEAD_MOUNTED_DISPLAY; + + XrSystemId system_id = XR_NULL_SYSTEM_ID; + XrResult result = get_system_fn_(handles_.instance, &get_info, &system_id); + if (result != XR_SUCCESS) + { + throw std::runtime_error("LiveMessageChannelTrackerImpl: xrGetSystem failed, result=" + std::to_string(result)); + } + return system_id; +} + +void LiveMessageChannelTrackerImpl::create_channel() +{ + XrOpaqueDataChannelCreateInfoNV create_info{ XR_TYPE_OPAQUE_DATA_CHANNEL_CREATE_INFO_NV }; + create_info.systemId = system_id_; + create_info.uuid = channel_uuid_; + + XrResult result = create_channel_fn_(handles_.instance, &create_info, &channel_); + if (result != XR_SUCCESS) + { + throw std::runtime_error("LiveMessageChannelTrackerImpl: xrCreateOpaqueDataChannelNV failed, result=" + + std::to_string(result)); + } +} + +void LiveMessageChannelTrackerImpl::destroy_channel() noexcept +{ + if (channel_ == XR_NULL_HANDLE) + { + return; + } + + if (shutdown_fn_) + { + XrResult result = shutdown_fn_(channel_); + if (result != XR_SUCCESS) + { + std::cerr << "[LiveMessageChannelTrackerImpl] xrShutdownOpaqueDataChannelNV failed, result=" << result + << std::endl; + } + } + if (destroy_channel_fn_) + { + XrResult result = destroy_channel_fn_(channel_); + if (result != XR_SUCCESS) + { + std::cerr << "[LiveMessageChannelTrackerImpl] xrDestroyOpaqueDataChannelNV failed, result=" << result + << std::endl; + } + } + channel_ = XR_NULL_HANDLE; +} + +bool LiveMessageChannelTrackerImpl::try_reopen_channel() +{ + try + { + destroy_channel(); + create_channel(); + return true; + } + catch (const std::exception& e) + { + std::cerr << "[LiveMessageChannelTrackerImpl] Failed to reopen message channel: " << e.what() << std::endl; + return false; + } +} + +MessageChannelStatus LiveMessageChannelTrackerImpl::query_status() const +{ + if (channel_ == XR_NULL_HANDLE) + { + // Channel was destroyed (e.g. failed reopen); report DISCONNECTED so + // update() will schedule a reopen attempt on the next frame. + return MessageChannelStatus::DISCONNECTED; + } + + XrOpaqueDataChannelStateNV channel_state{ XR_TYPE_OPAQUE_DATA_CHANNEL_STATE_NV }; + XrResult state_result = get_state_fn_(channel_, &channel_state); + if (state_result != XR_SUCCESS) + { + throw std::runtime_error("LiveMessageChannelTrackerImpl: xrGetOpaqueDataChannelStateNV failed, result=" + + std::to_string(state_result)); + } + + switch (channel_state.state) + { + case XR_OPAQUE_DATA_CHANNEL_STATUS_CONNECTING_NV: + return MessageChannelStatus::CONNECTING; + case XR_OPAQUE_DATA_CHANNEL_STATUS_CONNECTED_NV: + return MessageChannelStatus::CONNECTED; + case XR_OPAQUE_DATA_CHANNEL_STATUS_SHUTTING_NV: + return MessageChannelStatus::SHUTTING; + case XR_OPAQUE_DATA_CHANNEL_STATUS_DISCONNECTED_NV: + return MessageChannelStatus::DISCONNECTED; + default: + return MessageChannelStatus::UNKNOWN; + } +} + +XrUuidEXT LiveMessageChannelTrackerImpl::make_uuid( + const std::array& channel_uuid) const +{ + XrUuidEXT uuid{}; + std::memcpy(uuid.data, channel_uuid.data(), channel_uuid.size()); + return uuid; +} + +} // namespace core diff --git a/src/core/live_trackers/cpp/live_message_channel_tracker_impl.hpp b/src/core/live_trackers/cpp/live_message_channel_tracker_impl.hpp new file mode 100644 index 000000000..440311bdb --- /dev/null +++ b/src/core/live_trackers/cpp/live_message_channel_tracker_impl.hpp @@ -0,0 +1,80 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace core +{ + +using MessageChannelMcapChannels = McapTrackerChannels; + +class LiveMessageChannelTrackerImpl : public IMessageChannelTrackerImpl +{ +public: + static std::vector required_extensions() + { + return { XR_NV_OPAQUE_DATA_CHANNEL_EXTENSION_NAME }; + } + static std::unique_ptr create_mcap_channels(mcap::McapWriter& writer, + std::string_view base_name); + + LiveMessageChannelTrackerImpl(const OpenXRSessionHandles& handles, + const MessageChannelTracker* tracker, + std::unique_ptr mcap_channels = nullptr); + ~LiveMessageChannelTrackerImpl() override; + + LiveMessageChannelTrackerImpl(const LiveMessageChannelTrackerImpl&) = delete; + LiveMessageChannelTrackerImpl& operator=(const LiveMessageChannelTrackerImpl&) = delete; + LiveMessageChannelTrackerImpl(LiveMessageChannelTrackerImpl&&) = delete; + LiveMessageChannelTrackerImpl& operator=(LiveMessageChannelTrackerImpl&&) = delete; + + void update(int64_t monotonic_time_ns) override; + MessageChannelStatus get_status() const override; + const MessageChannelMessagesTrackedT& get_messages() const override; + void send_message(const std::vector& payload) const override; + +private: + void initialize_functions(); + XrSystemId resolve_system_id() const; + MessageChannelStatus query_status() const; + void create_channel(); + void destroy_channel() noexcept; + bool try_reopen_channel(); + XrUuidEXT make_uuid(const std::array& channel_uuid) const; + + OpenXRSessionHandles handles_; + const MessageChannelTracker* tracker_{ nullptr }; + XrSystemId system_id_{ XR_NULL_SYSTEM_ID }; + XrUuidEXT channel_uuid_{}; + XrOpaqueDataChannelNV channel_{ XR_NULL_HANDLE }; + + PFN_xrCreateOpaqueDataChannelNV create_channel_fn_{ nullptr }; + PFN_xrDestroyOpaqueDataChannelNV destroy_channel_fn_{ nullptr }; + PFN_xrGetOpaqueDataChannelStateNV get_state_fn_{ nullptr }; + PFN_xrSendOpaqueDataChannelNV send_fn_{ nullptr }; + PFN_xrReceiveOpaqueDataChannelNV receive_fn_{ nullptr }; + PFN_xrShutdownOpaqueDataChannelNV shutdown_fn_{ nullptr }; + PFN_xrGetSystem get_system_fn_{ nullptr }; + + XrTimeConverter time_converter_; + int64_t last_update_time_ = 0; + MessageChannelMessagesTrackedT messages_; + std::vector receive_buffer_; + std::unique_ptr mcap_channels_; +}; + +} // namespace core diff --git a/src/core/mcap/cpp/inc/mcap/recording_traits.hpp b/src/core/mcap/cpp/inc/mcap/recording_traits.hpp index 2e0d89db7..4790af798 100644 --- a/src/core/mcap/cpp/inc/mcap/recording_traits.hpp +++ b/src/core/mcap/cpp/inc/mcap/recording_traits.hpp @@ -52,4 +52,10 @@ struct OakRecordingTraits static constexpr std::string_view schema_name = "core.FrameMetadataOakRecord"; }; +struct MessageChannelRecordingTraits +{ + static constexpr std::string_view schema_name = "core.MessageChannelMessagesRecord"; + static constexpr std::array channels = { "message_channel" }; +}; + } // namespace core diff --git a/src/core/python/deviceio_init.py b/src/core/python/deviceio_init.py index f3ab4686b..81023b8dd 100644 --- a/src/core/python/deviceio_init.py +++ b/src/core/python/deviceio_init.py @@ -13,6 +13,8 @@ HandTracker, HeadTracker, ControllerTracker, + MessageChannelStatus, + MessageChannelTracker, FrameMetadataTrackerOak, Generic3AxisPedalTracker, FullBodyTrackerPico, @@ -49,6 +51,8 @@ "HandTracker", "HeadTracker", "ControllerTracker", + "MessageChannelStatus", + "MessageChannelTracker", "FrameMetadataTrackerOak", "Generic3AxisPedalTracker", "FullBodyTrackerPico", diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py b/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py index dd4265f93..2fac2f0c1 100644 --- a/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py @@ -9,6 +9,13 @@ from .controllers_source import ControllersSource from .pedals_source import Generic3AxisPedalSource from .full_body_source import FullBodySource +from .message_channel_source import MessageChannelSource +from .message_channel_sink import MessageChannelSink +from .message_channel_config import ( + MessageChannelConfig, + message_channel_config, + messageChannelConfig, +) from .deviceio_tensor_types import ( HeadPoseTrackedType, HandPoseTrackedType, @@ -20,6 +27,12 @@ DeviceIOControllerSnapshotTracked, DeviceIOGeneric3AxisPedalOutputTracked, DeviceIOFullBodyPosePicoTracked, + MessageChannelMessagesTrackedType, + MessageChannelConnectionStatus, + MessageChannelStatusType, + DeviceIOMessageChannelMessagesTracked, + MessageChannelMessagesTrackedGroup, + MessageChannelStatusGroup, ) __all__ = [ @@ -29,14 +42,25 @@ "ControllersSource", "Generic3AxisPedalSource", "FullBodySource", + "MessageChannelSource", + "MessageChannelSink", + "MessageChannelConfig", + "message_channel_config", + "messageChannelConfig", "HeadPoseTrackedType", "HandPoseTrackedType", "ControllerSnapshotTrackedType", "Generic3AxisPedalOutputTrackedType", "FullBodyPosePicoTrackedType", + "MessageChannelMessagesTrackedType", + "MessageChannelConnectionStatus", + "MessageChannelStatusType", "DeviceIOHeadPoseTracked", "DeviceIOHandPoseTracked", "DeviceIOControllerSnapshotTracked", "DeviceIOGeneric3AxisPedalOutputTracked", "DeviceIOFullBodyPosePicoTracked", + "DeviceIOMessageChannelMessagesTracked", + "MessageChannelMessagesTrackedGroup", + "MessageChannelStatusGroup", ] diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/deviceio_tensor_types.py b/src/core/retargeting_engine/python/deviceio_source_nodes/deviceio_tensor_types.py index fcc1a108e..77aaa4e4f 100644 --- a/src/core/retargeting_engine/python/deviceio_source_nodes/deviceio_tensor_types.py +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/deviceio_tensor_types.py @@ -9,6 +9,7 @@ the raw flatbuffer object (or None when the tracker is inactive). """ +from enum import IntEnum from typing import Any from ..interface.tensor_type import TensorType from ..interface.tensor_group_type import TensorGroupType @@ -18,6 +19,7 @@ ControllerSnapshotTrackedT, Generic3AxisPedalOutputTrackedT, FullBodyPosePicoTrackedT, + MessageChannelMessagesTrackedT, ) @@ -117,6 +119,56 @@ def validate_value(self, value: Any) -> None: ) +class MessageChannelMessagesTrackedType(TensorType): + """MessageChannelMessagesTrackedT wrapper type from DeviceIO MessageChannelTracker.""" + + def __init__(self, name: str) -> None: + super().__init__(name) + + def _check_instance_compatibility(self, other: TensorType) -> bool: + if not isinstance(other, MessageChannelMessagesTrackedType): + raise TypeError( + f"Expected MessageChannelMessagesTrackedType, got {type(other).__name__}" + ) + return True + + def validate_value(self, value: Any) -> None: + if not isinstance(value, MessageChannelMessagesTrackedT): + raise TypeError( + f"Expected MessageChannelMessagesTrackedT for '{self.name}', got {type(value).__name__}" + ) + + +class MessageChannelConnectionStatus(IntEnum): + """Message channel connection states exposed by MessageChannelSource.""" + + CONNECTING = 0 + CONNECTED = 1 + SHUTTING = 2 + DISCONNECTED = 3 + UNKNOWN = -1 + + +class MessageChannelStatusType(TensorType): + """Enum status for message channel connectivity.""" + + def __init__(self, name: str) -> None: + super().__init__(name) + + def _check_instance_compatibility(self, other: TensorType) -> bool: + if not isinstance(other, MessageChannelStatusType): + raise TypeError( + f"Expected MessageChannelStatusType, got {type(other).__name__}" + ) + return True + + def validate_value(self, value: Any) -> None: + if not isinstance(value, MessageChannelConnectionStatus): + raise TypeError( + f"Expected MessageChannelConnectionStatus for '{self.name}', got {type(value).__name__}" + ) + + def DeviceIOHeadPoseTracked() -> TensorGroupType: """Tracked head pose from DeviceIO HeadTracker. @@ -169,3 +221,27 @@ def DeviceIOFullBodyPosePicoTracked() -> TensorGroupType: "deviceio_full_body_pose_pico", [FullBodyPosePicoTrackedType("full_body_tracked")], ) + + +def DeviceIOMessageChannelMessagesTracked() -> TensorGroupType: + """Tracked message wrapper from DeviceIO MessageChannelTracker.""" + return TensorGroupType( + "deviceio_message_channel_messages_tracked", + [MessageChannelMessagesTrackedType("messages_tracked")], + ) + + +def MessageChannelMessagesTrackedGroup() -> TensorGroupType: + """Tracked batch of messages drained in one update.""" + return TensorGroupType( + "message_channel_messages_tracked", + [MessageChannelMessagesTrackedType("messages_tracked")], + ) + + +def MessageChannelStatusGroup() -> TensorGroupType: + """Message channel connection status enum.""" + return TensorGroupType( + "message_channel_status", + [MessageChannelStatusType("status")], + ) diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_config.py b/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_config.py new file mode 100644 index 000000000..7e4eb8294 --- /dev/null +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_config.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Factory helpers for message channel source/sink node pairs.""" + +from collections import deque +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from .message_channel_sink import MessageChannelSink +from .message_channel_source import MessageChannelSource + +import isaacteleop.deviceio as deviceio + +if TYPE_CHECKING: + from isaacteleop.schema import MessageChannelMessagesTrackedT + + +@dataclass +class MessageChannelConfig: + """Configuration for creating message channel retargeter nodes.""" + + name: str + channel_uuid: bytes + channel_name: str = "" + max_message_size: int = 64 * 1024 + outbound_queue_capacity: int = 256 + + def create_nodes(self) -> tuple[MessageChannelSource, MessageChannelSink]: + if len(self.channel_uuid) != 16: + raise ValueError( + "MessageChannelConfig.channel_uuid must be exactly 16 bytes" + ) + if self.outbound_queue_capacity <= 0: + raise ValueError("MessageChannelConfig.outbound_queue_capacity must be > 0") + tracker = deviceio.MessageChannelTracker( + self.channel_uuid, + self.channel_name, + self.max_message_size, + ) + # deque(maxlen=N) provides bounded queueing and drops oldest on overflow. + outbound_queue: deque["MessageChannelMessagesTrackedT"] = deque( + maxlen=self.outbound_queue_capacity + ) + source = MessageChannelSource( + f"{self.name}_source", + tracker, + outbound_queue, + ) + sink = MessageChannelSink(f"{self.name}_sink", outbound_queue) + return source, sink + + +def message_channel_config( + name: str, + channel_uuid: bytes, + channel_name: str = "", + max_message_size: int = 64 * 1024, + outbound_queue_capacity: int = 256, +) -> tuple[MessageChannelSource, MessageChannelSink]: + """Create source/sink nodes and shared tracker for a message channel.""" + return MessageChannelConfig( + name=name, + channel_uuid=channel_uuid, + channel_name=channel_name, + max_message_size=max_message_size, + outbound_queue_capacity=outbound_queue_capacity, + ).create_nodes() + + +# Backward-compatible alias matching requested API name. +messageChannelConfig = message_channel_config diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_sink.py b/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_sink.py new file mode 100644 index 000000000..55063660d --- /dev/null +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_sink.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Message channel sink node. + +Queues raw message payloads for delivery by MessageChannelSource.poll_tracker(). +""" + +from collections import deque +from typing import TYPE_CHECKING + +from ..interface.base_retargeter import BaseRetargeter +from ..interface.retargeter_core_types import RetargeterIO, RetargeterIOType +from .deviceio_tensor_types import MessageChannelMessagesTrackedGroup + +if TYPE_CHECKING: + from isaacteleop.schema import MessageChannelMessagesTrackedT + + +class MessageChannelSink(BaseRetargeter): + """Sink node that enqueues outbound message channel payloads.""" + + def __init__( + self, name: str, outbound_queue: "deque[MessageChannelMessagesTrackedT]" + ) -> None: + self._outbound_queue = outbound_queue + super().__init__(name) + + def input_spec(self) -> RetargeterIOType: + return {"messages_tracked": MessageChannelMessagesTrackedGroup()} + + def output_spec(self) -> RetargeterIOType: + return {} + + def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: + messages_tracked = inputs["messages_tracked"][0] + self._outbound_queue.append(messages_tracked) diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_source.py b/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_source.py new file mode 100644 index 000000000..a65da7029 --- /dev/null +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/message_channel_source.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Message channel source node. + +Converts DeviceIO MessageChannelMessagesTrackedT wrapper data for graph use. +""" + +from collections import deque +from typing import Any, TYPE_CHECKING + +from .interface import IDeviceIOSource +from ..interface.retargeter_core_types import RetargeterIO, RetargeterIOType +from ..interface.tensor_group import TensorGroup +from isaacteleop.schema import MessageChannelMessages, MessageChannelMessagesTrackedT +from .deviceio_tensor_types import ( + DeviceIOMessageChannelMessagesTracked, + MessageChannelMessagesTrackedGroup, + MessageChannelStatusGroup, + MessageChannelConnectionStatus, +) + +if TYPE_CHECKING: + from isaacteleop.deviceio_trackers import ( + ITracker, + MessageChannelTracker, + ) + from isaacteleop.schema import ( + MessageChannelMessagesTrackedT, + ) + + +class MessageChannelSource(IDeviceIOSource): + """Source node for reading message channel payloads from DeviceIO.""" + + def __init__( + self, + name: str, + tracker: "MessageChannelTracker", + outbound_queue: "deque[MessageChannelMessagesTrackedT]", + ) -> None: + self._tracker = tracker + self._outbound_queue = outbound_queue + self._last_drained_messages_tracked: MessageChannelMessagesTrackedT | None = ( + None + ) + self._last_status: MessageChannelConnectionStatus = ( + MessageChannelConnectionStatus.UNKNOWN + ) + super().__init__(name) + + def get_tracker(self) -> "ITracker": + return self._tracker + + def poll_tracker(self, deviceio_session: Any) -> RetargeterIO: + raw_status = self._tracker.get_status(deviceio_session) + try: + self._last_status = MessageChannelConnectionStatus(int(raw_status)) + except ValueError: + self._last_status = MessageChannelConnectionStatus.UNKNOWN + + # Flush queued outbound messages before polling inbound data. + if self._last_status == MessageChannelConnectionStatus.CONNECTED: + while self._outbound_queue: + batch = self._outbound_queue[0] + if batch.data: + sent = 0 + try: + for message in batch.data: + outbound_message = MessageChannelMessages(message.payload) + self._tracker.send_message( + deviceio_session, outbound_message + ) + sent += 1 + except Exception: + # Drop the delivered prefix so already-sent messages are + # not re-delivered on the next flush attempt. + if sent < len(batch.data): + self._outbound_queue[0] = MessageChannelMessagesTrackedT( + batch.data[sent:] + ) + else: + self._outbound_queue.popleft() + raise + self._outbound_queue.popleft() + + self._last_drained_messages_tracked = self._tracker.get_messages( + deviceio_session + ) + + source_inputs = self.input_spec() + result: RetargeterIO = {} + for input_name, group_type in source_inputs.items(): + tg = TensorGroup(group_type) + if self._last_drained_messages_tracked is None: + tg[0] = MessageChannelMessagesTrackedT() + else: + tg[0] = self._last_drained_messages_tracked + result[input_name] = tg + return result + + def input_spec(self) -> RetargeterIOType: + return { + "deviceio_message_channel_messages": DeviceIOMessageChannelMessagesTracked() + } + + def output_spec(self) -> RetargeterIOType: + return { + "messages_tracked": MessageChannelMessagesTrackedGroup(), + "status": MessageChannelStatusGroup(), + } + + def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: + if self._last_drained_messages_tracked is None: + outputs["messages_tracked"][0] = MessageChannelMessagesTrackedT() + else: + outputs["messages_tracked"][0] = self._last_drained_messages_tracked + outputs["status"][0] = self._last_status diff --git a/src/core/retargeting_engine_tests/python/test_message_channel_nodes.py b/src/core/retargeting_engine_tests/python/test_message_channel_nodes.py new file mode 100644 index 000000000..e0863cf23 --- /dev/null +++ b/src/core/retargeting_engine_tests/python/test_message_channel_nodes.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections import deque + +from isaacteleop.schema import ( + MessageChannelMessages, + MessageChannelMessagesTrackedT, +) +from isaacteleop.retargeting_engine.deviceio_source_nodes import ( + MessageChannelConnectionStatus, + MessageChannelSink, + MessageChannelSource, +) +from isaacteleop.retargeting_engine.interface.base_retargeter import _make_output_group +from isaacteleop.retargeting_engine.interface.tensor_group import TensorGroup + + +class DummyTracker: + """Tiny fake tracker used to unit test source/sink behavior.""" + + def __init__(self): + self.sent_payloads = [] + self._drained = MessageChannelMessagesTrackedT() + self.connected = True + + def send_message(self, session, payload): + if not self.connected: + raise RuntimeError("channel is not connected") + self.sent_payloads.append(payload.payload) + + def get_status(self, session): + return 1 if self.connected else 0 + + def get_messages(self, session): + return self._drained + + +def _make_inputs(node, raw): + spec = node.input_spec() + result = {} + for name, objects in raw.items(): + tg = TensorGroup(spec[name]) + for i, obj in enumerate(objects): + tg[i] = obj + result[name] = tg + return result + + +def test_message_channel_source_active_message(): + tracker = DummyTracker() + source = MessageChannelSource("msg_source", tracker, deque()) + message = MessageChannelMessages(b"hello") + tracker._drained = MessageChannelMessagesTrackedT([message]) + + inputs = source.poll_tracker(deviceio_session=object()) + outputs = {k: _make_output_group(v) for k, v in source.output_spec().items()} + source.compute(inputs, outputs) + + tracked_batch = outputs["messages_tracked"][0] + assert tracked_batch.data is not None + assert tracked_batch.data[0].payload == b"hello" + assert outputs["status"][0] == MessageChannelConnectionStatus.CONNECTED + + +def test_message_channel_source_inactive_message(): + tracker = DummyTracker() + source = MessageChannelSource("msg_source", tracker, deque()) + tracker._drained = MessageChannelMessagesTrackedT() + + inputs = source.poll_tracker(deviceio_session=object()) + outputs = {k: _make_output_group(v) for k, v in source.output_spec().items()} + source.compute(inputs, outputs) + + tracked_batch = outputs["messages_tracked"][0] + assert tracked_batch.data == [] + assert outputs["status"][0] == MessageChannelConnectionStatus.CONNECTED + + +def test_message_channel_sink_enqueues_message(): + outbound_queue = deque() + sink = MessageChannelSink("msg_sink", outbound_queue) + m1 = MessageChannelMessages(b"echo") + m2 = MessageChannelMessages(b"pong") + + batch = MessageChannelMessagesTrackedT([m1, m2]) + inputs = _make_inputs(sink, {"messages_tracked": [batch]}) + outputs = {k: _make_output_group(v) for k, v in sink.output_spec().items()} + sink.compute(inputs, outputs) + + assert sink.output_spec() == {} + assert outputs == {} + assert len(outbound_queue) == 1 + queued_batch = outbound_queue[0] + assert queued_batch.data is not None + assert queued_batch.data[0].payload == b"echo" + assert queued_batch.data[1].payload == b"pong" + + +def test_message_channel_source_returns_all_drained_messages(): + tracker = DummyTracker() + source = MessageChannelSource("msg_source_list", tracker, deque()) + m1 = MessageChannelMessages(b"x") + m2 = MessageChannelMessages(b"y") + tracker._drained = MessageChannelMessagesTrackedT([m1, m2]) + + inputs = source.poll_tracker(deviceio_session=object()) + outputs = {k: _make_output_group(v) for k, v in source.output_spec().items()} + source.compute(inputs, outputs) + + messages_tracked = outputs["messages_tracked"][0] + assert messages_tracked.data is not None + assert len(messages_tracked.data) == 2 + assert messages_tracked.data[0].payload == b"x" + assert messages_tracked.data[1].payload == b"y" + assert outputs["status"][0] == MessageChannelConnectionStatus.CONNECTED + + +def test_message_channel_source_keeps_outbound_queue_while_disconnected(): + tracker = DummyTracker() + tracker.connected = False + outbound_queue = deque() + source = MessageChannelSource("msg_source_disconnected", tracker, outbound_queue) + + outbound_queue.append( + MessageChannelMessagesTrackedT([MessageChannelMessages(b"a")]) + ) + outbound_queue.append( + MessageChannelMessagesTrackedT([MessageChannelMessages(b"b")]) + ) + + inputs = source.poll_tracker(deviceio_session=object()) + outputs = {k: _make_output_group(v) for k, v in source.output_spec().items()} + source.compute(inputs, outputs) + + assert len(outbound_queue) == 2 + assert outbound_queue[0].data[0].payload == b"a" + assert outbound_queue[1].data[0].payload == b"b" + assert outputs["status"][0] == MessageChannelConnectionStatus.CONNECTING + + tracker.connected = True + source.poll_tracker(deviceio_session=object()) + assert len(outbound_queue) == 0 + assert tracker.sent_payloads == [b"a", b"b"] + + inputs = source.poll_tracker(deviceio_session=object()) + outputs = {k: _make_output_group(v) for k, v in source.output_spec().items()} + source.compute(inputs, outputs) + assert outputs["status"][0] == MessageChannelConnectionStatus.CONNECTED + + +def test_message_channel_sink_bounded_queue_drops_oldest(): + outbound_queue = deque(maxlen=2) + sink = MessageChannelSink("msg_sink_bounded", outbound_queue) + m1 = MessageChannelMessages(b"1") + m2 = MessageChannelMessages(b"2") + m3 = MessageChannelMessages(b"3") + + b1 = MessageChannelMessagesTrackedT([m1]) + b2 = MessageChannelMessagesTrackedT([m2]) + b3 = MessageChannelMessagesTrackedT([m3]) + outputs = {k: _make_output_group(v) for k, v in sink.output_spec().items()} + sink.compute(_make_inputs(sink, {"messages_tracked": [b1]}), outputs) + sink.compute(_make_inputs(sink, {"messages_tracked": [b2]}), outputs) + sink.compute(_make_inputs(sink, {"messages_tracked": [b3]}), outputs) + + assert len(outbound_queue) == 2 + assert outbound_queue[0].data[0].payload == b"2" + assert outbound_queue[1].data[0].payload == b"3" diff --git a/src/core/schema/fbs/message_channel.fbs b/src/core/schema/fbs/message_channel.fbs new file mode 100644 index 000000000..28f039d5e --- /dev/null +++ b/src/core/schema/fbs/message_channel.fbs @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +include "timestamp.fbs"; + +namespace core; + +// Raw message payload for OpenXR opaque message channels. +table MessageChannelMessages { + payload: [ubyte] (id: 0); +} + +// Tracked wrapper for drained batches (data is null when no messages were drained). +table MessageChannelMessagesTracked { + data: [MessageChannelMessages] (id: 0); +} + +// MCAP recording wrapper for MessageChannelMessages. +table MessageChannelMessagesRecord { + data: MessageChannelMessages (id: 0); + timestamp: DeviceDataTimestamp (id: 1); +} + +root_type MessageChannelMessagesRecord; diff --git a/src/core/schema/python/CMakeLists.txt b/src/core/schema/python/CMakeLists.txt index 2949e19e7..d948e1417 100644 --- a/src/core/schema/python/CMakeLists.txt +++ b/src/core/schema/python/CMakeLists.txt @@ -7,6 +7,7 @@ pybind11_add_module(schema_py full_body_bindings.h hand_bindings.h head_bindings.h + message_channel_bindings.h pedals_bindings.h pose_bindings.h schema_module.cpp diff --git a/src/core/schema/python/message_channel_bindings.h b/src/core/schema/python/message_channel_bindings.h new file mode 100644 index 000000000..4f398a7bc --- /dev/null +++ b/src/core/schema/python/message_channel_bindings.h @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#include + +namespace py = pybind11; + +namespace core +{ + +inline void bind_message_channel(py::module& m) +{ + py::class_>(m, "MessageChannelMessages") + .def(py::init([]() { return std::make_shared(); })) + .def(py::init( + [](py::bytes payload) + { + auto obj = std::make_shared(); + std::string data = payload; + obj->payload.assign(data.begin(), data.end()); + return obj; + }), + py::arg("payload")) + .def_property( + "payload", + [](const MessageChannelMessagesT& self) + { return py::bytes(reinterpret_cast(self.payload.data()), self.payload.size()); }, + [](MessageChannelMessagesT& self, py::bytes payload) + { + std::string data = payload; + self.payload.assign(data.begin(), data.end()); + }); + + py::class_>( + m, "MessageChannelMessagesRecord") + .def(py::init<>()) + .def(py::init( + [](const MessageChannelMessagesT& data, const DeviceDataTimestamp& timestamp) + { + auto obj = std::make_shared(); + obj->data = std::make_shared(data); + obj->timestamp = std::make_shared(timestamp); + return obj; + }), + py::arg("data"), py::arg("timestamp")) + .def_property_readonly("data", + [](const MessageChannelMessagesRecordT& self) -> std::shared_ptr + { return self.data; }) + .def_readonly("timestamp", &MessageChannelMessagesRecordT::timestamp); + + py::class_>( + m, "MessageChannelMessagesTrackedT") + .def(py::init<>()) + .def(py::init( + [](const std::vector>& data) + { + auto obj = std::make_shared(); + obj->data = data; + return obj; + }), + py::arg("data")) + .def_property_readonly( + "data", + [](const MessageChannelMessagesTrackedT& self) -> std::vector> + { return self.data; }); +} + +} // namespace core diff --git a/src/core/schema/python/schema_init.py b/src/core/schema/python/schema_init.py index e5ab8c893..3f3aeb108 100644 --- a/src/core/schema/python/schema_init.py +++ b/src/core/schema/python/schema_init.py @@ -35,6 +35,10 @@ Generic3AxisPedalOutput, Generic3AxisPedalOutputTrackedT, Generic3AxisPedalOutputRecord, + # Message channel types. + MessageChannelMessages, + MessageChannelMessagesTrackedT, + MessageChannelMessagesRecord, # Camera-related types. StreamType, FrameMetadataOak, @@ -78,6 +82,10 @@ "Generic3AxisPedalOutput", "Generic3AxisPedalOutputTrackedT", "Generic3AxisPedalOutputRecord", + # Message channel types. + "MessageChannelMessages", + "MessageChannelMessagesTrackedT", + "MessageChannelMessagesRecord", # Camera types. "StreamType", "FrameMetadataOak", diff --git a/src/core/schema/python/schema_module.cpp b/src/core/schema/python/schema_module.cpp index 4c5b725fa..e20dae586 100644 --- a/src/core/schema/python/schema_module.cpp +++ b/src/core/schema/python/schema_module.cpp @@ -10,6 +10,7 @@ #include "full_body_bindings.h" #include "hand_bindings.h" #include "head_bindings.h" +#include "message_channel_bindings.h" #include "oak_bindings.h" #include "pedals_bindings.h" #include "pose_bindings.h" @@ -39,6 +40,9 @@ PYBIND11_MODULE(_schema, m) // Bind pedals types (Generic3AxisPedalOutput table). core::bind_pedals(m); + // Bind message channel types (MessageChannelMessages table). + core::bind_message_channel(m); + // Bind OAK types (StreamType enum, FrameMetadataOak table). core::bind_oak(m);