Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions lib/kernels/include/kernels/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ class GenericTensorAccessorW {
ArrayShape shape;
req<void *> ptr;
};
FF_VISITABLE_STRUCT(GenericTensorAccessorW, data_type, shape, ptr);
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(GenericTensorAccessorW,
data_type,
shape,
ptr);

class GenericTensorAccessorR {
public:
Expand All @@ -59,7 +62,10 @@ class GenericTensorAccessorR {
ArrayShape shape;
req<void const *> ptr;
};
FF_VISITABLE_STRUCT(GenericTensorAccessorR, data_type, shape, ptr);
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(GenericTensorAccessorR,
data_type,
shape,
ptr);

int32_t *get_int32_ptr(GenericTensorAccessorW const &);
int64_t *get_int64_ptr(GenericTensorAccessorW const &);
Expand Down
78 changes: 43 additions & 35 deletions lib/kernels/include/kernels/conv_2d_kernels.h
Original file line number Diff line number Diff line change
@@ -1,66 +1,74 @@
#ifndef _FLEXFLOW_OPS_KERNELS_CONV_2D_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_CONV_2D_KERNELS_H

#include "kernels/accessor.h"
#include "kernels/device.h"
#include "kernels/ff_handle.h"
#include "op-attrs/activation.h"
#include "utils/visitable.h"

namespace FlexFlow {

class Conv2DPerDeviceState : public PerDeviceOpState {
public:
Conv2DPerDeviceState(FFHandler handler);
ffTensorDescriptor_t inputTensor, biasTensor, outputTensor;
struct Conv2DPerDeviceState {
PerDeviceFFHandle handle;
ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t biasTensor;
ffTensorDescriptor_t outputTensor;
ffFilterDescriptor_t filterDesc;
ffActivationDescriptor_t actiDesc;
ffConvolutionDescriptor_t convDesc;
ffConvolutionFwdAlgo_t fwdAlgo;
ffConvolutionBwdFilterAlgo_t bwdFilterAlgo;
ffConvolutionBwdDataAlgo_t bwdDataAlgo;
bool relu, use_bias;
char op_name[MAX_OPNAME];
req<ffConvolutionBwdDataAlgo_t> bwdDataAlgo;
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(Conv2DPerDeviceState,
handle,
inputTensor,
biasTensor,
outputTensor,
filterDesc,
actiDesc,
convDesc,
fwdAlgo,
bwdFilterAlgo,
bwdDataAlgo);

namespace Kernels {
namespace Conv2D {

void init_kernel(Conv2DPerDeviceState *m,
int input_w,
int input_h,
int input_c,
int input_n,
int output_w,
int output_h,
int output_c,
int output_n,
int kernel_h,
int kernel_w,
int groups,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
float const *input_ptr,
float *output_ptr,
float const *kernel_ptr,
float *kernel_grad_ptr,
float *forward_time = nullptr,
float *backward_time = nullptr);
Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle,
optional<Activation> activation,
int kernel_h,
int kernel_w,
int groups,
int padding_h,
int padding_w,
int stride_h,
int stride_w,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
float const *filter_ptr,
float *filter_grad_ptr);

void forward_kernel(ffStream_t stream,
Conv2DPerDeviceState const *m,
Conv2DPerDeviceState const &m,
float const *input_ptr,
float *output_ptr,
float const *filter_ptr,
float const *bias_ptr);
float const *bias_ptr,
optional<Activation> activation);

void backward_kernel(ffStream_t stream,
Conv2DPerDeviceState const *m,
Conv2DPerDeviceState const &m,
float const *input_ptr,
float *input_grad_ptr,
float const *output_ptr,
float *output_grad_ptr,
float const *kernel_ptr,
float *kernel_grad_ptr,
float *bias_grad_ptr);
float const *filter_ptr,
float *filter_grad_ptr,
float *bias_grad_ptr,
optional<Activation> activation);

} // namespace Conv2D
} // namespace Kernels
Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
#include <iostream>
#include <sstream>

namespace FlexFlow {

#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
typedef cudaStream_t ffStream_t;
cudaError_t get_legion_stream(cudaStream_t *stream);
Expand Down Expand Up @@ -79,6 +77,8 @@ typedef hipError_t ffError_t;
#error "Unknown device"
#endif

namespace FlexFlow {

#define FatalError(s) \
do { \
std::stringstream _where, _message; \
Expand Down
18 changes: 18 additions & 0 deletions lib/kernels/include/kernels/ff_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#endif

#include "kernels/device.h"
#include "utils/visitable.h"

namespace FlexFlow {

Expand All @@ -22,6 +23,23 @@ struct PerDeviceFFHandle {
#endif
};

#ifdef FF_USE_NCCL
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(PerDeviceFFHandle,
dnn,
blas,
workSpace,
workSpaceSize,
allowTensorOpMathConversion,
ncclComm);
#else
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(PerDeviceFFHandle,
dnn,
blas,
workSpace,
workSpaceSize,
allowTensorOpMathConversion);
#endif

} // namespace FlexFlow

#endif
Loading