Skip to content
20 changes: 8 additions & 12 deletions lib/kernels/include/kernels/cast_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,24 @@

#include "kernels/accessor.h"
#include "kernels/device.h"
#include "op-attrs/ffconst.h"
#include "kernels/ff_handle.h"
#include "op-attrs/activation.h"

namespace FlexFlow {

class CastPerDeviceState : public PerDeviceOpState {
public:
CastPerDeviceState(FFHandler handle);
DataType input_data_type, output_data_type;
};

namespace Kernels {
namespace Cast {

void forward_kernel(ffStream_t stream,
CastPerDeviceState const *,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type);

void backward_kernel(ffStream_t stream,
CastPerDeviceState const *,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type);

} // namespace Cast
} // namespace Kernels
Expand Down
23 changes: 10 additions & 13 deletions lib/kernels/src/cuda/cast_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
* limitations under the License.
*/

#include "device.h"
#include "kernels/cast_kernels.h"
#include "kernels/cuda_helper.h"
#include "kernels/datatype_dispatch.h"
#include "kernels/device.h"

namespace FlexFlow {

CastPerDeviceState::CastPerDeviceState(FFHandler handle)
: PerDeviceOpState(handle) {}

namespace Kernels {
namespace Cast {

Expand All @@ -43,7 +40,6 @@ __global__ void
template <DataType IDT, DataType ODT>
struct ForwardKernel {
void operator()(ffStream_t stream,
CastPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
size_t volume = input.shape.get_volume();
Expand All @@ -55,7 +51,6 @@ struct ForwardKernel {
template <DataType IDT, DataType ODT>
struct BackwardKernel {
void operator()(ffStream_t stream,
CastPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
size_t volume = input.shape.get_volume();
Expand All @@ -65,19 +60,21 @@ struct BackwardKernel {
};

void forward_kernel(ffStream_t stream,
CastPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type) {
DataTypeDispatch2<ForwardKernel>{}(
m->input_data_type, m->output_data_type, stream, m, input, output);
input_type, output_type, stream, handle, input, output);
}

void backward_kernel(ffStream_t stream,
CastPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type) {
DataTypeDispatch2<BackwardKernel>{}(
m->input_data_type, m->output_data_type, stream, m, input, output);
input_type, output_type, stream, handle, input, output);
}

} // namespace Cast
Expand Down
20 changes: 8 additions & 12 deletions lib/kernels/src/hip/cast_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@
#include <hip/hip_runtime.h>

namespace FlexFlow {

CastPerDeviceState::CastPerDeviceState(FFHandler handle)
: PerDeviceOpState(handle) {}

namespace Kernels {
namespace Cast {

Expand All @@ -44,7 +40,6 @@ __global__ void
template <DataType IDT, DataType ODT>
struct ForwardKernel {
void operator()(ffStream_t stream,
CastPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
size_t volume = input.shape.get_volume();
Expand All @@ -62,7 +57,6 @@ struct ForwardKernel {
template <DataType IDT, DataType ODT>
struct BackwardKernel {
void operator()(ffStream_t stream,
CastPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
size_t volume = input.shape.get_volume();
Expand All @@ -79,19 +73,21 @@ struct BackwardKernel {
};

void forward_kernel(ffStream_t stream,
CastPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type) {
DataTypeDispatch2<ForwardKernel>{}(
m->input_data_type, m->output_data_type, stream, m, input, output);
input_type, output_type, stream, input, output);
}

void backward_kernel(ffStream_t stream,
CastPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type) {
DataTypeDispatch2<BackwardKernel>{}(
m->input_data_type, m->output_data_type, stream, m, input, output);
input_type, output_type, stream, input, output);
}

} // namespace Cast
Expand Down
Loading