From 91ca7b81073d93d04d8275dfec9299f36b96f1ba Mon Sep 17 00:00:00 2001 From: Akshay Date: Tue, 4 May 2021 21:18:27 +0530 Subject: [PATCH 1/5] add custom relu implementation --- examples/custom-relu-mnist.jl | 103 ++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 examples/custom-relu-mnist.jl diff --git a/examples/custom-relu-mnist.jl b/examples/custom-relu-mnist.jl new file mode 100644 index 000000000..75217d8e0 --- /dev/null +++ b/examples/custom-relu-mnist.jl @@ -0,0 +1,103 @@ +using Statistics +using DiffOpt +using Flux +using Flux: onehotbatch, onecold, crossentropy, throttle +using Base.Iterators: repeated +using OSQP +using JuMP +using ChainRulesCore + +## prepare data +imgs = Flux.Data.MNIST.images() +labels = Flux.Data.MNIST.labels(); + +# Preprocessing +X = hcat(float.(reshape.(imgs, :))...) #stack all the images +Y = onehotbatch(labels, 0:9); # just a common way to encode categorical variables + +test_X = hcat(float.(reshape.(Flux.Data.MNIST.images(:test), :))...) +test_Y = onehotbatch(Flux.Data.MNIST.labels(:test), 0:9) + +# float64 to float16, to save memory +X = convert(Array{Float16,2}, X) +test_X = convert(Array{Float16,2}, test_X) + +X = X[:, 1:1000] +Y = Y[:, 1:1000]; + +""" + relu method for a Matrix +""" +function myRelu(y::AbstractMatrix{T}; model = Model(() -> diff_optimizer(OSQP.Optimizer))) where {T} + x̂ = zero(y) + + # model init + N = length(y[:, 1]) + empty!(model) + set_optimizer_attribute(model, MOI.Silent(), true) + @variable(model, x[1:N] >= zero(T)) + + for i in 1:size(y)[2] + @objective( + model, + Min, + x'x -2x'y[:, i] + ) + optimize!(model) + x̂[:, i] = value.(x) + end + return x̂ +end + +function ChainRulesCore.rrule(::typeof(myRelu), y::AbstractArray{T}; model = Model(() -> diff_optimizer(OSQP.Optimizer))) where {T} + + pv = myRelu(y, model=model) + + function pullback_myRelu(dx) + x = model[:x] + dy = zero(dx) + + for i in 1:size(y)[2] + MOI.set.( + model, + DiffOpt.BackwardIn{MOI.VariablePrimal}(), + x, + dx[:, i] + ) + + DiffOpt.backward(model) # find grad + + dy[:, i] = MOI.get.( + model, + DiffOpt.BackwardOut{DiffOpt.LinearObjective}(), + x, + ) # coeff of `x` in -2x'y + dy[:, i] = -2 * dy[:, i] + end + + return (NO_FIELDS, dy) + end + return pv, pullback_myRelu +end + +m = Chain( + Dense(784, 64), + myRelu, + Dense(64, 10), + softmax +) + +loss(x, y) = crossentropy(m(x), y) +opt = ADAM(); # popular stochastic gradient descent variant + +accuracy(x, y) = mean(onecold(m(x)) .== onecold(y)) # cute way to find average of correct guesses + +dataset = repeated((X,Y), 20) # repeat the data set, very low accuracy on the orig dataset +evalcb = () -> @show(loss(X, Y)) # callback to show loss + +Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 5)); #took me ~5 minutes to train on CPU + +@show accuracy(X,Y) +@show accuracy(test_X, test_Y); + + From e8a606a51535208698c7c84b904f6375c7c071fa Mon Sep 17 00:00:00 2001 From: Akshay Sharma Date: Wed, 5 May 2021 02:33:40 -0700 Subject: [PATCH 2/5] Apply suggestions from code review Co-authored-by: Oscar Dowson --- examples/custom-relu-mnist.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/custom-relu-mnist.jl b/examples/custom-relu-mnist.jl index 75217d8e0..0317c890b 100644 --- a/examples/custom-relu-mnist.jl +++ b/examples/custom-relu-mnist.jl @@ -34,10 +34,10 @@ function myRelu(y::AbstractMatrix{T}; model = Model(() -> diff_optimizer(OSQP.Op # model init N = length(y[:, 1]) empty!(model) - set_optimizer_attribute(model, MOI.Silent(), true) - @variable(model, x[1:N] >= zero(T)) + set_silent(model) + @variable(model, x[1:N] >= 0) - for i in 1:size(y)[2] + for i in 1:size(y, 2) @objective( model, Min, @@ -72,7 +72,7 @@ function ChainRulesCore.rrule(::typeof(myRelu), y::AbstractArray{T}; model = Mod DiffOpt.BackwardOut{DiffOpt.LinearObjective}(), x, ) # coeff of `x` in -2x'y - dy[:, i] = -2 * dy[:, i] + dy[:, i] .= -2 .* dy[:, i] end return (NO_FIELDS, dy) @@ -100,4 +100,3 @@ Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 5)); #took me ~ @show accuracy(X,Y) @show accuracy(test_X, test_Y); - From e37f995e150c99700dbc2578dff8ac80db41cada Mon Sep 17 00:00:00 2001 From: Akshay Date: Thu, 6 May 2021 15:22:12 +0530 Subject: [PATCH 3/5] avoid jump warning --- examples/custom-relu-mnist.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/custom-relu-mnist.jl b/examples/custom-relu-mnist.jl index 0317c890b..9f07476f9 100644 --- a/examples/custom-relu-mnist.jl +++ b/examples/custom-relu-mnist.jl @@ -22,8 +22,8 @@ test_Y = onehotbatch(Flux.Data.MNIST.labels(:test), 0:9) X = convert(Array{Float16,2}, X) test_X = convert(Array{Float16,2}, test_X) -X = X[:, 1:1000] -Y = Y[:, 1:1000]; +X = X[:, 1:10000] +Y = Y[:, 1:10000]; """ relu method for a Matrix @@ -36,13 +36,14 @@ function myRelu(y::AbstractMatrix{T}; model = Model(() -> diff_optimizer(OSQP.Op empty!(model) set_silent(model) @variable(model, x[1:N] >= 0) + @objective( + model, + Min, + x'x -2x'y[:, 1] + ) for i in 1:size(y, 2) - @objective( - model, - Min, - x'x -2x'y[:, i] - ) + set_objective_coefficient.(model, x, -2y[:, i]) optimize!(model) x̂[:, i] = value.(x) end From 0db97ee16601dc36eff8d133e4445106b69ce90b Mon Sep 17 00:00:00 2001 From: Akshay Sharma Date: Thu, 6 May 2021 06:18:31 -0700 Subject: [PATCH 4/5] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mathieu Besançon --- examples/custom-relu-mnist.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/custom-relu-mnist.jl b/examples/custom-relu-mnist.jl index 9f07476f9..bbe5e6121 100644 --- a/examples/custom-relu-mnist.jl +++ b/examples/custom-relu-mnist.jl @@ -26,7 +26,7 @@ X = X[:, 1:10000] Y = Y[:, 1:10000]; """ - relu method for a Matrix +relu method for a Matrix """ function myRelu(y::AbstractMatrix{T}; model = Model(() -> diff_optimizer(OSQP.Optimizer))) where {T} x̂ = zero(y) @@ -85,7 +85,7 @@ m = Chain( Dense(784, 64), myRelu, Dense(64, 10), - softmax + softmax, ) loss(x, y) = crossentropy(m(x), y) @@ -100,4 +100,3 @@ Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 5)); #took me ~ @show accuracy(X,Y) @show accuracy(test_X, test_Y); - From 6eedeef72c804e2f360b27e840e5b494634fbc62 Mon Sep 17 00:00:00 2001 From: Akshay Date: Thu, 6 May 2021 18:50:46 +0530 Subject: [PATCH 5/5] rename --- examples/custom-relu-mnist.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/custom-relu-mnist.jl b/examples/custom-relu-mnist.jl index bbe5e6121..c47312e79 100644 --- a/examples/custom-relu-mnist.jl +++ b/examples/custom-relu-mnist.jl @@ -26,9 +26,11 @@ X = X[:, 1:10000] Y = Y[:, 1:10000]; """ + matrix_relu(y::AbstractMatrix{T}; model = Model(() -> diff_optimizer(OSQP.Optimizer)))::AbstractMatrix{T} + relu method for a Matrix """ -function myRelu(y::AbstractMatrix{T}; model = Model(() -> diff_optimizer(OSQP.Optimizer))) where {T} +function matrix_relu(y::AbstractMatrix{T}; model = Model(() -> diff_optimizer(OSQP.Optimizer))) where {T} x̂ = zero(y) # model init @@ -50,11 +52,11 @@ function myRelu(y::AbstractMatrix{T}; model = Model(() -> diff_optimizer(OSQP.Op return x̂ end -function ChainRulesCore.rrule(::typeof(myRelu), y::AbstractArray{T}; model = Model(() -> diff_optimizer(OSQP.Optimizer))) where {T} +function ChainRulesCore.rrule(::typeof(matrix_relu), y::AbstractArray{T}; model = Model(() -> diff_optimizer(OSQP.Optimizer))) where {T} - pv = myRelu(y, model=model) + pv = matrix_relu(y, model=model) - function pullback_myRelu(dx) + function pullback_matrix_relu(dx) x = model[:x] dy = zero(dx) @@ -78,12 +80,12 @@ function ChainRulesCore.rrule(::typeof(myRelu), y::AbstractArray{T}; model = Mod return (NO_FIELDS, dy) end - return pv, pullback_myRelu + return pv, pullback_matrix_relu end m = Chain( Dense(784, 64), - myRelu, + matrix_relu, Dense(64, 10), softmax, )