Skip to content

Commit 5432fe2

Browse files
committed
Fix broken merge
Fix broken merge
1 parent 9d937aa commit 5432fe2

File tree

2 files changed

+29
-29
lines changed

2 files changed

+29
-29
lines changed

src/beliefpropagation/beliefpropagationproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ end
7979
function SimpleMessageUpdate(
8080
edge;
8181
normalize = true,
82-
contraction_alg = "eager",
82+
contraction_alg = "exact",
8383
compute_diff = false,
8484
kwargs...
8585
)
@@ -275,7 +275,7 @@ function select_algorithm(
275275
end
276276

277277
extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter)
278-
edge_kwargs = rows(extended_kwargs, len = maxiter)
278+
edge_kwargs = rows(extended_kwargs, maxiter)
279279

280280
return BeliefPropagation(maxiter; stopping_criterion) do repnum
281281
return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...)

src/contract_network.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
using BackendSelection: @Algorithm_str, Algorithm
22
using Base.Broadcast: materialize
3-
using NamedDimsArrays: inds
4-
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, Mul, lazy, optimize_evaluation_order,
3+
using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order,
54
substitute, symnameddims
65

7-
function contract_network(tn; alg = default_kwargs(contract_network, tn).alg)
8-
return contract_network(alg, tn)
6+
# This is related to `MatrixAlgebraKit.select_algorithm`.
7+
# TODO: Define this in BackendSelection.jl.
8+
backend_value(::Algorithm{alg}) where {alg} = alg
9+
using BackendSelection: parameters
10+
function merge_parameters(alg::Algorithm; kwargs...)
11+
return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...)
12+
end
13+
to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...)
14+
to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...)
15+
16+
# `contract_network`
17+
function contract_network(alg::Algorithm, tn)
18+
return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented."))
19+
end
20+
function default_kwargs(::typeof(contract_network), tn)
21+
return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"()))
22+
end
23+
function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...)
24+
return contract_network(to_algorithm(alg; kwargs...), tn)
925
end
1026

1127
# `contract_network(::Algorithm"exact", ...)`
@@ -34,36 +50,20 @@ end
3450

3551
# `contraction_order`
3652
function contraction_order end
37-
default_kwargs(::typeof(contraction_order), tensors) = (; order = "eager")
38-
39-
function contraction_expression(tensors; order = default_kwargs(contraction_order, tensors).order)
40-
order = contraction_order(order, tensors)
41-
42-
# Contraction order may or may not have indices attached, canonicalize the format
43-
# by attaching indices.
44-
subs = Dict(symnameddims(i) => symnameddims(i, tensors[i]) for i in keys(tensors))
45-
46-
return substitute(order, subs)
47-
end
48-
49-
contraction_order(order, tensors) = order
50-
function contraction_order(tensors; order = default_kwargs(contraction_order, tensors).order)
51-
return contraction_order(Algorithm(order), tensors)
53+
default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"())
54+
function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...)
55+
return contraction_order(to_algorithm(alg; kwargs...), tn)
5256
end
5357
# Convert the tensor network to a flat symbolic multiplication expression.
54-
function contraction_order(::Algorithm"flat", tensors)
58+
function contraction_order(alg::Algorithm"flat", tn)
5559
# Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`.
5660
syms = vec([symnameddims(i, Tuple(axes(tn[i]))) for i in keys(tn)])
5761
return lazy(Mul(syms))
5862
end
5963
function contraction_order(alg::Algorithm"left_associative", tn)
6064
return prod(i -> symnameddims(i, Tuple(axes(tn[i]))), keys(tn))
6165
end
62-
63-
function contraction_order(
64-
order_algorithm::Algorithm,
65-
tensors,
66-
)
67-
order = contraction_order(tensors; order = "flat")
68-
return optimize_evaluation_order(order; alg = order_algorithm)
66+
function contraction_order(alg::Algorithm, tn)
67+
s = contraction_order(Algorithm"flat"(), tn)
68+
return optimize_evaluation_order(s; alg)
6969
end

0 commit comments

Comments
 (0)