Skip to content

[CUDA] Support volumetric (3-D) grid sampling in the CUDA GridSample operator#27201

Merged
tianleiwu merged 6 commits intomainfrom
hari/5d_grid_sample
Mar 13, 2026
Merged

[CUDA] Support volumetric (3-D) grid sampling in the CUDA GridSample operator#27201
tianleiwu merged 6 commits intomainfrom
hari/5d_grid_sample

Conversation

@hariharans29
Copy link
Member

@hariharans29 hariharans29 commented Jan 29, 2026

Description

  1. Supports volumetric input grid sampling in the CUDA EP GridSample operator (i.e.) 5-D input tensor a.k.a 3-D spatial data
  2. Registers the CUDA GridSample operator for opsets 20 and 22
  3. Supports both NCHW and NHWC layouts for volumetric inputs
  4. Does not support cubic mode for volumetric inputs for now and this is consistent with the CPU version of the implementation and hence will not cause "functional regression" (i.e.) cubic mode for 3-D spatial data is not supported on CPU and CUDA before and after this change. This is a TODO for the future.
  5. There are enough unit tests in grid_sample_test.cc to cover the volumetric input case and this is run in both NCHW (NCDHW for volumetric case) and NHWC (NDHWC for volumetric case) layouts for the CUDA EP

Motivation and Context

Resolve #21382
Resolve #18942
Resolve #16581
Resolve #18313

Related CPU PRs (for opset 20 and opset 22): #17744 && #23344

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds CUDA Execution Provider support for volumetric (5-D / 3-D spatial) GridSample, aligning CUDA behavior with ONNX opset semantics (including opsets 20 and 22) and supporting both NCHW and NHWC layouts.

Changes:

  • Implement 3D (volumetric) CUDA kernel path for GridSample and wire it into the existing CUDA operator.
  • Register CUDA GridSample kernels for ONNX opsets 20–21 and 22 (and NHWC variants where applicable).
  • Update grid sample tests/execution provider selection to exercise CUDA for newer opsets.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc Updates test execution providers and adds/adjusts opset coverage for GridSample cases.
onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h Declares the new 3D CUDA implementation entry point.
onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu Implements the 3D CUDA grid sampling kernel + host launcher; minor 2D fixups/comments.
onnxruntime/core/providers/cuda/tensor/grid_sample.h Tracks opset version in the CUDA kernel class.
onnxruntime/core/providers/cuda/tensor/grid_sample.cc Adds opset-aware attribute parsing, 4D/5D validation, and dispatch to 2D vs 3D CUDA impl; registers opset 20/22 kernels.
onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc Registers NHWC GridSample kernels for opset ranges 16–19, 20–21, and 22.
onnxruntime/core/providers/cuda/cuda_execution_provider.cc Registers ONNX-domain CUDA GridSample kernels for opset ranges 16–19, 20–21, and 22.
onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc Adds clarifying comment about preserving batch dim in perm generation.
Comments suppressed due to low confidence (1)

onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc:734

  • Test name indicates a 5-D case, but X_shape and Grid_shape are 4-D in this test. This makes the generated test suite confusing and can hide missing 5-D coverage. Either rename the test back to ..._20_4D_... or update the shapes/data to be truly 5-D.
TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
  OpTester test("GridSample", 20);
  std::string mode = "linear";
  std::string padding_mode = "border";
  int64_t align_corners = 1;
  std::initializer_list<int64_t> X_shape{2, 2, 3, 2};
  std::initializer_list<TypeParam> X_data{TypeParam(-1.916003f), TypeParam(0.150784f), TypeParam(-0.179898f), TypeParam(0.402727f), TypeParam(-0.549764f), TypeParam(1.772484f), TypeParam(1.014343f), TypeParam(0.502823f), TypeParam(0.976771f), TypeParam(-0.071957f), TypeParam(0.519875f), TypeParam(0.408665f), TypeParam(1.435640f), TypeParam(-0.807775f), TypeParam(-0.181661f), TypeParam(-0.574026f), TypeParam(-0.335351f), TypeParam(-0.155602f), TypeParam(0.348749f), TypeParam(1.055618f), TypeParam(0.737784f), TypeParam(-0.394725f), TypeParam(0.597608f), TypeParam(0.006105f)};
  std::initializer_list<int64_t> Grid_shape{2, 3, 2, 2};
  std::initializer_list<TypeParam> Grid_data{TypeParam(-0.189838f), TypeParam(-1.050410f), TypeParam(-1.072351f), TypeParam(-0.930754f), TypeParam(-0.502573f), TypeParam(0.186642f), TypeParam(-0.564332f), TypeParam(-0.042774f), TypeParam(-0.143740f), TypeParam(1.097448f), TypeParam(-0.547044f), TypeParam(1.127440f), TypeParam(-0.921224f), TypeParam(-1.001202f), TypeParam(0.390232f), TypeParam(-0.698394f), TypeParam(0.615509f), TypeParam(-0.663897f), TypeParam(0.944958f), TypeParam(1.161950f), TypeParam(0.076823f), TypeParam(0.256464f), TypeParam(1.118784f), TypeParam(0.711380f)};

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu
Copy link
Contributor

Overall, LGTM.

Below are some areas need attention:

  1. Test Name vs Data Mismatch - Review all renamed tests to ensure test names accurately reflect the test data dimensions (4D vs 5D). For example Copilot mentioned that test_grid_sample_20_5D_bilinear_border_align_corners indicates a 5-D case, but X_shape and Grid_shape are 4-D. This is confusing and may hide missing 5-D coverage.
  2. Performance - Consider adding performance benchmarks for the new 3D kernel path
  3. Edge Cases - Verify behavior at boundary conditions for 3D grid sampling

@tianleiwu tianleiwu enabled auto-merge (squash) March 13, 2026 02:39
@tianleiwu tianleiwu merged commit d8c1826 into main Mar 13, 2026
88 of 97 checks passed
@tianleiwu tianleiwu deleted the hari/5d_grid_sample branch March 13, 2026 02:41
tianleiwu added a commit that referenced this pull request Mar 13, 2026
…sion.py (#27642)

# Description

This PR addresses a build error and subsequent test failures related to
recent changes in GridSample and the transformer optimizer. Related PRs:
#27201, #27556.

## Changes

### 1. Fix GridSample Build Error
- Removed an unused local variable `mode_str` in
`onnxruntime/core/providers/cuda/tensor/grid_sample.cc` that was causing
a warning (treated as error) about shadowing a member variable.
- Ref:
[`grid_sample.cc`](https://github.com/microsoft/onnxruntime/blob/c979a2407f/onnxruntime/core/providers/cuda/tensor/grid_sample.cc#L54)

### 2. Update GridSample Tests
- Updated
`onnxruntime/test/providers/cpu/tensor/grid_sample_test_custom.inc` to
use default execution providers in `RunTests` instead of a hardcoded
opset version, ensuring compatibility across different environments.

### 3. Revert Transformer Fusion Fallback
- Reverted a recent change in
`onnxruntime/python/tools/transformers/fusion_skiplayernorm.py` that
enabled a fallback for `SkipLayerNormalization` fusion when symbolic
shape inference fails.
- This revert was necessary to avoid regressions in GPT-2 tests where
model definitions contain typos that intentionally (or coincidentally)
break shape inference.
- Ref:
[`fusion_skiplayernorm.py`](https://github.com/microsoft/onnxruntime/blob/c979a2407f/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py#L113)

### 4. Restore Transformer Test Parity
- Updated
`onnxruntime/test/python/transformers/test_attention_fusion.py`
specifically `test_qwen3_normalization_fusion` to match the expected
node counts after reverting the fusion fallback.
- Ref:
[`test_attention_fusion.py`](https://github.com/microsoft/onnxruntime/blob/c979a2407f/onnxruntime/test/python/transformers/test_attention_fusion.py#L398)

## Verification

- `build_cuda.sh` completed successfully.
- `onnxruntime/test/python/transformers/test_attention_fusion.py` passes
with "OK".
- `lintrunner -a` reports no issues.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

3 participants