Skip to content

Added simple CuArray patch#44

Closed
jacob-m-wilson-42 wants to merge 1 commit intoJuliaDiff:mainfrom
jacob-m-wilson-42:CUDA-Support
Closed

Added simple CuArray patch#44
jacob-m-wilson-42 wants to merge 1 commit intoJuliaDiff:mainfrom
jacob-m-wilson-42:CUDA-Support

Conversation

@jacob-m-wilson-42
Copy link

This is my first ever pull request so please feel free to comment on any mistakes.

I wanted to get something up quick, there may be a better solution but I added another method to the derivatives source file to accept any CuArray type. I also added a simple test (which throws scalar index warnings). I also added the CUDA dependency.

Please triple check anything before merging; this is my first commit ever!

@jacob-m-wilson-42
Copy link
Author

jacob-m-wilson-42 commented Jul 29, 2023

Hmmm I didn't change much to the source. Not sure why it broke so many things...

It looks like there might be some incompatible packages, but I'm not experienced enough to say for certain. If nothing else, maybe the updated derivative() file will save someone a little bit of time.

@tansongchen
Copy link
Member

Hi Jacob, sorry for the late reply and thanks for your contribution. Sure, we should support CUDA data types, but as a generic AD library, it would be better if we don't explicitly write out very specific types like CuArray -- maybe some more abstract ones would also work?

Plus, someone just relaxed the input data type to AbstractArray{T, 1} in this PR #45 . Could you try the most recent version on the main branch?

@tansongchen
Copy link
Member

Hi Jacob, I just updated the main branch. Now the signature looks like

@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S},
    order::Int64) where {T <: Number, S <: Number}
    derivative(f, x, l, Val{order + 1}())
end

Here x and l are allowed to be very different types, as long as they are both some AbstractVectors and S can be converted to T. I am not familiar with CuArrays but I believe this can solve your problem. Could you try the latest branch?

@jacob-m-wilson-42
Copy link
Author

Hello Songchen, also sorry for the late response. I will give it a shot when I get a chance and let you know how it goes! I think relaxing the type is probably better like you suggested.

@jacob-m-wilson-42
Copy link
Author

Songchen,

I tested the main branch code with the below

v, direction = CuArray([0f0, 0f0]), CuArray([1.0f0, 0.0f0])
derivative(x -> sum(exp.(x)), v, direction, 2) # directional derivative

and it worked! Excellent work, thanks so much for the addition! Sorry it took so long for me to see this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants