From 9d56105c6a534a1d3474730069c74617a8aa3b76 Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 22 Jun 2023 15:42:02 +1200 Subject: [PATCH 1/4] [Utilities] fix various utilities for ScalarNonlinearFunction --- src/Utilities/functions.jl | 29 +++++++++++++++++++++++++++-- test/Utilities/functions.jl | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index fe87c8bce1..46df5f199b 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -92,6 +92,8 @@ function value_type( return MA.promote_operation(*, C, T) end +value_type(::Type{T}, ::Type{MOI.ScalarNonlinearFunction}) where {T} = T + function value_type( ::Type{T}, ::Type{F}, @@ -107,7 +109,7 @@ Returns the value of function `f` if each variable index `vi` is evaluated as Note that `value_fn` must return a Number. See [`substitute_variables`](@ref) for a similar function where `value_fn` returns an - [`MOI.AbstractScalarFunction`](@ref). +[`MOI.AbstractScalarFunction`](@ref). """ function eval_variables end @@ -294,6 +296,7 @@ function map_indices( index_map::F, f::MOI.ScalarNonlinearFunction, ) where {F<:Function} + # TODO(odow): this uses recursion. We should remove at some point. return MOI.ScalarNonlinearFunction( f.head, convert(Vector{Any}, map_indices(index_map, f.args)), @@ -435,6 +438,17 @@ function substitute_variables( return g end +function substitute_variables( + variable_map::F, + f::MOI.ScalarNonlinearFunction, +) where {F<:Function} + # TODO(odow): this uses recursion. We should remove at some point. + return MOI.ScalarNonlinearFunction( + f.head, + Any[substitute_variables(variable_map, a) for a in f.args], + ) +end + function substitute_variables( variable_map::F, f::MOI.VectorAffineFunction{T}, @@ -796,6 +810,10 @@ is_canonical(::MOI.AbstractFunction) = false is_canonical(::Union{MOI.VariableIndex,MOI.VectorOfVariables}) = true +function is_canonical(f::MOI.ScalarNonlinearFunction) + return all(is_canonical(arg) for arg in f.args) +end + """ is_canonical(f::Union{ScalarAffineFunction, VectorAffineFunction}) @@ -888,7 +906,14 @@ function canonicalize!( return f end -canonicalize!(f::MOI.ScalarNonlinearFunction) = f +function canonicalize!(f::MOI.ScalarNonlinearFunction) + for (i, arg) in enumerate(f.args) + if !is_canonical(arg) + f.args[i] = canonicalize!(arg) + end + end + return f +end """ canonicalize!(f::Union{ScalarQuadraticFunction, VectorQuadraticFunction}) diff --git a/test/Utilities/functions.jl b/test/Utilities/functions.jl index 77cd8fae4a..631f9f090b 100644 --- a/test/Utilities/functions.jl +++ b/test/Utilities/functions.jl @@ -1808,6 +1808,7 @@ function test_value_type() T, MOI.ScalarQuadraticFunction{Complex{Int}}, ) == Complex{T} + @test MOI.Utilities.value_type(T, MOI.ScalarNonlinearFunction) == T @test MOI.Utilities.value_type( T, MOI.ScalarQuadraticFunction{Complex{T}}, @@ -1873,6 +1874,38 @@ function test_ScalarNonlinearFunction_count_map_indices_and_print() return end +function test_ScalarNonlinearFunction_map_indices() + src = MOI.Utilities.Model{Float64}() + x = MOI.add_variable(src) + f = MOI.ScalarNonlinearFunction(:log, Any[x]) + c = MOI.add_constraint(src, f, MOI.LessThan(1.0)) + dest = MOI.Utilities.Model{Float64}() + index_map = MOI.copy_to(dest, src) + new_f = MOI.Utilities.map_indices(index_map, f) + @test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[index_map[x]]) + return +end + +function test_ScalarNonlinearFunction_substitute_variables() + x = MOI.VariableIndex(1) + f = MOI.ScalarNonlinearFunction(:log, Any[1.0 * x]) + new_f = MOI.Utilities.substitute_variables(x -> -2.0 * x, f) + @test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[-2.0 * x]) + return +end + +function test_ScalarNonlinearFunction_is_canonical() + x = MOI.VariableIndex(1) + f = MOI.ScalarNonlinearFunction(:log, Any[1.0 * x]) + @test MOI.Utilities.is_canonical(f) + g = MOI.ScalarNonlinearFunction(:log, Any[1.0 * x + 1.0 * x]) + @test !MOI.Utilities.is_canonical(g) + MOI.Utilities.canonicalize!(g) + @test MOI.Utilities.is_canonical(g) + @test g.args[1] ≈ 2.0 * x + return +end + function test_vector_type() for T in (Int, Float64) @test MOI.Utilities.vector_type(T) == Vector{T} From daa8b7ae6851f5471de9c6b231f3b5c0d6f8606f Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 22 Jun 2023 16:27:19 +1200 Subject: [PATCH 2/4] Fix formattinng --- test/Utilities/functions.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Utilities/functions.jl b/test/Utilities/functions.jl index 631f9f090b..32d5d62e6d 100644 --- a/test/Utilities/functions.jl +++ b/test/Utilities/functions.jl @@ -1888,17 +1888,17 @@ end function test_ScalarNonlinearFunction_substitute_variables() x = MOI.VariableIndex(1) - f = MOI.ScalarNonlinearFunction(:log, Any[1.0 * x]) + f = MOI.ScalarNonlinearFunction(:log, Any[1.0*x]) new_f = MOI.Utilities.substitute_variables(x -> -2.0 * x, f) - @test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[-2.0 * x]) + @test new_f ≈ MOI.ScalarNonlinearFunction(:log, Any[-2.0*x]) return end function test_ScalarNonlinearFunction_is_canonical() x = MOI.VariableIndex(1) - f = MOI.ScalarNonlinearFunction(:log, Any[1.0 * x]) + f = MOI.ScalarNonlinearFunction(:log, Any[1.0*x]) @test MOI.Utilities.is_canonical(f) - g = MOI.ScalarNonlinearFunction(:log, Any[1.0 * x + 1.0 * x]) + g = MOI.ScalarNonlinearFunction(:log, Any[1.0*x+1.0*x]) @test !MOI.Utilities.is_canonical(g) MOI.Utilities.canonicalize!(g) @test MOI.Utilities.is_canonical(g) From abbe789256b76e2d12371cd2eddc3eea67b30d6d Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 22 Jun 2023 20:42:05 +1200 Subject: [PATCH 3/4] Update --- src/Utilities/functions.jl | 19 ++++++++++++++++++- test/Utilities/functions.jl | 7 +++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index 46df5f199b..4aed8eb2b2 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -806,12 +806,29 @@ function unsafe_add( return T(t1.output_index, scalar_term) end +is_canonical(::Number) = true + is_canonical(::MOI.AbstractFunction) = false is_canonical(::Union{MOI.VariableIndex,MOI.VectorOfVariables}) = true function is_canonical(f::MOI.ScalarNonlinearFunction) - return all(is_canonical(arg) for arg in f.args) + # Don't use recursion here! This gets called for all scalar nonlinear + # constraints. + stack = Any[arg for arg in f.args] + while !isempty(stack) + arg = pop!(stack) + if arg isa MOI.ScalarNonlinearFunction + for a in arg.args + push!(stack, a) + end + else + if !is_canonical(arg) + return false + end + end + end + return true end """ diff --git a/test/Utilities/functions.jl b/test/Utilities/functions.jl index 32d5d62e6d..7327f6fb28 100644 --- a/test/Utilities/functions.jl +++ b/test/Utilities/functions.jl @@ -1903,6 +1903,13 @@ function test_ScalarNonlinearFunction_is_canonical() MOI.Utilities.canonicalize!(g) @test MOI.Utilities.is_canonical(g) @test g.args[1] ≈ 2.0 * x + f = MOI.ScalarNonlinearFunction(:^, Any[x, 2]) + @test MOI.Utilities.is_canonical(f) + # Test deep recursion + for _ in 1:100_000 + f = MOI.ScalarNonlinearFunction(:sin, Any[f]) + end + @test MOI.Utilities.is_canonical(f) return end From 274b7d01bf9b616583b67b6984fbb5be307f4ef6 Mon Sep 17 00:00:00 2001 From: odow Date: Fri, 23 Jun 2023 10:24:09 +1200 Subject: [PATCH 4/4] Update --- src/Utilities/functions.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Utilities/functions.jl b/src/Utilities/functions.jl index 4aed8eb2b2..faf025c173 100644 --- a/src/Utilities/functions.jl +++ b/src/Utilities/functions.jl @@ -806,7 +806,8 @@ function unsafe_add( return T(t1.output_index, scalar_term) end -is_canonical(::Number) = true +# Generic fallback for items inside NonlinearFunctions like numbers. +is_canonical(::Any) = true is_canonical(::MOI.AbstractFunction) = false