diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index a8476abb3..cd4cfae62 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -5,7 +5,11 @@ using Compat using Accessors: @set, @reset using VectorInterface import VectorInterface as VI -using TensorKit, KrylovKit, OptimKit, TensorOperations + +using TensorKit +using TensorKit: TruncationScheme + +using KrylovKit, OptimKit, TensorOperations using ChainRulesCore, Zygote using LoggingExtras @@ -83,7 +87,8 @@ using .Defaults: set_scheduler! export set_scheduler! export SVDAdjoint, FullSVDReverseRule, IterSVD export CTMRGEnv, SequentialCTMRG, SimultaneousCTMRG -export FixedSpaceTruncation, HalfInfiniteProjector, FullInfiniteProjector +export FixedSpaceTruncation, SiteDependentTruncation +export HalfInfiniteProjector, FullInfiniteProjector export LocalOperator, physicalspace export expectation_value, cost_function, product_peps, correlation_length, network_value export correlator diff --git a/src/algorithms/time_evolution/simpleupdate.jl b/src/algorithms/time_evolution/simpleupdate.jl index 01eb070d0..db6e0aa49 100644 --- a/src/algorithms/time_evolution/simpleupdate.jl +++ b/src/algorithms/time_evolution/simpleupdate.jl @@ -12,7 +12,7 @@ struct SimpleUpdate dt::Number tol::Float64 maxiter::Int - trscheme::TensorKit.TruncationScheme + trscheme::TruncationScheme end # TODO: add kwarg constructor and SU Defaults @@ -34,7 +34,7 @@ function _su_xbond!( col::Int, gate::AbstractTensorMap{T,S,2,2}, peps::InfiniteWeightPEPS, - alg::SimpleUpdate, + trscheme::TruncationScheme, ) where {T<:Number,S<:ElementarySpace} Nr, Nc = size(peps) @assert 1 <= row <= Nr && 1 <= col <= Nc @@ -47,7 +47,7 @@ function _su_xbond!( B = _absorb_weights(B, peps.weights, row, cp1, Tuple(1:4), sqrtsB, false) # apply gate X, a, b, Y = _qr_bond(A, B) - a, s, b, ϵ = _apply_gate(a, b, gate, alg.trscheme) + a, s, b, ϵ = _apply_gate(a, b, gate, trscheme) A, B = _qr_bond_undo(X, a, b, Y) # remove environment weights _allfalse = ntuple(Returns(false), 3) @@ -86,6 +86,9 @@ function su_iter( # to update them using code for x-weights if direction == 2 peps2 = mirror_antidiag(peps2) + trscheme = mirror_antidiag(alg.trscheme) + else + trscheme = alg.trscheme end if bipartite for r in 1:2 @@ -94,7 +97,7 @@ function su_iter( direction == 1 ? gate : gate_mirrored, (CartesianIndex(r, 1), CartesianIndex(r, 2)), ) - ϵ = _su_xbond!(r, 1, term, peps2, alg) + ϵ = _su_xbond!(r, 1, term, peps2, truncation_scheme(trscheme, 1, r, 1)) peps2.vertices[rp1, 2] = deepcopy(peps2.vertices[r, 1]) peps2.vertices[rp1, 1] = deepcopy(peps2.vertices[r, 2]) peps2.weights[1, rp1, 2] = deepcopy(peps2.weights[1, r, 1]) @@ -106,7 +109,7 @@ function su_iter( direction == 1 ? gate : gate_mirrored, (CartesianIndex(r, c), CartesianIndex(r, c + 1)), ) - ϵ = _su_xbond!(r, c, term, peps2, alg) + ϵ = _su_xbond!(r, c, term, peps2, truncation_scheme(trscheme, 1, r, c)) end end if direction == 2 @@ -185,6 +188,7 @@ function simpleupdate( nnonly = is_nearest_neighbour(ham) use_3site = force_3site || !nnonly @assert !(bipartite && use_3site) "3-site simple update is incompatible with bipartite lattice." + # TODO: check SiteDependentTruncation is compatible with bipartite structure if use_3site return _simpleupdate3site(peps, ham, alg; check_interval) else diff --git a/src/algorithms/time_evolution/simpleupdate3site.jl b/src/algorithms/time_evolution/simpleupdate3site.jl index 1ce887b5b..1672e2c5e 100644 --- a/src/algorithms/time_evolution/simpleupdate3site.jl +++ b/src/algorithms/time_evolution/simpleupdate3site.jl @@ -215,7 +215,7 @@ The arrows between `Pa`, `s`, `Pb` are function _proj_from_RL( r::AbstractTensorMap{T,S,1,1}, l::AbstractTensorMap{T,S,1,1}; - trunc::TensorKit.TruncationScheme=notrunc(), + trunc::TruncationScheme=notrunc(), rev::Bool=false, ) where {T<:Number,S<:ElementarySpace} rl = r * l @@ -235,16 +235,19 @@ end Given a cluster `Ms` and the pre-calculated `R`, `L` bond matrices, find all projectors `Pa`, `Pb` and Schmidt weights `wts` on internal bonds. """ -function _get_allprojs(Ms, Rs, Ls, trunc::TensorKit.TruncationScheme, revs::Vector{Bool}) +function _get_allprojs( + Ms, Rs, Ls, trschemes::Vector{E}, revs::Vector{Bool} +) where {E<:TruncationScheme} N = length(Ms) + @assert length(trschemes) == N - 1 projs_errs = map(1:(N - 1)) do i - trunc2 = if isa(trunc, FixedSpaceTruncation) + trunc = if isa(trschemes[i], FixedSpaceTruncation) V = space(Ms[i + 1], 1) truncspace(isdual(V) ? V' : V) else - trunc + trschemes[i] end - return _proj_from_RL(Rs[i], Ls[i]; trunc=trunc2, rev=revs[i]) + return _proj_from_RL(Rs[i], Ls[i]; trunc, rev=revs[i]) end Pas = map(Base.Fix2(getindex, 1), projs_errs) wts = map(Base.Fix2(getindex, 2), projs_errs) @@ -258,10 +261,10 @@ end Find projectors to truncate internal bonds of the cluster `Ms` """ function _cluster_truncate!( - Ms::Vector{T}, trunc::TensorKit.TruncationScheme, revs::Vector{Bool} -) where {T<:PEPSTensor} + Ms::Vector{T}, trschemes::Vector{E}, revs::Vector{Bool} +) where {T<:PEPSTensor,E<:TruncationScheme} Rs, Ls = _get_allRLs(Ms) - Pas, Pbs, wts, ϵs = _get_allprojs(Ms, Rs, Ls, trunc, revs) + Pas, Pbs, wts, ϵs = _get_allprojs(Ms, Rs, Ls, trschemes, revs) # apply projectors # M1 -- (Pa1,wt1,Pb1) -- M2 -- (Pa2,wt2,Pb2) -- M3 for (i, (Pa, Pb)) in enumerate(zip(Pas, Pbs)) @@ -322,13 +325,13 @@ In the cluster, the axes of each PEPSTensor are reordered as ``` """ function apply_gatempo!( - Ms::Vector{T1}, gs::Vector{T2}; trunc::TensorKit.TruncationScheme -) where {T1<:PEPSTensor,T2<:AbstractTensorMap} + Ms::Vector{T1}, gs::Vector{T2}; trschemes::Vector{E} +) where {T1<:PEPSTensor,T2<:AbstractTensorMap,E<:TruncationScheme} @assert length(Ms) == length(gs) revs = [isdual(space(M, 1)) for M in Ms[2:end]] @assert !all(revs) _apply_gatempo!(Ms, gs) - wts, ϵs, = _cluster_truncate!(Ms, trunc, revs) + wts, ϵs, = _cluster_truncate!(Ms, trschemes, revs) return wts, ϵs end @@ -373,8 +376,8 @@ function get_3site_se(peps::InfiniteWeightPEPS, row::Int, col::Int) end function _su3site_se!( - row::Int, col::Int, gs::Vector{T}, peps::InfiniteWeightPEPS, alg::SimpleUpdate -) where {T<:AbstractTensorMap} + row::Int, col::Int, gs::Vector{T}, peps::InfiniteWeightPEPS, trschemes::Vector{E} +) where {T<:AbstractTensorMap,E<:TruncationScheme} Nr, Nc = size(peps) @assert 1 <= row <= Nr && 1 <= col <= Nc rm1, cp1 = _prev(row, Nr), _next(col, Nc) @@ -384,7 +387,7 @@ function _su3site_se!( coords = ((row, col), (row, cp1), (rm1, cp1)) # weights in the cluster wt_idxs = ((1, row, col), (2, row, cp1)) - wts, ϵ = apply_gatempo!(Ms, gs; trunc=alg.trscheme) + wts, ϵ = apply_gatempo!(Ms, gs; trschemes) for (wt, wt_idx) in zip(wts, wt_idxs) peps.weights[CartesianIndex(wt_idx)] = wt / norm(wt, Inf) end @@ -414,13 +417,19 @@ function su3site_iter( ), ) peps2 = deepcopy(peps) + trscheme = alg.trscheme for i in 1:4 for site in CartesianIndices(peps2.vertices) r, c = site[1], site[2] gs = gatempos[i][r, c] - _su3site_se!(r, c, gs, peps2, alg) + trschemes = [ + truncation_scheme(trscheme, 1, r, c) + truncation_scheme(trscheme, 2, r, _next(c, size(peps2)[2])) + ] + _su3site_se!(r, c, gs, peps2, trschemes) end peps2 = rotl90(peps2) + trscheme = rotl90(trscheme) end return peps2 end diff --git a/src/algorithms/truncation/bond_truncation.jl b/src/algorithms/truncation/bond_truncation.jl index 40f413074..c50e1ac38 100644 --- a/src/algorithms/truncation/bond_truncation.jl +++ b/src/algorithms/truncation/bond_truncation.jl @@ -13,13 +13,13 @@ $(TYPEDFIELDS) The truncation algorithm can be constructed from the following keyword arguments: -* `trscheme::TensorKit.TruncationScheme`: SVD truncation scheme when initilizing the truncated tensors connected by the bond. +* `trscheme::TruncationScheme`: SVD truncation scheme when initilizing the truncated tensors connected by the bond. * `maxiter::Int=50` : Maximal number of ALS iterations. * `tol::Float64=1e-15` : ALS converges when fidelity change between two FET iterations is smaller than `tol`. * `check_interval::Int=0` : Set number of iterations to print information. Output is suppressed when `check_interval <= 0`. """ @kwdef struct ALSTruncation - trscheme::TensorKit.TruncationScheme + trscheme::TruncationScheme maxiter::Int = 50 tol::Float64 = 1e-15 check_interval::Int = 0 diff --git a/src/algorithms/truncation/fullenv_truncation.jl b/src/algorithms/truncation/fullenv_truncation.jl index 7856f499b..f3fa336b3 100644 --- a/src/algorithms/truncation/fullenv_truncation.jl +++ b/src/algorithms/truncation/fullenv_truncation.jl @@ -13,7 +13,7 @@ $(TYPEDFIELDS) The truncation algorithm can be constructed from the following keyword arguments: -* `trscheme::TensorKit.TruncationScheme` : SVD truncation scheme when optimizing the new bond matrix. +* `trscheme::TruncationScheme` : SVD truncation scheme when optimizing the new bond matrix. * `maxiter::Int=50` : Maximal number of FET iterations. * `tol::Float64=1e-15` : FET converges when fidelity change between two FET iterations is smaller than `tol`. * `trunc_init::Bool=true` : Controls whether the initialization of the new bond matrix is obtained from truncated SVD of the old bond matrix. @@ -24,7 +24,7 @@ The truncation algorithm can be constructed from the following keyword arguments * [Glen Evenbly, Phys. Rev. B 98, 085155 (2018)](@cite evenbly_gauge_2018). """ @kwdef struct FullEnvTruncation - trscheme::TensorKit.TruncationScheme + trscheme::TruncationScheme maxiter::Int = 50 tol::Float64 = 1e-15 trunc_init::Bool = true diff --git a/src/algorithms/truncation/truncationschemes.jl b/src/algorithms/truncation/truncationschemes.jl index 9039d925e..3420c90b3 100644 --- a/src/algorithms/truncation/truncationschemes.jl +++ b/src/algorithms/truncation/truncationschemes.jl @@ -5,7 +5,11 @@ CTMRG specific truncation scheme for `tsvd` which keeps the bond space on which is performed fixed. Since different environment directions and unit cell entries might have different spaces, this truncation style is different from `TruncationSpace`. """ -struct FixedSpaceTruncation <: TensorKit.TruncationScheme end +struct FixedSpaceTruncation <: TruncationScheme end + +struct SiteDependentTruncation{T<:TruncationScheme} <: TruncationScheme + trschemes::Array{T,3} +end const TRUNCATION_SCHEME_SYMBOLS = IdDict{Symbol,Type{<:TruncationScheme}}( :fixedspace => FixedSpaceTruncation, @@ -14,6 +18,7 @@ const TRUNCATION_SCHEME_SYMBOLS = IdDict{Symbol,Type{<:TruncationScheme}}( :truncdim => TensorKit.TruncationDimension, :truncspace => TensorKit.TruncationSpace, :truncbelow => TensorKit.TruncationCutoff, + :sitedependent => SiteDependentTruncation, ) # Should be TruncationScheme but rename to avoid type piracy @@ -25,3 +30,65 @@ function _TruncationScheme(; alg=Defaults.trscheme, η=nothing) return isnothing(η) ? alg_type() : alg_type(η) end + +function truncation_scheme( + trscheme::TruncationScheme, direction::Int, row::Int, col::Int; kwargs... +) + return trscheme +end + +function truncation_scheme( + trscheme::SiteDependentTruncation, direction::Int, row::Int, col::Int; +) + return trscheme.trschemes[direction, row, col] +end + +# Mirror a TruncationScheme by its anti-diagonal line. +# When the number of directions is 2, it swaps the first and second direction, consistent with xbonds and ybonds, respectively. +# When the number of directions is 4, it swaps the first and second, and third and fourth directions, consistent with the order NORTH, EAST, SOUTH, WEST. +mirror_antidiag(trscheme::TruncationScheme) = trscheme +function mirror_antidiag(trscheme::SiteDependentTruncation) + directions = size(trscheme.trschemes)[1] + if directions == 2 + trschemes_mirrored = stack( + ( + mirror_antidiag(trscheme.trschemes[EAST, :, :]), + mirror_antidiag(trscheme.trschemes[NORTH, :, :]), + ); + dims=1, + ) + elseif directions == 4 + trschemes_mirrored = stack(( + mirror_antidiag(trscheme.trschemes[EAST, :, :]), + mirror_antidiag(trscheme.trschemes[NORTH, :, :]), + mirror_antidiag(trscheme.trschemes[WEST, :, :]), + mirror_antidiag(trscheme.trschemes[SOUTH, :, :]), + )) + else + error("Unsupported number of directions for mirror_antidiag: $directions") + end + return SiteDependentTruncation(trschemes_mirrored) +end + +# TODO: type piracy +Base.rotl90(trscheme::TruncationScheme) = trscheme + +function Base.rotl90(trscheme::SiteDependentTruncation) + directions, rows, cols = size(trscheme.trschemes) + trschemes_rotated = similar(trscheme.trschemes, directions, cols, rows) + + if directions == 2 + trschemes_rotated[NORTH, :, :] = circshift( + rotl90(trscheme.trschemes[EAST, :, :]), (0, -1) + ) + trschemes_rotated[EAST, :, :] = rotl90(trscheme.trschemes[NORTH, :, :]) + elseif directions == 4 + for dir in 1:4 + dir′ = _prev(dir, 4) + trschemes_rotated[dir′, :, :] = rotl90(trscheme.trschemes[dir, :, :]) + end + else + throw(ArgumentError("Unsupported number of directions for rotl90: $directions")) + end + return SiteDependentTruncation(trschemes_rotated) +end diff --git a/test/runtests.jl b/test/runtests.jl index 7d93eda49..787964a9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,9 @@ end @time @safetestset "Cluster truncation with projectors" begin include("timeevol/cluster_projectors.jl") end + @time @safetestset "Time evolution with site-dependent truncation" begin + include("timeevol/sitedep_truncation.jl") + end end if GROUP == "ALL" || GROUP == "UTILITY" @time @safetestset "LocalOperator" begin diff --git a/test/timeevol/cluster_projectors.jl b/test/timeevol/cluster_projectors.jl index 4ad3f60f6..91b2994e8 100644 --- a/test/timeevol/cluster_projectors.jl +++ b/test/timeevol/cluster_projectors.jl @@ -30,7 +30,7 @@ Vspaces = [ revs = [isdual(space(M, 1)) for M in Ms1[2:end]] # no truncation Ms2 = deepcopy(Ms1) - wts2, ϵs, = _cluster_truncate!(Ms2, FixedSpaceTruncation(), revs) + wts2, ϵs, = _cluster_truncate!(Ms2, fill(FixedSpaceTruncation(), N-1), revs) @test all((ϵ == 0) for ϵ in ϵs) absorb_wts_cluster!(Ms2, wts2) for (i, M) in enumerate(Ms2) @@ -41,7 +41,7 @@ Vspaces = [ @test all(lorths) && all(rorths) # truncation on one bond Ms3 = deepcopy(Ms1) - wts3, ϵs, = _cluster_truncate!(Ms3, truncspace(Vns), revs) + wts3, ϵs, = _cluster_truncate!(Ms3, fill(truncspace(Vns), N-1), revs) @test all((i == n) || (ϵ == 0) for (i, ϵ) in enumerate(ϵs)) absorb_wts_cluster!(Ms3, wts3) for (i, M) in enumerate(Ms3) diff --git a/test/timeevol/sitedep_truncation.jl b/test/timeevol/sitedep_truncation.jl new file mode 100644 index 000000000..ca62daca9 --- /dev/null +++ b/test/timeevol/sitedep_truncation.jl @@ -0,0 +1,54 @@ +using Test +using LinearAlgebra +using Random +using TensorKit +using PEPSKit +using PEPSKit: NORTH, EAST + +function get_bonddims(wpeps::InfiniteWeightPEPS) + xdims = collect(dim(domain(t, EAST)) for t in wpeps.vertices) + ydims = collect(dim(domain(t, NORTH)) for t in wpeps.vertices) + return stack([xdims, ydims]; dims=1) +end + +@testset "Simple update: bipartite 2-site" begin + Nr, Nc = 2, 2 + ham = real(heisenberg_XYZ(InfiniteSquare(Nr, Nc); Jx=1.0, Jy=1.0, Jz=1.0)) + Random.seed!(100) + wpeps0 = InfiniteWeightPEPS(rand, Float64, ℂ^2, ℂ^10; unitcell=(Nr, Nc)) + normalize!.(wpeps0.vertices, Inf) + # set trscheme to be compatible with bipartite structure + bonddims = stack([[6 4; 4 6], [5 7; 7 5]]; dims=1) + trscheme = SiteDependentTruncation(collect(truncdim(d) for d in bonddims)) + alg = SimpleUpdate(1e-2, 1e-14, 4, trscheme) + wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=true) + @test get_bonddims(wpeps) == bonddims + # check bipartite structure is preserved + for col in 1:2 + cp1 = PEPSKit._next(col, 2) + @test ( + wpeps.vertices[1, col] == wpeps.vertices[2, cp1] && + wpeps.weights[1, 1, col] == wpeps.weights[1, 2, cp1] && + wpeps.weights[2, 1, col] == wpeps.weights[2, 2, cp1] + ) + end +end + +@testset "Simple update: generic 2-site and 3-site" begin + Nr, Nc = 3, 4 + ham = real(heisenberg_XYZ(InfiniteSquare(Nr, Nc); Jx=1.0, Jy=1.0, Jz=1.0)) + Random.seed!(100) + wpeps0 = InfiniteWeightPEPS(rand, Float64, ℂ^2, ℂ^10; unitcell=(Nr, Nc)) + normalize!.(wpeps0.vertices, Inf) + # Site dependent truncation + bonddims = rand(2:8, 2, Nr, Nc) + @show bonddims + trscheme = SiteDependentTruncation(collect(truncdim(d) for d in bonddims)) + alg = SimpleUpdate(1e-2, 1e-14, 2, trscheme) + # 2-site SU + wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=false) + @test get_bonddims(wpeps) == bonddims + # 3-site SU + wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=false, force_3site=true) + @test get_bonddims(wpeps) == bonddims +end