|
1 | 1 | using BackendSelection: @Algorithm_str, Algorithm |
2 | 2 | using Base.Broadcast: materialize |
3 | | -using NamedDimsArrays: inds |
4 | | -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, Mul, lazy, optimize_evaluation_order, |
| 3 | +using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, |
5 | 4 | substitute, symnameddims |
6 | 5 |
|
7 | | -function contract_network(tn; alg = default_kwargs(contract_network, tn).alg) |
8 | | - return contract_network(alg, tn) |
| 6 | +# This is related to `MatrixAlgebraKit.select_algorithm`. |
| 7 | +# TODO: Define this in BackendSelection.jl. |
| 8 | +backend_value(::Algorithm{alg}) where {alg} = alg |
| 9 | +using BackendSelection: parameters |
| 10 | +function merge_parameters(alg::Algorithm; kwargs...) |
| 11 | + return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...) |
| 12 | +end |
| 13 | +to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...) |
| 14 | +to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...) |
| 15 | + |
| 16 | +# `contract_network` |
| 17 | +function contract_network(alg::Algorithm, tn) |
| 18 | + return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented.")) |
| 19 | +end |
| 20 | +function default_kwargs(::typeof(contract_network), tn) |
| 21 | + return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"())) |
| 22 | +end |
| 23 | +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) |
| 24 | + return contract_network(to_algorithm(alg; kwargs...), tn) |
9 | 25 | end |
10 | 26 |
|
11 | 27 | # `contract_network(::Algorithm"exact", ...)` |
|
34 | 50 |
|
35 | 51 | # `contraction_order` |
36 | 52 | function contraction_order end |
37 | | -default_kwargs(::typeof(contraction_order), tensors) = (; order = "eager") |
38 | | - |
39 | | -function contraction_expression(tensors; order = default_kwargs(contraction_order, tensors).order) |
40 | | - order = contraction_order(order, tensors) |
41 | | - |
42 | | - # Contraction order may or may not have indices attached, canonicalize the format |
43 | | - # by attaching indices. |
44 | | - subs = Dict(symnameddims(i) => symnameddims(i, tensors[i]) for i in keys(tensors)) |
45 | | - |
46 | | - return substitute(order, subs) |
47 | | -end |
48 | | - |
49 | | -contraction_order(order, tensors) = order |
50 | | -function contraction_order(tensors; order = default_kwargs(contraction_order, tensors).order) |
51 | | - return contraction_order(Algorithm(order), tensors) |
| 53 | +default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) |
| 54 | +function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) |
| 55 | + return contraction_order(to_algorithm(alg; kwargs...), tn) |
52 | 56 | end |
53 | 57 | # Convert the tensor network to a flat symbolic multiplication expression. |
54 | | -function contraction_order(::Algorithm"flat", tensors) |
| 58 | +function contraction_order(alg::Algorithm"flat", tn) |
55 | 59 | # Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`. |
56 | 60 | syms = vec([symnameddims(i, Tuple(axes(tn[i]))) for i in keys(tn)]) |
57 | 61 | return lazy(Mul(syms)) |
58 | 62 | end |
59 | 63 | function contraction_order(alg::Algorithm"left_associative", tn) |
60 | 64 | return prod(i -> symnameddims(i, Tuple(axes(tn[i]))), keys(tn)) |
61 | 65 | end |
62 | | - |
63 | | -function contraction_order( |
64 | | - order_algorithm::Algorithm, |
65 | | - tensors, |
66 | | - ) |
67 | | - order = contraction_order(tensors; order = "flat") |
68 | | - return optimize_evaluation_order(order; alg = order_algorithm) |
| 66 | +function contraction_order(alg::Algorithm, tn) |
| 67 | + s = contraction_order(Algorithm"flat"(), tn) |
| 68 | + return optimize_evaluation_order(s; alg) |
69 | 69 | end |
0 commit comments