ref JuliaDiff/ChainRulesTestUtils.jl#258
|
# Fallback method for `to_vec`. Won't always do what you wanted, but should be fine a decent |
|
# chunk of the time. |
|
function to_vec(x::T) where {T} |
|
Base.isstructtype(T) || throw(error("Expected a struct type")) |
|
isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types |
|
|
|
val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T)) |
|
vals = first.(val_vecs_and_backs) |
|
backs = last.(val_vecs_and_backs) |
|
|
|
v, vals_from_vec = to_vec(vals) |
|
function structtype_from_vec(v::Vector{<:Real}) |
|
val_vecs = vals_from_vec(v) |
|
values = map((b, v) -> b(v), backs, val_vecs) |
|
try |
|
T(values...) |
|
catch MethodError |
|
return _force_construct(T, values...) |
|
end |
|
end |
|
return v, structtype_from_vec |
|
end |
ref JuliaDiff/ChainRulesTestUtils.jl#258
FiniteDifferences.jl/src/to_vec.jl
Lines 36 to 57 in 5c2979e