diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index 80d618b741..e384587147 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -92,6 +92,8 @@ function value_type( return MA.promote_operation(*, C, T) end +value_type(::Type{T}, ::Type{MOI.ScalarNonlinearFunction}) where {T} = T + function value_type( ::Type{T}, ::Type{F}, @@ -325,6 +327,7 @@ function map_indices( index_map::F, f::MOI.ScalarNonlinearFunction, ) where {F<:Function} + # TODO(odow): this uses recursion. We should remove at some point. return MOI.ScalarNonlinearFunction( f.head, convert(Vector{Any}, map_indices(index_map, f.args)), @@ -466,6 +469,17 @@ function substitute_variables( return g end +function substitute_variables( + variable_map::F, + f::MOI.ScalarNonlinearFunction, +) where {F<:Function} + # TODO(odow): this uses recursion. We should remove at some point. + return MOI.ScalarNonlinearFunction( + f.head, + Any[substitute_variables(variable_map, a) for a in f.args], + ) +end + function substitute_variables( variable_map::F, f::MOI.VectorAffineFunction{T}, @@ -823,10 +837,32 @@ function unsafe_add( return T(t1.output_index, scalar_term) end +# Generic fallback for items inside NonlinearFunctions like numbers. +is_canonical(::Any) = true + is_canonical(::MOI.AbstractFunction) = false is_canonical(::Union{MOI.VariableIndex,MOI.VectorOfVariables}) = true +function is_canonical(f::MOI.ScalarNonlinearFunction) + # Don't use recursion here! This gets called for all scalar nonlinear + # constraints. + stack = Any[arg for arg in f.args] + while !isempty(stack) + arg = pop!(stack) + if arg isa MOI.ScalarNonlinearFunction + for a in arg.args + push!(stack, a) + end + else + if !is_canonical(arg) + return false + end + end + end + return true +end + """ is_canonical(f::Union{ScalarAffineFunction, VectorAffineFunction}) @@ -919,7 +955,14 @@ function canonicalize!( return f end -canonicalize!(f::MOI.ScalarNonlinearFunction) = f +function canonicalize!(f::MOI.ScalarNonlinearFunction) + for (i, arg) in enumerate(f.args) + if !is_canonical(arg) + f.args[i] = canonicalize!(arg) + end + end + return f +end """ canonicalize!(f::Union{ScalarQuadraticFunction, VectorQuadraticFunction}) diff --git a/test/Utilities/functions.jl b/test/Utilities/functions.jl index 79b214f5fc..3374c04b46 100644 --- a/test/Utilities/functions.jl +++ b/test/Utilities/functions.jl @@ -1844,6 +1844,7 @@ function test_value_type() T, MOI.ScalarQuadraticFunction{Complex{Int}}, ) == Complex{T} + @test MOI.Utilities.value_type(T, MOI.ScalarNonlinearFunction) == T @test MOI.Utilities.value_type( T, MOI.ScalarQuadraticFunction{Complex{T}}, @@ -1909,6 +1910,45 @@ function test_ScalarNonlinearFunction_count_map_indices_and_print() return end +function test_ScalarNonlinearFunction_map_indices() + src = MOI.Utilities.Model{Float64}() + x = MOI.add_variable(src) + f = MOI.ScalarNonlinearFunction(:log, Any[x]) + c = MOI.add_constraint(src, f, MOI.LessThan(1.0)) + dest = MOI.Utilities.Model{Float64}() + index_map = MOI.copy_to(dest, src) + new_f = MOI.Utilities.map_indices(index_map, f) + @test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[index_map[x]]) + return +end + +function test_ScalarNonlinearFunction_substitute_variables() + x = MOI.VariableIndex(1) + f = MOI.ScalarNonlinearFunction(:log, Any[1.0*x]) + new_f = MOI.Utilities.substitute_variables(x -> -2.0 * x, f) + @test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[-2.0*x]) + return +end + +function test_ScalarNonlinearFunction_is_canonical() + x = MOI.VariableIndex(1) + f = MOI.ScalarNonlinearFunction(:log, Any[1.0*x]) + @test MOI.Utilities.is_canonical(f) + g = MOI.ScalarNonlinearFunction(:log, Any[1.0*x+1.0*x]) + @test !MOI.Utilities.is_canonical(g) + MOI.Utilities.canonicalize!(g) + @test MOI.Utilities.is_canonical(g) + @test g.args[1] ≈ 2.0 * x + f = MOI.ScalarNonlinearFunction(:^, Any[x, 2]) + @test MOI.Utilities.is_canonical(f) + # Test deep recursion + for _ in 1:100_000 + f = MOI.ScalarNonlinearFunction(:sin, Any[f]) + end + @test MOI.Utilities.is_canonical(f) + return +end + function test_vector_type() for T in (Int, Float64) @test MOI.Utilities.vector_type(T) == Vector{T}