diff --git a/.proj.toml b/.proj.toml index 01ae36eddd..8898cda5d5 100644 --- a/.proj.toml +++ b/.proj.toml @@ -11,8 +11,9 @@ build_targets = [ # "substitutions", # "compiler", "substitution-generator", - "local-execution", + "local-execution", ] + test_targets = [ "utils-tests", "op-attrs-tests", diff --git a/flake.lock b/flake.lock index 2d1157ba40..6dce7855cb 100644 --- a/flake.lock +++ b/flake.lock @@ -81,4 +81,4 @@ }, "root": "root", "version": 7 -} +} \ No newline at end of file diff --git a/flake.nix b/flake.nix index 3599304ed0..38d6740b81 100644 --- a/flake.nix +++ b/flake.nix @@ -152,4 +152,4 @@ }; } ); -} +} \ No newline at end of file diff --git a/lib/kernels/CMakeLists.txt b/lib/kernels/CMakeLists.txt index b2b81c85bd..f166dd027c 100644 --- a/lib/kernels/CMakeLists.txt +++ b/lib/kernels/CMakeLists.txt @@ -7,6 +7,7 @@ file(GLOB_RECURSE SRC CONFIGURE_DEPENDS LIST_DIRECTORIES False src/*.cc + src/cuda/cuda_helper.cu src/cuda/ops/*.cu ) @@ -28,6 +29,7 @@ target_link_libraries( cuda cudnn nccl + utils ) define_ff_vars(${project_target}) @@ -37,3 +39,5 @@ set_target_properties( PROPERTIES CUDA_STANDARD 17 ) + +add_subdirectory(test) \ No newline at end of file diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index c65c2befb8..1ef121fb2a 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -142,6 +142,9 @@ std::vector const *> return out; } +GenericTensorAccessorR read_only_accessor_from_write_accessor( + GenericTensorAccessorW const &write_accessor); + } // namespace FlexFlow namespace FlexFlow { diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 6d6f5bf260..5427d25bc3 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -41,8 +41,6 @@ struct ArrayShape { std::optional at_maybe(std::size_t) const; - ArrayShape reversed_dim_order() const; - ArrayShape sub_shape(std::optional> start, std::optional> end) const; diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index 3f6f0daabc..de37b4169f 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -3,6 +3,7 @@ #include "device.h" #include "kernels/allocation.h" +#include "kernels/device.h" #include "kernels/ff_handle.h" #include "op-attrs/ops/attention.h" #include diff --git a/lib/kernels/include/kernels/conv_2d_kernels.h b/lib/kernels/include/kernels/conv_2d_kernels.h index 0a93125367..cfc64f963d 100644 --- a/lib/kernels/include/kernels/conv_2d_kernels.h +++ b/lib/kernels/include/kernels/conv_2d_kernels.h @@ -46,7 +46,7 @@ Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle, int padding_w, int stride_h, int stride_w, - GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input, GenericTensorAccessorW const &output, float const *filter_ptr, float *filter_grad_ptr); diff --git a/lib/kernels/include/kernels/device.h b/lib/kernels/include/kernels/device.h index 439937177a..c4e78821dc 100644 --- a/lib/kernels/include/kernels/device.h +++ b/lib/kernels/include/kernels/device.h @@ -26,9 +26,12 @@ #include #include +namespace FlexFlow { +cudaError_t get_legion_stream(cudaStream_t *stream); +} // namespace FlexFlow + #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) typedef cudaStream_t ffStream_t; -cudaError_t get_legion_stream(cudaStream_t *stream); typedef cudnnTensorDescriptor_t ffTensorDescriptor_t; typedef cudnnActivationDescriptor_t ffActivationDescriptor_t; typedef cudnnPoolingDescriptor_t ffPoolingDescriptor_t; @@ -96,7 +99,8 @@ using coord_t = long long; do { \ std::stringstream _error; \ if (status != 0) { \ - _error << "Cuda failure: " << status; \ + _error << "CUDA failure: " << cudaGetErrorString(status) << " (" \ + << status << ")"; \ FatalError(_error.str()); \ } \ } while (0) diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index 50b20cb80d..8c6864b2d9 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -29,14 +29,14 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, ElementUnaryAttrs const &attrs, - PerDeviceFFHandle &handle, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, ElementUnaryAttrs const &attrs, - PerDeviceFFHandle &handle, + PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output, diff --git a/lib/kernels/include/kernels/layer_norm_kernels.h b/lib/kernels/include/kernels/layer_norm_kernels.h index 52b450d3f5..be13d32879 100644 --- a/lib/kernels/include/kernels/layer_norm_kernels.h +++ b/lib/kernels/include/kernels/layer_norm_kernels.h @@ -34,8 +34,8 @@ namespace Kernels { namespace LayerNorm { // todo: this may have some problem. -LayerNormPerDeviceState init_kernel(PerDeviceFFHandle const &, - Allocator const &, +LayerNormPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator &allocator, bool elementwise_affine, int64_t effective_batch_size, int64_t effective_num_elements, diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index cf6ebfc2d4..d8ffd91489 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -6,7 +6,7 @@ namespace FlexFlow { -legion_dim_t add_to_legion_dim(legion_dim_t, int); +legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value); legion_dim_t legion_dim_from_ff_dim(ff_dim_t, int num_dimensions); diff --git a/lib/kernels/include/kernels/linear_kernels.h b/lib/kernels/include/kernels/linear_kernels.h index dc7f09a02a..c761eaf1d9 100644 --- a/lib/kernels/include/kernels/linear_kernels.h +++ b/lib/kernels/include/kernels/linear_kernels.h @@ -38,6 +38,7 @@ namespace Linear { LinearPerDeviceState init_kernel(PerDeviceFFHandle handle, float *one_ptr, + std::optional activation, std::optional regularizer, bool use_bias, DataType input_type, @@ -57,6 +58,7 @@ void forward_kernel(ffStream_t stream, int in_dim, int out_dim, int batch_size); + void backward_kernel(ffStream_t stream, LinearPerDeviceState const &m, void const *input_ptr, diff --git a/lib/kernels/include/kernels/local_cuda_allocator.h b/lib/kernels/include/kernels/local_cuda_allocator.h new file mode 100644 index 0000000000..18a4b6e78a --- /dev/null +++ b/lib/kernels/include/kernels/local_cuda_allocator.h @@ -0,0 +1,22 @@ +#include "kernels/allocation.h" +#include + +namespace FlexFlow { + +struct LocalCudaAllocator : public IAllocator { + LocalCudaAllocator() = default; + LocalCudaAllocator(LocalCudaAllocator const &) = delete; + LocalCudaAllocator(LocalCudaAllocator &&) = delete; + ~LocalCudaAllocator() override; + + void *allocate(size_t) override; + void deallocate(void *) override; + +private: + std::unordered_set ptrs; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalCudaAllocator); + +Allocator create_local_cuda_memory_allocator(); + +} // namespace FlexFlow diff --git a/lib/kernels/include/kernels/managed_ff_stream.h b/lib/kernels/include/kernels/managed_ff_stream.h new file mode 100644 index 0000000000..2f690b2eb3 --- /dev/null +++ b/lib/kernels/include/kernels/managed_ff_stream.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_KERNELS_MANAGED_FF_STREAM_H +#define _FLEXFLOW_KERNELS_MANAGED_FF_STREAM_H + +#include "device.h" + +namespace FlexFlow { + +struct ManagedFFStream { +public: + ManagedFFStream(); + + ManagedFFStream(ManagedFFStream const &) = delete; + ManagedFFStream &operator=(ManagedFFStream const &) = delete; + + ManagedFFStream(ManagedFFStream &&other) noexcept; + ManagedFFStream &operator=(ManagedFFStream &&other) noexcept; + + ~ManagedFFStream(); + + ffStream_t const &raw_stream() const; + +private: + ffStream_t *stream; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/kernels/include/kernels/managed_per_device_ff_handle.h b/lib/kernels/include/kernels/managed_per_device_ff_handle.h new file mode 100644 index 0000000000..0a83a5eecb --- /dev/null +++ b/lib/kernels/include/kernels/managed_per_device_ff_handle.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_KERNELS_MANAGED_HANDLE_H +#define _FLEXFLOW_KERNELS_MANAGED_HANDLE_H + +#include "kernels/ff_handle.h" + +namespace FlexFlow { + +struct ManagedPerDeviceFFHandle { +public: + ManagedPerDeviceFFHandle(); + + ManagedPerDeviceFFHandle(ManagedPerDeviceFFHandle const &) = delete; + ManagedPerDeviceFFHandle & + operator=(ManagedPerDeviceFFHandle const &) = delete; + + ManagedPerDeviceFFHandle(ManagedPerDeviceFFHandle &&other) noexcept; + ManagedPerDeviceFFHandle & + operator=(ManagedPerDeviceFFHandle &&other) noexcept; + + ~ManagedPerDeviceFFHandle(); + + PerDeviceFFHandle const &raw_handle() const; + +private: + PerDeviceFFHandle *handle; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/kernels/include/kernels/reduce_kernels.h b/lib/kernels/include/kernels/reduce_kernels.h index 56241b73ce..4287472875 100644 --- a/lib/kernels/include/kernels/reduce_kernels.h +++ b/lib/kernels/include/kernels/reduce_kernels.h @@ -31,8 +31,8 @@ namespace Reduce { ReducePerDeviceState init_kernel(PerDeviceFFHandle const &, OperatorType const &, size_t const &, - ArrayShape input_shape, - ArrayShape output_shape); + ArrayShape const &input_shape, + ArrayShape const &output_shape); void forward_kernel(ffStream_t stream, ReducePerDeviceState const &m, diff --git a/lib/kernels/include/kernels/replicate_kernels.h b/lib/kernels/include/kernels/replicate_kernels.h index 30d7bc5d90..409fc81f44 100644 --- a/lib/kernels/include/kernels/replicate_kernels.h +++ b/lib/kernels/include/kernels/replicate_kernels.h @@ -13,8 +13,8 @@ void forward_kernel(ffStream_t stream, GenericTensorAccessorW const &output); void backward_kernel(ffStream_t stream, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output, + GenericTensorAccessorW const &input, + GenericTensorAccessorR const &output, size_t num_replicas); } // namespace Replicate diff --git a/lib/kernels/include/kernels/softmax_kernels.h b/lib/kernels/include/kernels/softmax_kernels.h index 9831e55589..061230ec52 100644 --- a/lib/kernels/include/kernels/softmax_kernels.h +++ b/lib/kernels/include/kernels/softmax_kernels.h @@ -18,12 +18,18 @@ FF_VISITABLE_STRUCT(SoftmaxPerDeviceState, handle, inputTensor, dim); namespace Kernels { namespace Softmax { -SoftmaxPerDeviceState init_kernel(PerDeviceFFHandle const &, int); +SoftmaxPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + int dim, + int input_n, + int input_c, + int input_h, + int input_w); void forward_kernel(ffStream_t stream, SoftmaxPerDeviceState const &m, float const *input_ptr, float *output_ptr); + void backward_kernel(ffStream_t stream, float *input_grad_ptr, float const *output_grad_ptr, diff --git a/lib/kernels/include/kernels/transpose_kernels.h b/lib/kernels/include/kernels/transpose_kernels.h index fa087fada3..56da81ba2b 100644 --- a/lib/kernels/include/kernels/transpose_kernels.h +++ b/lib/kernels/include/kernels/transpose_kernels.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_OPS_KERNELS_TRANSPOSE_KERNELS_H #include "device.h" +#include "kernels/accessor.h" #include namespace FlexFlow { diff --git a/lib/kernels/src/accessor.cc b/lib/kernels/src/accessor.cc index f4ee2580d3..56002718b1 100644 --- a/lib/kernels/src/accessor.cc +++ b/lib/kernels/src/accessor.cc @@ -2,6 +2,46 @@ namespace FlexFlow { +int32_t *GenericTensorAccessorW::get_int32_ptr() const { + return this->get(); +} + +int64_t *GenericTensorAccessorW::get_int64_ptr() const { + return this->get(); +} + +float *GenericTensorAccessorW::get_float_ptr() const { + return this->get(); +} + +double *GenericTensorAccessorW::get_double_ptr() const { + return this->get(); +} + +half *GenericTensorAccessorW::get_half_ptr() const { + return this->get(); +} + +int32_t const *GenericTensorAccessorR::get_int32_ptr() const { + return this->get(); +} + +int64_t const *GenericTensorAccessorR::get_int64_ptr() const { + return this->get(); +} + +float const *GenericTensorAccessorR::get_float_ptr() const { + return this->get(); +} + +double const *GenericTensorAccessorR::get_double_ptr() const { + return this->get(); +} + +half const *GenericTensorAccessorR::get_half_ptr() const { + return get(); +} + int32_t *get_int32_ptr(GenericTensorAccessorW const &a) { return get(a); } @@ -92,4 +132,10 @@ std::vector return get(a); } +GenericTensorAccessorR read_only_accessor_from_write_accessor( + GenericTensorAccessorW const &writable) { + return GenericTensorAccessorR{ + writable.data_type, writable.shape, req(writable.ptr)}; +} + } // namespace FlexFlow diff --git a/lib/kernels/src/array_shape.cc b/lib/kernels/src/array_shape.cc index 44507c14c4..5410726e0a 100644 --- a/lib/kernels/src/array_shape.cc +++ b/lib/kernels/src/array_shape.cc @@ -3,11 +3,61 @@ namespace FlexFlow { +static LegionTensorDims + legion_dims_from_ff_dims(FFOrdered const &ff_ordered) { + std::vector sizes(ff_ordered.size()); + std::reverse_copy(ff_ordered.begin(), ff_ordered.end(), sizes.begin()); + return LegionTensorDims(sizes.begin(), sizes.end()); +} + ArrayShape::ArrayShape(size_t *_dims, size_t num_dims) : dims(_dims, _dims + num_dims) {} +ArrayShape::ArrayShape(TensorShape const &shape) + : dims(legion_dims_from_ff_dims(shape.dims.ff_ordered)) {} + +ArrayShape::ArrayShape(std::vector const &input_dims) + : dims(input_dims) {} + std::size_t ArrayShape::get_volume() const { + return this->num_elements(); +} + +std::size_t ArrayShape::num_dims() const { + return this->dims.size(); +} + +std::size_t ArrayShape::get_dim() const { + return this->num_dims(); +} + +std::size_t ArrayShape::num_elements() const { + if (dims.size() == 0) { + return 0; + } return product(this->dims); } +std::size_t ArrayShape::operator[](legion_dim_t idx) const { + return dims[idx]; +} + +ArrayShape ArrayShape::sub_shape( + std::optional> start, + std::optional> end) const { + NOT_IMPLEMENTED(); +} + +std::optional ArrayShape::at_maybe(std::size_t index) const { + if (index < dims.size()) { + return dims.at(legion_dim_t(index)); + } else { + return std::nullopt; + } +} + +size_t get_volume(ArrayShape const &shape) { + return shape.get_volume(); +} + } // namespace FlexFlow diff --git a/lib/kernels/src/cpu/initializer_kernels.cc b/lib/kernels/src/cpu/initializer_kernels.cc index 0ba04304e1..f3b4c9b8fd 100644 --- a/lib/kernels/src/cpu/initializer_kernels.cc +++ b/lib/kernels/src/cpu/initializer_kernels.cc @@ -1,6 +1,7 @@ #include "kernels/initializer_kernels.h" #include "kernels/accessor.h" #include "kernels/datatype_dispatch.h" +#include "kernels/device.h" namespace FlexFlow { @@ -44,4 +45,8 @@ void zero_init_kernel(TaskLocation const &loc, } } +void zero_init_kernel_gpu(GenericTensorAccessorW const &tensor) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/kernels/src/cuda/cuda_helper.cu b/lib/kernels/src/cuda/cuda_helper.cu index 316b8ed9ec..2b46ef890a 100644 --- a/lib/kernels/src/cuda/cuda_helper.cu +++ b/lib/kernels/src/cuda/cuda_helper.cu @@ -1,5 +1,5 @@ -#include "flexflow/model.h" -#include "flexflow/utils/cuda_helper.h" +#include "device.h" +#include "kernels/datatype_dispatch.h" namespace FlexFlow { @@ -45,7 +45,7 @@ __global__ void ones_kernel(float *ptr, coord_t size) { } template -__global__ void assign_kernel(DT *ptr, coord_t size, DT value) { +__global__ void assign_kernel(DT *ptr, size_t size, DT value) { CUDA_KERNEL_LOOP(i, size) { ptr[i] = value; } @@ -70,11 +70,11 @@ __host__ void relu_backward_kernel(DataType data_type, void const *output_ptr, size_t output_size, cudaStream_t stream) { - if (data_type == DT_FLOAT) { + if (data_type == DataType::FLOAT) { reluBackward <<>>( (float *)output_grad_ptr, (float const *)output_ptr, output_size); - } else if (data_type == DT_DOUBLE) { + } else if (data_type == DataType::DOUBLE) { reluBackward <<>>( (double *)output_grad_ptr, (double const *)output_ptr, output_size); @@ -97,11 +97,11 @@ __host__ void sigmoid_backward_kernel(DataType data_type, void const *output_ptr, size_t output_size, cudaStream_t stream) { - if (data_type == DT_FLOAT) { + if (data_type == DataType::FLOAT) { sigmoid_backward_function <<>>( (float *)output_grad_ptr, (float const *)output_ptr, output_size); - } else if (data_type == DT_DOUBLE) { + } else if (data_type == DataType::DOUBLE) { sigmoid_backward_function <<>>( (double *)output_grad_ptr, (double const *)output_ptr, output_size); @@ -220,14 +220,16 @@ __host__ void checkCUDA(cudaFreeHost(host_ptr)); } -cudnnStatus_t +ffStatus_t cudnnSetTensorDescriptorFromArrayShape(cudnnTensorDescriptor_t tensor, ArrayShape const &shape) { - ArrayShape flipped = shape.reversed_dim_order(); + std::vector reversed_dims(shape.dims.begin(), shape.dims.end()); + reversed(reversed_dims); + ArrayShape flipped(reversed_dims); if (flipped.get_dim() == 5) { - assert(flipped[0] == 1); - flipped = flipped.sub_shape(1, std::nullopt); + assert(flipped[legion_dim_t(0)] == 1); + flipped = flipped.sub_shape(legion_dim_t(1), std::nullopt); } assert(flipped.get_dim() > 0); @@ -244,11 +246,11 @@ cudnnStatus_t cudnnDataType_t ff_to_cudnn_datatype(DataType type) { switch (type) { - case DT_FLOAT: + case DataType::FLOAT: return CUDNN_DATA_FLOAT; - case DT_DOUBLE: + case DataType::DOUBLE: return CUDNN_DATA_DOUBLE; - case DT_INT32: + case DataType::INT32: return CUDNN_DATA_INT32; default: assert(false && "Unsupported cudnn data type"); @@ -258,11 +260,11 @@ cudnnDataType_t ff_to_cudnn_datatype(DataType type) { cudaDataType_t ff_to_cuda_datatype(DataType type) { switch (type) { - case DT_FLOAT: + case DataType::FLOAT: return CUDA_R_32F; - case DT_DOUBLE: + case DataType::DOUBLE: return CUDA_R_64F; - case DT_INT32: + case DataType::INT32: return CUDA_R_32I; default: assert(false && "Unspoorted cuda data type"); @@ -271,15 +273,15 @@ cudaDataType_t ff_to_cuda_datatype(DataType type) { } template __global__ void - assign_kernel(half *ptr, coord_t size, half value); + assign_kernel(half *ptr, size_t size, half value); template __global__ void - assign_kernel(float *ptr, coord_t size, float value); + assign_kernel(float *ptr, size_t size, float value); template __global__ void - assign_kernel(double *ptr, coord_t size, double value); + assign_kernel(double *ptr, size_t size, double value); template __global__ void - assign_kernel(int32_t *ptr, coord_t size, int32_t value); + assign_kernel(int32_t *ptr, size_t size, int32_t value); template __global__ void - assign_kernel(int64_t *ptr, coord_t size, int64_t value); + assign_kernel(int64_t *ptr, size_t size, int64_t value); template __global__ void add_kernel(float *dst, float const *src, size_t size); @@ -289,6 +291,8 @@ template __global__ void add_kernel(int32_t *dst, int32_t const *src, size_t size); template __global__ void add_kernel(int64_t *dst, int64_t const *src, size_t size); +template __global__ void + add_kernel(bool *dst, bool const *src, unsigned long size); template __global__ void copy_kernel(float *dst, float const *src, coord_t size); @@ -314,6 +318,11 @@ template __global__ void apply_add_with_scale(int64_t *data_ptr, size_t size, int64_t scale); +template __global__ void apply_add_with_scale(bool *data_ptr, + bool const *grad_ptr, + unsigned long size, + bool scale); + template __host__ void print_tensor(float const *ptr, size_t rect, char const *prefix); template __host__ void diff --git a/lib/kernels/src/cuda/ops/attention_kernels.cu b/lib/kernels/src/cuda/ops/attention_kernels.cu index 57809f043b..e50f3983cc 100644 --- a/lib/kernels/src/cuda/ops/attention_kernels.cu +++ b/lib/kernels/src/cuda/ops/attention_kernels.cu @@ -15,6 +15,7 @@ #include "device.h" #include "kernels/attention_kernels.h" +#include "kernels/device.h" namespace FlexFlow { namespace Kernels { @@ -41,11 +42,9 @@ MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, ffSeqDataDescriptor_t vDesc; ffSeqDataDescriptor_t oDesc; void *reserveSpace; - void *dropoutStates; int *devQoSeqArray; int *devKvSeqArray; size_t reserveSpaceSize; - size_t dropoutStateSize; size_t weightSize; checkCUDA(get_legion_stream(&stream)); @@ -301,8 +300,8 @@ void backward_kernel(cudaStream_t stream, void cleanup_kernel(Allocator &allocator, MHAPerDeviceState const &device_state) { - allocator.deallocate(device_state.loWinIdx); - allocator.deallocate(device_state.hiWinIdx); + free(device_state.loWinIdx); + free(device_state.hiWinIdx); checkCUDNN(cudnnDestroyAttnDescriptor(device_state.attnDesc)); checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.qDesc)); checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.kDesc)); diff --git a/lib/kernels/src/cuda/ops/batch_matmul_kernels.cu b/lib/kernels/src/cuda/ops/batch_matmul_kernels.cu index bdf1e0fe0c..eb23514c5f 100644 --- a/lib/kernels/src/cuda/ops/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/ops/batch_matmul_kernels.cu @@ -32,7 +32,7 @@ void forward_kernel(cudaStream_t stream, int a_seq_length_dim, int b_seq_length_dim, int seq_length) { - checkCUDA(cublasSetStream(handle.blas, stream)); + checkCUBLAS(cublasSetStream(handle.blas, stream)); checkCUDNN(cudnnSetStream(handle.dnn, stream)); int lda = k; int ldb = m; @@ -63,24 +63,24 @@ void forward_kernel(cudaStream_t stream, } float alpha = 1.0f, beta = 0.0f; - checkCUDA(cublasSgemmStridedBatched(handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_N, - m, - n, - k, - &alpha, - b_input_ptr, - ldb, - strideB, - a_input_ptr, - lda, - strideA, - &beta, - output_ptr, - ldo, - strideO, - batch)); + checkCUBLAS(cublasSgemmStridedBatched(handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_N, + m, + n, + k, + &alpha, + b_input_ptr, + ldb, + strideB, + a_input_ptr, + lda, + strideA, + &beta, + output_ptr, + ldo, + strideO, + batch)); } void backward_kernel(cudaStream_t stream, @@ -95,49 +95,49 @@ void backward_kernel(cudaStream_t stream, int n, int k, int batch) { - checkCUDA(cublasSetStream(handle.blas, stream)); + checkCUBLAS(cublasSetStream(handle.blas, stream)); checkCUDNN(cudnnSetStream(handle.dnn, stream)); int a_stride = n * k; int b_stride = m * k; int o_stride = n * m; float alpha = 1.0f; - checkCUDA(cublasSgemmStridedBatched(handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - k, - n, - m, - &alpha, - b_ptr, - m, - b_stride, - o_grad_ptr, - m, - o_stride, - &alpha, - a_grad_ptr, - k, - a_stride, - batch)); - checkCUDA(cublasSgemmStridedBatched(handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m, - k, - n, - &alpha, - o_grad_ptr, - m, - o_stride, - a_ptr, - k, - a_stride, - &alpha, - b_grad_ptr, - m, - b_stride, - batch)); + checkCUBLAS(cublasSgemmStridedBatched(handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + k, + n, + m, + &alpha, + b_ptr, + m, + b_stride, + o_grad_ptr, + m, + o_stride, + &alpha, + a_grad_ptr, + k, + a_stride, + batch)); + checkCUBLAS(cublasSgemmStridedBatched(handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + m, + k, + n, + &alpha, + o_grad_ptr, + m, + o_stride, + a_ptr, + k, + a_stride, + &alpha, + b_grad_ptr, + m, + b_stride, + batch)); } } // namespace BatchMatmul diff --git a/lib/kernels/src/cuda/ops/batch_norm_kernels.cu b/lib/kernels/src/cuda/ops/batch_norm_kernels.cu index 6f08001965..6c6e17a181 100644 --- a/lib/kernels/src/cuda/ops/batch_norm_kernels.cu +++ b/lib/kernels/src/cuda/ops/batch_norm_kernels.cu @@ -17,6 +17,7 @@ #include "kernels/allocation.h" #include "kernels/batch_norm_kernels.h" #include "kernels/ff_handle.h" +#include "utils/integer_conversions.h" namespace FlexFlow { namespace Kernels { @@ -108,8 +109,6 @@ BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, #if CUDNN_VERSION >= 7000 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; #endif - fprintf( - stderr, "output(%d,%d,%d,%d)\n", output_n, output_c, output_h, output_w); checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, @@ -133,11 +132,12 @@ BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, float *saveMean = (float *)runningVar + output_c; float *saveVar = (float *)saveMean + output_c; cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); assign_kernel<<>>( - runningMean, output_c, 0.0f); + runningMean, size_t_from_int(output_c), 0.0f); assign_kernel<<>>( - runningVar, output_c, 0.0f); + runningVar, size_t_from_int(output_c), 0.0f); if (relu) { checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); @@ -160,6 +160,8 @@ BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle, output_h, output_w, relu}; + + checkCUDA(cudaStreamDestroy(stream)); return per_device_state; } diff --git a/lib/kernels/src/cuda/ops/conv_2d_kernels.cu b/lib/kernels/src/cuda/ops/conv_2d_kernels.cu index 462e8a294b..e3a4c97a31 100644 --- a/lib/kernels/src/cuda/ops/conv_2d_kernels.cu +++ b/lib/kernels/src/cuda/ops/conv_2d_kernels.cu @@ -207,46 +207,47 @@ Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle, checkCUDNN(cudnnSetTensor4dDescriptor( outputTensor, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); - float time; // select forward algorithm - fwdAlgo = selectConvolutionForwardAlgorithm(handle.dnn, - inputTensor, - input.get_float_ptr(), - filterDesc, - filter_ptr, - convDesc, - handle.workSpace, - handle.workSpaceSize, - outputTensor, - output.get_float_ptr(), - nullptr); + fwdAlgo = selectConvolutionForwardAlgorithm( + handle.dnn, + inputTensor, + static_cast(input.get_float_ptr()), + filterDesc, + filter_ptr, + convDesc, + handle.workSpace, + handle.workSpaceSize, + outputTensor, + output.get_float_ptr(), + nullptr); // select backward filter algorithm - bwdFilterAlgo = - selectConvolutionBackwardFilterAlgorithm(handle.dnn, - inputTensor, - input.get_float_ptr(), - outputTensor, - output.get_float_ptr(), - convDesc, - handle.workSpace, - handle.workSpaceSize, - filterDesc, - filter_grad_ptr, - nullptr); + bwdFilterAlgo = selectConvolutionBackwardFilterAlgorithm( + handle.dnn, + inputTensor, + static_cast(input.get_float_ptr()), + outputTensor, + output.get_float_ptr(), + convDesc, + handle.workSpace, + handle.workSpaceSize, + filterDesc, + filter_grad_ptr, + nullptr); // select backward data algorithm - bwdDataAlgo = selectConvolutionBackwardDataAlgorithm(handle.dnn, - filterDesc, - filter_ptr, - outputTensor, - output.get_float_ptr(), - convDesc, - handle.workSpace, - handle.workSpaceSize, - inputTensor, - input.get_float_ptr(), - nullptr); + bwdDataAlgo = selectConvolutionBackwardDataAlgorithm( + handle.dnn, + filterDesc, + filter_ptr, + outputTensor, + output.get_float_ptr(), + convDesc, + handle.workSpace, + handle.workSpaceSize, + inputTensor, + static_cast(input.get_float_ptr()), + nullptr); if (activation.has_value()) { checkCUDNN(cudnnSetActivationDescriptor( actiDesc, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0)); @@ -265,7 +266,7 @@ Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle, return per_device_state; } -void forward_kernel(cudaStream_t stream, +void forward_kernel(ffStream_t stream, Conv2DPerDeviceState const &m, float const *input_ptr, float *output_ptr, @@ -310,7 +311,7 @@ void forward_kernel(cudaStream_t stream, } } -void backward_kernel(cudaStream_t stream, +void backward_kernel(ffStream_t stream, Conv2DPerDeviceState const &m, float const *input_ptr, float *input_grad_ptr, diff --git a/lib/kernels/src/cuda/ops/dropout_kernels.cu b/lib/kernels/src/cuda/ops/dropout_kernels.cu index 746656f409..adf0cd8e89 100644 --- a/lib/kernels/src/cuda/ops/dropout_kernels.cu +++ b/lib/kernels/src/cuda/ops/dropout_kernels.cu @@ -24,7 +24,7 @@ namespace Dropout { DropoutPerDeviceState init_kernel(PerDeviceFFHandle handle, float rate, unsigned long long seed, - ArrayShape output_shape, + ArrayShape const &output_shape, Allocator allocator) { ffTensorDescriptor_t inputTensor; ffTensorDescriptor_t outputTensor; diff --git a/lib/kernels/src/cuda/ops/element_binary_kernels.cu b/lib/kernels/src/cuda/ops/element_binary_kernels.cu index 45b4d43006..44273a323f 100644 --- a/lib/kernels/src/cuda/ops/element_binary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_binary_kernels.cu @@ -146,7 +146,7 @@ void forward_kernel(cudaStream_t stream, OperatorType op_type, bool broadcast_inputLHS, PerDeviceFFHandle handle) { - checkCUDA(cublasSetStream(handle.blas, stream)); + checkCUBLAS(cublasSetStream(handle.blas, stream)); checkCUDNN(cudnnSetStream(handle.dnn, stream)); float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f; switch (op_type) { @@ -253,7 +253,7 @@ void backward_kernel(cudaStream_t stream, bool broadcast_inputLHS, bool broadcast_inputRHS, PerDeviceFFHandle handle) { - checkCUDA(cublasSetStream(handle.blas, stream)); + checkCUBLAS(cublasSetStream(handle.blas, stream)); checkCUDNN(cudnnSetStream(handle.dnn, stream)); if (op_type == OperatorType::EW_ADD || op_type == OperatorType::EW_SUB) { diff --git a/lib/kernels/src/cuda/ops/gather_kernels.cu b/lib/kernels/src/cuda/ops/gather_kernels.cu index e002cf7e71..11c0a1a5e7 100644 --- a/lib/kernels/src/cuda/ops/gather_kernels.cu +++ b/lib/kernels/src/cuda/ops/gather_kernels.cu @@ -15,6 +15,7 @@ #include "device.h" #include "kernels/datatype_dispatch.h" +#include "kernels/device.h" #include "kernels/gather_kernels.h" namespace FlexFlow { diff --git a/lib/kernels/src/cuda/ops/linear_kernels.cu b/lib/kernels/src/cuda/ops/linear_kernels.cu index 9a36534a1b..ca51f0d216 100644 --- a/lib/kernels/src/cuda/ops/linear_kernels.cu +++ b/lib/kernels/src/cuda/ops/linear_kernels.cu @@ -16,6 +16,7 @@ #include "device.h" #include "kernels/allocation.h" #include "kernels/linear_kernels.h" +#include "utils/integer_conversions.h" namespace FlexFlow { @@ -115,7 +116,7 @@ void forward_kernel(cudaStream_t stream, int out_dim, int batch_size) { - checkCUDA(cublasSetStream(m.handle.blas, stream)); + checkCUBLAS(cublasSetStream(m.handle.blas, stream)); checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); float alpha = 1.0f, beta = 0.0f; cudaDataType_t input_type = ff_to_cuda_datatype(m.input_type); @@ -127,46 +128,46 @@ void forward_kernel(cudaStream_t stream, #else cudaDataType_t compute_type = CUDA_R_32F; #endif - checkCUDA(cublasGemmEx(m.handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_dim, - batch_size, - in_dim, - &alpha, - weight_ptr, - weight_type, - in_dim, - input_ptr, - input_type, - in_dim, - &beta, - output_ptr, - output_type, - out_dim, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // use_bias = True - if (bias_ptr != NULL) { - checkCUDA(cublasGemmEx(m.handle.blas, + checkCUBLAS(cublasGemmEx(m.handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, out_dim, batch_size, - 1, + in_dim, &alpha, - bias_ptr, + weight_ptr, weight_type, - 1, - m.one_ptr, - CUDA_R_32F, - 1, - &alpha, + in_dim, + input_ptr, + input_type, + in_dim, + &beta, output_ptr, output_type, out_dim, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // use_bias = True + if (bias_ptr != NULL) { + checkCUBLAS(cublasGemmEx(m.handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + out_dim, + batch_size, + 1, + &alpha, + bias_ptr, + weight_type, + 1, + m.one_ptr, + CUDA_R_32F, + 1, + &alpha, + output_ptr, + output_type, + out_dim, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } if (use_activation(m.activation)) { checkCUDNN(cudnnActivationForward(m.handle.dnn, @@ -178,7 +179,7 @@ void forward_kernel(cudaStream_t stream, m.outputTensor, output_ptr)); } else if (m.activation == Activation::GELU) { - size_t elements = (size_t)out_dim * (size_t)batch_size; + size_t elements = size_t_from_int(out_dim) * size_t_from_int(batch_size); constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI) constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0/M_PI) gelu_forward_kernel<<>>( @@ -200,10 +201,8 @@ void backward_kernel(cudaStream_t stream, int in_dim, int out_dim, int batch_size) { - - checkCUDA(cublasSetStream(m.handle.blas, stream)); + checkCUBLAS(cublasSetStream(m.handle.blas, stream)); checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); - float alpha = 1.0f; cudaDataType_t input_type = ff_to_cuda_datatype(m.input_type); cudaDataType_t weight_type = ff_to_cuda_datatype(m.weight_type); @@ -229,25 +228,25 @@ void backward_kernel(cudaStream_t stream, } // Compute weight gradiant // NOTE: we use alpha=1 for kernel_grad to accumulate gradients - checkCUDA(cublasGemmEx(m.handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_dim, - out_dim, - batch_size, - &alpha, - input_ptr, - input_type, - in_dim, - output_grad_ptr, - output_type, - out_dim, - &alpha, - kernel_grad_ptr, - weight_type, - in_dim, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + checkCUBLAS(cublasGemmEx(m.handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + in_dim, + out_dim, + batch_size, + &alpha, + input_ptr, + input_type, + in_dim, + output_grad_ptr, + output_type, + out_dim, + &alpha, + kernel_grad_ptr, + weight_type, + in_dim, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); if (m.regularizer == std::nullopt) { // do nothing @@ -256,19 +255,19 @@ void backward_kernel(cudaStream_t stream, if (regularizer_attrs.has()) { L2RegularizerAttrs l2_attrs = regularizer_attrs.get(); float lambda = l2_attrs.lambda; - checkCUDA(cublasSgeam(m.handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_N, - in_dim, - out_dim, - &alpha, - (float *)kernel_grad_ptr, - in_dim, - &lambda, - (float *)kernel_ptr, - in_dim, - (float *)kernel_grad_ptr, - in_dim)); + checkCUBLAS(cublasSgeam(m.handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_N, + in_dim, + out_dim, + &alpha, + (float *)kernel_grad_ptr, + in_dim, + &lambda, + (float *)kernel_ptr, + in_dim, + (float *)kernel_grad_ptr, + in_dim)); } else { assert(false && "Only L2 regularization is supported"); } @@ -278,48 +277,48 @@ void backward_kernel(cudaStream_t stream, // NOTE: we use alpha=1 for bias_grad to accumulate gradients // use_bias = True if (bias_grad_ptr != NULL) { - checkCUDA(cublasGemmEx(m.handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - 1, - out_dim, - batch_size, - &alpha, - m.one_ptr, - CUDA_R_32F, - 1, - output_grad_ptr, - output_type, - out_dim, - &alpha, - bias_grad_ptr, - weight_type, - 1, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + checkCUBLAS(cublasGemmEx(m.handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + 1, + out_dim, + batch_size, + &alpha, + m.one_ptr, + CUDA_R_32F, + 1, + output_grad_ptr, + output_type, + out_dim, + &alpha, + bias_grad_ptr, + weight_type, + 1, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } // Compute data gradiant // NOTE: we use alpha=1 for input_grad to accumulate gradients if (input_grad_ptr != NULL) { - checkCUDA(cublasGemmEx(m.handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_N, - in_dim, - batch_size, - out_dim, - &alpha, - kernel_ptr, - weight_type, - in_dim, - output_grad_ptr, - output_type, - out_dim, - &alpha, - input_grad_ptr, - input_type, - in_dim, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + checkCUBLAS(cublasGemmEx(m.handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_N, + in_dim, + batch_size, + out_dim, + &alpha, + kernel_ptr, + weight_type, + in_dim, + output_grad_ptr, + output_type, + out_dim, + &alpha, + input_grad_ptr, + input_type, + in_dim, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } } diff --git a/lib/kernels/src/cuda/ops/partition_kernels.cu b/lib/kernels/src/cuda/ops/partition_kernels.cu index 24f16f903e..e356f83d2a 100644 --- a/lib/kernels/src/cuda/ops/partition_kernels.cu +++ b/lib/kernels/src/cuda/ops/partition_kernels.cu @@ -39,8 +39,8 @@ template struct BackwardKernel { void operator()(cudaStream_t stream, RepartitionPerDeviceState const &m, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorW const &input_grad) { + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { add_kernel><<{}( - m.data_type, stream, m, output_grad, input_grad); + m.data_type, stream, m, input_grad, output_grad); } } // namespace Repartition diff --git a/lib/kernels/src/cuda/ops/replicate_kernels.cu b/lib/kernels/src/cuda/ops/replicate_kernels.cu index 2787f78916..0c87418f58 100644 --- a/lib/kernels/src/cuda/ops/replicate_kernels.cu +++ b/lib/kernels/src/cuda/ops/replicate_kernels.cu @@ -22,13 +22,13 @@ namespace Kernels { namespace Replicate { template -__global__ void replicate_backward_kernel(T const *input_ptr, - T *output_ptr, +__global__ void replicate_backward_kernel(T *input_ptr, + T const *output_ptr, size_t num_elements, size_t num_replicas) { CUDA_KERNEL_LOOP(i, num_elements) { for (size_t j = 0; j < num_replicas; j++) { - output_ptr[i] += input_ptr[i + j * num_elements]; + input_ptr[i] += output_ptr[i + j * num_elements]; } } } @@ -39,8 +39,8 @@ struct ForwardKernel { GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - checkCUDA(cudaMemcpyAsync((void *)input.get(), - (void *)output.get(), + checkCUDA(cudaMemcpyAsync((void *)output.get(), + (void *)input.get(), input.shape.num_elements() * size_of_datatype(T), cudaMemcpyDeviceToDevice, stream)); @@ -50,8 +50,8 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(cudaStream_t stream, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output, + GenericTensorAccessorW const &input, + GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; replicate_backward_kernel> @@ -70,8 +70,8 @@ void forward_kernel(cudaStream_t stream, } void backward_kernel(cudaStream_t stream, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output, + GenericTensorAccessorW const &input, + GenericTensorAccessorR const &output, size_t num_replicas) { DataTypeDispatch1{}( input.data_type, stream, input, output, num_replicas); diff --git a/lib/kernels/src/cuda/ops/softmax_kernels.cu b/lib/kernels/src/cuda/ops/softmax_kernels.cu index 34f29243d3..93ed85de18 100644 --- a/lib/kernels/src/cuda/ops/softmax_kernels.cu +++ b/lib/kernels/src/cuda/ops/softmax_kernels.cu @@ -21,10 +21,22 @@ namespace FlexFlow { namespace Kernels { namespace Softmax { -SoftmaxPerDeviceState init_kernel(PerDeviceFFHandle const &handle, int dim) { +SoftmaxPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + int dim, + int input_n, + int input_c, + int input_h, + int input_w) { ffTensorDescriptor_t inputTensor; checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor)); + checkCUDNN(cudnnSetTensor4dDescriptor(inputTensor, + CUDNN_TENSOR_NCHW, + CUDNN_DATA_FLOAT, + input_n, + input_c, + input_h, + input_w)); SoftmaxPerDeviceState per_device_state = {handle, inputTensor, dim}; return per_device_state; diff --git a/lib/kernels/src/device.h b/lib/kernels/src/device.h index 173cd14557..96670f712f 100644 --- a/lib/kernels/src/device.h +++ b/lib/kernels/src/device.h @@ -17,7 +17,8 @@ #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) #define FF_CUDNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS -#define FF_CURAND_STATUS_SUCESS CURAND_STATUS_SUCCESS +#define FF_CURAND_STATUS_SUCCESS CURAND_STATUS_SUCCESS +#define FF_CUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS #elif defined(FF_USE_HIP_ROCM) #define FF_CUDNN_STATUS_SUCCESS miopenStatusSuccess #define FF_CURAND_STATUS_SUCESS HIPRAND_STATUS_SUCCESS @@ -40,12 +41,21 @@ using ::FlexFlow::OperatorType; #define checkCURAND(status) \ do { \ std::stringstream _error; \ - if (status != FF_CURAND_STATUS_SUCESS) { \ + if (status != FF_CURAND_STATUS_SUCCESS) { \ _error << "CURAND failure: " << status; \ FatalError(_error.str()); \ } \ } while (0) +#define checkCUBLAS(status) \ + do { \ + std::stringstream _error; \ + if (status != FF_CUBLAS_STATUS_SUCCESS) { \ + _error << "CUBLAS failure: " << status; \ + FatalError(_error.str()); \ + } \ + } while (0) + // CUDA: grid stride looping #define CUDA_KERNEL_LOOP(i, n) \ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ diff --git a/lib/kernels/src/legion_dim.cc b/lib/kernels/src/legion_dim.cc new file mode 100644 index 0000000000..9ef47d40ae --- /dev/null +++ b/lib/kernels/src/legion_dim.cc @@ -0,0 +1,13 @@ +#include "kernels/legion_dim.h" + +namespace FlexFlow { + +legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value) { + return legion_dim_t(legion_dim.value + value); +} + +legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, int num_dimensions) { + return legion_dim_t(num_dimensions - ff_dim.value - 1); +} + +} // namespace FlexFlow diff --git a/lib/kernels/src/local_cuda_allocator.cc b/lib/kernels/src/local_cuda_allocator.cc new file mode 100644 index 0000000000..931e81c0b8 --- /dev/null +++ b/lib/kernels/src/local_cuda_allocator.cc @@ -0,0 +1,32 @@ +#include "kernels/local_cuda_allocator.h" +#include "kernels/device.h" + +namespace FlexFlow { +void *LocalCudaAllocator::allocate(size_t requested_memory_size) { + void *ptr; + checkCUDA(cudaMalloc(&ptr, requested_memory_size)); + this->ptrs.insert(ptr); + return ptr; +} + +void LocalCudaAllocator::deallocate(void *ptr) { + if (contains(this->ptrs, ptr)) { + checkCUDA(cudaFree(ptr)); + this->ptrs.erase(ptr); + } else { + throw std::runtime_error( + "Deallocating a pointer that was not allocated by this Allocator"); + } +} + +LocalCudaAllocator::~LocalCudaAllocator() { + for (auto ptr : ptrs) { + checkCUDA(cudaFree(ptr)); + } +} + +Allocator create_local_cuda_memory_allocator() { + return Allocator::create(); +} + +} // namespace FlexFlow diff --git a/lib/kernels/src/managed_ff_stream.cc b/lib/kernels/src/managed_ff_stream.cc new file mode 100644 index 0000000000..7385b6cc3e --- /dev/null +++ b/lib/kernels/src/managed_ff_stream.cc @@ -0,0 +1,28 @@ +#include "kernels/managed_ff_stream.h" + +namespace FlexFlow { + +ManagedFFStream::ManagedFFStream() : stream(new ffStream_t) { + checkCUDA(cudaStreamCreate(stream)); +} + +ManagedFFStream::ManagedFFStream(ManagedFFStream &&other) noexcept + : stream(std::exchange(other.stream, nullptr)) {} + +ManagedFFStream &ManagedFFStream::operator=(ManagedFFStream &&other) noexcept { + std::swap(this->stream, other.stream); + return *this; +} + +ManagedFFStream::~ManagedFFStream() { + if (stream != nullptr) { + checkCUDA(cudaStreamDestroy(*stream)); + delete stream; + } +} + +ffStream_t const &ManagedFFStream::raw_stream() const { + return *stream; +} + +} // namespace FlexFlow diff --git a/lib/kernels/src/managed_per_device_ff_handle.cc b/lib/kernels/src/managed_per_device_ff_handle.cc new file mode 100644 index 0000000000..c050e887b6 --- /dev/null +++ b/lib/kernels/src/managed_per_device_ff_handle.cc @@ -0,0 +1,39 @@ +#include "kernels/managed_per_device_ff_handle.h" +#include "device.h" + +namespace FlexFlow { + +ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle() { + handle = new PerDeviceFFHandle; + handle->workSpaceSize = 1024 * 1024; + handle->allowTensorOpMathConversion = true; + + checkCUDNN(cudnnCreate(&handle->dnn)); + checkCUBLAS(cublasCreate(&handle->blas)); + checkCUDA(cudaMalloc(&handle->workSpace, handle->workSpaceSize)); +} + +ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle( + ManagedPerDeviceFFHandle &&other) noexcept + : handle(std::exchange(other.handle, nullptr)) {} + +ManagedPerDeviceFFHandle &ManagedPerDeviceFFHandle::operator=( + ManagedPerDeviceFFHandle &&other) noexcept { + std::swap(this->handle, other.handle); + return *this; +} + +ManagedPerDeviceFFHandle::~ManagedPerDeviceFFHandle() { + if (handle != nullptr) { + checkCUDNN(cudnnDestroy(handle->dnn)); + checkCUBLAS(cublasDestroy(handle->blas)); + checkCUDA(cudaFree(handle->workSpace)); + delete handle; + } +} + +PerDeviceFFHandle const &ManagedPerDeviceFFHandle::raw_handle() const { + return *handle; +} + +} // namespace FlexFlow diff --git a/lib/kernels/test/CMakeLists.txt b/lib/kernels/test/CMakeLists.txt new file mode 100644 index 0000000000..007740b510 --- /dev/null +++ b/lib/kernels/test/CMakeLists.txt @@ -0,0 +1,17 @@ +ff_add_test_executable( + NAME + kernels-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + doctest + utils-test-common + kernels + op-attrs + cuda + cudnn + cudart + cublas +) diff --git a/lib/kernels/test/src/test_attention_kernel.cc b/lib/kernels/test/src/test_attention_kernel.cc new file mode 100644 index 0000000000..d44129ece1 --- /dev/null +++ b/lib/kernels/test/src/test_attention_kernel.cc @@ -0,0 +1,102 @@ +#include "doctest/doctest.h" +#include "kernels/attention_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test multi-head attention kernel") { + size_t num_samples = 10; + size_t num_heads = 4; + size_t qSize = 64, kSize = 64, vSize = 64; + size_t qProjSize = 64, kProjSize = 64, vProjSize = 64, oProjSize = 64; + size_t qoSeqLength = 20, kvSeqLength = 20; + + ManagedFFStream managed_stream{}; + ManagedPerDeviceFFHandle managed_handle{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + MHAPerDeviceState state = + Kernels::MultiHeadAttention::init_kernel(managed_handle.raw_handle(), + allocator, + num_samples, + num_heads, + qSize, + kSize, + vSize, + qProjSize, + kProjSize, + vProjSize, + oProjSize, + qoSeqLength, + kvSeqLength, + false); + + TensorShape query_shape = make_float_tensor_shape_from_legion_dims( + {qoSeqLength, num_samples, qSize}); + TensorShape key_shape = make_float_tensor_shape_from_legion_dims( + {kvSeqLength, num_samples, kSize}); + TensorShape value_shape = make_float_tensor_shape_from_legion_dims( + {kvSeqLength, num_samples, vSize}); + TensorShape output_shape = make_float_tensor_shape_from_legion_dims( + {qoSeqLength, num_samples, oProjSize}); + TensorShape weight_shape = + make_float_tensor_shape_from_legion_dims({state.weightSize}); + + GenericTensorAccessorW query_accessor = + create_random_filled_accessor_w(query_shape, allocator); + GenericTensorAccessorW key_accessor = + create_random_filled_accessor_w(key_shape, allocator); + GenericTensorAccessorW value_accessor = + create_random_filled_accessor_w(value_shape, allocator); + GenericTensorAccessorW weight_accessor = + create_random_filled_accessor_w(weight_shape, allocator); + + SUBCASE("forward_kernel") { + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::MultiHeadAttention::forward_kernel( + managed_stream.raw_stream(), + state, + query_accessor.get_float_ptr(), + key_accessor.get_float_ptr(), + value_accessor.get_float_ptr(), + weight_accessor.get_float_ptr(), + output_accessor.get_float_ptr()); + + std::vector host_output = load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(host_output)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW query_grad_accessor = + create_random_filled_accessor_w(query_shape, allocator); + GenericTensorAccessorW key_grad_accessor = + create_random_filled_accessor_w(key_shape, allocator); + GenericTensorAccessorW value_grad_accessor = + create_random_filled_accessor_w(value_shape, allocator); + GenericTensorAccessorW weight_grad_accessor = + create_random_filled_accessor_w(weight_shape, allocator); + GenericTensorAccessorW output_grad_accessor = + create_random_filled_accessor_w(output_shape, allocator); + + Kernels::MultiHeadAttention::backward_kernel( + managed_stream.raw_stream(), + state, + query_accessor.get_float_ptr(), + query_grad_accessor.get_float_ptr(), + key_accessor.get_float_ptr(), + key_grad_accessor.get_float_ptr(), + value_accessor.get_float_ptr(), + value_grad_accessor.get_float_ptr(), + weight_accessor.get_float_ptr(), + weight_grad_accessor.get_float_ptr(), + output_grad_accessor.get_float_ptr()); + } + + Kernels::MultiHeadAttention::cleanup_kernel(allocator, state); + } +} diff --git a/lib/kernels/test/src/test_batch_matmul_kernel.cc b/lib/kernels/test/src/test_batch_matmul_kernel.cc new file mode 100644 index 0000000000..18e6977148 --- /dev/null +++ b/lib/kernels/test/src/test_batch_matmul_kernel.cc @@ -0,0 +1,73 @@ +#include "doctest/doctest.h" +#include "kernels/batch_matmul_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test BatchMatmul Kernel") { + size_t m = 10; + size_t n = 10; + size_t k = 10; + size_t batch = 5; + size_t a_seq_length_dim = -1; + size_t b_seq_length_dim = -1; + size_t seq_length = -1; + + ManagedFFStream managed_stream{}; + ManagedPerDeviceFFHandle managed_handle{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + TensorShape input_shape_a = + make_float_tensor_shape_from_legion_dims({m, k, batch}); + TensorShape input_shape_b = + make_float_tensor_shape_from_legion_dims({k, n, batch}); + TensorShape output_shape = + make_float_tensor_shape_from_legion_dims({m, n, batch}); + + GenericTensorAccessorW a_accessor = + create_random_filled_accessor_w(input_shape_a, allocator); + GenericTensorAccessorW b_accessor = + create_random_filled_accessor_w(input_shape_b, allocator); + GenericTensorAccessorW output_accessor = + create_random_filled_accessor_w(output_shape, allocator); + + SUBCASE("forward_kernel") { + Kernels::BatchMatmul::forward_kernel(managed_stream.raw_stream(), + managed_handle.raw_handle(), + output_accessor.get_float_ptr(), + a_accessor.get_float_ptr(), + b_accessor.get_float_ptr(), + m, + n, + k, + batch, + a_seq_length_dim, + b_seq_length_dim, + seq_length); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW o_grad_accessor = + create_random_filled_accessor_w(output_shape, allocator); + GenericTensorAccessorW a_grad_accessor = + allocator.allocate_tensor(input_shape_a); + GenericTensorAccessorW b_grad_accessor = + allocator.allocate_tensor(input_shape_b); + + Kernels::BatchMatmul::backward_kernel(managed_stream.raw_stream(), + managed_handle.raw_handle(), + output_accessor.get_float_ptr(), + o_grad_accessor.get_float_ptr(), + a_accessor.get_float_ptr(), + a_grad_accessor.get_float_ptr(), + b_accessor.get_float_ptr(), + b_grad_accessor.get_float_ptr(), + m, + n, + k, + batch); + } + } +} diff --git a/lib/kernels/test/src/test_batch_norm_kernel.cc b/lib/kernels/test/src/test_batch_norm_kernel.cc new file mode 100644 index 0000000000..8487bbda6a --- /dev/null +++ b/lib/kernels/test/src/test_batch_norm_kernel.cc @@ -0,0 +1,103 @@ +#include "doctest/doctest.h" +#include "kernels/batch_norm_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test BatchNorm Kernel") { + size_t output_n = 1, output_c = 10, output_h = 10, output_w = 10; + + ManagedFFStream managed_stream{}; + ManagedPerDeviceFFHandle managed_handle{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + BatchNormPerDeviceState state = + Kernels::BatchNorm::init_kernel(managed_handle.raw_handle(), + allocator, + nullptr, + output_n, + output_c, + output_h, + output_w, + true); + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims( + {output_n, output_c, output_h, output_w}); + TensorShape output_shape = make_float_tensor_shape_from_legion_dims( + {output_n, output_c, output_h, output_w}); + TensorShape scale_shape = make_float_tensor_shape_from_legion_dims( + {output_n, output_c, output_h, output_w}); + TensorShape bias_shape = make_float_tensor_shape_from_legion_dims( + {output_n, output_c, output_h, output_w}); + + GenericTensorAccessorW input_accessor = + create_random_filled_accessor_w(input_shape, allocator); + GenericTensorAccessorW output_accessor = + create_random_filled_accessor_w(output_shape, allocator); + GenericTensorAccessorW scale_accessor = + create_filled_accessor_w(scale_shape, allocator, 1.0f); + + SUBCASE("forward_kernel") { + GenericTensorAccessorW bias_accessor = + create_filled_accessor_w(bias_shape, allocator, 0.0f); + + Kernels::BatchNorm::forward_kernel(managed_stream.raw_stream(), + state, + input_accessor.get_float_ptr(), + output_accessor.get_float_ptr(), + scale_accessor.get_float_ptr(), + bias_accessor.get_float_ptr()); + + std::vector host_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(host_output_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW output_grad_accessor = + create_random_filled_accessor_w(output_shape, allocator); + GenericTensorAccessorW input_grad_accessor = + create_random_filled_accessor_w(input_shape, allocator); + GenericTensorAccessorW scale_grad_accessor = + create_random_filled_accessor_w(scale_shape, allocator); + GenericTensorAccessorW bias_grad_accessor = + create_random_filled_accessor_w(bias_shape, allocator); + + Kernels::BatchNorm::backward_kernel(managed_stream.raw_stream(), + state, + input_accessor.get_float_ptr(), + output_grad_accessor.get_float_ptr(), + output_accessor.get_float_ptr(), + input_grad_accessor.get_float_ptr(), + scale_accessor.get_float_ptr(), + scale_grad_accessor.get_float_ptr(), + bias_grad_accessor.get_float_ptr(), + input_accessor.shape.num_elements()); + + std::vector host_input_grad_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + std::vector host_scale_grad_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(scale_grad_accessor)); + std::vector host_bias_grad_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(bias_grad_accessor)); + + CHECK(contains_non_zero(host_input_grad_data)); + CHECK(contains_non_zero(host_scale_grad_data)); + CHECK(contains_non_zero(host_bias_grad_data)); + } + + Kernels::BatchNorm::cleanup_kernel(allocator, + state.inputTensor, + state.biasTensor, + state.outputTensor, + state.actiDesc, + true, + state.runningMean); + } +} diff --git a/lib/kernels/test/src/test_cast_kernel.cc b/lib/kernels/test/src/test_cast_kernel.cc new file mode 100644 index 0000000000..004bc9c32f --- /dev/null +++ b/lib/kernels/test/src/test_cast_kernel.cc @@ -0,0 +1,57 @@ +#include "doctest/doctest.h" +#include "kernels/cast_kernels.h" +#include "kernels/cast_kernels_cpu.h" +#include "test_utils.h" +#include + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Call Cast Forward and Backward Kernels") { + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + TensorShape input_shape = + make_float_tensor_shape_from_legion_dims({100, 100}); + TensorShape output_shape = + make_double_tensor_shape_from_legion_dims({100, 100}); + + GenericTensorAccessorW output_accessor = + create_random_filled_accessor_w(output_shape, allocator); + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(input_shape, allocator)); + + Kernels::Cast::forward_kernel(managed_stream.raw_stream(), + input_accessor, + output_accessor, + DataType::FLOAT, + DataType::DOUBLE); + + std::vector host_double_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + + CHECK(contains_non_zero(host_double_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW grad_input_accessor = + allocator.allocate_tensor(input_shape); + + Kernels::Cast::backward_kernel( + managed_stream.raw_stream(), + read_only_accessor_from_write_accessor(output_accessor), + grad_input_accessor, + DataType::DOUBLE, + DataType::FLOAT); + + std::vector host_grad_float_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(grad_input_accessor)); + CHECK(contains_non_zero(host_grad_float_data)); + } + } +} diff --git a/lib/kernels/test/src/test_combine_kernel.cc b/lib/kernels/test/src/test_combine_kernel.cc new file mode 100644 index 0000000000..2e1000cb95 --- /dev/null +++ b/lib/kernels/test/src/test_combine_kernel.cc @@ -0,0 +1,49 @@ +#include "doctest/doctest.h" +#include "kernels/combine_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test combine kernel") { + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + TensorShape input_shape = + make_float_tensor_shape_from_legion_dims({100, 100}); + TensorShape output_shape = input_shape; + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(input_shape, allocator)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Combine::forward_kernel( + managed_stream.raw_stream(), input_accessor, output_accessor); + + std::vector host_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(host_output_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(output_shape, allocator)); + GenericTensorAccessorW input_grad_accessor = + allocator.allocate_tensor(input_shape); + + Kernels::Combine::backward_kernel(managed_stream.raw_stream(), + output_grad_accessor, + input_grad_accessor); + + std::vector host_input_grad = load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + CHECK(contains_non_zero(host_input_grad)); + } + } +} diff --git a/lib/kernels/test/src/test_concat_kernel.cc b/lib/kernels/test/src/test_concat_kernel.cc new file mode 100644 index 0000000000..bf2a521b4e --- /dev/null +++ b/lib/kernels/test/src/test_concat_kernel.cc @@ -0,0 +1,56 @@ +#include "doctest/doctest.h" +#include "kernels/concat_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test concat kernel forward and backward") { + size_t num_inputs = 3; + size_t size_per_input = 100; + ff_dim_t concat_axis = ff_dim_t(0); + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + TensorShape input_shape = + make_float_tensor_shape_from_legion_dims({size_per_input}); + TensorShape output_shape = + make_float_tensor_shape_from_legion_dims({size_per_input, num_inputs}); + + Allocator allocator = create_local_cuda_memory_allocator(); + + SUBCASE("forward_kernel") { + std::vector input_accessors = + repeat(num_inputs, [&]() { + return read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(input_shape, allocator)); + }); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Concat::forward_kernel(managed_stream.raw_stream(), + output_accessor, + input_accessors, + concat_axis); + + std::vector host_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + + CHECK(contains_non_zero(host_output_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(output_shape, allocator)); + std::vector input_grad_accessors = repeat( + num_inputs, [&]() { return allocator.allocate_tensor(input_shape); }); + + Kernels::Concat::backward_kernel(managed_stream.raw_stream(), + output_grad_accessor, + input_grad_accessors, + concat_axis); + } + } +} diff --git a/lib/kernels/test/src/test_cuda.cc b/lib/kernels/test/src/test_cuda.cc new file mode 100644 index 0000000000..ed5852bc31 --- /dev/null +++ b/lib/kernels/test/src/test_cuda.cc @@ -0,0 +1,32 @@ +#include "doctest/doctest.h" +#include "test_utils.h" + +#include + +namespace FlexFlow { +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test CUDA") { + int deviceCount = 0; + + cudaError_t device_error = cudaGetDeviceCount(&deviceCount); + CHECK(device_error == cudaSuccess); + CHECK(deviceCount > 0); + + int driverVersion = 0; + cudaError_t driver_error = cudaDriverGetVersion(&driverVersion); + CHECK(driver_error == cudaSuccess); + CHECK(driverVersion > 0); + + int runtimeVersion = 0; + cudaError_t runtime_error = cudaRuntimeGetVersion(&runtimeVersion); + CHECK(runtime_error == cudaSuccess); + CHECK(runtimeVersion > 0); + + if (device_error == cudaSuccess) { + void *ptr; + checkCUDA(cudaMalloc(&ptr, 1)); + checkCUDA(cudaFree(ptr)); + } + } +} +} // namespace FlexFlow diff --git a/lib/kernels/test/src/test_dropout.cc b/lib/kernels/test/src/test_dropout.cc new file mode 100644 index 0000000000..981bc611d8 --- /dev/null +++ b/lib/kernels/test/src/test_dropout.cc @@ -0,0 +1,69 @@ +#include "doctest/doctest.h" +#include "kernels/dropout_kernels.h" +#include "test_utils.h" +#include "utils/containers.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Dropout Kernels") { + unsigned long long seed = 12345; + float dropout_rate = 0.1; + + ArrayShape shape = ArrayShape{ + std::vector{10, 10}, + }; + + TensorShape input_shape = + make_float_tensor_shape_from_legion_dims({10, 10}); + TensorShape output_shape = input_shape; + + ManagedFFStream managed_stream{}; + ManagedPerDeviceFFHandle managed_handle{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + DropoutPerDeviceState state = Kernels::Dropout::init_kernel( + managed_handle.raw_handle(), dropout_rate, seed, shape, allocator); + + auto get_zero_count = [](std::vector const &data) { + return count(data, [](float x) { return x == 0.0f; }); + }; + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(input_shape, allocator)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Dropout::forward_kernel(managed_stream.raw_stream(), + state, + input_accessor.get_float_ptr(), + output_accessor.get_float_ptr()); + + std::vector host_output_accessor = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + + CHECK(contains_non_zero(host_output_accessor)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW output_grad_data = + create_random_filled_accessor_w(output_shape, allocator); + GenericTensorAccessorW input_grad_data = + create_random_filled_accessor_w(input_shape, allocator); + + Kernels::Dropout::backward_kernel(managed_stream.raw_stream(), + state, + output_grad_data.get_float_ptr(), + input_grad_data.get_float_ptr()); + } + + Kernels::Dropout::cleanup_kernel(allocator, + state.inputTensor, + state.outputTensor, + state.dropoutDesc, + state.dropoutStates); + } +} diff --git a/lib/kernels/test/src/test_flat_kernel.cc b/lib/kernels/test/src/test_flat_kernel.cc new file mode 100644 index 0000000000..70894858e3 --- /dev/null +++ b/lib/kernels/test/src/test_flat_kernel.cc @@ -0,0 +1,57 @@ +#include "doctest/doctest.h" +#include "kernels/flat_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Flat Kernel") { + Allocator allocator = create_local_cuda_memory_allocator(); + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims({100}); + TensorShape output_shape = input_shape; + + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(input_shape, allocator, 2.0f)); + + SUBCASE("forward_kernel") { + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Flat::forward_kernel(managed_stream.raw_stream(), + input_accessor, + output_accessor.get_float_ptr()); + + std::vector check_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + + std::vector expected_output_data( + input_accessor.shape.num_elements(), 2.0f); + CHECK(check_output_data == expected_output_data); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW output_grad_accessor = + create_filled_accessor_w(output_shape, allocator, 0.0f); + GenericTensorAccessorW input_grad_accessor = + create_filled_accessor_w(input_shape, allocator, 1.0f); + + Kernels::Flat::backward_kernel(managed_stream.raw_stream(), + input_accessor, + input_grad_accessor.get_float_ptr(), + output_grad_accessor.get_float_ptr()); + + std::vector backward_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + + std::vector expected_output_data( + input_accessor.shape.num_elements(), 1.0f); + CHECK(backward_output_data == expected_output_data); + } + } +} diff --git a/lib/kernels/test/src/test_gather_kernels.cc b/lib/kernels/test/src/test_gather_kernels.cc new file mode 100644 index 0000000000..88ac2f6889 --- /dev/null +++ b/lib/kernels/test/src/test_gather_kernels.cc @@ -0,0 +1,60 @@ +#include "doctest/doctest.h" +#include "kernels/gather_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Gather Forward and Backward Kernel") { + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + GatherPerDeviceState state = {managed_handle.raw_handle(), legion_dim_t(2)}; + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims({100}); + TensorShape output_shape = make_float_tensor_shape_from_legion_dims({50}); + + GenericTensorAccessorR index_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(output_shape, allocator)); + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(input_shape, allocator)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Gather::forward_kernel(managed_stream.raw_stream(), + state, + input_accessor, + index_accessor, + output_accessor); + + std::vector host_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(host_output_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(output_shape, allocator)); + GenericTensorAccessorW input_grad_accessor = + create_random_filled_accessor_w(input_shape, allocator); + + Kernels::Gather::backward_kernel(managed_stream.raw_stream(), + state, + output_grad_accessor, + index_accessor, + input_grad_accessor); + + std::vector host_input_grad_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + CHECK(contains_non_zero(host_input_grad_data)); + } + } +} diff --git a/lib/kernels/test/src/test_layer_norm_kernels.cc b/lib/kernels/test/src/test_layer_norm_kernels.cc new file mode 100644 index 0000000000..03b2f56bb9 --- /dev/null +++ b/lib/kernels/test/src/test_layer_norm_kernels.cc @@ -0,0 +1,75 @@ +#include "doctest/doctest.h" +#include "kernels/layer_norm_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test LayerNorm Forward and Backward Kernel") { + size_t batch_size = 10; + size_t feature_size = 10; + float epsilon = 1e-5f; + bool elementwise_affine = true; + + TensorShape input_shape = + make_float_tensor_shape_from_legion_dims({batch_size, feature_size}); + TensorShape output_shape = input_shape; + TensorShape feature_shape = + make_float_tensor_shape_from_legion_dims({feature_size}); + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + LayerNormPerDeviceState state = + Kernels::LayerNorm::init_kernel(managed_handle.raw_handle(), + allocator, + elementwise_affine, + batch_size, + feature_size, + epsilon); + + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(input_shape, allocator)); + GenericTensorAccessorW gamma_accessor = + create_filled_accessor_w(feature_shape, allocator, 1.0f); + + SUBCASE("forward_kernel") { + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + GenericTensorAccessorW beta_accessor = + create_filled_accessor_w(feature_shape, allocator, 0.0f); + + Kernels::LayerNorm::forward_kernel(managed_stream.raw_stream(), + state, + input_accessor, + output_accessor, + gamma_accessor, + beta_accessor); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(output_shape, allocator)); + GenericTensorAccessorW input_grad_accessor = + create_random_filled_accessor_w(input_shape, allocator); + GenericTensorAccessorW gamma_grad_accessor = + allocator.allocate_tensor(feature_shape); + GenericTensorAccessorW beta_grad_accessor = + allocator.allocate_tensor(feature_shape); + + Kernels::LayerNorm::backward_kernel( + managed_stream.raw_stream(), + state, + output_grad_accessor, + input_accessor, + input_grad_accessor, + read_only_accessor_from_write_accessor(gamma_accessor), + gamma_grad_accessor, + beta_grad_accessor); + } + } +} diff --git a/lib/kernels/test/src/test_partition_kernel.cc b/lib/kernels/test/src/test_partition_kernel.cc new file mode 100644 index 0000000000..437b37e954 --- /dev/null +++ b/lib/kernels/test/src/test_partition_kernel.cc @@ -0,0 +1,61 @@ +#include "doctest/doctest.h" +#include "kernels/partition_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Partition Forward and Backward") { + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + RepartitionPerDeviceState state = Kernels::Repartition::init_kernel( + managed_handle.raw_handle(), DataType::FLOAT); + + TensorShape input_shape = + make_float_tensor_shape_from_legion_dims({10, 10}); + TensorShape output_shape = input_shape; + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(input_shape, allocator, 1.0f)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Repartition::forward_kernel( + managed_stream.raw_stream(), state, input_accessor, output_accessor); + + std::vector check_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + + std::vector expected_output_data( + input_accessor.shape.num_elements(), 1.0f); + CHECK(check_output_data == expected_output_data); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(output_shape, allocator, 1.0f)); + GenericTensorAccessorW input_grad_accessor = + create_filled_accessor_w(input_shape, allocator, 2.0f); + + Kernels::Repartition::backward_kernel(managed_stream.raw_stream(), + state, + input_grad_accessor, + output_grad_accessor); + + std::vector host_grad_input_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + + std::vector expected_grad_input_data( + input_grad_accessor.shape.num_elements(), 3.0f); + CHECK(host_grad_input_data == expected_grad_input_data); + } + } +} diff --git a/lib/kernels/test/src/test_pool_2d_kernels.cc b/lib/kernels/test/src/test_pool_2d_kernels.cc new file mode 100644 index 0000000000..ebb92d39db --- /dev/null +++ b/lib/kernels/test/src/test_pool_2d_kernels.cc @@ -0,0 +1,79 @@ +#include "doctest/doctest.h" +#include "kernels/pool_2d_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Pool2D Forward and Backward Kernel") { + size_t input_w = 10, input_h = 10, input_c = 3, input_n = 1; + size_t output_w = 5, output_h = 5, output_c = 3, output_n = 1; + size_t pad_h = 0, pad_w = 0, kernel_h = 2, kernel_w = 2, stride_h = 2, + stride_w = 2; + + PoolOp pool_type = PoolOp::MAX; + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + Pool2DPerDeviceState state = + Kernels::Pool2D::init_kernel(managed_handle.raw_handle(), + std::nullopt, + input_w, + input_h, + input_c, + input_n, + output_w, + output_h, + output_c, + output_n, + pad_h, + pad_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + pool_type); + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims( + {input_w, input_h, input_c, input_n}); + TensorShape output_shape = make_float_tensor_shape_from_legion_dims( + {output_w, output_h, output_c, output_n}); + + GenericTensorAccessorW input_accessor = + create_random_filled_accessor_w(input_shape, allocator); + GenericTensorAccessorW output_accessor = + create_random_filled_accessor_w(output_shape, allocator); + + SUBCASE("forward_kernel") { + Kernels::Pool2D::forward_kernel(managed_stream.raw_stream(), + state, + input_accessor.ptr, + output_accessor.ptr); + + std::vector host_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(host_output_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW output_grad_accessor = + create_filled_accessor_w(output_shape, allocator, 1.0f); + GenericTensorAccessorW input_grad_accessor = + allocator.allocate_tensor(input_shape); + + Kernels::Pool2D::backward_kernel(managed_stream.raw_stream(), + state, + input_accessor.ptr, + input_grad_accessor.ptr, + output_accessor.ptr, + output_grad_accessor.ptr); + + std::vector host_input_grad = load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + CHECK(contains_non_zero(host_input_grad)); + } + } +} diff --git a/lib/kernels/test/src/test_reduction_kernel.cc b/lib/kernels/test/src/test_reduction_kernel.cc new file mode 100644 index 0000000000..1ea740f336 --- /dev/null +++ b/lib/kernels/test/src/test_reduction_kernel.cc @@ -0,0 +1,58 @@ +#include "doctest/doctest.h" +#include "kernels/reduction_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Reduction Forward and Backward Kernel") { + std::size_t num_replicas = 5; + + TensorShape input_shape = + make_float_tensor_shape_from_legion_dims({10, 10, 10, 10, 10}); + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + SUBCASE("forward_kernel") { + TensorShape output_shape = make_float_tensor_shape_from_legion_dims({10}); + + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(input_shape, allocator)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Reduction::forward_kernel(managed_stream.raw_stream(), + input_accessor, + output_accessor, + num_replicas); + + std::vector host_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(host_output_data)); + } + + SUBCASE("backward_kernel") { + TensorShape output_shape = input_shape; + + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(output_shape, allocator, 1.0f)); + GenericTensorAccessorW input_grad_accessor = + allocator.allocate_tensor(input_shape); + + Kernels::Reduction::backward_kernel(managed_stream.raw_stream(), + input_grad_accessor, + output_grad_accessor); + + std::vector expected_grad_input_data( + input_grad_accessor.shape.num_elements(), 1.0f); + std::vector host_grad_data = load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + CHECK(host_grad_data == expected_grad_input_data); + } + } +} diff --git a/lib/kernels/test/src/test_replicate_kernel.cc b/lib/kernels/test/src/test_replicate_kernel.cc new file mode 100644 index 0000000000..86d790f03c --- /dev/null +++ b/lib/kernels/test/src/test_replicate_kernel.cc @@ -0,0 +1,55 @@ +#include "doctest/doctest.h" +#include "kernels/replicate_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Replicate Kernel") { + std::size_t num_replicas = 10; + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims({100}); + TensorShape output_shape = input_shape; + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(input_shape, allocator, 1.0f)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Replicate::forward_kernel( + managed_stream.raw_stream(), input_accessor, output_accessor); + + std::vector check_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + + std::vector expected_output_data( + input_accessor.shape.num_elements(), 1.0f); + CHECK(check_output_data == expected_output_data); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW input_grad_accessor = + create_filled_accessor_w(input_shape, allocator, 1.0f); + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(output_shape, allocator, 1.0f)); + + Kernels::Replicate::backward_kernel(managed_stream.raw_stream(), + input_grad_accessor, + output_grad_accessor, + num_replicas); + + std::vector check_aggregated_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + CHECK(contains_non_zero(check_aggregated_data)); + } + } +} diff --git a/lib/kernels/test/src/test_reshape_kernel.cc b/lib/kernels/test/src/test_reshape_kernel.cc new file mode 100644 index 0000000000..f56bfacc2b --- /dev/null +++ b/lib/kernels/test/src/test_reshape_kernel.cc @@ -0,0 +1,59 @@ +#include "doctest/doctest.h" +#include "kernels/reshape_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Reshape Forward and Backward") { + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims({100}); + TensorShape output_shape = input_shape; + + ReshapePerDeviceState state = + Kernels::Reshape::init_kernel(DataType::FLOAT); + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(input_shape, allocator, 1.0f)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Reshape::forward_kernel( + managed_stream.raw_stream(), state, input_accessor, output_accessor); + + std::vector check_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + + std::vector expected_output_data( + input_accessor.shape.num_elements(), 1.0f); + CHECK(check_output_data == expected_output_data); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(output_shape, allocator, 1.0f)); + GenericTensorAccessorW input_grad_accessor = + create_filled_accessor_w(input_shape, allocator, 2.0f); + + Kernels::Reshape::backward_kernel(managed_stream.raw_stream(), + state, + input_grad_accessor, + output_grad_accessor); + + std::vector host_grad_input_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + + std::vector expected_grad_input_data( + input_grad_accessor.shape.num_elements(), 3.0f); + CHECK(host_grad_input_data == expected_grad_input_data); + } + } +} diff --git a/lib/kernels/test/src/test_reverse_kernels.cc b/lib/kernels/test/src/test_reverse_kernels.cc new file mode 100644 index 0000000000..cdaf65a305 --- /dev/null +++ b/lib/kernels/test/src/test_reverse_kernels.cc @@ -0,0 +1,62 @@ +#include "doctest/doctest.h" +#include "kernels/reverse_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Call Reverse Forward and Backward Kernels") { + std::size_t reverse_dim_size = 10; + std::size_t in_blk_size = 10; + std::size_t num_out_blks = 1; + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims({100}); + TensorShape output_shape = input_shape; + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_filled_accessor_w(input_shape, allocator, 1.0f)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Reverse::forward_kernel(managed_stream.raw_stream(), + input_accessor.get_float_ptr(), + output_accessor.get_float_ptr(), + num_out_blks, + reverse_dim_size, + in_blk_size, + input_accessor.shape.num_elements()); + + std::vector check_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(check_output_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW output_grad_accessor = + create_random_filled_accessor_w(output_shape, allocator); + GenericTensorAccessorW input_grad_accessor = + create_random_filled_accessor_w(input_shape, allocator); + + Kernels::Reverse::backward_kernel( + managed_stream.raw_stream(), + output_grad_accessor.get_float_ptr(), + input_grad_accessor.get_float_ptr(), + num_out_blks, + reverse_dim_size, + in_blk_size, + input_grad_accessor.shape.num_elements()); + + std::vector host_grad_input_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + CHECK(contains_non_zero(host_grad_input_data)); + } + } +} diff --git a/lib/kernels/test/src/test_softmax_kernel.cc b/lib/kernels/test/src/test_softmax_kernel.cc new file mode 100644 index 0000000000..f49c1ebbcc --- /dev/null +++ b/lib/kernels/test/src/test_softmax_kernel.cc @@ -0,0 +1,60 @@ +#include "doctest/doctest.h" +#include "kernels/softmax_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Softmax Kernel Operations") { + int input_n = 1, input_c = 1, input_h = 1, input_w = 100, channels = 100; + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims({100}); + TensorShape output_shape = input_shape; + + SoftmaxPerDeviceState state = Kernels::Softmax::init_kernel( + managed_handle.raw_handle(), 0, input_n, channels, input_h, input_w); + + GenericTensorAccessorW output_accessor = + create_random_filled_accessor_w(output_shape, allocator); + + SUBCASE("forward_kernel") { + GenericTensorAccessorW input_accessor = + create_random_filled_accessor_w(input_shape, allocator); + + Kernels::Softmax::forward_kernel(managed_stream.raw_stream(), + state, + input_accessor.get_float_ptr(), + output_accessor.get_float_ptr()); + + std::vector host_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(host_output_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorW output_grad_accessor = + create_filled_accessor_w(output_shape, allocator, 1.0f); + GenericTensorAccessorW input_grad_accessor = + allocator.allocate_tensor(input_shape); + + Kernels::Softmax::backward_kernel( + managed_stream.raw_stream(), + input_grad_accessor.get_float_ptr(), + output_grad_accessor.get_float_ptr(), + output_grad_accessor.shape.num_elements()); + + std::vector expected_input_grad_data = + std::vector(input_grad_accessor.shape.num_elements(), 1.0f); + std::vector host_input_grad_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + CHECK(host_input_grad_data == expected_input_grad_data); + } + } +} diff --git a/lib/kernels/test/src/test_split_kernel.cc b/lib/kernels/test/src/test_split_kernel.cc new file mode 100644 index 0000000000..7cc2b28c9e --- /dev/null +++ b/lib/kernels/test/src/test_split_kernel.cc @@ -0,0 +1,61 @@ +#include "doctest/doctest.h" +#include "kernels/split_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Split Forward and Backward Kernel") { + size_t num_outputs = 2; + coord_t out_blk_sizes[] = {50, 50}; + coord_t in_blk_size = 100; + coord_t num_blks = 1; + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + TensorShape input_shape = make_float_tensor_shape_from_legion_dims({100}); + TensorShape output_shape = make_float_tensor_shape_from_legion_dims({50}); + + SUBCASE("forward_kernel") { + GenericTensorAccessorW input_accessor = + create_random_filled_accessor_w(input_shape, allocator); + + std::vector output_ptrs = repeat(num_outputs, [&]() { + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + return output_accessor.get_float_ptr(); + }); + + Kernels::Split::forward_kernel(managed_stream.raw_stream(), + output_ptrs.data(), + input_accessor.get_float_ptr(), + out_blk_sizes, + in_blk_size, + num_blks, + num_outputs); + } + + SUBCASE("backward_kernel") { + std::vector output_grad_ptrs(num_outputs); + for (int i = 0; i < num_outputs; i++) { + GenericTensorAccessorW output_grad_accessor = + create_random_filled_accessor_w(output_shape, allocator); + output_grad_ptrs[i] = output_grad_accessor.get_float_ptr(); + } + + GenericTensorAccessorW input_grad_accessor = + create_filled_accessor_w(input_shape, allocator, 0.0f); + + Kernels::Split::backward_kernel(managed_stream.raw_stream(), + input_grad_accessor.get_float_ptr(), + (float const **)output_grad_ptrs.data(), + out_blk_sizes, + in_blk_size, + num_blks, + num_outputs); + } + } +} diff --git a/lib/kernels/test/src/test_transpose_kernel.cc b/lib/kernels/test/src/test_transpose_kernel.cc new file mode 100644 index 0000000000..2fc186a257 --- /dev/null +++ b/lib/kernels/test/src/test_transpose_kernel.cc @@ -0,0 +1,58 @@ +#include "doctest/doctest.h" +#include "kernels/transpose_kernels.h" +#include "test_utils.h" + +using namespace ::FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Transpose Kernel Operations") { + std::size_t num_dims = 2; + + std::vector perm = {ff_dim_t(0), ff_dim_t(1)}; + + ManagedPerDeviceFFHandle managed_handle{}; + ManagedFFStream managed_stream{}; + + Allocator allocator = create_local_cuda_memory_allocator(); + + TransposePerDeviceState state = + Kernels::Transpose::init_kernel(num_dims, perm); + + TensorShape input_shape = + make_float_tensor_shape_from_legion_dims({10, 10}); + TensorShape output_shape = input_shape; + + SUBCASE("forward_kernel") { + GenericTensorAccessorR input_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(input_shape, allocator)); + GenericTensorAccessorW output_accessor = + allocator.allocate_tensor(output_shape); + + Kernels::Transpose::forward_kernel( + managed_stream.raw_stream(), state, input_accessor, output_accessor); + + std::vector host_output_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(output_accessor)); + CHECK(contains_non_zero(host_output_data)); + } + + SUBCASE("backward_kernel") { + GenericTensorAccessorR output_grad_accessor = + read_only_accessor_from_write_accessor( + create_random_filled_accessor_w(output_shape, allocator)); + GenericTensorAccessorW input_grad_accessor = + create_random_filled_accessor_w(input_shape, allocator); + + Kernels::Transpose::backward_kernel(managed_stream.raw_stream(), + state, + input_grad_accessor, + output_grad_accessor); + + std::vector host_grad_input_data = + load_data_to_host_from_device( + read_only_accessor_from_write_accessor(input_grad_accessor)); + CHECK(contains_non_zero(host_grad_input_data)); + } + } +} diff --git a/lib/kernels/test/src/test_utils.cc b/lib/kernels/test/src/test_utils.cc new file mode 100644 index 0000000000..b591642570 --- /dev/null +++ b/lib/kernels/test/src/test_utils.cc @@ -0,0 +1,105 @@ +#include "test_utils.h" + +GenericTensorAccessorW create_random_filled_accessor_w(TensorShape const &shape, + Allocator &allocator, + bool cpu_fill) { + GenericTensorAccessorW accessor = allocator.allocate_tensor(shape); + size_t volume = accessor.shape.num_elements(); + std::vector host_data(volume); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(-1.0f, 1.0f); + + for (auto &val : host_data) { + val = dist(gen); + } + + if (cpu_fill) { + memcpy(accessor.ptr, host_data.data(), host_data.size() * sizeof(float)); + } else { + checkCUDA(cudaMemcpy(accessor.ptr, + host_data.data(), + host_data.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + + return accessor; +} + +GenericTensorAccessorW create_filled_accessor_w(TensorShape const &shape, + Allocator &allocator, + float val, + bool cpu_fill) { + GenericTensorAccessorW accessor = allocator.allocate_tensor(shape); + size_t volume = accessor.shape.num_elements(); + std::vector host_data(volume, val); + + if (cpu_fill) { + memcpy(accessor.ptr, host_data.data(), host_data.size() * sizeof(float)); + } else { + checkCUDA(cudaMemcpy(accessor.ptr, + host_data.data(), + host_data.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + + return accessor; +} + +GenericTensorAccessorW create_iota_filled_accessor_w(TensorShape const &shape, + Allocator &allocator, + bool cpu_fill) { + GenericTensorAccessorW accessor = allocator.allocate_tensor(shape); + size_t volume = accessor.shape.num_elements(); + std::vector host_data(volume); + + for (size_t i = 0; i < volume; i++) { + host_data[i] = i; + } + + if (cpu_fill) { + memcpy(accessor.ptr, host_data.data(), host_data.size() * sizeof(float)); + } else { + checkCUDA(cudaMemcpy(accessor.ptr, + host_data.data(), + host_data.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + + return accessor; +} + +void fill_tensor_accessor_w(GenericTensorAccessorW accessor, + float val, + bool cpu_fill) { + LegionTensorDims dims = accessor.shape.dims; + size_t volume = accessor.shape.num_elements(); + std::vector host_data(volume, val); + + if (cpu_fill) { + memcpy(accessor.ptr, host_data.data(), host_data.size() * sizeof(float)); + } else { + checkCUDA(cudaMemcpy(accessor.ptr, + host_data.data(), + host_data.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } +} + +TensorShape make_float_tensor_shape_from_legion_dims(FFOrdered dims) { + return TensorShape{ + TensorDims{ + dims, + }, + DataType::FLOAT, + }; +} + +TensorShape make_double_tensor_shape_from_legion_dims(FFOrdered dims) { + return TensorShape{ + TensorDims{ + dims, + }, + DataType::DOUBLE, + }; +} diff --git a/lib/kernels/test/src/test_utils.h b/lib/kernels/test/src/test_utils.h new file mode 100644 index 0000000000..abce3fd444 --- /dev/null +++ b/lib/kernels/test/src/test_utils.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_KERNELS_TEST_UTILS +#define _FLEXFLOW_KERNELS_TEST_UTILS + +#include "kernels/device.h" +#include "kernels/local_cuda_allocator.h" +#include "kernels/managed_ff_stream.h" +#include "kernels/managed_per_device_ff_handle.h" +#include + +GenericTensorAccessorW create_random_filled_accessor_w(TensorShape const &shape, + Allocator &allocator, + bool cpu_fill = false); + +GenericTensorAccessorW create_filled_accessor_w(TensorShape const &shape, + Allocator &allocator, + float val, + bool cpu_fill = false); + +GenericTensorAccessorW create_iota_filled_accessor_w(TensorShape const &shape, + Allocator &allocator, + bool cpu_fill = false); + +void fill_tensor_accessor_w(GenericTensorAccessorW accessor, + float val, + bool cpu_fill = false); + +TensorShape make_float_tensor_shape_from_legion_dims(FFOrdered dims); + +TensorShape make_double_tensor_shape_from_legion_dims(FFOrdered dims); + +template +std::vector load_data_to_host_from_device(GenericTensorAccessorR accessor) { + int volume = accessor.shape.get_volume(); + + std::vector local_data(volume); + checkCUDA(cudaMemcpy(local_data.data(), + accessor.ptr, + local_data.size() * sizeof(T), + cudaMemcpyDeviceToHost)); + return local_data; +} + +template +bool contains_non_zero(std::vector &data) { + return !all_of(data, [](T const &val) { return val == 0; }); +} + +#endif diff --git a/lib/local-execution/include/local-execution/local_allocator.h b/lib/local-execution/include/local-execution/local_allocator.h deleted file mode 100644 index 9b38b50ed5..0000000000 --- a/lib/local-execution/include/local-execution/local_allocator.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_LOCAL_ALLOCATOR_H -#define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_ALLOCATOR_H - -#include "kernels/allocation.h" - -namespace FlexFlow { - -struct LocalAllocator : public IAllocator { - LocalAllocator() = default; - LocalAllocator(LocalAllocator const &) = delete; - LocalAllocator(LocalAllocator &&) = delete; - ~LocalAllocator() = default; - - void *allocate(size_t) override; - void deallocate(void *) override; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalAllocator); - -Allocator get_local_memory_allocator(); - -} // namespace FlexFlow - -#endif diff --git a/lib/local-execution/include/local-execution/tracked_allocator.h b/lib/local-execution/include/local-execution/tracked_allocator.h index 49708954b4..ae7bd076ce 100644 --- a/lib/local-execution/include/local-execution/tracked_allocator.h +++ b/lib/local-execution/include/local-execution/tracked_allocator.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LOCAL_EXECUTION_TRACKED_ALLOCATOR_H #include "kernels/allocation.h" -#include "local-execution/local_allocator.h" namespace FlexFlow { diff --git a/lib/local-execution/src/local_allocator.cc b/lib/local-execution/src/local_allocator.cc deleted file mode 100644 index d393643ead..0000000000 --- a/lib/local-execution/src/local_allocator.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "local-execution/local_allocator.h" -#include "kernels/device.h" - -namespace FlexFlow { - -void *LocalAllocator::allocate(size_t requested_memory_size) { - void *ptr; - checkCUDA(cudaMalloc(&ptr, requested_memory_size)); - return ptr; -} - -void LocalAllocator::deallocate(void *ptr) { - checkCUDA(cudaFree(ptr)); -} - -Allocator get_local_memory_allocator() { - return Allocator::create(); -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index 51deb23d22..9cb1d9913a 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -1,5 +1,6 @@ #include "local-execution/local_cost_estimator.h" #include "kernels/device.h" +#include "kernels/local_cuda_allocator.h" #include "local-execution/tracked_allocator.h" #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" @@ -39,7 +40,7 @@ CostDetails LocalCostEstimator::estimate_cost( // allocate memory for inputs std::shared_ptr tracked_allocator_ptr = - std::make_shared(get_local_memory_allocator()); + std::make_shared(create_local_cuda_memory_allocator()); Allocator allocator = Allocator(tracked_allocator_ptr); TensorBackingMap tensor_backing_map; std::vector input_tensor_ids; diff --git a/lib/local-execution/src/ops/attention.cc b/lib/local-execution/src/ops/attention.cc index be1fae475c..fc3627404d 100644 --- a/lib/local-execution/src/ops/attention.cc +++ b/lib/local-execution/src/ops/attention.cc @@ -84,10 +84,10 @@ static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); Allocator allocator = acc.get_allocator(); - int qProjSize = acc.get_argument(QPROJSIZE); - int kProjSize = acc.get_argument(KPROJSIZE); - int vProjSize = acc.get_argument(VPROJSIZE); - int oProjSize = acc.get_argument(OPROJSIZE); + size_t qProjSize = acc.get_argument(QPROJSIZE); + size_t kProjSize = acc.get_argument(KPROJSIZE); + size_t vProjSize = acc.get_argument(VPROJSIZE); + size_t oProjSize = acc.get_argument(OPROJSIZE); PerDeviceFFHandle handle = acc.get_argument(HANDLE); ParallelTensorShape query_parallel_tensor_shape = acc.get_argument(QUERY_PARALLEL_TENSOR_SHAPE); diff --git a/lib/local-execution/src/ops/batch_norm.cc b/lib/local-execution/src/ops/batch_norm.cc index 831e42fad9..5eaa264541 100644 --- a/lib/local-execution/src/ops/batch_norm.cc +++ b/lib/local-execution/src/ops/batch_norm.cc @@ -71,6 +71,7 @@ static DeviceSpecific Allocator allocator = acc.get_allocator(); PerDeviceFFHandle handle = acc.get_argument(HANDLE); ProfilingSettings profiling = acc.get_argument(PROFILING); + auto output = acc.get_tensor(OUTPUT); auto const &attrs = acc.get_argument(ATTRS); diff --git a/lib/local-execution/src/ops/conv_2d.cc b/lib/local-execution/src/ops/conv_2d.cc index bc3e66f60f..59b2feaee0 100644 --- a/lib/local-execution/src/ops/conv_2d.cc +++ b/lib/local-execution/src/ops/conv_2d.cc @@ -57,7 +57,7 @@ static DeviceSpecific PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto attrs = acc.get_argument(ATTRS); - auto input = acc.get_tensor(INPUT); + auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); auto filter = acc.get_tensor(FILTER); auto filter_grad = acc.get_tensor_grad(FILTER); diff --git a/lib/local-execution/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc index 27277a2b74..91146e3f6c 100644 --- a/lib/local-execution/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -74,6 +74,7 @@ static DeviceSpecific LinearPerDeviceState state = init_kernel(handle, one_ptr, + attrs.activation, attrs.regularizer, attrs.use_bias, input.data_type, diff --git a/lib/local-execution/src/ops/replicate.cc b/lib/local-execution/src/ops/replicate.cc index 983248ac1a..fa20be7383 100644 --- a/lib/local-execution/src/ops/replicate.cc +++ b/lib/local-execution/src/ops/replicate.cc @@ -62,8 +62,8 @@ static std::optional backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); auto const &attrs = acc.get_argument(ATTRS); return profile(backward_kernel, diff --git a/lib/local-execution/src/ops/softmax.cc b/lib/local-execution/src/ops/softmax.cc index 9c919626cc..a0b3a047a7 100644 --- a/lib/local-execution/src/ops/softmax.cc +++ b/lib/local-execution/src/ops/softmax.cc @@ -56,9 +56,17 @@ static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { PerDeviceFFHandle handle = acc.get_argument(HANDLE); + auto output = acc.get_tensor(OUTPUT); auto const &attrs = acc.get_argument(ATTRS); - SoftmaxPerDeviceState per_device_state = init_kernel(handle, attrs.dim.value); + int output_w = output.shape.at(legion_dim_t(0)); + int output_h = output.shape.at(legion_dim_t(1)); + int output_c = output.shape.at(legion_dim_t(2)); + int output_n = output.shape.at(legion_dim_t(3)); + + SoftmaxPerDeviceState per_device_state = init_kernel( + handle, attrs.dim.value, output_n, output_c, output_h, output_w); + return DeviceSpecific::create(per_device_state); }