Conversation
Introduce `BeliefPropagationProblem` wrapper to hold the cache and the error `diff` field. Also simplifies some kwargs wrangling.
…be set from another cache
Also includes some fixes to the way `TensorNetwork` types are constructed based on index structure.
for more information, see https://pre-commit.ci
…instead of trying to operate on existing graphs The reason for this is: - One only cares about the edges of the input graph - A simple graph cannot be used as it "forgets" its edge names resulting in recursion - As shown with `TensorNetwork`, removing edges may not always be defined.
…s from an array.
This was caused by the change to the `cache` being backed by a directed graph.
|
Your PR no longer requires formatting changes. Thank you for your contribution! |
a20ddcb to
ac74a7f
Compare
…ble when constructing nested algorithms
Fix broken merge
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #26 +/- ##
==========================================
- Coverage 63.58% 57.88% -5.71%
==========================================
Files 21 21
Lines 736 1040 +304
==========================================
+ Hits 468 602 +134
- Misses 268 438 +170
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| messages = incoming_messages(bp_cache, vertex) | ||
| state = factors(bp_cache, vertex) | ||
|
|
||
| return (reduce(*, messages) * reduce(*, state))[] |
There was a problem hiding this comment.
This should be done with a contraction sequence finder. Probably with a default like "greedy" but the ability to overload the sequence kwargs would be ideal.
| message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) | ||
| message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED | ||
|
|
||
| function free_energy(bp_cache::AbstractBeliefPropagationCache) |
There was a problem hiding this comment.
Maybe this should be called logscalar ? Technically it returns log(scalar(bp_cache)) where scalar(bp_cache) is the BP approximation of the underlying network. The free energy is really -kBTlog(Z) so free_energy might be a misnomer.
There was a problem hiding this comment.
Agreed, math terminology would be preferred over physics terminology.
| return rv | ||
| end | ||
|
|
||
| function induced_subgraph_bpcache(graph, subvertices) |
There was a problem hiding this comment.
Is there a reason this can't ne called induced_subgraph and dispatch based on the type?
There was a problem hiding this comment.
It is a convention we are using where induced_subgraph_bpcache is a function defining the implementation of induced_subgraph for bpcache (perhaps this suffix should be renamed to something more explicit). Then if one defines a type NotAnAbstractBeliefPropagationCache that doesn't subtype AbstractBeliefPropagationCache but should behave like an AbstractBeliefPropagationCache (at least for subgraph purposes) one can define:
function induced_subgraph_from_vertices(cache::NotAnAbstractBeliefPropagationCache, subvertices)
return induced_subgraph_bpcache(cache, subvertices)
endNote the appropriate function to overload is usually induced_subgraph_from_vertices, which assumes the subvertices argument as already been canonized via to_vertices. In any case, overloading induced_subgraph directly can be lead to annoying method ambiguities as induced_subgraph is a function in Graphs.jl that we have no control over. Matt and I have discussed refining this aspect of the code to move away from induced_subgraph entirely and just having the call stack be:
subgraph(graph, vertices) = subgraph_<implementation>(graph, to_vertices(graph, vertices))which would simplify things greatly.
| using NamedGraphs.PartitionedGraphs: quotientvertices | ||
|
|
||
| @kwdef struct StopWhenConverged <: AI.StoppingCriterion | ||
| tol::Float64 = 0.0 |
There was a problem hiding this comment.
Can we use Number instead of Float64. It will be good to be able to use Float32 (and other precisions) arithmetic if we wish.
There was a problem hiding this comment.
We can use AbstractFloat.
There was a problem hiding this comment.
Just a note that probably we should parametrize the type (we can constrain the parameter to AbstractFloat or Real).
|
|
||
| if algorithm.normalize | ||
| # TODO: use `sum` not `norm` | ||
| message_norm = LinearAlgebra.norm(state.iterate) |
There was a problem hiding this comment.
message_norm = sum(state.iterate) .
This is important for stability in the presence of complex numbers (avoids blow up of non-PSDness of messages).
| maxiter = is_tree(cache) ? 1 : nothing, | ||
| tol = -Inf, | ||
| message_diff_function = if tol > -Inf | ||
| (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) |
There was a problem hiding this comment.
This message diff is currently dependent on phase, i.e. m1 = -m2 will count as non-converged. 1 - abs2(dot(m1, m2)) might be better, with m1 and m2 normalized.
| @assert U * LinearAlgebra.diagm(S) * V ≈ W | ||
| id = [1.0 0.0; 0.0 1.0] | ||
| set!(sqrt_Ws, e, id) | ||
| set!(sqrt_Ws, reverse(e), U * LinearAlgebra.diagm(S) * V) |
There was a problem hiding this comment.
Currently the SVD is not really playing a role here.
set!(sqrt_Ws, reverse(e), W) would work. Or, probably best, take the square root via the SVD and set U*sqrt(S) on e and sqrt(S)*V on reverse(e).
| end | ||
|
|
||
| bpc = BeliefPropagationCache(tn) | ||
| bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) |
There was a problem hiding this comment.
It would be nice if the maxiter default could be set to 1 if graph(bpc) ==1 and unset otherwise.
There was a problem hiding this comment.
@JoeyT1994 what do you mean by graph(bpc) == 1? Shouldn't the check be that it is a tree?
There was a problem hiding this comment.
Yes sorry I mean is_tree(graph(bpc))
This PR express belief propagation in terms of the new interface based on AlgorithmsInterface.jl and the included AlgorithmsInterfaceExtensions.jl library.
Note, the tolerance for the second belief propagation test has been increased from
1.0e-14to1.0e-12, however this test was not passing reliable before the changes in this PR anyway.