Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions lib/kernels/include/kernels/flat_kernels.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
#ifndef _FLEXFLOW_OPS_KERNELS_FLAT_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_FLAT_KERNELS_H

#include "kernels/accessor.h"
#include "kernels/device.h"

namespace FlexFlow {

class FlatPerDeviceState : public PerDeviceOpState {
public:
FlatPerDeviceState(FFHandler handle) : PerDeviceOpState(handle){};
};

namespace Kernels {
namespace Flat {

void forward_kernel(ffStream_t stream,
float const *input_ptr,
float *output_ptr,
size_t num_elements);
GenericTensorAccessorR input,
float *output_ptr);
void backward_kernel(ffStream_t stream,
GenericTensorAccessorR input,
float *input_grad_ptr,
float const *output_grad_ptr,
size_t num_elements);
float const *output_grad_ptr);

} // namespace Flat
} // namespace Kernels
Expand Down
27 changes: 11 additions & 16 deletions lib/kernels/src/cuda/flat_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,35 @@
* limitations under the License.
*/

#include "kernels/cuda_helper.h"
#include "device.h"
#include "kernels/accessor.h"
#include "kernels/device.h"
#include "kernels/flat_kernels.h"

namespace FlexFlow {

namespace Kernels {
namespace Flat {

void forward_kernel(cudaStream_t stream,
float const *input_ptr,
float *output_ptr,
size_t num_elements) {
GenericTensorAccessorR input,
float *output_ptr) {

checkCUDA(cudaMemcpyAsync(output_ptr,
input_ptr,
num_elements * sizeof(float),
input.get_float_ptr(),
(input.shape.num_elements()) * sizeof(float),
cudaMemcpyDeviceToDevice,
stream));
// checkCUDA(cudaDeviceSynchronize());
}

void backward_kernel(cudaStream_t stream,
GenericTensorAccessorR input,
float *input_grad_ptr,
float const *output_grad_ptr,
size_t num_elements) {
float const *output_grad_ptr) {

float alpha = 1.0f;
apply_add_with_scale<float>
<<<GET_BLOCKS(num_elements), CUDA_NUM_THREADS, 0, stream>>>(
input_grad_ptr, output_grad_ptr, num_elements, alpha);
// checkCUDA(cudaMemcpyAsync(acc_input_grad.ptr, acc_output_grad.ptr,
// acc_input_grad.rect.volume() * sizeof(float),
// cudaMemcpyDeviceToDevice));
// checkCUDA(cudaDeviceSynchronize());
<<<GET_BLOCKS(input.shape.num_elements()), CUDA_NUM_THREADS, 0, stream>>>(
input_grad_ptr, output_grad_ptr, input.shape.num_elements(), alpha);
}

} // namespace Flat
Expand Down
17 changes: 8 additions & 9 deletions lib/kernels/src/hip/flat_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,31 @@ namespace Kernels {
namespace Flat {

void forward_kernel(hipStream_t stream,
float const *input_ptr,
float *output_ptr,
size_t num_elements) {
GenericTensorAccessorR input,
float *output_ptr) {

checkCUDA(hipMemcpyAsync(output_ptr,
input_ptr,
num_elements * sizeof(float),
input.get_float_ptr(),
(input.shape.num_elements()) * sizeof(float),
hipMemcpyDeviceToDevice,
stream));
// checkCUDA(hipDeviceSynchronize());
}

void backward_kernel(hipStream_t stream,
GenericTensorAccessorR input,
float *input_grad_ptr,
float const *output_grad_ptr,
size_t num_elements) {
float const *output_grad_ptr) {

float alpha = 1.0f;
hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale<float>),
GET_BLOCKS(num_elements),
GET_BLOCKS(input.shape.num_elements()),
CUDA_NUM_THREADS,
0,
stream,
input_grad_ptr,
output_grad_ptr,
num_elements,
input.shape.num_elements(),
alpha);
// checkCUDA(hipMemcpyAsync(acc_input_grad.ptr, acc_output_grad.ptr,
// acc_input_grad.rect.volume() * sizeof(float),
Expand Down
Loading