Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/Utilities/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ function value_type(
return MA.promote_operation(*, C, T)
end

value_type(::Type{T}, ::Type{MOI.ScalarNonlinearFunction}) where {T} = T

function value_type(
::Type{T},
::Type{F},
Expand Down Expand Up @@ -325,6 +327,7 @@ function map_indices(
index_map::F,
f::MOI.ScalarNonlinearFunction,
) where {F<:Function}
# TODO(odow): this uses recursion. We should remove at some point.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know these use recursion. We can fix in future. Very few models should exceed the recursion depth.

return MOI.ScalarNonlinearFunction(
f.head,
convert(Vector{Any}, map_indices(index_map, f.args)),
Expand Down Expand Up @@ -466,6 +469,17 @@ function substitute_variables(
return g
end

function substitute_variables(
variable_map::F,
f::MOI.ScalarNonlinearFunction,
) where {F<:Function}
# TODO(odow): this uses recursion. We should remove at some point.
return MOI.ScalarNonlinearFunction(
f.head,
Any[substitute_variables(variable_map, a) for a in f.args],
)
end

function substitute_variables(
variable_map::F,
f::MOI.VectorAffineFunction{T},
Expand Down Expand Up @@ -823,10 +837,32 @@ function unsafe_add(
return T(t1.output_index, scalar_term)
end

# Generic fallback for items inside NonlinearFunctions like numbers.
is_canonical(::Any) = true

is_canonical(::MOI.AbstractFunction) = false

is_canonical(::Union{MOI.VariableIndex,MOI.VectorOfVariables}) = true

function is_canonical(f::MOI.ScalarNonlinearFunction)
# Don't use recursion here! This gets called for all scalar nonlinear
# constraints.
stack = Any[arg for arg in f.args]
while !isempty(stack)
arg = pop!(stack)
if arg isa MOI.ScalarNonlinearFunction
for a in arg.args
push!(stack, a)
end
else
if !is_canonical(arg)
return false
end
end
end
return true
end

"""
is_canonical(f::Union{ScalarAffineFunction, VectorAffineFunction})

Expand Down Expand Up @@ -919,7 +955,14 @@ function canonicalize!(
return f
end

canonicalize!(f::MOI.ScalarNonlinearFunction) = f
function canonicalize!(f::MOI.ScalarNonlinearFunction)
for (i, arg) in enumerate(f.args)
if !is_canonical(arg)
f.args[i] = canonicalize!(arg)
end
end
return f
end

"""
canonicalize!(f::Union{ScalarQuadraticFunction, VectorQuadraticFunction})
Expand Down
40 changes: 40 additions & 0 deletions test/Utilities/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,7 @@ function test_value_type()
T,
MOI.ScalarQuadraticFunction{Complex{Int}},
) == Complex{T}
@test MOI.Utilities.value_type(T, MOI.ScalarNonlinearFunction) == T
@test MOI.Utilities.value_type(
T,
MOI.ScalarQuadraticFunction{Complex{T}},
Expand Down Expand Up @@ -1909,6 +1910,45 @@ function test_ScalarNonlinearFunction_count_map_indices_and_print()
return
end

function test_ScalarNonlinearFunction_map_indices()
src = MOI.Utilities.Model{Float64}()
x = MOI.add_variable(src)
f = MOI.ScalarNonlinearFunction(:log, Any[x])
c = MOI.add_constraint(src, f, MOI.LessThan(1.0))
dest = MOI.Utilities.Model{Float64}()
index_map = MOI.copy_to(dest, src)
new_f = MOI.Utilities.map_indices(index_map, f)
@test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[index_map[x]])
return
end

function test_ScalarNonlinearFunction_substitute_variables()
x = MOI.VariableIndex(1)
f = MOI.ScalarNonlinearFunction(:log, Any[1.0*x])
new_f = MOI.Utilities.substitute_variables(x -> -2.0 * x, f)
@test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[-2.0*x])
return
end

function test_ScalarNonlinearFunction_is_canonical()
x = MOI.VariableIndex(1)
f = MOI.ScalarNonlinearFunction(:log, Any[1.0*x])
@test MOI.Utilities.is_canonical(f)
g = MOI.ScalarNonlinearFunction(:log, Any[1.0*x+1.0*x])
@test !MOI.Utilities.is_canonical(g)
MOI.Utilities.canonicalize!(g)
@test MOI.Utilities.is_canonical(g)
@test g.args[1] ≈ 2.0 * x
f = MOI.ScalarNonlinearFunction(:^, Any[x, 2])
@test MOI.Utilities.is_canonical(f)
# Test deep recursion
for _ in 1:100_000
f = MOI.ScalarNonlinearFunction(:sin, Any[f])
end
@test MOI.Utilities.is_canonical(f)
return
end

function test_vector_type()
for T in (Int, Float64)
@test MOI.Utilities.vector_type(T) == Vector{T}
Expand Down