diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorShape.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorShape.cs index 2f886acc65240d..577656f536c156 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorShape.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorShape.cs @@ -853,12 +853,22 @@ public static TensorShape Create(T[]? array) if (array is not null) { int linearLength = array.Length; + nint stride = 1; + + TensorFlags flags = TensorFlags.IsDense | TensorFlags.HasAnyDenseDimensions; + + if (linearLength <= 1) + { + stride = 0; + flags |= TensorFlags.IsBroadcast; + } + return new TensorShape( flattenedLength: linearLength, linearLength: linearLength, lengths: [linearLength], - strides: [1], - TensorFlags.IsDense | TensorFlags.HasAnyDenseDimensions + strides: [stride], + flags ); } return default; @@ -908,14 +918,22 @@ public static TensorShape Create(ref readonly T reference, nint linearLength, { if (!Unsafe.IsNullRef(in reference)) { + nint stride = 1; + TensorFlags flags = pinned ? TensorFlags.IsPinned : TensorFlags.None; flags |= TensorFlags.IsDense | TensorFlags.HasAnyDenseDimensions; + if (linearLength <= 1) + { + stride = 0; + flags |= TensorFlags.IsBroadcast; + } + return new TensorShape( flattenedLength: linearLength, linearLength: linearLength, lengths: [linearLength], - strides: [1], + strides: [stride], flags ); } diff --git a/src/libraries/System.Numerics.Tensors/tests/ReadOnlyTensorSpanTests.cs b/src/libraries/System.Numerics.Tensors/tests/ReadOnlyTensorSpanTests.cs index 3755acecd5f356..0490d75771c061 100644 --- a/src/libraries/System.Numerics.Tensors/tests/ReadOnlyTensorSpanTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/ReadOnlyTensorSpanTests.cs @@ -255,7 +255,7 @@ public static void ReadOnlyTensorSpanArrayConstructorTests() Assert.Equal(1, spanInt.Rank); Assert.Equal(0, spanInt.Lengths[0]); Assert.Equal(0, spanInt.FlattenedLength); - Assert.Equal(1, spanInt.Strides[0]); + Assert.Equal(0, spanInt.Strides[0]); // Make sure it still throws on index 0 Assert.Throws(() => { var spanInt = new ReadOnlyTensorSpan(b); diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs index 77b295a8c36f9d..f1a1c736e94e16 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs @@ -711,7 +711,7 @@ public static void TensorSpanArrayConstructorTests() Assert.Equal(1, spanInt.Rank); Assert.Equal(0, spanInt.Lengths[0]); Assert.Equal(0, spanInt.FlattenedLength); - Assert.Equal(1, spanInt.Strides[0]); + Assert.Equal(0, spanInt.Strides[0]); // Make sure it still throws on index 0 Assert.Throws(() => { var spanInt = new TensorSpan(b); diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs index db11389bd64c8c..f3e280f67587b3 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs @@ -618,6 +618,77 @@ public static void TensorFactoryCreateTests() }); } + [Fact] + public static void TensorCreateSingleElementTests() + { + // Tensor.Create with a single-element array should have stride 0 + Tensor src = Tensor.Create([1.0]); + Assert.Equal(1, src.Rank); + Assert.Equal(1, src.Lengths[0]); + Assert.Equal(0, src.Strides[0]); + Assert.Equal(1, src.FlattenedLength); + Assert.Equal(1.0, src[0]); + + // CreateFromShapeUninitialized without strides should work + Tensor dst = Tensor.CreateFromShapeUninitialized(src.Lengths); + Assert.Equal(1, dst.Rank); + Assert.Equal(1, dst.Lengths[0]); + Assert.Equal(0, dst.Strides[0]); + Assert.Equal(1, dst.FlattenedLength); + + // CopyTo should succeed + src.CopyTo(dst); + Assert.Equal(1.0, dst[0]); + + // CreateFromShapeUninitialized with explicit strides should work + dst = Tensor.CreateFromShapeUninitialized(src.Lengths, src.Strides); + Assert.Equal(1, dst.Rank); + Assert.Equal(1, dst.Lengths[0]); + Assert.Equal(0, dst.Strides[0]); + Assert.Equal(1, dst.FlattenedLength); + + src.CopyTo(dst); + Assert.Equal(1.0, dst[0]); + + // CreateFromShape without strides should also work + dst = Tensor.CreateFromShape(src.Lengths); + Assert.Equal(1, dst.Rank); + Assert.Equal(1, dst.Lengths[0]); + Assert.Equal(0, dst.Strides[0]); + Assert.Equal(1, dst.FlattenedLength); + + src.CopyTo(dst); + Assert.Equal(1.0, dst[0]); + + // CreateFromShape with explicit strides should also work + dst = Tensor.CreateFromShape(src.Lengths, src.Strides); + Assert.Equal(1, dst.Rank); + Assert.Equal(1, dst.Lengths[0]); + Assert.Equal(0, dst.Strides[0]); + Assert.Equal(1, dst.FlattenedLength); + + src.CopyTo(dst); + Assert.Equal(1.0, dst[0]); + + // TensorSpan from single-element span should also have stride 0 + Span span = [42.0]; + TensorSpan tensorSpan = new TensorSpan(span); + Assert.Equal(1, tensorSpan.Rank); + Assert.Equal(1, tensorSpan.Lengths[0]); + Assert.Equal(0, tensorSpan.Strides[0]); + Assert.Equal(1, tensorSpan.FlattenedLength); + Assert.Equal(42.0, tensorSpan[0]); + + // ReadOnlyTensorSpan from single-element span should also have stride 0 + ReadOnlySpan roSpan = [42.0]; + ReadOnlyTensorSpan roTensorSpan = new ReadOnlyTensorSpan(roSpan); + Assert.Equal(1, roTensorSpan.Rank); + Assert.Equal(1, roTensorSpan.Lengths[0]); + Assert.Equal(0, roTensorSpan.Strides[0]); + Assert.Equal(1, roTensorSpan.FlattenedLength); + Assert.Equal(42.0, roTensorSpan[0]); + } + [Fact] public static void TensorCosineSimilarityTests() {