From 0c20f1f71edfdcf044c4b134f95df8a6c989d2e1 Mon Sep 17 00:00:00 2001 From: jwilson Date: Fri, 28 Jul 2023 23:19:02 -0600 Subject: [PATCH] Added simple CuArray patch --- Project.toml | 1 + src/derivative.jl | 7 +++++++ test/cuda.jl | 19 +++++++++++++++++++ 3 files changed, 27 insertions(+) create mode 100644 test/cuda.jl diff --git a/Project.toml b/Project.toml index 7ad1b28c..a41e00e7 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Songchen Tan "] version = "0.2.1" [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" diff --git a/src/derivative.jl b/src/derivative.jl index ef46a1e0..4b442bdb 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -1,5 +1,6 @@ export derivative +using CUDA """ derivative(f, x::T, order::Int64) @@ -23,6 +24,12 @@ end derivative(f, x, l, Val{order + 1}()) end +# add CUDA support +@inline function derivative(f, x::CuArray, l::CuArray, + order::Int64) +derivative(f, x, l, Val{order + 1}()) +end + @inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N} t = TaylorScalar{T, N}(x, one(x)) return extract_derivative(f(t), N) diff --git a/test/cuda.jl b/test/cuda.jl new file mode 100644 index 00000000..0ebf6699 --- /dev/null +++ b/test/cuda.jl @@ -0,0 +1,19 @@ +using CUDA + +# cpu +f(x,y) = x^2 * y^2 +input_cpu = [1e0,1e0] +derivative(temp -> f(temp[1],temp[2]),input_cpu,[1e0,0e0],2) + +# gpu +f(x,y) = x^2 * y^2 +input_gpu = CuArray([1e0,1e0]) +direction_gpu = CuArray([1e0,0e0]) +derivative(temp -> f(temp[1],temp[2]),input_gpu,direction_gpu,2) + + + + + + +