diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index 1b570c822b..d75bbb7035 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -193,3 +193,8 @@ BIOES viterbi argmaxes unpadded +keypoint +keypoints +Keypoint +Keypoints +letterboxing diff --git a/.eslintrc.js b/.eslintrc.js index 8cb84b9ff8..35d2da64cc 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -10,6 +10,7 @@ const VALID_CATEGORIES = [ 'Models - LLM', 'Models - Object Detection', 'Models - Instance Segmentation', + 'Models - Pose Estimation', 'Models - Semantic Segmentation', 'Models - Speech To Text', 'Models - Style Transfer', diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index 730a0007e2..03770c2720 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -149,6 +149,14 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} /> + Instance Segmentation + router.navigate('pose_estimation/')} + > + Pose Estimation + router.navigate('ocr/')} diff --git a/apps/computer-vision/app/pose_estimation/index.tsx b/apps/computer-vision/app/pose_estimation/index.tsx new file mode 100644 index 0000000000..3a07816e7e --- /dev/null +++ b/apps/computer-vision/app/pose_estimation/index.tsx @@ -0,0 +1,259 @@ +import Spinner from '../../components/Spinner'; +import { BottomBar } from '../../components/BottomBar'; +import { getImage } from '../../utils'; +import { + usePoseEstimation, + PoseDetections, + RnExecutorchError, + RnExecutorchErrorCode, + YOLO26N_POSE, +} from 'react-native-executorch'; +import { View, StyleSheet, Image, Text } from 'react-native'; +import React, { useContext, useEffect, useState } from 'react'; +import { GeneratingContext } from '../../context'; +import ScreenWrapper from '../../ScreenWrapper'; +import { StatsBar } from '../../components/StatsBar'; +import Svg, { Circle, Line } from 'react-native-svg'; +import ErrorBanner from '../../components/ErrorBanner'; +import { COCO_SKELETON_CONNECTIONS } from '../../components/utils/cocoSkeleton'; + +// Colors for different people +const PERSON_COLORS = ['lime', 'cyan', 'magenta', 'yellow', 'orange', 'pink']; + +export default function PoseEstimationScreen() { + const [imageUri, setImageUri] = useState(''); + const [results, setResults] = useState([]); + const [error, setError] = useState(null); + const [imageDimensions, setImageDimensions] = useState<{ + width: number; + height: number; + }>(); + const [inferenceTime, setInferenceTime] = useState(null); + const [layout, setLayout] = useState({ width: 0, height: 0 }); + + const model = usePoseEstimation({ model: YOLO26N_POSE }); + const { setGlobalGenerating } = useContext(GeneratingContext); + + useEffect(() => { + setGlobalGenerating(model.isGenerating); + }, [model.isGenerating, setGlobalGenerating]); + + useEffect(() => { + if (model.error) setError(String(model.error)); + }, [model.error]); + + const handleCameraPress = async (isCamera: boolean) => { + const image = await getImage(isCamera); + const uri = image?.uri; + const width = image?.width; + const height = image?.height; + + if (uri && width && height) { + setImageUri(image.uri as string); + setImageDimensions({ width, height }); + setResults([]); + setInferenceTime(null); + } + }; + + const runForward = async () => { + if (imageUri) { + try { + const start = Date.now(); + const output = await model.forward(imageUri, { inputSize: 384 }); + setInferenceTime(Date.now() - start); + setResults(output); + } catch (e) { + if (e instanceof RnExecutorchError) { + switch (e.code) { + case RnExecutorchErrorCode.FileReadFailed: + setError('Could not read the selected image.'); + break; + case RnExecutorchErrorCode.ModelGenerating: + setError('Model is busy — wait for the current run to finish.'); + break; + case RnExecutorchErrorCode.InvalidUserInput: + case RnExecutorchErrorCode.InvalidArgument: + setError(`Invalid input: ${e.message}`); + break; + default: + setError(e.message); + } + } else { + setError(e instanceof Error ? e.message : String(e)); + } + } + } + }; + + if (!model.isReady) { + return ( + + ); + } + + return ( + + setError(null)} /> + + + {imageUri && imageDimensions?.width && imageDimensions?.height ? ( + + setLayout({ + width: e.nativeEvent.layout.width, + height: e.nativeEvent.layout.height, + }) + } + > + + {results.length > 0 && + layout.width > 0 && + layout.height > 0 && + (() => { + // Account for resizeMode="contain" letterboxing: the image's + // displayed area is smaller than the container in one axis. + const imageRatio = + imageDimensions.width / imageDimensions.height; + const layoutRatio = layout.width / layout.height; + let scaleX: number, scaleY: number; + if (imageRatio > layoutRatio) { + scaleX = layout.width / imageDimensions.width; + scaleY = layout.width / imageRatio / imageDimensions.height; + } else { + scaleY = layout.height / imageDimensions.height; + scaleX = + (layout.height * imageRatio) / imageDimensions.width; + } + const offsetX = + (layout.width - imageDimensions.width * scaleX) / 2; + const offsetY = + (layout.height - imageDimensions.height * scaleY) / 2; + const isInBounds = (kp: { x: number; y: number }) => + kp.x >= 0 && + kp.y >= 0 && + kp.x <= imageDimensions.width && + kp.y <= imageDimensions.height; + return ( + + {results.map((personKeypoints, personIdx) => { + const color = + PERSON_COLORS[personIdx % PERSON_COLORS.length]; + return ( + + {COCO_SKELETON_CONNECTIONS.map( + ([from, to], lineIdx) => { + const kp1 = personKeypoints[from]; + const kp2 = personKeypoints[to]; + if (!kp1 || !kp2) return null; + if (!isInBounds(kp1) || !isInBounds(kp2)) + return null; + return ( + + ); + } + )} + {Object.entries(personKeypoints) + .filter(([, kp]) => isInBounds(kp)) + .map(([name, kp]) => ( + + ))} + + ); + })} + + ); + })()} + + ) : ( + + )} + + {!imageUri && ( + + Pose Estimation + + This model detects human body keypoints (17 COCO keypoints) and + draws a skeleton overlay. Pick an image from your gallery or take + one with your camera to get started. + + + )} + + 0 ? results.length : null} + /> + + + ); +} + +const styles = StyleSheet.create({ + imageContainer: { + flex: 6, + width: '100%', + padding: 16, + }, + image: { + flex: 2, + borderRadius: 8, + width: '100%', + }, + imageWrapper: { + flex: 1, + width: '100%', + height: '100%', + }, + fullSizeImage: { + width: '100%', + height: '100%', + }, + infoContainer: { + alignItems: 'center', + padding: 16, + gap: 8, + }, + infoTitle: { + fontSize: 18, + fontWeight: '600', + color: 'navy', + }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, +}); diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index c39ff096c4..4020d20023 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -28,6 +28,7 @@ import SegmentationTask from '../../components/vision_camera/tasks/SegmentationT import InstanceSegmentationTask from '../../components/vision_camera/tasks/InstanceSegmentationTask'; import OCRTask from '../../components/vision_camera/tasks/OCRTask'; import StyleTransferTask from '../../components/vision_camera/tasks/StyleTransferTask'; +import PoseEstimationTask from '../../components/vision_camera/tasks/PoseEstimationTask'; // 1. Import ErrorBanner import ErrorBanner from '../../components/ErrorBanner'; @@ -36,6 +37,7 @@ type TaskId = | 'objectDetection' | 'segmentation' | 'instanceSegmentation' + | 'poseEstimation' | 'ocr' | 'styleTransfer'; type ModelId = @@ -52,6 +54,7 @@ type ModelId = | 'segmentationSelfie' | 'instanceSegmentationYolo26n' | 'instanceSegmentationRfdetr' + | 'poseEstimationYolo26n' | 'ocr' | 'styleTransferCandy' | 'styleTransferMosaic'; @@ -86,6 +89,11 @@ const TASKS: Task[] = [ { id: 'instanceSegmentationRfdetr', label: 'RF-DETR Nano Seg' }, ], }, + { + id: 'poseEstimation', + label: 'Pose', + variants: [{ id: 'poseEstimationYolo26n', label: 'YOLO26N Pose' }], + }, { id: 'objectDetection', label: 'Detect', @@ -223,6 +231,12 @@ export default function VisionCameraScreen() { outputs={frameOutput ? [frameOutput] : []} isActive={isFocused} orientationSource="device" + onError={(e) => { + console.warn('[Camera] onError', e); + setError(e.message); + }} + onStarted={() => console.log('[Camera] session started')} + onPreviewStarted={() => console.log('[Camera] preview got first frame')} /> )} + {activeTask === 'poseEstimation' && ( + + )} {activeTask === 'ocr' && } {activeTask === 'styleTransfer' && ( ([]); + const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); + const lastFrameTimeRef = useRef(Date.now()); + + useEffect(() => { + onErrorChange(poseModel.error ? String(poseModel.error) : null); + }, [poseModel.error, onErrorChange]); + + useEffect(() => { + onReadyChange(poseModel.isReady); + }, [poseModel.isReady, onReadyChange]); + + useEffect(() => { + onProgressChange(poseModel.downloadProgress); + }, [poseModel.downloadProgress, onProgressChange]); + + useEffect(() => { + onGeneratingChange(poseModel.isGenerating); + }, [poseModel.isGenerating, onGeneratingChange]); + + const poseRof = poseModel.runOnFrame; + + const updateDetections = useCallback( + (p: { + results: PoseDetections; + imageWidth: number; + imageHeight: number; + }) => { + setDetections(p.results); + setImageSize({ width: p.imageWidth, height: p.imageHeight }); + const now = Date.now(); + const diff = now - lastFrameTimeRef.current; + if (diff > 0) onFpsChange(Math.round(1000 / diff), diff); + lastFrameTimeRef.current = now; + }, + [onFpsChange] + ); + + const frameOutput = useFrameOutput({ + pixelFormat: 'rgb', + dropFramesWhileBusy: true, + enablePreviewSizedOutputBuffers: true, + + onFrame: useCallback( + (frame: Frame) => { + 'worklet'; + if (frameKillSwitch.getDirty()) { + frame.dispose(); + return; + } + try { + if (!poseRof) return; + const isFrontCamera = cameraPositionSync.getDirty() === 'front'; + const result = poseRof(frame, isFrontCamera, { + detectionThreshold: 0.5, + }); + const screenW = frame.height; + const screenH = frame.width; + if (result) { + scheduleOnRN(updateDetections, { + results: result, + imageWidth: screenW, + imageHeight: screenH, + }); + } + } catch { + // Frame may be disposed before processing completes + } finally { + frame.dispose(); + } + }, + [cameraPositionSync, poseRof, frameKillSwitch, updateDetections] + ), + }); + + useEffect(() => { + onFrameOutputChange(frameOutput); + }, [frameOutput, onFrameOutputChange]); + + const scale = Math.max( + canvasSize.width / imageSize.width, + canvasSize.height / imageSize.height + ); + const offsetX = (canvasSize.width - imageSize.width * scale) / 2; + const offsetY = (canvasSize.height - imageSize.height * scale) / 2; + + return ( + + + {detections.map((personKeypoints, personIdx) => { + const color = PERSON_COLORS[personIdx % PERSON_COLORS.length]; + const isVisible = (kp: { x: number; y: number }) => + kp.x >= 0 && + kp.y >= 0 && + kp.x <= imageSize.width && + kp.y <= imageSize.height; + return ( + + {/* Draw skeleton lines */} + {COCO_SKELETON_CONNECTIONS.map(([from, to], lineIdx) => { + const kp1 = personKeypoints[from]; + const kp2 = personKeypoints[to]; + if (!kp1 || !kp2) return null; + if (!isVisible(kp1) || !isVisible(kp2)) return null; + const x1 = kp1.x * scale + offsetX; + const y1 = kp1.y * scale + offsetY; + const x2 = kp2.x * scale + offsetX; + const y2 = kp2.y * scale + offsetY; + return ( + + ); + })} + {/* Draw keypoints */} + {Object.entries(personKeypoints) + .filter(([, kp]) => isVisible(kp)) + .map(([name, kp]) => { + const cx = kp.x * scale + offsetX; + const cy = kp.y * scale + offsetY; + return ( + + ); + })} + + ); + })} + + + ); +} diff --git a/docs/docs/03-hooks/02-computer-vision/usePoseEstimation.md b/docs/docs/03-hooks/02-computer-vision/usePoseEstimation.md new file mode 100644 index 0000000000..e31b928074 --- /dev/null +++ b/docs/docs/03-hooks/02-computer-vision/usePoseEstimation.md @@ -0,0 +1,140 @@ +--- +title: usePoseEstimation +--- + +Pose estimation is a computer vision technique that detects human bodies in an image and locates a fixed set of keypoints (e.g. nose, shoulders, knees) for each detected person. Unlike object detection, which produces a class label and a bounding box, pose estimation produces a structured set of named keypoints per person. React Native ExecuTorch offers a dedicated hook `usePoseEstimation` for this task. + +:::info +It is recommended to use models provided by us, which are available at our [Hugging Face repository](https://huggingface.co/software-mansion/react-native-executorch-yolo26-pose). You can also use [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) shipped with our library. +::: + +## API Reference + +- For detailed API Reference for `usePoseEstimation` see: [`usePoseEstimation` API Reference](../../06-api-reference/functions/usePoseEstimation.md). +- For all pose estimation models available out-of-the-box in React Native ExecuTorch see: [Pose Estimation Models](../../06-api-reference/index.md#models---pose-estimation). + +## High Level Overview + +```typescript +import { usePoseEstimation, YOLO26N_POSE } from 'react-native-executorch'; + +const model = usePoseEstimation({ + model: YOLO26N_POSE, +}); + +const imageUri = 'file:///Users/.../photo.jpg'; + +try { + const detections = await model.forward(imageUri); + // detections is an array of PersonKeypoints, keyed by name (e.g. detections[0].NOSE) +} catch (error) { + console.error(error); +} +``` + +### Arguments + +`usePoseEstimation` takes [`PoseEstimationProps`](../../06-api-reference/interfaces/PoseEstimationProps.md) that consists of: + +- `model` - An object containing: + - `modelName` - The name of a built-in model. See [`PoseEstimationModelSources`](../../06-api-reference/interfaces/PoseEstimationProps.md) for the list of supported models. + - `modelSource` - The location of the model binary (a URL or a bundled resource). +- An optional flag [`preventLoad`](../../06-api-reference/interfaces/PoseEstimationProps.md#preventload) which prevents auto-loading of the model. + +The hook is generic over the model config — TypeScript automatically infers the correct keypoint type based on the `modelName` you provide. No explicit generic parameter is needed. + +You need more details? Check the following resources: + +- For detailed information about `usePoseEstimation` arguments check this section: [`usePoseEstimation` arguments](../../06-api-reference/functions/usePoseEstimation.md#parameters). +- For all pose estimation models available out-of-the-box in React Native ExecuTorch see: [Pose Estimation Models](../../06-api-reference/index.md#models---pose-estimation). +- For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + +### Returns + +`usePoseEstimation` returns a [`PoseEstimationType`](../../06-api-reference/interfaces/PoseEstimationType.md) object containing: + +- `isReady` - Whether the model is loaded and ready to process images. +- `isGenerating` - Whether the model is currently processing an image. +- `error` - An error object if the model failed to load or encountered a runtime error. +- `downloadProgress` - A value between 0 and 1 representing the download progress of the model binary. +- `forward` - A function to run inference on an image. +- `getAvailableInputSizes` - A function that returns available input sizes for multi-method models (YOLO). Returns `undefined` for single-method models. +- `runOnFrame` - A synchronous worklet function for real-time VisionCamera frame processing. See [VisionCamera Integration](./visioncamera-integration.md) for usage. + +## Running the model + +To run the model, use the [`forward`](../../06-api-reference/interfaces/PoseEstimationType.md#forward) method. It accepts two arguments: + +- `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). +- `options` (optional) - A [`PoseEstimationOptions`](../../06-api-reference/interfaces/PoseEstimationOptions.md) object with the following properties: + - `detectionThreshold` (optional) - A number between 0 and 1 representing the minimum confidence score for a detected person. Defaults to model-specific value (typically `0.5`). + - `keypointThreshold` (optional) - Per-keypoint visibility threshold (0-1). Keypoints whose model-reported visibility falls below this are emitted as `(-1, -1)` so consumers can skip them. Defaults to model-specific value. + - `inputSize` (optional) - For multi-method models like YOLO, specify the input resolution (`384`, `512`, or `640`). Defaults to `384` for YOLO models. + +`forward` returns a promise resolving to an array of [`PersonKeypoints`](../../06-api-reference/type-aliases/PersonKeypoints.md) — one entry per detected person. Each entry is an object keyed by the model's keypoint names (typed against the model's keypoint map), where each value is a [`Keypoint`](../../06-api-reference/interfaces/Keypoint.md) with: + +- `x` - The x coordinate in the original image's pixel space. +- `y` - The y coordinate in the original image's pixel space. + +:::info +Keypoints whose visibility falls below `keypointThreshold` (or that the model considers off-image) are returned as `{ x: -1, y: -1 }`. Filter them out before drawing — e.g. `if (kp.x < 0 || kp.y < 0) skip;`. +::: + +For example, with a COCO-keypoint model: + +```typescript +const detections = await model.forward(imageUri); +const firstPerson = detections[0]; +firstPerson.NOSE; // { x, y } +firstPerson.LEFT_SHOULDER; // { x, y } +``` + +The keypoint names available on each person are determined by the model's keypoint map and are checked at compile time. + +## Example + +```typescript +import { usePoseEstimation, YOLO26N_POSE } from 'react-native-executorch'; + +function App() { + const model = usePoseEstimation({ + model: YOLO26N_POSE, + }); + + const handleDetect = async () => { + if (!model.isReady) return; + + const imageUri = 'file:///Users/.../photo.jpg'; + + try { + const detections = await model.forward(imageUri, { + detectionThreshold: 0.5, + inputSize: 640, + }); + + console.log('Detected:', detections.length, 'people'); + for (const person of detections) { + console.log('Nose at', person.NOSE.x, person.NOSE.y); + } + } catch (error) { + console.error(error); + } + }; + + // ... +} +``` + +## VisionCamera integration + +See the full guide: [VisionCamera Integration](./visioncamera-integration.md). + +## Supported models + +| Model | Number of keypoints | Keypoint list | Multi-size Support | +| ------------------------------------------------------------------------------------------- | ------------------- | ----------------------------------------------------------- | ------------------ | +| [YOLO26N-Pose](https://huggingface.co/software-mansion/react-native-executorch-yolo26-pose) | 17 | [COCO](../../06-api-reference/enumerations/CocoKeypoint.md) | Yes (384/512/640) | + +:::tip +YOLO models support multiple input sizes (384px, 512px, 640px). Smaller sizes are faster but less accurate, while larger sizes are more accurate but slower. Choose based on your speed/accuracy requirements. +::: diff --git a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md index 008d0121bd..b96b7f5274 100644 --- a/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md +++ b/docs/docs/03-hooks/02-computer-vision/visioncamera-integration.md @@ -23,6 +23,7 @@ The following hooks expose `runOnFrame`: - [`useInstanceSegmentation`](./useInstanceSegmentation.md) - [`useSemanticSegmentation`](./useSemanticSegmentation.md) - [`useStyleTransfer`](./useStyleTransfer.md) +- [`usePoseEstimation`](./usePoseEstimation.md) ## runOnFrame vs forward diff --git a/docs/docs/04-typescript-api/02-computer-vision/PoseEstimationModule.md b/docs/docs/04-typescript-api/02-computer-vision/PoseEstimationModule.md new file mode 100644 index 0000000000..bc32211b19 --- /dev/null +++ b/docs/docs/04-typescript-api/02-computer-vision/PoseEstimationModule.md @@ -0,0 +1,111 @@ +--- +title: PoseEstimationModule +--- + +TypeScript API implementation of the [usePoseEstimation](../../03-hooks/02-computer-vision/usePoseEstimation.md) hook. + +## API Reference + +- For detailed API Reference for `PoseEstimationModule` see: [`PoseEstimationModule` API Reference](../../06-api-reference/classes/PoseEstimationModule.md). +- For all pose estimation models available out-of-the-box in React Native ExecuTorch see: [Pose Estimation Models](../../06-api-reference/index.md#models---pose-estimation). + +## High Level Overview + +```typescript +import { PoseEstimationModule, YOLO26N_POSE } from 'react-native-executorch'; + +const imageUri = 'path/to/image.png'; + +// Creating an instance and loading the model +const poseEstimationModule = + await PoseEstimationModule.fromModelName(YOLO26N_POSE); + +// Running the model +const detections = await poseEstimationModule.forward(imageUri); +detections[0].NOSE; // { x, y } +``` + +### Methods + +All methods of `PoseEstimationModule` are explained in details here: [`PoseEstimationModule` API Reference](../../06-api-reference/classes/PoseEstimationModule.md) + +## Loading the model + +Use the static [`fromModelName`](../../06-api-reference/classes/PoseEstimationModule.md#frommodelname) factory method. It accepts a model config object (with `modelName` and `modelSource`) and an optional `onDownloadProgress` callback. It returns a promise resolving to a `PoseEstimationModule` instance whose return type is statically tied to the model's keypoint map. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + +## Running the model + +To run the model, use the [`forward`](../../06-api-reference/classes/PoseEstimationModule.md#forward) method. It accepts two arguments: + +- `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). +- `options` (optional) - A [`PoseEstimationOptions`](../../06-api-reference/interfaces/PoseEstimationOptions.md) object with: + - `detectionThreshold` (optional) - Minimum confidence score for a detected person (0-1). Defaults to model-specific value. + - `keypointThreshold` (optional) - Per-keypoint visibility threshold (0-1). Keypoints whose model-reported visibility falls below this are reported as `(-1, -1)` so consumers can skip them. Defaults to model-specific value. + - `inputSize` (optional) - For YOLO models: `384`, `512`, or `640`. Defaults to `384`. + +The method returns a promise resolving to an array of [`PersonKeypoints`](../../06-api-reference/type-aliases/PersonKeypoints.md). Each entry is an object keyed by the model's keypoint names (e.g. `NOSE`, `LEFT_SHOULDER`), where each value is a [`Keypoint`](../../06-api-reference/interfaces/Keypoint.md) with `x` and `y` coordinates in the original image's pixel space. + +:::info +Keypoints whose visibility falls below `keypointThreshold` (or that the model considers off-image) are returned as `{ x: -1, y: -1 }`. Filter them out before drawing — e.g. `if (kp.x < 0 || kp.y < 0) skip;`. +::: + +For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. + +### Example with Options + +```typescript +const detections = await model.forward(imageUri, { + detectionThreshold: 0.5, + inputSize: 640, // YOLO models only +}); + +for (const person of detections) { + console.log('Nose at', person.NOSE.x, person.NOSE.y); +} +``` + +## Using a custom model + +Use [`fromCustomModel`](../../06-api-reference/classes/PoseEstimationModule.md#fromcustommodel) to load your own exported model binary instead of a built-in preset. You provide the keypoint map; `forward`'s return type is automatically derived from it, so each detected person is typed as a record keyed by the names you defined. + +```typescript +import { PoseEstimationModule } from 'react-native-executorch'; + +const HandKeypoints = { + WRIST: 0, + THUMB_TIP: 1, + INDEX_TIP: 2, + MIDDLE_TIP: 3, + RING_TIP: 4, + PINKY_TIP: 5, +} as const; + +const detector = await PoseEstimationModule.fromCustomModel( + 'https://example.com/custom_pose.pte', + { keypointMap: HandKeypoints }, + (progress) => console.log(progress) +); + +const detections = await detector.forward(imageUri); +detections[0].THUMB_TIP; // { x, y } +``` + +### Required model contract + +The `.pte` binary must expose a `forward` method (or per-input-size methods such as `forward_384`, `forward_512`, `forward_640` for multi-resolution models) with the following interface: + +**Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. H and W are read from the model's declared input shape at load time. The mean and std vectors are supplied via `preprocessorConfig.normMean` and `preprocessorConfig.normStd` on the [`PoseEstimationConfig`](../../06-api-reference/type-aliases/PoseEstimationConfig.md) you pass to `fromCustomModel`; if omitted, the runtime feeds the resized image without normalization. + +**Outputs:** exactly three `float32` tensors, in this order: + +1. **Bounding boxes** — shape `[Q, 4]`, `(x1, y1, x2, y2)` per detection in model-input pixel space, where `Q` is the number of candidate detections. +2. **Confidence scores** — shape `[Q]`, person confidence in `[0, 1]`. +3. **Keypoints** — shape `[Q, K, 3]`, where `K` is the number of keypoints (must match the size of your `keypointMap`) and the last dimension is `(x, y, visibility)` per keypoint, in model-input pixel space. + +Preprocessing (resize → normalize) and postprocessing (coordinate rescaling, threshold filtering, mapping keypoints to your named keypoint map) are handled by the native runtime — your model only needs to produce the raw detections above. + +## Managing memory + +The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/PoseEstimationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/PoseEstimationModule.md#forward) after [`delete`](../../06-api-reference/classes/PoseEstimationModule.md#delete) unless you load the module again. diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index 22add11719..53ee65a904 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -74,6 +75,11 @@ void RnExecutorchInstaller::injectJSIBindings( models::object_detection::ObjectDetection>(jsiRuntime, jsCallInvoker, "loadObjectDetection")); + jsiRuntime->global().setProperty( + *jsiRuntime, "loadPoseEstimation", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadPoseEstimation")); + jsiRuntime->global().setProperty( *jsiRuntime, "loadExecutorchModule", RnExecutorchInstaller::loadModel( diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index c9aca42491..a20fd7b1bc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -361,6 +362,30 @@ inline jsi::Value getJsiValue(const std::vector &vec, return {runtime, array}; } +inline jsi::Value getJsiValue( + const rnexecutorch::models::pose_estimation::PersonKeypoints &keypoints, + jsi::Runtime &runtime) { + jsi::Array array(runtime, keypoints.size()); + for (size_t i = 0; i < keypoints.size(); ++i) { + jsi::Object point(runtime); + point.setProperty(runtime, "x", keypoints[i].x); + point.setProperty(runtime, "y", keypoints[i].y); + array.setValueAtIndex(runtime, i, point); + } + return array; +} + +// Pose estimation: all detected people (vector of person keypoints) +inline jsi::Value getJsiValue( + const rnexecutorch::models::pose_estimation::PoseDetections &detections, + jsi::Runtime &runtime) { + jsi::Array array(runtime, detections.size()); + for (size_t i = 0; i < detections.size(); ++i) { + array.setValueAtIndex(runtime, i, getJsiValue(detections[i], runtime)); + } + return array; +} + // Conditional as on android, size_t and uint64_t reduce to the same type, // introducing ambiguity template +#include +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch::models::pose_estimation { + +PoseEstimation::PoseEstimation(const std::string &modelSource, + std::vector normMean, + std::vector normStd, + std::shared_ptr callInvoker) + : VisionModel(modelSource, callInvoker) { + if (normMean.size() == 3) { + normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); + } else if (!normMean.empty()) { + log(LOG_LEVEL::Warn, + "normMean must have 3 elements — ignoring provided value."); + } + if (normStd.size() == 3) { + normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); + } else if (!normStd.empty()) { + log(LOG_LEVEL::Warn, + "normStd must have 3 elements — ignoring provided value."); + } +} + +PoseDetections PoseEstimation::postprocess(const std::vector &tensors, + cv::Size originalSize, + double detectionThreshold, + double keypointThreshold) { + // Output tensors (batch dim squeezed): + // 0: boxes (Q, 4) - xyxy bbox in model input pixel space + // 1: scores (Q,) - person confidence [0, 1] + // 2: keypoints (Q, K, 3) - per-detection keypoints (x, y, visibility) + + if (tensors.size() < 3) { + // TODO: maybe create a ContractNotMet error or something like this, this + // would also need to be applied for other models + return {}; + } + + auto scoresTensor = tensors[1].toTensor(); + auto keypointsTensor = tensors[2].toTensor(); + + const int32_t numKeypoints = static_cast(keypointsTensor.size(1)); + + const float *scores = scoresTensor.const_data_ptr(); + const float *kpData = keypointsTensor.const_data_ptr(); + + auto numDetections = scoresTensor.size(0); + + const auto &shape = modelInputShape_; + cv::Size modelInputSize(static_cast(shape[shape.size() - 1]), + static_cast(shape[shape.size() - 2])); + + float scaleX = static_cast(originalSize.width) / modelInputSize.width; + float scaleY = + static_cast(originalSize.height) / modelInputSize.height; + + PoseDetections allDetections; + + for (size_t i = 0; i < numDetections; ++i) { + if (scores[i] < detectionThreshold) { + continue; + } + + PersonKeypoints keypoints; + keypoints.reserve(numKeypoints); + + const float *detectionKps = kpData + i * numKeypoints * 3; + + for (size_t k = 0; k < numKeypoints; ++k) { + float visibility = detectionKps[k * 3 + 2]; + if (visibility < keypointThreshold) { + keypoints.emplace_back(-1, -1); + continue; + } + float x = detectionKps[k * 3]; + float y = detectionKps[k * 3 + 1]; + + int32_t scaledX = static_cast(std::round(x * scaleX)); + int32_t scaledY = static_cast(std::round(y * scaleY)); + + keypoints.emplace_back(scaledX, scaledY); + } + + allDetections.push_back(std::move(keypoints)); + } + + return allDetections; +} + +PoseDetections PoseEstimation::runInference(cv::Mat image, + double detectionThreshold, + double keypointThreshold, + const std::string &methodName) { + + log(LOG_LEVEL::Debug, "Running inference with model name: " + methodName); + + if (detectionThreshold < 0.0 || detectionThreshold > 1.0) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "detectionThreshold must be in range [0, 1]"); + } + if (keypointThreshold < 0.0 || keypointThreshold > 1.0) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "keypointThreshold must be in range [0, 1]"); + } + + std::scoped_lock lock(inference_mutex_); + cv::Size originalSize = image.size(); + auto inputShapes = getAllInputShapes(methodName); + if (inputShapes.empty() || inputShapes[0].size() < 2) { + throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, + "Could not determine input shape for method: " + + methodName); + } + modelInputShape_ = inputShapes[0]; + cv::Mat resizedToModelInput = preprocess(image); + + auto inputTensor = + (normMean_ && normStd_) + ? image_processing::getTensorFromMatrix( + modelInputShape_, resizedToModelInput, *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, + resizedToModelInput); + + auto executeResult = execute(methodName, {inputTensor}); + if (!executeResult.ok()) { + throw RnExecutorchError(executeResult.error(), + "The model's " + methodName + + " method did not succeed. " + "Ensure the model input is correct."); + } + + return postprocess(executeResult.get(), originalSize, detectionThreshold, + keypointThreshold); +} + +PoseDetections PoseEstimation::generateFromString(std::string imageSource, + double detectionThreshold, + double keypointThreshold, + std::string methodName) { + cv::Mat imageBGR = image_processing::readImage(imageSource); + cv::Mat imageRGB; + cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); + return runInference(std::move(imageRGB), detectionThreshold, + keypointThreshold, methodName); +} + +PoseDetections PoseEstimation::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData, + double detectionThreshold, + double keypointThreshold, + std::string methodName) { + auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); + cv::Mat frame = extractFromFrame(runtime, frameData); + cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient); + auto detections = + runInference(rotated, detectionThreshold, keypointThreshold, methodName); + for (auto &person : detections) { + ::rnexecutorch::utils::inverseRotatePoints(person, orient, rotated.size()); + } + return detections; +} + +PoseDetections PoseEstimation::generateFromPixels(JSTensorViewIn pixelData, + double detectionThreshold, + double keypointThreshold, + std::string methodName) { + cv::Mat image = extractFromPixels(pixelData); + return runInference(image, detectionThreshold, keypointThreshold, methodName); +} + +} // namespace rnexecutorch::models::pose_estimation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.h b/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.h new file mode 100644 index 0000000000..983519b34b --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.h @@ -0,0 +1,49 @@ +#pragma once + +#include "rnexecutorch/metaprogramming/ConstructorHelpers.h" +#include "rnexecutorch/models/VisionModel.h" +#include "rnexecutorch/models/pose_estimation/Types.h" +#include +#include + +namespace rnexecutorch { +namespace models::pose_estimation { + +class PoseEstimation : public VisionModel { +public: + PoseEstimation(const std::string &modelSource, std::vector normMean, + std::vector normStd, + std::shared_ptr callInvoker); + + [[nodiscard("Registered non-void function")]] PoseDetections + generateFromString(std::string imageSource, double detectionThreshold, + double keypointThreshold, std::string methodName); + [[nodiscard("Registered non-void function")]] PoseDetections + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, + double detectionThreshold, double keypointThreshold, + std::string methodName); + [[nodiscard("Registered non-void function")]] PoseDetections + generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold, + double keypointThreshold, std::string methodName); + +private: + std::optional normMean_; + std::optional normStd_; + + [[nodiscard("Registered non-void function")]] + PoseDetections runInference(cv::Mat image, double detectionThreshold, + double keypointThreshold, + const std::string &modelName); + + [[nodiscard("Registered non-void function")]] + PoseDetections postprocess(const std::vector &evl, + cv::Size originalSize, double detectionThreshold, + double keypointThreshold); +}; + +} // namespace models::pose_estimation + +REGISTER_CONSTRUCTOR(models::pose_estimation::PoseEstimation, std::string, + std::vector, std::vector, + std::shared_ptr); +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/Types.h new file mode 100644 index 0000000000..7d671ab7bb --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/Types.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace rnexecutorch::models::pose_estimation { + +// Single keypoint (x, y) +struct Keypoint { + int32_t x; + int32_t y; +}; + +// N keypoints for one person, depending on the model in question +using PersonKeypoints = std::vector; + +// N people for each image +using PoseDetections = std::vector; + +} // namespace rnexecutorch::models::pose_estimation diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 8286518217..7edf9d8a7c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -331,6 +331,17 @@ add_rn_test(InstanceSegmentationTests integration/InstanceSegmentationTest.cpp LIBS opencv_deps android ) +add_rn_test(PoseEstimationTests integration/PoseEstimationTest.cpp + SOURCES + ${RNEXECUTORCH_DIR}/models/pose_estimation/PoseEstimation.cpp + ${RNEXECUTORCH_DIR}/models/VisionModel.cpp + ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp + ${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp + ${IMAGE_UTILS_SOURCES} + LIBS opencv_deps android +) + add_rn_test(OCRTests integration/OCRTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/ocr/OCR.cpp diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/README.md b/packages/react-native-executorch/common/rnexecutorch/tests/README.md index 1a35743df0..8a28b40032 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/README.md +++ b/packages/react-native-executorch/common/rnexecutorch/tests/README.md @@ -69,5 +69,8 @@ To add new test you need to: LIBS opencv_deps ) ``` -* Lastly, add the test executable name to the run_tests script along with all the needed URL and assets. +* In `run_tests.sh`: + * Add the test executable name to `TEST_EXECUTABLES`. + * Add any models/files the test downloads at runtime to `MODELS` (filename + URL), **and** register every downloaded file the test loads in the `models_for_test()` case statement. The runner pushes only the files listed there from `$MODELS_DIR` to the device for that test, runs it, and removes them afterwards — anything missing won't be on the device when the test runs. Tests with no model dependencies don't need an entry. + * Repo-bundled fixtures (small images, audio, etc.) go in `TEST_ASSETS` instead. Those are pushed once up front and stay on the device; do not list them in `models_for_test()`. diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/PoseEstimationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/PoseEstimationTest.cpp new file mode 100644 index 0000000000..2e549bc304 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/PoseEstimationTest.cpp @@ -0,0 +1,253 @@ +#include "BaseModelTests.h" +#include "VisionModelTests.h" +#include +#include +#include +#include +#include + +using namespace rnexecutorch; +using namespace rnexecutorch::models::pose_estimation; +using namespace model_tests; + +constexpr auto kValidPoseModelPath = "yolo26n-pose.pte"; +constexpr auto kValidTestImagePath = + "file:///data/local/tmp/rnexecutorch_tests/we_are_software_mansion.jpg"; +constexpr auto kMethodName = "forward_384"; + +// ============================================================================ +// Common tests via typed test suite +// ============================================================================ +namespace model_tests { +template <> struct ModelTraits { + using ModelType = PoseEstimation; + + static ModelType createValid() { + return ModelType(kValidPoseModelPath, {}, {}, nullptr); + } + + static ModelType createInvalid() { + return ModelType("nonexistent.pte", {}, {}, nullptr); + } + + static void callGenerate(ModelType &model) { + (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, kMethodName); + } +}; +} // namespace model_tests + +using PoseEstimationTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(PoseEstimation, CommonModelTest, + PoseEstimationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(PoseEstimation, VisionModelTest, + PoseEstimationTypes); + +// ============================================================================ +// generateFromString — input path validity +// ============================================================================ +TEST(PoseEstimationGenerateTests, InvalidImagePathThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.5, + kMethodName), + RnExecutorchError); +} + +TEST(PoseEstimationGenerateTests, EmptyImagePathThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW((void)model.generateFromString("", 0.5, 0.5, kMethodName), + RnExecutorchError); +} + +TEST(PoseEstimationGenerateTests, MalformedURIThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5, 0.5, + kMethodName), + RnExecutorchError); +} + +// ============================================================================ +// generateFromString — threshold range +// ============================================================================ +TEST(PoseEstimationGenerateTests, NegativeDetectionThresholdThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.5, + kMethodName), + RnExecutorchError); +} + +TEST(PoseEstimationGenerateTests, DetectionThresholdAboveOneThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.5, + kMethodName), + RnExecutorchError); +} + +TEST(PoseEstimationGenerateTests, NegativeKeypointThresholdThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, -0.1, + kMethodName), + RnExecutorchError); +} + +TEST(PoseEstimationGenerateTests, KeypointThresholdAboveOneThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 1.1, + kMethodName), + RnExecutorchError); +} + +// ============================================================================ +// generateFromString — happy path & output shape +// ============================================================================ +TEST(PoseEstimationGenerateTests, ValidImageReturnsResults) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, kMethodName); + EXPECT_GE(results.size(), 0u); +} + +TEST(PoseEstimationGenerateTests, HighThresholdReturnsFewerResults) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + auto lowThresholdResults = + model.generateFromString(kValidTestImagePath, 0.1, 0.5, kMethodName); + auto highThresholdResults = + model.generateFromString(kValidTestImagePath, 0.95, 0.5, kMethodName); + EXPECT_GE(lowThresholdResults.size(), highThresholdResults.size()); +} + +TEST(PoseEstimationGenerateTests, AllDetectionsHaveSameKeypointCount) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + auto results = + model.generateFromString(kValidTestImagePath, 0.1, 0.5, kMethodName); + if (results.size() < 2) { + GTEST_SKIP() << "Need at least 2 detections to compare keypoint counts"; + } + const size_t firstSize = results.front().size(); + EXPECT_GT(firstSize, 0u); + for (const auto &person : results) { + EXPECT_EQ(person.size(), firstSize); + } +} + +TEST(PoseEstimationGenerateTests, KeypointsHaveValidStructure) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, kMethodName); + // Each detection must contain a non-zero number of keypoints, and each + // keypoint must be aggregate-initializable as { x, y } ints (compile-time). + for (const auto &person : results) { + EXPECT_GT(person.size(), 0u); + for (const auto &kp : person) { + // No range constraint here — out-of-bounds coords are valid model + // output for low-visibility keypoints; consumers filter on visibility. + static_assert(std::is_same_v); + static_assert(std::is_same_v); + (void)kp; + } + } +} + +// ============================================================================ +// generateFromPixels +// ============================================================================ +TEST(PoseEstimationPixelTests, ValidPixelDataReturnsResults) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + constexpr int32_t width = 4, height = 4, channels = 3; + std::vector pixelData(width * height * channels, 128); + JSTensorViewIn tensorView{pixelData.data(), + {height, width, channels}, + executorch::aten::ScalarType::Byte}; + auto results = model.generateFromPixels(tensorView, 0.3, 0.5, kMethodName); + EXPECT_GE(results.size(), 0u); +} + +TEST(PoseEstimationPixelTests, NegativeThresholdThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + constexpr int32_t width = 4, height = 4, channels = 3; + std::vector pixelData(width * height * channels, 128); + JSTensorViewIn tensorView{pixelData.data(), + {height, width, channels}, + executorch::aten::ScalarType::Byte}; + EXPECT_THROW( + (void)model.generateFromPixels(tensorView, -0.1, 0.5, kMethodName), + RnExecutorchError); +} + +TEST(PoseEstimationPixelTests, ThresholdAboveOneThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + constexpr int32_t width = 4, height = 4, channels = 3; + std::vector pixelData(width * height * channels, 128); + JSTensorViewIn tensorView{pixelData.data(), + {height, width, channels}, + executorch::aten::ScalarType::Byte}; + EXPECT_THROW( + (void)model.generateFromPixels(tensorView, 1.1, 0.5, kMethodName), + RnExecutorchError); +} + +// ============================================================================ +// Method name +// ============================================================================ +TEST(PoseEstimationMethodTests, InvalidMethodNameThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, + "forward_999"), + RnExecutorchError); +} + +TEST(PoseEstimationMethodTests, EmptyMethodNameThrows) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + EXPECT_THROW( + (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, ""), + RnExecutorchError); +} + +// ============================================================================ +// Normalisation params (constructor logs but does not throw) +// ============================================================================ +TEST(PoseEstimationNormTests, ValidNormParamsDoesntThrow) { + const std::vector mean = {0.485f, 0.456f, 0.406f}; + const std::vector std = {0.229f, 0.224f, 0.225f}; + EXPECT_NO_THROW(PoseEstimation(kValidPoseModelPath, mean, std, nullptr)); +} + +TEST(PoseEstimationNormTests, InvalidNormMeanSizeDoesntThrow) { + EXPECT_NO_THROW(PoseEstimation(kValidPoseModelPath, {0.5f}, + {0.229f, 0.224f, 0.225f}, nullptr)); +} + +TEST(PoseEstimationNormTests, InvalidNormStdSizeDoesntThrow) { + EXPECT_NO_THROW(PoseEstimation(kValidPoseModelPath, {0.485f, 0.456f, 0.406f}, + {0.5f}, nullptr)); +} + +TEST(PoseEstimationNormTests, ValidNormParamsGenerateSucceeds) { + const std::vector mean = {0.485f, 0.456f, 0.406f}; + const std::vector std = {0.229f, 0.224f, 0.225f}; + PoseEstimation model(kValidPoseModelPath, mean, std, nullptr); + EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, + kMethodName)); +} + +// ============================================================================ +// Inherited VisionModel methods +// ============================================================================ +TEST(PoseEstimationInheritedTests, GetInputShapeWorks) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + auto shape = model.getInputShape(kMethodName, 0); + EXPECT_EQ(shape.size(), 4); + EXPECT_EQ(shape[0], 1); + EXPECT_EQ(shape[1], 3); +} + +TEST(PoseEstimationInheritedTests, GetAllInputShapesWorks) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + auto shapes = model.getAllInputShapes(kMethodName); + EXPECT_FALSE(shapes.empty()); +} + +TEST(PoseEstimationInheritedTests, GetMethodMetaWorks) { + PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); + auto result = model.getMethodMeta(kMethodName); + EXPECT_TRUE(result.ok()); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh b/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh index e60508ec39..0ec0677d5b 100755 --- a/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh +++ b/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh @@ -34,6 +34,7 @@ TEST_EXECUTABLES=( "LLMTests" "TextToImageTests" "InstanceSegmentationTests" + "PoseEstimationTests" "SemanticSegmentationTests" "OCRTests" "VerticalOCRTests" @@ -55,7 +56,7 @@ MODELS=( "style_transfer_candy_xnnpack_fp32.pte|https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/main/xnnpack/style_transfer_candy_xnnpack_fp32.pte" "efficientnet_v2_s_xnnpack.pte|https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s/resolve/v0.6.0/xnnpack/efficientnet_v2_s_xnnpack.pte" "ssdlite320-mobilenetv3-large.pte|https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/resolve/v0.6.0/ssdlite320-mobilenetv3-large.pte" - "test_image.jpg|https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/1200px-Cat_November_2010-1a.jpg" + "test_image.jpg|https://upload.wikimedia.org/wikipedia/commons/f/f8/Cat_in_tree03.jpg" "clip-vit-base-patch32-vision_xnnpack.pte|https://huggingface.co/software-mansion/react-native-executorch-clip-vit-base-patch32/resolve/v0.6.0/clip-vit-base-patch32-vision_xnnpack.pte" "all-MiniLM-L6-v2_xnnpack.pte|https://huggingface.co/software-mansion/react-native-executorch-all-MiniLM-L6-v2/resolve/v0.6.0/all-MiniLM-L6-v2_xnnpack.pte" "tokenizer.json|https://huggingface.co/software-mansion/react-native-executorch-all-MiniLM-L6-v2/resolve/v0.6.0/tokenizer.json" @@ -81,6 +82,7 @@ MODELS=( "lfm2_vl_tokenizer_config.json|https://huggingface.co/software-mansion/react-native-executorch-lfm2.5-VL-1.6B/resolve/main/tokenizer_config.json" "yolo26n-seg.pte|https://huggingface.co/software-mansion/react-native-executorch-yolo26-seg/resolve/v0.8.0/yolo26n-seg/xnnpack/yolo26n-seg.pte" "segmentation_image.jpg|https://upload.wikimedia.org/wikipedia/commons/thumb/8/85/Collage_audi.jpg/1280px-Collage_audi.jpg" + "yolo26n-pose.pte|https://huggingface.co/software-mansion/react-native-executorch-yolo26-pose/resolve/v0.9.0/yolo26n/xnnpack/yolo26n-pose_xnnpack.pte" ) # ============================================================================ @@ -194,22 +196,23 @@ run_test() { # model dependencies. Adding a new test? Add its filenames below. models_for_test() { case "$1" in - BaseModelTests) echo "style_transfer_candy_xnnpack_fp32.pte" ;; - ClassificationTests) echo "efficientnet_v2_s_xnnpack.pte test_image.jpg" ;; - ObjectDetectionTests) echo "ssdlite320-mobilenetv3-large.pte test_image.jpg" ;; - ImageEmbeddingsTests) echo "clip-vit-base-patch32-vision_xnnpack.pte test_image.jpg" ;; - TextEmbeddingsTests) echo "all-MiniLM-L6-v2_xnnpack.pte tokenizer.json" ;; - StyleTransferTests) echo "style_transfer_candy_xnnpack_fp32.pte test_image.jpg" ;; - VADTests) echo "fsmn-vad_xnnpack.pte" ;; - TokenizerModuleTests) echo "tokenizer.json" ;; - SpeechToTextTests) echo "whisper_tiny_en_xnnpack.pte whisper_tokenizer.json" ;; - TextToSpeechTests) echo "kokoro_duration_predictor.pte kokoro_synthesizer.pte kokoro_af_heart.bin kokoro_us_lexicon.json kokoro_en_tagger.json" ;; - LLMTests) echo "smolLm2_135M_8da4w.pte smollm_tokenizer.json lfm2_5_vl_quantized_xnnpack_v2.pte lfm2_vl_tokenizer.json lfm2_vl_tokenizer_config.json test_image.jpg" ;; - TextToImageTests) echo "t2i_tokenizer.json t2i_encoder.pte t2i_unet.pte t2i_decoder.pte" ;; - InstanceSegmentationTests) echo "yolo26n-seg.pte segmentation_image.jpg" ;; - SemanticSegmentationTests) echo "deeplabV3_xnnpack_fp32.pte test_image.jpg" ;; - OCRTests | VerticalOCRTests) echo "xnnpack_craft_quantized.pte xnnpack_crnn_english.pte" ;; - *) echo "" ;; + BaseModelTests) echo "style_transfer_candy_xnnpack_fp32.pte" ;; + ClassificationTests) echo "efficientnet_v2_s_xnnpack.pte test_image.jpg" ;; + ObjectDetectionTests) echo "ssdlite320-mobilenetv3-large.pte test_image.jpg" ;; + ImageEmbeddingsTests) echo "clip-vit-base-patch32-vision_xnnpack.pte test_image.jpg" ;; + TextEmbeddingsTests) echo "all-MiniLM-L6-v2_xnnpack.pte tokenizer.json" ;; + StyleTransferTests) echo "style_transfer_candy_xnnpack_fp32.pte test_image.jpg" ;; + VADTests) echo "fsmn-vad_xnnpack.pte" ;; + TokenizerModuleTests) echo "tokenizer.json" ;; + SpeechToTextTests) echo "whisper_tiny_en_xnnpack.pte whisper_tokenizer.json" ;; + TextToSpeechTests) echo "kokoro_duration_predictor.pte kokoro_synthesizer.pte kokoro_af_heart.bin kokoro_us_lexicon.json kokoro_en_tagger.json" ;; + LLMTests) echo "smolLm2_135M_8da4w.pte smollm_tokenizer.json lfm2_5_vl_quantized_xnnpack_v2.pte lfm2_vl_tokenizer.json lfm2_vl_tokenizer_config.json test_image.jpg" ;; + TextToImageTests) echo "t2i_tokenizer.json t2i_encoder.pte t2i_unet.pte t2i_decoder.pte" ;; + InstanceSegmentationTests) echo "yolo26n-seg.pte segmentation_image.jpg" ;; + PoseEstimationTests) echo "yolo26n-pose.pte" ;; + SemanticSegmentationTests) echo "deeplabV3_xnnpack_fp32.pte test_image.jpg" ;; + OCRTests | VerticalOCRTests) echo "xnnpack_craft_quantized.pte xnnpack_crnn_english.pte" ;; + *) echo "" ;; esac } diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.cpp index 9a30b2c1a8..80425c2dab 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.cpp @@ -21,7 +21,14 @@ cv::Mat rotateFrameForModel(const cv::Mat &mat, cv::rotate(result, result, cv::ROTATE_90_CLOCKWISE); break; case Orientation::Right: +#if defined(__APPLE__) cv::rotate(result, result, cv::ROTATE_90_COUNTERCLOCKWISE); +#else + // Android front-cam in upright portrait reports orient=Right with + // isMirrored=true; the sensor mount needs CW (same as back-cam Left) + // to land upright for the model after the horizontal flip above. + cv::rotate(result, result, cv::ROTATE_90_CLOCKWISE); +#endif break; case Orientation::Down: cv::rotate(result, result, cv::ROTATE_180); @@ -50,13 +57,17 @@ void inverseRotateBbox(computer_vision::BBox &bbox, break; } case Orientation::Right: { - // upside-down portrait → portrait: nx = w - x, ny = h - y +#if defined(__APPLE__) + // iOS upside-down portrait → portrait: nx = w - x, ny = h - y float nx1 = w - bbox.x2, ny1 = h - bbox.y2; float nx2 = w - bbox.x1, ny2 = h - bbox.y1; bbox.x1 = nx1; bbox.y1 = ny1; bbox.x2 = nx2; bbox.y2 = ny2; +#endif + // Android front-cam upright portrait: rotated frame already in screen + // space, no inverse needed. break; } case Orientation::Down: { @@ -99,7 +110,12 @@ cv::Mat inverseRotateMat(const cv::Mat &mat, const FrameOrientation &orient) { cv::rotate(mat, result, cv::ROTATE_90_CLOCKWISE); break; case Orientation::Right: +#if defined(__APPLE__) cv::rotate(mat, result, cv::ROTATE_180); +#else + // Android front-cam upright portrait: mask already in screen space. + result = mat; +#endif break; case Orientation::Down: cv::rotate(mat, result, cv::ROTATE_90_COUNTERCLOCKWISE); diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.h b/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.h index ed3fb124f4..8f9ca46cc2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.h @@ -1,9 +1,9 @@ #pragma once -#include #include #include #include +#include namespace rnexecutorch::utils { @@ -61,37 +61,57 @@ void inverseRotateBbox(computer_vision::BBox &bbox, cv::Mat inverseRotateMat(const cv::Mat &mat, const FrameOrientation &orient); /** - * @brief Map 4-point bbox from rotated-frame space back to screen space. + * @brief A 2D point with mutable arithmetic `x` and `y` members. * - * Inverse of rotateFrameForModel for 4-point bboxes. - * rotatedSize is the rotated frame size (rotated.size()). - * Templated on point type — requires P to have float x and y members. + * Satisfied by e.g. `cv::Point2f`, `cv::Point`, and any user-defined struct + * shaped `{ T x; T y; }` where `T` is arithmetic. */ template -void inverseRotatePoints(std::array &points, - const FrameOrientation &orient, cv::Size rotatedSize) { +concept Point2D = requires(P &p) { + requires std::is_arithmetic_v>; + requires std::is_arithmetic_v>; +}; + +/** + * @brief Map a sequence of points from rotated-frame space back to screen + * space. Inverse of rotateFrameForModel for a collection of points. + * + * Works on any iterable whose elements satisfy {@link Point2D} + * (e.g. `std::array`, `std::vector

`). + * rotatedSize is the rotated frame size (rotated.size()). + */ +template + requires Point2D +void inverseRotatePoints(Points &points, const FrameOrientation &orient, + cv::Size rotatedSize) { const float w = static_cast(rotatedSize.width); const float h = static_cast(rotatedSize.height); + using Coord = decltype(std::declval().begin()->x); + for (auto &p : points) { - float x = p.x; - float y = p.y; + float x = static_cast(p.x); + float y = static_cast(p.y); switch (orient.orientation) { case Orientation::Up: // landscape-left → portrait: nx = h-y, ny = x - p.x = h - y; - p.y = x; + p.x = static_cast(h - y); + p.y = static_cast(x); break; case Orientation::Right: - // upside-down portrait → portrait: nx = w-x, ny = h-y - p.x = w - x; - p.y = h - y; +#if defined(__APPLE__) + // iOS upside-down portrait → portrait: nx = w-x, ny = h-y + p.x = static_cast(w - x); + p.y = static_cast(h - y); +#endif + // Android front-cam upright portrait: rotated frame already in + // screen space (mirror-selfie portrait), no inverse needed. break; case Orientation::Down: // landscape-right → portrait: nx = y, ny = w-x - p.x = y; - p.y = w - x; + p.x = static_cast(y); + p.y = static_cast(w - x); break; case Orientation::Left: break; @@ -105,8 +125,8 @@ void inverseRotatePoints(std::array &points, float sw = swapped ? h : w; float sh = swapped ? w : h; for (auto &p : points) { - p.x = sw - p.x; - p.y = sh - p.y; + p.x = static_cast(sw - static_cast(p.x)); + p.y = static_cast(sh - static_cast(p.y)); } } #endif diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 432f915eef..6fb20f9ca3 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -663,6 +663,17 @@ export const YOLO26X = { modelSource: YOLO26X_DETECTION_MODEL, } as const; +// YOLO26 Pose Estimation +const YOLO26N_POSE_MODEL = `${URL_PREFIX}-yolo26-pose/${NEXT_VERSION_TAG}/yolo26n/xnnpack/yolo26n-pose_xnnpack.pte`; + +/** + * @category Models - Pose Estimation + */ +export const YOLO26N_POSE = { + modelName: 'yolo26n-pose', + modelSource: YOLO26N_POSE_MODEL, +} as const; + // Style transfer const STYLE_TRANSFER_CANDY_MODEL = Platform.OS === `ios` diff --git a/packages/react-native-executorch/src/constants/poseEstimation.ts b/packages/react-native-executorch/src/constants/poseEstimation.ts new file mode 100644 index 0000000000..6d3929e8ef --- /dev/null +++ b/packages/react-native-executorch/src/constants/poseEstimation.ts @@ -0,0 +1,24 @@ +/** + * Standard COCO keypoint enum (17 keypoints). + * Use for type-safe keypoint access: `keypoints[CocoKeypoint.NOSE]` + * @category Types + */ +export enum CocoKeypoint { + NOSE = 0, + LEFT_EYE = 1, + RIGHT_EYE = 2, + LEFT_EAR = 3, + RIGHT_EAR = 4, + LEFT_SHOULDER = 5, + RIGHT_SHOULDER = 6, + LEFT_ELBOW = 7, + RIGHT_ELBOW = 8, + LEFT_WRIST = 9, + RIGHT_WRIST = 10, + LEFT_HIP = 11, + RIGHT_HIP = 12, + LEFT_KNEE = 13, + RIGHT_KNEE = 14, + LEFT_ANKLE = 15, + RIGHT_ANKLE = 16, +} diff --git a/packages/react-native-executorch/src/hooks/computer_vision/usePoseEstimation.ts b/packages/react-native-executorch/src/hooks/computer_vision/usePoseEstimation.ts new file mode 100644 index 0000000000..2eda27deaa --- /dev/null +++ b/packages/react-native-executorch/src/hooks/computer_vision/usePoseEstimation.ts @@ -0,0 +1,60 @@ +import { + PoseEstimationModule, + PoseEstimationKeypoints, +} from '../../modules/computer_vision/PoseEstimationModule'; +import { + PoseEstimationModelSources, + PoseEstimationProps, + PoseEstimationType, + PoseEstimationOptions, +} from '../../types/poseEstimation'; +import { PixelData } from '../../types/common'; +import { useModuleFactory } from '../useModuleFactory'; + +/** + * React hook for managing a Pose Estimation model instance. + * @typeParam C - A {@link PoseEstimationModelSources} config specifying which built-in model to load. + * @category Hooks + * @param props - Configuration object containing `model` config and optional `preventLoad` flag. + * @returns An object with model state (`error`, `isReady`, `isGenerating`, `downloadProgress`) and typed `forward` and `runOnFrame` functions. + */ +export const usePoseEstimation = ({ + model, + preventLoad = false, +}: PoseEstimationProps): PoseEstimationType< + PoseEstimationKeypoints +> => { + const { + error, + isReady, + isGenerating, + downloadProgress, + runForward, + runOnFrame, + instance, + } = useModuleFactory({ + factory: (config, onProgress) => + PoseEstimationModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); + + const forward = ( + input: string | PixelData, + options?: PoseEstimationOptions + ) => runForward((inst) => inst.forward(input, options)); + + const getAvailableInputSizes = () => + instance?.getAvailableInputSizes() ?? undefined; + + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + runOnFrame, + getAvailableInputSizes, + }; +}; diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 7cc148d16b..96d167a7d2 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -58,6 +58,11 @@ declare global { normStd: Triple | [], labelNames: string[] ) => Promise; + var loadPoseEstimation: ( + source: string, + normMean: Triple | [], + normStd: Triple | [] + ) => Promise; var loadExecutorchModule: (source: string) => Promise; var loadTokenizerModule: (source: string) => Promise; var loadImageEmbeddings: (source: string) => Promise; @@ -124,6 +129,7 @@ if ( global.loadExecutorchModule == null || global.loadClassification == null || global.loadObjectDetection == null || + global.loadPoseEstimation == null || global.loadTokenizerModule == null || global.loadTextEmbeddings == null || global.loadImageEmbeddings == null || @@ -165,6 +171,7 @@ export * from './hooks/computer_vision/useOCR'; export * from './hooks/computer_vision/useVerticalOCR'; export * from './hooks/computer_vision/useImageEmbeddings'; export * from './hooks/computer_vision/useTextToImage'; +export * from './hooks/computer_vision/usePoseEstimation'; export * from './hooks/natural_language_processing/useLLM'; export * from './hooks/natural_language_processing/useSpeechToText'; @@ -186,6 +193,7 @@ export * from './modules/computer_vision/OCRModule'; export * from './modules/computer_vision/VerticalOCRModule'; export * from './modules/computer_vision/ImageEmbeddingsModule'; export * from './modules/computer_vision/TextToImageModule'; +export * from './modules/computer_vision/PoseEstimationModule'; export * from './modules/natural_language_processing/LLMModule'; export * from './modules/natural_language_processing/SpeechToTextModule'; @@ -223,6 +231,7 @@ export * from './types/classification'; export * from './types/imageEmbeddings'; export * from './types/styleTransfer'; export * from './types/tti'; +export * from './types/poseEstimation'; // constants export * from './constants/commonVision'; @@ -232,6 +241,7 @@ export * from './constants/ocr/models'; export * from './constants/tts/models'; export * from './constants/tts/voices'; export * from './constants/llmDefaults'; +export * from './constants/poseEstimation'; export { RnExecutorchError } from './errors/errorUtils'; export { RnExecutorchErrorCode } from './errors/ErrorCodes'; diff --git a/packages/react-native-executorch/src/modules/computer_vision/PoseEstimationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/PoseEstimationModule.ts new file mode 100644 index 0000000000..ff2b68b1fd --- /dev/null +++ b/packages/react-native-executorch/src/modules/computer_vision/PoseEstimationModule.ts @@ -0,0 +1,328 @@ +import { + Frame, + LabelEnum, + PixelData, + ResourceSource, +} from '../../types/common'; +import { + Keypoint, + PersonKeypoints, + PoseDetections, + PoseEstimationOptions, + PoseEstimationModelSources, + PoseEstimationModelName, + PoseEstimationConfig, +} from '../../types/poseEstimation'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { RnExecutorchError } from '../../errors/errorUtils'; +import { VisionModule } from './VisionModule'; +import { fetchModelPath } from './VisionLabeledModule'; +import { CocoKeypoint } from '../../constants/poseEstimation'; +import { ResolveConfigOrType } from '../../types/computerVision'; + +const YOLO_POSE_CONFIG = { + keypointMap: CocoKeypoint, + preprocessorConfig: undefined, + availableInputSizes: [384, 512, 640] as const, + defaultInputSize: 384, + defaultDetectionThreshold: 0.5, + defaultKeypointThreshold: 0.5, +} satisfies PoseEstimationConfig; + +const ModelConfigs = { + 'yolo26n-pose': YOLO_POSE_CONFIG, +} as const satisfies Record< + PoseEstimationModelName, + PoseEstimationConfig +>; + +type ModelConfigsType = typeof ModelConfigs; + +/** + * Resolves the {@link LabelEnum} for a given built-in pose estimation model name. + * @typeParam M - A built-in model name from {@link PoseEstimationModelName}. + * @category Types + */ +export type PoseEstimationKeypoints = + (typeof ModelConfigs)[M]['keypointMap']; + +type ModelNameOf = C['modelName']; + +/** @internal */ +type ResolveKeypoints = + ResolveConfigOrType; + +function mapPersonKeypoints( + raw: Keypoint[][], + entries: [string, number][], + maxIndex: number +): PersonKeypoints[] { + 'worklet'; + if (raw.length > 0 && raw[0]!.length <= maxIndex) { + throw new Error( + `Keypoint map references index ${maxIndex} but model returned ${raw[0]!.length} keypoints per person — keypointMap is incompatible with this model.` + ); + } + const out: PersonKeypoints[] = []; + for (const person of raw) { + const named: Record = {}; + for (const [name, idx] of entries) named[name] = person[idx]!; + out.push(named as PersonKeypoints); + } + return out; +} + +/** + * Pose estimation module for detecting human body keypoints. + * @typeParam T - Either a built-in model name (e.g. `'yolo26n-pose'`) + * or a custom {@link LabelEnum} keypoint map. + * @category Typescript API + */ +export class PoseEstimationModule< + T extends PoseEstimationModelName | LabelEnum, +> extends VisionModule>> { + private readonly keypointMap: ResolveKeypoints; + private readonly modelConfig: PoseEstimationConfig; + // Numeric TS enums double-list + // their keys at runtime (value → name); we keep only the (name, index) pairs + private readonly keypointEntries: [string, number][]; + private readonly maxKeypointIndex: number; + + private constructor( + keypointMap: ResolveKeypoints, + modelConfig: PoseEstimationConfig, + nativeModule: unknown + ) { + super(); + this.keypointMap = keypointMap; + this.modelConfig = modelConfig; + this.nativeModule = nativeModule; + this.keypointEntries = []; + for (const [name, value] of Object.entries(keypointMap)) { + if (typeof value === 'number') this.keypointEntries.push([name, value]); + } + this.maxKeypointIndex = Math.max(...this.keypointEntries.map(([, v]) => v)); + } + + /** + * Creates a pose estimation instance for a built-in model. + * @param namedSources - A {@link PoseEstimationModelSources} object specifying which model to load. + * @param onDownloadProgress - Optional callback to monitor download progress (0-1). + * @returns A Promise resolving to a `PoseEstimationModule` instance typed to the model's keypoint map. + */ + static async fromModelName( + namedSources: C, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise>> { + const { modelSource } = namedSources; + const modelConfig = ModelConfigs[ + namedSources.modelName + ] as PoseEstimationConfig; + const { keypointMap, preprocessorConfig } = modelConfig; + const normMean = preprocessorConfig?.normMean ?? []; + const normStd = preprocessorConfig?.normStd ?? []; + + const modelPath = await fetchModelPath(modelSource, onDownloadProgress); + const nativeModule = await global.loadPoseEstimation( + modelPath, + normMean, + normStd + ); + + return new PoseEstimationModule>( + keypointMap as ResolveKeypoints>, + modelConfig, + nativeModule + ); + } + + /** + * Creates a pose estimation instance with a user-provided model binary and keypoint map. + * Use this when working with a custom-exported model that is not one of the built-in presets. + * @param modelSource - A fetchable resource pointing to the model binary. + * @param config - A {@link PoseEstimationConfig} object with the keypoint map and optional preprocessing parameters. + * @param onDownloadProgress - Optional callback to monitor download progress (0-1). + * @returns A Promise resolving to a `PoseEstimationModule` instance typed to the provided keypoint map. + */ + static async fromCustomModel( + modelSource: ResourceSource, + config: PoseEstimationConfig, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise> { + const { keypointMap, preprocessorConfig } = config; + const normMean = preprocessorConfig?.normMean ?? []; + const normStd = preprocessorConfig?.normStd ?? []; + + const modelPath = await fetchModelPath(modelSource, onDownloadProgress); + const nativeModule = await global.loadPoseEstimation( + modelPath, + normMean, + normStd + ); + + return new PoseEstimationModule( + keypointMap as ResolveKeypoints, + config, + nativeModule + ); + } + + /** + * Get the keypoint map for this model. + * @returns Map of keypoint names to indices, e.g. `{ NOSE: 0, LEFT_EYE: 1, ... }`. + */ + getKeypointMap(): ResolveKeypoints { + return this.keypointMap; + } + + /** + * Returns the available input sizes for this model, or undefined if the model accepts any size. + * @returns a readonly number[] specifying what input sizes the model supports. + */ + getAvailableInputSizes(): readonly number[] | undefined { + return this.modelConfig.availableInputSizes; + } + + /** + * Override runOnFrame to provide an options-based API for VisionCamera integration. + * @returns A worklet function for frame processing. + */ + override get runOnFrame(): ( + frame: Frame, + isFrontCamera: boolean, + options?: PoseEstimationOptions + ) => PoseDetections> { + if (!this.nativeModule) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'Model is not loaded. Ensure the model has been loaded before using runOnFrame.' + ); + } + + const nativeGenerateFromFrame = this.nativeModule.generateFromFrame; + const defaultDetectionThreshold = + this.modelConfig.defaultDetectionThreshold ?? 0.5; + const defaultKeypointThreshold = + this.modelConfig.defaultKeypointThreshold ?? 0.5; + const defaultInputSize = this.modelConfig.defaultInputSize; + const availableInputSizes = this.modelConfig.availableInputSizes; + const keypointEntries = this.keypointEntries; + const maxKeypointIndex = this.maxKeypointIndex; + return ( + frame: Frame, + isFrontCamera: boolean, + options?: PoseEstimationOptions + ): PoseDetections> => { + 'worklet'; + + const detectionThreshold = + options?.detectionThreshold ?? defaultDetectionThreshold; + const keypointThreshold = + options?.keypointThreshold ?? defaultKeypointThreshold; + const inputSize = options?.inputSize ?? defaultInputSize; + + // Validate inputSize + if ( + availableInputSizes && + inputSize !== undefined && + !availableInputSizes.includes(inputSize) + ) { + throw new Error( + `Invalid inputSize: ${inputSize}. Available sizes: ${availableInputSizes.join(', ')}` + ); + } + + const methodName = + inputSize !== undefined ? `forward_${inputSize}` : 'forward'; + + let nativeBuffer: { pointer: bigint; release(): void } | null = null; + try { + nativeBuffer = frame.getNativeBuffer(); + const frameData = { + nativeBuffer: nativeBuffer.pointer, + orientation: frame.orientation, + isMirrored: isFrontCamera, + }; + const raw: Keypoint[][] = nativeGenerateFromFrame( + frameData, + detectionThreshold, + keypointThreshold, + methodName + ); + return mapPersonKeypoints>( + raw, + keypointEntries, + maxKeypointIndex + ); + } finally { + if (nativeBuffer?.release) { + nativeBuffer.release(); + } + } + }; + } + + /** + * Run pose estimation on an image. + * @param input - Image path/URI or PixelData + * @param options - Detection options including inputSize for multi-method models + * @returns Array of detected people, each with keypoints accessible via the keypoint enum + */ + override async forward( + input: string | PixelData, + options?: PoseEstimationOptions + ): Promise>> { + if (this.nativeModule == null) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'Model not loaded. Please load the model before calling forward().' + ); + } + + const detectionThreshold = + options?.detectionThreshold ?? + this.modelConfig.defaultDetectionThreshold ?? + 0.5; + const keypointThreshold = + options?.keypointThreshold ?? + this.modelConfig.defaultKeypointThreshold ?? + 0.5; + const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize; + + // Validate inputSize against availableInputSizes + if ( + this.modelConfig.availableInputSizes && + inputSize !== undefined && + !this.modelConfig.availableInputSizes.includes(inputSize) + ) { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidArgument, + `Invalid inputSize: ${inputSize}. Available sizes: ${this.modelConfig.availableInputSizes.join(', ')}` + ); + } + + const methodName = + inputSize !== undefined ? `forward_${inputSize}` : 'forward'; + + const raw: Keypoint[][] = + typeof input === 'string' + ? await this.nativeModule.generateFromString( + input, + detectionThreshold, + keypointThreshold, + methodName + ) + : await this.nativeModule.generateFromPixels( + input, + detectionThreshold, + keypointThreshold, + methodName + ); + + return mapPersonKeypoints>( + raw, + this.keypointEntries, + this.maxKeypointIndex + ); + } +} diff --git a/packages/react-native-executorch/src/types/computerVision.ts b/packages/react-native-executorch/src/types/computerVision.ts index a5d1dee7b2..da15999100 100644 --- a/packages/react-native-executorch/src/types/computerVision.ts +++ b/packages/react-native-executorch/src/types/computerVision.ts @@ -1,5 +1,18 @@ import { LabelEnum } from './common'; +/* + * Automatically resolves the type to either Configs[NameOrType][OutputKey], if the NameOrType + * is a key of Configs. Otherwise, returns NameOrType. + * @internal + */ +export type ResolveConfigOrType< + NameOrType, + Configs extends Record>, + OutputKey extends string = 'output', +> = NameOrType extends keyof Configs + ? Configs[NameOrType][OutputKey] + : NameOrType; + /** * Given a model configs record (mapping model names to `{ labelMap }`) and a * type `T` (either a model name key or a raw {@link LabelEnum}), resolves to @@ -7,10 +20,6 @@ import { LabelEnum } from './common'; * @internal */ export type ResolveLabels< - T, + NameOrLabels, Configs extends Record, -> = T extends keyof Configs - ? Configs[T]['labelMap'] - : T extends LabelEnum - ? T - : never; +> = ResolveConfigOrType; diff --git a/packages/react-native-executorch/src/types/poseEstimation.ts b/packages/react-native-executorch/src/types/poseEstimation.ts new file mode 100644 index 0000000000..03afc592c3 --- /dev/null +++ b/packages/react-native-executorch/src/types/poseEstimation.ts @@ -0,0 +1,159 @@ +import { Frame, LabelEnum, PixelData, ResourceSource } from './common'; +import { CocoKeypoint } from '../constants/poseEstimation'; +import { RnExecutorchError } from '../errors/errorUtils'; + +export { CocoKeypoint }; + +/** + * A single keypoint with x, y coordinates + * @category Types + */ +export interface Keypoint { + x: number; + y: number; +} + +/** + * Keypoints for a single detected person, keyed by name from the keypoint map. + * @typeParam K - The {@link LabelEnum} for this model. + * @category Types + * @example + * ```ts + * person.NOSE; // { x, y } + * ``` + */ +export type PersonKeypoints = { + readonly [Name in keyof K]: Keypoint; +}; + +/** + * Pose estimation result containing all detected people. + * @category Types + */ +export type PoseDetections = + PersonKeypoints[]; + +/** + * Configuration for pose estimation model behavior. + * @category Types + * @typeParam K - The keypoint enum type for this model. + */ +export type PoseEstimationConfig = { + keypointMap: K; + preprocessorConfig?: { + normMean?: readonly [number, number, number]; + normStd?: readonly [number, number, number]; + }; + defaultDetectionThreshold?: number; + defaultKeypointThreshold?: number; +} & ( + | { + availableInputSizes: readonly number[]; + defaultInputSize: number; + } + | { + availableInputSizes?: undefined; + defaultInputSize?: undefined; + } +); + +/** + * Per-model config for {@link PoseEstimationModule.fromModelName}. + * Each model name maps to its required fields. + * @category Types + */ +export type PoseEstimationModelSources = { + modelName: 'yolo26n-pose'; + modelSource: ResourceSource; +}; + +/** + * Union of all built-in pose estimation model names. + * @category Types + */ +export type PoseEstimationModelName = PoseEstimationModelSources['modelName']; + +/** + * Props for usePoseEstimation hook. + * @typeParam C - A {@link PoseEstimationModelSources} config specifying which built-in model to load. + * @category Types + */ +export interface PoseEstimationProps { + model: C; + preventLoad?: boolean; +} + +/** + * Options for pose estimation inference + * @category Types + */ +export interface PoseEstimationOptions { + detectionThreshold?: number; + /** + * Per-keypoint visibility threshold (0-1). Keypoints whose visibility + * score is below this are emitted as (-1, -1) so consumers can skip them. + * Defaults to the model config's `defaultKeypointThreshold` (typically 0.5). + */ + keypointThreshold?: number; + /** + * Input size for multi-method models. + * For YOLO models, valid values are typically 384, 512, or 640. + * Maps to forward_384, forward_512, forward_640 methods. + */ + inputSize?: number; +} + +/** + * Return type of usePoseEstimation hook. + * @typeParam K - The {@link LabelEnum} representing the model's keypoint schema. + * @category Types + */ +export interface PoseEstimationType { + /** + * Contains the error object if the model failed to load or encountered a runtime error. + */ + error: RnExecutorchError | null; + + /** + * Indicates whether the model is loaded and ready to process images. + */ + isReady: boolean; + + /** + * Indicates whether the model is currently processing an image. + */ + isGenerating: boolean; + + /** + * Represents the download progress of the model binary as a value between 0 and 1. + */ + downloadProgress: number; + + /** + * Run pose estimation on an image. + * @param input - Image path/URI or PixelData + * @param options - Detection options + * @returns Array of detected people, each with keypoints accessible via the keypoint enum + */ + forward: ( + input: string | PixelData, + options?: PoseEstimationOptions + ) => Promise>; + + /** + * Returns the available input sizes for multi-method models. + * Returns undefined for single-method models. + */ + getAvailableInputSizes: () => readonly number[] | undefined; + + /** + * Synchronous worklet function for real-time VisionCamera frame processing. + */ + runOnFrame: + | (( + frame: Frame, + isFrontCamera: boolean, + options?: PoseEstimationOptions + ) => PoseDetections) + | null; +} diff --git a/packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts b/packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts index 9645afbaa9..46f4b34e2d 100644 --- a/packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts +++ b/packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts @@ -150,7 +150,7 @@ export namespace ResourceFetcherUtils { /** * Checks whether the given URL conforms to the huggingface.co/software-mansion schema. * @param url - the URL to the remote file - * @returns {boolean} Boolean specifying whether the given URL conforms to our HF repo schema + * @returns Boolean specifying whether the given URL conforms to our HF repo schema */ export function isUrlHfRepo(url: URL): boolean { return (