From 03f2890919157da14eea341c25e568df2145041b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 16:27:32 +0200 Subject: [PATCH 001/133] Remove PowerWeightedMeasure Unused and untested. --- src/MeasureBase.jl | 1 - src/combinators/powerweighted.jl | 37 -------------------------------- 2 files changed, 38 deletions(-) delete mode 100644 src/combinators/powerweighted.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index e29c4ae9..8d1a4bec 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -138,7 +138,6 @@ include("combinators/likelihood.jl") include("combinators/pointwise.jl") include("combinators/restricted.jl") include("combinators/smart-constructors.jl") -include("combinators/powerweighted.jl") include("combinators/conditional.jl") include("standard/stdmeasure.jl") diff --git a/src/combinators/powerweighted.jl b/src/combinators/powerweighted.jl deleted file mode 100644 index 47f50da4..00000000 --- a/src/combinators/powerweighted.jl +++ /dev/null @@ -1,37 +0,0 @@ -export ↑ - -struct PowerWeightedMeasure{M,A} <: AbstractMeasure - parent::M - exponent::A -end - -logdensity_def(d::PowerWeightedMeasure, x) = d.exponent * logdensity_def(d.parent, x) - -basemeasure(d::PowerWeightedMeasure, x) = basemeasure(d.parent, x)↑d.exponent - -basemeasure(d::PowerWeightedMeasure) = basemeasure(d.parent)↑d.exponent - -function powerweightedmeasure(d, α) - isone(α) && return d - PowerWeightedMeasure(d, α) -end - -(d::AbstractMeasure)↑α = powerweightedmeasure(d, α) - -insupport(d::PowerWeightedMeasure, x) = insupport(d.parent, x) - -function Base.show(io::IO, d::PowerWeightedMeasure) - print(io, d.parent, " ↑ ", d.exponent) -end - -function powerweightedmeasure(d::PowerWeightedMeasure, α) - powerweightedmeasure(d.parent, α * d.exponent) -end - -function powerweightedmeasure(d::WeightedMeasure, α) - weightedmeasure(α * d.logweight, powerweightedmeasure(d.base, α)) -end - -function Pretty.tile(d::PowerWeightedMeasure) - Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.exponent), sep = " ↑ ") -end From eac3622cb04e166e99fd0248b9285783c4d6bf3e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 16:01:17 +0200 Subject: [PATCH 002/133] Remove kernelfactor Not used currently. --- src/parameterized.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/parameterized.jl b/src/parameterized.jl index 78e43995..8b1c8c88 100644 --- a/src/parameterized.jl +++ b/src/parameterized.jl @@ -127,14 +127,3 @@ params(::Type{PM}) where {N,PM<:ParameterizedMeasure{N}} = N function paramnames(μ, constraints::NamedTuple{N}) where {N} tuple((k for k in paramnames(μ) if k ∉ N)...) end - -############################################################################### -# kernelfactor - -function kernelfactor(::Type{P}) where {N,P<:ParameterizedMeasure{N}} - (constructorof(P), N) -end - -function kernelfactor(::P) where {N,P<:ParameterizedMeasure{N}} - (constructorof(P), N) -end From b38d1419a2084ee8d338db8d2d9ccadcfcec1583 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 13:27:07 +0200 Subject: [PATCH 003/133] Rename pullback to pullbck and export it pullback has a huge potential for naming conflickts, and pullbck is more in line with pushfwd. Also simplify implementation of pullbck. --- src/combinators/transformedmeasure.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 803b404b..dab76d5f 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -140,7 +140,7 @@ end # pullback """ - pullback(f, μ, volcorr = WithVolCorr()) + pullbck(f, μ, volcorr = WithVolCorr()) A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a map _from_ the support of a measure, a pullback requires a map _into_ the @@ -152,8 +152,11 @@ in terms of the inverse function; the "forward" function is not used at all. In some cases, we may be focusing on log-density (and not, for example, sampling). To manually specify an inverse, call -`pullback(InverseFunctions.setinverse(f, finv), μ, volcorr)`. +`pullbck(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -function pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) - pushfwd(setinverse(inverse(f), f), μ, volcorr) +function pullbck(f, μ, volcorr::TransformVolCorr = WithVolCorr()) + PushforwardMeasure(inverse(f), f, μ, volcorr) end +export pullbck + +@deprecate pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) pullbck(f, μ, volcorr) From eccc203da1a0adeb8f2d0586d5c26d65c7fb5538 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 13:27:07 +0200 Subject: [PATCH 004/133] Rename bind to mbind and deprecate rightarrowtail Bind has too much naming conflict potential with Base.bind. The rightarrowtail operator looks very similar to the `>=>` "fish" operator (e.g. in Haskell), which is not a monadic bind. --- src/combinators/bind.jl | 45 ++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index cc2022f2..491bd5a2 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -3,34 +3,47 @@ struct Bind{M,K} <: AbstractMeasure k::K end -export ↣ """ -If -- μ is an `AbstractMeasure` or satisfies the Measure interface, and -- k is a function taking values from the support of μ and returning a measure + mbind(k, μ)::AbstractMeasure + +Given + +- a measure μ +- a kernel function k that takes values from the support of μ and returns a + measure -Then `μ ↣ k` is a measure, called a *monadic bind*. In a -probabilistic programming language like Soss.jl, this could be expressed as +The *monadic bind* operation `mbind(k, μ)` returns is a new measure. -Note that bind is usually written `>>=`, but this symbol is unavailable in Julia. +A monadic bind ofen written as `>>=` (e.g. in Haskell), but this symbol is +unavailable in Julia. ``` -bind = @model μ,k begin - x ~ μ - y ~ k(x) - return y +μ = StdExponential() +ν = mbind(μ) do scale + pushfwd(Base.Fix1(*, scale), StdNormal()) end ``` - -See also `bind` and `Bind` """ -↣(μ, k) = bind(μ, k) - -bind(μ, k) = Bind(μ, k) +mbind(k, μ) = Bind(μ, k) +export mbind function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} x = rand(rng, T, d.μ) y = rand(rng, T, d.k(x)) return y end + + +# ToDo: Remove `bind` (breaking). +@noinline function bind(μ, k) + Base.depwarn("`foo(μ, k)` is deprecated, use `mbind(k, μ)` instead.", :bind) + mbind(k, μ) +end + + +# ToDo: Remove `↣` (breaking): It looks too similar to the `>=>` "fish" +# operator (e.g. in Haskell) that is typically understood to take two monadic +# functions as an argument, while a bind take a monad and a monadic functions. +@deprecate ↣(μ, k) mbind(μ, k) +export ↣ From 49c58c114608a5d49af582d5cc4782c95c7a51a6 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 13:27:07 +0200 Subject: [PATCH 005/133] Introduce mintegrate and mintegrate_exp Removes the integral operators from MeasureBase, to be re-introduced in the submodule MeasureOperators. Also improves the likelihood documentation. --- src/combinators/likelihood.jl | 134 +++++++++++++++++++++------------- src/density.jl | 92 +++++++++++++++-------- 2 files changed, 143 insertions(+), 83 deletions(-) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 6dfd164f..93dc1186 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -11,9 +11,9 @@ abstract type AbstractLikelihood end # insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x) @doc raw""" - Likelihood(k::AbstractTransitionKernel, x) + Likelihood(k, x) -"Observe" a value `x`, yielding a function from the parameters to ℝ. +Default result of [`likelihoodof(k, x)`](@ref). Likelihoods are most commonly used in conjunction with an existing _prior_ measure to yield a new measure, the _posterior_. In Bayes's Law, we have @@ -91,12 +91,10 @@ Similarly to the above, we have Finally, let's return to the expression for Bayes's Law, -``P(θ|x) ∝ P(θ) P(x|θ)`` +``P(θ|x) ∝ P(x|θ) P(θ)`` -The product on the right side is computed pointwise. To work with this in -MeasureBase, we have a "pointwise product" `⊙`, which takes a measure and a -likelihood, and returns a new measure, that is, the unnormalized posterior that -has density ``P(θ) P(x|θ)`` with respect to the base measure of the prior. +In measure theory, the product on the right side is actually the Lebesgue integral, +of the likelihood with respect to the prior. For example, say we have @@ -104,21 +102,24 @@ For example, say we have x ~ Normal(μ,σ) σ = 1 -and we observe `x=3`. We can compute the posterior measure on `μ` as - - julia> post = Normal() ⊙ Likelihood(Normal{(:μ, :σ)}, (σ=1,), 3) - Normal() ⊙ Likelihood(Normal{(:μ, :σ), T} where T, (σ = 1,), 3) +and we observe `x=3`. We can compute the (non-normalized) posterior measure on +`μ` as - julia> logdensity_def(post, 2) - -2.5 + julia> prior = Normal() + julia> likelihood = Likelihood(μ -> Normal(μ, 1), 3) + julia> post = mintegrate(likelihood, prior) + julia> post isa MeasureBase.DensityMeasure + true + julia> logdensity_rel(post, Lebesgue(), 2) + -4.337877066409345 """ struct Likelihood{K,X} <: AbstractLikelihood k::K x::X - Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k, x) - Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k, x) - Likelihood(μ, x) = Likelihood(kernel(μ), x) + Likelihood(k::K, x::X) where {K,X} = new{K,X}(k, x) +#!!!!!!!!!!! # For type stability if `K isa UnionAll (e.g. a parameterized MeasureType)` + Likelihood(::Type{K}, x::X) where {K<:AbstractMeasure,X} = new{K,X}(K, x) end (lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) @@ -150,58 +151,87 @@ end export likelihoodof -""" - likelihoodof(k::AbstractTransitionKernel, x; constraints...) - likelihoodof(k::AbstractTransitionKernel, x, constraints::NamedTuple) +@doc raw""" + likelihoodof(k, x) -A likelihood is *not* a measure. Rather, a likelihood acts on a measure, through -the "pointwise product" `⊙`, yielding another measure. -""" -function likelihoodof end +Returns the likelihood of observing `x` under a family of probability +measures that is generated by a transition kernel `k(θ)`. + +`k(θ)` maps points in the parameter space to measures (resp. objects that can +be converted to measures) on a implicit set `Χ` that contains values like `x`. + +`likelihoodof(k, x)` returns a likelihood object. A likelihhood is **not** a +measure, it is a function from the parameter space to `ℝ₊`. Likelihood +objects can also be interpreted as "generic densities" (but **not** as +probability densities). -likelihoodof(k, x, ::NamedTuple{()}) = Likelihood(k, x) +`likelihoodof(k, x)` implicitly chooses `ξ = rootmeasure(k(θ))` as the +reference measure on the observation set `Χ`. Note that this implicit +`ξ` **must** be independent of `θ`. -likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) +`ℒₓ = likelihoodof(k, x)` has the mathematical interpretation -likelihoodof(k, x, pars::NamedTuple) = likelihoodof(kernel(k, pars), x) +```math +\mathcal{L}_x(\theta) = \frac{\rm{d}\, k(\theta)}{\rm{d}\, \chi}(x) +``` -likelihoodof(k::AbstractTransitionKernel, x) = Likelihood(k, x) +`likelihoodof` must return an object that implements the +[`DensityInterface`](https://github.com/JuliaMath/DensityInterface.jl)` API +and `ℒₓ = likelihoodof(k, x)` must satisfy -export log_likelihood_ratio +```julia +log(ℒₓ(θ)) == logdensityof(ℒₓ, θ) ≈ logdensityof(k(θ), x) +DensityKind(ℒₓ) isa IsDensity +``` + +By default, an instance of [`MeasureBase.Likelihood`](@ref) is returned. """ - log_likelihood_ratio(ℓ::Likelihood, p, q) +function likelihoodof end -Compute the log of the likelihood ratio, in order to compare two choices for -parameters. This is computed as +likelihoodof(k, x) = Likelihood(k, x) - logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) -Since `logdensity_rel` can leave common base measure unevaluated, this can be -more efficient than +############################################################################### +# At the least, we need to think through in some more detail whether +# (log-)likelihood ratios expressed in this way are correct and useful. For now +# this code is commented out; we may remove it entirely in the future. - logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) -""" -log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# export log_likelihood_ratio -# likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) +# """ +# log_likelihood_ratio(ℓ::Likelihood, p, q) -export likelihood_ratio +# Compute the log of the likelihood ratio, in order to compare two choices for +# parameters. This is computed as -""" - likelihood_ratio(ℓ::Likelihood, p, q) +# logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) -Compute the log of the likelihood ratio, in order to compare two choices for -parameters. This is equal to +# Since `logdensity_rel` can leave common base measure unevaluated, this can be +# more efficient than - density_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# """ +# log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) -but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. -Since `density_rel` can leave common base measure unevaluated, this can be -more efficient than +# # likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) - logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) -""" -function likelihood_ratio(ℓ::Likelihood, p, q) - exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) -end +# export likelihood_ratio + +# """ +# likelihood_ratio(ℓ::Likelihood, p, q) + +# Compute the log of the likelihood ratio, in order to compare two choices for +# parameters. This is equal to + +# density_rel(ℓ.k(p), ℓ.k(q), ℓ.x) + +# but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. +# Since `density_rel` can leave common base measure unevaluated, this can be +# more efficient than + +# logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# """ +# function likelihood_ratio(ℓ::Likelihood, p, q) +# exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) +# end diff --git a/src/density.jl b/src/density.jl index 4862dcb1..a3b1b95e 100644 --- a/src/density.jl +++ b/src/density.jl @@ -98,12 +98,13 @@ DensityInterface.funcdensity(d::LogDensity) = throw(MethodError(funcdensity, (d, base :: B end -A `DensityMeasure` is a measure defined by a density or log-density with respect -to some other "base" measure. +A `DensityMeasure` is a measure defined by a density or log-density with +respect to some other "base" measure. -Users should not call `DensityMeasure` directly, but should instead call `∫(f, -base)` (if `f` is a density function or `DensityInterface.IsDensity` object) or -`∫exp(f, base)` (if `f` is a log-density function). +Users should not instantiate `DensityMeasure` directly, but should instead +call `mintegral_exp(f, base)` (if `f` is a density function or +`DensityInterface.IsDensity` object) or `mintegral_exp(f, base)` (if `f` +is a log-density function). """ struct DensityMeasure{F,B} <: AbstractMeasure f::F @@ -120,48 +121,77 @@ end end function Pretty.tile(μ::DensityMeasure{F,B}) where {F,B} - result = Pretty.literal("DensityMeasure ∫(") + result = Pretty.literal("mintegrate(") result *= Pretty.pair_layout(Pretty.tile(μ.f), Pretty.tile(μ.base); sep = ", ") result *= Pretty.literal(")") end -export ∫ +basemeasure(μ::DensityMeasure) = μ.base + +logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) + +density_def(μ::DensityMeasure, x) = densityof(μ.f, x) -""" - ∫(f, base::AbstractMeasure) -Define a new measure in terms of a density `f` over some measure `base`. + +@doc raw""" + mintegrate(f, μ::AbstractMeasure)::AbstractMeasure + +Returns a new measure that represents the indefinite +[integral](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `f` with respect to `μ`. + +`ν = mintegrate(f, μ)` generates a measure `ν` that has the mathematical +interpretation + +math``` +\nu(A) = \int_A f(a) \, \rm{d}\mu(a) +``` """ -∫(f, base) = _densitymeasure(f, base, DensityKind(f)) +function mintegrate end +export mintegrate + +mintegrate(f, μ::AbstractMeasure) = _mintegrate_impl(f, μ, DensityKind(f)) -_densitymeasure(f, base, ::IsDensity) = DensityMeasure(f, base) -function _densitymeasure(f, base, ::HasDensity) - @error "`∫(f, base)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`." +_mintegrate_impl(f, μ, ::IsDensity) = DensityMeasure(f, μ) +function _mintegrate_impl(f, μ, ::HasDensity) + throw(ArgumentError( "`mintegrate(f, mu)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`.")) end -_densitymeasure(f, base, ::NoDensity) = DensityMeasure(funcdensity(f), base) +_mintegrate_impl(f, μ, ::NoDensity) = DensityMeasure(funcdensity(f), μ) -export ∫exp -""" - ∫exp(f, base::AbstractMeasure) +@doc raw""" + mintegrate_exp(log_f, μ::AbstractMeasure) + +Given a function `log_f` that semantically represents the log of a function +`f`, `mintegrate` returns a new measure that represents the indefinite +[integral](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `f` with respect to `μ`. + +`ν = mintegrate_exp(log_f, μ)` generates a measure `ν` that has the +mathematical interpretation -Define a new measure in terms of a log-density `f` over some measure `base`. +math``` +\nu(A) = \int_A e^{log(f(a))} \, \rm{d}\mu(a) = \int_A f(a) \, \rm{d}\mu(a) +``` + +Note that `exp(log_f(...))` is usually not run explicitly, calculations that +involve the resulting measure are typically performed in log-space, +internally. """ -∫exp(f, base) = _logdensitymeasure(f, base, DensityKind(f)) +function mintegrate_exp end +export mintegrate_exp + +mintegrate_exp(log_f, μ::AbstractMeasure) = _mintegrate_exp_impl(log_f, μ, DensityKind(log_f)) -function _logdensitymeasure(f, base, ::IsDensity) - @error "`∫exp(f, base)` is not valid when `DensityKind(f) == IsDensity()`. Use `∫(f, base)` instead." +function _mintegrate_exp_impl(log_f, μ, ::IsDensity) + throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegral(log_f, μ)` instead.")) end -function _logdensitymeasure(f, base, ::HasDensity) - @error "`∫exp(f, base)` is not valid when `DensityKind(f) == HasDensity()`." +function _mintegrate_exp_impl(log_f, μ, ::HasDensity) + throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.")) end -_logdensitymeasure(f, base, ::NoDensity) = DensityMeasure(logfuncdensity(f), base) +_mintegrate_exp_impl(log_f, μ, ::NoDensity) = DensityMeasure(logfuncdensity(log_f), μ) -basemeasure(μ::DensityMeasure) = μ.base - -logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) - -density_def(μ::DensityMeasure, x) = densityof(μ.f, x) """ rebase(μ, ν) @@ -172,4 +202,4 @@ basemeasure(rebase(μ, ν)) == ν density(rebase(μ, ν)) == 𝒹(μ,ν) ``` """ -rebase(μ, ν) = ∫(𝒹(μ, ν), ν) +rebase(μ, ν) = mintegrate(density_rel(μ, ν), ν) From ce7c5be5abae28d713a67bb2eb7f4b9784247478 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 15:59:42 +0200 Subject: [PATCH 006/133] Remove the rebase function A rebase can easily be written explicitly. --- src/MeasureBase.jl | 1 - src/density.jl | 12 ------------ 2 files changed, 13 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 8d1a4bec..0e8902da 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -39,7 +39,6 @@ using FunctionChains export ≪ export gentype -export rebase export AbstractMeasure diff --git a/src/density.jl b/src/density.jl index a3b1b95e..ea976462 100644 --- a/src/density.jl +++ b/src/density.jl @@ -191,15 +191,3 @@ function _mintegrate_exp_impl(log_f, μ, ::HasDensity) throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.")) end _mintegrate_exp_impl(log_f, μ, ::NoDensity) = DensityMeasure(logfuncdensity(log_f), μ) - - -""" - rebase(μ, ν) - -Express `μ` in terms of a density over `ν`. Satisfies -``` -basemeasure(rebase(μ, ν)) == ν -density(rebase(μ, ν)) == 𝒹(μ,ν) -``` -""" -rebase(μ, ν) = mintegrate(density_rel(μ, ν), ν) From debaea274e33d74c9a5d5a6366da29fe25f5f14b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 16:19:07 +0200 Subject: [PATCH 007/133] Rename bind to mbind and remove fish operator --- src/combinators/bind.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 491bd5a2..e987ea8e 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -33,17 +33,3 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} y = rand(rng, T, d.k(x)) return y end - - -# ToDo: Remove `bind` (breaking). -@noinline function bind(μ, k) - Base.depwarn("`foo(μ, k)` is deprecated, use `mbind(k, μ)` instead.", :bind) - mbind(k, μ) -end - - -# ToDo: Remove `↣` (breaking): It looks too similar to the `>=>` "fish" -# operator (e.g. in Haskell) that is typically understood to take two monadic -# functions as an argument, while a bind take a monad and a monadic functions. -@deprecate ↣(μ, k) mbind(μ, k) -export ↣ From 1401086661778f4c8b347018863914f1150b39e9 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 23 Jun 2023 10:55:56 +0200 Subject: [PATCH 008/133] Change field order of Bind and improve docs. --- src/combinators/bind.jl | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index e987ea8e..a27be6e0 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -1,6 +1,20 @@ +""" + struct MeasureBase.Bind{M,K} <: AbstractMeasure + +Represents a monatic bind. User code should create instances of `Bind` +directly, but should call `mbind(k, μ)` instead. +""" struct Bind{M,K} <: AbstractMeasure - μ::M k::K + μ::M +end + +getdof(d::Bind) = NoDOF{typeof(d)}() + +function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} + x = rand(rng, T, d.μ) + y = rand(rng, T, d.k(x)) + return y end @@ -25,11 +39,5 @@ unavailable in Julia. end ``` """ -mbind(k, μ) = Bind(μ, k) +mbind(k, μ) = Bind(k, μ) export mbind - -function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} - x = rand(rng, T, d.μ) - y = rand(rng, T, d.k(x)) - return y -end From 13f41292eca78ffbc0ee94c57a483a14b795bb69 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 16:31:59 +0200 Subject: [PATCH 009/133] Remove operator otimes To be re-introduced in sub-module MeasureOperators. --- src/combinators/product.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/combinators/product.jl b/src/combinators/product.jl index cb7a0aaf..3a5c6494 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -167,18 +167,6 @@ function testvalue(::Type{T}, d::AbstractProductMeasure) where {T} _map(m -> testvalue(T, m), marginals(d)) end -export ⊗ - -""" - ⊗(μs::AbstractMeasure...) - -`⊗` is a binary operator for building product measures. This satisfies the law - -``` - basemeasure(μ ⊗ ν) == basemeasure(μ) ⊗ basemeasure(ν) -``` -""" -⊗(μs::AbstractMeasure...) = productmeasure(μs) ############################################################################### # I <: Base.Generator From 01e4f898d99c6e7a00e67d4c67326f5a443f53bf Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 16:35:59 +0200 Subject: [PATCH 010/133] Removes PointwiseProductMeasure `mintegral` should be used instead to express posteriors. --- src/MeasureBase.jl | 1 - src/combinators/pointwise.jl | 30 ------------------------------ 2 files changed, 31 deletions(-) delete mode 100644 src/combinators/pointwise.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 0e8902da..072c52c6 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -134,7 +134,6 @@ include("combinators/product.jl") include("combinators/power.jl") include("combinators/spikemixture.jl") include("combinators/likelihood.jl") -include("combinators/pointwise.jl") include("combinators/restricted.jl") include("combinators/smart-constructors.jl") include("combinators/conditional.jl") diff --git a/src/combinators/pointwise.jl b/src/combinators/pointwise.jl deleted file mode 100644 index 778e7f4e..00000000 --- a/src/combinators/pointwise.jl +++ /dev/null @@ -1,30 +0,0 @@ -export ⊙ - -struct PointwiseProductMeasure{P,L} <: AbstractMeasure - prior::P - likelihood::L -end - -iterate(p::PointwiseProductMeasure, i = 1) = iterate((p.prior, p.likelihood), i) - -function Pretty.tile(d::PointwiseProductMeasure) - Pretty.pair_layout(Pretty.tile(d.prior), Pretty.tile(d.likelihood), sep = " ⊙ ") -end - -⊙(prior, ℓ) = pointwiseproduct(prior, ℓ) - -@inbounds function insupport(d::PointwiseProductMeasure, p) - prior, ℓ = d - istrue(insupport(prior, p)) && istrue(insupport(ℓ, p)) -end - -@inline function logdensity_def(d::PointwiseProductMeasure, p) - prior, ℓ = d - unsafe_logdensityof(ℓ, p) -end - -basemeasure(d::PointwiseProductMeasure) = d.prior - -function gentype(d::PointwiseProductMeasure) - gentype(d.prior) -end From 55c12d7d2dc32a3fd4bef317323b0452673a11c7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 22:59:03 +0200 Subject: [PATCH 011/133] Remove scrd operator To be reintroduced in submodule MeasureOperators --- src/density.jl | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/src/density.jl b/src/density.jl index ea976462..bba69985 100644 --- a/src/density.jl +++ b/src/density.jl @@ -20,8 +20,7 @@ For measures `μ` and `ν`, `Density(μ,ν)` represents the _density function_ `dμ/dν`, also called the _Radom-Nikodym derivative_: https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem#Radon%E2%80%93Nikodym_derivative -Instead of calling this directly, users should call `density_rel(μ, ν)` or -its abbreviated form, `𝒹(μ,ν)`. +Instead of calling this directly, users should call `density_rel(μ, ν)`. """ struct Density{M,B} <: AbstractDensity μ::M @@ -32,15 +31,6 @@ Base.:∘(::typeof(log), d::Density) = logdensity_rel(d.μ, d.base) Base.log(d::Density) = log ∘ d -export 𝒹 - -""" - 𝒹(μ, base) - -Compute the density (Radom-Nikodym derivative) of μ with respect to `base`. This -is a shorthand form for `density_rel(μ, base)`. -""" -𝒹(μ, base) = density_rel(μ, base) density_rel(μ, base) = Density(μ, base) @@ -73,16 +63,6 @@ Base.:∘(::typeof(exp), d::LogDensity) = density_rel(d.μ, d.base) Base.exp(d::LogDensity) = exp ∘ d -export log𝒹 - -""" - log𝒹(μ, base) - -Compute the log-density (Radom-Nikodym derivative) of μ with respect to `base`. -This is a shorthand form for `logdensity_rel(μ, base)` -""" -log𝒹(μ, base) = logdensity_rel(μ, base) - logdensity_rel(μ, base) = LogDensity(μ, base) (f::LogDensity)(x) = logdensity_rel(f.μ, f.base, x) From 5dd3fe5be7ff0d3f2cb2e09d3e59aa2a518fbf2c Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 23 Jun 2023 00:29:42 +0200 Subject: [PATCH 012/133] Remove ll-operator Absolute continuity is not really implemented yet. --- src/MeasureBase.jl | 1 - src/absolutecontinuity.jl | 3 +++ src/combinators/weighted.jl | 3 --- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 072c52c6..b6a28cfd 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -37,7 +37,6 @@ using Static using Static: StaticInteger using FunctionChains -export ≪ export gentype export AbstractMeasure diff --git a/src/absolutecontinuity.jl b/src/absolutecontinuity.jl index 8062198c..c65aeaf3 100644 --- a/src/absolutecontinuity.jl +++ b/src/absolutecontinuity.jl @@ -54,3 +54,6 @@ # representative(μ) ≪ representative(ν) && return true # return false # end + +# ≪(::M, ::WeightedMeasure{R,M}) where {R,M} = true +# ≪(::WeightedMeasure{R,M}, ::M) where {R,M} = true diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index db239b50..124662b6 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -46,9 +46,6 @@ end Base.:*(m::AbstractMeasure, k::Real) = k * m -≪(::M, ::WeightedMeasure{R,M}) where {R,M} = true -≪(::WeightedMeasure{R,M}, ::M) where {R,M} = true - gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) From 49cd7765cfc6abb5279d501999206fca42f26429 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 22 Jun 2023 13:27:07 +0200 Subject: [PATCH 013/133] Add measure operators in submodule MeasureOperators Having the operators in a sub-module makes it easier for users to control whether of they want them in their namespace. Operators have a larger naming conflict potential. --- src/MeasureBase.jl | 2 + src/measure_operators.jl | 141 ++++++++++++++++++++++++++++++++++++++ test/measure_operators.jl | 24 +++++++ test/runtests.jl | 2 + 4 files changed, 169 insertions(+) create mode 100644 src/measure_operators.jl create mode 100644 test/measure_operators.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index b6a28cfd..eeea2ad8 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -152,6 +152,8 @@ include("density-core.jl") include("interface.jl") +include("measure_operators.jl") + using .Interface end # module MeasureBase diff --git a/src/measure_operators.jl b/src/measure_operators.jl new file mode 100644 index 00000000..ee8a7d8d --- /dev/null +++ b/src/measure_operators.jl @@ -0,0 +1,141 @@ +""" + module MeasureOperators + +Defines the following operators for measures: + +* `f ⋄ μ == pushfwd(f, μ)` + +* `μ ⊙ f == inverse(f) ⋄ μ` +""" +module MeasureOperators + +using MeasureBase: AbstractMeasure +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel +using InverseFunctions: inverse +using Reexport: @reexport + + +@doc raw""" + ⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) + +The `\\diamond` operator denotes a pushforward operation: `ν = f ⋄ μ` +generates a +[pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure). + +A common mathematical notation for a pushforward is ``f_*μ``, but as +there is no "subscript-star" operator in Julia, we use `⋄`. + +See [`pushfwd(f, μ)`](@ref) for details. + +Also see [`ν ⊙ f`](@ref), the pullback operator. +""" +⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) +export ⋄ + + +@doc raw""" + ⊙(ν::AbstractMeasure, f) = pullbck(f, ν) + +The `\\odot` operator denotes a pullback operation. + +See also [`pullbck(ν, f)`](@ref) for details. Note that `pullbck` takes it's +arguments in different order, in keeping with the Julia convention of +passing functions as the first argument. A pullback is mathematically the +precomposition of a measure `μ`` with the function `f` applied to sets. so +`⊙` takes the measure as the first and the function as the second argument, +as common in mathematical notation for precomposition. + +A common mathematical notation for pullback in measure theory is +``f \circ μ``, but as `∘` is used for function composition in Julia and as +`f` semantically acts point-wise on sets, we use `⊙`. + +Also see [f ⋄ μ](@ref), the pushforward operator. +""" +⊙(ν::AbstractMeasure, f) = pullbck(f, ν) +export ⊙ + + +""" + μ ▷ k = mbind(k, μ) + +The `\\triangleright` operator denotes a measure monadic bind operation. + +A common operator choice for a monadics bind operator is `>>=` (e.g. in +the Haskell programming language), but this has a different meaning in +Julia and there is no close equivalent, so we use `▷`. + +See [`mbind(k, μ)`](@ref) for details. Note that `mbind` takes its +arguments in different order, in keeping with the Julia convention of +passing functions as the first argument. `▷`, on the other hand, takes +its arguments in the order common for monadic binds in functional +programming (like the Haskell `>>=` operator) and mathematics. +""" +▷(μ::AbstractMeasure,k) = mbind(k, μ) +export ▷ + + +# ToDo: Use `⨂` instead of `⊗` for better readability? +""" + ⊗(μs::AbstractMeasure...) = productmeasure(μs) + +`⊗` is an operator for building product measures. + +See [`productmeasure(μs)`](@ref) for details. +""" +⊗(μs::AbstractMeasure...) = productmeasure(μs) +export ⊗ + + +""" + ∫(f, μ::AbstractMeasure) = mintegrate(f, μ) + +Denotes an indefinite integral of the function `f` with respect to the +measure `μ`. + +See [`mintegrate(f, μ)`](@ref) for details. +""" +∫(f, μ::AbstractMeasure) = mintegrate(f, μ) +export ∫ + + +""" + ∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) + +Generates a new measure that is the indefinite integral of `exp` of `f` +with respect to the measure `μ`. + +See [`mintegrate_exp(f, μ)`](@ref) for details. +""" +∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) +export ∫exp + + +""" + 𝒹(ν, μ) = density_rel(ν, μ) + +Compute the density, i.e. the +[Radom-Nikodym derivative](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `ν`` with respect to `μ`. + +For details, see [`density_rel(ν, μ)`}(@ref). +""" +𝒹(ν, μ::AbstractMeasure) = density_rel(ν, μ) +export 𝒹 + + + +""" + log𝒹(ν, μ) = logdensity_rel(ν, μ) + +Compute the log-density, i.e. the logarithm of the +[Radom-Nikodym derivative](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `ν`` with respect to `μ`. + +For details, see [`logdensity_rel(ν, μ)`}(@ref). +""" +log𝒹(ν, μ::AbstractMeasure) = logdensity_rel(ν, μ) +export log𝒹 + + +end # module MeasureOperators diff --git a/test/measure_operators.jl b/test/measure_operators.jl new file mode 100644 index 00000000..a3adaa8f --- /dev/null +++ b/test/measure_operators.jl @@ -0,0 +1,24 @@ +using Test + +using MeasureBase: AbstractMeasure +using MeasureBase: StdExponential, StdLogistic, StdUniform +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel +using MeasureBase.MeasureOperators: ⋄, ⊙, ▷, ⊗, ∫, ∫exp, 𝒹, log𝒹 + +@testset "MeasureOperators" begin + μ = StdExponential() + ν = StdUniform() + k(σ) = pushfwd(x -> σ * x, StdNormal()) + μs = (StdExponential(), StdLogistic(), StdUniform()) + f = sqrt + + @test @inferred(f ⋄ μ) == pushfwd(f, μ) + @test @inferred(ν ⊙ f) == pullbck(f, ν) + @test @inferred(μ ▷ k) == mbind(k, μ) + @test @inferred(⊗(μs...)) == productmeasure(μs) + @test @inferred(∫(f, μ)) == mintegrate(f, μ) + @test @inferred(∫exp(f, μ)) == mintegrate_exp(f, μ) + @test @inferred(𝒹(ν, μ)) == density_rel(ν, μ) + @test @inferred(log𝒹(ν, μ)) == logdensity_rel(ν, μ) +end diff --git a/test/runtests.jl b/test/runtests.jl index f9263b6d..364d0091 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,4 +20,6 @@ include("smf.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") +include("measure_operators.jl") + include("test_docs.jl") From 62c1eef97d22b88fdb2c30fff1a03914e686c2c3 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 23 Jun 2023 10:43:15 +0200 Subject: [PATCH 014/133] Improve docstring for mbind Co-authored-by: Chad Scherrer --- src/combinators/bind.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index a27be6e0..4b34e7a2 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -28,6 +28,9 @@ Given measure The *monadic bind* operation `mbind(k, μ)` returns is a new measure. +If `ν == mbind(k, μ)` and all measures involved are sampleable, then +samples from `rand(ν)` follow the same distribution as those from `rand(k(rand(μ)))`. + A monadic bind ofen written as `>>=` (e.g. in Haskell), but this symbol is unavailable in Julia. From 3b5b8e053776be5b02a1ea54c3d1c9bcdeb99c2d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 23 Jun 2023 10:57:13 +0200 Subject: [PATCH 015/133] Improve likelihood docs Co-authored-by: Chad Scherrer --- src/combinators/likelihood.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 93dc1186..c5229b38 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -93,7 +93,7 @@ Finally, let's return to the expression for Bayes's Law, ``P(θ|x) ∝ P(x|θ) P(θ)`` -In measure theory, the product on the right side is actually the Lebesgue integral, +In measure theory, the product on the right side is the Lebesgue integral of the likelihood with respect to the prior. For example, say we have From 24a102343ff0fb4f0c3447d01bb571f4abce1228 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 23 Jun 2023 11:34:33 +0200 Subject: [PATCH 016/133] Apply JuliaFormatter --- src/combinators/bind.jl | 1 - src/combinators/likelihood.jl | 28 +--------------------------- src/combinators/product.jl | 1 - src/density.jl | 26 ++++++++++++++++++-------- src/measure_operators.jl | 12 +----------- src/static.jl | 4 +++- test/static.jl | 11 +++++++---- 7 files changed, 30 insertions(+), 53 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 4b34e7a2..465f2bf7 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -17,7 +17,6 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} return y end - """ mbind(k, μ)::AbstractMeasure diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index c5229b38..9b6ac567 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -64,31 +64,6 @@ With several parameters, things work as expected: --------- - Likelihood(M<:ParameterizedMeasure, constraint::NamedTuple, x) - -In some cases the measure might have several parameters, and we may want the -(log-)likelihood with respect to some subset of them. In this case, we can use -the three-argument form, where the second argument is a constraint. For example, - - julia> ℓ = Likelihood(Normal{(:μ,:σ)}, (σ=3.0,), 2.0) - Likelihood(Normal{(:μ, :σ), T} where T, (σ = 3.0,), 2.0) - -Similarly to the above, we have - - julia> density_def(ℓ, (μ=2.0,)) - 0.3333333333333333 - - julia> logdensity_def(ℓ, (μ=2.0,)) - -1.0986122886681098 - - julia> density_def(ℓ, 2.0) - 0.3333333333333333 - - julia> logdensity_def(ℓ, 2.0) - -1.0986122886681098 - ------------------------ - Finally, let's return to the expression for Bayes's Law, ``P(θ|x) ∝ P(x|θ) P(θ)`` @@ -118,7 +93,7 @@ struct Likelihood{K,X} <: AbstractLikelihood x::X Likelihood(k::K, x::X) where {K,X} = new{K,X}(k, x) -#!!!!!!!!!!! # For type stability if `K isa UnionAll (e.g. a parameterized MeasureType)` + #!!!!!!!!!!! # For type stability if `K isa UnionAll (e.g. a parameterized MeasureType)` Likelihood(::Type{K}, x::X) where {K<:AbstractMeasure,X} = new{K,X}(K, x) end @@ -191,7 +166,6 @@ function likelihoodof end likelihoodof(k, x) = Likelihood(k, x) - ############################################################################### # At the least, we need to think through in some more detail whether # (log-)likelihood ratios expressed in this way are correct and useful. For now diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 3a5c6494..516678f5 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -167,7 +167,6 @@ function testvalue(::Type{T}, d::AbstractProductMeasure) where {T} _map(m -> testvalue(T, m), marginals(d)) end - ############################################################################### # I <: Base.Generator diff --git a/src/density.jl b/src/density.jl index bba69985..0f4c8e03 100644 --- a/src/density.jl +++ b/src/density.jl @@ -31,7 +31,6 @@ Base.:∘(::typeof(log), d::Density) = logdensity_rel(d.μ, d.base) Base.log(d::Density) = log ∘ d - density_rel(μ, base) = Density(μ, base) (f::Density)(x) = density_rel(f.μ, f.base, x) @@ -112,8 +111,6 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) density_def(μ::DensityMeasure, x) = densityof(μ.f, x) - - @doc raw""" mintegrate(f, μ::AbstractMeasure)::AbstractMeasure @@ -135,11 +132,14 @@ mintegrate(f, μ::AbstractMeasure) = _mintegrate_impl(f, μ, DensityKind(f)) _mintegrate_impl(f, μ, ::IsDensity) = DensityMeasure(f, μ) function _mintegrate_impl(f, μ, ::HasDensity) - throw(ArgumentError( "`mintegrate(f, mu)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`.")) + throw( + ArgumentError( + "`mintegrate(f, mu)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`.", + ), + ) end _mintegrate_impl(f, μ, ::NoDensity) = DensityMeasure(funcdensity(f), μ) - @doc raw""" mintegrate_exp(log_f, μ::AbstractMeasure) @@ -162,12 +162,22 @@ internally. function mintegrate_exp end export mintegrate_exp -mintegrate_exp(log_f, μ::AbstractMeasure) = _mintegrate_exp_impl(log_f, μ, DensityKind(log_f)) +function mintegrate_exp(log_f, μ::AbstractMeasure) + _mintegrate_exp_impl(log_f, μ, DensityKind(log_f)) +end function _mintegrate_exp_impl(log_f, μ, ::IsDensity) - throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegral(log_f, μ)` instead.")) + throw( + ArgumentError( + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegral(log_f, μ)` instead.", + ), + ) end function _mintegrate_exp_impl(log_f, μ, ::HasDensity) - throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.")) + throw( + ArgumentError( + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.", + ), + ) end _mintegrate_exp_impl(log_f, μ, ::NoDensity) = DensityMeasure(logfuncdensity(log_f), μ) diff --git a/src/measure_operators.jl b/src/measure_operators.jl index ee8a7d8d..5822d4de 100644 --- a/src/measure_operators.jl +++ b/src/measure_operators.jl @@ -15,7 +15,6 @@ using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel using InverseFunctions: inverse using Reexport: @reexport - @doc raw""" ⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) @@ -33,7 +32,6 @@ Also see [`ν ⊙ f`](@ref), the pullback operator. ⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) export ⋄ - @doc raw""" ⊙(ν::AbstractMeasure, f) = pullbck(f, ν) @@ -55,7 +53,6 @@ Also see [f ⋄ μ](@ref), the pushforward operator. ⊙(ν::AbstractMeasure, f) = pullbck(f, ν) export ⊙ - """ μ ▷ k = mbind(k, μ) @@ -71,10 +68,9 @@ passing functions as the first argument. `▷`, on the other hand, takes its arguments in the order common for monadic binds in functional programming (like the Haskell `>>=` operator) and mathematics. """ -▷(μ::AbstractMeasure,k) = mbind(k, μ) +▷(μ::AbstractMeasure, k) = mbind(k, μ) export ▷ - # ToDo: Use `⨂` instead of `⊗` for better readability? """ ⊗(μs::AbstractMeasure...) = productmeasure(μs) @@ -86,7 +82,6 @@ See [`productmeasure(μs)`](@ref) for details. ⊗(μs::AbstractMeasure...) = productmeasure(μs) export ⊗ - """ ∫(f, μ::AbstractMeasure) = mintegrate(f, μ) @@ -98,7 +93,6 @@ See [`mintegrate(f, μ)`](@ref) for details. ∫(f, μ::AbstractMeasure) = mintegrate(f, μ) export ∫ - """ ∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) @@ -110,7 +104,6 @@ See [`mintegrate_exp(f, μ)`](@ref) for details. ∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) export ∫exp - """ 𝒹(ν, μ) = density_rel(ν, μ) @@ -123,8 +116,6 @@ For details, see [`density_rel(ν, μ)`}(@ref). 𝒹(ν, μ::AbstractMeasure) = density_rel(ν, μ) export 𝒹 - - """ log𝒹(ν, μ) = logdensity_rel(ν, μ) @@ -137,5 +128,4 @@ For details, see [`logdensity_rel(ν, μ)`}(@ref). log𝒹(ν, μ::AbstractMeasure) = logdensity_rel(ν, μ) export log𝒹 - end # module MeasureOperators diff --git a/src/static.jl b/src/static.jl index b723d043..da471b62 100644 --- a/src/static.jl +++ b/src/static.jl @@ -49,7 +49,9 @@ Returns the length of `x` as a dynamic or static integer. """ maybestatic_length(x) = length(x) maybestatic_length(x::AbstractUnitRange) = length(x) -function maybestatic_length(::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}) where {A,B} +function maybestatic_length( + ::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}, +) where {A,B} StaticInt{B - A + 1}() end diff --git a/test/static.jl b/test/static.jl index a6c50db2..f618124b 100644 --- a/test/static.jl +++ b/test/static.jl @@ -11,7 +11,7 @@ import FillArrays @test static(2) isa MeasureBase.IntegerLike @test true isa MeasureBase.IntegerLike @test static(true) isa MeasureBase.IntegerLike - + @test @inferred(MeasureBase.one_to(7)) isa Base.OneTo @test @inferred(MeasureBase.one_to(7)) == 1:7 @test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo @@ -19,10 +19,13 @@ import FillArrays @test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7) @test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == FillArrays.Fill(4.2, 3, 7) + @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == + FillArrays.Fill(4.2, 3, 7) @test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == FillArrays.Fill(4.2, (3:7, 2:5)) + @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == + FillArrays.Fill(4.2, (3:7,)) + @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == + FillArrays.Fill(4.2, (3:7, 2:5)) @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) == 7 From 3bf6b0e6b5befe1d723e00e1ef12b6d442dfd4d6 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 23 Jun 2023 11:47:49 +0200 Subject: [PATCH 017/133] Improve Likelihood ctor --- src/combinators/likelihood.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 9b6ac567..b244fd0f 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -92,11 +92,12 @@ struct Likelihood{K,X} <: AbstractLikelihood k::K x::X - Likelihood(k::K, x::X) where {K,X} = new{K,X}(k, x) - #!!!!!!!!!!! # For type stability if `K isa UnionAll (e.g. a parameterized MeasureType)` - Likelihood(::Type{K}, x::X) where {K<:AbstractMeasure,X} = new{K,X}(K, x) + Likelihood{K,X}(k, x) where {K,X} = new{K,X}(k, x) end +# For type stability, in case k is a type (resp. a constructor): +Likelihood(k, x::X) where {X} = Likelihood{Core.Typeof(k),X}(k, x) + (lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity() From b60b57435e98c790721ccb2933d92106fbc4988f Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 10:06:57 +0200 Subject: [PATCH 018/133] Fix typo in _mintegrate_exp_impl exception --- src/density.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/density.jl b/src/density.jl index 0f4c8e03..57367ec5 100644 --- a/src/density.jl +++ b/src/density.jl @@ -169,7 +169,7 @@ end function _mintegrate_exp_impl(log_f, μ, ::IsDensity) throw( ArgumentError( - "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegral(log_f, μ)` instead.", + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegrate(log_f, μ)` instead.", ), ) end From f72c6dd8303719cf3096ba2cd6c1100fc186b948 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 10:52:56 +0200 Subject: [PATCH 019/133] Add HierarchicalMeasure --- src/MeasureBase.jl | 1 + src/combinators/hierarchical.jl | 113 ++++++++++++++++++++++++++++++++ src/density-core.jl | 41 +++++++++++- src/getdof.jl | 4 ++ 4 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 src/combinators/hierarchical.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index eeea2ad8..289ffd25 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -143,6 +143,7 @@ include("standard/stdexponential.jl") include("standard/stdlogistic.jl") include("standard/stdnormal.jl") include("combinators/half.jl") +include("combinators/hierarchical.jl") include("rand.jl") include("fixedrng.jl") diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl new file mode 100644 index 00000000..290ecf48 --- /dev/null +++ b/src/combinators/hierarchical.jl @@ -0,0 +1,113 @@ +export HierarchicalMeasure + + +struct HierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure + f::F + m::M + dof_m::Int +end + + +function HierarchicalMeasure(f, m::AbstractMeasure, ::NoDOF) + throw(ArgumentError("Primary measure in HierarchicalMeasure must have fixed and known DOF")) +end + +HierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, dynamic(getdof(m))) + + +function _split_variate(h::HierarchicalMeasure, x) + # TODO: Splitting x will be more complicated in general: + x_primary, x_secondary = x + return (x_primary, x_secondary) +end + + +function _combine_variates(x_primary, x_secondary) + # TODO: Must offer optional flattening + return (x_primary, x_secondary) +end + + +function local_measure(h::HierarchicalMeasure, x) + x_primary, x_secondary = _split_variate(h, x) + m_primary = h.m + m_primary_local = local_measure(m_primary, x_primary) + m_secondary = m.f(x_secondary) + m_secondary_local = local_measure(m_secondary, x_secondary) + # TODO: Must optionally return a flattened product measure + return productmeasure(m_primary_local, m_secondary_local) +end + + +@inline function insupport(h::HierarchicalMeasure, x) + # Only test primary for efficiency: + x_primary = _split_variate(h, x)[1] + insupport(h.m, x_primary) +end + + +#!!!!!!! WON'T WORK: Only use primary measure for efficiency: +logdensity_type(h::HierarchicalMeasure{F,M}, ::Type{T}) where {F,M,T} = unstatic(float(logdensity_type(M, T))) + +# Can't implement logdensity_def(::HierarchicalMeasure, x) directly. + +# Can't implement getdof(::HierarchicalMeasure) efficiently + +# No way to return a functional base measure: +struct _BaseOfHierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure end +@inline basemeasure(::HierarchicalMeasure{F,M}) where {F,M} = _BaseOfHierarchicalMeasure{F,M}() + +@inline getdof(μ::HierarchicalMeasure) = NoDOF{typeof(μ)}() + +# Bypass `checked_arg`, would require potentially costly evaluation of h.f: +@inline checked_arg(::HierarchicalMeasure, x) = x + +function unsafe_logdensityof(h::HierarchicalMeasure, x) + x_primary, x_secondary = _split_variate(h, x) + h_primary, h_secondary = h.m, h.f(x_secondary) + unsafe_logdensityof(h_primary, x_primary) + logdensityof(h_secondary, x_secondary) +end + + +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::HierarchicalMeasure) where {T<:Real} + x_primary = rand(rng, T, h.m) + x_secondary = rand(rng, T, h.f(x_primary)) + return _combine_variates(x_primary, x_secondary) +end + + +function _split_measure_at(μ::PowerMeasure{M, Tuple{R}}, n::Integer) where {M<:StdMeasure,R} + dof_μ = getdof(μ) + return M()^n, M()^(dof_μ - n) +end + + +function transport_def( + ν::PowerMeasure{M, Tuple{R}}, + μ::HierarchicalMeasure, + x, +) where {M<:StdMeasure,R} + ν_primary, ν_secondary = _split_measure_at(ν, μ.dof_m) + x_primary, x_secondary = _split_variate(μ, x) + μ_primary = μ.m + μ_secondary = μ.f(x_secondary) + y_primary = transport_to(ν_primary, μ_primary, x_primary) + y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) + return vcat(y_primary, y_secondary) +end + + +function transport_def( + ν::HierarchicalMeasure, + μ::PowerMeasure{M, Tuple{R}}, + x, +) where {M<:StdMeasure,R} + dof_primary = ν.dof_m + μ_primary, μ_secondary = _split_measure_at(μ, dof_primary) + x_primary, x_secondary = x[begin:begin+dof_primary-1], x[begin+dof_primary:end] + ν_primary = ν.m + y_primary = transport_to(ν_primary, μ_primary, x_primary) + ν_secondary = ν.f(y_primary) + y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) + return _combine_variates(y_primary, y_secondary) +end diff --git a/src/density-core.jl b/src/density-core.jl index c8c861ee..655c6d6a 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -1,3 +1,5 @@ +export local_measure + export logdensityof export logdensity_rel export logdensity_def @@ -9,6 +11,25 @@ export densityof export density_rel export density_def + +""" + local_measure(m::AbstractMeasure, x)::AbstractMeasure + +Return a local measure of `m` at `x` which will be `m` itself for many +measures. + +A local measure of `m` is defined here as a measure that behaves like `m` in +the infinitesimal neighborhood of `x`. + +Note that the resulting measure may not be well defined outside of such a +neighborhood of `x`. + +See [`HierarchicalMeasure`](@ref) as an example of a measure where +`local_measure` returns different measures depending on `x`. +""" +local_measure(m::AbstractMeasure, x) = m + + """ logdensityof(m::AbstractMeasure, x) @@ -72,6 +93,17 @@ See also `logdensityof`. return ℓ_10 end + +""" + logdensity_type(m::AbstractMeasure}, ::Type{T}) where T + +Compute the return type of `logdensity_of(m, ::T)`. +""" +function logdensity_type(m::M,T) where {M<:AbstractMeasure} + Core.Compiler.return_type(logdensity_def, Tuple{M, T}) +end + + """ logdensity_rel(m1, m2, x) @@ -83,8 +115,8 @@ known to be in the support of both, it can be more efficient to call @inline function logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} T = unstatic( promote_type( - return_type(logdensity_def, (μ, x)), - return_type(logdensity_def, (ν, x)), + logdensity_type(μ, X), + logdensity_type(ν, X), ), ) inμ = insupport(μ, x) @@ -92,7 +124,10 @@ known to be in the support of both, it can be more efficient to call istrue(inμ) || return convert(T, ifelse(inν, -Inf, NaN)) istrue(inν) || return convert(T, Inf) - return unsafe_logdensity_rel(μ, ν, x) + μ_local = localmeasure(μ, x) + ν_local = localmeasure(ν, x) + + return unsafe_logdensity_rel(μ_local, ν_local, x) end """ diff --git a/src/getdof.jl b/src/getdof.jl index 4496b7f2..2c0bb60c 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -38,6 +38,10 @@ function check_dof end function check_dof(ν, μ) n_ν = getdof(ν) n_μ = getdof(μ) + # TODO: How to handle this better if DOF is unclear e.g. for HierarchicalMeasures? + if n_ν isa NoDOF || n_μ isa NoDOF + return true + end if n_ν != n_μ throw( ArgumentError( From 7157642e27b8fc18338a5c7583a731d375ebf960 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 10:52:56 +0200 Subject: [PATCH 020/133] Rename local_measure, change insupprt handling --- src/combinators/hierarchical.jl | 6 ++--- src/density-core.jl | 43 ++++++++++++++++++++++++++------- src/insupport.jl | 22 ++++++++++++++--- 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl index 290ecf48..32f08940 100644 --- a/src/combinators/hierarchical.jl +++ b/src/combinators/hierarchical.jl @@ -28,12 +28,12 @@ function _combine_variates(x_primary, x_secondary) end -function local_measure(h::HierarchicalMeasure, x) +function localmeasure(h::HierarchicalMeasure, x) x_primary, x_secondary = _split_variate(h, x) m_primary = h.m - m_primary_local = local_measure(m_primary, x_primary) + m_primary_local = localmeasure(m_primary, x_primary) m_secondary = m.f(x_secondary) - m_secondary_local = local_measure(m_secondary, x_secondary) + m_secondary_local = localmeasure(m_secondary, x_secondary) # TODO: Must optionally return a flattened product measure return productmeasure(m_primary_local, m_secondary_local) end diff --git a/src/density-core.jl b/src/density-core.jl index 655c6d6a..ef045b92 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -1,4 +1,4 @@ -export local_measure +export localmeasure export logdensityof export logdensity_rel @@ -13,7 +13,7 @@ export density_def """ - local_measure(m::AbstractMeasure, x)::AbstractMeasure + localmeasure(m::AbstractMeasure, x)::AbstractMeasure Return a local measure of `m` at `x` which will be `m` itself for many measures. @@ -25,9 +25,9 @@ Note that the resulting measure may not be well defined outside of such a neighborhood of `x`. See [`HierarchicalMeasure`](@ref) as an example of a measure where -`local_measure` returns different measures depending on `x`. +`localmeasure` returns different measures depending on `x`. """ -local_measure(m::AbstractMeasure, x) = m +localmeasure(m::AbstractMeasure, x) = m """ @@ -113,23 +113,42 @@ known to be in the support of both, it can be more efficient to call `unsafe_logdensity_rel`. """ @inline function logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} + inμ = insupport(μ, x) + inν = insupport(ν, x) + return unsafe_logdensity_rel(μ, ν, x, inμ, inν) +end + + +function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X} T = unstatic( promote_type( logdensity_type(μ, X), logdensity_type(ν, X), ), ) - inμ = insupport(μ, x) - inν = insupport(ν, x) + istrue(inμ) || return convert(T, ifelse(inν, -Inf, NaN)) istrue(inν) || return convert(T, Inf) - μ_local = localmeasure(μ, x) - ν_local = localmeasure(ν, x) + return unsafe_logdensity_rel(μ, ν, x) +end - return unsafe_logdensity_rel(μ_local, ν_local, x) + +function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X} + unsafe_logdensity_rel(μ, ν, x) +end + +function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X} + logd = unsafe_logdensity_rel(μ, ν, x) + return istrue(inμ) ? logd : logd * oftypeof(logd, -Inf) +end + +function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X} + logd = unsafe_logdensity_rel(μ, ν, x) + return istrue(inν) ? logd : logd * oftypeof(logd, +Inf) end + """ unsafe_logdensity_rel(m1, m2, x) @@ -139,6 +158,12 @@ known to be in the support of both `m1` and `m2`. See also `logdensity_rel`. """ @inline function unsafe_logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} + μ_local = localmeasure(μ, x) + ν_local = localmeasure(ν, x) + return _unsafe_logdensity_rel_local(μ_local, ν_local, x) +end + +@inline function _unsafe_logdensity_rel_local(μ::M, ν::N, x::X) where {M,N,X} if static_hasmethod(logdensity_def, Tuple{M,N,X}) return logdensity_def(μ, ν, x) end diff --git a/src/insupport.jl b/src/insupport.jl index 5184917d..e4f44ca8 100644 --- a/src/insupport.jl +++ b/src/insupport.jl @@ -1,20 +1,33 @@ +""" + MeasureBase.NoFastInsupport{MU} + +Indicates that there is no fast way to compute if a point lies within the +support of measures of type `MU` +""" +struct NoFastInsupport{MU} end + + """ inssupport(m, x) insupport(m) -`insupport(m,x)` computes whether `x` is in the support of `m`. +`insupport(m,x)` computes whether `x` is in the support of `m` and +returns either a `Bool` or an instance of [`NoFastInsupport`](@ref). `insupport(m)` returns a function, and satisfies - -insupport(m)(x) == insupport(m, x) +`insupport(m)(x) == insupport(m, x)`` """ function insupport end + """ MeasureBase.require_insupport(μ, x)::Nothing Checks if `x` is in the support of distribution/measure `μ`, throws an `ArgumentError` if not. + +Will not throw an exception if `insupport` returns an instance of +[`NoFastInsupport`](@ref). """ function require_insupport end @@ -24,7 +37,8 @@ function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) end function require_insupport(μ, x) - if !insupport(μ, x) + r = insupport(μ, x) + if !(r isa NoFastInsupport) || r throw(ArgumentError("x is not within the support of μ")) end return nothing From 4acbed96f6e4967da8f422316adc3b4e91755411 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 10:52:56 +0200 Subject: [PATCH 021/133] Don't require primary in HierarchicalMeasure to have known DOF --- src/combinators/hierarchical.jl | 213 ++++++++++++++++++++++---------- src/density-core.jl | 16 ++- src/standard/stdmeasure.jl | 6 + 3 files changed, 167 insertions(+), 68 deletions(-) diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl index 32f08940..0a2b1909 100644 --- a/src/combinators/hierarchical.jl +++ b/src/combinators/hierarchical.jl @@ -1,113 +1,198 @@ export HierarchicalMeasure -struct HierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure +# TODO: Document and use FlattenMode +abstract type FlattenMode end +struct NoFlatten <: FlattenMode end +struct AutoFlatten <: FlattenMode end + + +struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure f::F m::M - dof_m::Int + flatten_mode::FM end +# TODO: Document +const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten} +export HierarchicalProductMeasure -function HierarchicalMeasure(f, m::AbstractMeasure, ::NoDOF) - throw(ArgumentError("Primary measure in HierarchicalMeasure must have fixed and known DOF")) -end +HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten()) + +# TODO: Document +const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten} +export FlatHierarchicalMeasure + +FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten()) + +HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m) -HierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, dynamic(getdof(m))) -function _split_variate(h::HierarchicalMeasure, x) - # TODO: Splitting x will be more complicated in general: - x_primary, x_secondary = x - return (x_primary, x_secondary) +function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2}) + @assert x isa Tuple{2} + return x[1], x[2] end -function _combine_variates(x_primary, x_secondary) - # TODO: Must offer optional flattening - return (x_primary, x_secondary) +function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x) + a_test = testvalue(μ) + return _autosplit_variate_after_testvalue(a_test, x) end +function _autosplit_variate_after_testvalue(::Any, x) + @assert x isa Tuple{2} + return x[1], x[2] +end -function localmeasure(h::HierarchicalMeasure, x) - x_primary, x_secondary = _split_variate(h, x) - m_primary = h.m - m_primary_local = localmeasure(m_primary, x_primary) - m_secondary = m.f(x_secondary) - m_secondary_local = localmeasure(m_secondary, x_secondary) - # TODO: Must optionally return a flattened product measure - return productmeasure(m_primary_local, m_secondary_local) +function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector) + n, m = length(eachindex(a_test)), length(eachindex(x)) + # TODO: Use getindex or view? + return x[begin:n], x[begin+n:m] end +function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M} + return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) +end -@inline function insupport(h::HierarchicalMeasure, x) - # Only test primary for efficiency: - x_primary = _split_variate(h, x)[1] - insupport(h.m, x_primary) +@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} + # TODO: implement + @assert false end -#!!!!!!! WON'T WORK: Only use primary measure for efficiency: -logdensity_type(h::HierarchicalMeasure{F,M}, ::Type{T}) where {F,M,T} = unstatic(float(logdensity_type(M, T))) -# Can't implement logdensity_def(::HierarchicalMeasure, x) directly. +_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b) + + +_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b) + +_autoflat_combine_variates(a::Any, b::Any) = (a, b) + +_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b) + +_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b) + +# TODO: Check that names don't overlap: +_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b) + + +_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2) + +# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2) +# Needs a FlatProductMeasure type. + +function _localmeasure_with_rest(μ::HierarchicalProductMeasure, x) + μ_primary = μ.m + local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x) + μ_secondary = μ.f(x_secondary) + local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary) + return _local_productmeasure(μ.flatten_mode, local_primary, local_secondary), x_rest +end + +function _localmeasure_with_rest(μ::AbstractMeasure, x) + x_checked = checked_arg(μ, x) + return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0) +end + +function localmeasure(μ::HierarchicalProductMeasure, x) + h_local, x_rest = _localmeasure_with_rest(μ, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long while computing localmeasure of HierarchicalMeasure")) + end + return h_local +end -# Can't implement getdof(::HierarchicalMeasure) efficiently -# No way to return a functional base measure: -struct _BaseOfHierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure end -@inline basemeasure(::HierarchicalMeasure{F,M}) where {F,M} = _BaseOfHierarchicalMeasure{F,M}() +@inline insupport(::HierarchicalMeasure, x) = NoFastInsupport() @inline getdof(μ::HierarchicalMeasure) = NoDOF{typeof(μ)}() # Bypass `checked_arg`, would require potentially costly evaluation of h.f: @inline checked_arg(::HierarchicalMeasure, x) = x -function unsafe_logdensityof(h::HierarchicalMeasure, x) - x_primary, x_secondary = _split_variate(h, x) - h_primary, h_secondary = h.m, h.f(x_secondary) - unsafe_logdensityof(h_primary, x_primary) + logdensityof(h_secondary, x_secondary) -end +rootmeasure(::HierarchicalMeasure) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for HierarchicalMeasure")) + +basemeasure(::HierarchicalMeasure) = throw(ArgumentError("basemeasure is not available for HierarchicalMeasure")) + +logdensity_def(::HierarchicalMeasure, x) = throw(ArgumentError("logdensity_def is not available for HierarchicalMeasure")) + + +# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient +# # for AutoFlatten, since variate will be split in `localmeasure` and then +# # split again in log-density evaluation. Maybe add something like +# function unsafe_logdensityof(h::HierarchicalMeasure, x) +# local_primary, local_secondary, x_primary, x_secondary = ... +# # Need to call full logdensityof for h_secondary since x_secondary hasn't +# # been checked yet: +# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) +# end function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::HierarchicalMeasure) where {T<:Real} x_primary = rand(rng, T, h.m) x_secondary = rand(rng, T, h.f(x_primary)) - return _combine_variates(x_primary, x_secondary) + return _combine_variates(h.flatten_mode, x_primary, x_secondary) end -function _split_measure_at(μ::PowerMeasure{M, Tuple{R}}, n::Integer) where {M<:StdMeasure,R} - dof_μ = getdof(μ) - return M()^n, M()^(dof_μ - n) -end - -function transport_def( - ν::PowerMeasure{M, Tuple{R}}, - μ::HierarchicalMeasure, - x, -) where {M<:StdMeasure,R} - ν_primary, ν_secondary = _split_measure_at(ν, μ.dof_m) - x_primary, x_secondary = _split_variate(μ, x) +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::HierarchicalMeasure, x) μ_primary = μ.m + y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) μ_secondary = μ.f(x_secondary) - y_primary = transport_to(ν_primary, μ_primary, x_primary) - y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) - return vcat(y_primary, y_secondary) + y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) + return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest +end + +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) + dof_μ = getdof(μ) + x_μ, x_rest = _split_variate_after(flatten_mode, μ, x) + y = transport_to(ν_inner^dof_μ, μ, x_μ) + return y, x_rest +end + +function transport_def(ν::_PowerStdMeasure{1}, μ::HierarchicalMeasure, x) + ν_inner = _get_inner_stdmeasure(ν) + y, x_rest = _to_std_with_rest(ν_inner, μ, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) + end + return y end -function transport_def( - ν::HierarchicalMeasure, - μ::PowerMeasure{M, Tuple{R}}, - x, -) where {M<:StdMeasure,R} - dof_primary = ν.dof_m - μ_primary, μ_secondary = _split_measure_at(μ, dof_primary) - x_primary, x_secondary = x[begin:begin+dof_primary-1], x[begin+dof_primary:end] +function _from_std_with_rest(ν::HierarchicalMeasure, μ_inner::StdMeasure, x) ν_primary = ν.m - y_primary = transport_to(ν_primary, μ_primary, x_primary) + y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) ν_secondary = ν.f(y_primary) - y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) - return _combine_variates(y_primary, y_secondary) + y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) + return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest +end + +function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) + dof_ν = getdof(ν) + len_x = length(eachindex(x)) + + # Since we can't check DOF of original HierarchicalMeasure, we could "run out x" if + # the original x was too short. `transport_to` below will detect this, but better + # throw a more informative exception here: + if len_x < dof_ν + throw(ArgumentError("Variate too short during transport involving HierarchicalMeasure")) + end + + y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) + x_rest = Fill(zero(eltype(x)), dof_ν - len_x) + return y, x_rest +end + +function transport_def(ν::HierarchicalMeasure, μ::_PowerStdMeasure{1}, x) + # Sanity check, should be checked by transport machinery already: + @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector + μ_inner = _get_inner_stdmeasure(μ) + y, x_rest = _from_std_with_rest(ν, μ_inner, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) + end + return y end diff --git a/src/density-core.jl b/src/density-core.jl index ef045b92..626c6ced 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -55,6 +55,7 @@ To compute a log-density relative to a specific base-measure, see end _checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf)) +@inline _checksupport(::NoFastInsupport, result) = result import ChainRulesCore @inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result) @@ -77,6 +78,12 @@ This is "unsafe" because it does not check `insupport(m, x)`. See also `logdensityof`. """ @inline function unsafe_logdensityof(μ::M, x) where {M} + μ_local = localmeasure(μ, x) + # Extra dispatch boundary to reduce number of required specializations of implementation: + return _unsafe_logdensityof_local(μ_local, x) +end + +@inline function _unsafe_logdensityof_local(μ::M, x) where {M} ℓ_0 = logdensity_def(μ, x) b_0 = μ Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number @@ -119,7 +126,7 @@ known to be in the support of both, it can be more efficient to call end -function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X} +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X} T = unstatic( promote_type( logdensity_type(μ, X), @@ -134,16 +141,16 @@ function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where end -function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X} +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X} unsafe_logdensity_rel(μ, ν, x) end -function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X} +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X} logd = unsafe_logdensity_rel(μ, ν, x) return istrue(inμ) ? logd : logd * oftypeof(logd, -Inf) end -function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X} +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X} logd = unsafe_logdensity_rel(μ, ν, x) return istrue(inν) ? logd : logd * oftypeof(logd, +Inf) end @@ -160,6 +167,7 @@ See also `logdensity_rel`. @inline function unsafe_logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} μ_local = localmeasure(μ, x) ν_local = localmeasure(ν, x) + # Extra dispatch boundary to reduce number of required specializations of implementation: return _unsafe_logdensity_rel_local(μ_local, ν_local, x) end diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 833f280e..0df8977f 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -1,5 +1,11 @@ abstract type StdMeasure <: AbstractMeasure end + +const _PowerStdMeasure{N,M<:StdMeasure} = PowerMeasure{M,<:NTuple{N,Base.OneTo}} + +_get_inner_stdmeasure(μ::_PowerStdMeasure{N,M}) where {N,M} = M() + + StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randexp)) = StdExponential() StdMeasure(::typeof(randn)) = StdNormal() From d1f31bae19c77e0d10e9150a42d492fdd6f66e1e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 11:57:15 +0200 Subject: [PATCH 022/133] STASH --- src/combinators/hierarchical.jl | 64 ++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl index 0a2b1909..3256308b 100644 --- a/src/combinators/hierarchical.jl +++ b/src/combinators/hierarchical.jl @@ -1,61 +1,67 @@ export HierarchicalMeasure +""" + struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure -# TODO: Document and use FlattenMode -abstract type FlattenMode end -struct NoFlatten <: FlattenMode end -struct AutoFlatten <: FlattenMode end +Represents a hierarchical measure. - -struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure +User code should not instantiate `HierarchicalMeasure` directly, use +[`hierarchical_measure`](@ref) instead. +""" +struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure f::F m::M - flatten_mode::FM + flatten::G end -# TODO: Document -const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten} -export HierarchicalProductMeasure - -HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten()) +""" + hierarchical_measure(f, m::AbstractMeasure, flatten) -# TODO: Document -const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten} -export FlatHierarchicalMeasure +Construct a hierarchical measure from a function `f`, measure `m` and +""" +@inline function hierarchical_measure(f, m::AbstractMeasure, flatten) + F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(flatten) + HierarchicalProductMeasure{F,M,G}(f, m, flatten) +end -FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten()) -HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m) +#!!!!!! +const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,::typeof(=>)} +const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,::typeof(vcat)} -function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2}) - @assert x isa Tuple{2} - return x[1], x[2] +function _split_variate(::typeof(=>), ::AbstractMeasure, x::Pair) + return x.first, x.second end +function _split_variate(flatten::F, μ_primary::AbstractMeasure, x) where F + test_primary = testvalue(μ_primary) + return _split_variate_byvalue(flatten, test_primary, x) +end -function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x) - a_test = testvalue(μ) - return _autosplit_variate_after_testvalue(a_test, x) +function _split_variate(::Type{F}, μ::AbstractMeasure, x) where F + test_primary = testvalue(μ) + return _split_variate_byvalue(F, test_primary, x) end -function _autosplit_variate_after_testvalue(::Any, x) + +function _split_variate_byvalue(::Any, x) @assert x isa Tuple{2} return x[1], x[2] end -function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector) - n, m = length(eachindex(a_test)), length(eachindex(x)) +function _split_variate_byvalue(test_primary::AbstractVector, x::AbstractVector) + n, m = length(eachindex(test_primary)), length(eachindex(x)) # TODO: Use getindex or view? return x[begin:n], x[begin+n:m] end -function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M} +function _split_variate_byvalue(::Tuple{N}, x::Tuple{M}) where {N,M} return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) end -@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} +@generated function _split_variate_byvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} # TODO: implement @assert false end @@ -147,7 +153,7 @@ end function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) dof_μ = getdof(μ) - x_μ, x_rest = _split_variate_after(flatten_mode, μ, x) + x_μ, x_rest = _split_variate(flatten_mode, μ, x) y = transport_to(ν_inner^dof_μ, μ, x_μ) return y, x_rest end From 13cf10b2057833b0c1339ef1eb7a56528bf26cfa Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 15:15:00 +0200 Subject: [PATCH 023/133] STASH --- src/combinators/hierarchical.jl | 78 ++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl index 3256308b..433e1c5d 100644 --- a/src/combinators/hierarchical.jl +++ b/src/combinators/hierarchical.jl @@ -9,25 +9,58 @@ User code should not instantiate `HierarchicalMeasure` directly, use [`hierarchical_measure`](@ref) instead. """ struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure - f::F - m::M - flatten::G + f_kernel::F + μ_primary::M + f_combine::G end -""" - hierarchical_measure(f, m::AbstractMeasure, flatten) -Construct a hierarchical measure from a function `f`, measure `m` and -""" -@inline function hierarchical_measure(f, m::AbstractMeasure, flatten) - F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(flatten) - HierarchicalProductMeasure{F,M,G}(f, m, flatten) -end +@doc raw""" + hierarchical_measure(k, μ_primary::AbstractMeasure, f_combine) + +Construct a hierarchical measure from a transition kernel function +`f_kernel`, a measure `μ_primary` and a function `f_combine`. + +`f_kernel` must be a function that maps a variate `x_primary` of +`μ_primary` to a dependent secondary measure `μ_secondary = f(x_primary)). +`y = f_combine(x_primary, x_secondary)` must map variates from the primary +and the dependent secondary measure to a combined value. + +A measure + +```julia +`μ = hierarchical_measure(f_c, α, f_β)` +``` + +has the mathethematical interpretation (using the notation `β_a = f_β(a)`): + +```math +\mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) +``` +Comutationally, `x = rand(μ)` is equivalent to -#!!!!!! -const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,::typeof(=>)} -const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,::typeof(vcat)} +```julia +x_primary = rand(μ_primary) +μ_secondary = f_kernel(x_primary) +x_secondary = rand(μ_secondary) +x = f_combine(x_primary, x_secondary) +``` + +and `mbind(f_β, α)` is equivalent to: + +```julia +pushfwd((a, b) -> b, hierarchical_measure(=>, α, tuple)) +``` + +Possible choices for `f_combine` are `=>`/`Pair` or `tuple` (these work for +any combination of variate types), `vcat` (for tuple- or vector-like +variates) and `merge` (e.g. for `NamedTuple` variates). +""" +@inline function hierarchical_measure(f, μ::AbstractMeasure, f_combine) + F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(f_combine) + HierarchicalProductMeasure{F,M,G}(f, μ, f_combine) +end @@ -35,26 +68,19 @@ function _split_variate(::typeof(=>), ::AbstractMeasure, x::Pair) return x.first, x.second end -function _split_variate(flatten::F, μ_primary::AbstractMeasure, x) where F - test_primary = testvalue(μ_primary) - return _split_variate_byvalue(flatten, test_primary, x) +function _split_variate(f_combine::F, μ_primary::AbstractMeasure, x) where F + _split_variate_byvalue(f_combine, testvalue(μ), x) end +# Necessary/helpful for type stability? function _split_variate(::Type{F}, μ::AbstractMeasure, x) where F - test_primary = testvalue(μ) - return _split_variate_byvalue(F, test_primary, x) -end - - -function _split_variate_byvalue(::Any, x) - @assert x isa Tuple{2} - return x[1], x[2] + _split_variate_byvalue(F, testvalue(μ), x) end function _split_variate_byvalue(test_primary::AbstractVector, x::AbstractVector) n, m = length(eachindex(test_primary)), length(eachindex(x)) # TODO: Use getindex or view? - return x[begin:n], x[begin+n:m] + return x[begin:begin+n-1], x[begin+n:end] end function _split_variate_byvalue(::Tuple{N}, x::Tuple{M}) where {N,M} From 566e350b15ac426598b16fd35396478b448fb1c1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 17:03:20 +0200 Subject: [PATCH 024/133] STASH --- src/combinators/hierarchical.jl | 121 +++++++++++++++++++------------- 1 file changed, 72 insertions(+), 49 deletions(-) diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl index 433e1c5d..f13b4075 100644 --- a/src/combinators/hierarchical.jl +++ b/src/combinators/hierarchical.jl @@ -1,79 +1,103 @@ -export HierarchicalMeasure +export Bind """ - struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure + struct MeasureBase.Bind{F,M<:AbstractMeasure,G} <: AbstractMeasure -Represents a hierarchical measure. +Represents a monatic bind resp. a mbind in general. -User code should not instantiate `HierarchicalMeasure` directly, use -[`hierarchical_measure`](@ref) instead. +User code should not create instances of `Bind` directly, but should call +[`mbind`](@ref) instead. """ -struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure +struct Bind{F,M<:AbstractMeasure,G} <: AbstractMeasure f_kernel::F - μ_primary::M + m_primary::M f_combine::G end @doc raw""" - hierarchical_measure(k, μ_primary::AbstractMeasure, f_combine) + mbind(f_β, α::AbstractMeasure, f_c = second) -Construct a hierarchical measure from a transition kernel function -`f_kernel`, a measure `μ_primary` and a function `f_combine`. +Constructs a monadic bind resp. a hierarchical measure from a transition +kernel function `f_β`, a primary measure `α` and a variate combination +function `f_c`. -`f_kernel` must be a function that maps a variate `x_primary` of -`μ_primary` to a dependent secondary measure `μ_secondary = f(x_primary)). -`y = f_combine(x_primary, x_secondary)` must map variates from the primary -and the dependent secondary measure to a combined value. +`f_β` must be a function that maps a point `a` from the space of measure +`α` to a dependent measure `β_a = f_β(a)`. `ab = f_combine(a, b)` must +`a` and e +primary and a variates `b` of the dependent secondary measure `β_a` to +a combined value `ab`. A measure ```julia -`μ = hierarchical_measure(f_c, α, f_β)` +`μ = mbind(f_c, α, f_β)` ``` -has the mathethematical interpretation (using the notation `β_a = f_β(a)`): +has the mathethematical interpretation ```math \mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) ``` -Comutationally, `x = rand(μ)` is equivalent to +Without the default `fc = second` this becomes -```julia -x_primary = rand(μ_primary) -μ_secondary = f_kernel(x_primary) -x_secondary = rand(μ_secondary) -x = f_combine(x_primary, x_secondary) +```math +\mu(B) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) ``` -and `mbind(f_β, α)` is equivalent to: +which is equivalent to a monatic bind, when viewing measures as monads. + +Comutationally, `ab = rand(μ)` is equivalent to ```julia -pushfwd((a, b) -> b, hierarchical_measure(=>, α, tuple)) +a = rand(μ_primary) +β_a = f_β(a) +b = rand(β_a) +ab = f_combine(a, b) ``` -Possible choices for `f_combine` are `=>`/`Pair` or `tuple` (these work for +Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` +can be unambiguously split into `a` and `b` again. This is currently +implemented for `f_c` that is either `=>`/`Pair` or `tuple` (these work for any combination of variate types), `vcat` (for tuple- or vector-like -variates) and `merge` (e.g. for `NamedTuple` variates). +variates) and `merge` (`NamedTuple` variates). +[`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to +support other choices for `f_c`. """ -@inline function hierarchical_measure(f, μ::AbstractMeasure, f_combine) +@inline function mbind(f, μ::AbstractMeasure, f_combine) F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(f_combine) HierarchicalProductMeasure{F,M,G}(f, μ, f_combine) end +""" + MeasureBase.split_combined(f_c, α::AbstractMeasure, ab) + +Splits a combined value `ab` that originated from combining a point `a_orig` +from the space of a measure `α` with a point `b_orig` from the space of +another measure `β` via `ab = f_c(a_orig, b_orig)`. + +So with `a_orig = rand(α)`, `b_orig = rand(β)` and +`ab = f_c(a_orig, b_orig)`, the following must hold true: -function _split_variate(::typeof(=>), ::AbstractMeasure, x::Pair) +```julia +a, b2 = split_combined(f_c, α, ab) +a ≈ a_orig && b ≈ b_orig +``` +""" +function split_combined end + +function split_combined(::typeof(=>), ::AbstractMeasure, x::Pair) return x.first, x.second end -function _split_variate(f_combine::F, μ_primary::AbstractMeasure, x) where F +function split_combined(f_combine::F, μ_primary::AbstractMeasure, x) where F _split_variate_byvalue(f_combine, testvalue(μ), x) end # Necessary/helpful for type stability? -function _split_variate(::Type{F}, μ::AbstractMeasure, x) where F +function split_combined(::Type{F}, μ::AbstractMeasure, x) where F _split_variate_byvalue(F, testvalue(μ), x) end @@ -130,30 +154,30 @@ end function localmeasure(μ::HierarchicalProductMeasure, x) h_local, x_rest = _localmeasure_with_rest(μ, x) if !isempty(x_rest) - throw(ArgumentError("Variate too long while computing localmeasure of HierarchicalMeasure")) + throw(ArgumentError("Variate too long while computing localmeasure of Bind")) end return h_local end -@inline insupport(::HierarchicalMeasure, x) = NoFastInsupport() +@inline insupport(::Bind, x) = NoFastInsupport() -@inline getdof(μ::HierarchicalMeasure) = NoDOF{typeof(μ)}() +@inline getdof(μ::Bind) = NoDOF{typeof(μ)}() # Bypass `checked_arg`, would require potentially costly evaluation of h.f: -@inline checked_arg(::HierarchicalMeasure, x) = x +@inline checked_arg(::Bind, x) = x -rootmeasure(::HierarchicalMeasure) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for HierarchicalMeasure")) +rootmeasure(::Bind) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for Bind")) -basemeasure(::HierarchicalMeasure) = throw(ArgumentError("basemeasure is not available for HierarchicalMeasure")) +basemeasure(::Bind) = throw(ArgumentError("basemeasure is not available for Bind")) -logdensity_def(::HierarchicalMeasure, x) = throw(ArgumentError("logdensity_def is not available for HierarchicalMeasure")) +logdensity_def(::Bind, x) = throw(ArgumentError("logdensity_def is not available for Bind")) # # TODO: Default implementation of unsafe_logdensityof is a bit inefficient # # for AutoFlatten, since variate will be split in `localmeasure` and then # # split again in log-density evaluation. Maybe add something like -# function unsafe_logdensityof(h::HierarchicalMeasure, x) +# function unsafe_logdensityof(h::Bind, x) # local_primary, local_secondary, x_primary, x_secondary = ... # # Need to call full logdensityof for h_secondary since x_secondary hasn't # # been checked yet: @@ -161,15 +185,14 @@ logdensity_def(::HierarchicalMeasure, x) = throw(ArgumentError("logdensity_def i # end -function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::HierarchicalMeasure) where {T<:Real} +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Bind) where {T<:Real} x_primary = rand(rng, T, h.m) x_secondary = rand(rng, T, h.f(x_primary)) return _combine_variates(h.flatten_mode, x_primary, x_secondary) end - -function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::HierarchicalMeasure, x) +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::Bind, x) μ_primary = μ.m y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) μ_secondary = μ.f(x_secondary) @@ -179,22 +202,22 @@ end function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) dof_μ = getdof(μ) - x_μ, x_rest = _split_variate(flatten_mode, μ, x) + x_μ, x_rest = split_combined(flatten_mode, μ, x) y = transport_to(ν_inner^dof_μ, μ, x_μ) return y, x_rest end -function transport_def(ν::_PowerStdMeasure{1}, μ::HierarchicalMeasure, x) +function transport_def(ν::_PowerStdMeasure{1}, μ::Bind, x) ν_inner = _get_inner_stdmeasure(ν) y, x_rest = _to_std_with_rest(ν_inner, μ, x) if !isempty(x_rest) - throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) + throw(ArgumentError("Variate too long during transport involving Bind")) end return y end -function _from_std_with_rest(ν::HierarchicalMeasure, μ_inner::StdMeasure, x) +function _from_std_with_rest(ν::Bind, μ_inner::StdMeasure, x) ν_primary = ν.m y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) ν_secondary = ν.f(y_primary) @@ -206,11 +229,11 @@ function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) dof_ν = getdof(ν) len_x = length(eachindex(x)) - # Since we can't check DOF of original HierarchicalMeasure, we could "run out x" if + # Since we can't check DOF of original Bind, we could "run out x" if # the original x was too short. `transport_to` below will detect this, but better # throw a more informative exception here: if len_x < dof_ν - throw(ArgumentError("Variate too short during transport involving HierarchicalMeasure")) + throw(ArgumentError("Variate too short during transport involving Bind")) end y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) @@ -218,13 +241,13 @@ function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) return y, x_rest end -function transport_def(ν::HierarchicalMeasure, μ::_PowerStdMeasure{1}, x) +function transport_def(ν::Bind, μ::_PowerStdMeasure{1}, x) # Sanity check, should be checked by transport machinery already: @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector μ_inner = _get_inner_stdmeasure(μ) y, x_rest = _from_std_with_rest(ν, μ_inner, x) if !isempty(x_rest) - throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) + throw(ArgumentError("Variate too long during transport involving Bind")) end return y end From 8532cb829616c752cebfbe3dc63bb39f9c417bc7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 18:05:31 +0200 Subject: [PATCH 025/133] STASH mbind --- src/combinators/hierarchical.jl | 69 ++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 18 deletions(-) diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl index f13b4075..a9e089fb 100644 --- a/src/combinators/hierarchical.jl +++ b/src/combinators/hierarchical.jl @@ -8,30 +8,29 @@ Represents a monatic bind resp. a mbind in general. User code should not create instances of `Bind` directly, but should call [`mbind`](@ref) instead. """ -struct Bind{F,M<:AbstractMeasure,G} <: AbstractMeasure - f_kernel::F +struct Bind{K,M<:AbstractMeasure,C} <: AbstractMeasure + f_kernel::K m_primary::M - f_combine::G + f_combine::C end @doc raw""" mbind(f_β, α::AbstractMeasure, f_c = second) -Constructs a monadic bind resp. a hierarchical measure from a transition +Constructs a monadic bind, resp. a hierarchical measure, from a transition kernel function `f_β`, a primary measure `α` and a variate combination function `f_c`. -`f_β` must be a function that maps a point `a` from the space of measure -`α` to a dependent measure `β_a = f_β(a)`. `ab = f_combine(a, b)` must -`a` and e -primary and a variates `b` of the dependent secondary measure `β_a` to -a combined value `ab`. +`f_β` must be a function that maps a point `a` from the space of the primary +measure `α` to a dependent secondary measure `β_a = f_β(a)`. +`ab = f_combine(a, b)` must map such a point `a` and a point `b` from the +space of measure `β_a` to a combined value `ab = f_c(a, b)`. -A measure +The resulting measure ```julia -`μ = mbind(f_c, α, f_β)` +μ = mbind(f_c, α, f_β) ``` has the mathethematical interpretation @@ -40,15 +39,15 @@ has the mathethematical interpretation \mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) ``` -Without the default `fc = second` this becomes +When using the default `fc = second` (so `ab == b`) this simplies to ```math \mu(B) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) ``` -which is equivalent to a monatic bind, when viewing measures as monads. +which is equivalent to a monatic bind, viewing measures as monads. -Comutationally, `ab = rand(μ)` is equivalent to +Computationally, `ab = rand(μ)` is equivalent to ```julia a = rand(μ_primary) @@ -58,14 +57,48 @@ ab = f_combine(a, b) ``` Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` -can be unambiguously split into `a` and `b` again. This is currently -implemented for `f_c` that is either `=>`/`Pair` or `tuple` (these work for -any combination of variate types), `vcat` (for tuple- or vector-like +can be unambiguously split into `a` and `b` again, knowing `α`. This is +currently implemented for `f_c` that is either `=>`/`Pair` or `tuple` (these +work for any combination of variate types), `vcat` (for tuple- or vector-like variates) and `merge` (`NamedTuple` variates). [`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to support other choices for `f_c`. + +# Extended help + +Bayesian example with a correlated prior, that models the + +```julia +using MeasureBase + +prior = mbind + productmeasure(( + value => StdNormal() + )), merge +) do a + productmeasure(( + noise = pushfwd(sqrt ∘ Mul(abs(a.position)), StdExponential()) + )) +end + +model = θ -> pushfwd(MulAdd(θ.noise, θ.value), StdNormal())^10 + +joint_θ_obs = mbind(model, prior, tuple) +prior_predictive = mbind(model, prior) + +observation = rand(prior_predictive) +likelihood = likelihoodof(model, observation) + +posterior = mintegrate(likelihood, prior) + +θ = rand(prior) +logdensityof(posterior, θ) +``` """ -@inline function mbind(f, μ::AbstractMeasure, f_combine) +function mbind end +export mbind + +@inline function mbind(f, μ::AbstractMeasure, f_combine = second) F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(f_combine) HierarchicalProductMeasure{F,M,G}(f, μ, f_combine) end From bc97abe91a643b6783f979de69ca897f704eee5a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 21:25:01 +0200 Subject: [PATCH 026/133] STASH hierarchical to bind --- src/MeasureBase.jl | 1 - src/combinators/bind.jl | 297 +++++++++++++++++++++++++++++--- src/combinators/hierarchical.jl | 286 ------------------------------ test/combinators/bind.jl | 11 ++ test/runtests.jl | 1 + 5 files changed, 281 insertions(+), 315 deletions(-) delete mode 100644 src/combinators/hierarchical.jl create mode 100644 test/combinators/bind.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 289ffd25..eeea2ad8 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -143,7 +143,6 @@ include("standard/stdexponential.jl") include("standard/stdlogistic.jl") include("standard/stdnormal.jl") include("combinators/half.jl") -include("combinators/hierarchical.jl") include("rand.jl") include("fixedrng.jl") diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 465f2bf7..a9e089fb 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -1,45 +1,286 @@ +export Bind + """ - struct MeasureBase.Bind{M,K} <: AbstractMeasure + struct MeasureBase.Bind{F,M<:AbstractMeasure,G} <: AbstractMeasure + +Represents a monatic bind resp. a mbind in general. -Represents a monatic bind. User code should create instances of `Bind` -directly, but should call `mbind(k, μ)` instead. +User code should not create instances of `Bind` directly, but should call +[`mbind`](@ref) instead. """ -struct Bind{M,K} <: AbstractMeasure - k::K - μ::M +struct Bind{K,M<:AbstractMeasure,C} <: AbstractMeasure + f_kernel::K + m_primary::M + f_combine::C end -getdof(d::Bind) = NoDOF{typeof(d)}() -function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} - x = rand(rng, T, d.μ) - y = rand(rng, T, d.k(x)) - return y -end +@doc raw""" + mbind(f_β, α::AbstractMeasure, f_c = second) -""" - mbind(k, μ)::AbstractMeasure - -Given +Constructs a monadic bind, resp. a hierarchical measure, from a transition +kernel function `f_β`, a primary measure `α` and a variate combination +function `f_c`. + +`f_β` must be a function that maps a point `a` from the space of the primary +measure `α` to a dependent secondary measure `β_a = f_β(a)`. +`ab = f_combine(a, b)` must map such a point `a` and a point `b` from the +space of measure `β_a` to a combined value `ab = f_c(a, b)`. + +The resulting measure + +```julia +μ = mbind(f_c, α, f_β) +``` + +has the mathethematical interpretation + +```math +\mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) +``` -- a measure μ -- a kernel function k that takes values from the support of μ and returns a - measure +When using the default `fc = second` (so `ab == b`) this simplies to -The *monadic bind* operation `mbind(k, μ)` returns is a new measure. -If `ν == mbind(k, μ)` and all measures involved are sampleable, then -samples from `rand(ν)` follow the same distribution as those from `rand(k(rand(μ)))`. +```math +\mu(B) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) +``` +which is equivalent to a monatic bind, viewing measures as monads. -A monadic bind ofen written as `>>=` (e.g. in Haskell), but this symbol is -unavailable in Julia. +Computationally, `ab = rand(μ)` is equivalent to +```julia +a = rand(μ_primary) +β_a = f_β(a) +b = rand(β_a) +ab = f_combine(a, b) ``` -μ = StdExponential() -ν = mbind(μ) do scale - pushfwd(Base.Fix1(*, scale), StdNormal()) + +Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` +can be unambiguously split into `a` and `b` again, knowing `α`. This is +currently implemented for `f_c` that is either `=>`/`Pair` or `tuple` (these +work for any combination of variate types), `vcat` (for tuple- or vector-like +variates) and `merge` (`NamedTuple` variates). +[`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to +support other choices for `f_c`. + +# Extended help + +Bayesian example with a correlated prior, that models the + +```julia +using MeasureBase + +prior = mbind + productmeasure(( + value => StdNormal() + )), merge +) do a + productmeasure(( + noise = pushfwd(sqrt ∘ Mul(abs(a.position)), StdExponential()) + )) end + +model = θ -> pushfwd(MulAdd(θ.noise, θ.value), StdNormal())^10 + +joint_θ_obs = mbind(model, prior, tuple) +prior_predictive = mbind(model, prior) + +observation = rand(prior_predictive) +likelihood = likelihoodof(model, observation) + +posterior = mintegrate(likelihood, prior) + +θ = rand(prior) +logdensityof(posterior, θ) ``` """ -mbind(k, μ) = Bind(k, μ) +function mbind end export mbind + +@inline function mbind(f, μ::AbstractMeasure, f_combine = second) + F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(f_combine) + HierarchicalProductMeasure{F,M,G}(f, μ, f_combine) +end + + +""" + MeasureBase.split_combined(f_c, α::AbstractMeasure, ab) + +Splits a combined value `ab` that originated from combining a point `a_orig` +from the space of a measure `α` with a point `b_orig` from the space of +another measure `β` via `ab = f_c(a_orig, b_orig)`. + +So with `a_orig = rand(α)`, `b_orig = rand(β)` and +`ab = f_c(a_orig, b_orig)`, the following must hold true: + +```julia +a, b2 = split_combined(f_c, α, ab) +a ≈ a_orig && b ≈ b_orig +``` +""" +function split_combined end + +function split_combined(::typeof(=>), ::AbstractMeasure, x::Pair) + return x.first, x.second +end + +function split_combined(f_combine::F, μ_primary::AbstractMeasure, x) where F + _split_variate_byvalue(f_combine, testvalue(μ), x) +end + +# Necessary/helpful for type stability? +function split_combined(::Type{F}, μ::AbstractMeasure, x) where F + _split_variate_byvalue(F, testvalue(μ), x) +end + +function _split_variate_byvalue(test_primary::AbstractVector, x::AbstractVector) + n, m = length(eachindex(test_primary)), length(eachindex(x)) + # TODO: Use getindex or view? + return x[begin:begin+n-1], x[begin+n:end] +end + +function _split_variate_byvalue(::Tuple{N}, x::Tuple{M}) where {N,M} + return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) +end + +@generated function _split_variate_byvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} + # TODO: implement + @assert false +end + + + +_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b) + + +_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b) + +_autoflat_combine_variates(a::Any, b::Any) = (a, b) + +_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b) + +_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b) + +# TODO: Check that names don't overlap: +_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b) + + +_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2) + +# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2) +# Needs a FlatProductMeasure type. + +function _localmeasure_with_rest(μ::HierarchicalProductMeasure, x) + μ_primary = μ.m + local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x) + μ_secondary = μ.f(x_secondary) + local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary) + return _local_productmeasure(μ.flatten_mode, local_primary, local_secondary), x_rest +end + +function _localmeasure_with_rest(μ::AbstractMeasure, x) + x_checked = checked_arg(μ, x) + return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0) +end + +function localmeasure(μ::HierarchicalProductMeasure, x) + h_local, x_rest = _localmeasure_with_rest(μ, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long while computing localmeasure of Bind")) + end + return h_local +end + + +@inline insupport(::Bind, x) = NoFastInsupport() + +@inline getdof(μ::Bind) = NoDOF{typeof(μ)}() + +# Bypass `checked_arg`, would require potentially costly evaluation of h.f: +@inline checked_arg(::Bind, x) = x + +rootmeasure(::Bind) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for Bind")) + +basemeasure(::Bind) = throw(ArgumentError("basemeasure is not available for Bind")) + +logdensity_def(::Bind, x) = throw(ArgumentError("logdensity_def is not available for Bind")) + + +# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient +# # for AutoFlatten, since variate will be split in `localmeasure` and then +# # split again in log-density evaluation. Maybe add something like +# function unsafe_logdensityof(h::Bind, x) +# local_primary, local_secondary, x_primary, x_secondary = ... +# # Need to call full logdensityof for h_secondary since x_secondary hasn't +# # been checked yet: +# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) +# end + + +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Bind) where {T<:Real} + x_primary = rand(rng, T, h.m) + x_secondary = rand(rng, T, h.f(x_primary)) + return _combine_variates(h.flatten_mode, x_primary, x_secondary) +end + + +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::Bind, x) + μ_primary = μ.m + y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) + μ_secondary = μ.f(x_secondary) + y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) + return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest +end + +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) + dof_μ = getdof(μ) + x_μ, x_rest = split_combined(flatten_mode, μ, x) + y = transport_to(ν_inner^dof_μ, μ, x_μ) + return y, x_rest +end + +function transport_def(ν::_PowerStdMeasure{1}, μ::Bind, x) + ν_inner = _get_inner_stdmeasure(ν) + y, x_rest = _to_std_with_rest(ν_inner, μ, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long during transport involving Bind")) + end + return y +end + + +function _from_std_with_rest(ν::Bind, μ_inner::StdMeasure, x) + ν_primary = ν.m + y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) + ν_secondary = ν.f(y_primary) + y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) + return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest +end + +function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) + dof_ν = getdof(ν) + len_x = length(eachindex(x)) + + # Since we can't check DOF of original Bind, we could "run out x" if + # the original x was too short. `transport_to` below will detect this, but better + # throw a more informative exception here: + if len_x < dof_ν + throw(ArgumentError("Variate too short during transport involving Bind")) + end + + y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) + x_rest = Fill(zero(eltype(x)), dof_ν - len_x) + return y, x_rest +end + +function transport_def(ν::Bind, μ::_PowerStdMeasure{1}, x) + # Sanity check, should be checked by transport machinery already: + @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector + μ_inner = _get_inner_stdmeasure(μ) + y, x_rest = _from_std_with_rest(ν, μ_inner, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long during transport involving Bind")) + end + return y +end diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl deleted file mode 100644 index a9e089fb..00000000 --- a/src/combinators/hierarchical.jl +++ /dev/null @@ -1,286 +0,0 @@ -export Bind - -""" - struct MeasureBase.Bind{F,M<:AbstractMeasure,G} <: AbstractMeasure - -Represents a monatic bind resp. a mbind in general. - -User code should not create instances of `Bind` directly, but should call -[`mbind`](@ref) instead. -""" -struct Bind{K,M<:AbstractMeasure,C} <: AbstractMeasure - f_kernel::K - m_primary::M - f_combine::C -end - - -@doc raw""" - mbind(f_β, α::AbstractMeasure, f_c = second) - -Constructs a monadic bind, resp. a hierarchical measure, from a transition -kernel function `f_β`, a primary measure `α` and a variate combination -function `f_c`. - -`f_β` must be a function that maps a point `a` from the space of the primary -measure `α` to a dependent secondary measure `β_a = f_β(a)`. -`ab = f_combine(a, b)` must map such a point `a` and a point `b` from the -space of measure `β_a` to a combined value `ab = f_c(a, b)`. - -The resulting measure - -```julia -μ = mbind(f_c, α, f_β) -``` - -has the mathethematical interpretation - -```math -\mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) -``` - -When using the default `fc = second` (so `ab == b`) this simplies to - -```math -\mu(B) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) -``` - -which is equivalent to a monatic bind, viewing measures as monads. - -Computationally, `ab = rand(μ)` is equivalent to - -```julia -a = rand(μ_primary) -β_a = f_β(a) -b = rand(β_a) -ab = f_combine(a, b) -``` - -Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` -can be unambiguously split into `a` and `b` again, knowing `α`. This is -currently implemented for `f_c` that is either `=>`/`Pair` or `tuple` (these -work for any combination of variate types), `vcat` (for tuple- or vector-like -variates) and `merge` (`NamedTuple` variates). -[`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to -support other choices for `f_c`. - -# Extended help - -Bayesian example with a correlated prior, that models the - -```julia -using MeasureBase - -prior = mbind - productmeasure(( - value => StdNormal() - )), merge -) do a - productmeasure(( - noise = pushfwd(sqrt ∘ Mul(abs(a.position)), StdExponential()) - )) -end - -model = θ -> pushfwd(MulAdd(θ.noise, θ.value), StdNormal())^10 - -joint_θ_obs = mbind(model, prior, tuple) -prior_predictive = mbind(model, prior) - -observation = rand(prior_predictive) -likelihood = likelihoodof(model, observation) - -posterior = mintegrate(likelihood, prior) - -θ = rand(prior) -logdensityof(posterior, θ) -``` -""" -function mbind end -export mbind - -@inline function mbind(f, μ::AbstractMeasure, f_combine = second) - F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(f_combine) - HierarchicalProductMeasure{F,M,G}(f, μ, f_combine) -end - - -""" - MeasureBase.split_combined(f_c, α::AbstractMeasure, ab) - -Splits a combined value `ab` that originated from combining a point `a_orig` -from the space of a measure `α` with a point `b_orig` from the space of -another measure `β` via `ab = f_c(a_orig, b_orig)`. - -So with `a_orig = rand(α)`, `b_orig = rand(β)` and -`ab = f_c(a_orig, b_orig)`, the following must hold true: - -```julia -a, b2 = split_combined(f_c, α, ab) -a ≈ a_orig && b ≈ b_orig -``` -""" -function split_combined end - -function split_combined(::typeof(=>), ::AbstractMeasure, x::Pair) - return x.first, x.second -end - -function split_combined(f_combine::F, μ_primary::AbstractMeasure, x) where F - _split_variate_byvalue(f_combine, testvalue(μ), x) -end - -# Necessary/helpful for type stability? -function split_combined(::Type{F}, μ::AbstractMeasure, x) where F - _split_variate_byvalue(F, testvalue(μ), x) -end - -function _split_variate_byvalue(test_primary::AbstractVector, x::AbstractVector) - n, m = length(eachindex(test_primary)), length(eachindex(x)) - # TODO: Use getindex or view? - return x[begin:begin+n-1], x[begin+n:end] -end - -function _split_variate_byvalue(::Tuple{N}, x::Tuple{M}) where {N,M} - return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) -end - -@generated function _split_variate_byvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} - # TODO: implement - @assert false -end - - - -_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b) - - -_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b) - -_autoflat_combine_variates(a::Any, b::Any) = (a, b) - -_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b) - -_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b) - -# TODO: Check that names don't overlap: -_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b) - - -_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2) - -# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2) -# Needs a FlatProductMeasure type. - -function _localmeasure_with_rest(μ::HierarchicalProductMeasure, x) - μ_primary = μ.m - local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x) - μ_secondary = μ.f(x_secondary) - local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary) - return _local_productmeasure(μ.flatten_mode, local_primary, local_secondary), x_rest -end - -function _localmeasure_with_rest(μ::AbstractMeasure, x) - x_checked = checked_arg(μ, x) - return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0) -end - -function localmeasure(μ::HierarchicalProductMeasure, x) - h_local, x_rest = _localmeasure_with_rest(μ, x) - if !isempty(x_rest) - throw(ArgumentError("Variate too long while computing localmeasure of Bind")) - end - return h_local -end - - -@inline insupport(::Bind, x) = NoFastInsupport() - -@inline getdof(μ::Bind) = NoDOF{typeof(μ)}() - -# Bypass `checked_arg`, would require potentially costly evaluation of h.f: -@inline checked_arg(::Bind, x) = x - -rootmeasure(::Bind) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for Bind")) - -basemeasure(::Bind) = throw(ArgumentError("basemeasure is not available for Bind")) - -logdensity_def(::Bind, x) = throw(ArgumentError("logdensity_def is not available for Bind")) - - -# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient -# # for AutoFlatten, since variate will be split in `localmeasure` and then -# # split again in log-density evaluation. Maybe add something like -# function unsafe_logdensityof(h::Bind, x) -# local_primary, local_secondary, x_primary, x_secondary = ... -# # Need to call full logdensityof for h_secondary since x_secondary hasn't -# # been checked yet: -# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) -# end - - -function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Bind) where {T<:Real} - x_primary = rand(rng, T, h.m) - x_secondary = rand(rng, T, h.f(x_primary)) - return _combine_variates(h.flatten_mode, x_primary, x_secondary) -end - - -function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::Bind, x) - μ_primary = μ.m - y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) - μ_secondary = μ.f(x_secondary) - y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) - return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest -end - -function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) - dof_μ = getdof(μ) - x_μ, x_rest = split_combined(flatten_mode, μ, x) - y = transport_to(ν_inner^dof_μ, μ, x_μ) - return y, x_rest -end - -function transport_def(ν::_PowerStdMeasure{1}, μ::Bind, x) - ν_inner = _get_inner_stdmeasure(ν) - y, x_rest = _to_std_with_rest(ν_inner, μ, x) - if !isempty(x_rest) - throw(ArgumentError("Variate too long during transport involving Bind")) - end - return y -end - - -function _from_std_with_rest(ν::Bind, μ_inner::StdMeasure, x) - ν_primary = ν.m - y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) - ν_secondary = ν.f(y_primary) - y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) - return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest -end - -function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) - dof_ν = getdof(ν) - len_x = length(eachindex(x)) - - # Since we can't check DOF of original Bind, we could "run out x" if - # the original x was too short. `transport_to` below will detect this, but better - # throw a more informative exception here: - if len_x < dof_ν - throw(ArgumentError("Variate too short during transport involving Bind")) - end - - y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) - x_rest = Fill(zero(eltype(x)), dof_ν - len_x) - return y, x_rest -end - -function transport_def(ν::Bind, μ::_PowerStdMeasure{1}, x) - # Sanity check, should be checked by transport machinery already: - @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector - μ_inner = _get_inner_stdmeasure(μ) - y, x_rest = _from_std_with_rest(ν, μ_inner, x) - if !isempty(x_rest) - throw(ArgumentError("Variate too long during transport involving Bind")) - end - return y -end diff --git a/test/combinators/bind.jl b/test/combinators/bind.jl new file mode 100644 index 00000000..55fe0efb --- /dev/null +++ b/test/combinators/bind.jl @@ -0,0 +1,11 @@ +using Test + +using MeasureBase +using MeasureBase: AbstractMeasure +using MeasureBase: StdExponential, StdLogistic, StdUniform +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mbind, mintegrate, mintegrate_exp, density_rel, logdensity_rel + +@testset "bind.jl" begin + +end diff --git a/test/runtests.jl b/test/runtests.jl index 364d0091..0e47fd95 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,6 +19,7 @@ include("smf.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") +include("combinators/bind.jl") include("measure_operators.jl") From f9ef62065a6c9b7dedbf9298d3a1eb0473da1577 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 21:37:12 +0200 Subject: [PATCH 027/133] Add function asmeasure Will be used a lot when bridging from Distributions to MeasureBase. --- src/MeasureBase.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index eeea2ad8..6666cf76 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -57,6 +57,21 @@ abstract type AbstractMeasure end AbstractMeasure(m::AbstractMeasure) = m + +""" + asmeasure(m) + +Turns a measure-like object `m` into an `AbstractMeasure`. + +Calls `convert(AbstractMeasure, m)` by default +""" +function asmeasure end + +@inline asmeasure(m::AbstractMeasure) = m +asmeasure(m) = convert(AbstractMeasure, m) +export asmeasure + + function Pretty.quoteof(d::M) where {M<:AbstractMeasure} the_names = fieldnames(typeof(d)) :($M($([getfield(d, n) for n in the_names]...))) From c51a82e9d3698cb0f164ed1420521940e777208e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 22:39:55 +0200 Subject: [PATCH 028/133] STASH --- src/collection_utils.jl | 26 ++++++++++++++++++++++++++ src/combinators/bind.jl | 41 ++++++++++++++--------------------------- 2 files changed, 40 insertions(+), 27 deletions(-) create mode 100644 src/collection_utils.jl diff --git a/src/collection_utils.jl b/src/collection_utils.jl new file mode 100644 index 00000000..858810cf --- /dev/null +++ b/src/collection_utils.jl @@ -0,0 +1,26 @@ +# ToDo: Add custom rrules for _split_after? + +# ToDo: Use getindex instead of view for certain cases (array types)? +@inline function split_after(x::AbstractVector, n) + i_first = firstindex(x) + i_last = lastindex(x) + view(x, i_first, i_first+n-1), view(x, n, i_last) +end + +@inline _split_after(x::Tuple, n) = _split_after(x::Tuple, Val{n}()) +@inline _split_after(x::Tuple, ::Val{N}) where N = x[begin:begin+N-1], x[N:end] + +@generated function _split_after(x::NamedTuple{names}, ::Val{names_a}) where {names, names_a} + n = length(names_a) + if names_after[begin:begin+n-1] == names_a + names_b = names[n:end] + quote + a, b = _split_after(x, Val(n)) + NamedTuple{names_a}(a), NamedTuple{names_b}(b) + end + else + quote + throw(ArgumentError("Can't split NamedTuple{$names} after {$names_a}")) + end + end +end diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index a9e089fb..ec113370 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -8,10 +8,10 @@ Represents a monatic bind resp. a mbind in general. User code should not create instances of `Bind` directly, but should call [`mbind`](@ref) instead. """ -struct Bind{K,M<:AbstractMeasure,C} <: AbstractMeasure - f_kernel::K +struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure + f_kernel::FK m_primary::M - f_combine::C + f_combine::FC end @@ -58,7 +58,7 @@ ab = f_combine(a, b) Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` can be unambiguously split into `a` and `b` again, knowing `α`. This is -currently implemented for `f_c` that is either `=>`/`Pair` or `tuple` (these +currently implemented for `f_c` that is either tuple or `=>`/`Pair` (these work for any combination of variate types), `vcat` (for tuple- or vector-like variates) and `merge` (`NamedTuple` variates). [`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to @@ -98,9 +98,9 @@ logdensityof(posterior, θ) function mbind end export mbind -@inline function mbind(f, μ::AbstractMeasure, f_combine = second) - F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(f_combine) - HierarchicalProductMeasure{F,M,G}(f, μ, f_combine) +@inline function mbind(f_β, α::AbstractMeasure, f_c = second) + F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) + HierarchicalProductMeasure{F,M,G}(f_β, α, f_c) end @@ -121,32 +121,19 @@ a ≈ a_orig && b ≈ b_orig """ function split_combined end -function split_combined(::typeof(=>), ::AbstractMeasure, x::Pair) - return x.first, x.second -end +@inline split_combined(::typeof(tuple), @nospecialize(α::AbstractMeasure), x::Tuple{T,U}) where T,U = ab +@inline split_combined(::Type{Pair}, @nospecialize(α::AbstractMeasure), ab::Pair) = (ab...,) -function split_combined(f_combine::F, μ_primary::AbstractMeasure, x) where F +function split_combined(f_c::FC, α::AbstractMeasure, ab) where FC _split_variate_byvalue(f_combine, testvalue(μ), x) end -# Necessary/helpful for type stability? -function split_combined(::Type{F}, μ::AbstractMeasure, x) where F - _split_variate_byvalue(F, testvalue(μ), x) -end +_split_variate_byvalue(::typeof(vcat), test_a::AbstractVector, ab::AbstractVector) = _split_after(ab, length(test_a)) -function _split_variate_byvalue(test_primary::AbstractVector, x::AbstractVector) - n, m = length(eachindex(test_primary)), length(eachindex(x)) - # TODO: Use getindex or view? - return x[begin:begin+n-1], x[begin+n:end] -end - -function _split_variate_byvalue(::Tuple{N}, x::Tuple{M}) where {N,M} - return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) -end +_split_variate_byvalue(::typeof(vcat), ::Tuple{N}, ab::Tuple) where N = _split_after(ab, Val{N}()) -@generated function _split_variate_byvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} - # TODO: implement - @assert false +function _split_variate_byvalue(::typeof(merge), ::NamedTuple{names_a}, ab::NamedTuple) where names_a + _split_after(ab, Val{names_a}) end From 37949837d7d69f9566e0f005fd9cf5b228646574 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 00:47:40 +0200 Subject: [PATCH 029/133] STASH --- src/MeasureBase.jl | 3 +- src/combinators/bind.jl | 104 ++++-------------- src/combinators/combined.jl | 204 +++++++++++++++++++++++++++++++++++ src/getdof.jl | 6 ++ test/combinators/bind.jl | 2 +- test/combinators/combined.jl | 11 ++ test/runtests.jl | 1 + 7 files changed, 248 insertions(+), 83 deletions(-) create mode 100644 src/combinators/combined.jl create mode 100644 test/combinators/combined.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 6666cf76..1f2341f4 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -140,12 +140,13 @@ include("primitives/lebesgue.jl") include("primitives/dirac.jl") include("primitives/trivial.jl") -include("combinators/bind.jl") include("combinators/transformedmeasure.jl") include("combinators/weighted.jl") include("combinators/superpose.jl") include("combinators/product.jl") include("combinators/power.jl") +include("combinators/combined.jl") +include("combinators/bind.jl") include("combinators/spikemixture.jl") include("combinators/likelihood.jl") include("combinators/restricted.jl") diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index ec113370..7edf5b89 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -1,20 +1,3 @@ -export Bind - -""" - struct MeasureBase.Bind{F,M<:AbstractMeasure,G} <: AbstractMeasure - -Represents a monatic bind resp. a mbind in general. - -User code should not create instances of `Bind` directly, but should call -[`mbind`](@ref) instead. -""" -struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure - f_kernel::FK - m_primary::M - f_combine::FC -end - - @doc raw""" mbind(f_β, α::AbstractMeasure, f_c = second) @@ -24,7 +7,7 @@ function `f_c`. `f_β` must be a function that maps a point `a` from the space of the primary measure `α` to a dependent secondary measure `β_a = f_β(a)`. -`ab = f_combine(a, b)` must map such a point `a` and a point `b` from the +`ab = f_c(a, b)` must map such a point `a` and a point `b` from the space of measure `β_a` to a combined value `ab = f_c(a, b)`. The resulting measure @@ -33,7 +16,7 @@ The resulting measure μ = mbind(f_c, α, f_β) ``` -has the mathethematical interpretation +has the mathethematical interpretation (on sets $$A$$ and $$B$$) ```math \mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) @@ -53,7 +36,7 @@ Computationally, `ab = rand(μ)` is equivalent to a = rand(μ_primary) β_a = f_β(a) b = rand(β_a) -ab = f_combine(a, b) +ab = f_c(a, b) ``` Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` @@ -100,83 +83,42 @@ export mbind @inline function mbind(f_β, α::AbstractMeasure, f_c = second) F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) - HierarchicalProductMeasure{F,M,G}(f_β, α, f_c) + Bind{F,M,G}(f_β, α, f_c) end """ - MeasureBase.split_combined(f_c, α::AbstractMeasure, ab) - -Splits a combined value `ab` that originated from combining a point `a_orig` -from the space of a measure `α` with a point `b_orig` from the space of -another measure `β` via `ab = f_c(a_orig, b_orig)`. + struct MeasureBase.Bind <: AbstractMeasure -So with `a_orig = rand(α)`, `b_orig = rand(β)` and -`ab = f_c(a_orig, b_orig)`, the following must hold true: +Represents a monatic bind resp. a mbind in general. -```julia -a, b2 = split_combined(f_c, α, ab) -a ≈ a_orig && b ≈ b_orig -``` +User code should not create instances of `Bind` directly, but should call +[`mbind`](@ref) instead. """ -function split_combined end - -@inline split_combined(::typeof(tuple), @nospecialize(α::AbstractMeasure), x::Tuple{T,U}) where T,U = ab -@inline split_combined(::Type{Pair}, @nospecialize(α::AbstractMeasure), ab::Pair) = (ab...,) - -function split_combined(f_c::FC, α::AbstractMeasure, ab) where FC - _split_variate_byvalue(f_combine, testvalue(μ), x) -end - -_split_variate_byvalue(::typeof(vcat), test_a::AbstractVector, ab::AbstractVector) = _split_after(ab, length(test_a)) - -_split_variate_byvalue(::typeof(vcat), ::Tuple{N}, ab::Tuple) where N = _split_after(ab, Val{N}()) - -function _split_variate_byvalue(::typeof(merge), ::NamedTuple{names_a}, ab::NamedTuple) where names_a - _split_after(ab, Val{names_a}) +struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure + f_β::FK + α::M + f_c::FC end - -_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b) - - -_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b) - -_autoflat_combine_variates(a::Any, b::Any) = (a, b) - -_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b) - -_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b) - -# TODO: Check that names don't overlap: -_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b) - - -_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2) +_local_productmeasure(fc, α, \) = productmeasure(fc(marginals(μ1), marginals(μ2))) # TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2) # Needs a FlatProductMeasure type. -function _localmeasure_with_rest(μ::HierarchicalProductMeasure, x) - μ_primary = μ.m - local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x) - μ_secondary = μ.f(x_secondary) - local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary) - return _local_productmeasure(μ.flatten_mode, local_primary, local_secondary), x_rest -end - -function _localmeasure_with_rest(μ::AbstractMeasure, x) - x_checked = checked_arg(μ, x) - return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0) +function _local_split_combined(f_c, μ::Bind, x) + local_α, a, b_withrest = _local_split_combined(μ.α, x) + β_a = μ.f_β(a) + local_β_a, b, rest = _localmeasure_with_rest(β_a, b_withrest) + local_μ = productmeasure(μ.f_c(marginals(local_α), marginals(local_β_a))) + local_ab, rest = split_combined(μ.f_c, local_μ, ab) + return local_μ, local_ab, rest end -function localmeasure(μ::HierarchicalProductMeasure, x) - h_local, x_rest = _localmeasure_with_rest(μ, x) - if !isempty(x_rest) - throw(ArgumentError("Variate too long while computing localmeasure of Bind")) - end - return h_local +function _local_split_combined(f_c, μ::AbstractMeasure, x_withrest) + x, rest = split_combined(f_c, μ, x_withrest) + return localmeasure(μ, x), x, rest end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl new file mode 100644 index 00000000..f02f86b5 --- /dev/null +++ b/src/combinators/combined.jl @@ -0,0 +1,204 @@ +""" + MeasureBase.local_split_combined(f_c, α::AbstractMeasure, ab) + +Splits a combined value `ab` that originated from combining a point `a` +from the space of a measure `α` with a point `b` from the space of +another measure `β` via `ab = f_c(a, b)`. + +Returns a semantic equivalent of `(localmeasure(α, a), a, b)`. + +With `a_orig = rand(α)`, `b_orig = rand(β)` and +`ab = f_c(a_orig, b_orig)`, the following must hold true: + +```julia +local_α, a, b = local_split_combined(f_c, α, ab) +a ≈ a_orig && b ≈ b_orig +``` +""" +function local_split_combined end + +function local_split_combined(f_c, α::AbstractMeasure, ab) + a, b = _generic_split_combined(fc, α, ab) + return localmeasure(α, a), a, b +end + +@inline _generic_split_combined(::typeof(tuple), @nospecialize(α::AbstractMeasure), x::Tuple{T,U}) where T,U = ab +@inline _generic_split_combined(::Type{Pair}, @nospecialize(α::AbstractMeasure), ab::Pair) = (ab...,) + +function _generic_split_combined(f_c::FC, α::AbstractMeasure, ab) where FC + _split_variate_byvalue(f_c, testvalue(μ), x) +end + +_split_variate_byvalue(::typeof(vcat), test_a::AbstractVector, ab::AbstractVector) = _split_after(ab, length(test_a)) + +_split_variate_byvalue(::typeof(vcat), ::Tuple{N}, ab::Tuple) where N = _split_after(ab, Val{N}()) + +function _split_variate_byvalue(::typeof(merge), ::NamedTuple{names_a}, ab::NamedTuple) where names_a + _split_after(ab, Val{names_a}) +end + + +""" + mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) + +Combines two measures `α` and `β` to a joint measure via a point combination +function `f_c`. + +`f_c` must combine a given point `a` from the space of measure `α` with a +given point `b` from the space of measure `β` to a single value +`ab = f_c(a, b)` in the space of the combined measure +`μ = mcombine(f_c, α, β)`. + +The combined measure has the mathethematical interpretation (on +sets $$A$$ and $$B$$) + +```math +\mu(f_c(A, B)) = \alpha(A)\, \beta(B) +``` +""" +function mcombine end +export mcombine + +function mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) + FC, MA, MB = Core.Typeof(f_c), Core.Typeof(α), Core.Typeof(β) + JointMeasure{FC,MA,MB}(f_c, α, β) +end + +function mcombine(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) + productmeasure((a, b)) +end + +function mcombine(::typeof(vcat), α::AbstractProductMeasure, β::AbstractProductMeasure) + productmeasure(vcat(marginals(α), marginals(β))) +end + +function mcombine(::typeof(merge), α::AbstractProductMeasure, β::AbstractProductMeasure) + productmeasure(merge(marginals(α), marginals(β))) +end + + +""" + struct JointMeasure <: AbstractMeasure + +Represents a monatic bind resp. a mbind in general. + +User code should not create instances of `Joint` directly, but should call +[`mbind`](@ref) instead. +""" + +JointMeasure{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMeasure + f_c::FC + α::MA + β::MB +end + + +# TODO: Could split ab here, but would be wasteful. +@inline insupport(::Joint, ab) = NoFastInsupport() + +@inline getdof(μ::Joint) = getdof(μ.α) + getdof(μ.β) + +# Bypass `checked_arg`, would require require splitting ab: +@inline checked_arg(::Joint, ab) = ab + +rootmeasure(::Joint) = mcombine(μ.f_c rootmeasure(μ), rootmeasure(ν)) + +basemeasure(::Joint) = mcombine(μ.f_c basemeasure(μ), basemeasure(ν)) + +logdensity_def(::Joint, ab) + # Use _local_split_combined to avoid duplicate calculation of localmeasure(α): + local_α, a, b = _local_split_combined(μ.f_c, μ.α, ab) + return logdensity_def(local_α, a) + logdensity_def(μ.β, b) +end + +# Specialize logdensityof directly to avoid creating temporary joint base measures: +logdensityof(::Joint, ab) + # Use _local_split_combined to avoid duplicate calculation of localmeasure(α): + local_α, a, b = _local_split_combined(μ.f_c, μ.α, ab) + return logdensityof(local_α, a) + logdensityof(μ.β, b) +end + +function _local_split_combined(f_c, α::AbstractMeasure, ab) + a, b = _generic_split_combined(f_c, α, ab) + return localmeasure(α, a), a, b +end + + + +# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient +# # for AutoFlatten, since variate will be split in `localmeasure` and then +# # split again in log-density evaluation. Maybe add something like +# function unsafe_logdensityof(h::Joint, x) +# local_primary, local_secondary, x_primary, x_secondary = ... +# # Need to call full logdensityof for h_secondary since x_secondary hasn't +# # been checked yet: +# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) +# end + + +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Joint) where {T<:Real} + x_primary = rand(rng, T, h.m) + x_secondary = rand(rng, T, h.f(x_primary)) + return _combine_variates(h.flatten_mode, x_primary, x_secondary) +end + + +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::Joint, x) + μ_primary = μ.m + y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) + μ_secondary = μ.f(x_secondary) + y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) + return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest +end + +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) + dof_μ = getdof(μ) + x_μ, x_rest = _generic_split_combined(flatten_mode, μ, x) + y = transport_to(ν_inner^dof_μ, μ, x_μ) + return y, x_rest +end + +function transport_def(ν::_PowerStdMeasure{1}, μ::Joint, x) + ν_inner = _get_inner_stdmeasure(ν) + y, x_rest = _to_std_with_rest(ν_inner, μ, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long during transport involving Joint")) + end + return y +end + + +function _from_std_with_rest(ν::Joint, μ_inner::StdMeasure, x) + ν_primary = ν.m + y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) + ν_secondary = ν.f(y_primary) + y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) + return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest +end + +function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) + dof_ν = getdof(ν) + len_x = length(eachindex(x)) + + # Since we can't check DOF of original Joint, we could "run out x" if + # the original x was too short. `transport_to` below will detect this, but better + # throw a more informative exception here: + if len_x < dof_ν + throw(ArgumentError("Variate too short during transport involving Joint")) + end + + y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) + x_rest = Fill(zero(eltype(x)), dof_ν - len_x) + return y, x_rest +end + +function transport_def(ν::Joint, μ::_PowerStdMeasure{1}, x) + # Sanity check, should be checked by transport machinery already: + @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector + μ_inner = _get_inner_stdmeasure(μ) + y, x_rest = _from_std_with_rest(ν, μ_inner, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long during transport involving Joint")) + end + return y +end diff --git a/src/getdof.jl b/src/getdof.jl index 2c0bb60c..6a31fb4b 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -7,6 +7,12 @@ a global property of the measure. """ struct NoDOF{MU} end +_add_dof(dof_a::Real, dof_b::Real) = dof_a + dof_b +_add_dof(dof_a::NoDOF, ::Real) = dof_a +_add_dof(::Real, dof_b::NoDOF) = dof_b +_add_dof(dof_a::NoDOF, ::NoDOF) = dof_a + + """ getdof(μ) diff --git a/test/combinators/bind.jl b/test/combinators/bind.jl index 55fe0efb..8ad7bc65 100644 --- a/test/combinators/bind.jl +++ b/test/combinators/bind.jl @@ -3,7 +3,7 @@ using Test using MeasureBase using MeasureBase: AbstractMeasure using MeasureBase: StdExponential, StdLogistic, StdUniform -using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: pushfwd, pullbck, mbind, localmeasure using MeasureBase: mbind, mintegrate, mintegrate_exp, density_rel, logdensity_rel @testset "bind.jl" begin diff --git a/test/combinators/combined.jl b/test/combinators/combined.jl new file mode 100644 index 00000000..7111beb8 --- /dev/null +++ b/test/combinators/combined.jl @@ -0,0 +1,11 @@ +using Test + +using MeasureBase +using MeasureBase: AbstractMeasure +using MeasureBase: StdExponential, StdLogistic, StdUniform +using MeasureBase: pushfwd, pullbck, mbind, productmeasure, jointmeasure +using MeasureBase: mbind, mintegrate, mintegrate_exp, density_rel, logdensity_rel + +@testset "combined.jl" begin + +end diff --git a/test/runtests.jl b/test/runtests.jl index 0e47fd95..e2423b81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,6 +19,7 @@ include("smf.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") +include("combinators/combined.jl") include("combinators/bind.jl") include("measure_operators.jl") From 4b5b7a0dc4d5e446f3a7f181f956b7c1d300c27d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 05:26:04 +0200 Subject: [PATCH 030/133] STASH --- src/combinators/bind.jl | 98 +++++++++++++++++++++++++++---------- src/combinators/combined.jl | 62 ++++++++--------------- src/density-core.jl | 36 +++++++++----- 3 files changed, 119 insertions(+), 77 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 7edf5b89..c33c4fb9 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -102,23 +102,60 @@ struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure end -_local_productmeasure(fc, α, \) = productmeasure(fc(marginals(μ1), marginals(μ2))) +""" + MeasureBase.transportmeasure(μ::Bind, x)::AbstractMeasure + +Evaluates a monatic bind `μ` at a point `x`. + +The resulting measure behaves like `μ` in the infinitesimal neighborhood +of `x` in respect to density calculation and transport as well. +""" +function transportmeasure(μ::Bind, x) + tpm_α, a, b = tpmeasure_split_combined(μ.α, x) + tpm_β_a = transportmeasure(μ.f_β(a), b) + mcombine(μ.f_c, tpm_α, tpm_β_a) +end + +localmeasure(μ::Bind, x) = transportmeasure(μ, x) + + +tpmeasure_split_combined(f_c, μ::Bind, xy) = _bind_lsc(f_c, μ::Bind, xy) + +function _bind_lsc(f_c::typeof(tuple), μ::Bind, xy::Tuple{Vararg{Any,2}}) + x, y = x[1], y[1] + tpm_μ = transportmeasure(μ, x) + return tpm_μ, x, y +end + +function _bind_lsc(f_c::Type{Pair}, μ::Bind, xy::Pair) + x, y = x.first, y.second + tpm_μ = transportmeasure(μ, x) + return tpm_μ, x, y +end -# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2) -# Needs a FlatProductMeasure type. +const _CatBind{FC} = _BindBy{<:Any,<:Any,FC} -function _local_split_combined(f_c, μ::Bind, x) - local_α, a, b_withrest = _local_split_combined(μ.α, x) +_bind_lsc(f_c::typeof(vcat), μ::_CatBind{typeof{vcat}}, xy::AbstractVector) = _bind_lsc_cat(f_c, μ, xy) +_bind_lsc(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) = _bind_lsc_cat(f_c, μ, xy) + +function _bind_lsc_cat_lμabyxy(f_c, μ, xy) + tpm_α, a, by = tpmeasure_split_combined(μ.f_c, μ.α, xy) β_a = μ.f_β(a) - local_β_a, b, rest = _localmeasure_with_rest(β_a, b_withrest) - local_μ = productmeasure(μ.f_c(marginals(local_α), marginals(local_β_a))) - local_ab, rest = split_combined(μ.f_c, local_μ, ab) - return local_μ, local_ab, rest + tpm_β_a, b, y = tpmeasure_split_combined(f_c, β_a, by) + tpm_μ = mcombine(μ.f_c, tpm_α, tpm_β_a) + return tpm_μ, a, b, y, xy +end + +function _bind_lsc_cat(f_c::typeof(vcat), μ::_CatBind{typeof{vcat}}, xy::AbstractVector) + tpm_μ, a, b, y, xy = _bind_lsc_cat_lμabyxy(f_c, μ, xy) + # Don't use `x = f_c(a, b)` here, would allocate, splitting xy can use views: + x, y = _split_after(xy, length(a) + length(b)) + return tpm_μ, x, y end -function _local_split_combined(f_c, μ::AbstractMeasure, x_withrest) - x, rest = split_combined(f_c, μ, x_withrest) - return localmeasure(μ, x), x, rest +function _bind_lsc_cat(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) + tpm_μ, a, b, y, xy = _bind_lsc_cat_lμabyxy(f_c, μ, xy) + return tpm_μ, f_c(a, b), y end @@ -133,28 +170,39 @@ rootmeasure(::Bind) = throw(ArgumentError("root measure is implicit, but can't b basemeasure(::Bind) = throw(ArgumentError("basemeasure is not available for Bind")) +testvalue(::Bind) = throw(ArgumentError("testvalue is not available for Bind")) + logdensity_def(::Bind, x) = throw(ArgumentError("logdensity_def is not available for Bind")) +# Specialize logdensityof to avoid duplicate calculations: +function logdensityof(μ::Bind, x) + tpm_α, a, b = tpmeasure_split_combined(μ.α, x) + β_a = μ.f_β(a) + logdensityof(tpm_α, a) + logdensityof(β_a, b) +end -# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient -# # for AutoFlatten, since variate will be split in `localmeasure` and then -# # split again in log-density evaluation. Maybe add something like -# function unsafe_logdensityof(h::Bind, x) -# local_primary, local_secondary, x_primary, x_secondary = ... -# # Need to call full logdensityof for h_secondary since x_secondary hasn't -# # been checked yet: -# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) -# end +# Specialize unsafe_logdensityof to avoid duplicate calculations: +function unsafe_logdensityof(μ::Bind, x) + tpm_α, a, b = tpmeasure_split_combined(μ.α, x) + β_a = μ.f_β(a) + unsafe_logdensityof(tpm_α, a) + unsafe_logdensityof(β_a, b) +end -function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Bind) where {T<:Real} - x_primary = rand(rng, T, h.m) - x_secondary = rand(rng, T, h.f(x_primary)) - return _combine_variates(h.flatten_mode, x_primary, x_secondary) +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, μ::Bind) where {T<:Real} + a = rand(rng, T, μ.α) + b = rand(rng, T, μ.f_β(a)) + return μ.f_c(a, b) end + + +#!!!!!!!!!!!!!!!!!! TODO: + + function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::Bind, x) + μ_primary = μ.m y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) μ_secondary = μ.f(x_secondary) diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index f02f86b5..9c67e75e 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -1,28 +1,29 @@ """ - MeasureBase.local_split_combined(f_c, α::AbstractMeasure, ab) + MeasureBase.tpmeasure_split_combined(f_c, α::AbstractMeasure, ab) Splits a combined value `ab` that originated from combining a point `a` from the space of a measure `α` with a point `b` from the space of another measure `β` via `ab = f_c(a, b)`. -Returns a semantic equivalent of `(localmeasure(α, a), a, b)`. +Returns a semantic equivalent of +`(MeasureBase.transportmeasure(α, a), a, b)`. With `a_orig = rand(α)`, `b_orig = rand(β)` and `ab = f_c(a_orig, b_orig)`, the following must hold true: ```julia -local_α, a, b = local_split_combined(f_c, α, ab) +local_α, a, b = tpmeasure_split_combined(f_c, α, ab) a ≈ a_orig && b ≈ b_orig ``` """ -function local_split_combined end +function tpmeasure_split_combined end -function local_split_combined(f_c, α::AbstractMeasure, ab) - a, b = _generic_split_combined(fc, α, ab) - return localmeasure(α, a), a, b +function tpmeasure_split_combined(f_c, α::AbstractMeasure, ab) + a, b = _generic_split_combined(f_c, α, ab) + return transportmeasure(α, a), a, b end -@inline _generic_split_combined(::typeof(tuple), @nospecialize(α::AbstractMeasure), x::Tuple{T,U}) where T,U = ab +@inline _generic_split_combined(::typeof(tuple), @nospecialize(α::AbstractMeasure), x::Tuple{Vararg{Any,2}}) @inline _generic_split_combined(::Type{Pair}, @nospecialize(α::AbstractMeasure), ab::Pair) = (ab...,) function _generic_split_combined(f_c::FC, α::AbstractMeasure, ab) where FC @@ -61,39 +62,35 @@ export mcombine function mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) FC, MA, MB = Core.Typeof(f_c), Core.Typeof(α), Core.Typeof(β) - JointMeasure{FC,MA,MB}(f_c, α, β) + Combined{FC,MA,MB}(f_c, α, β) end function mcombine(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) productmeasure((a, b)) end -function mcombine(::typeof(vcat), α::AbstractProductMeasure, β::AbstractProductMeasure) - productmeasure(vcat(marginals(α), marginals(β))) -end - -function mcombine(::typeof(merge), α::AbstractProductMeasure, β::AbstractProductMeasure) - productmeasure(merge(marginals(α), marginals(β))) +function mcombine(f_c::Union{typeof(vcat),typeof(merge)}, α::AbstractProductMeasure, β::AbstractProductMeasure) + productmeasure(f_c(marginals(α), marginals(β))) end """ - struct JointMeasure <: AbstractMeasure + struct Combined <: AbstractMeasure -Represents a monatic bind resp. a mbind in general. +Represents a combination of two measures. User code should not create instances of `Joint` directly, but should call -[`mbind`](@ref) instead. +[`mcombine(f_c, α, β)`](@ref) instead. """ -JointMeasure{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMeasure +Combined{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMeasure f_c::FC α::MA β::MB end -# TODO: Could split ab here, but would be wasteful. +# TODO: Could split `ab`` here, but would be wasteful. @inline insupport(::Joint, ab) = NoFastInsupport() @inline getdof(μ::Joint) = getdof(μ.α) + getdof(μ.β) @@ -106,35 +103,17 @@ rootmeasure(::Joint) = mcombine(μ.f_c rootmeasure(μ), rootmeasure(ν)) basemeasure(::Joint) = mcombine(μ.f_c basemeasure(μ), basemeasure(ν)) logdensity_def(::Joint, ab) - # Use _local_split_combined to avoid duplicate calculation of localmeasure(α): - local_α, a, b = _local_split_combined(μ.f_c, μ.α, ab) + # Use _tpmeasure_split_combined to avoid duplicate calculation of transportmeasure(α): + local_α, a, b = _tpmeasure_split_combined(μ.f_c, μ.α, ab) return logdensity_def(local_α, a) + logdensity_def(μ.β, b) end # Specialize logdensityof directly to avoid creating temporary joint base measures: logdensityof(::Joint, ab) - # Use _local_split_combined to avoid duplicate calculation of localmeasure(α): - local_α, a, b = _local_split_combined(μ.f_c, μ.α, ab) + local_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) return logdensityof(local_α, a) + logdensityof(μ.β, b) end -function _local_split_combined(f_c, α::AbstractMeasure, ab) - a, b = _generic_split_combined(f_c, α, ab) - return localmeasure(α, a), a, b -end - - - -# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient -# # for AutoFlatten, since variate will be split in `localmeasure` and then -# # split again in log-density evaluation. Maybe add something like -# function unsafe_logdensityof(h::Joint, x) -# local_primary, local_secondary, x_primary, x_secondary = ... -# # Need to call full logdensityof for h_secondary since x_secondary hasn't -# # been checked yet: -# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) -# end - function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Joint) where {T<:Real} x_primary = rand(rng, T, h.m) @@ -142,6 +121,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Joint) where {T<:Real} return _combine_variates(h.flatten_mode, x_primary, x_secondary) end +#!!!!!!!!!!!!!!!!!!! TODO: function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::Joint, x) μ_primary = μ.m diff --git a/src/density-core.jl b/src/density-core.jl index 626c6ced..849e636c 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -1,5 +1,3 @@ -export localmeasure - export logdensityof export logdensity_rel export logdensity_def @@ -15,19 +13,35 @@ export density_def """ localmeasure(m::AbstractMeasure, x)::AbstractMeasure -Return a local measure of `m` at `x` which will be `m` itself for many -measures. - -A local measure of `m` is defined here as a measure that behaves like `m` in -the infinitesimal neighborhood of `x`. +Return a measure that behaves like `m` in the infinitesimal neighborhood +of `x` in respect to density calculation. -Note that the resulting measure may not be well defined outside of such a -neighborhood of `x`. +Note that the resulting measure may not be well defined outside of the +infinitesimal neighborhood of `x`. -See [`HierarchicalMeasure`](@ref) as an example of a measure where -`localmeasure` returns different measures depending on `x`. +For most measure types simply returns `m` itself. [`mbind`](@ref), +for example, generates measures for with `localmeasure(m, x)` depends +on `x`. """ localmeasure(m::AbstractMeasure, x) = m +export localmeasure + + +""" + MeasureBase.transportmeasure(μ::Bind, x)::AbstractMeasure + +Return a measure that behaves like `m` in the infinitesimal neighborhood +of `x` in respect to both transport and density calculation. + +Note that the resulting measure may not be well defined outside of the +infinitesimal neighborhood of `x`. + +For most measure types simply returns `m` itself. [`mbind`](@ref), +for example, generates measures for with `transportmeasure(m, x)` depends +on `x`. +""" +transportmeasure(m::AbstractMeasure, x) = m +export localmeasure """ From 65192347bf927f3e360837768b4ad401728680a4 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 05:54:54 +0200 Subject: [PATCH 031/133] STASH --- src/combinators/bind.jl | 40 +++++++++------------------------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index c33c4fb9..670f0609 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -197,44 +197,22 @@ end - -#!!!!!!!!!!!!!!!!!! TODO: - - -function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::Bind, x) - - μ_primary = μ.m - y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) - μ_secondary = μ.f(x_secondary) - y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) - return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest -end - -function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) - dof_μ = getdof(μ) - x_μ, x_rest = split_combined(flatten_mode, μ, x) - y = transport_to(ν_inner^dof_μ, μ, x_μ) - return y, x_rest -end - function transport_def(ν::_PowerStdMeasure{1}, μ::Bind, x) - ν_inner = _get_inner_stdmeasure(ν) - y, x_rest = _to_std_with_rest(ν_inner, μ, x) - if !isempty(x_rest) - throw(ArgumentError("Variate too long during transport involving Bind")) - end - return y + tpm_μ = transportmeasure(μ, x) + return transport_def(ν, tpm_μ, x) end function _from_std_with_rest(ν::Bind, μ_inner::StdMeasure, x) - ν_primary = ν.m - y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) - ν_secondary = ν.f(y_primary) - y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) - return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest + a, x2 = _from_std_with_rest(ν.α, μ_inner, x) + β_a = ν.f_β(a) + b, x_rest = _from_std_with_rest(β_a, μ_inner, x2) + return ν.f_c(a, b), x_rest end +# !!!!!!!!!!!!! TODO How to handle PushforwardMeasure, WeightedMeasure and +# so on that contain a Bind? + function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) dof_ν = getdof(ν) len_x = length(eachindex(x)) From 7a4e85d13d58fb12522cd89775c2879bcd634902 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 09:51:41 +0200 Subject: [PATCH 032/133] STASH --- src/combinators/bind.jl | 45 +++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 670f0609..f67d5e2b 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -119,15 +119,15 @@ end localmeasure(μ::Bind, x) = transportmeasure(μ, x) -tpmeasure_split_combined(f_c, μ::Bind, xy) = _bind_lsc(f_c, μ::Bind, xy) +tpmeasure_split_combined(f_c, μ::Bind, xy) = _bind_tsc(f_c, μ::Bind, xy) -function _bind_lsc(f_c::typeof(tuple), μ::Bind, xy::Tuple{Vararg{Any,2}}) +function _bind_tsc(f_c::typeof(tuple), μ::Bind, xy::Tuple{Vararg{Any,2}}) x, y = x[1], y[1] tpm_μ = transportmeasure(μ, x) return tpm_μ, x, y end -function _bind_lsc(f_c::Type{Pair}, μ::Bind, xy::Pair) +function _bind_tsc(f_c::Type{Pair}, μ::Bind, xy::Pair) x, y = x.first, y.second tpm_μ = transportmeasure(μ, x) return tpm_μ, x, y @@ -135,10 +135,10 @@ end const _CatBind{FC} = _BindBy{<:Any,<:Any,FC} -_bind_lsc(f_c::typeof(vcat), μ::_CatBind{typeof{vcat}}, xy::AbstractVector) = _bind_lsc_cat(f_c, μ, xy) -_bind_lsc(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) = _bind_lsc_cat(f_c, μ, xy) +_bind_tsc(f_c::typeof(vcat), μ::_CatBind{typeof{vcat}}, xy::AbstractVector) = _bind_tsc_cat(f_c, μ, xy) +_bind_tsc(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) = _bind_tsc_cat(f_c, μ, xy) -function _bind_lsc_cat_lμabyxy(f_c, μ, xy) +function _bind_tsc_cat_lμabyxy(f_c, μ, xy) tpm_α, a, by = tpmeasure_split_combined(μ.f_c, μ.α, xy) β_a = μ.f_β(a) tpm_β_a, b, y = tpmeasure_split_combined(f_c, β_a, by) @@ -146,15 +146,15 @@ function _bind_lsc_cat_lμabyxy(f_c, μ, xy) return tpm_μ, a, b, y, xy end -function _bind_lsc_cat(f_c::typeof(vcat), μ::_CatBind{typeof{vcat}}, xy::AbstractVector) - tpm_μ, a, b, y, xy = _bind_lsc_cat_lμabyxy(f_c, μ, xy) +function _bind_tsc_cat(f_c::typeof(vcat), μ::_CatBind{typeof{vcat}}, xy::AbstractVector) + tpm_μ, a, b, y, xy = _bind_tsc_cat_lμabyxy(f_c, μ, xy) # Don't use `x = f_c(a, b)` here, would allocate, splitting xy can use views: x, y = _split_after(xy, length(a) + length(b)) return tpm_μ, x, y end -function _bind_lsc_cat(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) - tpm_μ, a, b, y, xy = _bind_lsc_cat_lμabyxy(f_c, μ, xy) +function _bind_tsc_cat(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) + tpm_μ, a, b, y, xy = _bind_tsc_cat_lμabyxy(f_c, μ, xy) return tpm_μ, f_c(a, b), y end @@ -210,11 +210,13 @@ function _from_std_with_rest(ν::Bind, μ_inner::StdMeasure, x) return ν.f_c(a, b), x_rest end -# !!!!!!!!!!!!! TODO How to handle PushforwardMeasure, WeightedMeasure and -# so on that contain a Bind? - function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) dof_ν = getdof(ν) + origin = transport_origin(ν) + return _from_std_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) +end + +function _from_std_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::StdMeasure, x) len_x = length(eachindex(x)) # Since we can't check DOF of original Bind, we could "run out x" if @@ -224,11 +226,24 @@ function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) throw(ArgumentError("Variate too short during transport involving Bind")) end - y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) - x_rest = Fill(zero(eltype(x)), dof_ν - len_x) + x_inner_dof, x_rest = _split_after(x, dof_ν) + y = transport_to(ν, μ_inner^dof_ν, x_inner_dof) return y, x_rest end +function _from_std_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::StdMeasure, x) + _from_std_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) +end + +function _from_std_with_rest_withorigin(ν::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) + x_origin, x_rest = _from_std_with_rest(ν_origin, x, μ_inner) + from_origin(x_origin), x_rest +end + +function _from_std_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin, μ_inner::StdMeasure, x) + throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) +end + function transport_def(ν::Bind, μ::_PowerStdMeasure{1}, x) # Sanity check, should be checked by transport machinery already: @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector From a94adf8152f996c90b11943495d842e47ad1663d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 11:05:21 +0200 Subject: [PATCH 033/133] STASH _to_mvstd _to_mvstd --- src/combinators/bind.jl | 67 +++++----------- src/combinators/combined.jl | 147 ++++++++++++++++++++++-------------- 2 files changed, 106 insertions(+), 108 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index f67d5e2b..899c6633 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -196,61 +196,28 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, μ::Bind) where {T<:Real} end - -function transport_def(ν::_PowerStdMeasure{1}, μ::Bind, x) - tpm_μ = transportmeasure(μ, x) - return transport_def(ν, tpm_μ, x) +function transport_def(ν::_PowerStdMeasure{1}, μ::Bind, ab) + ν_inner = _get_inner_stdmeasure(ν) + _to_mvstd(ν_inner, μ, ab) end - -function _from_std_with_rest(ν::Bind, μ_inner::StdMeasure, x) - a, x2 = _from_std_with_rest(ν.α, μ_inner, x) - β_a = ν.f_β(a) - b, x_rest = _from_std_with_rest(β_a, μ_inner, x2) - return ν.f_c(a, b), x_rest -end - -function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) - dof_ν = getdof(ν) - origin = transport_origin(ν) - return _from_std_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) -end - -function _from_std_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::StdMeasure, x) - len_x = length(eachindex(x)) - - # Since we can't check DOF of original Bind, we could "run out x" if - # the original x was too short. `transport_to` below will detect this, but better - # throw a more informative exception here: - if len_x < dof_ν - throw(ArgumentError("Variate too short during transport involving Bind")) - end - - x_inner_dof, x_rest = _split_after(x, dof_ν) - y = transport_to(ν, μ_inner^dof_ν, x_inner_dof) - return y, x_rest -end - -function _from_std_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::StdMeasure, x) - _from_std_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) -end - -function _from_std_with_rest_withorigin(ν::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) - x_origin, x_rest = _from_std_with_rest(ν_origin, x, μ_inner) - from_origin(x_origin), x_rest +function _to_mvstd(ν_inner::StdMeasure, μ::Bind, ab) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) + β_a = μ.f_β(a) + y1 = _to_mvstd(ν_inner, tpm_α, a) + y2 = _to_mvstd(ν_inner, β_a, b) + return vcat(y1, y2) end -function _from_std_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin, μ_inner::StdMeasure, x) - throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) -end function transport_def(ν::Bind, μ::_PowerStdMeasure{1}, x) - # Sanity check, should be checked by transport machinery already: - @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector μ_inner = _get_inner_stdmeasure(μ) - y, x_rest = _from_std_with_rest(ν, μ_inner, x) - if !isempty(x_rest) - throw(ArgumentError("Variate too long during transport involving Bind")) - end - return y + _from_mvstd(ν, μ_inner, x) +end + +function _from_mvstd_with_rest(ν::Bind, μ_inner::StdMeasure, x) + a, x2 = _from_mvstd_with_rest(ν.α, μ_inner, x) + β_a = ν.f_β(a) + b, x_rest = _from_mvstd_with_rest(β_a, μ_inner, x2) + return ν.f_c(a, b), x_rest end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 9c67e75e..fc5966d8 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -12,7 +12,7 @@ With `a_orig = rand(α)`, `b_orig = rand(β)` and `ab = f_c(a_orig, b_orig)`, the following must hold true: ```julia -local_α, a, b = tpmeasure_split_combined(f_c, α, ab) +tpm_α, a, b = tpmeasure_split_combined(f_c, α, ab) a ≈ a_orig && b ≈ b_orig ``` """ @@ -23,8 +23,8 @@ function tpmeasure_split_combined(f_c, α::AbstractMeasure, ab) return transportmeasure(α, a), a, b end -@inline _generic_split_combined(::typeof(tuple), @nospecialize(α::AbstractMeasure), x::Tuple{Vararg{Any,2}}) -@inline _generic_split_combined(::Type{Pair}, @nospecialize(α::AbstractMeasure), ab::Pair) = (ab...,) +@inline _generic_split_combined(::typeof(tuple), ::AbstractMeasure, x::Tuple{Vararg{Any,2}}) +@inline _generic_split_combined(::Type{Pair}, ::AbstractMeasure, ab::Pair) = (ab...,) function _generic_split_combined(f_c::FC, α::AbstractMeasure, ab) where FC _split_variate_byvalue(f_c, testvalue(μ), x) @@ -42,7 +42,7 @@ end """ mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) -Combines two measures `α` and `β` to a joint measure via a point combination +Combines two measures `α` and `β` to a combined measure via a point combination function `f_c`. `f_c` must combine a given point `a` from the space of measure `α` with a @@ -79,11 +79,11 @@ end Represents a combination of two measures. -User code should not create instances of `Joint` directly, but should call +User code should not create instances of `Combined` directly, but should call [`mcombine(f_c, α, β)`](@ref) instead. """ -Combined{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMeasure +struct Combined{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMeasure f_c::FC α::MA β::MB @@ -91,94 +91,125 @@ end # TODO: Could split `ab`` here, but would be wasteful. -@inline insupport(::Joint, ab) = NoFastInsupport() +@inline insupport(::Combined, ab) = NoFastInsupport() -@inline getdof(μ::Joint) = getdof(μ.α) + getdof(μ.β) +@inline getdof(μ::Combined) = getdof(μ.α) + getdof(μ.β) # Bypass `checked_arg`, would require require splitting ab: -@inline checked_arg(::Joint, ab) = ab +@inline checked_arg(::Combined, ab) = ab -rootmeasure(::Joint) = mcombine(μ.f_c rootmeasure(μ), rootmeasure(ν)) +rootmeasure(::Combined) = mcombine(μ.f_c rootmeasure(μ), rootmeasure(ν)) -basemeasure(::Joint) = mcombine(μ.f_c basemeasure(μ), basemeasure(ν)) +basemeasure(::Combined) = mcombine(μ.f_c basemeasure(μ), basemeasure(ν)) -logdensity_def(::Joint, ab) - # Use _tpmeasure_split_combined to avoid duplicate calculation of transportmeasure(α): - local_α, a, b = _tpmeasure_split_combined(μ.f_c, μ.α, ab) - return logdensity_def(local_α, a) + logdensity_def(μ.β, b) +function logdensity_def(μ::Combined, ab) + # Use tpmeasure_split_combined to avoid duplicate calculation of transportmeasure(α): + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) + return logdensity_def(tpm_α, a) + logdensity_def(μ.β, b) end -# Specialize logdensityof directly to avoid creating temporary joint base measures: -logdensityof(::Joint, ab) - local_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) - return logdensityof(local_α, a) + logdensityof(μ.β, b) +# Specialize logdensityof directly to avoid creating temporary combined base measures: +function logdensityof(μ::Combined, ab) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) + return logdensityof(tpm_α, a) + logdensityof(μ.β, b) end -function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Joint) where {T<:Real} +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Combined) where {T<:Real} x_primary = rand(rng, T, h.m) x_secondary = rand(rng, T, h.f(x_primary)) return _combine_variates(h.flatten_mode, x_primary, x_secondary) end -#!!!!!!!!!!!!!!!!!!! TODO: -function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::Joint, x) - μ_primary = μ.m - y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) - μ_secondary = μ.f(x_secondary) - y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) - return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest +function transport_def(ν::_PowerStdMeasure{1}, μ::Combined, ab) + ν_inner = _get_inner_stdmeasure(ν) + _to_mvstd(ν_inner, μ, ab) end -function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) - dof_μ = getdof(μ) - x_μ, x_rest = _generic_split_combined(flatten_mode, μ, x) - y = transport_to(ν_inner^dof_μ, μ, x_μ) - return y, x_rest +function _to_mvstd(ν_inner::StdMeasure, μ::Combined, ab) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) + y1 = _to_mvstd(ν_inner, tpm_α, a) + y2 = _to_mvstd(ν_inner, μ.β, b) + return vcat(y1, y2) end -function transport_def(ν::_PowerStdMeasure{1}, μ::Joint, x) - ν_inner = _get_inner_stdmeasure(ν) - y, x_rest = _to_std_with_rest(ν_inner, μ, x) - if !isempty(x_rest) - throw(ArgumentError("Variate too long during transport involving Joint")) - end + +function transport_def(ν::Combined, μ::_PowerStdMeasure{1}, x) + μ_inner = _get_inner_stdmeasure(μ) + _from_mvstd(ν, μ_inner, x) +end + +function _from_mvstd_with_rest(ν::Combined, μ_inner::StdMeasure, x) + a, x2 = _from_mvstd_with_rest(ν.α, μ_inner, x) + b, x_rest = _from_mvstd_with_rest(ν.β, μ_inner, x2) + return ν.f_c(a, b), x_rest +end + + +function _to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) + return _to_mvstd_withdof(ν_inner, μ, getdof(μ), x, origin) +end + +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ, x) + y = transport_to(ν_inner^dof_μ, μ, x) return y end +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoDOF, x) + _to_mvstd_withorigin(ν_inner, μ, transport_origin(μ), x) +end -function _from_std_with_rest(ν::Joint, μ_inner::StdMeasure, x) - ν_primary = ν.m - y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) - ν_secondary = ν.f(y_primary) - y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) - return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest +function _to_mvstd_withorigin(ν_inner::StdMeasure, ::AbstractMeasure, μ_origin, x) + x_origin = _to_mvstd(ν_inner, μ_origin, x) + from_origin(x_origin) end -function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) +function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, NoTransportOrigin, x) + throw(ArgumentError("Don't know how to transport values of type $(nameof(typeof(x))) from $(nameof(typeof(μ))) to a power of $(nameof(typeof(ν_inner)))")) +end + + +function from_std(ν::AbstractMeasure, μ_inner::StdMeasure, x) + # Sanity check, should be checked by transport machinery already: + @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector + y, x_rest = _from_mvstd_with_rest(ν, μ_inner, x) + if !isempty(x_rest) + throw(ArgumentError("Input value too long during transport")) + end + return y +end + +function _from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) dof_ν = getdof(ν) + origin = transport_origin(ν) + return _from_mvstd_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) +end + +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::StdMeasure, x) len_x = length(eachindex(x)) - # Since we can't check DOF of original Joint, we could "run out x" if + # Since we can't check DOF of original Bind, we could "run out x" if # the original x was too short. `transport_to` below will detect this, but better # throw a more informative exception here: if len_x < dof_ν - throw(ArgumentError("Variate too short during transport involving Joint")) + throw(ArgumentError("Variate too short during transport involving Bind")) end - y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) - x_rest = Fill(zero(eltype(x)), dof_ν - len_x) + x_inner_dof, x_rest = _split_after(x, dof_ν) + y = transport_to(ν, μ_inner^dof_ν, x_inner_dof) return y, x_rest end -function transport_def(ν::Joint, μ::_PowerStdMeasure{1}, x) - # Sanity check, should be checked by transport machinery already: - @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector - μ_inner = _get_inner_stdmeasure(μ) - y, x_rest = _from_std_with_rest(ν, μ_inner, x) - if !isempty(x_rest) - throw(ArgumentError("Variate too long during transport involving Joint")) - end - return y +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::StdMeasure, x) + _from_mvstd_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) +end + +function _from_mvstd_with_rest_withorigin(::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) + x_origin, x_rest = _from_mvstd_with_rest(ν_origin, x, μ_inner) + from_origin(x_origin), x_rest +end + +function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin, μ_inner::StdMeasure, x) + throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) end From d69b9a0ea75c9038c0dbc0560b7a2e432d64f1d1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 11:44:38 +0200 Subject: [PATCH 034/133] STASH --- src/combinators/bind.jl | 22 ++---- src/combinators/combined.jl | 41 ++++++------ src/standard/stdmeasure.jl | 129 ++++++++++++++++++++++++++++++------ 3 files changed, 135 insertions(+), 57 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 899c6633..b6831a17 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -196,28 +196,18 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, μ::Bind) where {T<:Real} end -function transport_def(ν::_PowerStdMeasure{1}, μ::Bind, ab) - ν_inner = _get_inner_stdmeasure(ν) - _to_mvstd(ν_inner, μ, ab) -end - -function _to_mvstd(ν_inner::StdMeasure, μ::Bind, ab) +function transport_to_mvstd(ν_inner::StdMeasure, μ::Bind, ab) tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) β_a = μ.f_β(a) - y1 = _to_mvstd(ν_inner, tpm_α, a) - y2 = _to_mvstd(ν_inner, β_a, b) + y1 = transport_to_mvstd(ν_inner, tpm_α, a) + y2 = transport_to_mvstd(ν_inner, β_a, b) return vcat(y1, y2) end -function transport_def(ν::Bind, μ::_PowerStdMeasure{1}, x) - μ_inner = _get_inner_stdmeasure(μ) - _from_mvstd(ν, μ_inner, x) -end - -function _from_mvstd_with_rest(ν::Bind, μ_inner::StdMeasure, x) - a, x2 = _from_mvstd_with_rest(ν.α, μ_inner, x) +function transport_from_mvstd_with_rest(ν::Bind, μ_inner::StdMeasure, x) + a, x2 = transport_from_mvstd_with_rest(ν.α, μ_inner, x) β_a = ν.f_β(a) - b, x_rest = _from_mvstd_with_rest(β_a, μ_inner, x2) + b, x_rest = transport_from_mvstd_with_rest(β_a, μ_inner, x2) return ν.f_c(a, b), x_rest end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index fc5966d8..12029196 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -122,32 +122,28 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Combined) where {T<:Re end -function transport_def(ν::_PowerStdMeasure{1}, μ::Combined, ab) - ν_inner = _get_inner_stdmeasure(ν) - _to_mvstd(ν_inner, μ, ab) -end -function _to_mvstd(ν_inner::StdMeasure, μ::Combined, ab) +function transport_to_mvstd(ν_inner::StdMeasure, μ::Combined, ab) tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) - y1 = _to_mvstd(ν_inner, tpm_α, a) - y2 = _to_mvstd(ν_inner, μ.β, b) + y1 = transport_to_mvstd(ν_inner, tpm_α, a) + y2 = transport_to_mvstd(ν_inner, μ.β, b) return vcat(y1, y2) end -function transport_def(ν::Combined, μ::_PowerStdMeasure{1}, x) - μ_inner = _get_inner_stdmeasure(μ) - _from_mvstd(ν, μ_inner, x) -end - -function _from_mvstd_with_rest(ν::Combined, μ_inner::StdMeasure, x) - a, x2 = _from_mvstd_with_rest(ν.α, μ_inner, x) - b, x_rest = _from_mvstd_with_rest(ν.β, μ_inner, x2) +function transport_from_mvstd_with_rest(ν::Combined, μ_inner::StdMeasure, x) + a, x2 = transport_from_mvstd_with_rest(ν.α, μ_inner, x) + b, x_rest = transport_from_mvstd_with_rest(ν.β, μ_inner, x2) return ν.f_c(a, b), x_rest end -function _to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) +function transport_def(ν::_PowerStdMeasure{1}, μ::AbstractMeasure, ab) + ν_inner = _get_inner_stdmeasure(ν) + transport_to_mvstd(ν_inner, μ, ab) +end + +function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) return _to_mvstd_withdof(ν_inner, μ, getdof(μ), x, origin) end @@ -161,7 +157,7 @@ function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoDOF, x end function _to_mvstd_withorigin(ν_inner::StdMeasure, ::AbstractMeasure, μ_origin, x) - x_origin = _to_mvstd(ν_inner, μ_origin, x) + x_origin = transport_to_mvstd(ν_inner, μ_origin, x) from_origin(x_origin) end @@ -170,17 +166,22 @@ function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, NoTrans end +function transport_def(ν::AbstractMeasure, μ::_PowerStdMeasure{1}, x) + μ_inner = _get_inner_stdmeasure(μ) + _from_mvstd(ν, μ_inner, x) +end + function from_std(ν::AbstractMeasure, μ_inner::StdMeasure, x) # Sanity check, should be checked by transport machinery already: @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector - y, x_rest = _from_mvstd_with_rest(ν, μ_inner, x) + y, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x) if !isempty(x_rest) throw(ArgumentError("Input value too long during transport")) end return y end -function _from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) +function transport_from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) dof_ν = getdof(ν) origin = transport_origin(ν) return _from_mvstd_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) @@ -206,7 +207,7 @@ function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::S end function _from_mvstd_with_rest_withorigin(::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) - x_origin, x_rest = _from_mvstd_with_rest(ν_origin, x, μ_inner) + x_origin, x_rest = transport_from_mvstd_with_rest(ν_origin, x, μ_inner) from_origin(x_origin), x_rest end diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 0df8977f..f53e1083 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -1,9 +1,9 @@ abstract type StdMeasure <: AbstractMeasure end -const _PowerStdMeasure{N,M<:StdMeasure} = PowerMeasure{M,<:NTuple{N,Base.OneTo}} +const _PowerStdMeasure{N,MU<:StdMeasure} = PowerMeasure{MU,<:NTuple{N,Base.OneTo}} -_get_inner_stdmeasure(μ::_PowerStdMeasure{N,M}) where {N,M} = M() +_get_inner_stdmeasure(::_PowerStdMeasure{N,MU}) where {N,MU} = M() StdMeasure(::typeof(rand)) = StdUniform() @@ -22,43 +22,130 @@ function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) return fill_with(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)) end -function transport_def( - ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, - μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, - x, -) +function transport_def(ν::_PowerStdMeasure{MU,1}, μ::_PowerStdMeasure{NU,1}, x,) where {MU,NU} return transport_to(ν.parent, μ.parent).(x) end -function transport_def( - ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, - μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, - x, -) where {N,M} - return reshape(transport_to(ν.parent, μ.parent).(x), map(length, ν.axes)...) +transport_origin(μ::_PowerStdMeasure{N}) = ν.parent^product(map(length, μ.axes)) + +function from_origin(μ::_PowerStdMeasure{N}, x_origin::AbstractVector{<:Real}) + return reshape(x_origin, map(length, μ.axes)...) +end + + +# Transport to a multivariate standard measure from any measure: + +function transport_def(ν::_PowerStdMeasure{1}, μ::AbstractMeasure, ab) + ν_inner = _get_inner_stdmeasure(ν) + transport_to_mvstd(ν_inner, μ, ab) +end + +function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) + return _to_mvstd_withdof(ν_inner, μ, getdof(μ), x, origin) +end + +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ, x) + y = transport_to(ν_inner^dof_μ, μ, x) + return y +end + +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoDOF, x) + _to_mvstd_withorigin(ν_inner, μ, transport_origin(μ), x) +end + +function _to_mvstd_withorigin(ν_inner::StdMeasure, ::AbstractMeasure, μ_origin, x) + x_origin = transport_to_mvstd(ν_inner, μ_origin, x) + from_origin(x_origin) +end + +function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, NoTransportOrigin, x) + throw(ArgumentError("Don't know how to transport values of type $(nameof(typeof(x))) from $(nameof(typeof(μ))) to a power of $(nameof(typeof(ν_inner)))")) +end + + +# Transport from a multivariate standard measure to any measure: + +function transport_def(ν::AbstractMeasure, μ::_PowerStdMeasure{1}, x) + μ_inner = _get_inner_stdmeasure(μ) + _transport_from_mvstd(ν, μ_inner, x) +end + +function _transport_from_mvstd(ν::AbstractMeasure, μ_inner::StdMeasure, x) + # Sanity check, should be checked by transport machinery already: + @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector + y, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x) + if !isempty(x_rest) + throw(ArgumentError("Input value too long during transport")) + end + return y +end + +function transport_from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) + dof_ν = getdof(ν) + origin = transport_origin(ν) + return _from_mvstd_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) end -# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}): +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::StdMeasure, x) + len_x = length(eachindex(x)) + + # Since we can't check DOF of original Bind, we could "run out x" if + # the original x was too short. `transport_to` below will detect this, but better + # throw a more informative exception here: + if len_x < dof_ν + throw(ArgumentError("Variate too short during transport involving Bind")) + end + + x_inner_dof, x_rest = _split_after(x, dof_ν) + y = transport_to(ν, μ_inner^dof_ν, x_inner_dof) + return y, x_rest +end + +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::StdMeasure, x) + _from_mvstd_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) +end + +function _from_mvstd_with_rest_withorigin(::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) + x_origin, x_rest = transport_from_mvstd_with_rest(ν_origin, x, μ_inner) + from_origin(x_origin), x_rest +end + +function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin, μ_inner::StdMeasure, x) + throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) +end + + +_empty_zero(::AbstractVector{T}) where {T<:Real} = Fill(zero(T), 0) + + +# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}) +# for user convenience: + +# ToDo: Handle combined/bind measures that don't have a fast getdof! _std_measure(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() _std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof _std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ)) +function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} + transport_to(ν, _std_measure_for(MU, ν)) +end + function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} transport_to(_std_measure_for(NU, μ), μ) end -function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} - transport_to(ν, _std_measure_for(MU, ν)) -end + # Transform between standard measures and Dirac: -@inline transport_def(ν::Dirac, ::PowerMeasure{<:StdMeasure}, ::Any) = ν.x +@inline transport_from_mvstd_with_rest(ν::Dirac, ::StdMeasure, x::Any) = ν.x, x -@inline function transport_def(ν::PowerMeasure{<:StdMeasure}, ::Dirac, ::Any) - Zeros{Bool}(map(_ -> 0, ν.axes)) -end +@inline transport_to_mvstd(ν::PowerMeasure{<:StdMeasure}, ::Dirac, ::Any) = Zeros{Bool}(map(_ -> 0, ν.axes)) + + + +#!!!!!!!!!!!!!!!!!!!!!! TODO: # Helpers for product transforms and similar: From d9b3f0d7de2d603b4e81ce92f28ecba5b5bb805d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 11:45:47 +0200 Subject: [PATCH 035/133] STASH FIXUP --- src/combinators/combined.jl | 78 ------------------------------------- 1 file changed, 78 deletions(-) diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 12029196..1040d099 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -136,81 +136,3 @@ function transport_from_mvstd_with_rest(ν::Combined, μ_inner::StdMeasure, x) b, x_rest = transport_from_mvstd_with_rest(ν.β, μ_inner, x2) return ν.f_c(a, b), x_rest end - - -function transport_def(ν::_PowerStdMeasure{1}, μ::AbstractMeasure, ab) - ν_inner = _get_inner_stdmeasure(ν) - transport_to_mvstd(ν_inner, μ, ab) -end - -function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) - return _to_mvstd_withdof(ν_inner, μ, getdof(μ), x, origin) -end - -function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ, x) - y = transport_to(ν_inner^dof_μ, μ, x) - return y -end - -function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoDOF, x) - _to_mvstd_withorigin(ν_inner, μ, transport_origin(μ), x) -end - -function _to_mvstd_withorigin(ν_inner::StdMeasure, ::AbstractMeasure, μ_origin, x) - x_origin = transport_to_mvstd(ν_inner, μ_origin, x) - from_origin(x_origin) -end - -function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, NoTransportOrigin, x) - throw(ArgumentError("Don't know how to transport values of type $(nameof(typeof(x))) from $(nameof(typeof(μ))) to a power of $(nameof(typeof(ν_inner)))")) -end - - -function transport_def(ν::AbstractMeasure, μ::_PowerStdMeasure{1}, x) - μ_inner = _get_inner_stdmeasure(μ) - _from_mvstd(ν, μ_inner, x) -end - -function from_std(ν::AbstractMeasure, μ_inner::StdMeasure, x) - # Sanity check, should be checked by transport machinery already: - @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector - y, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x) - if !isempty(x_rest) - throw(ArgumentError("Input value too long during transport")) - end - return y -end - -function transport_from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) - dof_ν = getdof(ν) - origin = transport_origin(ν) - return _from_mvstd_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) -end - -function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::StdMeasure, x) - len_x = length(eachindex(x)) - - # Since we can't check DOF of original Bind, we could "run out x" if - # the original x was too short. `transport_to` below will detect this, but better - # throw a more informative exception here: - if len_x < dof_ν - throw(ArgumentError("Variate too short during transport involving Bind")) - end - - x_inner_dof, x_rest = _split_after(x, dof_ν) - y = transport_to(ν, μ_inner^dof_ν, x_inner_dof) - return y, x_rest -end - -function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::StdMeasure, x) - _from_mvstd_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) -end - -function _from_mvstd_with_rest_withorigin(::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) - x_origin, x_rest = transport_from_mvstd_with_rest(ν_origin, x, μ_inner) - from_origin(x_origin), x_rest -end - -function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin, μ_inner::StdMeasure, x) - throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) -end From c05c545463bb906041b7f537a3ae3bef7dbc6bf6 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 13:28:41 +0200 Subject: [PATCH 036/133] STASH --- src/MeasureBase.jl | 15 ++++---- src/combinators/power.jl | 25 +++++++++++++ src/combinators/product.jl | 75 +++++++++++++++++++++++++++++++++++++ src/standard/stdmeasure.jl | 77 -------------------------------------- 4 files changed, 108 insertions(+), 84 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 1f2341f4..263caefe 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -140,11 +140,18 @@ include("primitives/lebesgue.jl") include("primitives/dirac.jl") include("primitives/trivial.jl") +include("combinators/power.jl") + +include("standard/stdmeasure.jl") +include("standard/stduniform.jl") +include("standard/stdexponential.jl") +include("standard/stdlogistic.jl") +include("standard/stdnormal.jl") + include("combinators/transformedmeasure.jl") include("combinators/weighted.jl") include("combinators/superpose.jl") include("combinators/product.jl") -include("combinators/power.jl") include("combinators/combined.jl") include("combinators/bind.jl") include("combinators/spikemixture.jl") @@ -152,12 +159,6 @@ include("combinators/likelihood.jl") include("combinators/restricted.jl") include("combinators/smart-constructors.jl") include("combinators/conditional.jl") - -include("standard/stdmeasure.jl") -include("standard/stduniform.jl") -include("standard/stdexponential.jl") -include("standard/stdlogistic.jl") -include("standard/stdnormal.jl") include("combinators/half.jl") include("rand.jl") diff --git a/src/combinators/power.jl b/src/combinators/power.jl index a7fa24f8..c0eb7a1c 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -147,3 +147,28 @@ function logdensity_def( ) where {P<:PrimitiveMeasure,N} static(0.0) end + + +# For transport, always pull back to one-dimensional PowerMeasure first: + +transport_origin(μ::PowerMeasure{<:Any,N}) where N = ν.parent^product(map(length, μ.axes)) + +function from_origin(μ::_PowerStdMeasure{<:Any,N}, x_origin) where N + # Sanity check, should never fail: + @assert x_origin isa AbstractVector + return reshape(x_origin, map(length, μ.axes)...) +end + + +# One-dimensional PowerMeasure has an origin iff it's parent has an origin: + +transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _origin_pwr(::typeof(μ), transport_origin(μ.parent), μ.axes) +_pwr_origin(::Type{MU}, parent_origin, axes) = parent_origin^axes +_pwr_origin(::Type{MU}, ::NoTransportOrigin, axes) = NoTransportOrigin{MU} + +function from_origin(μ::PowerMeasure{<:AbstractMeasure,1}, x_origin) + # Sanity check, should never fail: + @assert x_origin isa AbstractVector + from_origin.(Ref(μ.parent), x_origin) +end + diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 516678f5..ed0970bc 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -225,3 +225,78 @@ function checked_arg( ) where {names} NamedTuple{names}(map(checked_arg, values(marginals(μ)), values(x))) end + + + +# Transport for products + + +#!!!!!!!!!!!!!!!!!!!!!! TODO: + +# Helpers for product transforms and similar: + +struct _TransportToStd{NU<:StdMeasure} <: Function end +_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) + +struct _TransportFromStd{MU<:StdMeasure} <: Function end +_TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x) + +function _tuple_transport_def( + ν::PowerMeasure{NU}, + μs::Tuple, + xs::Tuple, +) where {NU<:StdMeasure} + reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes) +end + +function transport_def( + ν::PowerMeasure{NU}, + μ::ProductMeasure{<:Tuple}, + x, +) where {NU<:StdMeasure} + _tuple_transport_def(ν, marginals(μ), x) +end + +function transport_def( + ν::PowerMeasure{NU}, + μ::ProductMeasure{<:NamedTuple{names}}, + x, +) where {NU<:StdMeasure,names} + _tuple_transport_def(ν, values(marginals(μ)), values(x)) +end + +@inline _offset_cumsum(s, x, y, rest...) = (s, _offset_cumsum(s + x, y, rest...)...) +@inline _offset_cumsum(s, x) = (s,) +@inline _offset_cumsum(s) = () + +function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike) + N = map(getdof, μs) + offs = _offset_cumsum(startidx, N...) + map((o, n) -> o:o+n-1, offs, N) +end + +function _tuple_transport_def( + νs::Tuple, + μ::PowerMeasure{MU}, + x::AbstractArray{<:Real}, +) where {MU<:StdMeasure} + vrs = _stdvar_viewranges(νs, firstindex(x)) + xs = map(r -> view(x, r), vrs) + map(_TransportFromStd{MU}, νs, xs) +end + +function transport_def( + ν::ProductMeasure{<:Tuple}, + μ::PowerMeasure{MU}, + x, +) where {MU<:StdMeasure} + _tuple_transport_def(marginals(ν), μ, x) +end + +function transport_def( + ν::ProductMeasure{<:NamedTuple{names}}, + μ::PowerMeasure{MU}, + x, +) where {MU<:StdMeasure,names} + NamedTuple{names}(_tuple_transport_def(values(marginals(ν)), μ, x)) +end diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index f53e1083..6e5287c6 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -26,11 +26,6 @@ function transport_def(ν::_PowerStdMeasure{MU,1}, μ::_PowerStdMeasure{NU,1}, x return transport_to(ν.parent, μ.parent).(x) end -transport_origin(μ::_PowerStdMeasure{N}) = ν.parent^product(map(length, μ.axes)) - -function from_origin(μ::_PowerStdMeasure{N}, x_origin::AbstractVector{<:Real}) - return reshape(x_origin, map(length, μ.axes)...) -end # Transport to a multivariate standard measure from any measure: @@ -142,75 +137,3 @@ end @inline transport_from_mvstd_with_rest(ν::Dirac, ::StdMeasure, x::Any) = ν.x, x @inline transport_to_mvstd(ν::PowerMeasure{<:StdMeasure}, ::Dirac, ::Any) = Zeros{Bool}(map(_ -> 0, ν.axes)) - - - -#!!!!!!!!!!!!!!!!!!!!!! TODO: - -# Helpers for product transforms and similar: - -struct _TransportToStd{NU<:StdMeasure} <: Function end -_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) - -struct _TransportFromStd{MU<:StdMeasure} <: Function end -_TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x) - -function _tuple_transport_def( - ν::PowerMeasure{NU}, - μs::Tuple, - xs::Tuple, -) where {NU<:StdMeasure} - reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes) -end - -function transport_def( - ν::PowerMeasure{NU}, - μ::ProductMeasure{<:Tuple}, - x, -) where {NU<:StdMeasure} - _tuple_transport_def(ν, marginals(μ), x) -end - -function transport_def( - ν::PowerMeasure{NU}, - μ::ProductMeasure{<:NamedTuple{names}}, - x, -) where {NU<:StdMeasure,names} - _tuple_transport_def(ν, values(marginals(μ)), values(x)) -end - -@inline _offset_cumsum(s, x, y, rest...) = (s, _offset_cumsum(s + x, y, rest...)...) -@inline _offset_cumsum(s, x) = (s,) -@inline _offset_cumsum(s) = () - -function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike) - N = map(getdof, μs) - offs = _offset_cumsum(startidx, N...) - map((o, n) -> o:o+n-1, offs, N) -end - -function _tuple_transport_def( - νs::Tuple, - μ::PowerMeasure{MU}, - x::AbstractArray{<:Real}, -) where {MU<:StdMeasure} - vrs = _stdvar_viewranges(νs, firstindex(x)) - xs = map(r -> view(x, r), vrs) - map(_TransportFromStd{MU}, νs, xs) -end - -function transport_def( - ν::ProductMeasure{<:Tuple}, - μ::PowerMeasure{MU}, - x, -) where {MU<:StdMeasure} - _tuple_transport_def(marginals(ν), μ, x) -end - -function transport_def( - ν::ProductMeasure{<:NamedTuple{names}}, - μ::PowerMeasure{MU}, - x, -) where {MU<:StdMeasure,names} - NamedTuple{names}(_tuple_transport_def(values(marginals(ν)), μ, x)) -end From 40cf99d024f9b090604d8f97ade45ab76649e6ca Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 21:09:26 +0200 Subject: [PATCH 037/133] STASH --- src/collection_utils.jl | 6 ++++++ src/combinators/product.jl | 39 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 858810cf..fbd55586 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -24,3 +24,9 @@ end end end end + + +Base.@propagate_inbounds function _as_tuple(v::AbstractVector, ::Val{N}) where {N} + @boundcheck @assert length(v) == N # ToDo: Throw proper exception + ntuple(i -> v[i], Val(N)) +end diff --git a/src/combinators/product.jl b/src/combinators/product.jl index ed0970bc..3d9c7d7b 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -227,6 +227,45 @@ function checked_arg( end +function transport_to(ν::Pro) +end + + +function _marginal_transport_def(marginals_ν::NamedTuple{names}, marginals_μ::NamedTuple, x) where names + # ToDo - Improvement: Match names as far as possible, even if in different order, and transport between + # the rest in the order given. + NamedTuple{names}(transport_to.(values(marginals_ν), values(marginals_μ), x)) +end + + + +function _marginal_transport_def(marginals_ν::AbstractVector{<:AbstractMeasure}, marginals_μ::AbstractVector{<:AbstractMeasure}, x) + @assert x isa AbstractVector # Sanity check, should not fail + transport_to.(marginals_ν, marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where N + @assert x isa Tuple{Vararg{AbstractMeasure,N}} # Sanity check, should not fail + transport_to.(marginals_ν, marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::NamedTuple{names}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where {names,N} + _marginal_transport_def(marginals_ν, NamedTuple{names}(marginals_μ), x) +end + +function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::NamedTuple{names}, x) where {names,N} + _marginal_transport_def(marginals_ν, values(marginals), x) +end + +function _marginal_transport_def(marginals_ν::AbstractVector{<:AbstractMeasure}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where N + _marginal_transport_def(_as_tuple(marginals_ν, Val(N)), marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::AbstractVector{<:AbstractMeasure}, x) where N + _marginal_transport_def(marginals_ν, _as_tuple(marginals_μ, Val(N)), x) +end + + # Transport for products From ef64ed84ab4fb7ef37f3c609a1eac6bed869009b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 21:11:23 +0200 Subject: [PATCH 038/133] STASH --- src/combinators/product.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 3d9c7d7b..7ba64bcb 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -237,6 +237,10 @@ function _marginal_transport_def(marginals_ν::NamedTuple{names}, marginals_μ:: NamedTuple{names}(transport_to.(values(marginals_ν), values(marginals_μ), x)) end +@inline function _marginal_transport_def(marginals_ν, marginals_μ, x) + marginal_transport_non_ntnt(marginals_ν, marginals_μ, x) +end + function _marginal_transport_def(marginals_ν::AbstractVector{<:AbstractMeasure}, marginals_μ::AbstractVector{<:AbstractMeasure}, x) From b04684ade20889dbebea5518acc440a80faa9369 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 30 Jun 2023 22:38:20 +0200 Subject: [PATCH 039/133] STASH --- src/combinators/bind.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index b6831a17..9e549e3f 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -13,7 +13,7 @@ space of measure `β_a` to a combined value `ab = f_c(a, b)`. The resulting measure ```julia -μ = mbind(f_c, α, f_β) +μ = mbind(f_β, α, f_c) ``` has the mathethematical interpretation (on sets $$A$$ and $$B$$) @@ -52,7 +52,7 @@ support other choices for `f_c`. Bayesian example with a correlated prior, that models the ```julia -using MeasureBase +using MeasureBase, AffineMaps prior = mbind productmeasure(( From 710b8f143e5466aa8723092161999ab0e32d9cf4 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 11:48:27 +0200 Subject: [PATCH 040/133] STASH --- src/collection_utils.jl | 28 ++++++++++++ src/combinators/combined.jl | 28 ++++++------ src/combinators/product.jl | 85 ++++++++++++++++++++++++------------- src/proxies.jl | 2 + 4 files changed, 100 insertions(+), 43 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index fbd55586..28c17d9b 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -30,3 +30,31 @@ Base.@propagate_inbounds function _as_tuple(v::AbstractVector, ::Val{N}) where { @boundcheck @assert length(v) == N # ToDo: Throw proper exception ntuple(i -> v[i], Val(N)) end + + +struct _TupleNamer{names} <: Function end +(::TupleNamer{names})(x::Tuple) where names = NamedTuple{names}(x) +InverseFunctions.inverse(::TupleNamer{names}) where names = TupleUnNamer{names}() +ChangesOfVariables.with_logabsdet_jacobian(::TupleNamer{names}, x::Tuple) where names = static(false) + +struct _TupleUnNamer{names} <: Function end +(::TupleUnNamer{names})(x::NamedTuple{names}) where {names} = values(x) +InverseFunctions.inverse(::TupleUnNamer{names}) where names = TupleNamer{names}() +ChangesOfVariables.with_logabsdet_jacobian(::TupleUnNamer{names}, x::NamedTuple{names}) where names = static(false) + + +_reorder_nt(x::NamedTuple{names},::Val{names}) where {names} = x + +@generated function _reorder_nt(x::NamedTuple{names},::Val{new_names}) where {names,new_names} + if sort([names...]) != sort([new_names...]) + :(throw(ArgumentError("Can't reorder NamedTuple{$names} to NamedTuple{$new_names}"))) + else + expr = :(()) + for nm in new_names + push!(expr.args, :($nm = x.$nm)) + end + return expr + end +end + +# ToDo: Add custom rrule for _reorder_nt? diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 1040d099..cd7c1f42 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -62,7 +62,7 @@ export mcombine function mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) FC, MA, MB = Core.Typeof(f_c), Core.Typeof(α), Core.Typeof(β) - Combined{FC,MA,MB}(f_c, α, β) + CombinedMeasure{FC,MA,MB}(f_c, α, β) end function mcombine(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) @@ -75,15 +75,15 @@ end """ - struct Combined <: AbstractMeasure + struct CombinedMeasure <: AbstractMeasure Represents a combination of two measures. -User code should not create instances of `Combined` directly, but should call +User code should not create instances of `CombinedMeasure` directly, but should call [`mcombine(f_c, α, β)`](@ref) instead. """ -struct Combined{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMeasure +struct CombinedMeasure{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMeasure f_c::FC α::MA β::MB @@ -91,31 +91,31 @@ end # TODO: Could split `ab`` here, but would be wasteful. -@inline insupport(::Combined, ab) = NoFastInsupport() +@inline insupport(::CombinedMeasure, ab) = NoFastInsupport() -@inline getdof(μ::Combined) = getdof(μ.α) + getdof(μ.β) +@inline getdof(μ::CombinedMeasure) = getdof(μ.α) + getdof(μ.β) # Bypass `checked_arg`, would require require splitting ab: -@inline checked_arg(::Combined, ab) = ab +@inline checked_arg(::CombinedMeasure, ab) = ab -rootmeasure(::Combined) = mcombine(μ.f_c rootmeasure(μ), rootmeasure(ν)) +rootmeasure(::CombinedMeasure) = mcombine(μ.f_c rootmeasure(μ), rootmeasure(ν)) -basemeasure(::Combined) = mcombine(μ.f_c basemeasure(μ), basemeasure(ν)) +basemeasure(::CombinedMeasure) = mcombine(μ.f_c basemeasure(μ), basemeasure(ν)) -function logdensity_def(μ::Combined, ab) +function logdensity_def(μ::CombinedMeasure, ab) # Use tpmeasure_split_combined to avoid duplicate calculation of transportmeasure(α): tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) return logdensity_def(tpm_α, a) + logdensity_def(μ.β, b) end # Specialize logdensityof directly to avoid creating temporary combined base measures: -function logdensityof(μ::Combined, ab) +function logdensityof(μ::CombinedMeasure, ab) tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) return logdensityof(tpm_α, a) + logdensityof(μ.β, b) end -function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::Combined) where {T<:Real} +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::CombinedMeasure) where {T<:Real} x_primary = rand(rng, T, h.m) x_secondary = rand(rng, T, h.f(x_primary)) return _combine_variates(h.flatten_mode, x_primary, x_secondary) @@ -123,7 +123,7 @@ end -function transport_to_mvstd(ν_inner::StdMeasure, μ::Combined, ab) +function transport_to_mvstd(ν_inner::StdMeasure, μ::CombinedMeasure, ab) tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) y1 = transport_to_mvstd(ν_inner, tpm_α, a) y2 = transport_to_mvstd(ν_inner, μ.β, b) @@ -131,7 +131,7 @@ function transport_to_mvstd(ν_inner::StdMeasure, μ::Combined, ab) end -function transport_from_mvstd_with_rest(ν::Combined, μ_inner::StdMeasure, x) +function transport_from_mvstd_with_rest(ν::CombinedMeasure, μ_inner::StdMeasure, x) a, x2 = transport_from_mvstd_with_rest(ν.α, μ_inner, x) b, x_rest = transport_from_mvstd_with_rest(ν.β, μ_inner, x2) return ν.f_c(a, b), x_rest diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 7ba64bcb..9555e807 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -1,5 +1,3 @@ -export ProductMeasure - using MappedArrays using MappedArrays: ReadonlyMultiMappedArray using Base: @propagate_inbounds @@ -72,42 +70,62 @@ function _rand_product( end |> collect end -@inline function logdensity_def(d::AbstractProductMeasure, x) - mapreduce(logdensity_def, +, marginals(d), x) + +@inline function logdensity_def(μ::AbstractProductMeasure, x) + marginals_density_op(logdensity_def, marginals(μ), x) +end +@inline function unsafe_logdensityof(μ::AbstractProductMeasure, x) + marginals_density_op(unsafe_logdensityof, marginals(μ), x) +end +@inline function logdensity_rel(μ::AbstractProductMeasure, ν::AbstractProductMeasure, x) + marginals_density_op(logdensity_rel, marginals(μ), marginals(ν), x) end -struct ProductMeasure{M} <: AbstractProductMeasure - marginals::M +function _marginals_density_op(density_op::F, marginals_μ, x) where F + mapreduce(density_op, +, marginals_μ, x) +end +@inline function _marginals_density_op(density_op::F, marginals_μ::Tuple, x::Tuple) where F + # For tuples, `mapreduce` can have trouble with type inference + sum(map(density_op, marginals_μ, x)) +end +@inline function _marginals_density_op(density_op::F, marginals_μ::NameTuple{names}, x::NameTuple) where {F,names} + nms = Val{names}() + _marginals_density_op(density_op, marginals_μ, _reorder_nt(values(x), nms)) end -@inline function logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x) - mapreduce(logdensity_rel, +, marginals(μ), marginals(ν), x) +function _marginals_density_op(density_op::F, marginals_μ, marginals_ν, x) where F + mapreduce(density_op, +, marginals_μ, marginals_ν, x) +end +@inline function _marginals_density_op(density_op::F, marginals_μ::Tuple, marginals_ν::Tuple, x::Tuple) where F + # For tuples, `mapreduce` can have trouble with type inference + sum(map(density_op, marginals_μ, marginals_ν, x)) +end +@inline function _marginals_density_op(density_op::F, marginals_μ::NameTuple{names}, marginals_ν::NameTuple, x::NameTuple) where {F,names} + nms = Val{names}() + _marginals_density_op(density_op, marginals_μ, _reorder_nt(values(marginals_ν), nms), _reorder_nt(values(x), nms)) end -function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple} - Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = " ⊗ ") + +""" + struct MeasureBase.ProductMeasure{M} <: AbstractProductMeasure + +Represents a products of measures. + +´ProductMeasure` wraps a collection of measures, this collection then +becomes the collection of the marginal measures of the `ProductMeasure`. + +User code should not instantiate `ProductMeasure` directly, but should call +[`productmeasure`](@ref) instead. +""" +struct ProductMeasure{M} <: AbstractProductMeasure + marginals::M end -# For tuples, `mapreduce` has trouble with type inference -@inline function logdensity_def(d::ProductMeasure{T}, x) where {T<:Tuple} - ℓs = map(logdensity_def, marginals(d), x) - sum(ℓs) +function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple} + Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = " ⊗ ") end -@generated function logdensity_def(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T} - k1 = QuoteNode(first(N)) - q = quote - m = marginals(d) - ℓ = logdensity_def(getproperty(m, $k1), getproperty(x, $k1)) - end - for k in Base.tail(N) - k = QuoteNode(k) - qk = :(ℓ += logdensity_def(getproperty(m, $k), getproperty(x, $k))) - push!(q.args, qk) - end - return q -end # @generated function basemeasure(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T} # q = quote @@ -231,9 +249,18 @@ function transport_to(ν::Pro) end +# ToDo - Possible improvement (breaking): For transport between +# NamedTuple-marginals Match names as far as possible, even if in +# different order, and transport between the remaining non-matching +# names in the order given? Direct transport between two +# non-standard measures will likely not be such a common use case, +# though, so may not be worth the effort. + + + + + function _marginal_transport_def(marginals_ν::NamedTuple{names}, marginals_μ::NamedTuple, x) where names - # ToDo - Improvement: Match names as far as possible, even if in different order, and transport between - # the rest in the order given. NamedTuple{names}(transport_to.(values(marginals_ν), values(marginals_μ), x)) end diff --git a/src/proxies.jl b/src/proxies.jl index 95aed270..0304bbb3 100644 --- a/src/proxies.jl +++ b/src/proxies.jl @@ -14,6 +14,8 @@ function proxy end macro useproxy(M) M = esc(M) quote + #!!!!!!!!!!! TODO add new API methods like localmeasure, transportmeasure, etc. !!!!!!!!!!!!! + @inline $MeasureBase.logdensity_def(μ::$M, x) = logdensity_def(proxy(μ), x) @inline $MeasureBase.basemeasure(μ::$M) = basemeasure(proxy(μ)) From 8d38eee4f528c49703f7e437abdf6420ab1d0c83 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 12:19:31 +0200 Subject: [PATCH 041/133] STASH --- src/combinators/power.jl | 15 +++++++ src/combinators/product.jl | 90 ++++++++++++++++---------------------- 2 files changed, 52 insertions(+), 53 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index c0eb7a1c..bd24614b 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -1,5 +1,20 @@ import Base +""" + marginals(μ::AbstractMeasure) + +Returns the marginals measures of `μ` as a collection of measures. + +The kind of marginalization implied by `marginals` depends on the +type of `μ`. + +`μ` may be a power of a measure or a product of measures, but other +types of measures may support `marginals` as well. +""" +function marginals end +export marginals + + export PowerMeasure """ diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 9555e807..00806cb1 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -4,9 +4,16 @@ using Base: @propagate_inbounds import Base using FillArrays -export AbstractProductMeasure +""" + abstract type AbstractProductMeasure + +Abstact type for products of measures. +[`marginals(μ::AbstractProductMeasure)`](@ref) returns the collection of +measures that `μ` is the product of. +""" abstract type AbstractProductMeasure <: AbstractMeasure end +export AbstractProductMeasure function Pretty.tile(μ::AbstractProductMeasure) result = Pretty.literal("ProductMeasure(") @@ -16,16 +23,12 @@ end massof(m::AbstractProductMeasure) = prod(massof, marginals(m)) -export marginals - function Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure) marginals(a) == marginals(b) end Base.length(μ::AbstractProductMeasure) = length(marginals(μ)) Base.size(μ::AbstractProductMeasure) = size(marginals(μ)) -basemeasure(d::AbstractProductMeasure) = productmeasure(map(basemeasure, marginals(d))) - function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractProductMeasure) where {T} mar = marginals(d) _rand_product(rng, T, mar, eltype(mar)) @@ -106,6 +109,35 @@ end end +@inline basemeasure(μ::AbstractProductMeasure) =_marginals_basemeasure(marginals(μ)) + +_marginals_basemeasure(marginals_μ) = productmeasure(map(basemeasure, marginals_μ)) + +function _marginals_basemeasure(marginals_μ::Base.Generator{I,F}) where {I,F} + T = Core.Compiler.return_type(marginals_μ.f, Tuple{eltype(marginals_μ.iter)}) + B = Core.Compiler.return_type(basemeasure, Tuple{T}) + _marginals_basemeasure_impl(μ, B, static(Base.issingletontype(B))) +end + +function _marginals_basemeasure(marginals_μ::AbstractMappedArray{T}) where {T} + B = Core.Compiler.return_type(basemeasure, Tuple{T}) + _marginals_basemeasure_impl(marginals_μ, B, static(Base.issingletontype(B))) +end + +function _marginals_basemeasure_impl(marginals_μ, ::Type{B}, ::True) where {B} + instance(B)^axes(marginals_μ) +end + +function _marginals_basemeasure_impl(marginals_μ::AbstractMappedArray{T}, ::Type{B}, ::False) where {T,B} + productmeasure(mappedarray(basemeasure, marginals_μ)) +end + +function _marginals_basemeasure_impl(marginals_μ::Base.Generator{I,F}, ::Type{B}, ::False) where {I,F,B} + productmeasure(Base.Generator(basekernel(marginals_μ.f), marginals_μ.iter)) +end + + + """ struct MeasureBase.ProductMeasure{M} <: AbstractProductMeasure @@ -127,54 +159,6 @@ end -# @generated function basemeasure(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T} -# q = quote -# m = marginals(d) -# end -# for k in N -# qk = QuoteNode(k) -# push!(q.args, :($k = basemeasure(getproperty(m, $qk)))) -# end - -# vals = map(x -> Expr(:(=), x,x), N) -# push!(q.args, Expr(:tuple, vals...)) -# return q -# end - -function basemeasure(μ::ProductMeasure{Base.Generator{I,F}}) where {I,F} - mar = marginals(μ) - T = Core.Compiler.return_type(mar.f, Tuple{eltype(mar.iter)}) - B = Core.Compiler.return_type(basemeasure, Tuple{T}) - _basemeasure(μ, B, static(Base.issingletontype(B))) -end - -function basemeasure(μ::ProductMeasure{A}) where {T,A<:AbstractMappedArray{T}} - B = Core.Compiler.return_type(basemeasure, Tuple{T}) - _basemeasure(μ, B, static(Base.issingletontype(B))) -end - -function _basemeasure(μ::ProductMeasure, ::Type{B}, ::True) where {B} - return instance(B)^axes(marginals(μ)) -end - -function _basemeasure( - μ::ProductMeasure{A}, - ::Type{B}, - ::False, -) where {T,A<:AbstractMappedArray{T},B} - mar = marginals(μ) - productmeasure(mappedarray(basemeasure, mar)) -end - -function _basemeasure( - μ::ProductMeasure{Base.Generator{I,F}}, - ::Type{B}, - ::False, -) where {I,F,B} - mar = marginals(μ) - productmeasure(Base.Generator(basekernel(mar.f), mar.iter)) -end - marginals(μ::ProductMeasure) = μ.marginals # TODO: Better `map` support in MappedArrays From 166b648ab87f5130ab88fb949dbccd17b5ad223a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 13:57:00 +0200 Subject: [PATCH 042/133] STASH --- src/MeasureBase.jl | 3 +- src/collection_utils.jl | 3 + src/combinators/combined.jl | 4 +- src/combinators/power.jl | 58 ++++++----- src/combinators/product.jl | 21 ++-- src/combinators/product_transport.jl | 139 +++++++++++++++++++++++++ src/standard/stdexponential.jl | 11 +- src/standard/stdlogistic.jl | 12 ++- src/standard/stdmeasure.jl | 147 +++++---------------------- src/standard/stdnormal.jl | 12 ++- src/standard/stduniform.jl | 14 ++- 11 files changed, 261 insertions(+), 163 deletions(-) create mode 100644 src/combinators/product_transport.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 263caefe..22b29926 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -140,7 +140,9 @@ include("primitives/lebesgue.jl") include("primitives/dirac.jl") include("primitives/trivial.jl") +include("combinators/product.jl") include("combinators/power.jl") +include("combinators/product_transport.jl") include("standard/stdmeasure.jl") include("standard/stduniform.jl") @@ -151,7 +153,6 @@ include("standard/stdnormal.jl") include("combinators/transformedmeasure.jl") include("combinators/weighted.jl") include("combinators/superpose.jl") -include("combinators/product.jl") include("combinators/combined.jl") include("combinators/bind.jl") include("combinators/spikemixture.jl") diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 28c17d9b..6c3c5d08 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -32,6 +32,9 @@ Base.@propagate_inbounds function _as_tuple(v::AbstractVector, ::Val{N}) where { end +_empty_zero(::AbstractVector{T}) where {T<:Real} = Fill(zero(T), 0) + + struct _TupleNamer{names} <: Function end (::TupleNamer{names})(x::Tuple) where names = NamedTuple{names}(x) InverseFunctions.inverse(::TupleNamer{names}) where names = TupleUnNamer{names}() diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index cd7c1f42..d22fb35c 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -98,9 +98,9 @@ end # Bypass `checked_arg`, would require require splitting ab: @inline checked_arg(::CombinedMeasure, ab) = ab -rootmeasure(::CombinedMeasure) = mcombine(μ.f_c rootmeasure(μ), rootmeasure(ν)) +rootmeasure(::CombinedMeasure) = mcombine(μ.f_c, rootmeasure(μ), rootmeasure(ν)) -basemeasure(::CombinedMeasure) = mcombine(μ.f_c basemeasure(μ), basemeasure(ν)) +basemeasure(::CombinedMeasure) = mcombine(μ.f_c, basemeasure(μ), basemeasure(ν)) function logdensity_def(μ::CombinedMeasure, ab) # Use tpmeasure_split_combined to avoid duplicate calculation of transportmeasure(α): diff --git a/src/combinators/power.jl b/src/combinators/power.jl index bd24614b..d221f0ce 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -26,6 +26,8 @@ the product determines the dimensionality of the resulting support. Note that power measures are only well-defined for integer powers. The nth power of a measure μ can be written μ^n. + +See also [`pwr_base`](@ref), [`pwr_axes`](@ref) and [`pwr_size`](@ref). """ struct PowerMeasure{M,A} <: AbstractProductMeasure parent::M @@ -35,6 +37,31 @@ end maybestatic_length(μ::PowerMeasure) = prod(maybestatic_size(μ)) maybestatic_size(μ::PowerMeasure) = map(maybestatic_length, μ.axes) + +""" + MeasureBase.pwr_base(μ::PowerMeasure) + +Returns `ν` for `μ = ν^axs` +""" +pwr_base(μ::PowerMeasure) = μ.parent + + +""" + MeasureBase.pwr_axes(μ::PowerMeasure) + +Returns `axs` for `μ = ν^axs`, `axs` being a tuple of integer ranges. +""" +pwr_axes(μ::PowerMeasure) = μ.axes + + +""" + MeasureBase.pwr_size(μ::PowerMeasure) + +Returns `sz` for `μ = ν^sz`, `sz` being a tuple of integers. +""" +pwr_size(μ::PowerMeasure) = map(length, μ.axes) + + function Pretty.tile(μ::PowerMeasure) sz = length.(μ.axes) arg1 = Pretty.tile(μ.parent) @@ -130,7 +157,7 @@ end end end -@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes)) +@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(pwr_size(μ)) @inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N} static(0) @@ -138,7 +165,7 @@ end @propagate_inbounds function checked_arg(μ::PowerMeasure, x::AbstractArray{<:Any}) @boundscheck begin - sz_μ = map(length, μ.axes) + sz_μ = pwr_size(μ) sz_x = size(x) if sz_μ != sz_x throw(ArgumentError("Size of variate doesn't match size of power measure")) @@ -164,26 +191,9 @@ function logdensity_def( end -# For transport, always pull back to one-dimensional PowerMeasure first: - -transport_origin(μ::PowerMeasure{<:Any,N}) where N = ν.parent^product(map(length, μ.axes)) - -function from_origin(μ::_PowerStdMeasure{<:Any,N}, x_origin) where N - # Sanity check, should never fail: - @assert x_origin isa AbstractVector - return reshape(x_origin, map(length, μ.axes)...) -end - - -# One-dimensional PowerMeasure has an origin iff it's parent has an origin: - -transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _origin_pwr(::typeof(μ), transport_origin(μ.parent), μ.axes) -_pwr_origin(::Type{MU}, parent_origin, axes) = parent_origin^axes -_pwr_origin(::Type{MU}, ::NoTransportOrigin, axes) = NoTransportOrigin{MU} - -function from_origin(μ::PowerMeasure{<:AbstractMeasure,1}, x_origin) - # Sanity check, should never fail: - @assert x_origin isa AbstractVector - from_origin.(Ref(μ.parent), x_origin) -end +""" + MeasureBase.StdPowerMeasure{MU<:StdMeasure,N} +Represents and N-dimensional power of the standard measure `MU()`. +""" +const StdPowerMeasure{N,MU<:StdMeasure} = PowerMeasure{MU,<:NTuple{N,Base.OneTo}} diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 00806cb1..9d90202d 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -29,6 +29,16 @@ end Base.length(μ::AbstractProductMeasure) = length(marginals(μ)) Base.size(μ::AbstractProductMeasure) = size(marginals(μ)) + +# TODO: Better `map` support in MappedArrays +_map(f, args...) = map(f, args...) +_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(fchain((x.f, f)), x.data) + +function testvalue(::Type{T}, μ::AbstractProductMeasure) where {T} + _map(m -> testvalue(T, m), marginals(μ)) +end + + function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractProductMeasure) where {T} mar = marginals(d) _rand_product(rng, T, mar, eltype(mar)) @@ -157,17 +167,8 @@ function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple} Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = " ⊗ ") end - - marginals(μ::ProductMeasure) = μ.marginals -# TODO: Better `map` support in MappedArrays -_map(f, args...) = map(f, args...) -_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(fchain((x.f, f)), x.data) - -function testvalue(::Type{T}, d::AbstractProductMeasure) where {T} - _map(m -> testvalue(T, m), marginals(d)) -end ############################################################################### # I <: Base.Generator @@ -229,6 +230,8 @@ function checked_arg( end + + function transport_to(ν::Pro) end diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl new file mode 100644 index 00000000..69003a3b --- /dev/null +++ b/src/combinators/product_transport.jl @@ -0,0 +1,139 @@ +# For transport, always pull a PowerMeasure back to one-dimensional PowerMeasure first: + +transport_origin(μ::PowerMeasure{<:Any,N}) where N = ν.parent^product(pwr_size(μ)) + +function from_origin(μ::PowerMeasure{<:Any,N}, x_origin) where N + # Sanity check, should never fail: + @assert x_origin isa AbstractVector + return reshape(x_origin, pwr_size(μ)...) +end + + +# A one-dimensional PowerMeasure has an origin if it's parent has an origin: + +transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _origin_pwr(::typeof(μ), transport_origin(μ.parent), μ.axes) +_pwr_origin(::Type{MU}, parent_origin, axes) = parent_origin^axes +_pwr_origin(::Type{MU}, ::NoTransportOrigin, axes) = NoTransportOrigin{MU} + +function from_origin(μ::PowerMeasure{<:AbstractMeasure,1}, x_origin) + # Sanity check, should never fail: + @assert x_origin isa AbstractVector + from_origin.(Ref(μ.parent), x_origin) +end + + +# Transport between univariate standard measures and power measures of size one: + +function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) + return transport_def(ν, μ.parent, only(x)) +end + +function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) + return fill_with(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)) +end + +function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{NU,1}, x,) where {MU,NU} + return transport_to(ν.parent, μ.parent).(x) +end + + +# Transport to a multivariate standard measure from any measure: + +function transport_def(ν::StdPowerMeasure{MU,1}, μ::AbstractMeasure, ab) where MU + ν_inner = _inner_stdmeasure(ν) + transport_to_mvstd(ν_inner, μ, ab) +end + +function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) + return _to_mvstd_withdof(ν_inner, μ, getdof(μ), x, origin) +end + +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ, x) + y = transport_to(ν_inner^dof_μ, μ, x) + return y +end + +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoDOF, x) + _to_mvstd_withorigin(ν_inner, μ, transport_origin(μ), x) +end + +function _to_mvstd_withorigin(ν_inner::StdMeasure, ::AbstractMeasure, μ_origin, x) + x_origin = transport_to_mvstd(ν_inner, μ_origin, x) + from_origin(x_origin) +end + +function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, NoTransportOrigin, x) + throw(ArgumentError("Don't know how to transport values of type $(nameof(typeof(x))) from $(nameof(typeof(μ))) to a power of $(nameof(typeof(ν_inner)))")) +end + + +# Transport from a multivariate standard measure to any measure: + +function transport_def(ν::AbstractMeasure, μ::StdPowerMeasure{MU,1}, x) where MU + μ_inner = _inner_stdmeasure(μ) + _transport_from_mvstd(ν, μ_inner, x) +end + +function _transport_from_mvstd(ν::AbstractMeasure, μ_inner::StdMeasure, x) + # Sanity check, should be checked by transport machinery already: + @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector + y, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x) + if !isempty(x_rest) + throw(ArgumentError("Input value too long during transport")) + end + return y +end + +function transport_from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) + dof_ν = getdof(ν) + origin = transport_origin(ν) + return _from_mvstd_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) +end + +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::StdMeasure, x) + len_x = length(eachindex(x)) + + # Since we can't check DOF of original Bind, we could "run out x" if + # the original x was too short. `transport_to` below will detect this, but better + # throw a more informative exception here: + if len_x < dof_ν + throw(ArgumentError("Variate too short during transport involving Bind")) + end + + x_inner_dof, x_rest = _split_after(x, dof_ν) + y = transport_to(ν, μ_inner^dof_ν, x_inner_dof) + return y, x_rest +end + +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::StdMeasure, x) + _from_mvstd_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) +end + +function _from_mvstd_with_rest_withorigin(::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) + x_origin, x_rest = transport_from_mvstd_with_rest(ν_origin, x, μ_inner) + from_origin(x_origin), x_rest +end + +function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin, μ_inner::StdMeasure, x) + throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) +end + + + + +# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}) +# for user convenience: + +# ToDo: Handle combined/bind measures that don't have a fast getdof! + +_std_measure(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() +_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof +_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ)) + +function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} + transport_to(ν, _std_measure_for(MU, ν)) +end + +function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} + transport_to(_std_measure_for(NU, μ), μ) +end diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index a02c5765..c1aaa88c 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -1,8 +1,17 @@ +""" + StdExponential <: StdMeasure + +Represents the standard (rate of one) +[exponential](https://en.wikipedia.org/wiki/Exponential_distribution) probability measure. + +See [`StdMeasure`](@ref) for the semantics of standard measures in the +context of MeasureBase. +""" struct StdExponential <: StdMeasure end export StdExponential -insupport(d::StdExponential, x) = x ≥ zero(x) +insupport(::StdExponential, x) = x ≥ zero(x) @inline logdensity_def(::StdExponential, x) = -x @inline basemeasure(::StdExponential) = LebesgueBase() diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 0d502ec6..705c153c 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -1,8 +1,16 @@ -struct StdLogistic <: StdMeasure end +""" + StdLogistic <: StdMeasure + +Represents the standard (centered, scale of one) +[logistic](https://en.wikipedia.org/wiki/Logistic_distribution) probability measure. +See [`StdMeasure`](@ref) for the semantics of standard measures in the +context of MeasureBase. +""" +struct StdLogistic <: StdMeasure end export StdLogistic -@inline insupport(d::StdLogistic, x) = true +@inline insupport(::StdLogistic, x) = true @inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u)) @inline basemeasure(::StdLogistic) = LebesgueBase() diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 6e5287c6..a571b54d 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -1,10 +1,31 @@ -abstract type StdMeasure <: AbstractMeasure end +""" + abstract type MeasureBase.StdMeasure + +Abstract supertype for standard measures. + +Standard measures must be singleton types that represent common fundamental +measures such as [`StdUniform`](@ref), [`StdExponential`](@ref), +[`StdNormal`](@ref) and [`StdLogistic`](@ref). +A standard measure ([`StdUniform`](@ref), [`StdExponential`](@ref) and +[`StdNormal`](@ref)) is defined for every common Julia random number +generation function: -const _PowerStdMeasure{N,MU<:StdMeasure} = PowerMeasure{MU,<:NTuple{N,Base.OneTo}} +``` +StdMeasure(rand) == StdUniform() +StdMeasure(randexp) == StdExponential() +StdMeasure(randn) == StdNormal() +``` -_get_inner_stdmeasure(::_PowerStdMeasure{N,MU}) where {N,MU} = M() +[`StdLogistic`](@ref) has no associated random number generation function. + +All standard measures must be normalized, i.e. [`massof`](@ref) always +returns one. +""" +abstract type StdMeasure <: AbstractMeasure end +@inline massof(::StdMeasure) = static(true) +@inline getdof(::StdMeasure) = static(1) StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randexp)) = StdExponential() @@ -12,127 +33,13 @@ StdMeasure(::typeof(randn)) = StdNormal() @inline check_dof(::StdMeasure, ::StdMeasure) = nothing -@inline transport_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x - -function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) - return transport_def(ν, μ.parent, only(x)) -end - -function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) - return fill_with(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)) -end - -function transport_def(ν::_PowerStdMeasure{MU,1}, μ::_PowerStdMeasure{NU,1}, x,) where {MU,NU} - return transport_to(ν.parent, μ.parent).(x) -end - - - -# Transport to a multivariate standard measure from any measure: - -function transport_def(ν::_PowerStdMeasure{1}, μ::AbstractMeasure, ab) - ν_inner = _get_inner_stdmeasure(ν) - transport_to_mvstd(ν_inner, μ, ab) -end - -function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) - return _to_mvstd_withdof(ν_inner, μ, getdof(μ), x, origin) -end - -function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ, x) - y = transport_to(ν_inner^dof_μ, μ, x) - return y -end - -function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoDOF, x) - _to_mvstd_withorigin(ν_inner, μ, transport_origin(μ), x) -end - -function _to_mvstd_withorigin(ν_inner::StdMeasure, ::AbstractMeasure, μ_origin, x) - x_origin = transport_to_mvstd(ν_inner, μ_origin, x) - from_origin(x_origin) -end - -function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, NoTransportOrigin, x) - throw(ArgumentError("Don't know how to transport values of type $(nameof(typeof(x))) from $(nameof(typeof(μ))) to a power of $(nameof(typeof(ν_inner)))")) -end - -# Transport from a multivariate standard measure to any measure: - -function transport_def(ν::AbstractMeasure, μ::_PowerStdMeasure{1}, x) - μ_inner = _get_inner_stdmeasure(μ) - _transport_from_mvstd(ν, μ_inner, x) -end - -function _transport_from_mvstd(ν::AbstractMeasure, μ_inner::StdMeasure, x) - # Sanity check, should be checked by transport machinery already: - @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector - y, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x) - if !isempty(x_rest) - throw(ArgumentError("Input value too long during transport")) - end - return y -end - -function transport_from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) - dof_ν = getdof(ν) - origin = transport_origin(ν) - return _from_mvstd_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) -end - -function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::StdMeasure, x) - len_x = length(eachindex(x)) - - # Since we can't check DOF of original Bind, we could "run out x" if - # the original x was too short. `transport_to` below will detect this, but better - # throw a more informative exception here: - if len_x < dof_ν - throw(ArgumentError("Variate too short during transport involving Bind")) - end - - x_inner_dof, x_rest = _split_after(x, dof_ν) - y = transport_to(ν, μ_inner^dof_ν, x_inner_dof) - return y, x_rest -end - -function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::StdMeasure, x) - _from_mvstd_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) -end - -function _from_mvstd_with_rest_withorigin(::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) - x_origin, x_rest = transport_from_mvstd_with_rest(ν_origin, x, μ_inner) - from_origin(x_origin), x_rest -end - -function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin, μ_inner::StdMeasure, x) - throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) -end - - -_empty_zero(::AbstractVector{T}) where {T<:Real} = Fill(zero(T), 0) - - -# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}) -# for user convenience: - -# ToDo: Handle combined/bind measures that don't have a fast getdof! - -_std_measure(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() -_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof -_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ)) - -function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} - transport_to(ν, _std_measure_for(MU, ν)) -end - -function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} - transport_to(_std_measure_for(NU, μ), μ) -end +# Transport between two equal standard measures: +@inline transport_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x -# Transform between standard measures and Dirac: +# Transport between a standard measure and Dirac: @inline transport_from_mvstd_with_rest(ν::Dirac, ::StdMeasure, x::Any) = ν.x, x diff --git a/src/standard/stdnormal.jl b/src/standard/stdnormal.jl index dc9cac74..b083606b 100644 --- a/src/standard/stdnormal.jl +++ b/src/standard/stdnormal.jl @@ -1,11 +1,19 @@ using SpecialFunctions: erfc, erfcinv using IrrationalConstants: invsqrt2 -struct StdNormal <: StdMeasure end +""" + StdNormal <: StdMeasure + +Represents the standard (mean of zero, variance of one) +[normal](https://en.wikipedia.org/wiki/Normal_distribution) probability measure. +See [`StdMeasure`](@ref) for the semantics of standard measures in the +context of MeasureBase. +""" +struct StdNormal <: StdMeasure end export StdNormal -@inline insupport(d::StdNormal, x) = true +@inline insupport(::StdNormal, x) = true @inline logdensity_def(::StdNormal, x) = -x^2 / 2 @inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), LebesgueBase()) diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 8817561e..d443ce2e 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -1,8 +1,18 @@ -struct StdUniform <: StdMeasure end +""" + StdUniform <: StdMeasure + +Represents the standard +[uniform](https://en.wikipedia.org/wiki/Continuous_uniform_distribution) +probability measure (from zero to one). It is the +same as the Lebesgue measure restricted to the unit interval. +See [`StdMeasure`](@ref) for the semantics of standard measures in the +context of MeasureBase. +""" +struct StdUniform <: StdMeasure end export StdUniform -insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) +insupport(::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline logdensity_def(::StdUniform, x) = zero(x) @inline basemeasure(::StdUniform) = LebesgueBase() From 840c1660ef93486d0025c08adf9082f3a4c312e9 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 15:08:11 +0200 Subject: [PATCH 043/133] Remove remnants of pointwiseproduct --- src/combinators/smart-constructors.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 26ba3948..45b1c5fa 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -4,18 +4,6 @@ half(μ::AbstractMeasure) = Half(μ) -############################################################################### -# PointwiseProductMeasure - -function pointwiseproduct(μ::AbstractMeasure, ℓ::Likelihood) - T = Core.Compiler.return_type(ℓ.k, Tuple{gentype(μ)}) - return pointwiseproduct(T, μ, ℓ) -end - -function pointwiseproduct(::Type{T}, μ::AbstractMeasure, ℓ::Likelihood) where {T} - return PointwiseProductMeasure(μ, ℓ) -end - ############################################################################### # PowerMeaure From c0b22e1bde68ebfd7e2392ac249912e516f70e9a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 15:25:16 +0200 Subject: [PATCH 044/133] STASH --- src/MeasureBase.jl | 1 - src/collection_utils.jl | 4 + src/combinators/power.jl | 4 - src/combinators/product.jl | 129 +------------------------- src/combinators/product_transport.jl | 125 +++++++++++++++++++++++++ src/combinators/smart-constructors.jl | 40 ++++++-- 6 files changed, 166 insertions(+), 137 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 22b29926..846f4fc5 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -45,7 +45,6 @@ import IfElse: ifelse export logdensity_def export basemeasure export basekernel -export productmeasure export insupport export getdof diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 6c3c5d08..3002259a 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -61,3 +61,7 @@ _reorder_nt(x::NamedTuple{names},::Val{names}) where {names} = x end # ToDo: Add custom rrule for _reorder_nt? + + +_fill_value(x::Fill) = x.value +_fill_axes(x::Fill) = x.axes diff --git a/src/combinators/power.jl b/src/combinators/power.jl index d221f0ce..31e2942d 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -94,10 +94,6 @@ end @inline _pm_axes(sz::Tuple{Vararg{IntegerLike,N}}) where {N} = map(one_to, sz) @inline _pm_axes(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} = axs -@inline function powermeasure(x::T, sz::Tuple{Vararg{Any,N}}) where {T,N} - PowerMeasure(x, _pm_axes(sz)) -end - marginals(d::PowerMeasure) = fill_with(d.parent, d.axes) function Base.:^(μ::AbstractMeasure, dims::Tuple{Vararg{AbstractArray,N}}) where {N} diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 9d90202d..11260be8 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -4,6 +4,7 @@ using Base: @propagate_inbounds import Base using FillArrays + """ abstract type AbstractProductMeasure @@ -169,6 +170,8 @@ end marginals(μ::ProductMeasure) = μ.marginals +proxy(μ::ProductMeasure{<:Fill}) = powermeasure(_fill_value(marginals(μ)), _fill_axes(marginals(μ))) + ############################################################################### # I <: Base.Generator @@ -231,129 +234,3 @@ end - -function transport_to(ν::Pro) -end - - -# ToDo - Possible improvement (breaking): For transport between -# NamedTuple-marginals Match names as far as possible, even if in -# different order, and transport between the remaining non-matching -# names in the order given? Direct transport between two -# non-standard measures will likely not be such a common use case, -# though, so may not be worth the effort. - - - - - -function _marginal_transport_def(marginals_ν::NamedTuple{names}, marginals_μ::NamedTuple, x) where names - NamedTuple{names}(transport_to.(values(marginals_ν), values(marginals_μ), x)) -end - -@inline function _marginal_transport_def(marginals_ν, marginals_μ, x) - marginal_transport_non_ntnt(marginals_ν, marginals_μ, x) -end - - - -function _marginal_transport_def(marginals_ν::AbstractVector{<:AbstractMeasure}, marginals_μ::AbstractVector{<:AbstractMeasure}, x) - @assert x isa AbstractVector # Sanity check, should not fail - transport_to.(marginals_ν, marginals_μ, x) -end - -function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where N - @assert x isa Tuple{Vararg{AbstractMeasure,N}} # Sanity check, should not fail - transport_to.(marginals_ν, marginals_μ, x) -end - -function _marginal_transport_def(marginals_ν::NamedTuple{names}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where {names,N} - _marginal_transport_def(marginals_ν, NamedTuple{names}(marginals_μ), x) -end - -function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::NamedTuple{names}, x) where {names,N} - _marginal_transport_def(marginals_ν, values(marginals), x) -end - -function _marginal_transport_def(marginals_ν::AbstractVector{<:AbstractMeasure}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where N - _marginal_transport_def(_as_tuple(marginals_ν, Val(N)), marginals_μ, x) -end - -function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::AbstractVector{<:AbstractMeasure}, x) where N - _marginal_transport_def(marginals_ν, _as_tuple(marginals_μ, Val(N)), x) -end - - - -# Transport for products - - -#!!!!!!!!!!!!!!!!!!!!!! TODO: - -# Helpers for product transforms and similar: - -struct _TransportToStd{NU<:StdMeasure} <: Function end -_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) - -struct _TransportFromStd{MU<:StdMeasure} <: Function end -_TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x) - -function _tuple_transport_def( - ν::PowerMeasure{NU}, - μs::Tuple, - xs::Tuple, -) where {NU<:StdMeasure} - reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes) -end - -function transport_def( - ν::PowerMeasure{NU}, - μ::ProductMeasure{<:Tuple}, - x, -) where {NU<:StdMeasure} - _tuple_transport_def(ν, marginals(μ), x) -end - -function transport_def( - ν::PowerMeasure{NU}, - μ::ProductMeasure{<:NamedTuple{names}}, - x, -) where {NU<:StdMeasure,names} - _tuple_transport_def(ν, values(marginals(μ)), values(x)) -end - -@inline _offset_cumsum(s, x, y, rest...) = (s, _offset_cumsum(s + x, y, rest...)...) -@inline _offset_cumsum(s, x) = (s,) -@inline _offset_cumsum(s) = () - -function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike) - N = map(getdof, μs) - offs = _offset_cumsum(startidx, N...) - map((o, n) -> o:o+n-1, offs, N) -end - -function _tuple_transport_def( - νs::Tuple, - μ::PowerMeasure{MU}, - x::AbstractArray{<:Real}, -) where {MU<:StdMeasure} - vrs = _stdvar_viewranges(νs, firstindex(x)) - xs = map(r -> view(x, r), vrs) - map(_TransportFromStd{MU}, νs, xs) -end - -function transport_def( - ν::ProductMeasure{<:Tuple}, - μ::PowerMeasure{MU}, - x, -) where {MU<:StdMeasure} - _tuple_transport_def(marginals(ν), μ, x) -end - -function transport_def( - ν::ProductMeasure{<:NamedTuple{names}}, - μ::PowerMeasure{MU}, - x, -) where {MU<:StdMeasure,names} - NamedTuple{names}(_tuple_transport_def(values(marginals(ν)), μ, x)) -end diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 69003a3b..9ff22b2f 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -137,3 +137,128 @@ end function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} transport_to(_std_measure_for(NU, μ), μ) end + + + + + +@inline transport_origin(μ::ProductMeasure) = _marginals_tp_origin(marginals(μ)) +@inline from_origin(μ::ProductMeasure, x_origin) = _marginals_from_origin(marginals(μ), x_origin) + +_marginals_tp_origin(::Ms) where Ms = NoTransportOrigin{PowerMeasure{M}}() + + +# Pull back from a product over a Fill to a power measure: + +_marginals_tp_origin(marginals_μ::Fill) = marginals_μ.value^marginals_μ.axes +_marginals_from_origin(::Fill, x_origin) = x_origin + + +# Pull back from a NamedTuple product measure to a Tuple product measure: +# +# Maybe ToDo (breaking): For transport between NamedTuple-marginals we could +# match names where possible, even if given in different order, and transport +# between the remaining non-matching names in the order given. This may not +# be worth the additional complexity, though, since transport is typically +# used with a (power of a) standard measure on one side. + +_marginals_tp_origin(marginals_μ::NamedTuple{names}) where names = productmeasure(values(marginals_μ)) +_marginals_from_origin(::NamedTuple{names}, x_origin::NamedTuple) where names = _reorder_nt(x_origin, Val(names)) + + +# Transport between two instances of ProductMeasure: + +transport_def(ν::ProductMeasure, μ::ProductMeasure, x) = _marginal_transport_def(marginals(ν), marginals(μ), x) + +function _marginal_transport_def(marginals_ν, marginals_μ, x) + @assert size(marginals_ν) == size(marginals_μ) == size(x) # Sanity check, should not fail + transport_def.(marginals_ν, marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where N + @assert x isa Tuple{Vararg{AbstractMeasure,N}} # Sanity check, should not fail + map(transport_def, marginals_ν, marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::AbstractVector{<:AbstractMeasure}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where N + _marginal_transport_def(_as_tuple(marginals_ν, Val(N)), marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::AbstractVector{<:AbstractMeasure}, x) where N + _marginal_transport_def(marginals_ν, _as_tuple(marginals_μ, Val(N)), _as_tuple(x, Val(N))) +end + + + +# Transport for products + + +#!!!!!!!!!!!!!!!!!!!!!! TODO: + +# Helpers for product transforms and similar: + +struct _TransportToStd{NU<:StdMeasure} <: Function end +_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) + +struct _TransportFromStd{MU<:StdMeasure} <: Function end +_TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x) + +function _tuple_transport_def( + ν::PowerMeasure{NU}, + μs::Tuple, + xs::Tuple, +) where {NU<:StdMeasure} + reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes) +end + +function transport_def( + ν::PowerMeasure{NU}, + μ::ProductMeasure{<:Tuple}, + x, +) where {NU<:StdMeasure} + _tuple_transport_def(ν, marginals(μ), x) +end + +function transport_def( + ν::PowerMeasure{NU}, + μ::ProductMeasure{<:NamedTuple{names}}, + x, +) where {NU<:StdMeasure,names} + _tuple_transport_def(ν, values(marginals(μ)), values(x)) +end + +@inline _offset_cumsum(s, x, y, rest...) = (s, _offset_cumsum(s + x, y, rest...)...) +@inline _offset_cumsum(s, x) = (s,) +@inline _offset_cumsum(s) = () + +function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike) + N = map(getdof, μs) + offs = _offset_cumsum(startidx, N...) + map((o, n) -> o:o+n-1, offs, N) +end + +function _tuple_transport_def( + νs::Tuple, + μ::PowerMeasure{MU}, + x::AbstractArray{<:Real}, +) where {MU<:StdMeasure} + vrs = _stdvar_viewranges(νs, firstindex(x)) + xs = map(r -> view(x, r), vrs) + map(_TransportFromStd{MU}, νs, xs) +end + +function transport_def( + ν::ProductMeasure{<:Tuple}, + μ::PowerMeasure{MU}, + x, +) where {MU<:StdMeasure} + _tuple_transport_def(marginals(ν), μ, x) +end + +function transport_def( + ν::ProductMeasure{<:NamedTuple{names}}, + μ::PowerMeasure{MU}, + x, +) where {MU<:StdMeasure,names} + NamedTuple{names}(_tuple_transport_def(values(marginals(ν)), μ, x)) +end diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 45b1c5fa..399d9471 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -7,7 +7,11 @@ half(μ::AbstractMeasure) = Half(μ) ############################################################################### # PowerMeaure -powermeasure(m::AbstractMeasure, ::Tuple{}) = m +powermeasure(m::AbstractMeasure, ::Tuple{}) = asmeasure(m) + +@inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N} + PowerMeasure(asmeasure(x), _pm_axes(sz)) +end function powermeasure( μ::WeightedMeasure, @@ -25,24 +29,48 @@ end ############################################################################### # ProductMeasure -productmeasure(mar::FillArrays.Fill) = powermeasure(mar.value, mar.axes) +""" + productmeasure(μs) + +Constructs a product over a collection `μs` of measures. + +Examples: + +```julia +using MeasureBase, AffineMaps +productmeasure((StdNormal(), StdExponential())) +productmeasure(a = StdNormal(), b = StdExponential())) +productmeasure([pushfwd(Mul(scale), StdExponential()) for scale in 0.1:0.2:2]) +productmeasure((pushfwd(Mul(scale), StdExponential()) for scale in 0.1:0.2:2)) +""" +function productmeasure end +export productmeasure + +productmeasure(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) + +productmeasure(mar::Tuple{Vararg{<:AbstractMeasure}}) = ProductMeasure(mar) +productmeasure(mar::Tuple) = ProductMeasure(map(asmeasure, mar)) + +productmeasure(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) +productmeasure(mar::NamedTuple) = ProductMeasure(map(asmeasure, mar)) + +productmeasure(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) +productmeasure(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} return powermeasure(mar.f.value, axes(mar.data)) end productmeasure(mar::Base.Generator) = ProductMeasure(mar) -productmeasure(mar::AbstractArray) = ProductMeasure(mar) # TODO: Make this static when its length is static @inline function productmeasure( - mar::AbstractArray{WeightedMeasure{StaticFloat64{W},M}}, + mar::AbstractArray{<:WeightedMeasure{StaticFloat64{W},M}}, ) where {W,M} return weightedmeasure(W * length(mar), productmeasure(map(basemeasure, mar))) end -productmeasure(nt::NamedTuple) = ProductMeasure(nt) -productmeasure(tup::Tuple) = ProductMeasure(tup) +# ToDo: Remove or at least refactor this (ProductMeasure shouldn't take a kernel at it's argument). productmeasure(f, param_maps, pars) = ProductMeasure(kernel(f, param_maps), pars) From 11ebdd697904b31a21fd3c0a83c0dc72bf96f97d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 15:41:02 +0200 Subject: [PATCH 045/133] STASH --- src/combinators/bind.jl | 13 +++++++------ src/combinators/likelihood.jl | 24 +++++++++++++----------- src/combinators/smart-constructors.jl | 12 ++++++++++++ 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 9e549e3f..bb62f801 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -112,12 +112,13 @@ of `x` in respect to density calculation and transport as well. """ function transportmeasure(μ::Bind, x) tpm_α, a, b = tpmeasure_split_combined(μ.α, x) - tpm_β_a = transportmeasure(μ.f_β(a), b) + tpm_β_a = transportmeasure(_get_β_a(μ, a), b) mcombine(μ.f_c, tpm_α, tpm_β_a) end localmeasure(μ::Bind, x) = transportmeasure(μ, x) +_get_β_a(μ::Bind, a) = asmeasure(μ.f_β(a)) tpmeasure_split_combined(f_c, μ::Bind, xy) = _bind_tsc(f_c, μ::Bind, xy) @@ -140,7 +141,7 @@ _bind_tsc(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) = _bi function _bind_tsc_cat_lμabyxy(f_c, μ, xy) tpm_α, a, by = tpmeasure_split_combined(μ.f_c, μ.α, xy) - β_a = μ.f_β(a) + β_a = _get_β_a(μ, a) tpm_β_a, b, y = tpmeasure_split_combined(f_c, β_a, by) tpm_μ = mcombine(μ.f_c, tpm_α, tpm_β_a) return tpm_μ, a, b, y, xy @@ -177,28 +178,28 @@ logdensity_def(::Bind, x) = throw(ArgumentError("logdensity_def is not available # Specialize logdensityof to avoid duplicate calculations: function logdensityof(μ::Bind, x) tpm_α, a, b = tpmeasure_split_combined(μ.α, x) - β_a = μ.f_β(a) + β_a = _get_β_a(μ, a) logdensityof(tpm_α, a) + logdensityof(β_a, b) end # Specialize unsafe_logdensityof to avoid duplicate calculations: function unsafe_logdensityof(μ::Bind, x) tpm_α, a, b = tpmeasure_split_combined(μ.α, x) - β_a = μ.f_β(a) + β_a = _get_β_a(μ, a) unsafe_logdensityof(tpm_α, a) + unsafe_logdensityof(β_a, b) end function Base.rand(rng::Random.AbstractRNG, ::Type{T}, μ::Bind) where {T<:Real} a = rand(rng, T, μ.α) - b = rand(rng, T, μ.f_β(a)) + b = rand(rng, T, _get_β_a(μ, a)) return μ.f_c(a, b) end function transport_to_mvstd(ν_inner::StdMeasure, μ::Bind, ab) tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) - β_a = μ.f_β(a) + β_a = _get_β_a(μ, a) y1 = transport_to_mvstd(ν_inner, tpm_α, a) y2 = transport_to_mvstd(ν_inner, β_a, b) return vcat(y1, y2) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index b244fd0f..8578040f 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -8,7 +8,7 @@ abstract type AbstractLikelihood end # ifelse(insupport(ℓ, p), t, f)() # end -# insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x) +# insupport(ℓ::AbstractLikelihood, p) = insupport(_eval_k(ℓ, p), ℓ.x) @doc raw""" Likelihood(k, x) @@ -98,7 +98,9 @@ end # For type stability, in case k is a type (resp. a constructor): Likelihood(k, x::X) where {X} = Likelihood{Core.Typeof(k),X}(k, x) -(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) +(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(_eval_k(lik, p), lik.x)) + +_eval_k(ℓ::AbstractLikelihood, p) = asmeasure(_eval_k(ℓ, p)) DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity() @@ -113,14 +115,14 @@ function Base.show(io::IO, ℓ::Likelihood) Pretty.pprint(io, ℓ) end -insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x) +insupport(ℓ::AbstractLikelihood, p) = insupport(_eval_k(ℓ, p), ℓ.x) @inline function logdensityof(ℓ::AbstractLikelihood, p) - logdensityof(ℓ.k(p), ℓ.x) + logdensityof(_eval_k(ℓ, p), ℓ.x) end @inline function unsafe_logdensityof(ℓ::AbstractLikelihood, p) - return unsafe_logdensityof(ℓ.k(p), ℓ.x) + return unsafe_logdensityof(_eval_k(ℓ, p), ℓ.x) end # basemeasure(ℓ::Likelihood) = @error "Likelihood requires local base measure" @@ -180,14 +182,14 @@ likelihoodof(k, x) = Likelihood(k, x) # Compute the log of the likelihood ratio, in order to compare two choices for # parameters. This is computed as -# logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# logdensity_rel(_eval_k(ℓ, p), ℓ.k(q), ℓ.x) # Since `logdensity_rel` can leave common base measure unevaluated, this can be # more efficient than -# logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# logdensityof(_eval_k(ℓ, p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) # """ -# log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(_eval_k(ℓ, p), ℓ.k(q), ℓ.x) # # likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) @@ -199,14 +201,14 @@ likelihoodof(k, x) = Likelihood(k, x) # Compute the log of the likelihood ratio, in order to compare two choices for # parameters. This is equal to -# density_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# density_rel(_eval_k(ℓ, p), ℓ.k(q), ℓ.x) # but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. # Since `density_rel` can leave common base measure unevaluated, this can be # more efficient than -# logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# logdensityof(_eval_k(ℓ, p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) # """ # function likelihood_ratio(ℓ::Likelihood, p, q) -# exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) +# exp(ULogarithmic, logdensity_rel(_eval_k(ℓ, p), ℓ.k(q), ℓ.x)) # end diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 399d9471..2ba208d3 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -7,6 +7,18 @@ half(μ::AbstractMeasure) = Half(μ) ############################################################################### # PowerMeaure +""" + powermeasure(μ, dims) + powermeasure(μ, axes) + +Constructs a power of a measure `μ`. + +`powermeasure(μ, exponent)` is semantically equivalent to +`productmeasure(Fill(μ, exponent))`, but more efficient. +""" +function powermeasure end +export powermeasure + powermeasure(m::AbstractMeasure, ::Tuple{}) = asmeasure(m) @inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N} From 58b34576db28f2275f2b437fa7712ed91cb26cb4 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 15:56:52 +0200 Subject: [PATCH 046/133] STASH --- src/transport.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transport.jl b/src/transport.jl index ce8ce1fd..18b4461e 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -92,6 +92,8 @@ distribution itself or a power of it (e.g. `StdUniform()` or """ function transport_to end +@inline transport_to(ν, μ) = TransportFunction(asmeasure(ν), asmeasure(μ)) + """ transport_to(ν, μ, x) @@ -99,6 +101,7 @@ Transport `x` from the measure `μ` to the measure `ν` """ transport_to(ν, μ, x) = transport_to(ν, μ)(x) + """ transport_def(ν, μ, x) @@ -230,8 +233,6 @@ struct TransportFunction{NU,MU} <: Function end end -@inline transport_to(ν, μ) = TransportFunction(ν, μ) - function Base.:(==)(a::TransportFunction, b::TransportFunction) return a.ν == b.ν && a.μ == b.μ end From 83da5267789edc6231dece9cc8d378e889f86f32 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 17:03:44 +0200 Subject: [PATCH 047/133] STASH fast_dof --- src/MeasureBase.jl | 1 - src/combinators/combined.jl | 3 +- src/combinators/power.jl | 7 +- src/combinators/product.jl | 3 +- src/combinators/product_transport.jl | 28 +++----- src/combinators/transformedmeasure.jl | 13 +++- src/getdof.jl | 100 ++++++++++++++++++++++---- src/transport.jl | 8 +-- 8 files changed, 122 insertions(+), 41 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 846f4fc5..e0aed112 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -47,7 +47,6 @@ export basemeasure export basekernel export insupport -export getdof export transport_to include("insupport.jl") diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index d22fb35c..e17252e6 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -93,7 +93,8 @@ end # TODO: Could split `ab`` here, but would be wasteful. @inline insupport(::CombinedMeasure, ab) = NoFastInsupport() -@inline getdof(μ::CombinedMeasure) = getdof(μ.α) + getdof(μ.β) +@inline getdof(μ::CombinedMeasure) = _add_dof(getdof(μ.α), getdof(μ.β)) +@inline fast_dof(μ::CombinedMeasure) =_add_dof(fast_dof(μ.α), fast_dof(μ.β)) # Bypass `checked_arg`, would require require splitting ab: @inline checked_arg(::CombinedMeasure, ab) = ab diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 31e2942d..475ce9d1 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -153,12 +153,17 @@ end end end -@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(pwr_size(μ)) +@inline getdof(μ::PowerMeasure) = _mul_dof(getdof(μ.parent), prod(pwr_size(μ))) +@inline fast_dof(μ::PowerMeasure) = _mul_dof(fast_dof(μ.parent), prod(pwr_size(μ))) @inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N} static(0) end +@inline function fast_dof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N} + static(0) +end + @propagate_inbounds function checked_arg(μ::PowerMeasure, x::AbstractArray{<:Any}) @boundscheck begin sz_μ = pwr_size(μ) diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 11260be8..d0349fbd 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -219,7 +219,8 @@ end return true end -getdof(d::AbstractProductMeasure) = mapreduce(getdof, +, marginals(d)) +getdof(d::AbstractProductMeasure) = mapreduce(getdof, _add_dof, marginals(d)) +fast_dof(d::AbstractProductMeasure) = mapreduce(fast_dof, _add_dof, marginals(d)) function checked_arg(μ::ProductMeasure{<:NTuple{N,Any}}, x::NTuple{N,Any}) where {N} map(checked_arg, marginals(μ), x) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 9ff22b2f..bd7feffd 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -45,15 +45,15 @@ function transport_def(ν::StdPowerMeasure{MU,1}, μ::AbstractMeasure, ab) where end function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) - return _to_mvstd_withdof(ν_inner, μ, getdof(μ), x, origin) + return _to_mvstd_withdof(ν_inner, μ, fast_dof(μ), x, origin) end -function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ, x) +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ::IntegerLike, x) y = transport_to(ν_inner^dof_μ, μ, x) return y end -function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoDOF, x) +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::AbstractNoDOF, x) _to_mvstd_withorigin(ν_inner, μ, transport_origin(μ), x) end @@ -75,8 +75,6 @@ function transport_def(ν::AbstractMeasure, μ::StdPowerMeasure{MU,1}, x) where end function _transport_from_mvstd(ν::AbstractMeasure, μ_inner::StdMeasure, x) - # Sanity check, should be checked by transport machinery already: - @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector y, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x) if !isempty(x_rest) throw(ArgumentError("Input value too long during transport")) @@ -85,12 +83,12 @@ function _transport_from_mvstd(ν::AbstractMeasure, μ_inner::StdMeasure, x) end function transport_from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) - dof_ν = getdof(ν) + dof_ν = fast_dof(ν) origin = transport_origin(ν) - return _from_mvstd_with_rest_withdof(ν, getdof(ν), μ_inner, x, dof_ν, origin) + return _from_mvstd_with_rest_withdof(ν, dof_ν, μ_inner, x, dof_ν, origin) end -function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::StdMeasure, x) +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν::IntegerLike, μ_inner::StdMeasure, x) len_x = length(eachindex(x)) # Since we can't check DOF of original Bind, we could "run out x" if @@ -105,7 +103,7 @@ function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν, μ_inner::St return y, x_rest end -function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::NoDOF, μ_inner::StdMeasure, x) +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::AbstractNoDOF, μ_inner::StdMeasure, x) _from_mvstd_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) end @@ -119,16 +117,13 @@ function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin end - - # Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}) # for user convenience: -# ToDo: Handle combined/bind measures that don't have a fast getdof! +_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure_for_impl(M, some_dof(μ)) +_std_measure_for_impl(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() +_std_measure_for_impl(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof -_std_measure(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() -_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof -_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ)) function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} transport_to(ν, _std_measure_for(MU, ν)) @@ -139,9 +134,6 @@ function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} end - - - @inline transport_origin(μ::ProductMeasure) = _marginals_tp_origin(marginals(μ)) @inline from_origin(μ::ProductMeasure, x_origin) = _marginals_from_origin(marginals(μ), x_origin) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index dab76d5f..65e58cfa 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -88,10 +88,17 @@ end pushfwd(ν.f, basemeasure(parent(ν)), NoVolCorr()) end -_pushfwd_dof(::Type{MU}, ::Type, dof) where {MU} = NoDOF{MU}() -_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof -@inline getdof(ν::MU) where {MU<:PushforwardMeasure} = getdof(ν.origin) +const _NonBijectivePushforward = Union{PushforwardMeasure{<:Any,<:NoInverse},PushforwardMeasure{<:NoInverse,<:Any},PushforwardMeasure{<:NoInverse,<:NoInverse}} + +@inline getdof(ν::PushforwardMeasure) = _pushfwd_dof(ν) +_pushfwd_dof(ν::PushforwardMeasure) = getdof(ν.origin) +_pushfwd_dof(ν::_NonBijectivePushforward) = NoDOF{typeof(ν)}() + +@inline fast_dof(ν::PushforwardMeasure) = _pushfwd_fastdof(ν) +_pushfwd_fastdof(ν::PushforwardMeasure) = fast_dof(ν.origin) +_pushfwd_fastdof(ν::_NonBijectivePushforward) = NoDOF{typeof(ν)}() + # Bypass `checked_arg`, would require potentially costly transformation: @inline checked_arg(::PushforwardMeasure, x) = x diff --git a/src/getdof.jl b/src/getdof.jl index 6a31fb4b..2ee16ea8 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -1,16 +1,24 @@ """ - MeasureBase.NoDOF{MU} + abstract type MeasureBase.AbstractNoDOF + +Abstract supertype for [`NoDOF`](@ref) and [`NoFastDOF`](@ref). +""" +abstract type AbstractNoDOF end + +_add_dof(dof_a::Real, dof_b::Real) = dof_a + dof_b +_add_dof(dof_a::AbstractNoDOF, ::Real) = dof_a +_add_dof(::Real, dof_b::AbstractNoDOF) = dof_b +_add_dof(dof_a::AbstractNoDOF, ::AbstractNoDOF) = dof_a + + +""" + MeasureBase.NoDOF{MU} <: AbstractNoDOF Indicates that there is no way to compute degrees of freedom of a measure of type `MU` with the given information, e.g. because the DOF are not a global property of the measure. """ -struct NoDOF{MU} end - -_add_dof(dof_a::Real, dof_b::Real) = dof_a + dof_b -_add_dof(dof_a::NoDOF, ::Real) = dof_a -_add_dof(::Real, dof_b::NoDOF) = dof_b -_add_dof(dof_a::NoDOF, ::NoDOF) = dof_a +struct NoDOF{MU} <: AbstractNoDOF end """ @@ -26,26 +34,93 @@ is `n - 1`. Also see [`check_dof`](@ref). """ function getdof end +export getdof # Prevent infinite recursion: -@inline _default_getdof(::Type{MU}, ::MU) where {MU} = NoDOF{MU} +@inline _default_getdof(::Type{MU}, ::MU) where {MU} = NoDOF{MU}() @inline _default_getdof(::Type{MU}, mu_base) where {MU} = getdof(mu_base) @inline getdof(μ::MU) where {MU} = _default_getdof(MU, basemeasure(μ)) + +""" + MeasureBase.NoFastDOF{MU} <: AbstractNoDOF + +Indicates that there is no way to compute degrees of freedom of a measure +of type `MU` with the given information, e.g. because the DOF are not +a global property of the measure. +""" +struct NoFastDOF{MU} <: AbstractNoDOF end + + +""" + fast_dof(μ::MU) + +Returns the effective number of degrees of freedom of variates of +measure `μ`, if it can be computed efficiently, otherwise +returns [`NoFastDOF{MU}()`](@ref). + +Defaults to `getdof(μ)` and should be specialized for measures for +wich DOF can't be computed instantly. + +The effective NDOF my differ from the length of the variates. For example, +the effective NDOF for a Dirichlet distribution with variates of length `n` +is `n - 1`. + +Also see [`check_dof`](@ref). +""" +function fast_dof end +export fast_dof + +fast_dof(μ) = getdof(μ) + +# Prevent infinite recursion: +@inline _default_fastdof(::Type{MU}, ::MU) where {MU} = NoFastDOF{MU}() +@inline _default_fastdof(::Type{MU}, mu_base) where {MU} = fast_dof(mu_base) + +@inline fast_dof(μ::MU) where {MU} = _default_fastdof(MU, basemeasure(μ)) + + +""" + MeasureBase.some_dof(μ::AbstractMeasure) + +Get the DOF at some unspecified point of measure `μ`. + +Use with caution! + +In general, use [`getdof(μ)`](@ref) instead. `some_dof` is useful +for measures are expected to have a constant DOF of their whole +space but for which there is no way to compute it (or prove that +the DOF is constant of the measurable space). +""" +function some_dof end + +function some_dof() + m = asmeasure(μ) + _try_direct_dof(m, getdof(m)) +end + +_try_direct_dof(::AbstractMeasure, dof::IntegerLike) = dof +_try_direct_dof(μ::AbstractMeasure, ::AbstractNoDOF) = _try_local_dof(μ::AbstractMeasure, some_dof(some_localmeasure(μ))) +_try_local_dof(::AbstractMeasure, dof::IntegerLike) = dof +_try_local_dof(μ::AbstractMeasure, ::AbstractNoDOF) = throw(ArgumentError("Can't determine DOF for measure of type $(nameof(typeof(μ)))")) + +_some_localmeasure(μ::AbstractMeasure) = localmeasure(μ, testvalue(μ)) + + """ MeasureBase.check_dof(ν, μ)::Nothing Check if `ν` and `μ` have the same effective number of degrees of freedom -according to [`MeasureBase.getdof`](@ref). +according to [`MeasureBase.fast_dof`](@ref). """ function check_dof end function check_dof(ν, μ) - n_ν = getdof(ν) - n_μ = getdof(μ) + n_ν = fast_dof(ν) + n_μ = fast_dof(μ) # TODO: How to handle this better if DOF is unclear e.g. for HierarchicalMeasures? - if n_ν isa NoDOF || n_μ isa NoDOF + if n_ν isa AbstractNoDOF || n_μ isa AbstractNoDOF return true end if n_ν != n_μ @@ -61,6 +136,7 @@ end _check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent() ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback + """ MeasureBase.NoArgCheck{MU,T} diff --git a/src/transport.jl b/src/transport.jl index 18b4461e..a8f2682b 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -76,7 +76,7 @@ and/or * `MeasureBase.from_origin(μ::MyMeasure, x) = y` * `MeasureBase.to_origin(μ::MyMeasure, y) = x` -and ensure `MeasureBase.getdof(μ::MyMeasure)` is defined correctly. +and ensure `MeasureBase.fast_dof(μ::MyMeasure)` is defined correctly. A standard measure type like `StdUniform`, `StdExponential` or `StdLogistic` may also be used as the source or target of the transform: @@ -86,8 +86,8 @@ f_to_uniform(StdUniform, μ) f_to_uniform(ν, StdUniform) ``` -Depending on [`getdof(μ)`](@ref) (resp. `ν`), an instance of the standard -distribution itself or a power of it (e.g. `StdUniform()` or +Depending on [`some_getdof(μ)`](@ref) (resp. `ν`), an instance of the +standard measure itself or a power of it (e.g. `StdUniform()` or `StdUniform()^dof`) will be chosen as the transformation partner. """ function transport_to end @@ -189,7 +189,7 @@ end return prog end -@inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ)) +@inline _transport_intermediate(ν, μ) = _transport_intermediate(fast_dof(ν), fast_dof(μ)) @inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ @inline _transport_intermediate(::StaticInteger{1}, ::StaticInteger{1}) = StdUniform() From 10b7fd0bfd61dba0d2661fb82c88e6c6ca30393a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 18:03:54 +0200 Subject: [PATCH 048/133] STASH --- src/MeasureBase.jl | 8 +-- src/combinators/abstract_product.jl | 42 +++++++++++++ src/combinators/power.jl | 14 ----- src/combinators/product.jl | 89 +++++++++------------------- src/combinators/product_transport.jl | 20 ++++++- src/standard/stdmeasure.jl | 7 --- 6 files changed, 92 insertions(+), 88 deletions(-) create mode 100644 src/combinators/abstract_product.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index e0aed112..216d1f63 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -138,16 +138,16 @@ include("primitives/lebesgue.jl") include("primitives/dirac.jl") include("primitives/trivial.jl") -include("combinators/product.jl") -include("combinators/power.jl") -include("combinators/product_transport.jl") - include("standard/stdmeasure.jl") include("standard/stduniform.jl") include("standard/stdexponential.jl") include("standard/stdlogistic.jl") include("standard/stdnormal.jl") +include("combinators/abstract_product.jl") +include("combinators/power.jl") +include("combinators/product.jl") +include("combinators/product_transport.jl") include("combinators/transformedmeasure.jl") include("combinators/weighted.jl") include("combinators/superpose.jl") diff --git a/src/combinators/abstract_product.jl b/src/combinators/abstract_product.jl new file mode 100644 index 00000000..2d14cbce --- /dev/null +++ b/src/combinators/abstract_product.jl @@ -0,0 +1,42 @@ +""" + marginals(μ::AbstractMeasure) + +Returns the marginals measures of `μ` as a collection of measures. + +The kind of marginalization implied by `marginals` depends on the +type of `μ`. + +`μ` may be a power of a measure or a product of measures, but other +types of measures may support `marginals` as well. +""" +function marginals end +export marginals + + +""" + abstract type AbstractProductMeasure + +Abstact type for products of measures. + +[`marginals(μ::AbstractProductMeasure)`](@ref) returns the collection of +measures that `μ` is the product of. +""" +abstract type AbstractProductMeasure <: AbstractMeasure end +export AbstractProductMeasure + +function Pretty.tile(μ::AbstractProductMeasure) + result = Pretty.literal("ProductMeasure(") + result *= Pretty.tile(marginals(μ)) + result *= Pretty.literal(")") +end + +massof(m::AbstractProductMeasure) = prod(massof, marginals(m)) + +Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure) = marginals(a) == marginals(b) +Base.isapprox(a::AbstractProductMeasure, b::AbstractProductMeasure; kwargs...) = isapprox(marginals(a), marginals(b); kwargs...) + + +# # ToDo: Do we want this? It's not so clear what the semantics of `length` and `size` +# # for measures should be, in general: +# Base.length(μ::AbstractProductMeasure) = length(marginals(μ)) +# Base.size(μ::AbstractProductMeasure) = size(marginals(μ)) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 475ce9d1..bbbdc198 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -1,19 +1,5 @@ import Base -""" - marginals(μ::AbstractMeasure) - -Returns the marginals measures of `μ` as a collection of measures. - -The kind of marginalization implied by `marginals` depends on the -type of `μ`. - -`μ` may be a power of a measure or a product of measures, but other -types of measures may support `marginals` as well. -""" -function marginals end -export marginals - export PowerMeasure diff --git a/src/combinators/product.jl b/src/combinators/product.jl index d0349fbd..7d036d75 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -3,44 +3,43 @@ using MappedArrays: ReadonlyMultiMappedArray using Base: @propagate_inbounds import Base using FillArrays +using Random: rand!, GLOBAL_RNG, AbstractRNG """ - abstract type AbstractProductMeasure + struct MeasureBase.ProductMeasure{M} <: AbstractProductMeasure -Abstact type for products of measures. +Represents a products of measures. -[`marginals(μ::AbstractProductMeasure)`](@ref) returns the collection of -measures that `μ` is the product of. +´ProductMeasure` wraps a collection of measures, this collection then +becomes the collection of the marginal measures of the `ProductMeasure`. + +User code should not instantiate `ProductMeasure` directly, but should call +[`productmeasure`](@ref) instead. """ -abstract type AbstractProductMeasure <: AbstractMeasure end -export AbstractProductMeasure +struct ProductMeasure{M} <: AbstractProductMeasure + marginals::M +end -function Pretty.tile(μ::AbstractProductMeasure) - result = Pretty.literal("ProductMeasure(") - result *= Pretty.tile(marginals(μ)) - result *= Pretty.literal(")") +function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple} + Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = " ⊗ ") end -massof(m::AbstractProductMeasure) = prod(massof, marginals(m)) +marginals(μ::ProductMeasure) = μ.marginals -function Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure) - marginals(a) == marginals(b) -end -Base.length(μ::AbstractProductMeasure) = length(marginals(μ)) -Base.size(μ::AbstractProductMeasure) = size(marginals(μ)) +proxy(μ::ProductMeasure{<:Fill}) = powermeasure(_fill_value(marginals(μ)), _fill_axes(marginals(μ))) # TODO: Better `map` support in MappedArrays _map(f, args...) = map(f, args...) _map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(fchain((x.f, f)), x.data) -function testvalue(::Type{T}, μ::AbstractProductMeasure) where {T} +function testvalue(::Type{T}, μ::ProductMeasure) where {T} _map(m -> testvalue(T, m), marginals(μ)) end -function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractProductMeasure) where {T} +function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure) where {T} mar = marginals(d) _rand_product(rng, T, mar, eltype(mar)) end @@ -85,13 +84,13 @@ function _rand_product( end -@inline function logdensity_def(μ::AbstractProductMeasure, x) +@inline function logdensity_def(μ::ProductMeasure, x) marginals_density_op(logdensity_def, marginals(μ), x) end -@inline function unsafe_logdensityof(μ::AbstractProductMeasure, x) +@inline function unsafe_logdensityof(μ::ProductMeasure, x) marginals_density_op(unsafe_logdensityof, marginals(μ), x) end -@inline function logdensity_rel(μ::AbstractProductMeasure, ν::AbstractProductMeasure, x) +@inline function logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x) marginals_density_op(logdensity_rel, marginals(μ), marginals(ν), x) end @@ -120,10 +119,13 @@ end end -@inline basemeasure(μ::AbstractProductMeasure) =_marginals_basemeasure(marginals(μ)) +@inline basemeasure(μ::ProductMeasure) =_marginals_basemeasure(marginals(μ)) _marginals_basemeasure(marginals_μ) = productmeasure(map(basemeasure, marginals_μ)) + +# I <: Base.Generator + function _marginals_basemeasure(marginals_μ::Base.Generator{I,F}) where {I,F} T = Core.Compiler.return_type(marginals_μ.f, Tuple{eltype(marginals_μ.iter)}) B = Core.Compiler.return_type(basemeasure, Tuple{T}) @@ -148,37 +150,6 @@ function _marginals_basemeasure_impl(marginals_μ::Base.Generator{I,F}, ::Type{B end - -""" - struct MeasureBase.ProductMeasure{M} <: AbstractProductMeasure - -Represents a products of measures. - -´ProductMeasure` wraps a collection of measures, this collection then -becomes the collection of the marginal measures of the `ProductMeasure`. - -User code should not instantiate `ProductMeasure` directly, but should call -[`productmeasure`](@ref) instead. -""" -struct ProductMeasure{M} <: AbstractProductMeasure - marginals::M -end - -function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple} - Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = " ⊗ ") -end - -marginals(μ::ProductMeasure) = μ.marginals - -proxy(μ::ProductMeasure{<:Fill}) = powermeasure(_fill_value(marginals(μ)), _fill_axes(marginals(μ))) - - -############################################################################### -# I <: Base.Generator - -export rand! -using Random: rand!, GLOBAL_RNG, AbstractRNG - @propagate_inbounds function Random.rand!( rng::AbstractRNG, d::ProductMeasure, @@ -192,8 +163,6 @@ using Random: rand!, GLOBAL_RNG, AbstractRNG return x end -export rand! -using Random: rand!, GLOBAL_RNG function _rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure, mar::AbstractArray) where {T} elT = typeof(rand(rng, T, first(mar))) @@ -203,7 +172,7 @@ function _rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure, mar::AbstractArra rand!(rng, d, x) end -@inline function insupport(d::AbstractProductMeasure, x::AbstractArray) +@inline function insupport(d::ProductMeasure, x::AbstractArray) mar = marginals(d) # We might get lucky and know statically that everything is inbounds T = Core.Compiler.return_type(insupport, Tuple{eltype(mar),eltype(x)}) @@ -212,15 +181,15 @@ end end end -@inline function insupport(d::AbstractProductMeasure, x) +@inline function insupport(d::ProductMeasure, x) for (mj, xj) in zip(marginals(d), x) dynamic(insupport(mj, xj)) || return false end return true end -getdof(d::AbstractProductMeasure) = mapreduce(getdof, _add_dof, marginals(d)) -fast_dof(d::AbstractProductMeasure) = mapreduce(fast_dof, _add_dof, marginals(d)) +getdof(d::ProductMeasure) = mapreduce(getdof, _add_dof, marginals(d)) +fast_dof(d::ProductMeasure) = mapreduce(fast_dof, _add_dof, marginals(d)) function checked_arg(μ::ProductMeasure{<:NTuple{N,Any}}, x::NTuple{N,Any}) where {N} map(checked_arg, marginals(μ), x) @@ -233,5 +202,3 @@ function checked_arg( NamedTuple{names}(map(checked_arg, values(marginals(μ)), values(x))) end - - diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index bd7feffd..421fc5d6 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -182,10 +182,26 @@ end -# Transport for products +function transport_to_mvstd(ν_inner::StdMeasure, μ::ProductMeasure, ab) + _marginals_to_mvstd(ν_inner, marginals(μ), ab) +end + +function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) + a, x2 = transport_from_mvstd_with_rest(ν.α, μ_inner, x) + b, x_rest = transport_from_mvstd_with_rest(ν.β, μ_inner, x2) + return ν.f_c(a, b), x_rest +end + +# Transport between a standard measure and Dirac: -#!!!!!!!!!!!!!!!!!!!!!! TODO: +@inline transport_from_mvstd_with_rest(ν::Dirac, ::StdMeasure, x::Any) = ν.x, x + +@inline transport_to_mvstd(::StdMeasure, ::Dirac, ::Any) = Zeros{Bool}(0) + + + +# Transport for products # Helpers for product transforms and similar: diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index a571b54d..05843ede 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -37,10 +37,3 @@ StdMeasure(::typeof(randn)) = StdNormal() # Transport between two equal standard measures: @inline transport_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x - - -# Transport between a standard measure and Dirac: - -@inline transport_from_mvstd_with_rest(ν::Dirac, ::StdMeasure, x::Any) = ν.x, x - -@inline transport_to_mvstd(ν::PowerMeasure{<:StdMeasure}, ::Dirac, ::Any) = Zeros{Bool}(map(_ -> 0, ν.axes)) From a86823382dd48d0ef1a08add1a9b8778a835df5d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 18:26:16 +0200 Subject: [PATCH 049/133] STASH --- src/combinators/product_transport.jl | 64 ++++++++++++++++++++-------- src/transport.jl | 12 ------ 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 421fc5d6..5dd8e004 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -1,3 +1,42 @@ +""" + transport_to(ν, ::Type{MU}) where {NU<:StdMeasure} + transport_to(::Type{NU}, μ) where {NU<:StdMeasure} + +As a user convencience, a standard measure type like [`StdUniform`](@ref), +[`StdExponential`](@ref), [`StdNormal`](@ref) or [`StdLogistic`](@ref) +may be used directly as the source or target a measure transport. + +Depending on [`some_getdof(μ)`](@ref) (resp. `ν`), an instance of the +standard measure itself or a power of it will be automatically chosen as +the transport partner. + +Example: + +```julia +transport_to(StdNormal, μ) +transport_to(ν, StdNormal) +``` +""" +function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} + transport_to(ν, _std_tp_partner(MU, ν)) +end + +function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} + transport_to(_std_tp_partner(NU, μ), μ) +end + +function transport_to(::Type{NU}, ::Type{MU}) where {NU<:StdMeasure,MU<:StdMeasure} + throw(ArgumentError("Can't construct a transport function between the type of two standard measures, need a measure instance on one side")) +end + +_std_tp_partner(::Type{M}, μ) where {M<:StdMeasure} = _std_tp_partner_bydof(M, some_dof(μ)) +_std_tp_partner_bydof(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() +_std_tp_partner_bydof(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof +function _std_tp_partner_bydof(::Type{M}, dof::AbstractNoDOF{MU}) where {M<:StdMeasure,MU} + throw(ArgumentError("Can't determine a standard transport partner for measures of type $(nameof(typeof(MU)))")) +end + + # For transport, always pull a PowerMeasure back to one-dimensional PowerMeasure first: transport_origin(μ::PowerMeasure{<:Any,N}) where N = ν.parent^product(pwr_size(μ)) @@ -117,21 +156,14 @@ function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin end -# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}) -# for user convenience: +# Transport between a standard measure and Dirac: -_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure_for_impl(M, some_dof(μ)) -_std_measure_for_impl(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() -_std_measure_for_impl(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof +@inline transport_from_mvstd_with_rest(ν::Dirac, ::StdMeasure, x::Any) = ν.x, x + +@inline transport_to_mvstd(::StdMeasure, ::Dirac, ::Any) = Zeros{Bool}(0) -function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} - transport_to(ν, _std_measure_for(MU, ν)) -end -function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} - transport_to(_std_measure_for(NU, μ), μ) -end @inline transport_origin(μ::ProductMeasure) = _marginals_tp_origin(marginals(μ)) @@ -182,6 +214,10 @@ end + + + + function transport_to_mvstd(ν_inner::StdMeasure, μ::ProductMeasure, ab) _marginals_to_mvstd(ν_inner, marginals(μ), ab) end @@ -193,12 +229,6 @@ function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure end -# Transport between a standard measure and Dirac: - -@inline transport_from_mvstd_with_rest(ν::Dirac, ::StdMeasure, x::Any) = ν.x, x - -@inline transport_to_mvstd(::StdMeasure, ::Dirac, ::Any) = Zeros{Bool}(0) - # Transport for products diff --git a/src/transport.jl b/src/transport.jl index a8f2682b..ce736b72 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -77,18 +77,6 @@ and/or * `MeasureBase.to_origin(μ::MyMeasure, y) = x` and ensure `MeasureBase.fast_dof(μ::MyMeasure)` is defined correctly. - -A standard measure type like `StdUniform`, `StdExponential` or -`StdLogistic` may also be used as the source or target of the transform: - -```julia -f_to_uniform(StdUniform, μ) -f_to_uniform(ν, StdUniform) -``` - -Depending on [`some_getdof(μ)`](@ref) (resp. `ν`), an instance of the -standard measure itself or a power of it (e.g. `StdUniform()` or -`StdUniform()^dof`) will be chosen as the transformation partner. """ function transport_to end From 8d1a2113516be081c59249adc8fff37ce117dd79 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 18:27:23 +0200 Subject: [PATCH 050/133] STASH --- src/combinators/product_transport.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 5dd8e004..9da9c0d2 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -223,9 +223,7 @@ function transport_to_mvstd(ν_inner::StdMeasure, μ::ProductMeasure, ab) end function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) - a, x2 = transport_from_mvstd_with_rest(ν.α, μ_inner, x) - b, x_rest = transport_from_mvstd_with_rest(ν.β, μ_inner, x2) - return ν.f_c(a, b), x_rest + _marginals_from_mvstd_with_rest(marginals(ν), μ_inner, x) end From 2c21ced25d6685e2a9a570eb4bb1c183364b5653 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 18:27:36 +0200 Subject: [PATCH 051/133] STASH --- src/combinators/product_transport.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 9da9c0d2..e99a39b6 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -215,9 +215,6 @@ end - - - function transport_to_mvstd(ν_inner::StdMeasure, μ::ProductMeasure, ab) _marginals_to_mvstd(ν_inner, marginals(μ), ab) end From 31c227c9af136102a636477d1b2f932b494b0c35 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 20:05:24 +0200 Subject: [PATCH 052/133] STASH --- src/collection_utils.jl | 22 ++++++- src/combinators/product_transport.jl | 86 ++++++++-------------------- src/proxies.jl | 13 ++++- 3 files changed, 54 insertions(+), 67 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 3002259a..02d836b5 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -1,10 +1,14 @@ +# ToDo: Specialize for StaticArray: +@inline _getindex_or_view(A, idxs...) = view(A, idxs...) + + # ToDo: Add custom rrules for _split_after? # ToDo: Use getindex instead of view for certain cases (array types)? -@inline function split_after(x::AbstractVector, n) +@inline function _split_after(x::AbstractVector, n) i_first = firstindex(x) i_last = lastindex(x) - view(x, i_first, i_first+n-1), view(x, n, i_last) + _getindex_or_view(x, i_first, i_first+n-1), _getindex_or_view(x, n, i_last) end @inline _split_after(x::Tuple, n) = _split_after(x::Tuple, Val{n}()) @@ -65,3 +69,17 @@ end _fill_value(x::Fill) = x.value _fill_axes(x::Fill) = x.axes + + +# ToDo: Flatten to SVector: +_flatten_to_rv(tpl::Tuple{Vararg{Number}}) = vcat(x...) + +# ToDo: +#_flatten_to_vector(tpl::Tuple{Vararg{StaticVector}}) = ... + +_flatten_to_flatten_to_numvectorvector(tpl::Tuple{AbstractVector{<:Number}}) = vcat(tpl...) + +# ToDo: Use faster implementation that handles large numbers of vectors efficiently: +_flatten_to_rv(V::AbstractVector{<:AbstractVector{<:Number}}) = vcat(V...) + +_flatten_to_rv(V::ArrayOfSimilarArray{<:Number}) = flatview(A) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index e99a39b6..2e51aceb 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -214,84 +214,46 @@ end +# Transport from ProductMeasure to StdMeasure type: -function transport_to_mvstd(ν_inner::StdMeasure, μ::ProductMeasure, ab) - _marginals_to_mvstd(ν_inner, marginals(μ), ab) -end - -function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) - _marginals_from_mvstd_with_rest(marginals(ν), μ_inner, x) +function transport_to_mvstd(ν_inner::StdMeasure, μ::ProductMeasure, x) + _marginals_to_mvstd(ν_inner, marginals(μ), x) end +struct _TransportToMvStd{NU<:StdMeasure} <: Function end +(::_TransportToMvStd{NU})(μ, x) where {NU} = transport_to_mvstd(NU(), μ, x) +function _marginals_to_mvstd(::StdMeasure{NU}, marginals_μ::Tuple, x::Tuple) where NU + _flatten_to_rv(map(_TransportToMvStd{NU}(), marginals_μ, x)) +end +function _marginals_to_mvstd(::StdMeasure{NU}, marginals_μ, x) where NU + _flatten_to_rv(broadcast(_TransportToMvStd{NU}(), marginals_μ, x)) +end -# Transport for products - -# Helpers for product transforms and similar: -struct _TransportToStd{NU<:StdMeasure} <: Function end -_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) -struct _TransportFromStd{MU<:StdMeasure} <: Function end -_TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x) +# Transport StdMeasure type to ProductMeasure, with rest: -function _tuple_transport_def( - ν::PowerMeasure{NU}, - μs::Tuple, - xs::Tuple, -) where {NU<:StdMeasure} - reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes) +function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) + marginals_μ = marginals(μ) + marg_dof = _marginals_dof(marginals_μ) + marg_offs = _marginal_offsets(marg_dof) + _marginals_from_mvstd_with_rest(marginals_ν, marg_dof, μ_inner, x) end -function transport_def( - ν::PowerMeasure{NU}, - μ::ProductMeasure{<:Tuple}, - x, -) where {NU<:StdMeasure} - _tuple_transport_def(ν, marginals(μ), x) -end -function transport_def( - ν::PowerMeasure{NU}, - μ::ProductMeasure{<:NamedTuple{names}}, - x, -) where {NU<:StdMeasure,names} - _tuple_transport_def(ν, values(marginals(μ)), values(x)) +function _marginals_dof(marginals_μ::Tuple{Vararg{AbstractMeasure,N}}) where N + map(fast_getdof, marginals_μ) end -@inline _offset_cumsum(s, x, y, rest...) = (s, _offset_cumsum(s + x, y, rest...)...) -@inline _offset_cumsum(s, x) = (s,) -@inline _offset_cumsum(s) = () -function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike) - N = map(getdof, μs) - offs = _offset_cumsum(startidx, N...) - map((o, n) -> o:o+n-1, offs, N) -end - -function _tuple_transport_def( - νs::Tuple, - μ::PowerMeasure{MU}, - x::AbstractArray{<:Real}, -) where {MU<:StdMeasure} - vrs = _stdvar_viewranges(νs, firstindex(x)) - xs = map(r -> view(x, r), vrs) - map(_TransportFromStd{MU}, νs, xs) -end +# ToDo: Use static array for result: +_marginals_to_mvstd(ν_inner::StdMeasure, marginals_μ::Tuple, x) -function transport_def( - ν::ProductMeasure{<:Tuple}, - μ::PowerMeasure{MU}, - x, -) where {MU<:StdMeasure} - _tuple_transport_def(marginals(ν), μ, x) +function _marginals_to_mvstd_split_x(marg_dof::Tuple{Vararg{StaticInteger,N}}, x::Tuple{Vararg{Any,N}}) where N end -function transport_def( - ν::ProductMeasure{<:NamedTuple{names}}, - μ::PowerMeasure{MU}, - x, -) where {MU<:StdMeasure,names} - NamedTuple{names}(_tuple_transport_def(values(marginals(ν)), μ, x)) +function _marginal_offsets(marg_dof::Tuple{Vararg{StaticInteger,N}}) where N + _offset_cumsum(0, marg_dof...) end diff --git a/src/proxies.jl b/src/proxies.jl index 0304bbb3..aeb451ae 100644 --- a/src/proxies.jl +++ b/src/proxies.jl @@ -14,18 +14,25 @@ function proxy end macro useproxy(M) M = esc(M) quote - #!!!!!!!!!!! TODO add new API methods like localmeasure, transportmeasure, etc. !!!!!!!!!!!!! - @inline $MeasureBase.logdensity_def(μ::$M, x) = logdensity_def(proxy(μ), x) + @inline $MeasureBase.unsafe_logdensityof(μ::$M, x) = unsafe_logdensityof(proxy(μ), x) @inline $MeasureBase.basemeasure(μ::$M) = basemeasure(proxy(μ)) - @inline $MeasureBase.basemeasure_depth(μ::$M) = basemeasure_depth(proxy(μ)) + @inline $MeasureBase.rootmeasure(μ::$M) = rootmeasure(proxy(μ)) + @inline $MeasureBase.insupport(μ::$M) = insupport(proxy(μ)) + + @inline $MeasureBase.getdof(μ::$M) = getdof(proxy(μ)) + @inline $MeasureBase.fast_getdof(μ::$M) = fast_getdof(proxy(μ)) + @inline $MeasureBase.transport_origin(μ::$M) = transport_origin(proxy(μ)) @inline $MeasureBase.to_origin(μ::$M, y) = to_origin(proxy(μ), y) @inline $MeasureBase.from_origin(μ::$M, x) = from_origin(proxy(μ), x) + @inline $MeasureBase.localmeasure(μ::$M, x) = from_origin(localmeasure(μ), x) + @inline $MeasureBase.transportmeasure(μ::$M, x) = from_origin(transportmeasure(μ), x) + @inline $MeasureBase.massof(μ::$M) = massof(proxy(μ)) @inline $MeasureBase.massof(μ::$M, s) = massof(proxy(μ), s) From 8af47630cb73a31db5ab27e3894eab85c7899eaa Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 21:27:30 +0200 Subject: [PATCH 053/133] STASH --- src/collection_utils.jl | 18 ++++++++++++------ src/combinators/product_transport.jl | 10 ++++++---- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 02d836b5..91015624 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -1,14 +1,19 @@ -# ToDo: Specialize for StaticArray: -@inline _getindex_or_view(A, idxs...) = view(A, idxs...) +# _get_n_at_offs counts it's offset from 0, not from 1! + +@inline function _get_n_at_offs(A, n::IntegerLike, offset::IntegerLike) + from = firstindex(A) + dynamic(offs) + view(A, from:(from+dynamic(n)-1)) +end +# ToDo: Specialize _get_n_at_offs for StaticArray. # ToDo: Add custom rrules for _split_after? -# ToDo: Use getindex instead of view for certain cases (array types)? -@inline function _split_after(x::AbstractVector, n) +# ToDo: Specialize for StaticVector: +@inline function _split_after(x::AbstractVector, n::IntegerLike) i_first = firstindex(x) i_last = lastindex(x) - _getindex_or_view(x, i_first, i_first+n-1), _getindex_or_view(x, n, i_last) + _get_n_at_offs(x, n, zero(n)), _getindex_or_view(x, n, i_last) end @inline _split_after(x::Tuple, n) = _split_after(x::Tuple, Val{n}()) @@ -32,7 +37,8 @@ end Base.@propagate_inbounds function _as_tuple(v::AbstractVector, ::Val{N}) where {N} @boundcheck @assert length(v) == N # ToDo: Throw proper exception - ntuple(i -> v[i], Val(N)) + i_offs = firstindex(v) - 1 + ntuple(i -> v[i_offs + i], Val(N)) end diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 2e51aceb..de82df8f 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -238,20 +238,22 @@ end function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) marginals_μ = marginals(μ) marg_dof = _marginals_dof(marginals_μ) - marg_offs = _marginal_offsets(marg_dof) _marginals_from_mvstd_with_rest(marginals_ν, marg_dof, μ_inner, x) end +@generated function _split_x_by_marginals_with_rest(marg_dof::Tuple{Vararg{IntegerLike,N}}, x::AbstractVector{<:Real}) where N + expr = () +end -function _marginals_dof(marginals_μ::Tuple{Vararg{AbstractMeasure,N}}) where N +function _marginals_dof(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}) where N map(fast_getdof, marginals_μ) end # ToDo: Use static array for result: -_marginals_to_mvstd(ν_inner::StdMeasure, marginals_μ::Tuple, x) +_marginals_from_mvstd_with_rest(ν_inner::StdMeasure, marginals_μ::Tuple, x) -function _marginals_to_mvstd_split_x(marg_dof::Tuple{Vararg{StaticInteger,N}}, x::Tuple{Vararg{Any,N}}) where N +function _marginals_from_mvstd_with_rest_split_x(marg_dof::Tuple{Vararg{StaticInteger,N}}, x::Tuple{Vararg{Any,N}}) where N end function _marginal_offsets(marg_dof::Tuple{Vararg{StaticInteger,N}}) where N From c28e8542211b44d8dd106b2b6c82d7b69068e1af Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 21:32:00 +0200 Subject: [PATCH 054/133] Add StaticArrays to deps The load time of StaticArrays is rather high, be we can't take advantage of statically known measure sizes without a direct dependency on it. With StaticArrays we'll be able to handle many low-dimension problems without any heap allocations. --- Project.toml | 1 + src/MeasureBase.jl | 2 ++ test/Project.toml | 1 + 3 files changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index c5a75472..e6a9e002 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 216d1f63..c21b5e35 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -37,6 +37,8 @@ using Static using Static: StaticInteger using FunctionChains +using StaticArrays: StaticArray, StaticVector, StaticMatrix, SArray, SVector, SMatrix + export gentype export AbstractMeasure diff --git a/test/Project.toml b/test/Project.toml index f80fdd98..e30fa7a2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,5 +13,6 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 5917a5c35edad4fc12be17cfa786cbac25a8d8d3 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 22:33:15 +0200 Subject: [PATCH 055/133] STASH --- src/collection_utils.jl | 41 ++++++++++++++++++++++++----------------- src/static.jl | 20 ++++++++++++++++++++ 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 91015624..83ec3b36 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -1,19 +1,33 @@ -# _get_n_at_offs counts it's offset from 0, not from 1! +# ToDo: Add custom rrules for the get/view/split/etc. functions defined here. -@inline function _get_n_at_offs(A, n::IntegerLike, offset::IntegerLike) - from = firstindex(A) + dynamic(offs) - view(A, from:(from+dynamic(n)-1)) +Base.@propagate_inbounds _as_tuple(v::AbstractVector, ::Val{N}) where {N} = Tuple(SVector{N}(v)) + +Base.Base.@propagate_inbounds function _get_or_view(A::AbstractVector, from::IntegerLike, until::IntegerLike) + (view(A, from:until)) +end +Base.Base.@propagate_inbounds function _get_or_view(A::AbstractVector, ::StaticInteger{from}, ::StaticInteger{until}) where {from,until} + SVector{until-from+1}(view(A, from:until)) end -# ToDo: Specialize _get_n_at_offs for StaticArray. +# ToDo: Specialize for StaticVector instead of SVector? +Base.Base.@propagate_inbounds function _get_or_view(A::SVector, ::StaticInteger{from}, ::StaticInteger{until}) where {from,until} + # ToDo: Improve implementation: + SVector(_get_or_view(Tuple(A), from, until)) +end + +Base.Base.@propagate_inbounds function _get_or_view(tpl::Tuple, from::IntegerLike, until::IntegerLike) + ntuple(i -> tpl[from + i - 1], Val(until - from + 1)) +end +# ToDo: Is this specialization necessary? +Base.Base.@propagate_inbounds function _get_or_view(tpl::Tuple, ::StaticInteger{from}, ::StaticInteger{until}) where {from,until} + ntuple(i -> tpl[from + i - 1], Val(until - from + 1)) +end -# ToDo: Add custom rrules for _split_after? -# ToDo: Specialize for StaticVector: @inline function _split_after(x::AbstractVector, n::IntegerLike) - i_first = firstindex(x) - i_last = lastindex(x) - _get_n_at_offs(x, n, zero(n)), _getindex_or_view(x, n, i_last) + i_first = _maybestatic_firstindex(x) + i_last = _maybestatic_lastindex(x) + _get_or_view(x, i_first, i_first + n - static(1)), _get_or_view(x, i_first + n, i_last) end @inline _split_after(x::Tuple, n) = _split_after(x::Tuple, Val{n}()) @@ -35,13 +49,6 @@ end end -Base.@propagate_inbounds function _as_tuple(v::AbstractVector, ::Val{N}) where {N} - @boundcheck @assert length(v) == N # ToDo: Throw proper exception - i_offs = firstindex(v) - 1 - ntuple(i -> v[i_offs + i], Val(N)) -end - - _empty_zero(::AbstractVector{T}) where {T<:Real} = Fill(zero(T), 0) diff --git a/src/static.jl b/src/static.jl index da471b62..05d211d2 100644 --- a/src/static.jl +++ b/src/static.jl @@ -61,3 +61,23 @@ end Returns the size of `x` as a tuple of dynamic or static integers. """ maybestatic_size(x) = size(x) + + +""" + MeasureBase.maybestatic_firstindex(x)::Integer + +Returns the first index `x` as a dynamic or static integer. +""" +maybestatic_firstindex(x::AbstractVector) = firstindex(x) +maybestatic_firstindex(::Tuple{Vararg{Any,N}}) where N = static(1) +maybestatic_firstindex(nt::NamedTuple) = maybestatic_firstindex(values(nt)) + + +""" + MeasureBase.maybestatic_lastindex(x)::Integer + +Returns the first index `x` as a dynamic or static integer. +""" +maybestatic_lastindex(x::AbstractVector) = lastindex(x) +maybestatic_lastindex(::Tuple{Vararg{Any,N}}) where N = static(N) +maybestatic_lastindex(nt::NamedTuple) = maybestatic_firstindex(values(nt)) From 94cfd3bc9b644393da83a10502220ce684297a35 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 22:54:04 +0200 Subject: [PATCH 056/133] STASH --- src/collection_utils.jl | 21 ++++++++++----------- src/combinators/product_transport.jl | 3 ++- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 83ec3b36..1b8cfd35 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -49,6 +49,9 @@ end end +# ToDo: Add static reshape for static arrays! + + _empty_zero(::AbstractVector{T}) where {T<:Real} = Fill(zero(T), 0) @@ -79,20 +82,16 @@ end # ToDo: Add custom rrule for _reorder_nt? - +# Field access functions for Fill: _fill_value(x::Fill) = x.value _fill_axes(x::Fill) = x.axes -# ToDo: Flatten to SVector: -_flatten_to_rv(tpl::Tuple{Vararg{Number}}) = vcat(x...) - -# ToDo: -#_flatten_to_vector(tpl::Tuple{Vararg{StaticVector}}) = ... - -_flatten_to_flatten_to_numvectorvector(tpl::Tuple{AbstractVector{<:Number}}) = vcat(tpl...) +_flatten_to_rv(VV::AbstractVector{<:AbstractVector{<:Real}}) = flatview(VectorOfArrays(VV)) +_flatten_to_rv(VV::AbstractVector{<:StaticVector{N,<:Real}}) where N = flatview(VectorOfSimilarArrays(VV)) -# ToDo: Use faster implementation that handles large numbers of vectors efficiently: -_flatten_to_rv(V::AbstractVector{<:AbstractVector{<:Number}}) = vcat(V...) +_flatten_to_rv(VV::VectorOfSimilarVectors{<:Real}) = flatview(VV) +_flatten_to_rv(VV::VectorOfVectors{<:Real}) = flatview(VV) -_flatten_to_rv(V::ArrayOfSimilarArray{<:Number}) = flatview(A) +_flatten_to_rv(tpl::Tuple{<:AbstractVector{<:Real}}) = vcat(tpl...) +_flatten_to_rv(tpl::Tuple{<:StaticVector{N,<:Real}}) where N = vcat(tpl...) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index de82df8f..7fafbf73 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -250,7 +250,8 @@ function _marginals_dof(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}) where N end -# ToDo: Use static array for result: +### !!!!!!!!!!!!!!!!!!!!!!!!!!!!! TODO ########################### + _marginals_from_mvstd_with_rest(ν_inner::StdMeasure, marginals_μ::Tuple, x) function _marginals_from_mvstd_with_rest_split_x(marg_dof::Tuple{Vararg{StaticInteger,N}}, x::Tuple{Vararg{Any,N}}) where N From d958c5185616cf6a941e722b150686e378c6ca98 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 1 Jul 2023 23:05:54 +0200 Subject: [PATCH 057/133] STASH FIXES --- src/combinators/bind.jl | 2 +- src/combinators/combined.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index bb62f801..972da552 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -160,7 +160,7 @@ function _bind_tsc_cat(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::Name end -@inline insupport(::Bind, x) = NoFastInsupport() +@inline insupport(μ::Bind, ::Any) = NoFastInsupport{typeof(μ)}() @inline getdof(μ::Bind) = NoDOF{typeof(μ)}() diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index e17252e6..39b41c20 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -90,8 +90,7 @@ struct CombinedMeasure{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMe end -# TODO: Could split `ab`` here, but would be wasteful. -@inline insupport(::CombinedMeasure, ab) = NoFastInsupport() +@inline insupport(μ::CombinedMeasure, ab) = NoFastInsupport(typeof(μ)) @inline getdof(μ::CombinedMeasure) = _add_dof(getdof(μ.α), getdof(μ.β)) @inline fast_dof(μ::CombinedMeasure) =_add_dof(fast_dof(μ.α), fast_dof(μ.β)) From 4d8b1750d9e327e0dbbe2325f70caaa2bfa46296 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 19:54:48 +0200 Subject: [PATCH 058/133] STASH --- src/collection_utils.jl | 4 ++-- src/static.jl | 41 ++++++++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 1b8cfd35..cbb7251e 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -25,8 +25,8 @@ end @inline function _split_after(x::AbstractVector, n::IntegerLike) - i_first = _maybestatic_firstindex(x) - i_last = _maybestatic_lastindex(x) + i_first = maybestatic_firstindex(x) + i_last = maybestatic_lastindex(x) _get_or_view(x, i_first, i_first + n - static(1)), _get_or_view(x, i_first + n, i_last) end diff --git a/src/static.jl b/src/static.jl index 05d211d2..9c170929 100644 --- a/src/static.jl +++ b/src/static.jl @@ -42,6 +42,7 @@ end FillArrays.Fill(x, dyn_axs) end + """ MeasureBase.maybestatic_length(x)::IntegerLike @@ -49,35 +50,53 @@ Returns the length of `x` as a dynamic or static integer. """ maybestatic_length(x) = length(x) maybestatic_length(x::AbstractUnitRange) = length(x) +maybestatic_length(::Tuple{Vararg{Any,N}}) where N = static(N) +maybestatic_length(nt::NamedTuple) = maybestatic_length(values(nt)) +maybestatic_length(x::StaticArray) = maybestatic_length(maybestatic_eachindex(x)) +maybestatic_length(::StaticArrays.SOneTo{N}) where {N} = static(N) function maybestatic_length( ::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}, ) where {A,B} StaticInt{B - A + 1}() end + """ MeasureBase.maybestatic_size(x)::Tuple{Vararg{IntegerLike}} Returns the size of `x` as a tuple of dynamic or static integers. """ -maybestatic_size(x) = size(x) +maybestatic_size(x) = map(maybestatic_length, axes(x)) + + +""" + MeasureBase.maybestatic_eachindex(x) + +Returns the the index range of `x` as a dynamic or static integer range +""" +maybestatic_eachindex(x::AbstractArray) = _conv_static_eachindex(eachindex(x)) +maybestatic_eachindex(::Tuple{Vararg{Any,N}}) where N = static(1):static(N) +maybestatic_eachindex(nt::NamedTuple) = maybestatic_eachindex(values(nt)) + +_conv_static_eachindex(idxs) = idxs +_conv_static_eachindex(::Static.SOneTo{N}) where {N} = static(1):static(N) """ - MeasureBase.maybestatic_firstindex(x)::Integer + MeasureBase.maybestatic_first(A) -Returns the first index `x` as a dynamic or static integer. +Returns the first element of `A` as a dynamic or static value. """ -maybestatic_firstindex(x::AbstractVector) = firstindex(x) -maybestatic_firstindex(::Tuple{Vararg{Any,N}}) where N = static(1) -maybestatic_firstindex(nt::NamedTuple) = maybestatic_firstindex(values(nt)) +maybestatic_first(A::AbstractArray) = first(A) +maybestatic_first(::StaticArrays.SOneTo{N}) where N = static(1) +maybestatic_first(::Static.OptionallyStaticUnitRange{<:Static.StaticInteger{from},<:Static.StaticInteger}) where from = static(from) """ - MeasureBase.maybestatic_lastindex(x)::Integer + MeasureBase.maybestatic_last(A) -Returns the first index `x` as a dynamic or static integer. +Returns the last element of `A` as a dynamic or static value. """ -maybestatic_lastindex(x::AbstractVector) = lastindex(x) -maybestatic_lastindex(::Tuple{Vararg{Any,N}}) where N = static(N) -maybestatic_lastindex(nt::NamedTuple) = maybestatic_firstindex(values(nt)) +maybestatic_last(A::AbstractArray) = last(A) +maybestatic_last(::StaticArrays.SOneTo{N}) where N = static(N) +maybestatic_last(::Static.OptionallyStaticUnitRange{<:Static.StaticInteger,<:Static.StaticInteger{until}}) where until = static(until) From ba6f70019d4918d342eef872730be38cfe3edc50 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 21:39:08 +0200 Subject: [PATCH 059/133] STASH --- src/collection_utils.jl | 2 +- src/combinators/product_transport.jl | 56 ++++++++++++++++++++++------ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index cbb7251e..6563477a 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -27,7 +27,7 @@ end @inline function _split_after(x::AbstractVector, n::IntegerLike) i_first = maybestatic_firstindex(x) i_last = maybestatic_lastindex(x) - _get_or_view(x, i_first, i_first + n - static(1)), _get_or_view(x, i_first + n, i_last) + _get_or_view(x, i_first, i_first + n - one(n)), _get_or_view(x, i_first + n, i_last) end @inline _split_after(x::Tuple, n) = _split_after(x::Tuple, Val{n}()) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 7fafbf73..83a5ae0d 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -235,28 +235,60 @@ end # Transport StdMeasure type to ProductMeasure, with rest: +const _MaybeUnkownDOF = Union{IntegerLike,AbstractNoDOF} + +const _KnownDOFs = Union{Tuple{Vararg{IntegerLike,N}} where N, StaticVector{<:IntegerLike}} +const _MaybeUnkownKnownDOFs = Union{Tuple{Vararg{_MaybeUnkownDOF,N}} where N, StaticVector{<:_MaybeUnkownDOF}} + function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) - marginals_μ = marginals(μ) - marg_dof = _marginals_dof(marginals_μ) - _marginals_from_mvstd_with_rest(marginals_ν, marg_dof, μ_inner, x) + νs = marginals(ν) + dofs = map(fast_getdof, marginals_μ) + return _marginals_from_mvstd_with_rest(νs, dofs, μ_inner, x) end -@generated function _split_x_by_marginals_with_rest(marg_dof::Tuple{Vararg{IntegerLike,N}}, x::AbstractVector{<:Real}) where N - expr = () +function _dof_access_firstidxs(dofs::Tuple{Vararg{IntegerLike,N}}, first_idx) where N + cumsum((first_idx, dofs[begin:end-1]...)) end -function _marginals_dof(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}) where N - map(fast_getdof, marginals_μ) +function _dof_access_firstidxs(dofs::AbstractVector{<:IntegerLike}, first_idx) where N + # ToDo: Improve imlementation (reduce memory allocations) + cumsum(vcat([eltype(dofs)(first_idx)], dofs[begin:end-1])) end +function _split_x_by_marginals_with_rest(dofs::Union{Tuple,AbstractVector}, x::AbstractVector{<:Real}) + x_idxs = maybestatic_eachindex(x) + first_idxs = _dof_access_firstidxs(dofs, maybestatic_first(x_idxs)) + xs = map((from, n) -> _get_or_view(x, from, from + n - one(n)), first_idxs, dofs) + x_rest = _get_or_view(x, first_idxs[end] + dofs[end], maybestatic_last(x_idxs)) + return xs, r_rest +end -### !!!!!!!!!!!!!!!!!!!!!!!!!!!!! TODO ########################### +function _marginals_from_mvstd_with_rest(νs, dofs::_KnownDOFs, μ_inner::StdMeasure, x::AbstractVector{<:Real}) + xs, x_rest = _split_x_by_marginals_with_rest + # ToDo: Is this ideal? + μs = map(n -> μ_inner^n, dofs) + ys = map(transport_def, νs, μs, xs) + return ys, x_rest +end -_marginals_from_mvstd_with_rest(ν_inner::StdMeasure, marginals_μ::Tuple, x) +function _marginals_from_mvstd_with_rest(νs::Tuple, dofs::_MaybeUnkownKnownDOFs, μ_inner::StdMeasure, x::AbstractVector{<:Real}) + _marginals_from_mvstd_with_rest_nodof(νs, μ_inner, x) +end -function _marginals_from_mvstd_with_rest_split_x(marg_dof::Tuple{Vararg{StaticInteger,N}}, x::Tuple{Vararg{Any,N}}) where N +function _marginals_from_mvstd_with_rest_nodof(νs::Tuple{Vararg{AbstractMeasure,N}}, μ_inner::StdMeasure, x::AbstractVector{<:Real}) where N + # ToDo: Check for type stability, may need generated function + y1, x_rest = transport_from_mvstd_with_rest(νs[1], μ_inner, x) + y2_end, x_final_rest = _marginals_from_mvstd_with_rest(νs[2:end], μ_inner, x_rest) + return (y1, y2_end...), x_final_rest end -function _marginal_offsets(marg_dof::Tuple{Vararg{StaticInteger,N}}) where N - _offset_cumsum(0, marg_dof...) +function _marginals_from_mvstd_with_rest_nodof(νs::AbstractVector{<:AbstractMeasure}, μ_inner::StdMeasure, x::AbstractVector{<:Real}) + # ToDo: Check for type stability, may need generated function + y1, x_rest = transport_from_mvstd_with_rest(νs[1], μ_inner, x) + ys = [y1] + for ν in νs[begin+1:end] + y_i, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x_rest) + ys = vcat(ys, y_i) + end + return ys, x_rest end From 79c72cb897cea281095f1a18867feda8a220e726 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 21:58:25 +0200 Subject: [PATCH 060/133] STASH FIXES --- src/MeasureBase.jl | 9 +++++---- src/combinators/bind.jl | 10 +++++----- src/combinators/combined.jl | 4 ++-- src/combinators/power.jl | 2 +- src/combinators/product.jl | 4 ++-- src/combinators/product_transport.jl | 18 +++++++++--------- src/getdof.jl | 6 ------ src/proxies.jl | 2 +- 8 files changed, 25 insertions(+), 30 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index c21b5e35..07597215 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -37,6 +37,7 @@ using Static using Static: StaticInteger using FunctionChains +import StaticArrays using StaticArrays: StaticArray, StaticVector, StaticMatrix, SArray, SVector, SMatrix export gentype @@ -126,13 +127,16 @@ include("getdof.jl") include("transport.jl") include("schema.jl") include("splat.jl") -include("proxies.jl") include("kernel.jl") include("parameterized.jl") include("domains.jl") include("primitive.jl") include("utils.jl") include("mass-interface.jl") +include("density.jl") +include("density-core.jl") + +include("proxies.jl") # include("absolutecontinuity.jl") include("primitives/counting.jl") @@ -165,9 +169,6 @@ include("combinators/half.jl") include("rand.jl") include("fixedrng.jl") -include("density.jl") -include("density-core.jl") - include("interface.jl") include("measure_operators.jl") diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 972da552..ba391bfc 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -134,10 +134,10 @@ function _bind_tsc(f_c::Type{Pair}, μ::Bind, xy::Pair) return tpm_μ, x, y end -const _CatBind{FC} = _BindBy{<:Any,<:Any,FC} +const _BindBy{FC} = Bind{<:Any,<:Any,FC} -_bind_tsc(f_c::typeof(vcat), μ::_CatBind{typeof{vcat}}, xy::AbstractVector) = _bind_tsc_cat(f_c, μ, xy) -_bind_tsc(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) = _bind_tsc_cat(f_c, μ, xy) +_bind_tsc(f_c::typeof(vcat), μ::_BindBy{typeof{vcat}}, xy::AbstractVector) = _bind_tsc_cat(f_c, μ, xy) +_bind_tsc(f_c::typeof(merge), μ::_BindBy{typeof{merge}}, xy::NamedTuple) = _bind_tsc_cat(f_c, μ, xy) function _bind_tsc_cat_lμabyxy(f_c, μ, xy) tpm_α, a, by = tpmeasure_split_combined(μ.f_c, μ.α, xy) @@ -147,14 +147,14 @@ function _bind_tsc_cat_lμabyxy(f_c, μ, xy) return tpm_μ, a, b, y, xy end -function _bind_tsc_cat(f_c::typeof(vcat), μ::_CatBind{typeof{vcat}}, xy::AbstractVector) +function _bind_tsc_cat(f_c::typeof(vcat), μ::_BindBy{typeof{vcat}}, xy::AbstractVector) tpm_μ, a, b, y, xy = _bind_tsc_cat_lμabyxy(f_c, μ, xy) # Don't use `x = f_c(a, b)` here, would allocate, splitting xy can use views: x, y = _split_after(xy, length(a) + length(b)) return tpm_μ, x, y end -function _bind_tsc_cat(f_c::typeof(merge), μ::_CatBind{typeof{merge}}, xy::NamedTuple) +function _bind_tsc_cat(f_c::typeof(merge), μ::_BindBy{typeof{merge}}, xy::NamedTuple) tpm_μ, a, b, y, xy = _bind_tsc_cat_lμabyxy(f_c, μ, xy) return tpm_μ, f_c(a, b), y end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 39b41c20..a737b7ff 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -23,7 +23,7 @@ function tpmeasure_split_combined(f_c, α::AbstractMeasure, ab) return transportmeasure(α, a), a, b end -@inline _generic_split_combined(::typeof(tuple), ::AbstractMeasure, x::Tuple{Vararg{Any,2}}) +@inline _generic_split_combined(::typeof(tuple), ::AbstractMeasure, x::Tuple{Vararg{Any,2}}) = x @inline _generic_split_combined(::Type{Pair}, ::AbstractMeasure, ab::Pair) = (ab...,) function _generic_split_combined(f_c::FC, α::AbstractMeasure, ab) where FC @@ -39,7 +39,7 @@ function _split_variate_byvalue(::typeof(merge), ::NamedTuple{names_a}, ab::Name end -""" +@doc raw""" mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) Combines two measures `α` and `β` to a combined measure via a point combination diff --git a/src/combinators/power.jl b/src/combinators/power.jl index bbbdc198..603232a3 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -183,4 +183,4 @@ end Represents and N-dimensional power of the standard measure `MU()`. """ -const StdPowerMeasure{N,MU<:StdMeasure} = PowerMeasure{MU,<:NTuple{N,Base.OneTo}} +const StdPowerMeasure{MU<:StdMeasure,N} = PowerMeasure{MU,<:NTuple{N,Base.OneTo}} diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 7d036d75..98e9681e 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -101,7 +101,7 @@ end # For tuples, `mapreduce` can have trouble with type inference sum(map(density_op, marginals_μ, x)) end -@inline function _marginals_density_op(density_op::F, marginals_μ::NameTuple{names}, x::NameTuple) where {F,names} +@inline function _marginals_density_op(density_op::F, marginals_μ::NamedTuple{names}, x::NamedTuple) where {F,names} nms = Val{names}() _marginals_density_op(density_op, marginals_μ, _reorder_nt(values(x), nms)) end @@ -113,7 +113,7 @@ end # For tuples, `mapreduce` can have trouble with type inference sum(map(density_op, marginals_μ, marginals_ν, x)) end -@inline function _marginals_density_op(density_op::F, marginals_μ::NameTuple{names}, marginals_ν::NameTuple, x::NameTuple) where {F,names} +@inline function _marginals_density_op(density_op::F, marginals_μ::NamedTuple{names}, marginals_ν::NamedTuple, x::NamedTuple) where {F,names} nms = Val{names}() _marginals_density_op(density_op, marginals_μ, _reorder_nt(values(marginals_ν), nms), _reorder_nt(values(x), nms)) end diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 83a5ae0d..a80cb972 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -50,9 +50,9 @@ end # A one-dimensional PowerMeasure has an origin if it's parent has an origin: -transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _origin_pwr(::typeof(μ), transport_origin(μ.parent), μ.axes) -_pwr_origin(::Type{MU}, parent_origin, axes) = parent_origin^axes -_pwr_origin(::Type{MU}, ::NoTransportOrigin, axes) = NoTransportOrigin{MU} +transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _origin_pwr(typeof(μ), transport_origin(μ.parent), μ.axes) +_pwr_origin(::Type{MU}, parent_origin, axes) where MU = parent_origin^axes +_pwr_origin(::Type{MU}, ::NoTransportOrigin, axes) where MU = NoTransportOrigin{MU} function from_origin(μ::PowerMeasure{<:AbstractMeasure,1}, x_origin) # Sanity check, should never fail: @@ -101,7 +101,7 @@ function _to_mvstd_withorigin(ν_inner::StdMeasure, ::AbstractMeasure, μ_origin from_origin(x_origin) end -function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, NoTransportOrigin, x) +function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoTransportOrigin, x) throw(ArgumentError("Don't know how to transport values of type $(nameof(typeof(x))) from $(nameof(typeof(μ))) to a power of $(nameof(typeof(ν_inner)))")) end @@ -151,7 +151,7 @@ function _from_mvstd_with_rest_withorigin(::AbstractMeasure, ν_origin, μ_inner from_origin(x_origin), x_rest end -function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, NoTransportOrigin, μ_inner::StdMeasure, x) +function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, ::NoTransportOrigin, μ_inner::StdMeasure, x) throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) end @@ -223,11 +223,11 @@ end struct _TransportToMvStd{NU<:StdMeasure} <: Function end (::_TransportToMvStd{NU})(μ, x) where {NU} = transport_to_mvstd(NU(), μ, x) -function _marginals_to_mvstd(::StdMeasure{NU}, marginals_μ::Tuple, x::Tuple) where NU +function _marginals_to_mvstd(::NU, marginals_μ::Tuple, x::Tuple) where {NU<:StdMeasure} _flatten_to_rv(map(_TransportToMvStd{NU}(), marginals_μ, x)) end -function _marginals_to_mvstd(::StdMeasure{NU}, marginals_μ, x) where NU +function _marginals_to_mvstd(::NU, marginals_μ, x) where {NU<:StdMeasure} _flatten_to_rv(broadcast(_TransportToMvStd{NU}(), marginals_μ, x)) end @@ -242,7 +242,7 @@ const _MaybeUnkownKnownDOFs = Union{Tuple{Vararg{_MaybeUnkownDOF,N}} where N, St function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) νs = marginals(ν) - dofs = map(fast_getdof, marginals_μ) + dofs = map(fast_dof, marginals_μ) return _marginals_from_mvstd_with_rest(νs, dofs, μ_inner, x) end @@ -250,7 +250,7 @@ function _dof_access_firstidxs(dofs::Tuple{Vararg{IntegerLike,N}}, first_idx) wh cumsum((first_idx, dofs[begin:end-1]...)) end -function _dof_access_firstidxs(dofs::AbstractVector{<:IntegerLike}, first_idx) where N +function _dof_access_firstidxs(dofs::AbstractVector{<:IntegerLike}, first_idx) # ToDo: Improve imlementation (reduce memory allocations) cumsum(vcat([eltype(dofs)(first_idx)], dofs[begin:end-1])) end diff --git a/src/getdof.jl b/src/getdof.jl index 2ee16ea8..9380fcb2 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -74,12 +74,6 @@ export fast_dof fast_dof(μ) = getdof(μ) -# Prevent infinite recursion: -@inline _default_fastdof(::Type{MU}, ::MU) where {MU} = NoFastDOF{MU}() -@inline _default_fastdof(::Type{MU}, mu_base) where {MU} = fast_dof(mu_base) - -@inline fast_dof(μ::MU) where {MU} = _default_fastdof(MU, basemeasure(μ)) - """ MeasureBase.some_dof(μ::AbstractMeasure) diff --git a/src/proxies.jl b/src/proxies.jl index aeb451ae..588b21a3 100644 --- a/src/proxies.jl +++ b/src/proxies.jl @@ -24,7 +24,7 @@ macro useproxy(M) @inline $MeasureBase.insupport(μ::$M) = insupport(proxy(μ)) @inline $MeasureBase.getdof(μ::$M) = getdof(proxy(μ)) - @inline $MeasureBase.fast_getdof(μ::$M) = fast_getdof(proxy(μ)) + @inline $MeasureBase.fast_dof(μ::$M) = fast_dof(proxy(μ)) @inline $MeasureBase.transport_origin(μ::$M) = transport_origin(proxy(μ)) @inline $MeasureBase.to_origin(μ::$M, y) = to_origin(proxy(μ), y) From 2da902971f4fa67879268e25dfd93d8c4ddf749c Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 22:09:53 +0200 Subject: [PATCH 061/133] STASH FIXES --- src/combinators/smart-constructors.jl | 4 ++-- src/getdof.jl | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 2ba208d3..19377cef 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -21,7 +21,7 @@ export powermeasure powermeasure(m::AbstractMeasure, ::Tuple{}) = asmeasure(m) -@inline function powermeasure(x::T, sz::Tuple{Vararg{<:Any,N}}) where {T,N} +@inline function powermeasure(x::T, sz::Tuple{Vararg{Any,N}}) where {T,N} PowerMeasure(asmeasure(x), _pm_axes(sz)) end @@ -60,7 +60,7 @@ export productmeasure productmeasure(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) -productmeasure(mar::Tuple{Vararg{<:AbstractMeasure}}) = ProductMeasure(mar) +productmeasure(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar) productmeasure(mar::Tuple) = ProductMeasure(map(asmeasure, mar)) productmeasure(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) diff --git a/src/getdof.jl b/src/getdof.jl index 9380fcb2..996ba747 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -1,9 +1,9 @@ """ - abstract type MeasureBase.AbstractNoDOF + abstract type MeasureBase.AbstractNoDOF{MU} Abstract supertype for [`NoDOF`](@ref) and [`NoFastDOF`](@ref). """ -abstract type AbstractNoDOF end +abstract type AbstractNoDOF{MU} end _add_dof(dof_a::Real, dof_b::Real) = dof_a + dof_b _add_dof(dof_a::AbstractNoDOF, ::Real) = dof_a @@ -12,13 +12,13 @@ _add_dof(dof_a::AbstractNoDOF, ::AbstractNoDOF) = dof_a """ - MeasureBase.NoDOF{MU} <: AbstractNoDOF + MeasureBase.NoDOF{MU} <: AbstractNoDOF{MU} Indicates that there is no way to compute degrees of freedom of a measure of type `MU` with the given information, e.g. because the DOF are not a global property of the measure. """ -struct NoDOF{MU} <: AbstractNoDOF end +struct NoDOF{MU} <: AbstractNoDOF{MU} end """ @@ -44,13 +44,13 @@ export getdof """ - MeasureBase.NoFastDOF{MU} <: AbstractNoDOF + MeasureBase.NoFastDOF{MU} <: AbstractNoDOF{MU} Indicates that there is no way to compute degrees of freedom of a measure of type `MU` with the given information, e.g. because the DOF are not a global property of the measure. """ -struct NoFastDOF{MU} <: AbstractNoDOF end +struct NoFastDOF{MU} <: AbstractNoDOF{MU} end """ From 8c8d4e877c2bacc4a8234dd335e054e2c608efd1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 22:19:00 +0200 Subject: [PATCH 062/133] FIXES --- src/combinators/bind.jl | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index ba391bfc..5f16c473 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -120,26 +120,25 @@ localmeasure(μ::Bind, x) = transportmeasure(μ, x) _get_β_a(μ::Bind, a) = asmeasure(μ.f_β(a)) -tpmeasure_split_combined(f_c, μ::Bind, xy) = _bind_tsc(f_c, μ::Bind, xy) +tpmeasure_split_combined(f_c, μ::Bind, xy) = _bind_tpm_sc(f_c, μ::Bind, xy) -function _bind_tsc(f_c::typeof(tuple), μ::Bind, xy::Tuple{Vararg{Any,2}}) +function _bind_tpm_sc(::typeof(tuple), μ::Bind, xy::Tuple{Vararg{Any,2}}) x, y = x[1], y[1] tpm_μ = transportmeasure(μ, x) return tpm_μ, x, y end -function _bind_tsc(f_c::Type{Pair}, μ::Bind, xy::Pair) +function _bind_tpm_sc(::Type{Pair}, μ::Bind, xy::Pair) x, y = x.first, y.second tpm_μ = transportmeasure(μ, x) return tpm_μ, x, y end -const _BindBy{FC} = Bind{<:Any,<:Any,FC} +const _BindBy{FC} = Bind{<:Any,<:AbstractMeasure,FC} +_bind_tpm_sc(f_c::typeof(vcat), μ::_BindBy{typeof(vcat)}, xy::AbstractVector) = _bind_tpm_sc_cat(f_c, μ, xy) +_bind_tpm_sc(f_c::typeof(merge), μ::_BindBy{typeof(merge)}, xy::NamedTuple) = _bind_tpm_sc_cat(f_c, μ, xy) -_bind_tsc(f_c::typeof(vcat), μ::_BindBy{typeof{vcat}}, xy::AbstractVector) = _bind_tsc_cat(f_c, μ, xy) -_bind_tsc(f_c::typeof(merge), μ::_BindBy{typeof{merge}}, xy::NamedTuple) = _bind_tsc_cat(f_c, μ, xy) - -function _bind_tsc_cat_lμabyxy(f_c, μ, xy) +function _bind_tpm_sc_cat_lμabyxy(f_c, μ, xy) tpm_α, a, by = tpmeasure_split_combined(μ.f_c, μ.α, xy) β_a = _get_β_a(μ, a) tpm_β_a, b, y = tpmeasure_split_combined(f_c, β_a, by) @@ -147,15 +146,15 @@ function _bind_tsc_cat_lμabyxy(f_c, μ, xy) return tpm_μ, a, b, y, xy end -function _bind_tsc_cat(f_c::typeof(vcat), μ::_BindBy{typeof{vcat}}, xy::AbstractVector) - tpm_μ, a, b, y, xy = _bind_tsc_cat_lμabyxy(f_c, μ, xy) +function _bind_tpm_sc_cat(f_c::typeof(vcat), μ::_BindBy{typeof(vcat)}, xy::AbstractVector) + tpm_μ, a, b, y, xy = _bind_tpm_sc_cat_lμabyxy(f_c, μ, xy) # Don't use `x = f_c(a, b)` here, would allocate, splitting xy can use views: x, y = _split_after(xy, length(a) + length(b)) return tpm_μ, x, y end -function _bind_tsc_cat(f_c::typeof(merge), μ::_BindBy{typeof{merge}}, xy::NamedTuple) - tpm_μ, a, b, y, xy = _bind_tsc_cat_lμabyxy(f_c, μ, xy) +function _bind_tpm_sc_cat(f_c::typeof(merge), μ::_BindBy{typeof(merge)}, xy::NamedTuple) + tpm_μ, a, b, y, xy = _bind_tpm_sc_cat_lμabyxy(f_c, μ, xy) return tpm_μ, f_c(a, b), y end From 71674c3e08f7e4a1231eee8d4a2b725b97285549 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 22:20:54 +0200 Subject: [PATCH 063/133] FIXUP StaticArrays dep --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index e6a9e002..98cca9d7 100644 --- a/Project.toml +++ b/Project.toml @@ -54,5 +54,6 @@ SpecialFunctions = "2" Static = "0.8, 1" Statistics = "1" Test = "1" +StaticArrays = "1.5" Tricks = "0.1" julia = "1.6" From 9a602ad615c4695798389b262fbb1cee2138502b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 22:54:01 +0200 Subject: [PATCH 064/133] FIXES --- src/combinators/power.jl | 27 ++++++++++++++------------- src/getdof.jl | 10 +++++++--- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 603232a3..41ebd75e 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -55,28 +55,29 @@ function Pretty.tile(μ::PowerMeasure) return Pretty.pair_layout(arg1, arg2; sep = " ^ ") end -# ToDo: Make rand return static arrays for statically-sized power measures. +# ToDo: Make rand and testvalue return static arrays for statically-sized power measures. function _cartidxs(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} CartesianIndices(map(_dynamic, axs)) end -function Base.rand( - rng::AbstractRNG, - ::Type{T}, - d::PowerMeasure{M}, -) where {T,M<:AbstractMeasure} - map(_cartidxs(d.axes)) do _ - rand(rng, T, d.parent) - end +function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} + map(_ -> rand(rng, T, d.parent), _cartidxs(d.axes)) end -function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} - map(_cartidxs(d.axes)) do _ - rand(rng, d.parent) - end +function Base.rand(rng::AbstractRNG, d::PowerMeasure) + map(_ -> rand(rng, d.parent), _cartidxs(d.axes)) end +function testvalue(::Type{T}, d::PowerMeasure) where {T} + map(_ -> testvalue(T, d.parent), _cartidxs(d.axes)) +end + +function testvalue(d::PowerMeasure) + map(_ -> testvalue(d.parent), _cartidxs(d.axes)) +end + + @inline _pm_axes(sz::Tuple{Vararg{IntegerLike,N}}) where {N} = map(one_to, sz) @inline _pm_axes(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} = axs diff --git a/src/getdof.jl b/src/getdof.jl index 996ba747..fc9c13e2 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -5,11 +5,15 @@ Abstract supertype for [`NoDOF`](@ref) and [`NoFastDOF`](@ref). """ abstract type AbstractNoDOF{MU} end -_add_dof(dof_a::Real, dof_b::Real) = dof_a + dof_b -_add_dof(dof_a::AbstractNoDOF, ::Real) = dof_a -_add_dof(::Real, dof_b::AbstractNoDOF) = dof_b +_add_dof(dof_a::Real) = dof_a +_add_dof(dof_a::Real, dof_b::IntegerLike) = dof_a + dof_b +_add_dof(dof_a::AbstractNoDOF, ::IntegerLike) = dof_a +_add_dof(::IntegerLike, dof_b::AbstractNoDOF) = dof_b _add_dof(dof_a::AbstractNoDOF, ::AbstractNoDOF) = dof_a +_mul_dof(dof_a::IntegerLike, n::IntegerLike) = dof_a * n +_mul_dof(dof_a::AbstractNoDOF, ::IntegerLike) = dof_a + """ MeasureBase.NoDOF{MU} <: AbstractNoDOF{MU} From 1d2cb8f92fc9326efb62cdafe9730428740249f2 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 23:03:17 +0200 Subject: [PATCH 065/133] FIXUP --- src/density-core.jl | 2 +- src/proxies.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/density-core.jl b/src/density-core.jl index 849e636c..3cb7b119 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -136,7 +136,7 @@ known to be in the support of both, it can be more efficient to call @inline function logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} inμ = insupport(μ, x) inν = insupport(ν, x) - return unsafe_logdensity_rel(μ, ν, x, inμ, inν) + return _logdensity_rel_impl(μ, ν, x, inμ, inν) end diff --git a/src/proxies.jl b/src/proxies.jl index 588b21a3..ffbf286f 100644 --- a/src/proxies.jl +++ b/src/proxies.jl @@ -30,8 +30,8 @@ macro useproxy(M) @inline $MeasureBase.to_origin(μ::$M, y) = to_origin(proxy(μ), y) @inline $MeasureBase.from_origin(μ::$M, x) = from_origin(proxy(μ), x) - @inline $MeasureBase.localmeasure(μ::$M, x) = from_origin(localmeasure(μ), x) - @inline $MeasureBase.transportmeasure(μ::$M, x) = from_origin(transportmeasure(μ), x) + @inline $MeasureBase.localmeasure(μ::$M, x) = localmeasure(proxy(μ), x) + @inline $MeasureBase.transportmeasure(μ::$M, x) = transportmeasure(proxy(μ), x) @inline $MeasureBase.massof(μ::$M) = massof(proxy(μ)) @inline $MeasureBase.massof(μ::$M, s) = massof(proxy(μ), s) From 074983f699b53acdfd92ea5d1d28ac5218573c07 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 2 Jul 2023 23:42:33 +0200 Subject: [PATCH 066/133] STASH --- src/combinators/power.jl | 38 +++++++++------------------- src/combinators/product.jl | 4 +-- src/combinators/product_transport.jl | 2 +- src/getdof.jl | 17 +++++++------ src/mass-interface.jl | 5 +++- 5 files changed, 28 insertions(+), 38 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 41ebd75e..e7293f14 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -45,7 +45,7 @@ pwr_axes(μ::PowerMeasure) = μ.axes Returns `sz` for `μ = ν^sz`, `sz` being a tuple of integers. """ -pwr_size(μ::PowerMeasure) = map(length, μ.axes) +pwr_size(μ::PowerMeasure) = map(maybestatic_length, μ.axes) function Pretty.tile(μ::PowerMeasure) @@ -103,26 +103,22 @@ params(d::PowerMeasure) = params(first(marginals(d))) basemeasure(d.parent)^d.axes end -@inline function logdensity_def(d::PowerMeasure{M}, x) where {M} +@inline logdensity_def(d::PowerMeasure, x) = _pwr_logdensity_def(d.parent, x, prod(pwr_size(d))) + +@inline _pwr_logdensity_def(::PowerMeasure, x, ::Integer, ::StaticInteger{0}) = static(false) + +@inline function _pwr_logdensity_def(d::PowerMeasure, x, ::IntegerLike) parent = d.parent sum(x) do xj logdensity_def(parent, xj) end end -@inline function logdensity_def(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N} - parent = d.parent - sum(1:N) do j - @inbounds logdensity_def(parent, x[j]) - end -end +# ToDo: Specialized version of _pwr_logdensity_def for statically-sized power measures + +# ToDo: Re-enable this? +# _pwr_logdensity_def(::PowerMeasure{P}, x, ::IntegerLike) where {P<:PrimitiveMeasure} = static(0.0) -@inline function logdensity_def( - d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}}, - x, -) where {M,N} - static(0.0) -end @inline function insupport(μ::PowerMeasure, x) p = μ.parent @@ -140,8 +136,8 @@ end end end -@inline getdof(μ::PowerMeasure) = _mul_dof(getdof(μ.parent), prod(pwr_size(μ))) -@inline fast_dof(μ::PowerMeasure) = _mul_dof(fast_dof(μ.parent), prod(pwr_size(μ))) +@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(pwr_size(μ)) +@inline fast_dof(μ::PowerMeasure) = fast_dof(μ.parent) * prod(pwr_size(μ)) @inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N} static(0) @@ -168,16 +164,6 @@ end massof(m::PowerMeasure) = massof(m.parent)^prod(m.axes) -logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0) - -# To avoid ambiguities -function logdensity_def( - ::PowerMeasure{P,Tuple{Vararg{Static.SOneTo{0},N}}}, - x, -) where {P<:PrimitiveMeasure,N} - static(0.0) -end - """ MeasureBase.StdPowerMeasure{MU<:StdMeasure,N} diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 98e9681e..9d019664 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -188,8 +188,8 @@ end return true end -getdof(d::ProductMeasure) = mapreduce(getdof, _add_dof, marginals(d)) -fast_dof(d::ProductMeasure) = mapreduce(fast_dof, _add_dof, marginals(d)) +getdof(d::ProductMeasure) = sum(getdof, marginals(d)) +fast_dof(d::ProductMeasure) = sum(fast_dof, marginals(d)) function checked_arg(μ::ProductMeasure{<:NTuple{N,Any}}, x::NTuple{N,Any}) where {N} map(checked_arg, marginals(μ), x) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index a80cb972..144bb1f6 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -39,7 +39,7 @@ end # For transport, always pull a PowerMeasure back to one-dimensional PowerMeasure first: -transport_origin(μ::PowerMeasure{<:Any,N}) where N = ν.parent^product(pwr_size(μ)) +transport_origin(μ::PowerMeasure{<:Any,N}) where N = μ.parent^product(pwr_size(μ)) function from_origin(μ::PowerMeasure{<:Any,N}, x_origin) where N # Sanity check, should never fail: diff --git a/src/getdof.jl b/src/getdof.jl index fc9c13e2..1ee82a6f 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -5,14 +5,15 @@ Abstract supertype for [`NoDOF`](@ref) and [`NoFastDOF`](@ref). """ abstract type AbstractNoDOF{MU} end -_add_dof(dof_a::Real) = dof_a -_add_dof(dof_a::Real, dof_b::IntegerLike) = dof_a + dof_b -_add_dof(dof_a::AbstractNoDOF, ::IntegerLike) = dof_a -_add_dof(::IntegerLike, dof_b::AbstractNoDOF) = dof_b -_add_dof(dof_a::AbstractNoDOF, ::AbstractNoDOF) = dof_a - -_mul_dof(dof_a::IntegerLike, n::IntegerLike) = dof_a * n -_mul_dof(dof_a::AbstractNoDOF, ::IntegerLike) = dof_a +Base.:+(nodof::AbstractNoDOF) = nodof +Base.:+(::IntegerLike, nodof::AbstractNoDOF) = nodof +Base.:+(nodof::AbstractNoDOF, ::IntegerLike) = nodof +Base.:+(nodof::AbstractNoDOF, ::AbstractNoDOF) = nodof + +Base.:*(nodof::AbstractNoDOF) = nodof +Base.:*(::IntegerLike, nodof::AbstractNoDOF) = nodof +Base.:*(nodof::AbstractNoDOF, ::IntegerLike) = nodof +Base.:*(nodof::AbstractNoDOF, ::AbstractNoDOF) = nodof """ diff --git a/src/mass-interface.jl b/src/mass-interface.jl index 7b0518f9..2a96283f 100644 --- a/src/mass-interface.jl +++ b/src/mass-interface.jl @@ -22,7 +22,10 @@ for T in (:UnknownFiniteMass, :UnknownMass) @eval begin Base.:+(::$T, ::$T) = $T() Base.:*(::$T, ::$T) = $T() - Base.:^(::$T, k::Number) = isfinite(k) ? $T() : UnknownMass() + Base.:^(::$T, k::Real) = isfinite(k) ? $T() : UnknownMass() + # Disambiguation: + Base.:^(::$T, k::Integer) = isfinite(k) ? $T() : UnknownMass() + Base.:^(::$T, k::Rational) = isfinite(k) ? $T() : UnknownMass() end end From d6e8022682003df1f9c4134be64b2b3f74535701 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 3 Jul 2023 00:17:26 +0200 Subject: [PATCH 067/133] STASH --- src/combinators/power.jl | 1 + src/fixedrng.jl | 7 ++++--- src/mass-interface.jl | 12 ++++++++---- src/primitives/lebesgue.jl | 6 +++--- src/static.jl | 2 ++ 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index e7293f14..7420276d 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -78,6 +78,7 @@ function testvalue(d::PowerMeasure) end +@inline _pm_axes(::Tuple{}) = () @inline _pm_axes(sz::Tuple{Vararg{IntegerLike,N}}) where {N} = map(one_to, sz) @inline _pm_axes(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} = axs diff --git a/src/fixedrng.jl b/src/fixedrng.jl index 232b0891..31991418 100644 --- a/src/fixedrng.jl +++ b/src/fixedrng.jl @@ -5,9 +5,10 @@ Base.rand(::FixedRNG) = one(Float64) / 2 Random.randn(::FixedRNG) = zero(Float64) Random.randexp(::FixedRNG) = one(Float64) -Base.rand(::FixedRNG, ::Type{T}) where {T<:Real} = one(T) / 2 -Random.randn(::FixedRNG, ::Type{T}) where {T<:Real} = zero(T) -Random.randexp(::FixedRNG, ::Type{T}) where {T<:Real} = one(T) +# Use Random.BitFloatType instead of Real to avoid ambiguities: +Base.rand(::FixedRNG, ::Type{T}) where {T<:Random.BitFloatType} = one(T) / 2 +Random.randn(::FixedRNG, ::Type{T}) where {T<:Random.BitFloatType} = zero(T) +Random.randexp(::FixedRNG, ::Type{T}) where {T<:Random.BitFloatType} = one(T) # We need concrete type parameters to avoid amiguity for these cases for T in [Float16, Float32, Float64] diff --git a/src/mass-interface.jl b/src/mass-interface.jl index 2a96283f..55644a61 100644 --- a/src/mass-interface.jl +++ b/src/mass-interface.jl @@ -107,9 +107,6 @@ isnormalized(x, p::Real = 2) = isone(norm(x, p)) isone(::AbstractUnknownMass) = false -function massof(m, s) - _massof(m, s, rootmeasure(m)) -end """ (m::AbstractMeasure)(s) @@ -119,4 +116,11 @@ in this way, users should add the corresponding `massof` method. """ (m::AbstractMeasure)(s) = massof(m, s) -massof(μ, a_b::AbstractInterval) = smf(μ, rightendpoint(a_b)) - smf(μ, leftendpoint(a_b)) +function massof(m, s) + _default_massof_impl(m, s, rootmeasure(μ)) +end + +# # ToDo: Use smf if defined +#function _default_massof_impl(μ, a_b::AbstractInterval, ::LebesgueBase) +# smf(μ, rightendpoint(a_b)) - smf(μ, leftendpoint(a_b)) +#end diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index 8c42766f..2d2ed8dd 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -26,12 +26,12 @@ end massof(::LebesgueBase) = static(Inf) -function _massof(m, s::Interval, ::LebesgueBase) +function _default_massof_impl(m, s::AbstractInterval, ::LebesgueBase) mass = massof(m) nu = mass * StdUniform() f = transport_to(nu, m) - a = f(minimum(s)) - b = f(maximum(s)) + a = f(leftendpoint(s)) + b = f(rightendpoint(s)) return mass * abs(b - a) end diff --git a/src/static.jl b/src/static.jl index 9c170929..e0213096 100644 --- a/src/static.jl +++ b/src/static.jl @@ -29,6 +29,8 @@ Returns an instance of `FillArrays.Fill`. """ function fill_with end +@inline fill_with(x::T, ::Tuple{}) where T = FillArrays.Fill(x) + @inline function fill_with(x::T, sz::Tuple{Vararg{IntegerLike,N}}) where {T,N} fill_with(x, map(one_to, sz)) end From a708815508385637b2a03e600bd82494ef7b55ac Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 3 Jul 2023 00:45:14 +0200 Subject: [PATCH 068/133] FIXES --- src/combinators/bind.jl | 2 +- src/combinators/combined.jl | 6 +++--- src/combinators/product_transport.jl | 14 ++++++++------ src/density-core.jl | 4 ++-- src/getdof.jl | 4 ++-- src/insupport.jl | 8 +++++--- src/mass-interface.jl | 2 +- 7 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 5f16c473..84a815aa 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -111,7 +111,7 @@ The resulting measure behaves like `μ` in the infinitesimal neighborhood of `x` in respect to density calculation and transport as well. """ function transportmeasure(μ::Bind, x) - tpm_α, a, b = tpmeasure_split_combined(μ.α, x) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, x) tpm_β_a = transportmeasure(_get_β_a(μ, a), b) mcombine(μ.f_c, tpm_α, tpm_β_a) end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index a737b7ff..f85c33d2 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -90,10 +90,10 @@ struct CombinedMeasure{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMe end -@inline insupport(μ::CombinedMeasure, ab) = NoFastInsupport(typeof(μ)) +@inline insupport(μ::CombinedMeasure, ab) = NoFastInsupport{typeof(μ)}() -@inline getdof(μ::CombinedMeasure) = _add_dof(getdof(μ.α), getdof(μ.β)) -@inline fast_dof(μ::CombinedMeasure) =_add_dof(fast_dof(μ.α), fast_dof(μ.β)) +@inline getdof(μ::CombinedMeasure) = getdof(μ.α) + getdof(μ.β) +@inline fast_dof(μ::CombinedMeasure) = fast_dof(μ.α) + fast_dof(μ.β) # Bypass `checked_arg`, would require require splitting ab: @inline checked_arg(::CombinedMeasure, ab) = ab diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 144bb1f6..5715c723 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -61,14 +61,16 @@ function from_origin(μ::PowerMeasure{<:AbstractMeasure,1}, x_origin) end -# Transport between univariate standard measures and power measures of size one: +# Transport between univariate standard measures and 1-dim power measures of size one: -function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) +function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure,1}, x) return transport_def(ν, μ.parent, only(x)) end -function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) - return fill_with(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)) +function transport_def(ν::PowerMeasure{<:StdMeasure,1}, μ::StdMeasure, x) + axes_ν = pwr_axes(ν) + @assert prod(axes_ν) == 1 + return fill_with(transport_def(ν.parent, μ, x), map(maybestatic_length, ν.axes)) end function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{NU,1}, x,) where {MU,NU} @@ -264,14 +266,14 @@ function _split_x_by_marginals_with_rest(dofs::Union{Tuple,AbstractVector}, x::A end function _marginals_from_mvstd_with_rest(νs, dofs::_KnownDOFs, μ_inner::StdMeasure, x::AbstractVector{<:Real}) - xs, x_rest = _split_x_by_marginals_with_rest + xs, x_rest = _split_x_by_marginals_with_rest(dofs, x) # ToDo: Is this ideal? μs = map(n -> μ_inner^n, dofs) ys = map(transport_def, νs, μs, xs) return ys, x_rest end -function _marginals_from_mvstd_with_rest(νs::Tuple, dofs::_MaybeUnkownKnownDOFs, μ_inner::StdMeasure, x::AbstractVector{<:Real}) +function _marginals_from_mvstd_with_rest(νs, ::_MaybeUnkownKnownDOFs, μ_inner::StdMeasure, x::AbstractVector{<:Real}) _marginals_from_mvstd_with_rest_nodof(νs, μ_inner, x) end diff --git a/src/density-core.jl b/src/density-core.jl index 3cb7b119..4a61e940 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -161,12 +161,12 @@ end @inline function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X} logd = unsafe_logdensity_rel(μ, ν, x) - return istrue(inμ) ? logd : logd * oftypeof(logd, -Inf) + return istrue(inμ) ? logd : logd * oftype(logd, -Inf) end @inline function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X} logd = unsafe_logdensity_rel(μ, ν, x) - return istrue(inν) ? logd : logd * oftypeof(logd, +Inf) + return istrue(inν) ? logd : logd * oftype(logd, +Inf) end diff --git a/src/getdof.jl b/src/getdof.jl index 1ee82a6f..fa60b8c0 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -94,13 +94,13 @@ the DOF is constant of the measurable space). """ function some_dof end -function some_dof() +function some_dof(μ) m = asmeasure(μ) _try_direct_dof(m, getdof(m)) end _try_direct_dof(::AbstractMeasure, dof::IntegerLike) = dof -_try_direct_dof(μ::AbstractMeasure, ::AbstractNoDOF) = _try_local_dof(μ::AbstractMeasure, some_dof(some_localmeasure(μ))) +_try_direct_dof(μ::AbstractMeasure, ::AbstractNoDOF) = _try_local_dof(μ::AbstractMeasure, some_dof(_some_localmeasure(μ))) _try_local_dof(::AbstractMeasure, dof::IntegerLike) = dof _try_local_dof(μ::AbstractMeasure, ::AbstractNoDOF) = throw(ArgumentError("Can't determine DOF for measure of type $(nameof(typeof(μ)))")) diff --git a/src/insupport.jl b/src/insupport.jl index e4f44ca8..4da0adcd 100644 --- a/src/insupport.jl +++ b/src/insupport.jl @@ -37,9 +37,11 @@ function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) end function require_insupport(μ, x) - r = insupport(μ, x) - if !(r isa NoFastInsupport) || r - throw(ArgumentError("x is not within the support of μ")) + ins = insupport(μ, x) + if !(ins isa NoFastInsupport) + if !ins + throw(ArgumentError("x is not within the support of μ")) + end end return nothing end diff --git a/src/mass-interface.jl b/src/mass-interface.jl index 55644a61..c2cadfc2 100644 --- a/src/mass-interface.jl +++ b/src/mass-interface.jl @@ -117,7 +117,7 @@ in this way, users should add the corresponding `massof` method. (m::AbstractMeasure)(s) = massof(m, s) function massof(m, s) - _default_massof_impl(m, s, rootmeasure(μ)) + _default_massof_impl(m, s, rootmeasure(m)) end # # ToDo: Use smf if defined From 7422be11203d5dbfbb03b2894681ccc842e74067 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 3 Jul 2023 00:49:23 +0200 Subject: [PATCH 069/133] FIXES --- src/combinators/bind.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 84a815aa..3e90ddba 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -1,5 +1,5 @@ @doc raw""" - mbind(f_β, α::AbstractMeasure, f_c = second) + mbind(f_β, α::AbstractMeasure, f_c = x -> x[2]) Constructs a monadic bind, resp. a hierarchical measure, from a transition kernel function `f_β`, a primary measure `α` and a variate combination @@ -22,7 +22,7 @@ has the mathethematical interpretation (on sets $$A$$ and $$B$$) \mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) ``` -When using the default `fc = second` (so `ab == b`) this simplies to +When using the default `fc = x -> x[2]` (so `ab == b`) this simplies to ```math \mu(B) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) @@ -81,7 +81,7 @@ logdensityof(posterior, θ) function mbind end export mbind -@inline function mbind(f_β, α::AbstractMeasure, f_c = second) +@inline function mbind(f_β, α::AbstractMeasure, f_c = x -> x[2]) F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) Bind{F,M,G}(f_β, α, f_c) end From b062e0ad63d060b460bdee104cd52cc340186d50 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 3 Jul 2023 01:00:51 +0200 Subject: [PATCH 070/133] Add ArraysOfArrays to deps Necessary for efficient allocation of multivariate sample sets. --- Project.toml | 2 ++ src/MeasureBase.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index 98cca9d7..b8faa57f 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chad Scherrer ", "Oliver Schulz Date: Mon, 3 Jul 2023 01:00:58 +0200 Subject: [PATCH 071/133] STASH FIXES --- src/MeasureBase.jl | 1 + src/collection_utils.jl | 10 ++++++---- src/combinators/bind.jl | 4 ++-- src/combinators/combined.jl | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 890abcdd..3510846d 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -124,6 +124,7 @@ using Compat using IrrationalConstants include("static.jl") +include("collection_utils.jl") include("smf.jl") include("getdof.jl") include("transport.jl") diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 6563477a..8b5ae3df 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -55,16 +55,18 @@ end _empty_zero(::AbstractVector{T}) where {T<:Real} = Fill(zero(T), 0) +#= struct _TupleNamer{names} <: Function end +struct _TupleUnNamer{names} <: Function end + (::TupleNamer{names})(x::Tuple) where names = NamedTuple{names}(x) InverseFunctions.inverse(::TupleNamer{names}) where names = TupleUnNamer{names}() ChangesOfVariables.with_logabsdet_jacobian(::TupleNamer{names}, x::Tuple) where names = static(false) -struct _TupleUnNamer{names} <: Function end (::TupleUnNamer{names})(x::NamedTuple{names}) where {names} = values(x) InverseFunctions.inverse(::TupleUnNamer{names}) where names = TupleNamer{names}() ChangesOfVariables.with_logabsdet_jacobian(::TupleUnNamer{names}, x::NamedTuple{names}) where names = static(false) - +=# _reorder_nt(x::NamedTuple{names},::Val{names}) where {names} = x @@ -83,8 +85,8 @@ end # ToDo: Add custom rrule for _reorder_nt? # Field access functions for Fill: -_fill_value(x::Fill) = x.value -_fill_axes(x::Fill) = x.axes +_fill_value(x::FillArrays.Fill) = x.value +_fill_axes(x::FillArrays.Fill) = x.axes _flatten_to_rv(VV::AbstractVector{<:AbstractVector{<:Real}}) = flatview(VectorOfArrays(VV)) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 3e90ddba..a7960ba9 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -176,14 +176,14 @@ logdensity_def(::Bind, x) = throw(ArgumentError("logdensity_def is not available # Specialize logdensityof to avoid duplicate calculations: function logdensityof(μ::Bind, x) - tpm_α, a, b = tpmeasure_split_combined(μ.α, x) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, x) β_a = _get_β_a(μ, a) logdensityof(tpm_α, a) + logdensityof(β_a, b) end # Specialize unsafe_logdensityof to avoid duplicate calculations: function unsafe_logdensityof(μ::Bind, x) - tpm_α, a, b = tpmeasure_split_combined(μ.α, x) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, x) β_a = _get_β_a(μ, a) unsafe_logdensityof(tpm_α, a) + unsafe_logdensityof(β_a, b) end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index f85c33d2..72890d6c 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -27,7 +27,7 @@ end @inline _generic_split_combined(::Type{Pair}, ::AbstractMeasure, ab::Pair) = (ab...,) function _generic_split_combined(f_c::FC, α::AbstractMeasure, ab) where FC - _split_variate_byvalue(f_c, testvalue(μ), x) + _split_variate_byvalue(f_c, testvalue(μ), ab) end _split_variate_byvalue(::typeof(vcat), test_a::AbstractVector, ab::AbstractVector) = _split_after(ab, length(test_a)) From ddf1f2686d0f58caff61da80d38159aa1a74f073 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 3 Jul 2023 01:32:18 +0200 Subject: [PATCH 072/133] STASH FIXES --- src/collection_utils.jl | 7 ++++--- src/combinators/bind.jl | 6 ++++++ src/combinators/combined.jl | 23 ++++++++++++++--------- src/combinators/product.jl | 14 +++++++------- src/combinators/product_transport.jl | 15 +++++++-------- src/mass-interface.jl | 2 +- 6 files changed, 39 insertions(+), 28 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 8b5ae3df..ca43eb22 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -25,8 +25,9 @@ end @inline function _split_after(x::AbstractVector, n::IntegerLike) - i_first = maybestatic_firstindex(x) - i_last = maybestatic_lastindex(x) + idxs = maybestatic_eachindex(x) + i_first = maybestatic_first(idxs) + i_last = maybestatic_last(idxs) _get_or_view(x, i_first, i_first + n - one(n)), _get_or_view(x, i_first + n, i_last) end @@ -35,7 +36,7 @@ end @generated function _split_after(x::NamedTuple{names}, ::Val{names_a}) where {names, names_a} n = length(names_a) - if names_after[begin:begin+n-1] == names_a + if names[begin:begin+n-1] == names_a names_b = names[n:end] quote a, b = _split_after(x, Val(n)) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index a7960ba9..6c40bba9 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -195,6 +195,12 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, μ::Bind) where {T<:Real} return μ.f_c(a, b) end +function Base.rand(rng::Random.AbstractRNG, μ::Bind) + a = rand(rng, μ.α) + b = rand(rng, _get_β_a(μ, a)) + return μ.f_c(a, b) +end + function transport_to_mvstd(ν_inner::StdMeasure, μ::Bind, ab) tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 72890d6c..21595c26 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -27,7 +27,7 @@ end @inline _generic_split_combined(::Type{Pair}, ::AbstractMeasure, ab::Pair) = (ab...,) function _generic_split_combined(f_c::FC, α::AbstractMeasure, ab) where FC - _split_variate_byvalue(f_c, testvalue(μ), ab) + _split_variate_byvalue(f_c, testvalue(α), ab) end _split_variate_byvalue(::typeof(vcat), test_a::AbstractVector, ab::AbstractVector) = _split_after(ab, length(test_a)) @@ -35,7 +35,7 @@ _split_variate_byvalue(::typeof(vcat), test_a::AbstractVector, ab::AbstractVecto _split_variate_byvalue(::typeof(vcat), ::Tuple{N}, ab::Tuple) where N = _split_after(ab, Val{N}()) function _split_variate_byvalue(::typeof(merge), ::NamedTuple{names_a}, ab::NamedTuple) where names_a - _split_after(ab, Val{names_a}) + _split_after(ab, Val(names_a)) end @@ -66,7 +66,7 @@ function mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) end function mcombine(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) - productmeasure((a, b)) + productmeasure((α, β)) end function mcombine(f_c::Union{typeof(vcat),typeof(merge)}, α::AbstractProductMeasure, β::AbstractProductMeasure) @@ -98,9 +98,9 @@ end # Bypass `checked_arg`, would require require splitting ab: @inline checked_arg(::CombinedMeasure, ab) = ab -rootmeasure(::CombinedMeasure) = mcombine(μ.f_c, rootmeasure(μ), rootmeasure(ν)) +rootmeasure(μ::CombinedMeasure) = mcombine(μ.f_c, rootmeasure(μ.α), rootmeasure(μ.β)) -basemeasure(::CombinedMeasure) = mcombine(μ.f_c, basemeasure(μ), basemeasure(ν)) +basemeasure(μ::CombinedMeasure) = mcombine(μ.f_c, basemeasure(μ.α), basemeasure(μ.β)) function logdensity_def(μ::CombinedMeasure, ab) # Use tpmeasure_split_combined to avoid duplicate calculation of transportmeasure(α): @@ -115,12 +115,17 @@ function logdensityof(μ::CombinedMeasure, ab) end -function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::CombinedMeasure) where {T<:Real} - x_primary = rand(rng, T, h.m) - x_secondary = rand(rng, T, h.f(x_primary)) - return _combine_variates(h.flatten_mode, x_primary, x_secondary) +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, μ::CombinedMeasure) where {T<:Real} + a = rand(rng, T, μ.α) + b = rand(rng, T, μ.β) + return μ.f_c(a, b) end +function Base.rand(rng::Random.AbstractRNG, μ::CombinedMeasure) + a = rand(rng, μ.α) + b = rand(rng, μ.β) + return μ.f_c(a, b) +end function transport_to_mvstd(ν_inner::StdMeasure, μ::CombinedMeasure, ab) diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 9d019664..20084fd3 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -85,13 +85,13 @@ end @inline function logdensity_def(μ::ProductMeasure, x) - marginals_density_op(logdensity_def, marginals(μ), x) + _marginals_density_op(logdensity_def, marginals(μ), x) end @inline function unsafe_logdensityof(μ::ProductMeasure, x) - marginals_density_op(unsafe_logdensityof, marginals(μ), x) + _marginals_density_op(unsafe_logdensityof, marginals(μ), x) end @inline function logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x) - marginals_density_op(logdensity_rel, marginals(μ), marginals(ν), x) + _marginals_density_op(logdensity_rel, marginals(μ), marginals(ν), x) end function _marginals_density_op(density_op::F, marginals_μ, x) where F @@ -103,7 +103,7 @@ end end @inline function _marginals_density_op(density_op::F, marginals_μ::NamedTuple{names}, x::NamedTuple) where {F,names} nms = Val{names}() - _marginals_density_op(density_op, marginals_μ, _reorder_nt(values(x), nms)) + _marginals_density_op(density_op, values(marginals_μ), values(_reorder_nt(x, Val(nms)))) end function _marginals_density_op(density_op::F, marginals_μ, marginals_ν, x) where F @@ -115,11 +115,11 @@ end end @inline function _marginals_density_op(density_op::F, marginals_μ::NamedTuple{names}, marginals_ν::NamedTuple, x::NamedTuple) where {F,names} nms = Val{names}() - _marginals_density_op(density_op, marginals_μ, _reorder_nt(values(marginals_ν), nms), _reorder_nt(values(x), nms)) + _marginals_density_op(density_op, values(marginals_μ), values(_reorder_nt(marginals_ν, nms)), values(_reorder_nt(x, nms))) end -@inline basemeasure(μ::ProductMeasure) =_marginals_basemeasure(marginals(μ)) +@inline basemeasure(μ::ProductMeasure) = _marginals_basemeasure(marginals(μ)) _marginals_basemeasure(marginals_μ) = productmeasure(map(basemeasure, marginals_μ)) @@ -129,7 +129,7 @@ _marginals_basemeasure(marginals_μ) = productmeasure(map(basemeasure, marginals function _marginals_basemeasure(marginals_μ::Base.Generator{I,F}) where {I,F} T = Core.Compiler.return_type(marginals_μ.f, Tuple{eltype(marginals_μ.iter)}) B = Core.Compiler.return_type(basemeasure, Tuple{T}) - _marginals_basemeasure_impl(μ, B, static(Base.issingletontype(B))) + _marginals_basemeasure_impl(marginals_μ, B, static(Base.issingletontype(B))) end function _marginals_basemeasure(marginals_μ::AbstractMappedArray{T}) where {T} diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 5715c723..fbd3fe69 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -39,7 +39,7 @@ end # For transport, always pull a PowerMeasure back to one-dimensional PowerMeasure first: -transport_origin(μ::PowerMeasure{<:Any,N}) where N = μ.parent^product(pwr_size(μ)) +transport_origin(μ::PowerMeasure{<:Any,N}) where N = μ.parent^prod(pwr_size(μ)) function from_origin(μ::PowerMeasure{<:Any,N}, x_origin) where N # Sanity check, should never fail: @@ -50,7 +50,7 @@ end # A one-dimensional PowerMeasure has an origin if it's parent has an origin: -transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _origin_pwr(typeof(μ), transport_origin(μ.parent), μ.axes) +transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _pwr_origin(typeof(μ), pwr_base(μ), pwr_axes(μ)) _pwr_origin(::Type{MU}, parent_origin, axes) where MU = parent_origin^axes _pwr_origin(::Type{MU}, ::NoTransportOrigin, axes) where MU = NoTransportOrigin{MU} @@ -81,7 +81,7 @@ end # Transport to a multivariate standard measure from any measure: function transport_def(ν::StdPowerMeasure{MU,1}, μ::AbstractMeasure, ab) where MU - ν_inner = _inner_stdmeasure(ν) + ν_inner = pwr_base(ν) transport_to_mvstd(ν_inner, μ, ab) end @@ -98,9 +98,9 @@ function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::Abstract _to_mvstd_withorigin(ν_inner, μ, transport_origin(μ), x) end -function _to_mvstd_withorigin(ν_inner::StdMeasure, ::AbstractMeasure, μ_origin, x) +function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, μ_origin, x) x_origin = transport_to_mvstd(ν_inner, μ_origin, x) - from_origin(x_origin) + from_origin(μ, x_origin) end function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoTransportOrigin, x) @@ -111,7 +111,7 @@ end # Transport from a multivariate standard measure to any measure: function transport_def(ν::AbstractMeasure, μ::StdPowerMeasure{MU,1}, x) where MU - μ_inner = _inner_stdmeasure(μ) + μ_inner = pwr_base(μ) _transport_from_mvstd(ν, μ_inner, x) end @@ -125,8 +125,7 @@ end function transport_from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) dof_ν = fast_dof(ν) - origin = transport_origin(ν) - return _from_mvstd_with_rest_withdof(ν, dof_ν, μ_inner, x, dof_ν, origin) + return _from_mvstd_with_rest_withdof(ν, dof_ν, μ_inner, x) end function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν::IntegerLike, μ_inner::StdMeasure, x) diff --git a/src/mass-interface.jl b/src/mass-interface.jl index c2cadfc2..09a83b2c 100644 --- a/src/mass-interface.jl +++ b/src/mass-interface.jl @@ -68,7 +68,7 @@ finite, or we may know nothing at all about it. For these cases, it will return `UnknownFiniteMass` or `UnknownMass`, respectively. When no `massof` method exists, it defaults to `UnknownMass`. """ -massof(m::AbstractMeasure) = UnknownMass(m) +massof(::AbstractMeasure) = UnknownMass() struct NormalizedMeasure{P,M} <: AbstractMeasure parent::P From fadbd837a10380bba6f91fd99866d225955e342d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 3 Jul 2023 01:37:06 +0200 Subject: [PATCH 073/133] STASH --- src/combinators/product_transport.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index fbd3fe69..8b61b5a1 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -86,7 +86,7 @@ function transport_def(ν::StdPowerMeasure{MU,1}, μ::AbstractMeasure, ab) where end function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) - return _to_mvstd_withdof(ν_inner, μ, fast_dof(μ), x, origin) + return _to_mvstd_withdof(ν_inner, μ, fast_dof(μ), x) end function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ::IntegerLike, x) @@ -147,9 +147,9 @@ function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::AbstractNoDOF, μ_ _from_mvstd_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) end -function _from_mvstd_with_rest_withorigin(::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) +function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) x_origin, x_rest = transport_from_mvstd_with_rest(ν_origin, x, μ_inner) - from_origin(x_origin), x_rest + from_origin(ν, x_origin), x_rest end function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, ::NoTransportOrigin, μ_inner::StdMeasure, x) @@ -170,7 +170,7 @@ end @inline transport_origin(μ::ProductMeasure) = _marginals_tp_origin(marginals(μ)) @inline from_origin(μ::ProductMeasure, x_origin) = _marginals_from_origin(marginals(μ), x_origin) -_marginals_tp_origin(::Ms) where Ms = NoTransportOrigin{PowerMeasure{M}}() +_marginals_tp_origin(::Ms) where Ms = NoTransportOrigin{ProductMeasure{Ms}}() # Pull back from a product over a Fill to a power measure: @@ -243,7 +243,7 @@ const _MaybeUnkownKnownDOFs = Union{Tuple{Vararg{_MaybeUnkownDOF,N}} where N, St function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) νs = marginals(ν) - dofs = map(fast_dof, marginals_μ) + dofs = map(fast_dof, νs) return _marginals_from_mvstd_with_rest(νs, dofs, μ_inner, x) end @@ -261,7 +261,7 @@ function _split_x_by_marginals_with_rest(dofs::Union{Tuple,AbstractVector}, x::A first_idxs = _dof_access_firstidxs(dofs, maybestatic_first(x_idxs)) xs = map((from, n) -> _get_or_view(x, from, from + n - one(n)), first_idxs, dofs) x_rest = _get_or_view(x, first_idxs[end] + dofs[end], maybestatic_last(x_idxs)) - return xs, r_rest + return xs, x_rest end function _marginals_from_mvstd_with_rest(νs, dofs::_KnownDOFs, μ_inner::StdMeasure, x::AbstractVector{<:Real}) @@ -279,7 +279,7 @@ end function _marginals_from_mvstd_with_rest_nodof(νs::Tuple{Vararg{AbstractMeasure,N}}, μ_inner::StdMeasure, x::AbstractVector{<:Real}) where N # ToDo: Check for type stability, may need generated function y1, x_rest = transport_from_mvstd_with_rest(νs[1], μ_inner, x) - y2_end, x_final_rest = _marginals_from_mvstd_with_rest(νs[2:end], μ_inner, x_rest) + y2_end, x_final_rest = _marginals_from_mvstd_with_rest_nodof(νs[2:end], μ_inner, x_rest) return (y1, y2_end...), x_final_rest end @@ -288,7 +288,7 @@ function _marginals_from_mvstd_with_rest_nodof(νs::AbstractVector{<:AbstractMea y1, x_rest = transport_from_mvstd_with_rest(νs[1], μ_inner, x) ys = [y1] for ν in νs[begin+1:end] - y_i, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x_rest) + y_i, x_rest = _marginals_from_mvstd_with_rest_nodof(ν, μ_inner, x_rest) ys = vcat(ys, y_i) end return ys, x_rest From fa58a3936ba8d54ab81e72c8b5a3f45e98d78283 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 3 Jul 2023 01:40:49 +0200 Subject: [PATCH 074/133] STASH --- src/combinators/power.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 7420276d..1c51f9f6 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -104,14 +104,13 @@ params(d::PowerMeasure) = params(first(marginals(d))) basemeasure(d.parent)^d.axes end -@inline logdensity_def(d::PowerMeasure, x) = _pwr_logdensity_def(d.parent, x, prod(pwr_size(d))) +@inline logdensity_def(d::PowerMeasure, x) = _pwr_logdensity_def(pwr_base(d), x, prod(pwr_size(d))) -@inline _pwr_logdensity_def(::PowerMeasure, x, ::Integer, ::StaticInteger{0}) = static(false) +@inline _pwr_logdensity_def(d_base, x, ::Integer, ::StaticInteger{0}) = static(false) -@inline function _pwr_logdensity_def(d::PowerMeasure, x, ::IntegerLike) - parent = d.parent +@inline function _pwr_logdensity_def(d_base, x, ::IntegerLike) sum(x) do xj - logdensity_def(parent, xj) + logdensity_def(d_base, xj) end end From 9ad36f396e53f2a67bd7c81307a2b2d0e73860dd Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Jul 2023 16:08:31 +0200 Subject: [PATCH 075/133] STASH --- src/combinators/power.jl | 2 +- src/combinators/product_transport.jl | 16 ++++++------ src/static.jl | 38 ++++++++++++++++++++++++++++ src/transport.jl | 12 ++++++++- 4 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 1c51f9f6..9a3fb3b7 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -170,4 +170,4 @@ massof(m::PowerMeasure) = massof(m.parent)^prod(m.axes) Represents and N-dimensional power of the standard measure `MU()`. """ -const StdPowerMeasure{MU<:StdMeasure,N} = PowerMeasure{MU,<:NTuple{N,Base.OneTo}} +const StdPowerMeasure{MU<:StdMeasure,N} = PowerMeasure{MU,<:NTuple{N,UnitRangeFromOne}} diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 8b61b5a1..94f4aea6 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -39,7 +39,7 @@ end # For transport, always pull a PowerMeasure back to one-dimensional PowerMeasure first: -transport_origin(μ::PowerMeasure{<:Any,N}) where N = μ.parent^prod(pwr_size(μ)) +transport_origin(μ::PowerMeasure{<:Any,N}) where N = transport_origin(μ.parent)^prod(pwr_size(μ)) function from_origin(μ::PowerMeasure{<:Any,N}, x_origin) where N # Sanity check, should never fail: @@ -67,10 +67,10 @@ function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure,1}, x) return transport_def(ν, μ.parent, only(x)) end -function transport_def(ν::PowerMeasure{<:StdMeasure,1}, μ::StdMeasure, x) - axes_ν = pwr_axes(ν) - @assert prod(axes_ν) == 1 - return fill_with(transport_def(ν.parent, μ, x), map(maybestatic_length, ν.axes)) +function transport_def(ν::StdPowerMeasure{<:StdMeasure,1}, μ::StdMeasure, x) + sz_ν = pwr_size(ν) + @assert prod(sz_ν) == 1 + return fill_with(transport_def(ν.parent, μ, x), sz_ν) end function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{NU,1}, x,) where {MU,NU} @@ -80,9 +80,9 @@ end # Transport to a multivariate standard measure from any measure: -function transport_def(ν::StdPowerMeasure{MU,1}, μ::AbstractMeasure, ab) where MU +function transport_def(ν::StdPowerMeasure{MU,1}, μ::AbstractMeasure, x) where MU ν_inner = pwr_base(ν) - transport_to_mvstd(ν_inner, μ, ab) + transport_to_mvstd(ν_inner, μ, x) end function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) @@ -90,7 +90,7 @@ function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) end function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ::IntegerLike, x) - y = transport_to(ν_inner^dof_μ, μ, x) + y = transport_def(ν_inner^dof_μ, μ, x) return y end diff --git a/src/static.jl b/src/static.jl index e0213096..58205ab5 100644 --- a/src/static.jl +++ b/src/static.jl @@ -44,6 +44,20 @@ end FillArrays.Fill(x, dyn_axs) end +@inline function fill_with(x::T, axs::Tuple{Vararg{StaticOneTo}}) where T + fill(x, _sarray_type(T, map(maybestatic_length, axs))) +end + +@inline function fill_with(x::T, sz::Tuple{Vararg{StaticInteger}}) where T + fill(x, _sarray_type(T, sz)) +end + +@inline @generated function _sarray_type(::Type{T}, sz::Tuple{Vararg{StaticInteger}}) where T + Ns = map(p -> p.parameters[1], sz.parameters) + :(SArray{Tuple{$Ns...}}) +end + + """ MeasureBase.maybestatic_length(x)::IntegerLike @@ -102,3 +116,27 @@ Returns the last element of `A` as a dynamic or static value. maybestatic_last(A::AbstractArray) = last(A) maybestatic_last(::StaticArrays.SOneTo{N}) where N = static(N) maybestatic_last(::Static.OptionallyStaticUnitRange{<:Static.StaticInteger,<:Static.StaticInteger{until}}) where until = static(until) + + +""" + const UnitRangeFromOne + +Alias for unit ranges that start at one. +""" +const UnitRangeFromOne = Union{Base.OneTo, Static.OptionallyStaticUnitRange, StaticArrays.SOneTo} + + +""" + const StaticOneTo{N} + +A static unit range from one to N. +""" +const StaticOneTo{N} = Union{Static.OptionallyStaticUnitRange{StaticInt{1},StaticInt{N}}, StaticArrays.SOneTo{N}} + + +""" + const StaticUnitRange + +A static unit range. +""" +const StaticUnitRange = Union{Static.OptionallyStaticUnitRange{<:StaticInt,<:StaticInt}, StaticArrays.SOneTo} diff --git a/src/transport.jl b/src/transport.jl index ce736b72..fa792080 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -8,6 +8,9 @@ See [`MeasureBase.transport_origin`](@ref). """ struct NoTransportOrigin{NU} end +Base.:^(origin::NoTransportOrigin, ::IntegerLike) = origin + + """ MeasureBase.transport_origin(ν) @@ -141,6 +144,13 @@ end μ, x, ) where {n_ν,n_μ} + if n_ν == 10 + return :(throw(ArgumentError("Transport to measure of type $(nameof(typeof(ν))) not supported, origin stack too deep."))) + end + if n_μ == 10 + return :(throw(ArgumentError("Transport from measure of type $(nameof(typeof(μ))) not supported, origin stack too deep."))) + end + prog = quote μ0 = μ x0 = x @@ -153,7 +163,7 @@ end end for i in 1:n_μ x_i = Symbol(:x, i) - x_last = Symbol(:x, i - 1) + x_last = Symbol(:x, i - 1)_origin_depth(ν)(ν)(ν) μ_last = Symbol(:μ, i - 1) push!(prog.args, :($x_i = to_origin($μ_last, $x_last))) end From 32c903ced71825d1e98568c3364a4d3f06e59eba Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Jul 2023 16:25:15 +0200 Subject: [PATCH 076/133] STASH --- src/collection_utils.jl | 3 -- src/combinators/product_transport.jl | 2 +- src/static.jl | 61 +++++++++++++++++----------- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index ca43eb22..5698174a 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -50,9 +50,6 @@ end end -# ToDo: Add static reshape for static arrays! - - _empty_zero(::AbstractVector{T}) where {T<:Real} = Fill(zero(T), 0) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 94f4aea6..522ba02a 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -44,7 +44,7 @@ transport_origin(μ::PowerMeasure{<:Any,N}) where N = transport_origin(μ.parent function from_origin(μ::PowerMeasure{<:Any,N}, x_origin) where N # Sanity check, should never fail: @assert x_origin isa AbstractVector - return reshape(x_origin, pwr_size(μ)...) + return maybestatic_reshape(x_origin, pwr_size(μ)...) end diff --git a/src/static.jl b/src/static.jl index 58205ab5..fd7110dd 100644 --- a/src/static.jl +++ b/src/static.jl @@ -5,6 +5,31 @@ Equivalent to `Union{Integer,Static.StaticInteger}`. """ const IntegerLike = Union{Integer,Static.StaticInteger} + +""" + const UnitRangeFromOne + +Alias for unit ranges that start at one. +""" +const UnitRangeFromOne = Union{Base.OneTo, Static.OptionallyStaticUnitRange, StaticArrays.SOneTo} + + +""" + const StaticOneTo{N} + +A static unit range from one to N. +""" +const StaticOneTo{N} = Union{Static.OptionallyStaticUnitRange{StaticInt{1},StaticInt{N}}, StaticArrays.SOneTo{N}} + + +""" + const StaticUnitRange + +A static unit range. +""" +const StaticUnitRange = Union{Static.OptionallyStaticUnitRange{<:StaticInt,<:StaticInt}, StaticArrays.SOneTo} + + """ MeasureBase.one_to(n::IntegerLike) @@ -58,6 +83,18 @@ end end +""" + MeasureBase.maybestatic_reshape(A, sz) + +Reshapes array `A` to sizes `sz`. + +If `A` is a static array and `sz` is static, the result is a static array. +""" +function maybestatic_reshape end + +maybestatic_reshape(A, sz) = reshape(A, sz) +maybestatic_reshape(A::StaticArray, sz::Tuple{Vararg{StaticInteger}}) = _sarray_type(eltype(A), sz)(Tuple(A)) + """ MeasureBase.maybestatic_length(x)::IntegerLike @@ -116,27 +153,3 @@ Returns the last element of `A` as a dynamic or static value. maybestatic_last(A::AbstractArray) = last(A) maybestatic_last(::StaticArrays.SOneTo{N}) where N = static(N) maybestatic_last(::Static.OptionallyStaticUnitRange{<:Static.StaticInteger,<:Static.StaticInteger{until}}) where until = static(until) - - -""" - const UnitRangeFromOne - -Alias for unit ranges that start at one. -""" -const UnitRangeFromOne = Union{Base.OneTo, Static.OptionallyStaticUnitRange, StaticArrays.SOneTo} - - -""" - const StaticOneTo{N} - -A static unit range from one to N. -""" -const StaticOneTo{N} = Union{Static.OptionallyStaticUnitRange{StaticInt{1},StaticInt{N}}, StaticArrays.SOneTo{N}} - - -""" - const StaticUnitRange - -A static unit range. -""" -const StaticUnitRange = Union{Static.OptionallyStaticUnitRange{<:StaticInt,<:StaticInt}, StaticArrays.SOneTo} From cc7a724e994c3f7c56ddc9126b1df3e82cbbbb6a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Jul 2023 16:44:53 +0200 Subject: [PATCH 077/133] STASH --- ambiguity-fixes.jl.txt | 126 --------------------------------------- src/combinators/power.jl | 6 ++ src/static.jl | 10 ++-- 3 files changed, 11 insertions(+), 131 deletions(-) delete mode 100644 ambiguity-fixes.jl.txt diff --git a/ambiguity-fixes.jl.txt b/ambiguity-fixes.jl.txt deleted file mode 100644 index 7210bcc4..00000000 --- a/ambiguity-fixes.jl.txt +++ /dev/null @@ -1,126 +0,0 @@ -function superpose(::T, ::T) where {T<:SuperpositionMeasure} - @error "FIXME" -end - -function kernel(::Type{M}, ::NamedTuple{()}) where {M<:ParameterizedMeasure} - @error "FIXME" -end - -function logdensity_def( - ::T, - ::S, - ::Any, -) where { - T<:(MeasureBase.SuperpositionMeasure{Tuple{A,B}} where {A,B}), - S<:(MeasureBase.SuperpositionMeasure{Tuple{A,B}} where {A,B}), -} - @error "FIXME" -end - -function transport_def(::StdUniform, ::StdExponential, ::NoTransformOrigin) - @error "FIXME" -end - -@inline function transport_def(::StdExponential, ::StdUniform, ::NoTransformOrigin) - @error "FIXME" -end - -function transport_def(::StdUniform, ::StdExponential, ::NoTransport) - @error "FIXME" -end - -@inline function transport_def(::StdExponential, ::StdUniform, ::NoTransport) - @error "FIXME" -end - -function transport_def(::StdUniform, ::StdLogistic, ::NoTransformOrigin) - @error "FIXME" -end - -function transport_def(::StdUniform, ::StdLogistic, ::NoTransport) - @error "FIXME" -end - -function transport_def(::StdLogistic, ::StdUniform, ::NoTransport) - @error "FIXME" -end - -@inline function transport_def(::StdLogistic, ::StdUniform, ::NoTransformOrigin) - @error "FIXME" -end - -@inline function transport_def(::MU, ::MU, ::NoTransport) where {MU<:StdMeasure} - @error "FIXME" -end - -@inline function transport_def(::MU, ::MU, ::NoTransformOrigin) where {MU<:StdMeasure} - @error "FIXME" -end - -function transport_def(::StdMeasure, ::PowerMeasure{<:StdMeasure}, ::NoTransport) - @error "FIXME" -end - -function transport_def(::StdMeasure, ::PowerMeasure{<:StdMeasure}, ::NoTransformOrigin) - @error "FIXME" -end - -function transport_def(::PowerMeasure{<:StdMeasure}, ::StdMeasure, ::NoTransformOrigin) - @error "FIXME" -end - -function transport_def(::PowerMeasure{<:StdMeasure}, ::StdMeasure, ::NoTransport) - @error "FIXME" -end - -function transport_def( - ::PowerMeasure{<:StdMeasure,<:Tuple{Base.OneTo}}, - ::PowerMeasure{<:StdMeasure,<:Tuple{Base.OneTo}}, - ::NoTransport, -) - @error "FIXME" -end - -function transport_def( - ::PowerMeasure{<:StdMeasure,<:Tuple{Base.OneTo}}, - ::PowerMeasure{<:StdMeasure,<:Tuple{Base.OneTo}}, - ::NoTransformOrigin, -) - @error "FIXME" -end - -function transport_def( - ::PowerMeasure{<:StdMeasure,<:Tuple{Vararg{Base.OneTo,N}}}, - ::PowerMeasure{<:StdMeasure,<:Tuple{Vararg{Base.OneTo,M}}}, - ::NoTransport, -) where {N,M} - @error "FIXME" -end - -function transport_def( - ::PowerMeasure{<:StdMeasure,<:Tuple{Vararg{Base.OneTo,N}}}, - ::PowerMeasure{<:StdMeasure,<:Tuple{Vararg{Base.OneTo,M}}}, - ::NoTransformOrigin, -) where {N,M} - @error "FIXME" -end - -function transport_to(::Type{NU}, ::Type{MU}) where {MU<:StdMeasure,NU<:StdMeasure} - @error "FIXME" -end - -function transport_def(::Dirac, ::PowerMeasure{<:StdMeasure}, ::NoTransport) - @error "FIXME" -end - -function transport_def(::Dirac, ::PowerMeasure{<:StdMeasure}, ::NoTransformOrigin) - @error "FIXME" -end - -@inline function transport_def(::PowerMeasure{<:StdMeasure}, ::Dirac, ::NoTransport) - @error "FIXME" -end - -@inline function transport_def(::PowerMeasure{<:StdMeasure}, ::Dirac, ::NoTransformOrigin) - @error "FIXME" -end diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 9a3fb3b7..15042182 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -69,6 +69,12 @@ function Base.rand(rng::AbstractRNG, d::PowerMeasure) map(_ -> rand(rng, d.parent), _cartidxs(d.axes)) end +function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure{M,<:Tuple{Vararg{StaticOneTo}}}) where {T,M} + sz = pwr_size(d) + base_d = pwr_base(d) + _sarray_type(sz)(ntuple(_ -> rand(rng, T, base_d), prod(sz))) +end + function testvalue(::Type{T}, d::PowerMeasure) where {T} map(_ -> testvalue(T, d.parent), _cartidxs(d.axes)) end diff --git a/src/static.jl b/src/static.jl index fd7110dd..a7fa9af5 100644 --- a/src/static.jl +++ b/src/static.jl @@ -69,15 +69,15 @@ end FillArrays.Fill(x, dyn_axs) end -@inline function fill_with(x::T, axs::Tuple{Vararg{StaticOneTo}}) where T - fill(x, _sarray_type(T, map(maybestatic_length, axs))) +@inline function fill_with(x, axs::Tuple{Vararg{StaticOneTo}}) + fill(x, _sarray_type(map(maybestatic_length, axs))) end -@inline function fill_with(x::T, sz::Tuple{Vararg{StaticInteger}}) where T - fill(x, _sarray_type(T, sz)) +@inline function fill_with(x, sz::Tuple{Vararg{StaticInteger}}) + fill(x, _sarray_type(sz)) end -@inline @generated function _sarray_type(::Type{T}, sz::Tuple{Vararg{StaticInteger}}) where T +@inline @generated function _sarray_type(sz::Tuple{Vararg{StaticInteger}}) Ns = map(p -> p.parameters[1], sz.parameters) :(SArray{Tuple{$Ns...}}) end From b3fb6fa6f328ff0c690022298d421a4877b89326 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Jul 2023 17:59:28 +0200 Subject: [PATCH 078/133] STASH --- src/combinators/power.jl | 11 ++++++++- src/combinators/product_transport.jl | 2 +- src/domains.jl | 2 +- src/static.jl | 35 +++++++++++++++++----------- test/static.jl | 12 +++++----- 5 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 15042182..63d1093a 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -70,9 +70,18 @@ function Base.rand(rng::AbstractRNG, d::PowerMeasure) end function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure{M,<:Tuple{Vararg{StaticOneTo}}}) where {T,M} + #!!!!!!!!!!! sz = pwr_size(d) base_d = pwr_base(d) _sarray_type(sz)(ntuple(_ -> rand(rng, T, base_d), prod(sz))) + 0 +end + +function Base.rand(rng::AbstractRNG, d::PowerMeasure{M,<:Tuple{Vararg{StaticOneTo}}}) where M + #!!!!!!!!!!! + sz = pwr_size(d) + base_d = pwr_base(d) + broadcast(_ -> rand(rng, base_d), MeasureBase.maybestatic_fill(nothing, sz)) end function testvalue(::Type{T}, d::PowerMeasure) where {T} @@ -88,7 +97,7 @@ end @inline _pm_axes(sz::Tuple{Vararg{IntegerLike,N}}) where {N} = map(one_to, sz) @inline _pm_axes(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} = axs -marginals(d::PowerMeasure) = fill_with(d.parent, d.axes) +marginals(d::PowerMeasure) = maybestatic_fill(d.parent, d.axes) function Base.:^(μ::AbstractMeasure, dims::Tuple{Vararg{AbstractArray,N}}) where {N} powermeasure(μ, dims) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 522ba02a..3f76a967 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -70,7 +70,7 @@ end function transport_def(ν::StdPowerMeasure{<:StdMeasure,1}, μ::StdMeasure, x) sz_ν = pwr_size(ν) @assert prod(sz_ν) == 1 - return fill_with(transport_def(ν.parent, μ, x), sz_ν) + return maybestatic_fill(transport_def(ν.parent, μ, x), sz_ν) end function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{NU,1}, x,) where {MU,NU} diff --git a/src/domains.jl b/src/domains.jl index e03f753c..9579780b 100644 --- a/src/domains.jl +++ b/src/domains.jl @@ -116,7 +116,7 @@ struct Simplex <: CodimOne end function zeroset(::Simplex) f(x::AbstractArray{T}) where {T} = sum(x) - one(T) - ∇f(x::AbstractArray{T}) where {T} = fill_with(one(T), size(x)) + ∇f(x::AbstractArray{T}) where {T} = maybestatic_fill(one(T), size(x)) ZeroSet(f, ∇f) end diff --git a/src/static.jl b/src/static.jl index a7fa9af5..a38e7be5 100644 --- a/src/static.jl +++ b/src/static.jl @@ -46,21 +46,21 @@ _dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N) _dynamic(r::AbstractUnitRange) = minimum(r):maximum(r) """ - MeasureBase.fill_with(x, sz::NTuple{N,<:IntegerLike}) where N + MeasureBase.maybestatic_fill(x, sz::NTuple{N,<:IntegerLike}) where N Creates an array of size `sz` filled with `x`. Returns an instance of `FillArrays.Fill`. """ -function fill_with end +function maybestatic_fill end -@inline fill_with(x::T, ::Tuple{}) where T = FillArrays.Fill(x) +@inline maybestatic_fill(x::T, ::Tuple{}) where T = FillArrays.Fill(x) -@inline function fill_with(x::T, sz::Tuple{Vararg{IntegerLike,N}}) where {T,N} - fill_with(x, map(one_to, sz)) +@inline function maybestatic_fill(x::T, sz::Tuple{Vararg{IntegerLike,N}}) where {T,N} + maybestatic_fill(x, map(one_to, sz)) end -@inline function fill_with(x::T, axs::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N} +@inline function maybestatic_fill(x::T, axs::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N} # While `FillArrays.Fill` (mostly?) works with axes that are static unit # ranges, some operations that automatic differentiation requires do fail # on such instances of `Fill` (e.g. `reshape` from dynamic to static size). @@ -69,17 +69,26 @@ end FillArrays.Fill(x, dyn_axs) end -@inline function fill_with(x, axs::Tuple{Vararg{StaticOneTo}}) - fill(x, _sarray_type(map(maybestatic_length, axs))) +@inline function maybestatic_fill(x::T, axs::Tuple{Vararg{StaticOneTo}}) where T + fill(x, staticarray_type(T, map(maybestatic_length, axs))) end -@inline function fill_with(x, sz::Tuple{Vararg{StaticInteger}}) - fill(x, _sarray_type(sz)) +@inline function maybestatic_fill(x::T, sz::Tuple{Vararg{StaticInteger}}) where T + fill(x, staticarray_type(T, sz)) end -@inline @generated function _sarray_type(sz::Tuple{Vararg{StaticInteger}}) - Ns = map(p -> p.parameters[1], sz.parameters) - :(SArray{Tuple{$Ns...}}) + +""" + staticarray_type(T, sz::Tuple{Vararg{StaticInteger}}) + +Returns the type of a static array with element type `T` and size `sz`. +""" +function staticarray_type end + +@inline @generated function staticarray_type(::Type{T}, sz::Tuple{Vararg{StaticInteger,N}}) where {T,N} + szs = map(p -> p.parameters[1], sz.parameters) + len = prod(szs) + :(SArray{Tuple{$szs...},T,$N,$len}) end diff --git a/test/static.jl b/test/static.jl index f618124b..83e4f930 100644 --- a/test/static.jl +++ b/test/static.jl @@ -17,14 +17,14 @@ import FillArrays @test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo @test @inferred(MeasureBase.one_to(static(7))) == static(1):static(7) - @test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == + @test @inferred(MeasureBase.maybestatic_fill(4.2, (7,))) == FillArrays.Fill(4.2, 7) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (static(7),))) == FillArrays.Fill(4.2, 7) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (3, static(7)))) == FillArrays.Fill(4.2, 3, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == + @test @inferred(MeasureBase.maybestatic_fill(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,)) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (static(3):static(7),))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == + @test @inferred(MeasureBase.maybestatic_fill(4.2, (3:7, static(2):static(5)))) == FillArrays.Fill(4.2, (3:7, 2:5)) @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int From 0279e3dd5bd11e797db445bcc4821366835529d8 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Jul 2023 18:05:06 +0200 Subject: [PATCH 079/133] STASH pwr rand incomplete --- src/combinators/power.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 63d1093a..98385e63 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -61,6 +61,15 @@ function _cartidxs(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} CartesianIndices(map(_dynamic, axs)) end +#!!!!!!!!!! +function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} + _pwr_rand(rng, T, d.parent, d.axes) +end + +function Base.rand(rng::AbstractRNG, d::PowerMeasure) + _pwr_rand(rng, T, d.parent, d.axes) +end + function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} map(_ -> rand(rng, T, d.parent), _cartidxs(d.axes)) end @@ -70,15 +79,12 @@ function Base.rand(rng::AbstractRNG, d::PowerMeasure) end function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure{M,<:Tuple{Vararg{StaticOneTo}}}) where {T,M} - #!!!!!!!!!!! sz = pwr_size(d) base_d = pwr_base(d) - _sarray_type(sz)(ntuple(_ -> rand(rng, T, base_d), prod(sz))) - 0 + broadcast(_ -> rand(rng, T, base_d), MeasureBase.maybestatic_fill(nothing, sz)) end function Base.rand(rng::AbstractRNG, d::PowerMeasure{M,<:Tuple{Vararg{StaticOneTo}}}) where M - #!!!!!!!!!!! sz = pwr_size(d) base_d = pwr_base(d) broadcast(_ -> rand(rng, base_d), MeasureBase.maybestatic_fill(nothing, sz)) @@ -186,3 +192,9 @@ massof(m::PowerMeasure) = massof(m.parent)^prod(m.axes) Represents and N-dimensional power of the standard measure `MU()`. """ const StdPowerMeasure{MU<:StdMeasure,N} = PowerMeasure{MU,<:NTuple{N,UnitRangeFromOne}} + +function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure{<:StdMeasure,<:Tuple{Vararg{StaticOneTo}}}) where {T,M} + sz = pwr_size(d) + base_d = pwr_base(d) + broadcast(_ -> rand(rng, T, base_d), MeasureBase.maybestatic_fill(nothing, sz)) +end From 826110724a99fc33987237f1e3913a07e7f2dcd5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Jul 2023 21:06:08 +0200 Subject: [PATCH 080/133] STASH rand --- src/combinators/power.jl | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 98385e63..7c428e27 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -61,15 +61,6 @@ function _cartidxs(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} CartesianIndices(map(_dynamic, axs)) end -#!!!!!!!!!! -function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} - _pwr_rand(rng, T, d.parent, d.axes) -end - -function Base.rand(rng::AbstractRNG, d::PowerMeasure) - _pwr_rand(rng, T, d.parent, d.axes) -end - function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} map(_ -> rand(rng, T, d.parent), _cartidxs(d.axes)) end @@ -193,8 +184,6 @@ Represents and N-dimensional power of the standard measure `MU()`. """ const StdPowerMeasure{MU<:StdMeasure,N} = PowerMeasure{MU,<:NTuple{N,UnitRangeFromOne}} -function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure{<:StdMeasure,<:Tuple{Vararg{StaticOneTo}}}) where {T,M} - sz = pwr_size(d) - base_d = pwr_base(d) - broadcast(_ -> rand(rng, T, base_d), MeasureBase.maybestatic_fill(nothing, sz)) -end +# ToDo: Fast specialized rand for static and non-static StdPowerMeasure! + +# ToDo: Define mrand and dispatch Base.rand to mrand to burden Base.rand with less methods! From f73c1fe25d76d14a4828106f39c5659334d7e75e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Jul 2023 21:35:50 +0200 Subject: [PATCH 081/133] STASH FIXES --- src/static.jl | 2 +- src/transport.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/static.jl b/src/static.jl index a38e7be5..2529b093 100644 --- a/src/static.jl +++ b/src/static.jl @@ -102,7 +102,7 @@ If `A` is a static array and `sz` is static, the result is a static array. function maybestatic_reshape end maybestatic_reshape(A, sz) = reshape(A, sz) -maybestatic_reshape(A::StaticArray, sz::Tuple{Vararg{StaticInteger}}) = _sarray_type(eltype(A), sz)(Tuple(A)) +maybestatic_reshape(A::StaticArray, sz::Tuple{Vararg{StaticInteger}}) = staticarray_type(eltype(A), sz)(Tuple(A)) """ diff --git a/src/transport.jl b/src/transport.jl index fa792080..d11ce318 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -163,7 +163,7 @@ end end for i in 1:n_μ x_i = Symbol(:x, i) - x_last = Symbol(:x, i - 1)_origin_depth(ν)(ν)(ν) + x_last = Symbol(:x, i - 1) μ_last = Symbol(:μ, i - 1) push!(prog.args, :($x_i = to_origin($μ_last, $x_last))) end From 36956daf589cafc6ed38253f9340d10887dcda10 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 16 Jul 2023 10:22:35 +0200 Subject: [PATCH 082/133] Specialize pushfwd and pullbck for DensityMeasure --- src/combinators/transformedmeasure.jl | 46 +++++++++++++++++++++------ 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 65e58cfa..6377beee 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -18,6 +18,8 @@ function parent(::AbstractTransformedMeasure) end export PushforwardMeasure +# ToDo: Store FunctionWithInverse instead of f and finv in PushforwardMeasure? + """ struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward f :: F @@ -126,23 +128,30 @@ measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the To manually specify an inverse, call `pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -function pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr()) +pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pullbck_impl(f, μ, volcorr) + +function _generic_pushfwd_impl(f, μ, volcorr::TransformVolCorr = WithVolCorr()) PushforwardMeasure(f, inverse(f), μ, volcorr) end -function pushfwd(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) +function _generic_pushfwd_impl(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) _pushfwd_of_pushfwd(f, μ, μ.volcorr, volcorr) end # Either both WithVolCorr or both NoVolCorr, so we can merge them -function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, ::V, v::V) where {V} - pushfwd(fchain((μ.f, f)), μ.origin, v) +function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, ::V, volcorr::V) where {V} + pushfwd(f ∘ fchain(μ.f), μ.origin, volcorr) +end + +function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, volcorr) + PushforwardMeasure(f, inverse(f), μ, volcorr) end -function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, v) - PushforwardMeasure(f, inverse(f), μ, v) +function _generic_pushfwd_impl(f, μ::DensityMeasure, volcorr::TransformVolCorr = WithVolCorr()) + mintegrate(fchain(μ.f) ∘ inverse(f), pushfwd(f, μ.base, volcorr)) end + ############################################################################### # pullback @@ -161,9 +170,28 @@ some cases, we may be focusing on log-density (and not, for example, sampling). To manually specify an inverse, call `pullbck(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -function pullbck(f, μ, volcorr::TransformVolCorr = WithVolCorr()) - PushforwardMeasure(inverse(f), f, μ, volcorr) -end +pullbck(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pullbck_impl(f, μ, volcorr) export pullbck @deprecate pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) pullbck(f, μ, volcorr) + +function _generic_pullbck_impl(f, μ, volcorr::TransformVolCorr = WithVolCorr()) + PushforwardMeasure(inverse(f), f, μ, volcorr) +end + +function _generic_pushfwd_impl(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) + _pullbck_of_pushfwd(f, μ, μ.volcorr, volcorr) +end + +# Either both WithVolCorr or both NoVolCorr, so we can merge them +function _pullbck_of_pushfwd(f, μ::PushforwardMeasure, ::V, volcorr::V) where {V} + pullbck(fchain(μ.finv) ∘ f, μ.origin, volcorr) +end + +function _pullbck_of_pushfwd(f, μ::PushforwardMeasure, _, volcorr) + PushforwardMeasure(inverse(f), f, μ, volcorr) +end + +function _generic_pullbck_impl(f, μ::DensityMeasure, volcorr::TransformVolCorr = WithVolCorr()) + mintegrate(fchain(μ.f) ∘ f, pullbck(f, μ.base, volcorr)) +end From 39a70dc566b9f1a3364bebbb74f57fb2fd5087c6 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 16 Jul 2023 10:51:24 +0200 Subject: [PATCH 083/133] Add OneTwoMany to deps --- Project.toml | 2 ++ src/MeasureBase.jl | 1 + test/Project.toml | 1 + 3 files changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index b8faa57f..1444d339 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +OneTwoMany = "762dc654-8631-413a-a342-372a7419ad9d" PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -49,6 +50,7 @@ LogExpFunctions = "0.3" LogarithmicNumbers = "1" MappedArrays = "0.4" NaNMath = "0.3, 1" +OneTwoMany = "0.1" PrettyPrinting = "0.3, 0.4" Random = "1" Reexport = "1" diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 3510846d..f476e1e4 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -27,6 +27,7 @@ import Base.iterate import ConstructionBase using ConstructionBase: constructorof using IntervalSets +using OneTwoMany: getsecond using PrettyPrinting const Pretty = PrettyPrinting diff --git a/test/Project.toml b/test/Project.toml index e30fa7a2..3f04208a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +OneTwoMany = "762dc654-8631-413a-a342-372a7419ad9d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" From a049b670e6edaadd28723ce11a0b56cdc6698c93 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 16 Jul 2023 10:52:03 +0200 Subject: [PATCH 084/133] STASH furtther specialize bind and combined --- src/combinators/bind.jl | 8 +++++++- src/combinators/combined.jl | 21 ++++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 6c40bba9..8d3b7b1a 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -81,11 +81,17 @@ logdensityof(posterior, θ) function mbind end export mbind -@inline function mbind(f_β, α::AbstractMeasure, f_c = x -> x[2]) +@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) + +@inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c) F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) Bind{F,M,G}(f_β, α, f_c) end +function _generic_mbind_impl(f_β, α::Dirac, f_c) + mcombine(f_c, α, f_β(α.value)) +end + """ struct MeasureBase.Bind <: AbstractMeasure diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 21595c26..b2ce1cb7 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -60,16 +60,31 @@ sets $$A$$ and $$B$$) function mcombine end export mcombine -function mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) +@inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage3(f_c, α, β) + +@inline function _generic_mcombine_impl_stage1(f_c, α::AbstractMeasure, β::AbstractMeasure) + _generic_mcombine_impl_stage2(f_c, α, β) +end + +@inline function _generic_mcombine_impl_stage2(f_c, α::AbstractMeasure, β::AbstractMeasure) FC, MA, MB = Core.Typeof(f_c), Core.Typeof(α), Core.Typeof(β) CombinedMeasure{FC,MA,MB}(f_c, α, β) end -function mcombine(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) +@inline function _generic_mcombine_impl_stage2(f_c, α::Dirac, β::Dirac) + Dirac(f_c(α.value, β.value)) +end + + +@inline _generic_mcombine_impl_stage1(::typeof(first), α::AbstractMeasure, β::AbstractMeasure) = α +@inline _generic_mcombine_impl_stage1(::typeof(getsecond), α::AbstractMeasure, β::AbstractMeasure) = β +@inline _generic_mcombine_impl_stage1(::typeof(last), α::AbstractMeasure, β::AbstractMeasure) = β + +@inline function _generic_mcombine_impl_stage1(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) productmeasure((α, β)) end -function mcombine(f_c::Union{typeof(vcat),typeof(merge)}, α::AbstractProductMeasure, β::AbstractProductMeasure) +@inline function _generic_mcombine_impl_stage1(f_c::Union{typeof(vcat),typeof(merge)}, α::AbstractProductMeasure, β::AbstractProductMeasure) productmeasure(f_c(marginals(α), marginals(β))) end From 498e9676cd64f934c57421a98a0d680a1b82445d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 16 Jul 2023 11:30:31 +0200 Subject: [PATCH 085/133] STASH smart ctors, canonical measure nesting --- src/combinators/smart-constructors.jl | 56 ++++++++++++++++----------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 19377cef..e1b5617f 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -1,4 +1,9 @@ +# Canonical measure type nesting, outer to inner: +# +# WeightedMeasure, Dirac, PowerMeasure, ProductMeasure + + ############################################################################### # Half @@ -19,25 +24,27 @@ Constructs a power of a measure `μ`. function powermeasure end export powermeasure -powermeasure(m::AbstractMeasure, ::Tuple{}) = asmeasure(m) +@inline powermeasure(μ, exponent) = _generic_powermeasure_impl(asmeasure(μ), _pm_axes(exponent)) + +@inline _generic_powermeasure_stage1(μ::AbstractMeasure, ::Tuple{}) = μ -@inline function powermeasure(x::T, sz::Tuple{Vararg{Any,N}}) where {T,N} - PowerMeasure(asmeasure(x), _pm_axes(sz)) +@inline function _generic_powermeasure_stage1(μ::AbstractMeasure, exponent::Tuple) + _generic_powermeasure_stage2(μ, exponent) end -function powermeasure( - μ::WeightedMeasure, - dims::Tuple{<:AbstractArray,Vararg{AbstractArray}}, -) - k = mapreduce(length, *, dims) * μ.logweight - return weightedmeasure(k, μ.base^dims) +@inline _generic_powermeasure_stage2(μ::AbstractMeasure, exponent::Tuple) = PowerMeasure(μ, exponent) + +@inline function _generic_powermeasure_stage2(μ::Dirac, exponent::Tuple) + Dirac(maybestatic_fill(μ.value, exponent)) end -function powermeasure(μ::WeightedMeasure, dims::NonEmptyTuple) - k = prod(dims) * μ.logweight - return weightedmeasure(k, μ.base^dims) +@inline function _generic_powermeasure_stage2(μ::WeightedMeasure, exponent::Tuple) + ν = μ.base^exponent + k = maybestatic_length(ν) * μ.logweight + return weightedmeasure(k, ν) end + ############################################################################### # ProductMeasure @@ -58,25 +65,30 @@ productmeasure((pushfwd(Mul(scale), StdExponential()) for scale in 0.1:0.2:2)) function productmeasure end export productmeasure -productmeasure(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) +@inline productmeasure(mar) = _generic_procuctmeasure_impl(mar) + +@inline _generic_procuctmeasure_impl(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) -productmeasure(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar) -productmeasure(mar::Tuple) = ProductMeasure(map(asmeasure, mar)) +@inline _generic_procuctmeasure_impl(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar) +_generic_procuctmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.value), mar) +_generic_procuctmeasure_impl(mar::Tuple) = productmeasure(map(asmeasure, mar)) -productmeasure(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) -productmeasure(mar::NamedTuple) = ProductMeasure(map(asmeasure, mar)) +@inline _generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) +_generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.value), mar) +_generic_procuctmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, mar)) -productmeasure(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) -productmeasure(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) +@inline _generic_procuctmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) +_generic_procuctmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) +_generic_procuctmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) -function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} +@inline function _generic_procuctmeasure_impl(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} return powermeasure(mar.f.value, axes(mar.data)) end -productmeasure(mar::Base.Generator) = ProductMeasure(mar) +@inline _generic_procuctmeasure_impl(mar::Base.Generator) = ProductMeasure(mar) # TODO: Make this static when its length is static -@inline function productmeasure( +@inline function _generic_procuctmeasure_impl( mar::AbstractArray{<:WeightedMeasure{StaticFloat64{W},M}}, ) where {W,M} return weightedmeasure(W * length(mar), productmeasure(map(basemeasure, mar))) From 04c9a767fc77b89f2d2d9e97596203c2a98b5fb9 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 17 Jul 2023 15:33:26 +0200 Subject: [PATCH 086/133] STASH Change pushfwd/pullback specialialization for DensityMeasure --- src/combinators/transformedmeasure.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 6377beee..ba31c05f 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -147,8 +147,11 @@ function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, volcorr) PushforwardMeasure(f, inverse(f), μ, volcorr) end -function _generic_pushfwd_impl(f, μ::DensityMeasure, volcorr::TransformVolCorr = WithVolCorr()) - mintegrate(fchain(μ.f) ∘ inverse(f), pushfwd(f, μ.base, volcorr)) +function _generic_pushfwd_impl(f::TransportFunction{NU,MU}, μ::DensityMeasure{F,MU}, ::WithVolCorr) where {NU,MU,F} + if !(f.μ === μ.base || f.μ === μ.base) + throw(ArgumentError("pushfwd on DensityMeasure with TransportFunction of same source measure type as the density base requires base and source to be equal.")) + end + mintegrate(fchain(μ.f) ∘ inverse(f), f.ν) end @@ -192,6 +195,9 @@ function _pullbck_of_pushfwd(f, μ::PushforwardMeasure, _, volcorr) PushforwardMeasure(inverse(f), f, μ, volcorr) end -function _generic_pullbck_impl(f, μ::DensityMeasure, volcorr::TransformVolCorr = WithVolCorr()) - mintegrate(fchain(μ.f) ∘ f, pullbck(f, μ.base, volcorr)) +function _generic_pullbck_impl(f::TransportFunction{NU,MU}, μ::DensityMeasure{F,NU}, ::WithVolCorr) where {NU,MU,F} + if !(f.ν === μ.base || f.ν === μ.base) + throw(ArgumentError("pushfwd on DensityMeasure with TransportFunction of same destination measure type as the density base requires base and destination to be equal.")) + end + mintegrate(fchain(μ.f) ∘ f, f.μ) end From 2f8c2c6597cc0abd3f3fd438ceb9f4f0d2ab6e1e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 17 Jul 2023 15:38:02 +0200 Subject: [PATCH 087/133] STASH Allow single-arg mbind and friends --- src/combinators/bind.jl | 2 ++ src/combinators/transformedmeasure.jl | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 8d3b7b1a..5551a5c9 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -81,6 +81,8 @@ logdensityof(posterior, θ) function mbind end export mbind +@inline mbind(f_β) = Base.Fix1(mbind, f_β) + @inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) @inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index ba31c05f..2aedeb00 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -116,8 +116,6 @@ end ############################################################################### # pushfwd -export pushfwd - """ pushfwd(f, μ, volcorr = WithVolCorr()) @@ -128,6 +126,11 @@ measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the To manually specify an inverse, call `pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ +function pushfwd end +export pushfwd + +pushfwd(f) = Base.Fix1(pushfwd, f) + pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pullbck_impl(f, μ, volcorr) function _generic_pushfwd_impl(f, μ, volcorr::TransformVolCorr = WithVolCorr()) @@ -173,10 +176,13 @@ some cases, we may be focusing on log-density (and not, for example, sampling). To manually specify an inverse, call `pullbck(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -pullbck(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pullbck_impl(f, μ, volcorr) + +function pullbck end export pullbck -@deprecate pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) pullbck(f, μ, volcorr) +pullbck(f) = Base.Fix1(pullbck, f) + +pullbck(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pullbck_impl(f, μ, volcorr) function _generic_pullbck_impl(f, μ, volcorr::TransformVolCorr = WithVolCorr()) PushforwardMeasure(inverse(f), f, μ, volcorr) @@ -201,3 +207,6 @@ function _generic_pullbck_impl(f::TransportFunction{NU,MU}, μ::DensityMeasure{F end mintegrate(fchain(μ.f) ∘ f, f.μ) end + + +@deprecate pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) pullbck(f, μ, volcorr) From 09067d7824b7b0a876c51fd53893344521385a38 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 17 Jul 2023 15:41:05 +0200 Subject: [PATCH 088/133] Checking insupport in PushforwardMeasure would be to expensive --- src/combinators/transformedmeasure.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 2aedeb00..976f2c34 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -80,7 +80,8 @@ end return logdensity_def(ν.origin, x) end -insupport(ν::PushforwardMeasure, y) = insupport(ν.origin, ν.finv(y)) +# ToDo: How to handle this better? +insupport(ν::PushforwardMeasure, y) = NoFastInsupport{typeof(ν)}() function testvalue(::Type{T}, ν::PushforwardMeasure) where {T} ν.f(testvalue(T, parent(ν))) From e0f9980b81f82f846833ac519743182d314cc473 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 29 Aug 2023 08:43:21 -0700 Subject: [PATCH 089/133] comment out duplicate --- src/combinators/transformedmeasure.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 976f2c34..672fed9f 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -189,9 +189,10 @@ function _generic_pullbck_impl(f, μ, volcorr::TransformVolCorr = WithVolCorr()) PushforwardMeasure(inverse(f), f, μ, volcorr) end -function _generic_pushfwd_impl(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) - _pullbck_of_pushfwd(f, μ, μ.volcorr, volcorr) -end +# TODO: Duplicated method - was this supposed to be `_generic_pullbck_impl`? +# function _generic_pushfwd_impl(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) +# _pullbck_of_pushfwd(f, μ, μ.volcorr, volcorr) +# end # Either both WithVolCorr or both NoVolCorr, so we can merge them function _pullbck_of_pushfwd(f, μ::PushforwardMeasure, ::V, volcorr::V) where {V} From 5d4f4583c47e72b89be84bbfa8bc101a285c433f Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 29 Aug 2023 08:43:47 -0700 Subject: [PATCH 090/133] fix typo --- src/combinators/product_transport.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 3f76a967..0a435edd 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -2,7 +2,7 @@ transport_to(ν, ::Type{MU}) where {NU<:StdMeasure} transport_to(::Type{NU}, μ) where {NU<:StdMeasure} -As a user convencience, a standard measure type like [`StdUniform`](@ref), +As a user convenience, a standard measure type like [`StdUniform`](@ref), [`StdExponential`](@ref), [`StdNormal`](@ref) or [`StdLogistic`](@ref) may be used directly as the source or target a measure transport. From 7438bc84adacdba0c64c2c97979b71be2363373e Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 29 Aug 2023 10:18:10 -0700 Subject: [PATCH 091/133] drop 2-argument `basemeasure` --- src/density-core.jl | 2 +- src/interface.jl | 2 +- src/utils.jl | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/density-core.jl b/src/density-core.jl index 4a61e940..bcd2e51e 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -101,7 +101,7 @@ end ℓ_0 = logdensity_def(μ, x) b_0 = μ Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number - b_{i} = basemeasure(b_{i - 1}, x) + b_{i} = basemeasure(b_{i - 1}) # The below makes the evaluated code shorter, but screws up Zygote # if b_{i} isa typeof(b_{i - 1}) diff --git a/src/interface.jl b/src/interface.jl index 18080ac7..f66c9893 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -64,7 +64,7 @@ function test_interface(μ::M) where {M} # testvalue, logdensityof x = @inferred testvalue(Float64, μ) - β = @inferred basemeasure(μ, x) + β = @inferred basemeasure(μ) ℓμ = @inferred logdensityof(μ, x) ℓβ = @inferred logdensityof(β, x) diff --git a/src/utils.jl b/src/utils.jl index 4a0c79a6..33fa9f96 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -19,8 +19,6 @@ testvalue(::Type{T}) where {T} = zero(T) export rootmeasure -basemeasure(μ, x) = basemeasure(μ) - """ rootmeasure(μ::AbstractMeasure) From ae2eeb5d6192f0b7a2eeb663e072d6ff898270a8 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 29 Aug 2023 10:51:13 -0700 Subject: [PATCH 092/133] bugfixes --- src/combinators/smart-constructors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index e1b5617f..b3b741de 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -24,7 +24,7 @@ Constructs a power of a measure `μ`. function powermeasure end export powermeasure -@inline powermeasure(μ, exponent) = _generic_powermeasure_impl(asmeasure(μ), _pm_axes(exponent)) +@inline powermeasure(μ, exponent) = _generic_powermeasure_stage1(asmeasure(μ), _pm_axes(exponent)) @inline _generic_powermeasure_stage1(μ::AbstractMeasure, ::Tuple{}) = μ @@ -35,7 +35,7 @@ end @inline _generic_powermeasure_stage2(μ::AbstractMeasure, exponent::Tuple) = PowerMeasure(μ, exponent) @inline function _generic_powermeasure_stage2(μ::Dirac, exponent::Tuple) - Dirac(maybestatic_fill(μ.value, exponent)) + Dirac(maybestatic_fill(μ.x, exponent)) end @inline function _generic_powermeasure_stage2(μ::WeightedMeasure, exponent::Tuple) From af0c68f90a7fa63f606f3a69bf9975a29bc9c8a9 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 29 Aug 2023 13:30:23 -0700 Subject: [PATCH 093/133] comment out questionable code --- src/combinators/smart-constructors.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index b3b741de..ea38f970 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -70,16 +70,18 @@ export productmeasure @inline _generic_procuctmeasure_impl(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) @inline _generic_procuctmeasure_impl(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar) -_generic_procuctmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.value), mar) +_generic_procuctmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.x), mar) _generic_procuctmeasure_impl(mar::Tuple) = productmeasure(map(asmeasure, mar)) @inline _generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) -_generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.value), mar) +_generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.x), mar) _generic_procuctmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, mar)) @inline _generic_procuctmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) -_generic_procuctmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) -_generic_procuctmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) + +# TODO: These methods don't make sense. What are they supposed to do? +# _generic_procuctmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) +# _generic_procuctmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) @inline function _generic_procuctmeasure_impl(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} return powermeasure(mar.f.value, axes(mar.data)) From f6a253986a16a1e913220985b9b008cd4c1b1b35 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 29 Aug 2023 14:14:28 -0700 Subject: [PATCH 094/133] oops my mistake --- src/combinators/smart-constructors.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index ea38f970..9b339e9c 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -79,9 +79,8 @@ _generic_procuctmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, ma @inline _generic_procuctmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) -# TODO: These methods don't make sense. What are they supposed to do? -# _generic_procuctmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) -# _generic_procuctmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) +_generic_procuctmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) +_generic_procuctmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) @inline function _generic_procuctmeasure_impl(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} return powermeasure(mar.f.value, axes(mar.data)) From 7a48e344fd5fd2678c04328d624d92737ecc7a73 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Fri, 1 Sep 2023 11:15:59 -0700 Subject: [PATCH 095/133] depend on ConstantRNGs.jl --- Project.toml | 2 ++ src/MeasureBase.jl | 1 + src/fixedrng.jl | 20 -------------------- src/utils.jl | 4 ++-- 4 files changed, 5 insertions(+), 22 deletions(-) delete mode 100644 src/fixedrng.jl diff --git a/Project.toml b/Project.toml index 1444d339..4f2320fe 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ConstantRNGs = "aa9b60e7-6b1c-4c29-a6e5-e43521412437" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -37,6 +38,7 @@ ArraysOfArrays = "0.6" ChainRulesCore = "1" ChangesOfVariables = "0.1.3" Compat = "3.35, 4" +ConstantRNGs = "0.1" ConstructionBase = "1.3" DensityInterface = "0.4" FillArrays = "0.12, 0.13, 1" diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index f476e1e4..3f6cd047 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -8,6 +8,7 @@ import Random: gentype using Statistics using LinearAlgebra +using ConstantRNGs import IntervalSets # This seems harder than it should be to get `IntervalSets.:(..)` @eval (using IntervalSets: $(Symbol(IntervalSets.:(..)))) diff --git a/src/fixedrng.jl b/src/fixedrng.jl deleted file mode 100644 index 31991418..00000000 --- a/src/fixedrng.jl +++ /dev/null @@ -1,20 +0,0 @@ -export FixedRNG -struct FixedRNG <: AbstractRNG end - -Base.rand(::FixedRNG) = one(Float64) / 2 -Random.randn(::FixedRNG) = zero(Float64) -Random.randexp(::FixedRNG) = one(Float64) - -# Use Random.BitFloatType instead of Real to avoid ambiguities: -Base.rand(::FixedRNG, ::Type{T}) where {T<:Random.BitFloatType} = one(T) / 2 -Random.randn(::FixedRNG, ::Type{T}) where {T<:Random.BitFloatType} = zero(T) -Random.randexp(::FixedRNG, ::Type{T}) where {T<:Random.BitFloatType} = one(T) - -# We need concrete type parameters to avoid amiguity for these cases -for T in [Float16, Float32, Float64] - @eval begin - Base.rand(::FixedRNG, ::Type{$T}) = one($T) / 2 - Random.randn(::FixedRNG, ::Type{$T}) = zero($T) - Random.randexp(::FixedRNG, ::Type{$T}) = one($T) - end -end diff --git a/src/utils.jl b/src/utils.jl index 33fa9f96..1d51be7d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,9 +11,9 @@ showparams(io::IO, nt::NamedTuple) = print(io, nt) export testvalue -@inline testvalue(μ) = rand(FixedRNG(), μ) +@inline testvalue(μ) = rand(ConstantRNG(), μ) -@inline testvalue(::Type{T}, μ) where {T} = rand(FixedRNG(), T, μ) +@inline testvalue(::Type{T}, μ) where {T} = rand(ConstantRNG(), T, μ) testvalue(::Type{T}) where {T} = zero(T) From 865797a414869cc1c625f523218f341613f91ef9 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 5 Sep 2023 14:59:37 -0700 Subject: [PATCH 096/133] fix typo --- src/combinators/smart-constructors.jl | 28 +++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 9b339e9c..d3ff5eb6 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -65,31 +65,31 @@ productmeasure((pushfwd(Mul(scale), StdExponential()) for scale in 0.1:0.2:2)) function productmeasure end export productmeasure -@inline productmeasure(mar) = _generic_procuctmeasure_impl(mar) +@inline productmeasure(mar) = _generic_productmeasure_impl(mar) -@inline _generic_procuctmeasure_impl(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) +@inline _generic_productmeasure_impl(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) -@inline _generic_procuctmeasure_impl(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar) -_generic_procuctmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.x), mar) -_generic_procuctmeasure_impl(mar::Tuple) = productmeasure(map(asmeasure, mar)) +@inline _generic_productmeasure_impl(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar) +_generic_productmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.x), mar) +_generic_productmeasure_impl(mar::Tuple) = productmeasure(map(asmeasure, mar)) -@inline _generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) -_generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.x), mar) -_generic_procuctmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, mar)) +@inline _generic_productmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) +_generic_productmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.x), mar) +_generic_productmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, mar)) -@inline _generic_procuctmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) +@inline _generic_productmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) -_generic_procuctmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) -_generic_procuctmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) +_generic_productmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) +_generic_productmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) -@inline function _generic_procuctmeasure_impl(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} +@inline function _generic_productmeasure_impl(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} return powermeasure(mar.f.value, axes(mar.data)) end -@inline _generic_procuctmeasure_impl(mar::Base.Generator) = ProductMeasure(mar) +@inline _generic_productmeasure_impl(mar::Base.Generator) = ProductMeasure(mar) # TODO: Make this static when its length is static -@inline function _generic_procuctmeasure_impl( +@inline function _generic_productmeasure_impl( mar::AbstractArray{<:WeightedMeasure{StaticFloat64{W},M}}, ) where {W,M} return weightedmeasure(W * length(mar), productmeasure(map(basemeasure, mar))) From 6746e47771a58f0b0de8a09ca943551046c54b32 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 6 Sep 2023 13:35:02 -0700 Subject: [PATCH 097/133] drop `fixedrng.jl` (using ConstantRNGs.jl instead) --- src/MeasureBase.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 3f6cd047..8da912af 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -172,7 +172,6 @@ include("combinators/conditional.jl") include("combinators/half.jl") include("rand.jl") -include("fixedrng.jl") include("interface.jl") From 7a5f4e14ac10503dff46a2c7ecd3c15eecb05d56 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 6 Sep 2023 13:38:27 -0700 Subject: [PATCH 098/133] Optimize product measures when `Base.issingltetontype` --- src/combinators/smart-constructors.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index d3ff5eb6..edd9f6f5 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -80,7 +80,15 @@ _generic_productmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, ma @inline _generic_productmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) _generic_productmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) -_generic_productmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar)) + +# TODO: We should be able to further optimize this +function _generic_productmeasure_impl(mar::AbstractArray{T}) where {T} + if Base.issingletontype(T) + first(mar) ^ size(mar) + else + ProductMeasure(asmeasure.(mar)) + end +end @inline function _generic_productmeasure_impl(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} return powermeasure(mar.f.value, axes(mar.data)) From 60509bde6a393c830f87893396ceb27ab6f41007 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 6 Sep 2023 14:03:15 -0700 Subject: [PATCH 099/133] small change to make JET happy --- src/domains.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/domains.jl b/src/domains.jl index 9579780b..9c1f1a21 100644 --- a/src/domains.jl +++ b/src/domains.jl @@ -106,7 +106,7 @@ function tangentat( one(T) - Statistics.corm(g1, zero(T), g2, zero(T)) < tol end -function zeroset(::CodimOne)::ZeroSet end +function zeroset(::CodimOne) end ########################################################### # Simplex From adc729dc710dca0449e5f8d9ed4602554066572b Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 7 Sep 2023 09:53:02 -0700 Subject: [PATCH 100/133] bugfix --- src/combinators/transformedmeasure.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 672fed9f..17d19497 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -132,7 +132,7 @@ export pushfwd pushfwd(f) = Base.Fix1(pushfwd, f) -pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pullbck_impl(f, μ, volcorr) +pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pushfwd_impl(f, μ, volcorr) function _generic_pushfwd_impl(f, μ, volcorr::TransformVolCorr = WithVolCorr()) PushforwardMeasure(f, inverse(f), μ, volcorr) From a0ace73b4572ca4e219acfb767dc2a77b9ec1470 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 20 Sep 2023 11:28:17 -0700 Subject: [PATCH 101/133] bugfix --- src/combinators/product_transport.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 0a435edd..fb1ff908 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -63,7 +63,7 @@ end # Transport between univariate standard measures and 1-dim power measures of size one: -function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure,1}, x) +function transport_def(ν::StdMeasure, μ::StdPowerMeasure{<:StdMeasure,1}, x) return transport_def(ν, μ.parent, only(x)) end @@ -110,7 +110,7 @@ end # Transport from a multivariate standard measure to any measure: -function transport_def(ν::AbstractMeasure, μ::StdPowerMeasure{MU,1}, x) where MU +function transport_def(ν::AbstractMeasure, μ::StdPowerMeasure{<:StdMeasure,1}, x) μ_inner = pwr_base(μ) _transport_from_mvstd(ν, μ_inner, x) end From b95b1a1510244653e9637fd561c9d1aeb1aa0b8a Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 20 Sep 2023 12:17:48 -0700 Subject: [PATCH 102/133] bugfix --- src/combinators/product_transport.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index fb1ff908..3b5e24fd 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -73,7 +73,7 @@ function transport_def(ν::StdPowerMeasure{<:StdMeasure,1}, μ::StdMeasure, x) return maybestatic_fill(transport_def(ν.parent, μ, x), sz_ν) end -function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{NU,1}, x,) where {MU,NU} +function transport_def(ν::StdPowerMeasure{<:StdMeasure,1}, μ::StdPowerMeasure{<:StdMeasure,1}, x,) return transport_to(ν.parent, μ.parent).(x) end From 303f8f162392ef0b2204ac8a2051f0ea5f483a3e Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 20 Sep 2023 12:46:50 -0700 Subject: [PATCH 103/133] make _dynamic work for empty ranges --- src/static.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/static.jl b/src/static.jl index 2529b093..f8acfa90 100644 --- a/src/static.jl +++ b/src/static.jl @@ -43,7 +43,14 @@ on the type of `n`. _dynamic(x::Number) = dynamic(x) _dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N) -_dynamic(r::AbstractUnitRange) = minimum(r):maximum(r) + +function _dynamic(r::AbstractUnitRange) + if isempty(r) + Base.OneTo(0) + else + minimum(r):maximum(r) + end +end """ MeasureBase.maybestatic_fill(x, sz::NTuple{N,<:IntegerLike}) where N From 999c39eb3de5892fabe7ebb197bee2bb84737b5a Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 20 Sep 2023 12:49:16 -0700 Subject: [PATCH 104/133] fix type instability --- src/static.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/static.jl b/src/static.jl index f8acfa90..3274a53e 100644 --- a/src/static.jl +++ b/src/static.jl @@ -46,7 +46,7 @@ _dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N) function _dynamic(r::AbstractUnitRange) if isempty(r) - Base.OneTo(0) + 1:0 else minimum(r):maximum(r) end From 8fe3b1403d40fe6c0d1cdb9dc194b6c90e97d0af Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 20 Sep 2023 12:52:36 -0700 Subject: [PATCH 105/133] specialize _dynamic(::Base.OneTo) --- src/static.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/static.jl b/src/static.jl index 3274a53e..3f77cf3a 100644 --- a/src/static.jl +++ b/src/static.jl @@ -43,6 +43,7 @@ on the type of `n`. _dynamic(x::Number) = dynamic(x) _dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N) +_dynamic(r::Base.OneTo) = Base.OneTo(dynamic(r.stop)) function _dynamic(r::AbstractUnitRange) if isempty(r) From 29f2c21870b785820f79d95172e0df7697e87c1f Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 20 Sep 2023 13:35:19 -0700 Subject: [PATCH 106/133] Integer => IntegerLike --- src/transport.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transport.jl b/src/transport.jl index d11ce318..93506b6b 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -188,7 +188,7 @@ end end @inline _transport_intermediate(ν, μ) = _transport_intermediate(fast_dof(ν), fast_dof(μ)) -@inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ +@inline _transport_intermediate(::IntegerLike, n_μ::IntegerLike) = StdUniform()^n_μ @inline _transport_intermediate(::StaticInteger{1}, ::StaticInteger{1}) = StdUniform() _call_transport_def(ν, μ, x) = transport_def(ν, μ, x) From e495c04e492728c05c543da0a87720d964bf997b Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 20 Sep 2023 14:38:17 -0700 Subject: [PATCH 107/133] Simplify, drop _reorder_nt --- src/collection_utils.jl | 16 ---------------- src/combinators/product.jl | 6 ++---- src/combinators/product_transport.jl | 2 +- 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index 5698174a..c974b018 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -66,22 +66,6 @@ InverseFunctions.inverse(::TupleUnNamer{names}) where names = TupleNamer{names}( ChangesOfVariables.with_logabsdet_jacobian(::TupleUnNamer{names}, x::NamedTuple{names}) where names = static(false) =# -_reorder_nt(x::NamedTuple{names},::Val{names}) where {names} = x - -@generated function _reorder_nt(x::NamedTuple{names},::Val{new_names}) where {names,new_names} - if sort([names...]) != sort([new_names...]) - :(throw(ArgumentError("Can't reorder NamedTuple{$names} to NamedTuple{$new_names}"))) - else - expr = :(()) - for nm in new_names - push!(expr.args, :($nm = x.$nm)) - end - return expr - end -end - -# ToDo: Add custom rrule for _reorder_nt? - # Field access functions for Fill: _fill_value(x::FillArrays.Fill) = x.value _fill_axes(x::FillArrays.Fill) = x.axes diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 20084fd3..70c554d0 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -102,8 +102,7 @@ end sum(map(density_op, marginals_μ, x)) end @inline function _marginals_density_op(density_op::F, marginals_μ::NamedTuple{names}, x::NamedTuple) where {F,names} - nms = Val{names}() - _marginals_density_op(density_op, values(marginals_μ), values(_reorder_nt(x, Val(nms)))) + _marginals_density_op(density_op, values(marginals_μ), values(NamedTuple{names}(x))) end function _marginals_density_op(density_op::F, marginals_μ, marginals_ν, x) where F @@ -114,8 +113,7 @@ end sum(map(density_op, marginals_μ, marginals_ν, x)) end @inline function _marginals_density_op(density_op::F, marginals_μ::NamedTuple{names}, marginals_ν::NamedTuple, x::NamedTuple) where {F,names} - nms = Val{names}() - _marginals_density_op(density_op, values(marginals_μ), values(_reorder_nt(marginals_ν, nms)), values(_reorder_nt(x, nms))) + _marginals_density_op(density_op, values(marginals_μ), values(NamedTuple{names}(marginals_ν)), values(NamedTuple{names}(x))) end diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index 3b5e24fd..d81832ac 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -188,7 +188,7 @@ _marginals_from_origin(::Fill, x_origin) = x_origin # used with a (power of a) standard measure on one side. _marginals_tp_origin(marginals_μ::NamedTuple{names}) where names = productmeasure(values(marginals_μ)) -_marginals_from_origin(::NamedTuple{names}, x_origin::NamedTuple) where names = _reorder_nt(x_origin, Val(names)) +_marginals_from_origin(::NamedTuple{names}, x_origin::NamedTuple) where names = NamedTuple{names}(x_origin) # Transport between two instances of ProductMeasure: From d49f63f62614f70e0908d31869f6b6a15080ff6d Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 21 Sep 2023 09:03:55 -0700 Subject: [PATCH 108/133] fix tuple methods --- src/collection_utils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index c974b018..a9f0765c 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -77,5 +77,6 @@ _flatten_to_rv(VV::AbstractVector{<:StaticVector{N,<:Real}}) where N = flatview( _flatten_to_rv(VV::VectorOfSimilarVectors{<:Real}) = flatview(VV) _flatten_to_rv(VV::VectorOfVectors{<:Real}) = flatview(VV) -_flatten_to_rv(tpl::Tuple{<:AbstractVector{<:Real}}) = vcat(tpl...) -_flatten_to_rv(tpl::Tuple{<:StaticVector{N,<:Real}}) where N = vcat(tpl...) +_flatten_to_rv(::Tuple{}) = [] +_flatten_to_rv(tpl::Tuple{Vararg{AbstractVector}}) = vcat(tpl...) +_flatten_to_rv(tpl::Tuple{Vararg{StaticVector}}) = vcat(tpl...) From 115e0c669d1508f934649dcf9d6b6b080b2d0ed7 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 21 Sep 2023 09:23:50 -0700 Subject: [PATCH 109/133] small bugfixes (issues caught by JET) --- src/combinators/smart-constructors.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index edd9f6f5..e803896b 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -70,11 +70,11 @@ export productmeasure @inline _generic_productmeasure_impl(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) @inline _generic_productmeasure_impl(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar) -_generic_productmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.x), mar) +_generic_productmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.x, mar)) _generic_productmeasure_impl(mar::Tuple) = productmeasure(map(asmeasure, mar)) @inline _generic_productmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) -_generic_productmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.x), mar) +_generic_productmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.x, mar)) _generic_productmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, mar)) @inline _generic_productmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) @@ -105,7 +105,7 @@ end # ToDo: Remove or at least refactor this (ProductMeasure shouldn't take a kernel at it's argument). -productmeasure(f, param_maps, pars) = ProductMeasure(kernel(f, param_maps), pars) +productmeasure(f, param_maps, pars) = productmeasure(kernel(f, param_maps), pars) function productmeasure(k::ParameterizedTransitionKernel, pars) productmeasure(k.suff, k.param_maps, pars) From 42ec20f3457ddd5a8ab224b8651d9218c5d37139 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 21 Sep 2023 09:27:53 -0700 Subject: [PATCH 110/133] Change `...stage3` call to `...stage1` (there is no stage 3) --- src/combinators/combined.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index b2ce1cb7..485bc12e 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -60,7 +60,7 @@ sets $$A$$ and $$B$$) function mcombine end export mcombine -@inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage3(f_c, α, β) +@inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage1(f_c, α, β) @inline function _generic_mcombine_impl_stage1(f_c, α::AbstractMeasure, β::AbstractMeasure) _generic_mcombine_impl_stage2(f_c, α, β) From 760faff489173f3cafa89ed6082543a7d95c8686 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 21 Sep 2023 09:29:27 -0700 Subject: [PATCH 111/133] fix a Dirac call (`.value` => `.x`) --- src/combinators/combined.jl | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 485bc12e..11ebc558 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -62,20 +62,6 @@ export mcombine @inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage1(f_c, α, β) -@inline function _generic_mcombine_impl_stage1(f_c, α::AbstractMeasure, β::AbstractMeasure) - _generic_mcombine_impl_stage2(f_c, α, β) -end - -@inline function _generic_mcombine_impl_stage2(f_c, α::AbstractMeasure, β::AbstractMeasure) - FC, MA, MB = Core.Typeof(f_c), Core.Typeof(α), Core.Typeof(β) - CombinedMeasure{FC,MA,MB}(f_c, α, β) -end - -@inline function _generic_mcombine_impl_stage2(f_c, α::Dirac, β::Dirac) - Dirac(f_c(α.value, β.value)) -end - - @inline _generic_mcombine_impl_stage1(::typeof(first), α::AbstractMeasure, β::AbstractMeasure) = α @inline _generic_mcombine_impl_stage1(::typeof(getsecond), α::AbstractMeasure, β::AbstractMeasure) = β @inline _generic_mcombine_impl_stage1(::typeof(last), α::AbstractMeasure, β::AbstractMeasure) = β @@ -88,6 +74,18 @@ end productmeasure(f_c(marginals(α), marginals(β))) end +@inline function _generic_mcombine_impl_stage1(f_c, α::AbstractMeasure, β::AbstractMeasure) + _generic_mcombine_impl_stage2(f_c, α, β) +end + +@inline function _generic_mcombine_impl_stage2(f_c, α::AbstractMeasure, β::AbstractMeasure) + FC, MA, MB = Core.Typeof(f_c), Core.Typeof(α), Core.Typeof(β) + CombinedMeasure{FC,MA,MB}(f_c, α, β) +end + +@inline function _generic_mcombine_impl_stage2(f_c, α::Dirac, β::Dirac) + Dirac(f_c(α.x, β.x)) +end """ struct CombinedMeasure <: AbstractMeasure From 24964bbfdf40e162498530b7c2ffaa065a947bda Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 21 Sep 2023 09:37:39 -0700 Subject: [PATCH 112/133] Fixing another Dirac --- src/combinators/bind.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 5551a5c9..9af12c93 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -91,7 +91,7 @@ export mbind end function _generic_mbind_impl(f_β, α::Dirac, f_c) - mcombine(f_c, α, f_β(α.value)) + mcombine(f_c, α, f_β(α.x)) end From 886f927cd6339573df614128ba8347e21cb22e12 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 21 Sep 2023 10:02:55 -0700 Subject: [PATCH 113/133] nit-picking --- src/combinators/product_transport.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index d81832ac..ac8c1045 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -48,7 +48,7 @@ function from_origin(μ::PowerMeasure{<:Any,N}, x_origin) where N end -# A one-dimensional PowerMeasure has an origin if it's parent has an origin: +# A one-dimensional PowerMeasure has an origin if its parent has an origin: transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _pwr_origin(typeof(μ), pwr_base(μ), pwr_axes(μ)) _pwr_origin(::Type{MU}, parent_origin, axes) where MU = parent_origin^axes From 1c9f69e44cb1dfce76e61d3937465b6430d3841f Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 21 Sep 2023 14:43:10 -0700 Subject: [PATCH 114/133] specialize transports between StdPowerMeasures with the same base --- src/combinators/product_transport.jl | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl index ac8c1045..f87219cf 100644 --- a/src/combinators/product_transport.jl +++ b/src/combinators/product_transport.jl @@ -60,21 +60,39 @@ function from_origin(μ::PowerMeasure{<:AbstractMeasure,1}, x_origin) from_origin.(Ref(μ.parent), x_origin) end +# Specialize for case of equal bases. Because of StdPowerMeasure methods below +# specify `,1`, we need extra methods to avoid ambiguity + +function transport_def(ν::StdPowerMeasure{MU}, μ::StdPowerMeasure{MU}, x) where MU + reshape(x, ν.axes) +end + +function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{MU}, x) where MU + reshape(x, ν.axes) +end + +function transport_def(ν::StdPowerMeasure{MU}, μ::StdPowerMeasure{MU,1}, x) where MU + reshape(x, ν.axes) +end + +function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{MU,1}, x) where MU + reshape(x, ν.axes) +end # Transport between univariate standard measures and 1-dim power measures of size one: -function transport_def(ν::StdMeasure, μ::StdPowerMeasure{<:StdMeasure,1}, x) +function transport_def(ν::StdMeasure, μ::StdPowerMeasure{MU,1}, x) where {MU} return transport_def(ν, μ.parent, only(x)) end -function transport_def(ν::StdPowerMeasure{<:StdMeasure,1}, μ::StdMeasure, x) +function transport_def(ν::StdPowerMeasure{NU,1}, μ::StdMeasure, x) where {NU} sz_ν = pwr_size(ν) @assert prod(sz_ν) == 1 return maybestatic_fill(transport_def(ν.parent, μ, x), sz_ν) end -function transport_def(ν::StdPowerMeasure{<:StdMeasure,1}, μ::StdPowerMeasure{<:StdMeasure,1}, x,) - return transport_to(ν.parent, μ.parent).(x) +function transport_def(ν::StdPowerMeasure{NU,1}, μ::StdPowerMeasure{MU,1}, x) where {NU,MU} + reshape(transport_to(ν.parent, μ.parent).(x), ν.axes) end @@ -110,7 +128,7 @@ end # Transport from a multivariate standard measure to any measure: -function transport_def(ν::AbstractMeasure, μ::StdPowerMeasure{<:StdMeasure,1}, x) +function transport_def(ν::AbstractMeasure, μ::StdPowerMeasure{MU,1}, x) where {MU} μ_inner = pwr_base(μ) _transport_from_mvstd(ν, μ_inner, x) end From f1274fc81afb1b974f596f6634f2b146d7c652b9 Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Mon, 23 Oct 2023 11:26:08 +0200 Subject: [PATCH 115/133] addition to docstring in bind.jl --- src/combinators/bind.jl | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 9af12c93..cdbd7006 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -54,9 +54,9 @@ Bayesian example with a correlated prior, that models the ```julia using MeasureBase, AffineMaps -prior = mbind +prior = mbind( productmeasure(( - value => StdNormal() + position = StdNormal(), )), merge ) do a productmeasure(( @@ -64,7 +64,18 @@ prior = mbind )) end -model = θ -> pushfwd(MulAdd(θ.noise, θ.value), StdNormal())^10 +prior = mbind + a -> productmeasure(( + noise = pushfwd(sqrt ∘ Mul(abs(a.position)), StdExponential()) + )), + productmeasure(( + position = StdNormal(), + )), + merge +) + + +model = θ -> pushfwd(MulAdd(θ.noise, θ.position), StdNormal())^10 joint_θ_obs = mbind(model, prior, tuple) prior_predictive = mbind(model, prior) From 4ac8b781cd38188a64cac8f34616b7944875426e Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Mon, 23 Oct 2023 12:51:17 +0200 Subject: [PATCH 116/133] fixed _split_after in collection_utils.jl --- src/collection_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/collection_utils.jl b/src/collection_utils.jl index a9f0765c..1e3e4f76 100644 --- a/src/collection_utils.jl +++ b/src/collection_utils.jl @@ -32,15 +32,15 @@ end end @inline _split_after(x::Tuple, n) = _split_after(x::Tuple, Val{n}()) -@inline _split_after(x::Tuple, ::Val{N}) where N = x[begin:begin+N-1], x[N:end] +@inline _split_after(x::Tuple, ::Val{N}) where N = x[begin:begin+N-1], x[begin+N:end] @generated function _split_after(x::NamedTuple{names}, ::Val{names_a}) where {names, names_a} n = length(names_a) if names[begin:begin+n-1] == names_a - names_b = names[n:end] + names_b = names[begin+n:end] quote - a, b = _split_after(x, Val(n)) - NamedTuple{names_a}(a), NamedTuple{names_b}(b) + a, b = _split_after(values(x), Val($n)) + NamedTuple{$names_a}(a), NamedTuple{$names_b}(b) end else quote From 5c2bedbaf75fabec0320eea4c60e0ea5952a04f1 Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Tue, 24 Oct 2023 16:19:05 +0200 Subject: [PATCH 117/133] Temporary change to mbind constructor: If no f_c is specified, default to get_second_tmp instead of getsecond until getsecond satisfies conditions in docstring --- src/combinators/bind.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index cdbd7006..d15397c0 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -94,7 +94,8 @@ export mbind @inline mbind(f_β) = Base.Fix1(mbind, f_β) -@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) +#@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) --- temporary --- +@inline mbind(f_β, α::AbstractMeasure, f_c = get_second_tmp) = _generic_mbind_impl(f_β, α, f_c) @inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c) F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) @@ -236,3 +237,13 @@ function transport_from_mvstd_with_rest(ν::Bind, μ_inner::StdMeasure, x) b, x_rest = transport_from_mvstd_with_rest(β_a, μ_inner, x2) return ν.f_c(a, b), x_rest end + +#temporary (getsecond does not satisfy condition on f_c as described in docstring) --- +function get_first_tmp(a, b) + return a +end + +function get_second_tmp(a, b) + return b +end +#--- \ No newline at end of file From dd5c9023817c3765c656c3e8d2d33fb49d046304 Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Tue, 24 Oct 2023 16:21:08 +0200 Subject: [PATCH 118/133] added a rand function to lebesgue.jl, corresponding to uniform distribution --- src/primitives/lebesgue.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index 2d2ed8dd..e6b72ee4 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -59,6 +59,8 @@ Base.:∘(::typeof(basemeasure), ::Type{Lebesgue}) = LebesgueBase() Base.show(io::IO, d::Lebesgue) = print(io, "Lebesgue(", d.support, ")") +Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::Lebesgue) where {T} = rand(rng, T) + insupport(μ::Lebesgue, x) = x ∈ μ.support insupport(::Lebesgue{RealNumbers}, ::Real) = true From b6ae089a82de8915d8abcc22c41864c04d200e4b Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Wed, 25 Oct 2023 16:05:22 +0200 Subject: [PATCH 119/133] fixed _eval_k --- src/combinators/likelihood.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 8578040f..863ea7d4 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -100,7 +100,7 @@ Likelihood(k, x::X) where {X} = Likelihood{Core.Typeof(k),X}(k, x) (lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(_eval_k(lik, p), lik.x)) -_eval_k(ℓ::AbstractLikelihood, p) = asmeasure(_eval_k(ℓ, p)) +_eval_k(ℓ::AbstractLikelihood, p) = asmeasure(ℓ.k(p)) DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity() From cd511f82714b3b973176a5fbb2ffaa01dc3b627f Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Thu, 26 Oct 2023 10:01:32 +0200 Subject: [PATCH 120/133] rand on Lebesgue doesn't make sense in general --- src/primitives/lebesgue.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index e6b72ee4..2d2ed8dd 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -59,8 +59,6 @@ Base.:∘(::typeof(basemeasure), ::Type{Lebesgue}) = LebesgueBase() Base.show(io::IO, d::Lebesgue) = print(io, "Lebesgue(", d.support, ")") -Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::Lebesgue) where {T} = rand(rng, T) - insupport(μ::Lebesgue, x) = x ∈ μ.support insupport(::Lebesgue{RealNumbers}, ::Real) = true From 75b0caecfb544730b44262b5674cbb44f7d6cd6f Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Thu, 26 Oct 2023 10:04:26 +0200 Subject: [PATCH 121/133] extension to insupport for PreoductMeasures in case there is NoFastInsupport --- src/combinators/product.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 70c554d0..2e8caee7 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -181,7 +181,10 @@ end @inline function insupport(d::ProductMeasure, x) for (mj, xj) in zip(marginals(d), x) - dynamic(insupport(mj, xj)) || return false + insup = dynamic(insupport(mj, xj)) + if insup isa NoFastInsupport || insup == false + return insup + end end return true end From 80d372e46225e80ceddff09fc576d87e0f7893dc Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Thu, 26 Oct 2023 10:06:53 +0200 Subject: [PATCH 122/133] change to insupport for PowerMeasures, temporary workaround for _all if there is NoFastInsupport --- src/combinators/power.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 7c428e27..d613cf60 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -140,12 +140,17 @@ end end end +_all(A) = all(A) +_all(::AbstractArray{NoFastInsupport{T}}) where T = NoFastInsupport{T}() + + @inline function insupport(μ::PowerMeasure, x::AbstractArray) p = μ.parent - all(x) do xj + insupp = broadcast(x) do xj # https://github.com/SciML/Static.jl/issues/36 dynamic(insupport(p, xj)) end + _all(insupp) end @inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(pwr_size(μ)) From c5d5e115403c9400405153326e6156486034dab7 Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Thu, 26 Oct 2023 10:14:47 +0200 Subject: [PATCH 123/133] fixed docstring about mintegral, mintegral_exp --- src/density.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/density.jl b/src/density.jl index 57367ec5..5f37401b 100644 --- a/src/density.jl +++ b/src/density.jl @@ -81,7 +81,7 @@ A `DensityMeasure` is a measure defined by a density or log-density with respect to some other "base" measure. Users should not instantiate `DensityMeasure` directly, but should instead -call `mintegral_exp(f, base)` (if `f` is a density function or +call `mintegral(f, base)` (if `f` is a density function or `DensityInterface.IsDensity` object) or `mintegral_exp(f, base)` (if `f` is a log-density function). """ @@ -111,6 +111,8 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) density_def(μ::DensityMeasure, x) = densityof(μ.f, x) +localmeasure(μ::DensityMeasure, x) = DensityMeasure(μ.f, localmeasure(μ.base, x)) + @doc raw""" mintegrate(f, μ::AbstractMeasure)::AbstractMeasure From c3f8c5c19e00c3bb96fcd2c1683a3c1053c81ff2 Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Thu, 26 Oct 2023 16:11:43 +0200 Subject: [PATCH 124/133] explained example in docstring fixed some typos in docstring --- src/combinators/bind.jl | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index d15397c0..11d58c52 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -28,7 +28,7 @@ When using the default `fc = x -> x[2]` (so `ab == b`) this simplies to \mu(B) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) ``` -which is equivalent to a monatic bind, viewing measures as monads. +which is equivalent to a monadic bind, viewing measures as monads. Computationally, `ab = rand(μ)` is equivalent to @@ -49,7 +49,16 @@ support other choices for `f_c`. # Extended help -Bayesian example with a correlated prior, that models the +Bayesian example with a correlated prior: Mathematically, let + +position = a1 ~ StdNormal(), +noise = a2 ~ pushforward(h(a1, .), StdExponential()) + +where `h(a1,a2) = √(abs(a1) * a2)`. Note that StdNormal() and StdExponential() +are StdMeasures, but it makes perfect sense to sample from them, as they have +base measures and therefore densities. +Because the prior on the space of `A = A1 × A2 = (position, noise)` is a +hierarchical measure, we can construct it using mbind by setting merge as f_c: ```julia using MeasureBase, AffineMaps @@ -60,21 +69,10 @@ prior = mbind( )), merge ) do a productmeasure(( - noise = pushfwd(sqrt ∘ Mul(abs(a.position)), StdExponential()) + noise = pushfwd(setinverse(sqrt, setladj(x -> x^2, x -> log(2))) ∘ Mul(abs(a.position)), StdExponential()), )) end -prior = mbind - a -> productmeasure(( - noise = pushfwd(sqrt ∘ Mul(abs(a.position)), StdExponential()) - )), - productmeasure(( - position = StdNormal(), - )), - merge -) - - model = θ -> pushfwd(MulAdd(θ.noise, θ.position), StdNormal())^10 joint_θ_obs = mbind(model, prior, tuple) From e18c0cc7cf12baa269db0410844ed1f63f9b6a02 Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Fri, 27 Oct 2023 13:35:47 +0200 Subject: [PATCH 125/133] change to insupport: != false instead of == true to account for NoFastInsupport --- src/density.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/density.jl b/src/density.jl index 5f37401b..543e33d6 100644 --- a/src/density.jl +++ b/src/density.jl @@ -96,7 +96,8 @@ struct DensityMeasure{F,B} <: AbstractMeasure end @inline function insupport(d::DensityMeasure, x) - insupport(d.base, x) == true && isfinite(logdensityof(getfield(d, :f), x)) + # ToDo: should not evaluate f + insupport(d.base, x) != false && isfinite(logdensityof(getfield(d, :f), x)) end function Pretty.tile(μ::DensityMeasure{F,B}) where {F,B} From eae9a753d38aee5367f9127a38afea121510caf5 Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Fri, 27 Oct 2023 13:41:05 +0200 Subject: [PATCH 126/133] completed docstring bind.jl --- src/combinators/bind.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 11d58c52..83390690 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -54,11 +54,10 @@ Bayesian example with a correlated prior: Mathematically, let position = a1 ~ StdNormal(), noise = a2 ~ pushforward(h(a1, .), StdExponential()) -where `h(a1,a2) = √(abs(a1) * a2)`. Note that StdNormal() and StdExponential() -are StdMeasures, but it makes perfect sense to sample from them, as they have -base measures and therefore densities. -Because the prior on the space of `A = A1 × A2 = (position, noise)` is a -hierarchical measure, we can construct it using mbind by setting merge as f_c: +where `h(a1,a2) = √(abs(a1) * a2)`. +Because this prior on the space of `A = A1 × A2 = (position, noise)` is a +hierarchical measure (a2 depends on a1), we can construct it using mbind by +setting merge as f_c: ```julia using MeasureBase, AffineMaps From 0a6646b56b036d850cc3e58f779c080ad022c026 Mon Sep 17 00:00:00 2001 From: HannahMeilchen Date: Fri, 27 Oct 2023 14:09:19 +0200 Subject: [PATCH 127/133] typo in docstring --- src/measure_operators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/measure_operators.jl b/src/measure_operators.jl index 5822d4de..90f973f0 100644 --- a/src/measure_operators.jl +++ b/src/measure_operators.jl @@ -58,7 +58,7 @@ export ⊙ The `\\triangleright` operator denotes a measure monadic bind operation. -A common operator choice for a monadics bind operator is `>>=` (e.g. in +A common operator choice for a monadic bind operator is `>>=` (e.g. in the Haskell programming language), but this has a different meaning in Julia and there is no close equivalent, so we use `▷`. From 5c4ae2c2bdeab3899bd2019fc1e5e04089c1d495 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 7 Nov 2023 10:08:45 +0100 Subject: [PATCH 128/133] Require OneTwoMany v0.1.2 for secondarg --- Project.toml | 2 +- src/MeasureBase.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 4f2320fe..4edbbfd9 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ LogExpFunctions = "0.3" LogarithmicNumbers = "1" MappedArrays = "0.4" NaNMath = "0.3, 1" -OneTwoMany = "0.1" +OneTwoMany = "0.1.2" PrettyPrinting = "0.3, 0.4" Random = "1" Reexport = "1" diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 8da912af..093793c8 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -28,7 +28,7 @@ import Base.iterate import ConstructionBase using ConstructionBase: constructorof using IntervalSets -using OneTwoMany: getsecond +using OneTwoMany: secondarg using PrettyPrinting const Pretty = PrettyPrinting From 63ee36f06ff9d09c2eb3c0bd94f87f51bb9893cd Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 7 Nov 2023 10:09:03 +0100 Subject: [PATCH 129/133] STASH add mkernel and MKernel --- src/combinators/bind.jl | 57 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 83390690..bc9aef1e 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -1,8 +1,56 @@ @doc raw""" - mbind(f_β, α::AbstractMeasure, f_c = x -> x[2]) + mkernel(f_β, f_c = OneTwoMany.secondarg)::Function + +Constructs generalized monadic transistion kernel from a primary transition +kernel function `f_β` and a value combination function `f_c`. + +`f_β` must behave like `β = f_β(a)`, taking a value `a` from a primary +measurable space and return a measure-like object `β`. + +`f_c` must behave like `c = f_c(a, b)`, taking a value `a` (like f_β) and a +value `b` from the measurable space of `β` and return a value `c`. + +`f_k = mkernel(f_β, f_c)` then acts like + +```julia +f_k(a) ≡ pushforward(c -> f_c(c[1], c[2]), productmeasure((Dirac(a), f_β(a)))) +``` + +(`≡` denoting pseudocode-equivalency here). So with the default +`f_c == OneTwoMany.secondarg`, we just have `f_k(a) ≡ f_β(a) + +Also, + +```julia +mbind(mkernel(f_β, f_c), α) == mbind(f_β, α, f_c) +``` + +See also [`mbind`](@ref). +""" +function mkernel end +export mkernel + + +""" + struct MeasureBase.MKernel <: Function + +Represents a generalized monatic transition kernel. + +User code should not create instances of `MKernel` directly, but should call +[`mkernel`](@ref) instead. +""" +struct MKernel + f_β::FK + f_c::FC +end + + +@doc raw""" + mbind(f_β, α::AbstractMeasure, f_c = OneTwoMany.secondarg) + mbind(f_β::MeasureBase.MKernel, α::AbstractMeasure) Constructs a monadic bind, resp. a hierarchical measure, from a transition -kernel function `f_β`, a primary measure `α` and a variate combination +kernel function `f_β`, a primary measure `α` and a value combination function `f_c`. `f_β` must be a function that maps a point `a` from the space of the primary @@ -22,7 +70,7 @@ has the mathethematical interpretation (on sets $$A$$ and $$B$$) \mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) ``` -When using the default `fc = x -> x[2]` (so `ab == b`) this simplies to +When using the default `fc = OneTwoMany.secondarg` (so `ab == b`) this simplies to ```math \mu(B) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) @@ -91,6 +139,9 @@ export mbind @inline mbind(f_β) = Base.Fix1(mbind, f_β) +# ToDo: Store MKernel in Bind instead of separate fields f_β and f_c? +@inline mbind(f_k::MKernel, α::AbstractMeasure) = mbind(f_k.f_β, α, f_k.f_c) + #@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) --- temporary --- @inline mbind(f_β, α::AbstractMeasure, f_c = get_second_tmp) = _generic_mbind_impl(f_β, α, f_c) From 336c1d2a318a37639b452fe4b26b4edfa638a589 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 7 Nov 2023 10:14:23 +0100 Subject: [PATCH 130/133] FIXUP secondarg --- src/combinators/bind.jl | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index bc9aef1e..273f99f0 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -143,7 +143,7 @@ export mbind @inline mbind(f_k::MKernel, α::AbstractMeasure) = mbind(f_k.f_β, α, f_k.f_c) #@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) --- temporary --- -@inline mbind(f_β, α::AbstractMeasure, f_c = get_second_tmp) = _generic_mbind_impl(f_β, α, f_c) +@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, α, f_c) @inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c) F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) @@ -285,13 +285,3 @@ function transport_from_mvstd_with_rest(ν::Bind, μ_inner::StdMeasure, x) b, x_rest = transport_from_mvstd_with_rest(β_a, μ_inner, x2) return ν.f_c(a, b), x_rest end - -#temporary (getsecond does not satisfy condition on f_c as described in docstring) --- -function get_first_tmp(a, b) - return a -end - -function get_second_tmp(a, b) - return b -end -#--- \ No newline at end of file From 8dbdc52e00494c2f057c41fc26383655c9463917 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 7 Nov 2023 10:21:15 +0100 Subject: [PATCH 131/133] STASH mkernel --- src/combinators/bind.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 273f99f0..b97a7d90 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -30,6 +30,12 @@ See also [`mbind`](@ref). function mkernel end export mkernel +@inline mkernel(f_β::MKernel) = f_β +@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c) + +@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c) +@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β + """ struct MeasureBase.MKernel <: Function @@ -139,7 +145,6 @@ export mbind @inline mbind(f_β) = Base.Fix1(mbind, f_β) -# ToDo: Store MKernel in Bind instead of separate fields f_β and f_c? @inline mbind(f_k::MKernel, α::AbstractMeasure) = mbind(f_k.f_β, α, f_k.f_c) #@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) --- temporary --- @@ -169,6 +174,9 @@ struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure f_c::FC end +# ToDo: Store MKernel in Bind instead of separate fields f_β and f_c? + + """ MeasureBase.transportmeasure(μ::Bind, x)::AbstractMeasure From 56157f94524da575e9f7cb29ff6722a1e1d84ca6 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 7 Nov 2023 10:42:07 +0100 Subject: [PATCH 132/133] STASH bindkernel and boundmeasure --- src/combinators/bind.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index b97a7d90..a70f27eb 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -93,6 +93,13 @@ b = rand(β_a) ab = f_c(a, b) ``` +The measure `α` that went into the bind can be retrieved via +`boundmeasure(mbind(f_β, α, ...)) == α`. + +`mbind(f_β, α, f_c)` is equivalent to `mbind(mkernel(f_β, f_c), α)` +(see [`mkernel`](@ref)) with +`bindkernel(mbind(mkernel(f_β, f_c), α)) == mbind(mkernel(f_β, f_c)`. + Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` can be unambiguously split into `a` and `b` again, knowing `α`. This is currently implemented for `f_c` that is either tuple or `=>`/`Pair` (these @@ -177,6 +184,33 @@ end # ToDo: Store MKernel in Bind instead of separate fields f_β and f_c? +""" + bindkernel(μ::Bind)::MKernel + +Returns the monatic transition kernel of a monatic bind, so that +`bindkernel(mbind(f_k::MKernel, α)) == f_k`. + +See [`mbind`](@ref) and [`mkernel`](@ref) for details. +""" +function bindkernel end +export bindkernel + +bindkernel(μ::Bind) = mkernel(μ.f_β, μ.f_c) + + +""" + boundmeasure(μ::Bind)::MKernel + +Returns the measure that went into a monatic bind, so that +`boundmeasure(mbind(f_k, α)) == α`. + +See [`mbind`](@ref) and [`mkernel`](@ref) for details. +""" +function boundmeasure end +export boundmeasure + +boundmeasure(μ::Bind) = mkernel(μ.f_β, μ.f_c) + """ MeasureBase.transportmeasure(μ::Bind, x)::AbstractMeasure From 2d017c96fad3b7ce30b6ed0193784385c3da4803 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 7 Nov 2023 11:29:37 +0100 Subject: [PATCH 133/133] STASH bind --- src/MeasureBase.jl | 2 +- src/combinators/bind.jl | 39 ++++++++++++++++++++----------------- src/combinators/combined.jl | 5 ++--- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 093793c8..4fc3bf6a 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -28,7 +28,7 @@ import Base.iterate import ConstructionBase using ConstructionBase: constructorof using IntervalSets -using OneTwoMany: secondarg +using OneTwoMany: firstarg, secondarg using PrettyPrinting const Pretty = PrettyPrinting diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index a70f27eb..fe6af2bf 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -30,12 +30,6 @@ See also [`mbind`](@ref). function mkernel end export mkernel -@inline mkernel(f_β::MKernel) = f_β -@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c) - -@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c) -@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β - """ struct MeasureBase.MKernel <: Function @@ -45,12 +39,20 @@ Represents a generalized monatic transition kernel. User code should not create instances of `MKernel` directly, but should call [`mkernel`](@ref) instead. """ -struct MKernel - f_β::FK +struct MKernel{FT,FC} <: Function + f_β::FT f_c::FC end +@inline mkernel(f_β::MKernel) = f_β +@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c) + +@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c) +@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β + + + @doc raw""" mbind(f_β, α::AbstractMeasure, f_c = OneTwoMany.secondarg) mbind(f_β::MeasureBase.MKernel, α::AbstractMeasure) @@ -102,7 +104,7 @@ The measure `α` that went into the bind can be retrieved via Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` can be unambiguously split into `a` and `b` again, knowing `α`. This is -currently implemented for `f_c` that is either tuple or `=>`/`Pair` (these +currently implemented for `f_c` that is either `tuple` or `=>`/`Pair` (these work for any combination of variate types), `vcat` (for tuple- or vector-like variates) and `merge` (`NamedTuple` variates). [`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to @@ -152,19 +154,20 @@ export mbind @inline mbind(f_β) = Base.Fix1(mbind, f_β) -@inline mbind(f_k::MKernel, α::AbstractMeasure) = mbind(f_k.f_β, α, f_k.f_c) - -#@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) --- temporary --- -@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, α, f_c) +@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, asmeasure(α), f_c) @inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c) F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) Bind{F,M,G}(f_β, α, f_c) end -function _generic_mbind_impl(f_β, α::Dirac, f_c) - mcombine(f_c, α, f_β(α.x)) -end +@inline _generic_mbind_impl(f_β, α::Dirac, f_c) = mcombine(f_c, α, f_β(α.x)) + +@inline _generic_mbind_impl(@nospecialize(f_β), α::AbstractMeasure, ::typeof(firstarg)) = α +@inline _generic_mbind_impl(@nospecialize(f_β), α::Dirac, ::typeof(firstarg)) = α + +@inline _generic_mbind_impl(f_k::MKernel, α::AbstractMeasure, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c) +@inline _generic_mbind_impl(f_k::MKernel, α::Dirac, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c) """ @@ -175,8 +178,8 @@ Represents a monatic bind resp. a mbind in general. User code should not create instances of `Bind` directly, but should call [`mbind`](@ref) instead. """ -struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure - f_β::FK +struct Bind{FT,M<:AbstractMeasure,FC} <: AbstractMeasure + f_β::FT α::M f_c::FC end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 11ebc558..4127e952 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -62,9 +62,8 @@ export mcombine @inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage1(f_c, α, β) -@inline _generic_mcombine_impl_stage1(::typeof(first), α::AbstractMeasure, β::AbstractMeasure) = α -@inline _generic_mcombine_impl_stage1(::typeof(getsecond), α::AbstractMeasure, β::AbstractMeasure) = β -@inline _generic_mcombine_impl_stage1(::typeof(last), α::AbstractMeasure, β::AbstractMeasure) = β +@inline _generic_mcombine_impl_stage1(::typeof(firstarg), α::AbstractMeasure, β::AbstractMeasure) = α +@inline _generic_mcombine_impl_stage1(::typeof(secondarg), α::AbstractMeasure, β::AbstractMeasure) = β @inline function _generic_mcombine_impl_stage1(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) productmeasure((α, β))