diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index 7c530bc10c..d474ea4da5 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -46,7 +46,7 @@ namespace Kernels { namespace MultiHeadAttention { MHAPerDeviceState init_kernel(PerDeviceFFHandle const &, - Allocator const &, + Allocator &, int num_samples, int num_heads, int qSize, @@ -80,6 +80,9 @@ void backward_kernel(ffStream_t stream, float *weight_grad_ptr, float const *output_grad_ptr); +void cleanup_kernel(Allocator &allocator, + MHAPerDeviceState const &device_state); + } // namespace MultiHeadAttention } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 0e4437bdb8..c1966309a4 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -1,43 +1,36 @@ #ifndef _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H #define _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H +#include "kernels/allocation.h" #include "kernels/device.h" #include "kernels/ff_handle.h" +#include "utils/visitable.h" namespace FlexFlow { - -class BatchMatmulPerDeviceState : public PerDeviceOpState { -public: - BatchMatmulPerDeviceState(FFHandler handler); - int a_seq_length_dim, b_seq_length_dim; -}; - namespace Kernels { namespace BatchMatmul { void forward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, - float *o_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + PerDeviceFFHandle const &handle, + float *output_ptr, + float const *a_input_ptr, + float const *b_input_ptr, int m, int n, int k, int batch, - int a_seq_length_dim = -1, - int b_seq_length_dim = -1, - int seq_length = -1); + int seq_length, + int a_seq_length_dim, + int b_seq_length_dim); void backward_kernel(ffStream_t stream, - BatchMatmulPerDeviceState const *, + PerDeviceFFHandle const &handle, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, diff --git a/lib/kernels/src/cuda/attention_kernels.cu b/lib/kernels/src/cuda/attention_kernels.cu index 5981179395..c2225c13d4 100644 --- a/lib/kernels/src/cuda/attention_kernels.cu +++ b/lib/kernels/src/cuda/attention_kernels.cu @@ -13,53 +13,64 @@ * limitations under the License. */ +#include "device.h" #include "kernels/attention_kernels.h" -#include "kernels/cuda_helper.h" +#include "kernels/device.h" namespace FlexFlow { namespace Kernels { namespace MultiHeadAttention { -void init_kernel(MHAPerDeviceState *m, - int num_samples, - int num_heads, - int qSize, - int kSize, - int vSize, - int qProjSize, - int kProjSize, - int vProjSize, - int oProjSize, - int qoSeqLength, - int kvSeqLength, - bool add_bias_kv) { +MHAPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator &allocator, + int num_samples, + int num_heads, + int qSize, + int kSize, + int vSize, + int qProjSize, + int kProjSize, + int vProjSize, + int oProjSize, + int qoSeqLength, + int kvSeqLength, + bool add_bias_kv) { cudaStream_t stream; + ffAttnDescriptor_t attnDesc; + ffSeqDataDescriptor_t qDesc; + ffSeqDataDescriptor_t kDesc; + ffSeqDataDescriptor_t vDesc; + ffSeqDataDescriptor_t oDesc; + void *reserveSpace; + void *dropoutStates; + int *devQoSeqArray; + int *devKvSeqArray; + size_t reserveSpaceSize; + size_t dropoutStateSize; + size_t weightSize; + checkCUDA(get_legion_stream(&stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - checkCUDNN(cudnnCreateAttnDescriptor(&m->attnDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&m->qDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&m->kDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&m->vDesc)); - checkCUDNN(cudnnCreateSeqDataDescriptor(&m->oDesc)); + checkCUDNN(cudnnSetStream(handle.dnn, stream)); + checkCUDNN(cudnnCreateAttnDescriptor(&attnDesc)); + checkCUDNN(cudnnCreateSeqDataDescriptor(&qDesc)); + checkCUDNN(cudnnCreateSeqDataDescriptor(&kDesc)); + checkCUDNN(cudnnCreateSeqDataDescriptor(&vDesc)); + checkCUDNN(cudnnCreateSeqDataDescriptor(&oDesc)); + // Currently do not support adding bias to key/value projection assert(!add_bias_kv); cudnnAttnQueryMap_t attnMode = CUDNN_ATTN_QUERYMAP_ALL_TO_ONE; + // Assume no beam search for now int maxBeamSize = 1; - // printf("batchSize(%d) qSize(%d) kSize(%d) vSize(%d) qProjSize(%d) - // kProjSize(%d)\n", - // num_samples, attn->qSize, attn->kSize, attn->vSize, attn->qProjSize, - // attn->kProjSize); - // printf("vProjSize(%d) oProjSize(%d) qoSeqLength(%d) kvSeqLength(%d)\n", - // attn->vProjSize, attn->oProjSize, attn->qoSeqLength, - // attn->kvSeqLength); + cudnnMathType_t math_type; - if (m->handle.allowTensorOpMathConversion) { + if (handle.allowTensorOpMathConversion) { math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; } else { math_type = CUDNN_TENSOR_OP_MATH; } - checkCUDNN(cudnnSetAttnDescriptor(m->attnDesc, + checkCUDNN(cudnnSetAttnDescriptor(attnDesc, attnMode, num_heads, 1.0f /*smScalar*/, @@ -80,14 +91,10 @@ void init_kernel(MHAPerDeviceState *m, num_samples, maxBeamSize)); size_t workSpaceSize; - checkCUDNN(cudnnGetMultiHeadAttnBuffers(m->handle.dnn, - m->attnDesc, - &m->weightSize, - &workSpaceSize, - &m->reserveSpaceSize)); - assert(workSpaceSize <= m->handle.workSpaceSize); - // printf("weightSize(%zu) workSpaceSize(%zu) reserveSpaceSize(%zu)\n", - // weightSize, workSpaceSize, reserveSpaceSize); + checkCUDNN(cudnnGetMultiHeadAttnBuffers( + handle.dnn, attnDesc, &weightSize, &workSpaceSize, &reserveSpaceSize)); + assert(workSpaceSize <= handle.workSpaceSize); + int dimA[CUDNN_SEQDATA_DIM_COUNT]; cudnnSeqDataAxis_t axes[CUDNN_SEQDATA_DIM_COUNT]; assert(CUDNN_SEQDATA_DIM_COUNT == 4); @@ -95,8 +102,8 @@ void init_kernel(MHAPerDeviceState *m, axes[2] = CUDNN_SEQDATA_BEAM_DIM; axes[1] = CUDNN_SEQDATA_TIME_DIM; axes[0] = CUDNN_SEQDATA_BATCH_DIM; - int *qoSeqArray = (int *)malloc(sizeof(int) * num_samples); - int *kvSeqArray = (int *)malloc(sizeof(int) * num_samples); + std::unique_ptr qoSeqArray(new int[num_samples]); + std::unique_ptr kvSeqArray(new int[num_samples]); for (int i = 0; i < num_samples; i++) { qoSeqArray[i] = qoSeqLength; kvSeqArray[i] = kvSeqLength; @@ -107,13 +114,13 @@ void init_kernel(MHAPerDeviceState *m, dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; dimA[CUDNN_SEQDATA_TIME_DIM] = qoSeqLength; dimA[CUDNN_SEQDATA_VECT_DIM] = qSize; - checkCUDNN(cudnnSetSeqDataDescriptor(m->qDesc, + checkCUDNN(cudnnSetSeqDataDescriptor(qDesc, CUDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, axes, num_samples, - qoSeqArray, + qoSeqArray.get(), NULL)); } // Set kDesc @@ -122,13 +129,13 @@ void init_kernel(MHAPerDeviceState *m, dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; dimA[CUDNN_SEQDATA_TIME_DIM] = kvSeqLength; dimA[CUDNN_SEQDATA_VECT_DIM] = kSize; - checkCUDNN(cudnnSetSeqDataDescriptor(m->kDesc, + checkCUDNN(cudnnSetSeqDataDescriptor(kDesc, CUDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, axes, num_samples, - kvSeqArray, + kvSeqArray.get(), NULL)); } // Set vDesc @@ -137,13 +144,13 @@ void init_kernel(MHAPerDeviceState *m, dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; dimA[CUDNN_SEQDATA_TIME_DIM] = kvSeqLength; dimA[CUDNN_SEQDATA_VECT_DIM] = vSize; - checkCUDNN(cudnnSetSeqDataDescriptor(m->vDesc, + checkCUDNN(cudnnSetSeqDataDescriptor(vDesc, CUDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, axes, num_samples, - kvSeqArray, + kvSeqArray.get(), NULL)); } // Set oDesc @@ -152,155 +159,92 @@ void init_kernel(MHAPerDeviceState *m, dimA[CUDNN_SEQDATA_BATCH_DIM] = num_samples; dimA[CUDNN_SEQDATA_TIME_DIM] = qoSeqLength; dimA[CUDNN_SEQDATA_VECT_DIM] = oProjSize; - checkCUDNN(cudnnSetSeqDataDescriptor(m->oDesc, + checkCUDNN(cudnnSetSeqDataDescriptor(oDesc, CUDNN_DATA_FLOAT, CUDNN_SEQDATA_DIM_COUNT, dimA, axes, num_samples, - qoSeqArray, + qoSeqArray.get(), NULL)); } // allocate memory for the seqArray and reserve space { - size_t totalSize = m->reserveSpaceSize + sizeof(int) * num_samples * 2; + size_t totalSize = reserveSpaceSize + sizeof(int) * num_samples * 2; - m->devQoSeqArray = (int *)m->gpu_alloc(totalSize); - checkCUDA(cudaMemcpy(m->devQoSeqArray, - qoSeqArray, + devQoSeqArray = (int *)allocator.allocate(totalSize); + checkCUDA(cudaMemcpy(devQoSeqArray, + qoSeqArray.get(), sizeof(int) * num_samples, cudaMemcpyHostToDevice)); - m->devKvSeqArray = m->devQoSeqArray + num_samples; - checkCUDA(cudaMemcpy(m->devKvSeqArray, - kvSeqArray, + devKvSeqArray = devQoSeqArray + num_samples; + checkCUDA(cudaMemcpy(devKvSeqArray, + kvSeqArray.get(), sizeof(int) * num_samples, cudaMemcpyHostToDevice)); - m->reserveSpace = m->devKvSeqArray + num_samples; + reserveSpace = devKvSeqArray + num_samples; } // allocate memory for loWinIdx/hiWinIdx - m->loWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); - m->hiWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); + int *loWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); + int *hiWinIdx = (int *)malloc(sizeof(int) * qoSeqLength); for (int i = 0; i < qoSeqLength; i++) { - m->loWinIdx[i] = 0; - m->hiWinIdx[i] = kvSeqLength; + loWinIdx[i] = 0; + hiWinIdx[i] = kvSeqLength; } - free(qoSeqArray); - free(kvSeqArray); -} - -/* void forward_kernel_wrapper(MHAPerDeviceState const *m, */ -/* float const *query_ptr, */ -/* float const *key_ptr, */ -/* float const *value_ptr, */ -/* float const *weight_ptr, */ -/* float *output_ptr) { */ -/* wrapper(Internal::forward_kernel, m->profiling, ) */ -/* cudaStream_t stream; */ -/* checkCUDA(get_legion_stream(&stream)); */ - -/* cudaEvent_t t_start, t_end; */ -/* if (m->profiling) { */ -/* cudaEventCreate(&t_start); */ -/* cudaEventCreate(&t_end); */ -/* cudaEventRecord(t_start, stream); */ -/* } */ -/* Internal::forward_kernel( */ -/* m, query_ptr, key_ptr, value_ptr, weight_ptr, output_ptr, stream); */ -/* if (m->profiling) { */ -/* cudaEventRecord(t_end, stream); */ -/* checkCUDA(cudaEventSynchronize(t_end)); */ -/* float elapsed = 0; */ -/* checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); */ -/* cudaEventDestroy(t_start); */ -/* cudaEventDestroy(t_end); */ -/* printf("MultiHeadAttention forward time = %.2fms\n", elapsed); */ -/* // print_tensor<3, float>(acc_query.ptr, acc_query.rect, */ -/* // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, - */ -/* // acc_output.rect, "[Attention:forward:output]"); */ -/* } */ -/* } */ - -/* void backward_kernel_wrapper( */ -/* MHAPerDeviceState const *m, */ -/* float const *query_ptr, */ -/* float *query_grad_ptr, */ -/* float const *key_ptr, */ -/* float *key_grad_ptr, */ -/* float const *value_ptr, */ -/* float *value_grad_ptr, */ -/* float const *weight_ptr, */ -/* float *weight_grad_ptr, */ -/* float const *output_grad_ptr) { */ -/* cudaStream_t stream; */ -/* checkCUDA(get_legion_stream(&stream)); */ -/* cudaEvent_t t_start, t_end; */ -/* if (m->profiling) { */ -/* cudaEventCreate(&t_start); */ -/* cudaEventCreate(&t_end); */ -/* cudaEventRecord(t_start, stream); */ -/* } */ + MHAPerDeviceState per_device_state = {handle, + weightSize, + reserveSpaceSize, + attnDesc, + qDesc, + kDesc, + vDesc, + oDesc, + devQoSeqArray, + devKvSeqArray, + loWinIdx, + hiWinIdx, + reserveSpace, + allocator}; -/* Internal::backward_kernel(m, */ -/* query_ptr, */ -/* query_grad_ptr, */ -/* key_ptr, */ -/* key_grad_ptr, */ -/* value_ptr, */ -/* value_grad_ptr, */ -/* weight_ptr, */ -/* weight_grad_ptr, */ -/* output_grad_ptr, */ -/* stream); */ -/* if (m->profiling) { */ -/* cudaEventRecord(t_end, stream); */ -/* checkCUDA(cudaEventSynchronize(t_end)); */ -/* float elapsed = 0; */ -/* checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); */ -/* cudaEventDestroy(t_start); */ -/* cudaEventDestroy(t_end); */ -/* printf("MultiHeadAttention backward time = %.2fms\n", elapsed); */ -/* } */ -/* } */ - -/* namespace Internal { */ + return per_device_state; +} void forward_kernel(cudaStream_t stream, - MHAPerDeviceState *m, + MHAPerDeviceState const &device_state, float const *query_ptr, float const *key_ptr, float const *value_ptr, float const *weight_ptr, float *output_ptr) { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + checkCUDNN(cudnnSetStream(device_state.handle.dnn, stream)); - checkCUDNN(cudnnMultiHeadAttnForward(m->handle.dnn, - m->attnDesc, + checkCUDNN(cudnnMultiHeadAttnForward(device_state.handle.dnn, + device_state.attnDesc, -1, - m->loWinIdx, - m->hiWinIdx, - m->devQoSeqArray, - m->devKvSeqArray, - m->qDesc, + device_state.loWinIdx, + device_state.hiWinIdx, + device_state.devQoSeqArray, + device_state.devKvSeqArray, + device_state.qDesc, query_ptr, nullptr /*residual*/, - m->kDesc, + device_state.kDesc, key_ptr, - m->vDesc, + device_state.vDesc, value_ptr, - m->oDesc, + device_state.oDesc, output_ptr, - m->weightSize, + device_state.weightSize, weight_ptr, - m->handle.workSpaceSize, - m->handle.workSpace, - m->reserveSpaceSize, - m->reserveSpace)); + device_state.handle.workSpaceSize, + device_state.handle.workSpace, + device_state.reserveSpaceSize, + device_state.reserveSpace)); } void backward_kernel(cudaStream_t stream, - MHAPerDeviceState *m, + MHAPerDeviceState const &device_state, float const *query_ptr, float *query_grad_ptr, float const *key_ptr, @@ -310,101 +254,63 @@ void backward_kernel(cudaStream_t stream, float const *weight_ptr, float *weight_grad_ptr, float const *output_grad_ptr) { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + checkCUDNN(cudnnSetStream(device_state.handle.dnn, stream)); - checkCUDNN(cudnnMultiHeadAttnBackwardData(m->handle.dnn, - m->attnDesc, - m->loWinIdx, - m->hiWinIdx, - m->devQoSeqArray, - m->devKvSeqArray, - m->oDesc, + checkCUDNN(cudnnMultiHeadAttnBackwardData(device_state.handle.dnn, + device_state.attnDesc, + device_state.loWinIdx, + device_state.hiWinIdx, + device_state.devQoSeqArray, + device_state.devKvSeqArray, + device_state.oDesc, output_grad_ptr, - m->qDesc, + device_state.qDesc, query_grad_ptr, query_ptr, - m->kDesc, + device_state.kDesc, key_grad_ptr, key_ptr, - m->vDesc, + device_state.vDesc, value_grad_ptr, value_ptr, - m->weightSize, + device_state.weightSize, weight_ptr, - m->handle.workSpaceSize, - m->handle.workSpace, - m->reserveSpaceSize, - m->reserveSpace)); - checkCUDNN(cudnnMultiHeadAttnBackwardWeights(m->handle.dnn, - m->attnDesc, - CUDNN_WGRAD_MODE_ADD, - m->qDesc, - query_ptr, - m->kDesc, - key_ptr, - m->vDesc, - value_ptr, - m->oDesc, - output_grad_ptr, - m->weightSize, - weight_ptr, - weight_grad_ptr, - m->handle.workSpaceSize, - m->handle.workSpace, - m->reserveSpaceSize, - m->reserveSpace)); + device_state.handle.workSpaceSize, + device_state.handle.workSpace, + device_state.reserveSpaceSize, + device_state.reserveSpace)); + checkCUDNN( + cudnnMultiHeadAttnBackwardWeights(device_state.handle.dnn, + device_state.attnDesc, + CUDNN_WGRAD_MODE_ADD, + device_state.qDesc, + query_ptr, + device_state.kDesc, + key_ptr, + device_state.vDesc, + value_ptr, + device_state.oDesc, + output_grad_ptr, + device_state.weightSize, + weight_ptr, + weight_grad_ptr, + device_state.handle.workSpaceSize, + device_state.handle.workSpace, + device_state.reserveSpaceSize, + device_state.reserveSpace)); +} + +void cleanup_kernel(Allocator &allocator, + MHAPerDeviceState const &device_state) { + allocator.deallocate(device_state.loWinIdx); + allocator.deallocate(device_state.hiWinIdx); + checkCUDNN(cudnnDestroyAttnDescriptor(device_state.attnDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.qDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.kDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.vDesc)); + checkCUDNN(cudnnDestroySeqDataDescriptor(device_state.oDesc)); } -/* } // namespace Internal */ } // namespace MultiHeadAttention } // namespace Kernels - -MHAPerDeviceState::MHAPerDeviceState(FFHandler handler, - Memory gpu_mem, - int num_samples, - int num_heads, - int qSize, - int kSize, - int vSize, - int qProjSize, - int kProjSize, - int vProjSize, - int oProjSize, - int qoSeqLength, - int kvSeqLength, - bool add_bias_kv) - : PerDeviceOpState(handler) {} - -MHAPerDeviceState::MHAPerDeviceState(FFHandler handler, - std::unique_ptr allocator, - MultiHeadAttentionAttrs const &attrs, - ArrayShape const &query_shape, - ArrayShape const &key_shape, - ArrayShape const &value_shape) { - : MHAPerDeviceState(handler, - allocator, - query_shape[2], - attrs.num_heads, - query_shape[0], - key_shape[0], - value_shape[0], - qProjSize(attrs), - kProjSize(attrs), - vProjSize(attrs), - oProjSize(attrs), - query_shape[1], - key_shape[1], - attrs.add_bias_kv) -{ } - - MHAPerDeviceState::~MHAPerDeviceState(void) { - free(loWinIdx); - free(hiWinIdx); - checkCUDNN(cudnnDestroyAttnDescriptor(attnDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(qDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(kDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(vDesc)); - checkCUDNN(cudnnDestroySeqDataDescriptor(oDesc)); - } - } // namespace FlexFlow diff --git a/lib/kernels/src/cuda/batch_matmul_kernels.cu b/lib/kernels/src/cuda/batch_matmul_kernels.cu index 3593ac4ab2..9d35cb6c1a 100644 --- a/lib/kernels/src/cuda/batch_matmul_kernels.cu +++ b/lib/kernels/src/cuda/batch_matmul_kernels.cu @@ -13,136 +13,28 @@ * limitations under the License. */ +#include "device.h" #include "kernels/batch_matmul_kernels.h" -#include "kernels/cuda_helper.h" +#include "kernels/device.h" namespace FlexFlow { - -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { -/* void forward_kernel_wrapper(BatchMatmulPerDeviceState const *meta, */ -/* float *o_ptr, */ -/* float const *a_ptr, */ -/* float const *b_ptr, */ -/* float const *c_ptr, */ -/* int m, */ -/* int n, */ -/* int k, */ -/* int batch, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* int seq_length) { */ -/* cudaStream_t stream; */ -/* */ - -/* cudaEvent_t t_start, t_end; */ -/* if (meta->profiling) { */ -/* cudaEventCreate(&t_start); */ -/* cudaEventCreate(&t_end); */ -/* cudaEventRecord(t_start, stream); */ -/* } */ -/* Internal::forward_kernel(meta, */ -/* o_ptr, */ -/* a_ptr, */ -/* b_ptr, */ -/* c_ptr, */ -/* m, */ -/* n, */ -/* k, */ -/* batch, */ -/* stream, */ -/* a_seq_length_dim, */ -/* b_seq_length_dim, */ -/* seq_length); */ -/* if (meta->profiling) { */ -/* cudaEventRecord(t_end, stream); */ -/* checkCUDA(cudaEventSynchronize(t_end)); */ -/* float elapsed = 0; */ -/* checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); */ -/* cudaEventDestroy(t_start); */ -/* cudaEventDestroy(t_end); */ -/* printf("BatchMatmul forward time = %.2lfms\n", elapsed); */ -/* } */ -/* } */ - -/* void backward_kernel_wrapper(BatchMatmulPerDeviceState const *meta, */ -/* float const *o_ptr, */ -/* float const *o_grad_ptr, */ -/* float const *a_ptr, */ -/* float *a_grad_ptr, */ -/* float const *b_ptr, */ -/* float *b_grad_ptr, */ -/* float *c_grad_ptr, */ -/* int m, */ -/* int n, */ -/* int k, */ -/* int batch) { */ -/* cudaStream_t stream; */ -/* */ - -/* cudaEvent_t t_start, t_end; */ -/* if (meta->profiling) { */ -/* cudaEventCreate(&t_start); */ -/* cudaEventCreate(&t_end); */ -/* cudaEventRecord(t_start, stream); */ -/* } */ -/* Internal::backward_kernel(meta, */ -/* o_ptr, */ -/* o_grad_ptr, */ -/* a_ptr, */ -/* a_grad_ptr, */ -/* b_ptr, */ -/* b_grad_ptr, */ -/* c_grad_ptr, */ -/* m, */ -/* n, */ -/* k, */ -/* batch, */ -/* stream); */ -/* if (meta->profiling) { */ -/* cudaEventRecord(t_end, stream); */ -/* checkCUDA(cudaEventSynchronize(t_end)); */ -/* float elapsed = 0; */ -/* checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); */ -/* cudaEventDestroy(t_start); */ -/* cudaEventDestroy(t_end); */ -/* printf("BatchMatmul backward time = %.2lfms\n", elapsed); */ -/* } */ -/* } */ - -/* namespace Internal { */ - -/* -A: (batch, n, k) -B: (batch, k, m) -O: (batch, n, m) -O = A * B -*/ - void forward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, - float *o_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + PerDeviceFFHandle const &handle, + float *output_ptr, + float const *a_input_ptr, + float const *b_input_ptr, int m, int n, int k, int batch, - cudaStream_t stream, int a_seq_length_dim, int b_seq_length_dim, int seq_length) { - checkCUDA(cublasSetStream(meta->handle.blas, stream)); - checkCUDNN(cudnnSetStream(meta->handle.dnn, stream)); - - // int a_stride = n * k; - // int b_stride = m * k; - // int o_stride = n * m; + checkCUDA(cublasSetStream(handle.blas, stream)); + checkCUDNN(cudnnSetStream(handle.dnn, stream)); int lda = k; int ldb = m; int ldo = m; @@ -172,56 +64,46 @@ void forward_kernel(cudaStream_t stream, } float alpha = 1.0f, beta = 0.0f; - checkCUDA(cublasSgemmStridedBatched(meta->handle.blas, + checkCUDA(cublasSgemmStridedBatched(handle.blas, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, - b_ptr, + b_input_ptr, ldb, strideB, - a_ptr, + a_input_ptr, lda, strideA, &beta, - o_ptr, + output_ptr, ldo, strideO, batch)); - // current assume c is null - assert(c_ptr == NULL); } -/* -A, AGrad: (batch, n, k) -B, BGrad: (batch, k, m) -O, OGrad: (batch, n, m) -AGrad = OGrad * B^T -BGrad = A^T * OGrad -*/ void backward_kernel(cudaStream_t stream, - BatchMatmulPerDeviceState const *meta, + PerDeviceFFHandle const &handle, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, int batch) { - checkCUDA(cublasSetStream(meta->handle.blas, stream)); - checkCUDNN(cudnnSetStream(meta->handle.dnn, stream)); + checkCUDA(cublasSetStream(handle.blas, stream)); + checkCUDNN(cudnnSetStream(handle.dnn, stream)); int a_stride = n * k; int b_stride = m * k; int o_stride = n * m; float alpha = 1.0f; - checkCUDA(cublasSgemmStridedBatched(meta->handle.blas, + checkCUDA(cublasSgemmStridedBatched(handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, k, @@ -239,7 +121,7 @@ void backward_kernel(cudaStream_t stream, k, a_stride, batch)); - checkCUDA(cublasSgemmStridedBatched(meta->handle.blas, + checkCUDA(cublasSgemmStridedBatched(handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, m, @@ -257,10 +139,8 @@ void backward_kernel(cudaStream_t stream, m, b_stride, batch)); - assert(c_grad_ptr == NULL); } -/* } // namespace Internal */ } // namespace BatchMatmul } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/src/hip/batch_matmul_kernels.cpp b/lib/kernels/src/hip/batch_matmul_kernels.cpp index d8b6500326..cbfd669e0f 100644 --- a/lib/kernels/src/hip/batch_matmul_kernels.cpp +++ b/lib/kernels/src/hip/batch_matmul_kernels.cpp @@ -18,10 +18,6 @@ #include namespace FlexFlow { - -BatchMatmulPerDeviceState::BatchMatmulPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace BatchMatmul { @@ -32,21 +28,19 @@ O: (batch, n, m) O = A * B */ void forward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, - float *o_ptr, - float const *a_ptr, - float const *b_ptr, - float const *c_ptr, + PerDeviceFFHandle const &handle, + float *output_ptr, + float const *a_input_ptr, + float const *b_input_ptr, int m, int n, int k, int batch, - hipStream_t stream, int a_seq_length_dim, int b_seq_length_dim, - int seq_length) { - checkCUDA(hipblasSetStream(meta->handle.blas, stream)); - checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); + int seq_length = -1) { + checkCUDA(hipblasSetStream(handle.blas, stream)); + checkCUDNN(miopenSetStream(handle.dnn, stream)); // int a_stride = n * k; // int b_stride = m * k; @@ -80,7 +74,7 @@ void forward_kernel(hipStream_t stream, } float alpha = 1.0f, beta = 0.0f; - checkCUDA(hipblasSgemmStridedBatched(meta->handle.blas, + checkCUDA(hipblasSgemmStridedBatched(handle.blas, HIPBLAS_OP_N, HIPBLAS_OP_N, m, @@ -98,8 +92,6 @@ void forward_kernel(hipStream_t stream, ldo, strideO, batch)); - // current assume c is null - assert(c_ptr == NULL); } /* @@ -110,26 +102,25 @@ AGrad = OGrad * B^T BGrad = A^T * OGrad */ void backward_kernel(hipStream_t stream, - BatchMatmulPerDeviceState const *meta, + PerDeviceFFHandle const &handle, float const *o_ptr, float const *o_grad_ptr, float const *a_ptr, float *a_grad_ptr, float const *b_ptr, float *b_grad_ptr, - float *c_grad_ptr, int m, int n, int k, int batch) { - checkCUDA(hipblasSetStream(meta->handle.blas, stream)); - checkCUDNN(miopenSetStream(meta->handle.dnn, stream)); + checkCUDA(hipblasSetStream(handle.blas, stream)); + checkCUDNN(miopenSetStream(handle.dnn, stream)); int a_stride = n * k; int b_stride = m * k; int o_stride = n * m; float alpha = 1.0f; - checkCUDA(hipblasSgemmStridedBatched(meta->handle.blas, + checkCUDA(hipblasSgemmStridedBatched(handle.blas, HIPBLAS_OP_T, HIPBLAS_OP_N, k, @@ -147,7 +138,7 @@ void backward_kernel(hipStream_t stream, k, a_stride, batch)); - checkCUDA(hipblasSgemmStridedBatched(meta->handle.blas, + checkCUDA(hipblasSgemmStridedBatched(handle.blas, HIPBLAS_OP_N, HIPBLAS_OP_T, m, @@ -165,10 +156,8 @@ void backward_kernel(hipStream_t stream, m, b_stride, batch)); - assert(c_grad_ptr == NULL); } -} // namespace Internal } // namespace BatchMatmul } // namespace Kernels } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..b05a5eb022 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -13,7 +13,6 @@ struct BatchMatmulAttrs { FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index b9df2d9037..9d407ec469 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -11,7 +11,10 @@ struct ParallelDim { int degree; req is_replica_dim; }; -FF_VISITABLE_STRUCT(ParallelDim, size, degree, is_replica_dim); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelDim, + size, + degree, + is_replica_dim); bool is_valid(ParallelDim const &); bool is_replica_dim(ParallelDim const &); diff --git a/lib/runtime/src/ops/attention.cc b/lib/runtime/src/ops/attention.cc index bca87bdb53..41905f9014 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/runtime/src/ops/attention.cc @@ -17,6 +17,7 @@ #include "kernels/attention_kernels.h" #include "legion.h" #include "op-attrs/ops/attention.h" +#include "task_spec/op_task_signature.h" namespace FlexFlow { @@ -121,18 +122,6 @@ static DeviceSpecific int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)]; int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)]; - assert(qoSeqLength == query.shape[legion_dim_t(1)]); - assert(qSize == query.shape[legion_dim_t(0)]); - assert(num_samples == key.shape[legion_dim_t(2)]); - assert(kvSeqLength == key.shape[legion_dim_t(1)]); - assert(kSize == key.shape[legion_dim_t(0)]); - assert(num_samples == value.shape[legion_dim_t(2)]); - assert(kvSeqLength == value.shape[legion_dim_t(1)]); - assert(vSize == value.shape[legion_dim_t(0)]); - assert(num_samples == output.shape[legion_dim_t(2)]); - assert(qoSeqLength == output.shape[legion_dim_t(1)]); - assert(oProjSize == output.shape[legion_dim_t(0)]); - DeviceSpecific per_device_state = acc.create_device_specific( init_kernel(handle, @@ -149,9 +138,6 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv)); - - assert(weight.shape.get_volume() * sizeof(float) == - acc.unwrap(per_device_state)->weightSize); return per_device_state; } @@ -299,7 +285,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, } template <> -void register_task() { +OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); init.add_arg_slot(QUERY_PARALLEL_TENSOR_SHAPE); init.add_arg_slot(KEY_PARALLEL_TENSOR_SHAPE); @@ -313,12 +299,19 @@ void register_task() { init.add_return_value(); - register_task( - ATTENTION_INIT_TASK_ID, "MultiHeadAttention Init", init, init_task); + return init; } template <> -void register_task() { +void register_task() { + register_task(ATTENTION_INIT_TASK_ID, + "Attention Init", + init_signature(), + init_task); +} + +template <> +OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); fwd.add_input_slot(QUERY); @@ -330,17 +323,31 @@ void register_task() { fwd.add_arg_slot(PROFILING); fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - register_task( - ATTENTION_FWD_TASK_ID, "MultiHeadAttention Fwd", fwd, forward_task); + return fwd; } template <> -void register_task() { +void register_task() { + register_task(ATTENTION_FWD_TASK_ID, + "Attention Fwd", + fwd_signature(), + forward_task); +} + +template <> +OpTaskSignature bwd_signature() { OpTaskSignature bwd = infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); - register_task( - ATTENTION_BWD_TASK_ID, "MultiHeadAttention Bwd", bwd, backward_task); + return bwd; +} + +template <> +void register_task() { + register_task(ATTENTION_BWD_TASK_ID, + "Attention Bwd", + bwd_signature(), + backward_task); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/attention.h b/lib/runtime/src/ops/attention.h index f0a5e0abc3..09a4ef036f 100644 --- a/lib/runtime/src/ops/attention.h +++ b/lib/runtime/src/ops/attention.h @@ -20,9 +20,9 @@ OpTaskInvocation backward(MultiHeadAttentionAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim, MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &query_shape, - ParallelTensorShape const &key_shape, - ParallelTensorShape const &value_shape, + InputParallelTensorDesc const &query_shape, + InputParallelTensorDesc const &key_shape, + InputParallelTensorDesc const &value_shape, ProfilingSettings const &settings, MachineView const &mv); } // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/runtime/src/ops/batch_matmul.cc index 3e860bd413..5f40def699 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/runtime/src/ops/batch_matmul.cc @@ -15,752 +15,233 @@ #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" -#include "kernels/profiling.h" -#include "legion/legion_utilities.h" -#include "tasks.h" +#include "legion.h" +#include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" +#include "task_spec/op_task_signature.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchMatmul; +using Legion::Context; +using Legion::PhysicalRegion; +using Legion::Runtime; +using Legion::Task; + enum Slots { - A_INPUT, - B_INPUT, - OUTPUT, - A_INPUT_GRAD, - B_INPUT_GRAD, - OUTPUT_GRAD, + A_INPUT, // tensor + B_INPUT, // tensor ATTRS, - PROFILING + OUTPUT, // tensor + PROFILING, + HANDLE, + ITERATION_CONFIG }; -OpTaskInvocation init(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; - - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, enable_profiling()); - - return {BATCHMATMUL_INIT_TASK_ID, b}; -} - OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b; - - b.bind(A_INPUT, input_tensor(0)); - b.bind(B_INPUT, input_tensor(1)); - b.bind(OUTPUT, output_tensor(0)); + OpTaskBinding fwd; - return {BATCHMATMUL_FWD_TASK_ID, b}; -} - -OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return {BATCHMATMUL_BWD_TASK_ID, b}; -} + fwd.bind(A_INPUT, input_tensor(0)); + fwd.bind(B_INPUT, input_tensor(1)); + fwd.bind(OUTPUT, output_tensor(0)); -BatchMatmulParams BatchMatmul::get_params() const { - BatchMatmulParams params; - params.a_seq_length_dim = inputs[0]->num_dims - 1 - this->a_seq_length_dim; - params.b_seq_length_dim = inputs[1]->num_dims - 1 - this->b_seq_length_dim; - return params; -} + fwd.bind_arg(ATTRS, attrs); + fwd.bind_arg(HANDLE, ff_handle()); + fwd.bind_arg(PROFILING, profiling_settings()); + fwd.bind_arg(ITERATION_CONFIG, iteration_config()); -Tensor FFModel::batch_matmul(const Tensor A, - const Tensor B, - int a_seq_length_dim, - int b_seq_length_dim, - char const *name) { - Layer *bmm = new Layer(this, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B); - assert((a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - int dims[MAX_TENSOR_DIM]; - int numdim = A->num_dims; - for (int i = 0; i < numdim; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - bmm->outputs[0] = create_tensor_legion_ordering( - numdim, dims, A->data_type, bmm, 0, true /*create_grad*/); - bmm->add_int_property("a_seq_length_dim", a_seq_length_dim); - bmm->add_int_property("b_seq_length_dim", b_seq_length_dim); - layers.push_back(bmm); - return bmm->outputs[0]; + return {BATCHMATMUL_FWD_TASK_ID, fwd}; } -Op *BatchMatmul::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("a_seq_length_dim", value); - int a_seq_length_dim = value; - layer->get_int_property("b_seq_length_dim", value); - int b_seq_length_dim = value; - return new BatchMatmul(model, - inputs[0], - inputs[1], - a_seq_length_dim, - b_seq_length_dim, - layer->name); -} +OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { + OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); -BatchMatmul::BatchMatmul( - FFModel &model, - BatchMatmulParams const ¶ms, - std::pair const &inputs, - char const *name) - : BatchMatmul(model, - inputs.first, - inputs.second, - params.a_seq_length_dim, - params.b_seq_length_dim, - name) {} - -// return A*B -BatchMatmul::BatchMatmul(FFModel &model, - const ParallelTensor A, - const ParallelTensor B, - int _a_seq_length_dim, - int _b_seq_length_dim, - char const *name) - : Op(model, - OP_BATCHMATMUL, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - A, - B), - a_seq_length_dim(A->num_dims - 1 - _a_seq_length_dim), - b_seq_length_dim(B->num_dims - 1 - _b_seq_length_dim) { - assert((_a_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert((_b_seq_length_dim <= 1) && - "FlexFlow currently only supports seq_length_dim of 0 or 1 (in " - "Fortran ordering)."); - assert(A->num_dims == B->num_dims); - for (int i = A->num_dims - 1; i >= 2; i--) { - assert(A->dims[i] == B->dims[i]); - } - assert(A->dims[0] == B->dims[1]); - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < A->num_dims; i++) { - dims[i] = A->dims[i]; - } - dims[0] = B->dims[0]; - numOutputs = 1; - outputs[0] = model.create_parallel_tensor_legion_ordering( - A->num_dims, dims, DT_FLOAT, this); - // C is not none - // if (C != Tensor::NO_TENSOR) { - // numInputs = 3; - // assert(C.num_dims == outputs[0].num_dims); - // for (int i = 0; i < C.num_dims; i++) - // assert(C.adim[i] == outputs[0].adim[i]); - //} + return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -void BatchMatmul::serialize(Legion::Serializer &sez) const { - BatchMatmulParams params = get_params(); - sez.serialize(params.a_seq_length_dim); - sez.serialize(params.b_seq_length_dim); -} +static optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto a_input = acc.get_tensor(A_INPUT); + auto b_input = acc.get_tensor(B_INPUT); + auto output = acc.get_tensor(OUTPUT); + auto attrs = acc.get_argument(ATTRS); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); -using PCG::Node; -/*static*/ -Node BatchMatmul::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 2); - int a_seq_length_dim, b_seq_length_dim; - dez.deserialize(a_seq_length_dim); - dez.deserialize(b_seq_length_dim); - - BatchMatmulParams params; - params.a_seq_length_dim = a_seq_length_dim; - params.b_seq_length_dim = b_seq_length_dim; - return ff.get_or_create_node({inputs[0], inputs[1]}, params); -} + ProfilingSettings profiling = acc.get_argument(PROFILING); + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); -Op *BatchMatmul::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - BatchMatmulParams params = get_params(); - return new BatchMatmul(ff, params, {inputs[0], inputs[1]}, this->name); -} + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); -template <> -void register_task() { - OpTaskSignature sig(OpTaskType::INIT); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); - sig.add_arg_slot(ATTRS); - sig.add_arg_slot(PROFILING); + int batch = 1; + for (int i = 2; i < a_input.shape.get_dim(); i++) { + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); + batch *= dim_size; + } - register_task(BATCHMATMUL_INIT_TASK_ID, "BatchMatmul Init", sig, init_task); + return profile(forward_kernel, + profiling, + "[BatchMatmul] forward_time = %.2lfms\n", + handle, + output.get_float_ptr(), + a_input.get_float_ptr(), + b_input.get_float_ptr(), + m, + n, + k, + batch, + attrs.a_seq_length_dim, + attrs.b_seq_length_dim, + iter_config.seq_length); +} + +static void forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + forward_task_impl(acc); } -static OpTaskSignature get_fwd_task_signature() { - OpTaskSignature fwd(OpTaskType::FWD); +static optional backward_task_impl(TaskArgumentAccessor const &acc) { + // BatchMatmul* bmm = (BatchMatmul*) task->args; + FFIterationConfig iter_config = + acc.get_argument(ITERATION_CONFIG); + ProfilingSettings profiling = acc.get_argument(PROFILING); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); - fwd.add_input_slot(A_INPUT, READ_WRITE); - fwd.add_input_slot(B_INPUT, READ_WRITE); - fwd.add_output_slot(OUTPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + assert(output.shape == output_grad.shape); - return fwd; -} + auto a_input = acc.get_tensor(A_INPUT); + auto a_input_grad = acc.get_tensor_grad(A_INPUT); + assert(a_input.shape == a_input_grad.shape); -static OpTaskSignature get_bwd_task_signature() { - OpTaskSignature bwd(OpTaskType::BWD); + auto b_input = acc.get_tensor(B_INPUT); + auto b_input_grad = acc.get_tensor_grad(B_INPUT); + assert(b_input.shape == b_input_grad.shape); - bwd.add_input_slot(A_INPUT); - bwd.add_input_slot(B_INPUT); - bwd.add_input_grad_slot(A_INPUT_GRAD); - bwd.add_input_grad_slot(B_INPUT_GRAD); - bwd.add_output_slot(OUTPUT); - bwd.add_output_grad_slot(OUTPUT_GRAD); + // check dins + int m = b_input.shape[legion_dim_t(0)]; + assert(m == output.shape[legion_dim_t(0)]); + int n = a_input.shape[legion_dim_t(1)]; + assert(n == output.shape[legion_dim_t(1)]); + int k = a_input.shape[legion_dim_t(0)]; + assert(k == b_input.shape[legion_dim_t(1)]); + assert(a_input.shape.get_volume() == b_input.shape.get_volume()); + assert(a_input.shape.get_volume() == output.shape.get_volume()); + int batch = 1; + for (int i = 2; i < a_input.shape.dims.num_dims(); i++) { + int dim_size = a_input.shape[legion_dim_t(i)]; + assert(dim_size == b_input.shape[legion_dim_t(i)]); + assert(dim_size == output.shape[legion_dim_t(i)]); + batch *= dim_size; + } - return bwd; + return profile(backward_kernel, + profiling, + "[BatchMatmul] backward_time = %.2lfms\n", + handle, + output.get_float_ptr(), + output_grad.get_float_ptr(), + a_input.get_float_ptr(), + a_input_grad.get_float_ptr(), + b_input.get_float_ptr(), + b_input_grad.get_float_ptr(), + m, + n, + k, + batch); +} + +static void backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + TaskArgumentAccessor acc(task, regions, ctx, runtime); + backward_task_impl(acc); } -OpTaskBinding BatchMatmul::get_init_task_binding() const { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, this->attrs); - binding.bind_arg(PROFILING, this->profiling); +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + BatchMatmulAttrs const &attrs, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, + ProfilingSettings const &settings, + MachineView const &pc) { + auto env = sim.new_environment(); - return binding; -} - -OpTaskBinding BatchMatmul::get_fwd_task_binding() const { - OpTaskBinding binding; + ParallelTensorShape output_shape = + get_output_shape(attrs, a_input.shape, b_input.shape); - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind(OUTPUT, output_tensor(0)); + SimTaskBinding fwd_binding; + fwd_binding.bind(A_INPUT, a_input); + fwd_binding.bind(B_INPUT, b_input); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(ATTRS, attrs); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(ATTRS, this->attrs); - return binding; -} + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); -OpTaskBinding BatchMatmul::get_bwd_task_binding() const { - OpTaskBinding binding; - binding.bind(A_INPUT, input_tensor(0)); - binding.bind(B_INPUT, input_tensor(1)); - binding.bind_grad(A_INPUT_GRAD, input_tensor(0).grad()); - binding.bind_grad(B_INPUT_GRAD, input_tensor(1).grad()); + auto fwd_accessor = + env.get_fwd_accessor(BATCHMATMUL_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(BATCHMATMUL_BWD_TASK_ID, bwd_binding); - binding.bind(OUTPUT, output_tensor(0)); - binding.bind_grad(OUTPUT_GRAD, output_tensor(0).grad()); + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); - binding.bind_arg(ATTRS, this->attrs); - return binding; + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); } -void BatchMatmul::init(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // init_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_INIT_TASK_ID, get_init_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} // namespace FlexFlow -// / -// template -// void BatchMatmul::init_with_dim(FFModel const &ff) { -// assert(check_output_input_weight_same_parallel_is()); -// parallel_is = outputs[0]->parallel_is; -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_init(ff, argmap); -// IndexLauncher launcher(BATCHMATMUL_INIT_TASK_ID, -// parallel_is, -// TaskArgument(this, sizeof(BatchMatmul)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// FutureMap fm = runtime->execute_index_space(ctx, launcher); -// fm.wait_all_results(); -// set_opmeta_from_futuremap(ff, fm); -// } - -PerDeviceOpState * - BatchMatmul::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - auto const &attrs = acc.get_argument(ATTRS); - bool profiling = acc.get_argument(PROFILING); - - FFHandler handle = *((FFHandler const *)task->local_args); - BatchMatmulPerDeviceState *m = new BatchMatmulPerDeviceState(handle); - m->profiling = profiling; - m->a_seq_length_dim = attrs.a_seq_length_dim; - m->b_seq_length_dim = attrs.b_seq_length_dim; - return m; -} - -void BatchMatmul::forward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - // forward_with_dim(ff); - this->execute_task(ff, BATCHMATMUL_FWD_TASK_ID, get_fwd_task_signature()); - break; - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); -} -} - -// template -// void BatchMatmul::forward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_forward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_FWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// WRITE_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// for (int i = 0; i < numInputs; i++) { -// launcher.add_region_requirement(RegionRequirement(inputs[i]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[i]->region)); -// launcher.add_field(i + 1, FID_DATA); -// } -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](O): output - regions[1](I): A - regions[2](I): B - ////////////////////(optional) regions[3](I): C -- TODO: is C deprecated? - output = A * B /////////+ C -*/ -void BatchMatmul::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); - TaskArgumentAccessor acc(task, regions, ctx, runtime); + fwd.add_input_slot(A_INPUT); + fwd.add_input_slot(B_INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_arg_slot(ATTRS); + fwd.add_arg_slot(PROFILING); + fwd.add_unchecked_arg_slot(HANDLE); - // const BatchMatmul* bmm = (const BatchMatmul*) task->args; - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - // BatchMatmulMeta const *meta = *((BatchMatmulMeta **)task->local_args); - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - - auto a_input = acc.get_tensor(A_INPUT); - auto b_input = acc.get_tensor(B_INPUT); - auto output = acc.get_tensor(OUTPUT); - - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); - - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); - int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); - batch *= dim_size; - } - float *out_ptr = output.get_float_ptr(); - c float const *a_ptr = a_input.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float const *c_ptr = NULL; - // if (regions.size() == 4) { - // Domain c_domain = runtime->get_index_space_domain( - // ctx, task->regions[3].region.get_index_space()); - // assert(c_domain == a_domain); - // c_ptr = helperGetTensorPointerRO( - // regions[3], task->regions[3], FID_DATA, ctx, runtime); - // } - - profile(forward_kernel, - meta->profiling, - "[BatchMatmul] forward_time = %.2lfms\n", - out_ptr, - a_ptr, - b_ptr, - c_ptr, - m, - n, - k, - batch, - meta->a_seq_length_dim, - meta->b_seq_length_dim, - iter_config->seq_length); + return fwd; } -void BatchMatmul::backward(FFModel const &ff) { - int dim = outputs[0]->num_dims; - switch (dim) { -#define DIMFUNC(DIM) \ - case DIM: { \ - backward_with_dim(ff); \ - break; \ - } - LEGION_FOREACH_N(DIMFUNC) -#undef DIMFUNC - default: - assert(false); - } +template <> +void register_task() { + register_task(BATCHMATMUL_FWD_TASK_ID, + "BatchMatmul Fwd", + fwd_signature(), + forward_task); } -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -// template -// void BatchMatmul::backward_with_dim(FFModel const &ff) { -// ArgumentMap argmap; -// Context ctx = ff.config.lg_ctx; -// Runtime *runtime = ff.config.lg_hlr; -// set_argumentmap_for_backward(ff, argmap); -// IndexLauncher launcher( -// BATCHMATMUL_BWD_TASK_ID, -// parallel_is, -// TaskArgument(&ff.iter_config, sizeof(FFIterationConfig)), -// argmap, -// Predicate::TRUE_PRED, -// false /*must*/, -// 0 /*mapper_id*/, -// outputs[0]->machine_view.hash()); -// // regions[0](I): output -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region)); -// launcher.add_field(0, FID_DATA); -// // regions[1](I): output_grad -// launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// outputs[0]->region_grad)); -// launcher.add_field(1, FID_DATA); -// // regions[2](I): A -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[0]->region)); -// launcher.add_field(2, FID_DATA); -// // regions[3](I/O): A_grad -// launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[0]->region_grad)); -// launcher.add_field(3, FID_DATA); -// // regions[4](I): B -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part, -// 0 /*projection id*/, -// READ_ONLY, -// EXCLUSIVE, -// inputs[1]->region)); -// launcher.add_field(4, FID_DATA); -// // regions[5](I/O): B_grad -// launcher.add_region_requirement(RegionRequirement(inputs[1]->part_grad, -// 0 /*projection id*/, -// READ_WRITE, -// EXCLUSIVE, -// inputs[1]->region_grad)); -// launcher.add_field(5, FID_DATA); -// runtime->execute_index_space(ctx, launcher); -// } - -/* - regions[0](I): output - regions[1](I): output_grad - regions[2](I): A - regions[3](I/O): A_grad - regions[4](I): B - regions[5](I/O): B_grad - regions[6](I/O): C_grad -*/ -__host__ void - BatchMatmul::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - // Currently assume C is NULL - assert(regions.size() == 6); - assert(task->regions.size() == 6); - // BatchMatmul* bmm = (BatchMatmul*) task->args; - TaskArgumentAccessor acc(task, regions, ctx, runtime); - FFIterationConfig const *iter_config = (FFIterationConfig const *)task->args; - BatchMatmulPerDeviceState const *meta = - *((BatchMatmulPerDeviceState **)task->local_args); - // output domains - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); - assert(output == - output_grad); // is this equivalent to checking `Domain` equality? - // A domains - auto a_input = acc.get_tensor(A_INPUT); - auto a_input_grad = acc.get_tensor(A_INPUT_GRAD); - assert(a_input == a_input_grad); - // B domains - auto b_input = acc.get_tensor(B_INPUT); - auto b_input_grad = acc.get_tensor(B_INPUT_GRAD); - assert(b_input == b_input_grad); +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = + infer_bwd_signature(fwd_signature()); - // check dins - int m = b_input.shape[0]; - assert(m == output.shape[0]); - int n = a_input.shape[1]; - assert(n == output.shape[1]); - int k = a_input.shape[0]; - assert(k == b_input.shape[1]); - assert(a_input.shape.size() == b_input.shape.size()); - assert(a_input.shape.size() == output.shape.size()); - int batch = 1; - for (int i = 2; i < a_input.shape.size(); i++) { - int dim_size = a_input.shape[i]; - assert(dim_size == b_input.shape[i]); - assert(dim_size == output.shape[i]); - batch *= dim_size; - } - // get pointers - float const *out_ptr = output.get_float_ptr(); - float const *out_grad_ptr = output_grad.get_float_ptr(); - float const *a_ptr = a_input.get_float_ptr(); - float *a_grad_ptr = a_input_grad.get_float_ptr(); - float const *b_ptr = b_input.get_float_ptr(); - float *b_grad_ptr = b_input_grad.get_float_ptr(); - - float *c_grad_ptr = NULL; - - // TODO: add support for meta->a_seq_length_dim >= 0 - // or meta->b_seq_length_dim >= 0 - assert((meta->a_seq_length_dim >= a_len) || (iter_config->seq_length == 0)); - assert((meta->b_seq_length_dim >= b_len) || (iter_config->seq_length == 0)); - - profile(backward_kernel, - meta->profiling, - "[BatchMatmul] backward_time = %.2lfms\n", - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); + return bwd; } -void BatchMatmul::print_layer(FFModel const &ff) { - return; +template <> +void register_task() { + register_task(BATCHMATMUL_BWD_TASK_ID, + "BatchMatmul Bwd", + bwd_signature(), + backward_task); } -bool BatchMatmul::measure_operator_cost(Simulator *sim, - MachineView const &pc, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_output, sub_input0, sub_input1; - if (!outputs[0]->get_sub_tensor(pc, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(pc, sub_input0)) { - return false; - } - if (!inputs[1]->get_sub_tensor(pc, sub_input1)) { - return false; - } - - int input0_c = sub_input0.dims[0].size; - int input0_r = sub_input0.dims[1].size; - int input1_c = sub_input1.dims[0].size; - int input1_r = sub_input1.dims[1].size; - int output_c = sub_output.dims[0].size; - int output_r = sub_output.dims[1].size; - - assert(input0_c == input1_r); - assert(input0_r == output_r); - assert(input1_c == output_c); - - assert(sub_input0.dims[2] == sub_input1.dims[2]); - assert(sub_input1.dims[2] == sub_output.dims[2]); - int batch = 1; - assert(sub_input0.num_dims == sub_input1.num_dims); - for (int i = 2; i < sub_input0.num_dims; i++) { - assert(sub_input0.dims[i] == sub_input1.dims[i]); - assert(sub_input0.dims[i] == sub_output.dims[i]); - batch *= sub_input0.dims[i].size; - } - - BatchMatmulPerDeviceState *meta = sim->batch_matmul_meta; - - // allocate tensors in simulator - sim->free_all(); - float *a_ptr = (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - assert(a_ptr != NULL); - float *b_ptr = (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - assert(b_ptr != NULL); - float *c_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_ptr != NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - int m = input1_c; - int n = input0_r; - int k = input0_c; - - assert(meta->profiling == false); - - std::function forward, backward; - forward = [&](ffStream_t stream) { - forward_kernel(stream, meta, out_ptr, a_ptr, b_ptr, c_ptr, m, n, k, batch); - }; - - if (sim->computationMode == COMP_MODE_TRAINING) { - float *a_grad_ptr = - (float *)sim->allocate(sub_input0.get_volume(), DT_FLOAT); - float *b_grad_ptr = - (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT); - float *c_grad_ptr = NULL; - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - - float *out_grad_ptr = - (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT); - assert(out_grad_ptr != NULL); - cost_metrics.outputs_memory += - cost_metrics.total_mem_diff_from(sim->offset); - - backward = [&](ffStream_t stream) { - backward_kernel(stream, - meta, - out_ptr, - out_grad_ptr, - a_ptr, - a_grad_ptr, - b_ptr, - b_grad_ptr, - c_grad_ptr, - m, - n, - k, - batch); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf) backward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure BatchMatmul] name(%s) adim(%d %d %d) bdim(%d %d %d) " - "odim(%d %d %d) forward_time(%.4lf)\n", - name, - batch, - input0_r, - input0_c, - batch, - input1_r, - input1_c, - batch, - output_r, - output_c, - cost_metrics.forward_time); - } - - return true; -} -} -; // namespace FlexFlow +}; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/runtime/src/ops/batch_matmul.h index c133c2a875..7d3f2308da 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/runtime/src/ops/batch_matmul.h @@ -2,84 +2,26 @@ #define _FLEXFLOW_BATCH_MATMUL_H #include "op-attrs/ops/batch_matmul.h" -#include "op_task_invocation.h" -#include "op_task_signature.h" #include "sim_environment.h" +#include "task_spec/op_task_invocation.h" +#include "task_spec/op_task_signature.h" namespace FlexFlow { -template <> -void register_task(); template <> void register_task(); template <> void register_task(); -OpTaskInvocation init(BatchMatmulAttrs const &); OpTaskInvocation forward(BatchMatmulAttrs const &); OpTaskInvocation backward(BatchMatmulAttrs const &); -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, +CostMetrics measure_operator_cost(SimEnvFactory const &sim, BatchMatmulAttrs const &attrs, - ParallelTensorShape const &lhs_input_shape, - ParallelTensorShape const &rhs_input_shape, + InputParallelTensorDesc const &a_input, + InputParallelTensorDesc const &b_input, ProfilingSettings const &settings, - MachineView const &); - -/* class BatchMatmul : public Op { */ -/* public: */ -/* BatchMatmul(FFModel &model, */ -/* const ParallelTensor A, */ -/* const ParallelTensor B, */ -/* int a_seq_length_dim, */ -/* int b_seq_length_dim, */ -/* char const *name = nullptr); */ -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ -/* /1* static PCG::Node deserialize(FFModel &ff, *1/ */ -/* /1* Legion::Deserializer &d, *1/ */ -/* /1* ParallelTensor inputs[], *1/ */ -/* /1* int num_inputs); *1/ */ -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ - -/* OpTaskBinding get_init_task_binding() const override; */ -/* OpTaskBinding get_fwd_task_binding() const override; */ -/* OpTaskBinding get_bwd_task_binding() const override; */ -/* private: */ -/* template */ -/* void init_with_dim(FFModel const &ff); */ -/* template */ -/* void forward_with_dim(FFModel const &ff); */ -/* template */ -/* void backward_with_dim(FFModel const &ff); */ - -/* public: */ -/* int a_seq_length_dim, b_seq_length_dim; */ -/* }; */ + MachineView const &pc); } // namespace FlexFlow diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/runtime/src/task_spec/runtime_arg_ref.h index 033c2bcfbc..655300e692 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/runtime/src/task_spec/runtime_arg_ref.h @@ -7,7 +7,11 @@ namespace FlexFlow { -enum class RuntimeArgRefType { FF_HANDLE, PROFILING_SETTINGS }; +enum class RuntimeArgRefType { + FF_HANDLE, + PROFILING_SETTINGS, + FF_ITERATION_CONFIG +}; template using RuntimeArgRef = ArgRef;