diff --git a/deps/fmt b/deps/fmt index f5e54359df..a33701196a 160000 --- a/deps/fmt +++ b/deps/fmt @@ -1 +1 @@ -Subproject commit f5e54359df4c26b6230fc61d38aa294581393084 +Subproject commit a33701196adfad74917046096bf5a2aa0ab0bb50 diff --git a/lib/kernels/src/cuda/aggregate_spec_kernels.cu b/lib/kernels/src/cuda/aggregate_spec_kernels.cu deleted file mode 100644 index 8a39b7f558..0000000000 --- a/lib/kernels/src/cuda/aggregate_spec_kernels.cu +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernels/aggregate_spec_kernels.h" -#include "kernels/cuda_helper.h" - -namespace FlexFlow { - -AggregateSpecPerDeviceState::AggregateSpecPerDeviceState(FFHandler handler, - int n) - : PerDeviceOpState(handler) { - checkCUDA(cudaMalloc(&dev_region_ptrs, n * sizeof(float *))); -} -AggregateSpecPerDeviceState::~AggregateSpecPerDeviceState(void) { - checkCUDA(cudaFree(&dev_region_ptrs)); -} - -namespace Kernels { -namespace AggregateSpec { - -void forward_kernel(cudaStream_t stream, - AggregateSpecPerDeviceState const *m, - float **exp_preds, - int const *acc_gate_assign_ptr, - float *acc_output_ptr, - int n, - int const k, - int rows, - int const batch_size, - int out_dim) { - - checkCUDA(cublasSetStream(m->handle.blas, stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - - // call forward kernel - cudaMemcpy(m->dev_region_ptrs, - exp_preds, - n * sizeof(float *), - cudaMemcpyHostToDevice); - - aggspec_forward_kernel<<>>(m->dev_region_ptrs, - acc_gate_assign_ptr, - acc_output_ptr, - n, - k, - rows, - batch_size, - out_dim); -} - -void backward_kernel(cudaStream_t stream, - AggregateSpecPerDeviceState const *m, - float **exp_grads, - int const *acc_gate_assign_ptr, - int const *acc_true_gate_assign_ptr, - float const *acc_gate_pred_ptr, - float *acc_full_gate_grad_ptr, - float const *acc_output_grad_ptr, - int n, - int const k, - int rows, - float lambda_bal, - int const batch_size, - int out_dim) { - - checkCUDA(cublasSetStream(m->handle.blas, stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - - // call backward kernel - cudaMemcpy(m->dev_region_ptrs, - exp_grads, - n * sizeof(float *), - cudaMemcpyHostToDevice); - - aggspec_backward_kernel<<>>(m->dev_region_ptrs, - acc_gate_assign_ptr, - acc_true_gate_assign_ptr, - acc_gate_pred_ptr, - acc_full_gate_grad_ptr, - acc_output_grad_ptr, - n, - k, - rows, - lambda_bal, - batch_size, - out_dim); -} - -__global__ void - aggspec_forward_kernel(float **exp_preds, - int const *exp_assign, - float *output, - int n, // num experts - int const k, // num chosen experts - int exp_samples, // max samples per expert - int const batch_size, - int out_dim) { - __shared__ float - *chosen_exp_preds[AGGREGATE_SPEC_MAX_K * AGGREGATE_SPEC_MAX_BATCH_SIZE]; - - // Get pred pointers, single thread per block - if (threadIdx.x == 0) { - int expert_idx[AGGREGATE_SPEC_MAX_N] = {0}; - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < k; j++) { - // Get pointer to chosen expert predictions - int expert = exp_assign[i * k + j]; - if (expert_idx[expert] >= exp_samples) { - // dropped sample - chosen_exp_preds[i * k + j] = 0; - continue; - } - chosen_exp_preds[i * k + j] = - exp_preds[expert] + expert_idx[expert] * out_dim; - expert_idx[expert]++; - } - } - } - - __syncthreads(); - - // compute output - CUDA_KERNEL_LOOP(i, k * batch_size * out_dim) { - if (chosen_exp_preds[i / out_dim] != 0) { - output[i] = chosen_exp_preds[i / out_dim][i % out_dim]; - } else { - output[i] = 0.0f; - } - } -} - -__device__ void aggspec_backward_kernel_gate(float const *output_grad, - float *full_gate_grads, - int const *expert_assign, - bool const *cache_corr, - float const *gate_pred, - int *expert_bal, - float lambda_bal, - int batch_size, - int k, - int n, - int out_dim) { - - __shared__ float gate_grad_sum[AGGREGATE_SPEC_MAX_BATCH_SIZE]; - - // init gate_grad_sum to 0 - CUDA_KERNEL_LOOP(i, batch_size) { - gate_grad_sum[i] = 0.0f; - } - - __syncthreads(); - - // get sum of expert errors - /* NOTE: Errors just squared L2 norm of gradients. * batch_size because the - expert gradients are /= batch_size and then it would be /= batch_size^2 here -*/ - CUDA_KERNEL_LOOP(i, batch_size * k * out_dim) { - if (cache_corr[i / (k * out_dim)]) { - float res = output_grad[i] * output_grad[i] * batch_size; - float *gate_grad_idx = - full_gate_grads + (i / (out_dim * k)) * n + - expert_assign[(i / (out_dim * k)) * k + (i / out_dim) % k]; - atomicAdd(gate_grad_idx, res); - atomicAdd(gate_grad_sum + i / (k * out_dim), res); - } - } - - // Compute gate gradients: - // Assigned expert i, sample j: pred(i,j) - err_(i,j)/sum_l err(l,j) - __syncthreads(); - CUDA_KERNEL_LOOP(i, k * batch_size) { - if (cache_corr[i / k]) { - full_gate_grads[i / k * n + expert_assign[i]] /= gate_grad_sum[i / k]; - full_gate_grads[i / k * n + expert_assign[i]] -= (1.0f - gate_pred[i]); - } - } - - // balance term - __syncthreads(); - CUDA_KERNEL_LOOP(i, n * batch_size) { - full_gate_grads[i] += lambda_bal * expert_bal[i % n]; - } - - __syncthreads(); - - // make 0 mean - CUDA_KERNEL_LOOP(i, n * batch_size) { - int start = (i / n) * n; - float sub = -full_gate_grads[i] / n; - for (int j = 0; j < n; j++) { - atomicAdd(full_gate_grads + start + j, sub); - } - } -} - -__device__ void aggspec_backward_kernel_exp(float const *output_grad, - float const *gate_preds, - float **exp_grads, - int batch_size, - int k, - int out_dim) { - // compute expert gradients - CUDA_KERNEL_LOOP(i, k * out_dim * batch_size) { - if (exp_grads[i / out_dim] != 0) { - exp_grads[i / out_dim][i % out_dim] += - gate_preds[i / out_dim] * output_grad[i]; - } - } -} - -__global__ void - aggspec_backward_kernel(float **exp_grads, - int const *exp_assign, - int const *true_exp_assign, - float const *gating_net_preds, - float *full_gating_grads, - float const *output_grads, - int n, // num experts - int k, // num chosen experts - int exp_samples, // max samples per expert - float lambda_bal, - int batch_size, - int out_dim) { - __shared__ float - *chosen_exp_grads[AGGREGATE_SPEC_MAX_K * AGGREGATE_SPEC_MAX_BATCH_SIZE]; - __shared__ int expert_bal[AGGREGATE_SPEC_MAX_N]; - __shared__ bool cache_corr[AGGREGATE_SPEC_MAX_BATCH_SIZE]; - - // Get pred pointers, single thread per block - if (threadIdx.x == 0) { - // init arrays - for (int i = 0; i < n; i++) { - expert_bal[i] = 0; - } - for (int i = 0; i < batch_size; i++) { - cache_corr[i] = true; - } - - // Get pointer to chosen expert grads and expert counts - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < k; j++) { - int expert = true_exp_assign[k * i + j]; - if (expert != exp_assign[k * i + j]) { - cache_corr[i] = false; - } - if (expert_bal[expert] >= exp_samples) { - // dropped sample - chosen_exp_grads[i * k + j] = 0; - expert_bal[expert]++; - continue; - } - chosen_exp_grads[i * k + j] = - exp_grads[expert] + expert_bal[expert] * out_dim; - expert_bal[expert]++; - } - } - } - - __syncthreads(); - - // NOTE: These 2 functions could execute independently in parallel - // get expert gradients - aggspec_backward_kernel_exp( - output_grads, gating_net_preds, chosen_exp_grads, batch_size, k, out_dim); - - // get gating net gradients - aggspec_backward_kernel_gate(output_grads, - full_gating_grads, - exp_assign, - cache_corr, - gating_net_preds, - expert_bal, - (lambda_bal * n) / batch_size, - batch_size, - k, - n, - out_dim); -} - -} // namespace AggregateSpec -} // namespace Kernels -} // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index ec3e592607..a80c579689 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -12,6 +12,7 @@ struct MultiHeadAttentionAttrs { req dropout; req bias, add_bias_kv, add_zero_attn; }; + FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, embed_dim, num_heads, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..c9d81c98e4 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -14,6 +14,9 @@ FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +ParallelTensorShape get_output_shape(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 4ec823d4ae..29b76d96e9 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -12,7 +12,8 @@ struct BatchNormAttrs { }; FF_VISITABLE_STRUCT(BatchNormAttrs, relu); -ParallelTensorShape get_output_shape(BatchNormAttrs const &); +ParallelTensorShape get_output_shape(BatchNormAttrs const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 63563f8df8..403fcc21a6 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -13,6 +13,9 @@ struct CastAttrs { }; FF_VISITABLE_STRUCT(CastAttrs, dtype); +ParallelTensorShape get_output_shape(CastAttrs const &, + ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(CastAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index deaba9e093..49bea57a38 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -15,6 +15,9 @@ struct CombineAttrs { FF_VISITABLE_STRUCT(CombineAttrs, combine_dim, combine_degree); CHECK_VALID_OP_ATTR(CombineAttrs); +ParallelTensorShape get_output_shape(CombineAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index 78f848f18b..09b9f2f2ca 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -15,6 +15,8 @@ struct ConcatAttrs { FF_VISITABLE_STRUCT(ConcatAttrs, axis, num_inputs); CHECK_VALID_OP_ATTR(ConcatAttrs); +ParallelTensorShape get_output_shape(ConcatAttrs const &, + std::vector const &); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 3034dc8c62..79233eb8fc 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -32,6 +32,9 @@ CHECK_VALID_OP_ATTR(Conv2DAttrs); TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &); TensorShape get_bias_shape(Conv2DAttrs const &, TensorShape const &); +ParallelTensorShape get_output_shape(Conv2DAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 8e0049f526..edf6db9ea8 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -14,6 +14,9 @@ struct DropoutAttrs { FF_VISITABLE_STRUCT(DropoutAttrs, rate, seed); CHECK_VALID_OP_ATTR(DropoutAttrs); +ParallelTensorShape get_output_shape(DropoutAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index c4a096166d..377a03970a 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H #include "core.h" -#include "op-attrs/datatype.h" #include "op-attrs/op.h" #include "op-attrs/parallel_tensor_shape.h" #include "utils/visitable.h" @@ -22,6 +21,10 @@ FF_VISITABLE_STRUCT(ElementBinaryAttrs, should_broadcast_rhs); CHECK_VALID_OP_ATTR(ElementBinaryAttrs); +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 1b72e83cb5..d0dbc3661c 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -22,6 +22,9 @@ struct ElementUnaryAttrs { FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 8b00fa22ce..52d22fe836 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -23,6 +23,9 @@ struct EmbeddingAttrs { FF_VISITABLE_STRUCT(EmbeddingAttrs, num_entries, out_channels, aggr, data_type); CHECK_VALID_OP_ATTR(EmbeddingAttrs); +ParallelTensorShape get_output_shape(EmbeddingAttrs const &, + ParallelTensorShape const &); + TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 706689199d..88b0a6cb54 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -11,6 +11,9 @@ struct FlatAttrs {}; FF_VISITABLE_STRUCT(FlatAttrs); CHECK_VALID_OP_ATTR(FlatAttrs); +ParallelTensorShape get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index ca2406ef75..ad97e52556 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -14,6 +14,9 @@ struct GatherAttrs { FF_VISITABLE_STRUCT(GatherAttrs, dim); CHECK_VALID_OP_ATTR(GatherAttrs); +std::vector get_output_shape(GatherAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/groupby.h b/lib/op-attrs/include/op-attrs/ops/groupby.h index 174c40242e..d2c1033b31 100644 --- a/lib/op-attrs/include/op-attrs/ops/groupby.h +++ b/lib/op-attrs/include/op-attrs/ops/groupby.h @@ -14,6 +14,10 @@ struct Group_byAttrs { FF_VISITABLE_STRUCT(Group_byAttrs, n, alpha); CHECK_VALID_OP_ATTR(Group_byAttrs); +ParallelTensorShape get_output_shape(Group_byAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index dab055b2c9..f279b0650c 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -16,6 +16,9 @@ struct LayerNormAttrs { FF_VISITABLE_STRUCT(LayerNormAttrs, axes, elementwise_affine, eps); CHECK_VALID_OP_ATTR(LayerNormAttrs); +ParallelTensorShape get_output_shape(LayerNormAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 3be8be2040..e696bb9fd0 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -34,6 +34,9 @@ FF_VISITABLE_STRUCT( LinearAttrs, out_channels, use_bias, data_type, activation, regularizer); CHECK_VALID_OP_ATTR(LinearAttrs); +ParallelTensorShape get_output_shape(LinearAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index efe29b3b2e..3bc862c481 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -29,6 +29,9 @@ FF_VISITABLE_STRUCT(Pool2DAttrs, activation); CHECK_VALID_OP_ATTR(Pool2DAttrs); +ParallelTensorShape get_output_shape(Pool2DAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow namespace fmt { diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 193d3b0dc8..96827a83cc 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -18,6 +18,9 @@ struct ReduceAttrs { FF_VISITABLE_STRUCT(ReduceAttrs, axes, op_type, keepdims); CHECK_VALID_OP_ATTR(ReduceAttrs); +ParallelTensorShape get_output_shape(ReduceAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index f848f879fc..70f268c97d 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -15,6 +15,9 @@ struct ReductionAttrs { FF_VISITABLE_STRUCT(ReductionAttrs, reduction_dim, reduction_degree); CHECK_VALID_OP_ATTR(ReductionAttrs); +ParallelTensorShape get_output_shape(ReductionAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 83c4ae870b..8abdc6eb1c 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -15,6 +15,9 @@ struct RepartitionAttrs { FF_VISITABLE_STRUCT(RepartitionAttrs, repartition_dim, repartition_degree); CHECK_VALID_OP_ATTR(RepartitionAttrs); +ParallelTensorShape get_output_shape(RepartitionAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 92e64a4120..2bbcad9d95 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -15,6 +15,9 @@ struct ReplicateAttrs { FF_VISITABLE_STRUCT(ReplicateAttrs, replicate_dim, replicate_degree); CHECK_VALID_OP_ATTR(ReplicateAttrs); +ParallelTensorShape get_output_shape(ReplicateAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index b118482a2b..78b9806fe7 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_RESHAPE_ATTRS_H #include "core.h" +#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" #include "utils/visitable.h" @@ -13,6 +14,9 @@ struct ReshapeAttrs { FF_VISITABLE_STRUCT(ReshapeAttrs, shape); CHECK_VALID_OP_ATTR(ReshapeAttrs); +ParallelTensorShape get_output_shape(ReshapeAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 6030285f14..ce1295f437 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -3,6 +3,7 @@ #include "core.h" #include "op-attrs/ff_dim.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/visitable.h" namespace FlexFlow { @@ -13,6 +14,9 @@ struct ReverseAttrs { FF_VISITABLE_STRUCT(ReverseAttrs, axis); CHECK_VALID_OP_ATTR(ReverseAttrs); +ParallelTensorShape get_output_shape(ReverseAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 9a776737f5..8f31bccdef 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -14,6 +14,9 @@ struct SoftmaxAttrs { FF_VISITABLE_STRUCT(SoftmaxAttrs, dim); CHECK_VALID_OP_ATTR(SoftmaxAttrs); +ParallelTensorShape get_output_shape(SoftmaxAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index fa66bc46f5..f2f904a9f7 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -7,13 +7,14 @@ #include "utils/visitable.h" namespace FlexFlow { - struct SplitAttrs { req> splits; ff_dim_t axis; }; FF_VISITABLE_STRUCT(SplitAttrs, splits, axis); CHECK_VALID_OP_ATTR(SplitAttrs); +std::vector get_output_shape(SplitAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index 413855913c..3a3b49ab3b 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -7,13 +7,18 @@ namespace FlexFlow { +// pytorch code: torch.topk(input_tensor, k, largest=True, sorted=True, dim=dim) struct TopKAttrs { req k; req sorted; + req axis; }; -FF_VISITABLE_STRUCT(TopKAttrs, k, sorted); +FF_VISITABLE_STRUCT(TopKAttrs, k, sorted, axis); CHECK_VALID_OP_ATTR(TopKAttrs); +ParallelTensorShape get_output_shape(TopKAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 87db435979..461aa0aacb 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -10,10 +10,14 @@ namespace FlexFlow { struct TransposeAttrs { req> perm; + bool is_valid(ParallelTensorShape const &) const; }; FF_VISITABLE_STRUCT(TransposeAttrs, perm); CHECK_VALID_OP_ATTR(TransposeAttrs); +ParallelTensorShape get_output_shape(TransposeAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index fd560352bb..e7df3b72df 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -27,11 +27,15 @@ struct ParallelTensorShape : public use_visitable_cmp { int num_dims() const; + int get_volume() const; + ParallelDim const &at(ff_dim_t const &) const; ParallelDim &at(ff_dim_t const &); ParallelDim const &operator[](ff_dim_t const &) const; ParallelDim &operator[](ff_dim_t const &); + bool is_valid() const; + public: ParallelTensorDims dims; DataType data_type; diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/attention.cc index e9ae6ec803..fb9ab0cd29 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/attention.cc @@ -1,4 +1,8 @@ #include "op-attrs/ops/attention.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/exception.decl.h" +#include "utils/exception.h" namespace FlexFlow { @@ -53,31 +57,201 @@ TensorShape return {dims, DataType::FLOAT}; } -ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &query_shape, - ParallelTensorShape const &key_shape, - ParallelTensorShape const &value_shape) { - /* ParallelDim replica_dim = query_shape.at(ff_dim_t(query_shape.num_dims() - - * 2)); */ - /* replica_dim.size = replica_dim.degree; */ +// according to the pytorch +// https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html, +// we consider the batch size +// query: [replicate_num, seq_len, batch_size, embed_dim],4D, +// key: (replicate_num, seq_len, batch_size, embed_dim) +// value: (replicate_num ,seq_len, batch_size,embed_dim) +// multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) +// output: (seq_len, batch_size, embed_dim) - /* ParallelDim */ +ParallelTensorShape get_output_shape( + MultiHeadAttentionAttrs const &attrs, + MultiHeadAttentionInputs const &input) { + if (input.query.num_dims() != 3 || input.key.num_dims() != 3 || + input.value.num_dims() != 3) { + throw mk_runtime_error("MultiHeadAttentionAttrs: num_dims != 3"); + } - ParallelTensorShape output_shape = query_shape; - output_shape.at(ff_dim_t(output_shape.num_dims() - 1)).size = attrs.embed_dim; - return output_shape; + if (input.query.at(ff_dim_t(0)).size != input.key.at(ff_dim_t(0)).size || + input.query.at(ff_dim_t(0)).size != input.value.at(ff_dim_t(0)).size || + input.key.at(ff_dim_t(0)).size != input.value.at(ff_dim_t(0)).size) { + throw mk_runtime_error("MultiHeadAttentionAttrs: seq_len not match"); + } + + if (input.query.at(ff_dim_t(1)).size != input.key.at(ff_dim_t(1)).size || + input.query.at(ff_dim_t(1)).size != input.value.at(ff_dim_t(1)).size || + input.key.at(ff_dim_t(1)).size != input.value.at(ff_dim_t(1)).size) { + throw mk_runtime_error("MultiHeadAttentionAttrs: batch_size not match"); + } + + if (input.query.at(ff_dim_t(2)).size != input.key.at(ff_dim_t(2)).size || + input.query.at(ff_dim_t(2)).size != input.value.at(ff_dim_t(2)).size || + input.key.at(ff_dim_t(2)).size != input.value.at(ff_dim_t(2)).size) { + throw mk_runtime_error("MultiHeadAttentionAttrs: embed_dim not match"); + } + + if (input.query.at(ff_dim_t(2)).size != attrs.embed_dim || + input.key.at(ff_dim_t(2)).size != attrs.embed_dim || + input.value.at(ff_dim_t(2)).size != attrs.embed_dim) { + throw mk_runtime_error( + "MultiHeadAttentionAttrs: input's embed_dim not match to attrs"); + } + + if (attrs.embed_dim != (attrs.num_heads * attrs.kdim)) { + throw mk_runtime_error( + "MultiHeadAttentionAttrs: embed_dim not match to num_heads * kdim"); + } + + // TODO: how to deal with the degree + // q = wq*x , k = wk*x, v = wv*x (seq_len, batch_size, embed_dim) + // k->(seq_len, num_head, batch_size, kdim) + // v->(seq_len, num_head, batch_size, vdim) + // q->(seq_len, num_head, batch_size, kdim) + // attn = q @k (seq_len, num_head, batch_size, batch_size) + // attn = attn @v (seq_len, num_head, batch_size, vdim) + // attn = attn.transpose(1,2) (seq_len, batch_size, num_head, vdim) + // + + // Note: we support tensor parallelism for seq_len/batch_size/embed_dim + ParallelTensorShape output = input.query; + for (int i = 0; i < output.num_dims(); i++) { + output.at(ff_dim_t(i)).degree = input.query.at(ff_dim_t(i)).degree; + output.at(ff_dim_t(i)).is_replica_dim = + input.query.at(ff_dim_t(i)).degree > 1; + } + return output; } -TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &query_shape, - TensorShape const &key_shape, - TensorShape const &value_shape) { - ParallelTensorShape parallel_shape = - get_output_shape(attrs, - static_cast(query_shape), - static_cast(key_shape), - static_cast(value_shape)); - return get_tensor_shape_unsafe(parallel_shape); +// https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html, +// we consider the batch size +// query/key/value: 4D dimensions +// query:[, , , ] + +// key:[, , , ] + +// value:[, , , ] +// multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + +// ### k:(<, ,, ,, ) + +// k->(<, ,, ,, ) //num_head * kdim = embed_dim , dk2 = dk4 + +// v->(<, ,, ,, ) //num_head * vdim = embed_dim , dv2 = dv4 + +// q->(<, , . ,, ) //num_head * kdim = embed_dim , dq2 = dq4 + +// we have dk1 = dv1 = dq1 dk2 = dk4=dv2=dv4=dq2=dq4 + +// 1)/ attn = q @k (, , , +// , ) + +// how to decide the ra11//da10/da11/da12/da13/da14? ⇒ I think da11 =dk1, da12 = +// dk2, da13. = dk3, da14 = dq3 + +// rk * dk3 * dk4=rq * dq3 * dq4 = ra11 * da13 * da14 = ra11 * dk3 * dq3 + +// => ra11 = (rk * dk4) / dq3 = (rq * dq4) / dk3 , ra11/da10 = rq / dq0, + +// =>da10 = ra11 * dq0 / rq = dq0 * dq4 / dk3 + +// output attn: (< (rq * dq4) / dk3, dq0 * dq4 / dk3, t>, , +// , , ) + +// 2)attn = attn @v (seq_len, num_head, batch_size, vdim) + +// input attn:(< (rq * dq4) / dk3, dq0 * dq4 / dk3, t>, , +// , , ) + +// input v: ((<, , , ,, ) //num_head * vdim = embed_dim + +// output attn:(, , , , + +// how to decide ra21//da20/da21/da22/da23/da24? ⇒ da21 = dk1, da22 = dk2, da23 +// = dk3, da24 = dv4 + +// ra21 * da23 * da24 = rv * dv3 * dv4 ⇒ ra21 = (rv * dv3) / dk3 + +// ra21 / da20 = (rq * dq4) / dk3 / (dq0 * dq4 / dk3) ⇒ da20 = (rv * dv3 * dq0) +// / (rq * dk3) + +// output attn:(<(rv * dv3) / dk3 , (rv * dv3 * dq0) / (rq * dk3), t>, , , , ) + +// 3) attn = attn.transpose(1,2 ) (seq_len, batch_size, num_head, vdim) + +// input attn:(<(rv * dv3) / dk3 , (rv * dv3 * dq0) / (rq * dk3), t>, , , , + +// output attn:(<(rv * dv3) / dk3 , (rv * dv3 * dq0) / (rq * dk3), t>, , , , + +// 4)attn = attn.reshape(seq_len, batch_size, num_head*vdim) + +// input attn:(<(rv * dv3) / dk3 , (rv * dv3 * dq0) / (rq * dk3), t>, , , , + +// output attn:(<(rv * dv3) / dk3 , (rv * dv3 * dq0) / (rq * dk3), t>, , , , + +ParallelTensorShape get_output_shape( + MultiHeadAttentionAttrs const &attrs, + MultiHeadAttentionInputs const &input) { + + if (input.query.num_dims() != 4 || input.key.num_dims() != 4 || + input.value.num_dims() != 4) { + throw mk_runtime_error("MultiHeadAttentionAttrs: num_dims != 4"); + } + + if (input.query.at(ff_dim_t(1)).size != input.key.at(ff_dim_t(1)).size || + input.query.at(ff_dim_t(1)).size != input.value.at(ff_dim_t(1)).size || + input.key.at(ff_dim_t(1)).size != input.value.at(ff_dim_t(1)).size) { + throw mk_runtime_error("MultiHeadAttentionAttrs: seq_len not match"); + } + + if (input.query.at(ff_dim_t(2)).size != input.key.at(ff_dim_t(2)).size || + input.query.at(ff_dim_t(2)).size != input.value.at(ff_dim_t(2)).size || + input.key.at(ff_dim_t(2)).size != input.value.at(ff_dim_t(2)).size) { + throw mk_runtime_error("MultiHeadAttentionAttrs: batch_size not match"); + } + + if (input.query.at(ff_dim_t(3)).size != input.key.at(ff_dim_t(3)).size || + input.query.at(ff_dim_t(3)).size != input.value.at(ff_dim_t(3)).size || + input.key.at(ff_dim_t(3)).size != input.value.at(ff_dim_t(3)).size) { + throw mk_runtime_error("MultiHeadAttentionAttrs: embed_dim not match"); + } + + if (input.query.at(ff_dim_t(3)).size != attrs.embed_dim || + input.key.at(ff_dim_t(3)).size != attrs.embed_dim || + input.value.at(ff_dim_t(3)).size != attrs.embed_dim) { + throw mk_runtime_error( + "MultiHeadAttentionAttrs: input's embed_dim not match to attrs"); + } + + if (attrs.embed_dim != (attrs.num_heads * attrs.kdim)) { + throw mk_runtime_error( + "MultiHeadAttentionAttrs: embed_dim not match to num_heads * kdim"); + } + + ParallelTensorShape output = input.key; + + output.at(ff_dim_t(0)).size = + (input.value.at(ff_dim_t(0)).size * input.value.at(ff_dim_t(2)).degree) / + input.key.at(ff_dim_t(2)).degree; // rv3 * dv3 / dk3 + output.at(ff_dim_t(0)).degree = + (input.value.at(ff_dim_t(0)).size * input.value.at(ff_dim_t(2)).degree * + input.query.at(ff_dim_t(0)).degree) / + (input.query.at(ff_dim_t(0)).size * input.key.at(ff_dim_t(2)).degree); + // (rv * dv3 * dq0) / (rq * dk3) + output.at(ff_dim_t(0)).is_replica_dim = true; + + return output; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc index 1cc8c5cfda..62c0f525e8 100644 --- a/lib/op-attrs/src/batch_matmul.cc +++ b/lib/op-attrs/src/batch_matmul.cc @@ -1,7 +1,70 @@ #include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/exception.decl.h" +#include "utils/exception.h" namespace FlexFlow { +// lhs: [, ,, ] +// rhs:[, ,, ] +// in the original tensor, we assume the dl1/dr1 is 1 +// output:[, , , ] +// how to decide the r3, d01, do3, do4 +// Note: Lsize = r1 * dl3 * dl4, Rsize = r2 * dr3 * dr4 , Rsize = Lsize +// do3 = dl3, do4 = dr4 +// so, r3 = Lsize / do3 / do4 +// r3 / do1 = r1 / dl1 +ParallelTensorShape get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &lhs, + ParallelTensorShape const &rhs) { + if (lhs.num_dims() != 4 || rhs.num_dims() != 4) { + throw mk_runtime_error("rhs or lhs dimension is not 4"); + } + + int rl = lhs.at(ff_dim_t(0)).size; // replicate_num of lhs + int dl1 = lhs.at(ff_dim_t(0)).degree; // degree of 0 dimension + int dl3 = lhs.at(ff_dim_t(3)).degree; // degree of third dimension + int dr4 = rhs.at(ff_dim_t(4)).degree; // degree of fouth dimenstion + + int lsize = lhs.get_volume(); + int rsize = rhs.get_volume(); + if (lsize != rsize) { + throw mk_runtime_error("BatchMatmulAttrs::get_output_shape, the volume of " + "lhs and rhs are not matched "); + } + + if (lhs.at(ff_dim_t(1)).size != rhs.at(ff_dim_t(1)).size) { + throw mk_runtime_error( + "BatchMatmulAttrs::get_output_shape, batch size is not equal"); + } + + if (lhs.at(ff_dim_t(3)).size != rhs.at(ff_dim_t(3)).size) { + throw mk_runtime_error( + "BatchMatmulAttrs::get_output_shape: forth demension of lhs and third " + "dementions of rhs are not match"); + } + + // 4D tensor + ParallelTensorShape output_shape = lhs; + + output_shape.at(ff_dim_t(0)).size = lsize / (dl3 * dr4); + output_shape.at(ff_dim_t(0)).degree = + output_shape.at(ff_dim_t(0)).size / + (rl / dl1); // this may have some problem + output_shape.at(ff_dim_t(0)).is_replica_dim = true; + + output_shape.at(ff_dim_t(3)).size = lhs.at(ff_dim_t(3)).size; + output_shape.at(ff_dim_t(3)).degree = dl3; + output_shape.at(ff_dim_t(3)).is_replica_dim = false; + + output_shape.at(ff_dim_t(4)).size = rhs.at(ff_dim_t(4)).size(); + output_shape.at(ff_dim_t(4)).degree = dr4; + output_shape.at(ff_dim_t(4)).is_replica_dim = false; + + return output_shape; +} + /* bool BatchMatmulAttrs::is_valid( */ /* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { */ diff --git a/lib/op-attrs/src/batch_norm.cc b/lib/op-attrs/src/batch_norm.cc index 4e352d5f1c..5e22c8147d 100644 --- a/lib/op-attrs/src/batch_norm.cc +++ b/lib/op-attrs/src/batch_norm.cc @@ -1,3 +1,18 @@ #include "op-attrs/ops/batch_norm.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { +// input_shape: [b, c, h, w] +// output: [b, c, h, w] +ParallelTensorShape get_output_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (!input_shape.is_valid() || input_shape.num_dims() != 4) { + throw mk_runtime_error( + "BatchNormAttrs::get_output_shape: input_shape is invalid"); + } + + // the degree of the output is the same as the input_shape + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/broadcast.cc b/lib/op-attrs/src/broadcast.cc index c69f480b84..f0de4cc807 100644 --- a/lib/op-attrs/src/broadcast.cc +++ b/lib/op-attrs/src/broadcast.cc @@ -1,3 +1,6 @@ #include "op-attrs/ops/broadcast.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// what's the definition of broadcast for get_output_shape +} // namespace FlexFlow diff --git a/lib/op-attrs/src/cast.cc b/lib/op-attrs/src/cast.cc index e4ab178a7e..60f899fed2 100644 --- a/lib/op-attrs/src/cast.cc +++ b/lib/op-attrs/src/cast.cc @@ -1,11 +1,17 @@ #include "op-attrs/ops/cast.h" +#include "utils/exception.h" namespace FlexFlow { -/* bool CastAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* bool valid = input.is_valid(); */ -/* valid &= (input.at(input.num_dims() - 1).degree == 1); */ -/* return valid; */ -/* } */ +ParallelTensorShape get_output_shape(CastAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (!input_shape.is_valid()) { + throw mk_runtime_error( + "CastAttrs::get_output_shape: input_shape is invalid"); + } + ParallelTensorShape output_shape = input_shape; + output_shape.data_type = attrs.dtype; + return output_shape; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/combine.cc b/lib/op-attrs/src/combine.cc index cdca524538..ee77fd08b6 100644 --- a/lib/op-attrs/src/combine.cc +++ b/lib/op-attrs/src/combine.cc @@ -1,18 +1,18 @@ #include "op-attrs/ops/combine.h" +#include "utils/exception.decl.h" #include "utils/hash-utils.h" namespace FlexFlow { - -/* bool CombineAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* return input.at(this->combine_legion_dim).degree % this->combine_degree == - * 0; */ -/* } */ - -/* ParallelTensorShape CombineAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->combine_legion_dim).degree /= this->combine_degree; */ -/* return output; */ -/* } */ +ParallelTensorShape + get_output_shape_shape(CombineAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output_shape = input_shape; + /* + output_shape.at(attrs.combine_dim).degree /= attrs.combine_degree; + output_shape.at(attrs.combine_dim).is_replica_dim = + output_shape.at(attrs.combine_dim).degree > 1;*/ + NOT_IMPLEMENTED(); + return output_shape; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/concat.cc b/lib/op-attrs/src/concat.cc index 065c58f365..5efe8855d8 100644 --- a/lib/op-attrs/src/concat.cc +++ b/lib/op-attrs/src/concat.cc @@ -1,14 +1,35 @@ #include "op-attrs/ops/concat.h" +#include "utils/exception.decl.h" +#include "utils/exception.h" namespace FlexFlow { +ParallelTensorShape + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + ParallelTensorShape output = inputs[0]; + for (auto &i : inputs) { + if (attrs.axis >= i.num_dims() || i.is_valid() == false) { + throw mk_runtime_error("ConcatAttrs::get_output_shape: axis is out of " + "range or input is invalid"); + } + } -/* bool ConcatAttrs::is_valid( */ -/* std::vector const &input) const { */ -/* bool valid = true; */ -/* for (auto p : input) { */ -/* valid &= p.is_valid(); */ -/* } */ -/* return valid; */ -/* } */ + int dims = inputs[0].num_dims(); + for (int i = 1; i < inputs.size(); i++) { + if (inputs[i].num_dims() != dims) { + throw mk_runtime_error(" the input dims not matched at i:", i); + } + } + + for (auto &i : inputs) { + output.at(ff_dim_t(attrs.axis)).size += i.at(ff_dim_t(attrs.axis)).size; + } + output.at(ff_dim_t(0)).is_replica_dim = true; + // note: how to decide the degee? + for (int i = 1; i < output.num_dims(); i++) { + output.at(ff_dim_t(i)).is_replica_dim = false; + } + return output; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/conv_2d.cc b/lib/op-attrs/src/conv_2d.cc index d000d31feb..75c61a82af 100644 --- a/lib/op-attrs/src/conv_2d.cc +++ b/lib/op-attrs/src/conv_2d.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/parallel_tensor_shape.h" #include "parallel_dim_mapping_record.h" #include "parallel_dim_mapping_record_solver.h" +#include "utils/exception.h" #include "utils/vector.h" namespace FlexFlow { @@ -81,27 +84,47 @@ std::vector return mappings; } -/* bool Conv2DAttrs::is_valid(ParallelTensorShape const &input_shape) const { */ -/* bool is_valid = true; */ -/* is_valid &= input_shape.is_valid(); */ -/* is_valid &= this->calculate_output_shape(input_shape).is_valid(); */ -/* is_valid &= this->calculate_kernel_shape(input_shape).is_valid(); */ -/* if (use_bias) { */ -/* is_valid &= this->calculate_bias_shape(input_shape).is_valid(); */ -/* } */ +// input: (, , , < input_h, di4, f>, +// ) -/* // TODO FIXME: Currently disable parallelizing the height and width - * dimension */ -/* if (input_shape.at(0).degree > 1 || input_shape.at(1).degree > 1) { */ -/* return false; */ -/* } */ +// kernel(Conv2DAttrs): out_channels, kernel_h, kernel_w, stride_h, stride_w, +// padding_h, padding_w, -/* return is_valid; */ +// output shape:(, , , , ) -/* } */ +// output_h = (input_h + 2 * padding_h - kernel_h) / stride_h + 1 +// output_w = (input_w + 2 * padding_w - kernel_w) / stride_w + 1 +// assert: for the kernel, dk1 == dk2=dk4=dk4=dk5=1 +// question:how to decide the ro/do3/do4/do5? +// I think: do3= di3, di4= do4, di5 = do5, do1=di1, ro=ri +ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input) { + if (input.num_dims() != 5) { + throw mk_runtime_error("Conv2DAttrs::get_output_shape: input is invalid"); + } + if (attrs.kernel_h > input.at(ff_dim_t(3)).size || + attrs.kernel_w > input.at(ff_dim_t(4)).size) { + throw mk_runtime_error( + "Conv2DAttrs::get_output_shape: kernel size is larger than input size"); + } -/* OperatorType Conv2DAttrs::op_type() const { */ -/* return OP_CONV2D; */ -/* } */ + ParallelTensorShape output = input; + output.at(ff_dim_t(0)).is_replica_dim = true; + output.at(ff_dim_t(2)).size = attrs.out_channels; + output.at(ff_dim_t(3)).size = + (input.at(ff_dim_t(3)).size + 2 * attrs.padding_h - attrs.kernel_h) / + attrs.stride_h + + 1; + output.at(ff_dim_t(4)).size = + (input.at(ff_dim_t(4)).size + 2 * attrs.padding_w - attrs.kernel_w) / + attrs.stride_w + + 1; + for (int i = 1; i < output.num_dims(); i++) { + output.at(ff_dim_t(i)).is_replica_dim = false; + } + + return output; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/dropout.cc b/lib/op-attrs/src/dropout.cc new file mode 100644 index 0000000000..7bdae67af9 --- /dev/null +++ b/lib/op-attrs/src/dropout.cc @@ -0,0 +1,11 @@ +#include "dropout.h" +#include "op-attrs/get_output_shapes.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(DropoutAttrs const &attrs, + ParallelTensorShape const &input_shape) { + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/element_binary.cc b/lib/op-attrs/src/element_binary.cc index b713c6753f..57236bb04f 100644 --- a/lib/op-attrs/src/element_binary.cc +++ b/lib/op-attrs/src/element_binary.cc @@ -1,3 +1,35 @@ #include "op-attrs/ops/element_binary.h" +#include "op-attrs/ff_dim.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &atts, + ParallelTensorShape const &lhs, + ParallelTensorShape const &rhs) { + ParallelTensorShape output = lhs.num_dims() >= rhs.num_dims() ? lhs : rhs; + // how to decide its degree and size for replicate_num + output.at(ff_dim_t(0)).is_replica_dim = false; + for (int i = 1; i < output.num_dims(); i++) { + if (i >= lhs.num_dims()) { + output.at(ff_dim_t(i)) = rhs.at(ff_dim_t(i)); + } else if (i >= rhs.num_dims()) { + output.at(ff_dim_t(i)) = lhs.at(ff_dim_t(i)); + } else if (lhs.at(ff_dim_t(i)).size == rhs.at(ff_dim_t(i)).size) { + output.at(ff_dim_t(i)) = lhs.at(ff_dim_t(i)); + } else if (lhs.at(ff_dim_t(i)).size == 1) { + output.at(ff_dim_t(i)) = rhs.at(ff_dim_t(i)); + } else if (rhs.at(ff_dim_t(i)).size == 1) { + output.at(ff_dim_t(i)) = lhs.at(ff_dim_t(i)); + } else { + throw mk_runtime_error( + "Operands of shapes {} and {} could not be broadcast together", + lhs, + rhs); + } + } + + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/element_unary.cc b/lib/op-attrs/src/element_unary.cc index 481151fafb..08622e6f63 100644 --- a/lib/op-attrs/src/element_unary.cc +++ b/lib/op-attrs/src/element_unary.cc @@ -1,3 +1,10 @@ #include "op-attrs/ops/element_unary.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &atts, + ParallelTensorShape const &input_shape) { + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/embedding.cc index 02cbfaa031..8407eeebfd 100644 --- a/lib/op-attrs/src/embedding.cc +++ b/lib/op-attrs/src/embedding.cc @@ -1,3 +1,39 @@ #include "op-attrs/ops/embedding.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// pytorch nn.Embedding +// Embedding OP: (num_embeddings, embedding_dim) (num_entries, out_channels) +// input:(, < b, di2, f>, < seq_len, di3, f>) +// EmbeddingAttrs:req num_entries, out_channels; +// output:(, , , ) +ParallelTensorShape get_output_shape(EmbeddingAttrs const &attrs, + ParallelTensorShape const &input) { + if (input.num_dims() != 3) { + throw mk_runtime_error("for embedding, input shape must be 3D"); + } + + std::vector data; + data.resize(4); + data[0] = input.at(ff_dim_t(0)); + data[0].is_replica_dim = true; + data[1] = input.at(ff_dim_t(1)); + data[2] = input.at(ff_dim_t(2)); + data[3].size = attrs.out_channels; // TODO:what's the embedding_dim? + data[3].is_replica_dim = false; + + ParallelTensorShape output = ParallelTensorShape( + ParallelTensorDims(TensorDims(data.begin(), data.end())), + attrs.data_type); + + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/flat.cc b/lib/op-attrs/src/flat.cc index 75d31beae4..0cfe71b398 100644 --- a/lib/op-attrs/src/flat.cc +++ b/lib/op-attrs/src/flat.cc @@ -1,6 +1,8 @@ #include "op-attrs/ops/flat.h" +#include "op-attrs/ff_dim.h" #include "parallel_dim_mapping_record.h" #include "parallel_dim_mapping_record_solver.h" +#include "utils/exception.h" #include namespace FlexFlow { @@ -14,6 +16,44 @@ namespace Output { constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; } +// flat is like the pytorch view +// tensor = torch.randn(2, 3, 4) ,flattened_tensor = tensor.view(-1) #shape: +// (24) +// input: (, , , ......) +// assume d1=d2=d3 +// output: 2d dimention (, ) +ParallelTensorShape get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input) { + if (input.num_dims() < 2) { + throw mk_runtime_error("for flat,its dims must greater than 2"); + } + + int degree = input.at(ff_dim_t(1)).degree; + for (int i = 1; i < input.num_dims(); i++) { + if (degree != input.at(ff_dim_t(i)).degree) { + throw mk_runtime_error( + "for flat, all degree should be equal, but elemement ", i, " not"); + } + } + std::vector data; + data.resize(2); + data[0] = input.at(ff_dim_t(0)); + data[0].is_replica_dim = true; + data[1].degree = input.at(ff_dim_t(1)).degree; + data[1].size = input.at(ff_dim_t(1)).size; + data[1].is_replica_dim = false; + + for (int i = 2; i < input.num_dims(); i++) { + data[1].size *= input.at(ff_dim_t(i)).size; + } + + ParallelTensorShape output = ParallelTensorShape( + ParallelTensorDims(TensorDims(data.begin(), data.end())), + input.data_type); + + return output; +} + /* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ /* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ diff --git a/lib/op-attrs/src/gather.cc b/lib/op-attrs/src/gather.cc index 4f2c13c794..7402bdc67c 100644 --- a/lib/op-attrs/src/gather.cc +++ b/lib/op-attrs/src/gather.cc @@ -1,7 +1,34 @@ #include "op-attrs/ops/gather.h" +#include "utils/exception.h" namespace FlexFlow { +// https://pytorch.org/docs/stable/generated/torch.gather.html +// todo: why return a vector? +std::vector + get_output_shapes(GatherAttrs const &attrs, + ParallelTensorShape const &input, + ParallelTensorShape const &index) { + if (input.num_dims() != index.num_dims()) { + throw mk_runtime_error( + "for gather, the dimensions of input and index are not match"); + } + + for (int i = 1; i < input.num_dims(); i++) { + if (i != attrs.dim && + input.at(ff_dim_t(i)).size <= index.at(ff_dim_t(i)).size) { + throw mk_runtime_error( + "Gather: index.size(d) <= input.size(d) for all dimensions d != dim"); + } + + ParallelTensorShape output = index; + output.at(ff_dim_t(0)) = input.at(ff_dim_t(0)); + std::vector results; + // NOTE(lambda):why return a vector? + results.push_back(output); + return results; + } +} /* bool GatherAttrs::is_valid(ParallelTensorShape const &lhs, * ParallelTensorShape const &rhs) const { */ /* if (lhs.num_dims() != rhs.num_dims()) { */ diff --git a/lib/op-attrs/src/groupby.cc b/lib/op-attrs/src/groupby.cc index 96c9db2838..17c091e02e 100644 --- a/lib/op-attrs/src/groupby.cc +++ b/lib/op-attrs/src/groupby.cc @@ -1,3 +1,42 @@ #include "op-attrs/ops/groupby.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// import torch +// data = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) +// # group index tensor group_indices +// group_indices = torch.tensor([0, 1, 0, 2, 1, 2, 0, 1]) + +// # groupby operator +// unique_indices, unique_inverse_indices = +// torch.unique(group_indices,return_inverse=True) + +// print(f"unique_indices: {unique_indices} and unique_inverse_indices: +// {unique_inverse_indices}") + +// grouped_data = [] + +// for i in unique_indices: # use unique_inverse_indices +// group_data = data[unique_inverse_indices == i] +// grouped_data.append(group_data) + +// for i, group in enumerate(grouped_data): +// print(f"Group {i}: {group}") + +// Group 0: tensor([10, 30, 70]) +// Group 1: tensor([20, 50, 80]) +// Group 2: tensor([40, 60]) + +ParallelTensorShape get_output_shape(Group_byAttrs const &attrs, + ParallelTensorShape const &input_shape, + ParallelTensorShape const &index) { + if (input_shape.num_dims() != index.num_dims()) { + throw mk_runtime_error( + "Group_by: input and index must have the same number of dimensions"); + } + // Note: how to decide the groupby output shape? + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/layer_norm.cc b/lib/op-attrs/src/layer_norm.cc index ab88de3622..43211fbf24 100644 --- a/lib/op-attrs/src/layer_norm.cc +++ b/lib/op-attrs/src/layer_norm.cc @@ -1,3 +1,16 @@ #include "op-attrs/ops/layer_norm.h" +#include "utils/exceptions.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// todo: maybe we need to set the degree of parallel_dim +ParallelTensorShape get_output_shape(LayerNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (input.num_dims() < 2) { + throw mk_runtime_error("LayerNorm: input must have at least 2 dimensions"); + } + // output shape is smae as input + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/linear.cc b/lib/op-attrs/src/linear.cc index 16a94e7f6c..6c1748f517 100644 --- a/lib/op-attrs/src/linear.cc +++ b/lib/op-attrs/src/linear.cc @@ -1,3 +1,31 @@ #include "op-attrs/ops/linear.h" +#include "op-attrs/ff_dim.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// https://pytorch.org/docs/stable/generated/torch.nn.Linear.html +// torch.nn.Linear(in_features, out_features, bias=True, device=None, +// dtype=None) +// pytorch: input shape:{batch_size, input_channels} +// pytorch linearattrs: should be {input_channels, output_channels} +// pytorch: output shape:{batch_size, output_channels} +// question: the Linearattrs doesn't have input_channels +// input: (, , ) +// linearattrs: should be {input_channels, output_channels} +// the Linearattrs doesn't have input_channels, just have output_channels +// output:(, , > +// I think do1 = di1, do = ri, do2= di2, do3 = di3 + +ParallelTensorShape get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { + ParallelTensorShape output_shape = input; + if (input.num_dims() != 3) { + throw mk_runtime_error("LinearAttrs: input shape should be 3D"); + } + + output_shape.at(ff_dim_t(2)).size = attrs.out_channels; + return output_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/parallel_tensor_shape.cc index 9a36e7d11b..5848991c13 100644 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/parallel_tensor_shape.cc @@ -1,4 +1,5 @@ #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ff_dim.h" #include "utils/containers.h" #include "utils/hash-utils.h" @@ -13,6 +14,15 @@ static std::vector lift_dims(TensorDims const &dims) { return lifted_dims; } +int ParallelTensorShape::get_volume() const { + int volume = this->at(ff_dim_t(0)).size; + for (int i = 1; i < num_dims(); i++) { + volume *= this->at(ff_dim_t(0)).degree; + } + + return volume; +} + ParallelTensorDims::ParallelTensorDims(TensorDims const &dims) : data(lift_dims(dims)) {} diff --git a/lib/op-attrs/src/pool_2d.cc b/lib/op-attrs/src/pool_2d.cc index 0867aeb344..23e1c6dd3d 100644 --- a/lib/op-attrs/src/pool_2d.cc +++ b/lib/op-attrs/src/pool_2d.cc @@ -1,6 +1,8 @@ #include "op-attrs/ops/pool_2d.h" +#include "op-attrs/ff_dim.h" #include "parallel_dim_mapping_record.h" #include "parallel_dim_mapping_record_solver.h" +#include "utils/exception.h" namespace FlexFlow { @@ -39,6 +41,55 @@ static ParallelDimMappingSolution return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); } +// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html +// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html +// input:(< ri, di1, t>, , , , ) + +// Pool2DAttrs: req kernel_h, kernel_w, stride_h, stride_w, padding_h, +// padding_w; + +// for avgpool2d: output shape:(< ri, di1, t>, , , +// <1,1,f>, <1,1,f> ) + +// for maxpool2d, output shape:(< ri, di1, t>, , , +// , ) + +// output_height = (input_height + 2 * padding_h - kernel_h) / stride_h + 1 +// output_width = (input_width + 2 * padding_w - kernel_w) / stride_w + 1 +ParallelTensorShape get_output_shape(Pool2DAttrs const &attrs, + ParallelTensorShape const &input) { + if (input.num_dims() != 5) { + throw mk_runtime_error("Pool2DAttrs, input shape should be 5D"); + } + + if (attrs.pool_type == PoolOp::AVG) { + std::vector data; + data.resize(4); + data[0] = input.at(ff_dim_t(0)); + data[1] = input.at(ff_dim_t(1)); + data[2] = {1, 1, false}; + data[3] = {1, 1, false}; + ParallelTensorShape output = ParallelTensorShape( + ParallelTensorDims(TensorDims(data.begin(), data.end())), + input.data_type); + return output; + } else if (attrs.pool_type == PoolOp::MAX) { + ParallelTensorShape output_shape = input; + output_shape.at(ff_dim_t(3)).size = + (input.at(ff_dim_t(3)).size + 2 * attrs.padding_h - attrs.kernel_h) / + attrs.stride_h + + 1; + output_shape.at(ff_dim_t(4)).size = + (input.at(ff_dim_t(4)).size + 2 * attrs.padding_w - attrs.kernel_w) / + attrs.stride_w + + 1; + return output_shape; + } else { + throw mk_runtime_error("Pool2DAttrs: pool type is not supported"); + } +} + /* ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape * const &input) const { */ /* return solve_mappings(input).output_shapes.at(0); */ diff --git a/lib/op-attrs/src/reduce.cc b/lib/op-attrs/src/reduce.cc index 9d1770d5be..b28c722268 100644 --- a/lib/op-attrs/src/reduce.cc +++ b/lib/op-attrs/src/reduce.cc @@ -1,3 +1,22 @@ #include "op-attrs/ops/reduce.h" +#include "utils/exception.decl.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// +ParallelTensorShape get_output_shape(ReduceAttrs const &attrs, + ParallelTensorShape const &input) { + if (input.num_dims() - attrs.axes.size() == 1) { + throw mk_runtime_error(" for reduce, the input and attrs.axes must match"); + } + ParallelTensorShape output = input; + for (int i = 0; i < attrs.axes.size(); i++) { + output.at(attrs.axes.at(i)).size = 1; + output.at(attrs.axes.at(i)).is_replica_dim = false; + } + + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/reduction.cc b/lib/op-attrs/src/reduction.cc index 22fc9bab6a..48455bd706 100644 --- a/lib/op-attrs/src/reduction.cc +++ b/lib/op-attrs/src/reduction.cc @@ -10,4 +10,11 @@ namespace FlexFlow { /* return output; */ /* } */ +ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output(input_shape.dims, input_shape.data_type); + output.at(attrs.reduction_dim).size = 1; + return output; +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/repartition.cc b/lib/op-attrs/src/repartition.cc index 672e68b4f6..3046cf1ca7 100644 --- a/lib/op-attrs/src/repartition.cc +++ b/lib/op-attrs/src/repartition.cc @@ -1,11 +1,22 @@ #include "op-attrs/ops/repartition.h" +#include "op-attrs/parallel_dim.h" +#include "utils/exception.h" namespace FlexFlow { -/* bool RepartitionAttrs::is_valid(ParallelTensorShape const &input_shape) const - * { */ -/* ParallelDim dim = input_shape.at(this->repartition_legion_dim); */ -/* return (dim.size % this->repartition_degree * dim.degree == 0); */ -/* } */ +// this may be wrong partition by n multiplies degree by n and keeps shape the +// same +ParallelTensorShape get_output_shape(RepartitionAttrs const &attrs, + ParallelTensorShape const &input) { + ParallelDim dim = input.at(attrs.repartition_dim); + if (dim.size % attrs.repartition_degree * dim.degree != 0) { + throw mk_runtime_error("RepartitionAttrs: input.at(attrs.repartition_dim) " + "attrs.repartition_degree * dim.degree != 0"); + } + ParallelTensorShape output(input.dims, input.data_type); + output.at(attrs.repartition_dim).degree *= + attrs.repartition_degree; // NOTE: this may have some problem + return output; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/replicate.cc b/lib/op-attrs/src/replicate.cc index 73ad288d8c..b3c4e8e970 100644 --- a/lib/op-attrs/src/replicate.cc +++ b/lib/op-attrs/src/replicate.cc @@ -1,3 +1,28 @@ #include "op-attrs/ops/replicate.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/parallel_dim.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// replicate by n multiplies degree by n and shape by n +// seems it is like pytorch's repeat +// original_tensor = torch.tensor([1, 2, 3]) torch.Size([3]) +/// replicated_tensor = original_tensor.repeat(3) torch.Size([9]) + +// original_tensor = torch.randn(2, 3, 4) torch.Size([2, 3, 4]) +// repeated_tensor = original_tensor.repeat(3, 1, 1) torch.Size([6, 3, 4]) + +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, + ParallelTensorShape const &input) { + if (attrs.replicate_dim >= input.num_dims() || attrs.replicate_degree <= 0) { + throw mk_runtime_error("ReplicateAttrs::get_output_shape: axis is out of " + "range or input is invalid"); + } + ParallelTensorShape output = input; + output.at(ff_dim_t(0)).is_replica_dim = true; + output.at(ff_dim_t(0)).size *= attrs.replicate_degree; + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/reshape.cc b/lib/op-attrs/src/reshape.cc index e8349e1f26..1cbbd36863 100644 --- a/lib/op-attrs/src/reshape.cc +++ b/lib/op-attrs/src/reshape.cc @@ -1,3 +1,60 @@ #include "op-attrs/ops/reshape.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// https://pytorch.org/docs/stable/generated/torch.reshape.html +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, + ParallelTensorShape const &input) { + std::size_t input_volume = + input.dims.get_volume() / input.at(ff_dim_t(0)).size; + std::size_t attrs_volume = 1; + for (int i = 0; i < attrs.shape.dims.num_dims(); i++) { + attrs_volume *= attrs.shape.at(ff_dim_t(i)); + } + if (input_volume != attrs_volume) { + throw mk_runtime_error("ReshapeAttrs: input_volume != attrs_volume"); + } + + std::vector data; + + if (attrs.shape.dims.num_dims() == 1) { + // infer the shape + if (attrs.shape.at(ff_dim_t(0)) == -1) { + // the output shape will be (, ) + data.resize(2); + data[0] = input.at(ff_dim_t(0)); + data[1].size = input_volume; + // how to decide the degree? + ParallelTensorShape output = ParallelTensorShape( + ParallelTensorDims(TensorDims(data.begin(), data.end())), + input.data_type); + return output; + } else { + // i = attrs.shape.at(ff_dim_t(0) + // the output shape will be (, , ) + data.resize(3); + data[0] = input.at(ff_dim_t(0)); + data[1].size = attrs.shape.at(ff_dim_t(0)); + data[2].size = input_volume / attrs.shape.at(ff_dim_t(0)); + for (int i = 1; i < 3; i++) { + // how to decide the degree? + data[i].is_replica_dim = false; + } + ParallelTensorShape output = ParallelTensorShape( + ParallelTensorDims(TensorDims(data.begin(), data.end())), + input.data_type); + return output; + } + } + + ParallelTensorDims dims{attrs.shape.dims}; + ParallelTensorShape output = {dims, input.data_type}; + + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/reverse.cc b/lib/op-attrs/src/reverse.cc new file mode 100644 index 0000000000..f79fcfd0ed --- /dev/null +++ b/lib/op-attrs/src/reverse.cc @@ -0,0 +1,15 @@ +#include "op-attrs/ops/reverse.h" +#include "op-attrs/ff_dim.h" +#include "utils/exception.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (attrs.axis < 0 || attrs.axis >= input_shape.num_dims()) { + throw mk_runtime_error("ReverseAttrs: axis is invalid"); + } + return input_shape; +} + +}; // namespace FlexFlow diff --git a/lib/op-attrs/src/softmax.cc b/lib/op-attrs/src/softmax.cc index 9f95da4fb7..eff13aab59 100644 --- a/lib/op-attrs/src/softmax.cc +++ b/lib/op-attrs/src/softmax.cc @@ -1,3 +1,14 @@ #include "op-attrs/ops/softmax.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (input_shape.num_dims() < 2) { + throw mk_runtime_error("SoftmaxAttrs: input_shape.num_dims() < 2"); + } + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/split.cc b/lib/op-attrs/src/split.cc index acda8f3262..8ab083c46b 100644 --- a/lib/op-attrs/src/split.cc +++ b/lib/op-attrs/src/split.cc @@ -1,3 +1,44 @@ #include "op-attrs/ops/split.h" +#include "op-attrs/ff_dim.h" +#include "utils/containers.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +std::vector + get_output_shapes(SplitAttrs const &attrs, + ParallelTensorShape const &input) { + + std::size_t dims_sum = sum(attrs.splits); + if (dims_sum != input.at(ff_dim_t(attrs.axis)).size) { + throw mk_runtime_error( + "SplitAttrs: dims_sum != input.at(ff_dim_t(attrs.axis)).size"); + } + + std::vector outputs; + for (std::size_t i = 0; i < attrs.splits.size(); ++i) { + outputs.emplace_back(input); + outputs.back().at(ff_dim_t(attrs.axis)).size = attrs.splits[i]; + outputs.back().at(ff_dim_t(attrs.axis)).degree = + input.at(ff_dim_t(attrs.axis)).degree; + outputs.back().at(ff_dim_t(attrs.axis)).is_replica_dim = attrs.axis == 0; + } + return outputs; +} + +std::vector + get_output_shape(SplitAttrs const &attrs, + ParallelTensorShape const &input) { + std::size_t dims_sum = sum(attrs.splits); + if (dims_sum != input.at(ff_dim_t(attrs.axis)).size) { + throw mk_runtime_error( + "SplitAttrs: dims_sum != input.at(ff_dim_t(attrs.axis)).size"); + } + + std::vector outputs; + for (std::size_t i = 0; i < attrs.splits.size(); ++i) { + } + return outputs; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/topk.cc b/lib/op-attrs/src/topk.cc index 9d701e4868..8c1d043a57 100644 --- a/lib/op-attrs/src/topk.cc +++ b/lib/op-attrs/src/topk.cc @@ -1,3 +1,22 @@ #include "op-attrs/ops/topk.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, + ParallelTensorShape const &input) { + + if (attrs.k > input.at(ff_dim_t(attrs.axis)).size) { + throw mk_runtime_error( + "TopKAttrs: k > input.at(ff_dim_t(attrs.axis)).size"); + } + + ParallelTensorShape output = input; + output.at(ff_dim_t(attrs.axis)).size = attrs.k; + output.at(ff_dim_t(attrs.axis)).degree = + input.at(ff_dim_t(attrs.axis)).degree; + output.at(ff_dim_t(attrs.axis)).is_replica_dim = false; + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/transpose.cc b/lib/op-attrs/src/transpose.cc index ad4a84a3d5..ac853bd0ed 100644 --- a/lib/op-attrs/src/transpose.cc +++ b/lib/op-attrs/src/transpose.cc @@ -1,3 +1,38 @@ #include "op-attrs/ops/transpose.h" +#include "op-attrs/ff_dim.h" +#include "utils/exception.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +// assume input:[, , , ] +// perem is [1,2] +// output:[, , , ] +ParallelTensorShape get_output_shape(TransposeAttrs const &attrs, + ParallelTensorShape const &input) { + if (attrs.perm.size() != 2) { + throw mk_runtime_error("TransposeAttrs: perm.size() != 2"); + } + + auto dim0 = attrs.perm[0]; // dim0 and dim1 should not be 0 + auto dim1 = attrs.perm[1]; + if (dim0 <= 0 || dim1 <= 0 || dim0 >= input.num_dims() || + dim1 >= input.num_dims()) { + throw mk_runtime_error("TransposeAttrs: dim0 <= 0 || dim1 <= 0 || dim0 >= " + "input.num_dims() || dim1 >= input.num_dims()"); + } + + ParallelTensorShape output = input; + int temp = input.at(ff_dim_t(dim0)).size; + int degree = input.at(ff_dim_t(dim0)).degree; + output.at(ff_dim_t(dim0)).size = input.at(ff_dim_t(dim1)).size; + output.at(ff_dim_t(dim1)).size = temp; + output.at(ff_dim_t(dim0)).degree = input.at(ff_dim_t(dim1)).degree; + output.at(ff_dim_t(dim1)).degree = degree; + output.at(ff_dim_t(dim0)).is_replica_dim = dim0 == 0; + output.at(ff_dim_t(dim1)).is_replica_dim = dim1 == 0; + return output; +} + +} + +} // namespace FlexFlow