From 5664b86a2d35c294e463729f2d0d2d6d3c5790f3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 18 Feb 2026 07:35:49 -0500 Subject: [PATCH 1/6] More tweaks --- ext/TensorKitCUDAExt/TensorKitCUDAExt.jl | 3 +- ext/TensorKitCUDAExt/auxiliary.jl | 28 +++++++++++++ ext/TensorKitCUDAExt/cutensormap.jl | 53 ++++++++++++++++++++++-- src/auxiliary/auxiliary.jl | 2 +- src/tensors/braidingtensor.jl | 9 ++-- src/tensors/tensoroperations.jl | 10 +++-- src/tensors/treetransformers.jl | 2 +- 7 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 ext/TensorKitCUDAExt/auxiliary.jl diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index f5efb98bb..1b2932f97 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -10,7 +10,7 @@ using TensorKit.Factorizations using TensorKit.Strided using TensorKit.Factorizations: AbstractAlgorithm using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check -import TensorKit: randisometry, rand, randn +import TensorKit: randisometry, rand, randn, _copyto!, _add_general_kernel_nonthreaded!, blocktype using TensorKit: MatrixAlgebraKit @@ -18,5 +18,6 @@ using Random include("cutensormap.jl") include("truncation.jl") +include("auxiliary.jl") end diff --git a/ext/TensorKitCUDAExt/auxiliary.jl b/ext/TensorKitCUDAExt/auxiliary.jl new file mode 100644 index 000000000..0b11a962f --- /dev/null +++ b/ext/TensorKitCUDAExt/auxiliary.jl @@ -0,0 +1,28 @@ +function TensorKit._copyto!(A::StridedView{TA, 1, <:CuArray{TA}}, B::StridedView{TB, 2, <:CuArray{TB}}) where {TA, TB} + length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))")) + + Adata = parent(A) + Astr = stride(A, 1) + IA = A.offset + + Bdata = parent(B) + Bstr = strides(B) + + IB_1 = B.offset + # build index arrays + IAs = Int[] + IBs = Int[] + @inbounds for _ in axes(B, 2) + IB = IB_1 + for _ in axes(B, 1) + IA += Astr + append!(IAs, IA) + IB += Bstr[1] + append!(IBs, IB) + end + IB_1 += Bstr[2] + end + Adata[IAs] .= Bdata[IBs] + + return A +end diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index f065c2ec1..df621acf6 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -7,6 +7,17 @@ function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) end +#=function TensorKit.TensorMap{T, S₁, N₁, N₂, A}( + ::UndefInitializer, space::TensorMapSpace{S₂, N₁, N₂} +) where {T, S₁, S₂ <: TensorKit.ElementarySpace, N₁, N₂, A <: CuVector{T}} + d = TensorKit.fusionblockstructure(space).totaldim + data = A(undef, d) + if !isbitstype(T) + zerovector!(data) + end + return TensorKit.TensorMap{T, S₂, A}(data, space) +end=# + # project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}} h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V) @@ -17,6 +28,10 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V) end +function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S} + return SubArray{T, 1, CuVector{T, CUDA.DeviceMemory}, Tuple{UnitRange{Int}}, true} +end + for (fname, felt) in ((:zeros, :zero), (:ones, :one)) @eval begin function CUDA.$fname( @@ -102,9 +117,21 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S} end function Base.convert( - TT::Type{CuTensorMap{T, S, N₁, N₂}}, - t::AbstractTensorMap{<:Any, S, N₁, N₂} - ) where {T, S, N₁, N₂} + TT::Type{TensorMap{T, S, N₁, N₂, A}}, + t::TensorMap{T, S, N₁, N₂, AA} + ) where {T, S, N₁, N₂, A <: CuArray{T}, AA} + if typeof(t) === TT + return t + else + tnew = TT(undef, space(t)) + return copy!(tnew, t) + end +end + +function Base.convert( + TT::Type{TensorMap{T, S, N₁, N₂, A}}, + t::AdjointTensorMap + ) where {T, S, N₁, N₂, A <: CuArray{T}} if typeof(t) === TT return t else @@ -140,6 +167,8 @@ end TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} = CuArray{T, N, CUDA.default_memory} +TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{CuArray{T, N}}) where {T, N} = + CuArray{T, N, CUDA.default_memory} # CuTensorMap exponentation: @@ -168,3 +197,21 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth) return tf end end + +function TensorKit._add_general_kernel_nonthreaded!( + tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend... + ) + # preallocate buffers + buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer) + + for subtransformer in transformer.data + # Special case without intermediate buffers whenever there is only a single block + if length(subtransformer[1]) == 1 + TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...) + else + cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...) + TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...) + end + end + return nothing +end diff --git a/src/auxiliary/auxiliary.jl b/src/auxiliary/auxiliary.jl index a7105cda6..797a55505 100644 --- a/src/auxiliary/auxiliary.jl +++ b/src/auxiliary/auxiliary.jl @@ -60,7 +60,7 @@ end # Low-overhead implementation of `copyto!` for specific case of `stride(B, 1) < stride(B, 2)` # used in indexmanipulations: avoids the overhead of Strided.jl function _copyto!(A::StridedView{<:Any, 1}, B::StridedView{<:Any, 2}) - length(A) == length(B) || throw(DimensionMismatch()) + length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))")) Adata = parent(A) Astr = stride(A, 1) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 0070bc2d4..9d7a05af5 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -171,12 +171,15 @@ end has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false function add_transform!( tdst::AbstractTensorMap, - tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple, + tsrc::BraidingTensor{T, S}, + (p₁, p₂)::Index2Tuple, fusiontreetransform, α::Number, β::Number, backend::AbstractBackend... - ) + ) where {T, S} + tsrc_map = TensorMapWithStorage{scalartype(tdst), storagetype(tdst)}(undef, (tsrc.V2 ⊗ tsrc.V1) ← (tsrc.V1 ⊗ tsrc.V2)) + copy!(tsrc_map, tsrc) return add_transform!( - tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β, + tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β, backend... ) end diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 0820fe1af..a9074ca43 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -419,8 +419,10 @@ end # Scalar implementation #----------------------- function scalar(t::AbstractTensorMap{T, S, 0, 0}) where {T, S} - Bs = collect(blocks(t)) - inds = findall(!iszero ∘ last, Bs) - isempty(inds) && return zero(scalartype(t)) - return only(last(Bs[only(inds)])) + Bs = blocks(t) + B_ends = collect.(map(last, Bs)) + nz_B_ends = [!iszero.(B) for B in B_ends] + valid_Bs = filter(any, B_ends) + isempty(valid_Bs) && return zero(scalartype(t)) + return only(last(first(valid_Bs))) end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 36cd3926d..b1d2008b5 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -46,7 +46,7 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc) end const _GenericTransformerData{T, N} = Tuple{ - Matrix{T}, + DenseMatrix{T}, Tuple{NTuple{N, Int}, Vector{Tuple{NTuple{N, Int}, Int}}}, Tuple{NTuple{N, Int}, Vector{Tuple{NTuple{N, Int}, Int}}}, } From 0c903ac0694f7ae6a33afc903923fcf680f7287d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 20 Feb 2026 05:56:57 -0500 Subject: [PATCH 2/6] Even more small tweaks --- ext/TensorKitCUDAExt/cutensormap.jl | 13 +------------ src/tensors/abstracttensor.jl | 13 ++++++++----- src/tensors/tensoroperations.jl | 10 ++++------ 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index df621acf6..37b2e90cb 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -7,17 +7,6 @@ function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) end -#=function TensorKit.TensorMap{T, S₁, N₁, N₂, A}( - ::UndefInitializer, space::TensorMapSpace{S₂, N₁, N₂} -) where {T, S₁, S₂ <: TensorKit.ElementarySpace, N₁, N₂, A <: CuVector{T}} - d = TensorKit.fusionblockstructure(space).totaldim - data = A(undef, d) - if !isbitstype(T) - zerovector!(data) - end - return TensorKit.TensorMap{T, S₂, A}(data, space) -end=# - # project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}} h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V) @@ -29,7 +18,7 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr end function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S} - return SubArray{T, 1, CuVector{T, CUDA.DeviceMemory}, Tuple{UnitRange{Int}}, true} + return CuMatrix{T, CUDA.DeviceMemory} end for (fname, felt) in ((:zeros, :zero), (:ones, :one)) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 2d7239460..9293f1375 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t)) function storagetype(::Type{T}) where {T <: AbstractTensorMap} if T isa Union # attempt to be slightly more specific by promoting unions - Ma = storagetype(T.a) - Mb = storagetype(T.b) - return promote_storagetype(Ma, Mb) + return promote_storagetype(T.a, T.b) + elseif eltype(T) isa Union + # attempt to be slightly more specific by promoting unions + TU = eltype(T) + return promote_storagetype(TU.a, TU.b) else # fallback definition by using scalartype return similarstoragetype(scalartype(T)) @@ -103,8 +105,9 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} = # implement on tensors similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT)) -similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} = - similarstoragetype(storagetype(TT), T) +function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} + return similarstoragetype(storagetype(TT), T) +end # implement on arrays similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index a9074ca43..0820fe1af 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -419,10 +419,8 @@ end # Scalar implementation #----------------------- function scalar(t::AbstractTensorMap{T, S, 0, 0}) where {T, S} - Bs = blocks(t) - B_ends = collect.(map(last, Bs)) - nz_B_ends = [!iszero.(B) for B in B_ends] - valid_Bs = filter(any, B_ends) - isempty(valid_Bs) && return zero(scalartype(t)) - return only(last(first(valid_Bs))) + Bs = collect(blocks(t)) + inds = findall(!iszero ∘ last, Bs) + isempty(inds) && return zero(scalartype(t)) + return only(last(Bs[only(inds)])) end From 81550ae5374ddadf449f5681c66fbf4c6bc130eb Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 26 Feb 2026 13:19:18 +0100 Subject: [PATCH 3/6] Tests now unbroken --- test/cuda/tensors.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index 0fad13473..7bdd90f9d 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -98,8 +98,8 @@ for V in spacelist next = @constinferred Nothing iterate(bs, state) b2 = @constinferred block(t, first(blocksectors(t))) @test b1 == b2 - @test_broken eltype(bs) === Pair{typeof(c), typeof(b1)} - @test_broken typeof(b1) === TensorKit.blocktype(t) + @test eltype(bs) === Pair{typeof(c), typeof(b1)} + @test typeof(b1) === TensorKit.blocktype(t) @test typeof(c) === sectortype(t) end end @@ -162,8 +162,8 @@ for V in spacelist next = @constinferred Nothing iterate(bs, state) b2 = @constinferred block(t', first(blocksectors(t'))) @test b1 == b2 - @test_broken eltype(bs) === Pair{typeof(c), typeof(b1)} - @test_broken typeof(b1) === TensorKit.blocktype(t') + @test eltype(bs) === Pair{typeof(c), typeof(b1)} + @test typeof(b1) === TensorKit.blocktype(t') @test typeof(c) === sectortype(t) # linear algebra @test isa(@constinferred(norm(t)), real(T)) From 85a7a683786675463ef5a5431fe160a9482f380e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 27 Feb 2026 12:05:08 +0100 Subject: [PATCH 4/6] Apply Lukas' suggestions --- ext/TensorKitCUDAExt/cutensormap.jl | 29 +---------------------------- src/tensors/braidingtensor.jl | 2 +- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index 37b2e90cb..a5c6a7bf1 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -105,30 +105,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S} return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)] end -function Base.convert( - TT::Type{TensorMap{T, S, N₁, N₂, A}}, - t::TensorMap{T, S, N₁, N₂, AA} - ) where {T, S, N₁, N₂, A <: CuArray{T}, AA} - if typeof(t) === TT - return t - else - tnew = TT(undef, space(t)) - return copy!(tnew, t) - end -end - -function Base.convert( - TT::Type{TensorMap{T, S, N₁, N₂, A}}, - t::AdjointTensorMap - ) where {T, S, N₁, N₂, A <: CuArray{T}} - if typeof(t) === TT - return t - else - tnew = TT(undef, space(t)) - return copy!(tnew, t) - end -end - function LinearAlgebra.isposdef(t::CuTensorMap) domain(t) == codomain(t) || throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) @@ -154,11 +130,8 @@ function Base.promote_rule( return CuTensorMap{T, S, N₁, N₂} end -TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} = +TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} = CuArray{T, N, CUDA.default_memory} -TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{CuArray{T, N}}) where {T, N} = - CuArray{T, N, CUDA.default_memory} - # CuTensorMap exponentation: function TensorKit.exp!(t::CuTensorMap) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 9d7a05af5..4fd68f7da 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -176,7 +176,7 @@ function add_transform!( fusiontreetransform, α::Number, β::Number, backend::AbstractBackend... ) where {T, S} - tsrc_map = TensorMapWithStorage{scalartype(tdst), storagetype(tdst)}(undef, (tsrc.V2 ⊗ tsrc.V1) ← (tsrc.V1 ⊗ tsrc.V2)) + tsrc_map = similar(tdst, storagetype(tdst), space(tsrc)) copy!(tsrc_map, tsrc) return add_transform!( tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β, From b34c63e9e1be60f19dd33fdbc587b054a3190ac4 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 3 Mar 2026 14:36:12 +0100 Subject: [PATCH 5/6] Force latest Strided? --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7e82876c5..b44c75742 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ Printf = "1" Random = "1" SafeTestsets = "0.1" ScopedValues = "1.3.0" -Strided = "2" +Strided = "2.3.3" TensorKitSectors = "0.3.5" TensorOperations = "5.1" Test = "1" From 0c88fd75f8f6d4f530fa4c5bd4b8ac24dd39280c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 3 Mar 2026 12:54:40 -0500 Subject: [PATCH 6/6] Uno reverse on _copyto --- ext/TensorKitCUDAExt/TensorKitCUDAExt.jl | 1 - ext/TensorKitCUDAExt/auxiliary.jl | 28 ------------------------ src/auxiliary/auxiliary.jl | 6 +++-- 3 files changed, 4 insertions(+), 31 deletions(-) delete mode 100644 ext/TensorKitCUDAExt/auxiliary.jl diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index 1b2932f97..4ee4865f1 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -18,6 +18,5 @@ using Random include("cutensormap.jl") include("truncation.jl") -include("auxiliary.jl") end diff --git a/ext/TensorKitCUDAExt/auxiliary.jl b/ext/TensorKitCUDAExt/auxiliary.jl deleted file mode 100644 index 0b11a962f..000000000 --- a/ext/TensorKitCUDAExt/auxiliary.jl +++ /dev/null @@ -1,28 +0,0 @@ -function TensorKit._copyto!(A::StridedView{TA, 1, <:CuArray{TA}}, B::StridedView{TB, 2, <:CuArray{TB}}) where {TA, TB} - length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))")) - - Adata = parent(A) - Astr = stride(A, 1) - IA = A.offset - - Bdata = parent(B) - Bstr = strides(B) - - IB_1 = B.offset - # build index arrays - IAs = Int[] - IBs = Int[] - @inbounds for _ in axes(B, 2) - IB = IB_1 - for _ in axes(B, 1) - IA += Astr - append!(IAs, IA) - IB += Bstr[1] - append!(IBs, IB) - end - IB_1 += Bstr[2] - end - Adata[IAs] .= Bdata[IBs] - - return A -end diff --git a/src/auxiliary/auxiliary.jl b/src/auxiliary/auxiliary.jl index 797a55505..e7bb0f586 100644 --- a/src/auxiliary/auxiliary.jl +++ b/src/auxiliary/auxiliary.jl @@ -57,9 +57,11 @@ function _interleave(a::NTuple{N}, b::NTuple{N}) where {N} return (a[1], b[1], _interleave(tail(a), tail(b))...) end +_copyto!(A, B) = copyto!(A, B) + # Low-overhead implementation of `copyto!` for specific case of `stride(B, 1) < stride(B, 2)` -# used in indexmanipulations: avoids the overhead of Strided.jl -function _copyto!(A::StridedView{<:Any, 1}, B::StridedView{<:Any, 2}) +# for CPU-hosted Arrays # used in indexmanipulations: avoids the overhead of Strided.jl +function _copyto!(A::StridedView{TA, 1, AA}, B::StridedView{TB, 2, BB}) where {TA <: Number, TB <: Number, AA <: DenseArray{TA}, BB <: DenseArray{TB}} length(A) == length(B) || throw(DimensionMismatch(lazy"length of A ($(length(A))) does not match length of B ($(length(B))")) Adata = parent(A)