diff --git a/lib/op-attrs/CMakeLists.txt b/lib/op-attrs/CMakeLists.txt index 778be53d7c..244958c76e 100644 --- a/lib/op-attrs/CMakeLists.txt +++ b/lib/op-attrs/CMakeLists.txt @@ -6,7 +6,7 @@ ff_add_library( PUBLIC_INCLUDE include/ PRIVATE_INCLUDE - src/ + src/ DEPS utils ) diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 2417f37fdb..53ad2ae679 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -94,6 +94,32 @@ struct formatter<::FlexFlow::DataType> : formatter { } }; +template <> +struct formatter<::FlexFlow::DataTypeValue> : formatter { + template + auto format(::FlexFlow::DataTypeValue v, FormatContext &ctx) + -> decltype(ctx.out()) { + using namespace FlexFlow; + + string_view s = "unknown"; + if (auto const *f32 = get_if>(&v)) { + s = fmt::to_string(*f32); + } else if (auto const *f64 = get_if>(&v)) { + s = fmt::to_string(*f64); + } else if (auto const *i32 = get_if>(&v)) { + s = fmt::to_string(*i32); + } else if (auto const *i64 = get_if>(&v)) { + s = fmt::to_string(*i64); + } else if (auto const *h = get_if>(&v)) { + s = fmt::to_string(*h); + } else if (auto const *b = get_if>(&v)) { + s = fmt::to_string(*b); + } + return formatter::format(s, ctx); + } +}; + } // namespace fmt +// namespace fmt #endif diff --git a/lib/op-attrs/include/op-attrs/ops/aggregate.h b/lib/op-attrs/include/op-attrs/ops/aggregate.h index faf16472b1..224f65e19c 100644 --- a/lib/op-attrs/include/op-attrs/ops/aggregate.h +++ b/lib/op-attrs/include/op-attrs/ops/aggregate.h @@ -16,6 +16,9 @@ struct AggregateAttrs { req lambda_bal; }; FF_VISITABLE_STRUCT(AggregateAttrs, n, lambda_bal); +FF_VISIT_FMTABLE(AggregateAttrs); +CHECK_FMTABLE(AggregateAttrs); +CHECK_VALID_OP_ATTR(AggregateAttrs); DataType get_datatype(AggregateAttrs const &); bool is_valid(AggregateAttrs const &, @@ -32,7 +35,6 @@ ParallelTensorShape ParallelTensorShape const &full_gate_gradients, std::vector const &exp_preds); -CHECK_VALID_OP_ATTR(AggregateAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/aggregate_spec.h b/lib/op-attrs/include/op-attrs/ops/aggregate_spec.h index 8373452dfa..76f472d922 100644 --- a/lib/op-attrs/include/op-attrs/ops/aggregate_spec.h +++ b/lib/op-attrs/include/op-attrs/ops/aggregate_spec.h @@ -12,6 +12,9 @@ struct AggregateSpecAttrs { req lambda_bal; }; FF_VISITABLE_STRUCT(AggregateSpecAttrs, n, lambda_bal); +FF_VISIT_FMTABLE(AggregateSpecAttrs); +CHECK_FMTABLE(AggregateSpecAttrs); +CHECK_VALID_OP_ATTR(AggregateSpecAttrs); ParallelTensorShape get_output_shape(AggregateSpecAttrs const &, @@ -21,7 +24,6 @@ ParallelTensorShape ParallelTensorShape const &gate_gradients_full, std::vector const &exp_preds); -CHECK_VALID_OP_ATTR(AggregateSpecAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index ec3e592607..e270cf86af 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -21,6 +21,9 @@ FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, bias, add_bias_kv, add_zero_attn); +FF_VISIT_FMTABLE(MultiHeadAttentionAttrs); +CHECK_FMTABLE(MultiHeadAttentionAttrs); +CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); template struct MultiHeadAttentionInputs @@ -70,7 +73,6 @@ ParallelTensorShape TensorShape get_output_shape(MultiHeadAttentionAttrs const &, MultiHeadAttentionInputs const &); -CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index c74824570c..ebd8d32c63 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -11,7 +11,8 @@ struct BatchMatmulAttrs { req a_seq_length_dim, b_seq_length_dim; }; FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); - +FF_VISIT_FMTABLE(BatchMatmulAttrs); +CHECK_FMTABLE(BatchMatmulAttrs); CHECK_VALID_OP_ATTR(BatchMatmulAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 4ec823d4ae..200d679d97 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -11,11 +11,12 @@ struct BatchNormAttrs { req relu; }; FF_VISITABLE_STRUCT(BatchNormAttrs, relu); +FF_VISIT_FMTABLE(BatchNormAttrs); +CHECK_FMTABLE(BatchNormAttrs); +CHECK_VALID_OP_ATTR(BatchNormAttrs); ParallelTensorShape get_output_shape(BatchNormAttrs const &); -CHECK_VALID_OP_ATTR(BatchNormAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 433bf23241..c38fe667ac 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -11,7 +11,8 @@ struct BroadcastAttrs { req> target_dims; }; FF_VISITABLE_STRUCT(BroadcastAttrs, target_dims); - +FF_VISIT_FMTABLE(BroadcastAttrs); +CHECK_FMTABLE(BroadcastAttrs); CHECK_VALID_OP_ATTR(BroadcastAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 63563f8df8..2d02c920db 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -12,8 +12,10 @@ struct CastAttrs { req dtype; }; FF_VISITABLE_STRUCT(CastAttrs, dtype); - +FF_VISIT_FMTABLE(CastAttrs); +CHECK_FMTABLE(CastAttrs); CHECK_VALID_OP_ATTR(CastAttrs); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index deaba9e093..c94fbb0dfa 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -13,6 +13,8 @@ struct CombineAttrs { req combine_degree; }; FF_VISITABLE_STRUCT(CombineAttrs, combine_dim, combine_degree); +FF_VISIT_FMTABLE(CombineAttrs); +CHECK_FMTABLE(CombineAttrs); CHECK_VALID_OP_ATTR(CombineAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index b9bd14a231..fc55d3ea45 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -12,6 +12,8 @@ struct ConcatAttrs { ff_dim_t axis; }; FF_VISITABLE_STRUCT(ConcatAttrs, axis); +FF_VISIT_FMTABLE(ConcatAttrs); +CHECK_FMTABLE(ConcatAttrs); CHECK_VALID_OP_ATTR(ConcatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 3034dc8c62..12437c94bc 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -27,6 +27,8 @@ FF_VISITABLE_STRUCT(Conv2DAttrs, groups, activation, use_bias); +FF_VISIT_FMTABLE(Conv2DAttrs); +CHECK_FMTABLE(Conv2DAttrs); CHECK_VALID_OP_ATTR(Conv2DAttrs); TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 8e0049f526..34d76b69ee 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -12,6 +12,8 @@ struct DropoutAttrs { req seed; }; FF_VISITABLE_STRUCT(DropoutAttrs, rate, seed); +FF_VISIT_FMTABLE(DropoutAttrs); +CHECK_FMTABLE(DropoutAttrs); CHECK_VALID_OP_ATTR(DropoutAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index c4a096166d..a8a48529ed 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -20,6 +20,8 @@ FF_VISITABLE_STRUCT(ElementBinaryAttrs, compute_type, should_broadcast_lhs, should_broadcast_rhs); +FF_VISIT_FMTABLE(ElementBinaryAttrs); +CHECK_FMTABLE(ElementBinaryAttrs); CHECK_VALID_OP_ATTR(ElementBinaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 1b72e83cb5..cbca9b01b3 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -14,12 +14,16 @@ struct ElementScalarUnaryAttrs { req scalar; }; FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op, scalar); +FF_VISIT_FMTABLE(ElementScalarUnaryAttrs); +CHECK_FMTABLE(ElementScalarUnaryAttrs); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); struct ElementUnaryAttrs { req op; }; FF_VISITABLE_STRUCT(ElementUnaryAttrs, op); +FF_VISIT_FMTABLE(ElementUnaryAttrs); +CHECK_FMTABLE(ElementUnaryAttrs); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 8b00fa22ce..c3cb2f2368 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -21,6 +21,8 @@ struct EmbeddingAttrs { req data_type; }; FF_VISITABLE_STRUCT(EmbeddingAttrs, num_entries, out_channels, aggr, data_type); +FF_VISIT_FMTABLE(EmbeddingAttrs); +CHECK_FMTABLE(EmbeddingAttrs); CHECK_VALID_OP_ATTR(EmbeddingAttrs); TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 706689199d..1f01395d85 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -9,6 +9,8 @@ namespace FlexFlow { struct FlatAttrs {}; FF_VISITABLE_STRUCT(FlatAttrs); +FF_VISIT_FMTABLE(FlatAttrs); +CHECK_FMTABLE(FlatAttrs); CHECK_VALID_OP_ATTR(FlatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index ca2406ef75..73e96ae3d8 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -12,6 +12,8 @@ struct GatherAttrs { ff_dim_t dim; }; FF_VISITABLE_STRUCT(GatherAttrs, dim); +FF_VISIT_FMTABLE(GatherAttrs); +CHECK_FMTABLE(GatherAttrs); CHECK_VALID_OP_ATTR(GatherAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/groupby.h b/lib/op-attrs/include/op-attrs/ops/groupby.h index 174c40242e..80eabb5529 100644 --- a/lib/op-attrs/include/op-attrs/ops/groupby.h +++ b/lib/op-attrs/include/op-attrs/ops/groupby.h @@ -12,6 +12,8 @@ struct Group_byAttrs { req alpha; }; FF_VISITABLE_STRUCT(Group_byAttrs, n, alpha); +FF_VISIT_FMTABLE(Group_byAttrs); +CHECK_FMTABLE(Group_byAttrs); CHECK_VALID_OP_ATTR(Group_byAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 26c486c9ac..28522c7457 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -8,6 +8,8 @@ namespace FlexFlow { struct InputAttrs {}; FF_VISITABLE_STRUCT(InputAttrs); +FF_VISIT_FMTABLE(InputAttrs); +CHECK_FMTABLE(InputAttrs); CHECK_VALID_OP_ATTR(InputAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index dab055b2c9..d78d0983b6 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -14,6 +14,8 @@ struct LayerNormAttrs { req eps; }; FF_VISITABLE_STRUCT(LayerNormAttrs, axes, elementwise_affine, eps); +FF_VISIT_FMTABLE(LayerNormAttrs); +CHECK_FMTABLE(LayerNormAttrs); CHECK_VALID_OP_ATTR(LayerNormAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 3be8be2040..aa3144b8ff 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -13,12 +13,16 @@ struct L1RegularizerAttrs { req lambda; }; FF_VISITABLE_STRUCT(L1RegularizerAttrs, lambda); +FF_VISIT_FMTABLE(L1RegularizerAttrs); +CHECK_FMTABLE(L1RegularizerAttrs); CHECK_VALID_OP_ATTR(L1RegularizerAttrs); struct L2RegularizerAttrs { req lambda; }; FF_VISITABLE_STRUCT(L2RegularizerAttrs, lambda); +FF_VISIT_FMTABLE(L2RegularizerAttrs); +CHECK_FMTABLE(L2RegularizerAttrs); CHECK_VALID_OP_ATTR(L2RegularizerAttrs); using RegularizerAttrs = variant; @@ -32,6 +36,8 @@ struct LinearAttrs { }; FF_VISITABLE_STRUCT( LinearAttrs, out_channels, use_bias, data_type, activation, regularizer); +FF_VISIT_FMTABLE(LinearAttrs); +CHECK_FMTABLE(LinearAttrs); CHECK_VALID_OP_ATTR(LinearAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions.h index 7a3db05329..0b1a2c12e5 100644 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions.h +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions.h @@ -22,12 +22,16 @@ struct SparseCategoricalCrossEntropyLossAttrs { req replace_labels; // for aggregate_spec: More predictions than labels }; FF_VISITABLE_STRUCT(SparseCategoricalCrossEntropyLossAttrs, replace_labels); +FF_VISIT_FMTABLE(SparseCategoricalCrossEntropyLossAttrs); +CHECK_FMTABLE(SparseCategoricalCrossEntropyLossAttrs); CHECK_VALID_OP_ATTR(SparseCategoricalCrossEntropyLossAttrs); struct OtherLossAttrs { req loss_type; }; FF_VISITABLE_STRUCT(OtherLossAttrs, loss_type); +FF_VISIT_FMTABLE(OtherLossAttrs); +CHECK_FMTABLE(OtherLossAttrs); CHECK_VALID_OP_ATTR(OtherLossAttrs); using LossAttrs = diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index 658e1b7d98..7b2eafa6f3 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -8,6 +8,8 @@ namespace FlexFlow { struct NoopAttrs {}; FF_VISITABLE_STRUCT(NoopAttrs); +FF_VISIT_FMTABLE(NoopAttrs); +CHECK_FMTABLE(NoopAttrs); CHECK_VALID_OP_ATTR(NoopAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index efe29b3b2e..62a7ee491e 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -27,6 +27,8 @@ FF_VISITABLE_STRUCT(Pool2DAttrs, padding_w, pool_type, activation); +FF_VISIT_FMTABLE(Pool2DAttrs); +CHECK_FMTABLE(Pool2DAttrs); CHECK_VALID_OP_ATTR(Pool2DAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 193d3b0dc8..39b3b2e329 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -16,6 +16,8 @@ struct ReduceAttrs { req keepdims; }; FF_VISITABLE_STRUCT(ReduceAttrs, axes, op_type, keepdims); +FF_VISIT_FMTABLE(ReduceAttrs); +CHECK_FMTABLE(ReduceAttrs); CHECK_VALID_OP_ATTR(ReduceAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index f848f879fc..58f81dbccc 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -13,6 +13,8 @@ struct ReductionAttrs { req reduction_degree; }; FF_VISITABLE_STRUCT(ReductionAttrs, reduction_dim, reduction_degree); +FF_VISIT_FMTABLE(ReductionAttrs); +CHECK_FMTABLE(ReductionAttrs); CHECK_VALID_OP_ATTR(ReductionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 83c4ae870b..078b0aa602 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -13,6 +13,8 @@ struct RepartitionAttrs { req repartition_degree; }; FF_VISITABLE_STRUCT(RepartitionAttrs, repartition_dim, repartition_degree); +FF_VISIT_FMTABLE(RepartitionAttrs); +CHECK_FMTABLE(RepartitionAttrs); CHECK_VALID_OP_ATTR(RepartitionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 92e64a4120..dfc3a44741 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -13,6 +13,8 @@ struct ReplicateAttrs { req replicate_degree; }; FF_VISITABLE_STRUCT(ReplicateAttrs, replicate_dim, replicate_degree); +FF_VISIT_FMTABLE(ReplicateAttrs); +CHECK_FMTABLE(ReplicateAttrs); CHECK_VALID_OP_ATTR(ReplicateAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index b118482a2b..b4cdb84a27 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -11,6 +11,8 @@ struct ReshapeAttrs { TensorShape shape; }; FF_VISITABLE_STRUCT(ReshapeAttrs, shape); +FF_VISIT_FMTABLE(ReshapeAttrs); +CHECK_FMTABLE(ReshapeAttrs); CHECK_VALID_OP_ATTR(ReshapeAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 6030285f14..35582ef9c5 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -11,6 +11,8 @@ struct ReverseAttrs { ff_dim_t axis; }; FF_VISITABLE_STRUCT(ReverseAttrs, axis); +FF_VISIT_FMTABLE(ReverseAttrs); +CHECK_FMTABLE(ReverseAttrs); CHECK_VALID_OP_ATTR(ReverseAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 9a776737f5..9bf53254f6 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -12,6 +12,8 @@ struct SoftmaxAttrs { ff_dim_t dim; }; FF_VISITABLE_STRUCT(SoftmaxAttrs, dim); +FF_VISIT_FMTABLE(SoftmaxAttrs); +CHECK_FMTABLE(SoftmaxAttrs); CHECK_VALID_OP_ATTR(SoftmaxAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index fa66bc46f5..232a996380 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -13,6 +13,8 @@ struct SplitAttrs { ff_dim_t axis; }; FF_VISITABLE_STRUCT(SplitAttrs, splits, axis); +FF_VISIT_FMTABLE(SplitAttrs); +CHECK_FMTABLE(SplitAttrs); CHECK_VALID_OP_ATTR(SplitAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index 413855913c..963c6189a6 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -12,6 +12,8 @@ struct TopKAttrs { req sorted; }; FF_VISITABLE_STRUCT(TopKAttrs, k, sorted); +FF_VISIT_FMTABLE(TopKAttrs); +CHECK_FMTABLE(TopKAttrs); CHECK_VALID_OP_ATTR(TopKAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 87db435979..8b87fca08a 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -12,6 +12,8 @@ struct TransposeAttrs { req> perm; }; FF_VISITABLE_STRUCT(TransposeAttrs, perm); +FF_VISIT_FMTABLE(TransposeAttrs); +CHECK_FMTABLE(TransposeAttrs); CHECK_VALID_OP_ATTR(TransposeAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index b9df2d9037..ae05f39d61 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -12,6 +12,8 @@ struct ParallelDim { req is_replica_dim; }; FF_VISITABLE_STRUCT(ParallelDim, size, degree, is_replica_dim); +FF_VISIT_FMTABLE(ParallelDim); +CHECK_FMTABLE(ParallelDim); bool is_valid(ParallelDim const &); bool is_replica_dim(ParallelDim const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index d38ba75232..4053804a97 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -2,13 +2,17 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #include "parallel_dim.h" +#include "tensor_shape.h" #include "utils/visitable.h" namespace FlexFlow { -struct ParallelTensorDims : public use_visitable_cmp { +struct ParallelTensorDims { explicit ParallelTensorDims(TensorDims const &); + template + ParallelTensorDims(std::vector const &dims) : data(dims) {} + size_t get_volume() const; size_t num_dims() const; @@ -38,16 +42,31 @@ struct ParallelTensorDims : public use_visitable_cmp { const_reverse_iterator crend() const; public: - FFOrdered data; + req> data; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelTensorDims, data); + bool is_valid(ParallelTensorDims const &); TensorDims get_piece_dims(ParallelTensorDims const &); TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ParallelTensorDims, data); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensorDims); +namespace fmt { + +template <> +struct formatter<::FlexFlow::ParallelTensorDims> : formatter { + template + auto format(::FlexFlow::ParallelTensorDims dims, FormatContext &ctx) const + -> decltype(ctx.out()) { + using namespace FlexFlow; + + std::vector v(dims.data.begin(), dims.data.end()); + return formatter::format(fmt::to_string(v), ctx); + } +}; + +} // namespace fmt #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index fd560352bb..5e78350ebe 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -16,15 +16,13 @@ namespace FlexFlow { /** * @brief Represent the shape of a ParallelTensor. */ -struct ParallelTensorShape : public use_visitable_cmp { - ParallelTensorShape() = delete; +struct ParallelTensorShape { + ParallelTensorShape(TensorShape const &); template ParallelTensorShape(Dims const &dims, DataType data_type) : dims(dims), data_type(data_type) {} - ParallelTensorShape(TensorShape const &); - int num_dims() const; ParallelDim const &at(ff_dim_t const &) const; @@ -33,9 +31,12 @@ struct ParallelTensorShape : public use_visitable_cmp { ParallelDim &operator[](ff_dim_t const &); public: - ParallelTensorDims dims; - DataType data_type; + req dims; + req data_type; }; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelTensorShape, dims, data_type); +FF_VISIT_FMTABLE(ParallelTensorShape); +CHECK_FMTABLE(ParallelTensorShape); TensorShape get_piece_shape(ParallelTensorShape const &); int get_num_replica_dims(ParallelTensorShape const &); @@ -49,7 +50,4 @@ std::vector } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ParallelTensorShape, data_type, dims); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensorShape); - #endif diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index fa34860817..c8a5638ef6 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -11,24 +11,35 @@ namespace FlexFlow { using TensorDims = FFOrdered; -struct TensorShape : public use_visitable_cmp { - TensorShape() = delete; - - template - TensorShape(Dims const &dims, DataType data_type) - : dims(dims), data_type(data_type) {} - +struct TensorShape { size_t at(ff_dim_t) const; size_t operator[](ff_dim_t) const; public: - TensorDims dims; - DataType data_type; + req dims; + req data_type; }; +FF_VISITABLE_STRUCT(TensorShape, dims, data_type); +FF_VISIT_FMTABLE(TensorShape); +CHECK_FMTABLE(TensorShape); + } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::TensorShape, dims, data_type); -MAKE_VISIT_HASHABLE(::FlexFlow::TensorShape); +namespace fmt { + +template <> +struct formatter<::FlexFlow::TensorDims> : formatter { + template + auto format(::FlexFlow::TensorDims dims, FormatContext &ctx) const + -> decltype(ctx.out()) { + using namespace FlexFlow; + + std::vector v(dims.begin(), dims.end()); + return formatter::format(fmt::to_string(v), ctx); + } +}; + +} // namespace fmt #endif diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/attention.cc index e9ae6ec803..b15f787668 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/attention.cc @@ -26,6 +26,22 @@ int get_oProjSize(MultiHeadAttentionAttrs const &attrs) { return attrs.embed_dim; } +int get_qSize(MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_kSize(MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_vSize(MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_oSize(ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + int get_qSize(TensorShape const &query_shape) { return query_shape.at(ff_dim_t(0)); } @@ -80,6 +96,11 @@ TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, return get_tensor_shape_unsafe(parallel_shape); } +TensorShape get_output_shape(MultiHeadAttentionAttrs const &, + MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow // Tensor FFModel::multihead_attention(const Tensor query, diff --git a/lib/op-attrs/src/conv_2d.cc b/lib/op-attrs/src/conv_2d.cc index d000d31feb..40ba3c8b41 100644 --- a/lib/op-attrs/src/conv_2d.cc +++ b/lib/op-attrs/src/conv_2d.cc @@ -81,6 +81,14 @@ std::vector return mappings; } +TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +TensorShape get_bias_shape(Conv2DAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + /* bool Conv2DAttrs::is_valid(ParallelTensorShape const &input_shape) const { */ /* bool is_valid = true; */ /* is_valid &= input_shape.is_valid(); */ diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/embedding.cc index 02cbfaa031..56014fcc67 100644 --- a/lib/op-attrs/src/embedding.cc +++ b/lib/op-attrs/src/embedding.cc @@ -1,3 +1,9 @@ #include "op-attrs/ops/embedding.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/get_op_type.cc b/lib/op-attrs/src/get_op_type.cc index 7e3235fa9d..16f92f4b1e 100644 --- a/lib/op-attrs/src/get_op_type.cc +++ b/lib/op-attrs/src/get_op_type.cc @@ -32,6 +32,9 @@ OperatorType get_op_type(DropoutAttrs const &) { OperatorType get_op_type(ElementBinaryAttrs const &attrs) { return attrs.type; } +OperatorType get_op_type(ElementScalarUnaryAttrs const &) { + NOT_IMPLEMENTED(); +} OperatorType get_op_type(ElementUnaryAttrs const &attrs) { return attrs.op; } @@ -71,6 +74,9 @@ OperatorType get_op_type(ReduceAttrs const &) { OperatorType get_op_type(ReshapeAttrs const &) { return Op::RESHAPE; } +OperatorType get_op_type(ReverseAttrs const &) { + return Op::REVERSE; +} OperatorType get_op_type(SplitAttrs const &) { return Op::SPLIT; } diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc index f44a677873..b3a38c167c 100644 --- a/lib/op-attrs/src/get_output_shapes.cc +++ b/lib/op-attrs/src/get_output_shapes.cc @@ -2,8 +2,13 @@ namespace FlexFlow { -ParallelTensorShape as_parallel(TensorShape const &); -std::vector as_parallel(std::vector const &); +ParallelTensorShape as_parallel(TensorShape const &) { + NOT_IMPLEMENTED(); +} + +std::vector as_parallel(std::vector const &) { + NOT_IMPLEMENTED(); +} TensorShape get_output_shape(AggregateAttrs const &attrs, TensorShape const &gate_preds, @@ -20,4 +25,61 @@ TensorShape get_output_shape(AggregateAttrs const &attrs, as_parallel(exp_preds))); } +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +// FIXME: These are added to get rid of the linker errors about missing +// definitions. +template <> +TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +template <> +TensorShape get_output_shape(Conv2DAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +template <> +TensorShape get_output_shape(DropoutAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +template <> +TensorShape get_output_shape(ElementBinaryAttrs const &, + TensorShape const &, + TensorShape const &) { + NOT_IMPLEMENTED(); +} + +template <> +TensorShape get_output_shape(EmbeddingAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +template <> +TensorShape FlexFlow::get_output_shape( + variant const &, + TensorShape const &) { + NOT_IMPLEMENTED(); +} + +template <> +std::vector get_output_shapes(ElementBinaryAttrs const &attrs, + TensorShape const &, + TensorShape const &) { + NOT_IMPLEMENTED(); +} + +template <> +std::vector get_output_shapes(GatherAttrs const &attrs, + TensorShape const &, + TensorShape const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim.cc b/lib/op-attrs/src/parallel_dim.cc index cb9c603508..7e961880c9 100644 --- a/lib/op-attrs/src/parallel_dim.cc +++ b/lib/op-attrs/src/parallel_dim.cc @@ -6,4 +6,8 @@ bool is_valid(ParallelDim const &dim) { return dim.size > 0 && dim.degree >= 1 && dim.size % dim.degree == 0; } +bool is_replica_dim(ParallelDim const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc index 52df0fe0fa..c36cd8c59a 100644 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc +++ b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc @@ -183,6 +183,22 @@ void construct_weight_parallel_dims( input_idx, input_dim, weight_idx, weight_dim, operation)); } +ParallelDimMappingRecord + construct_weight_parallel_dims(int input_dim, + int weight_dim, + int input_idx, + int weight_idx, + tl::optional operation) { + NOT_IMPLEMENTED(); +} + +std::vector construct_weight_parallel_dims( + std::vector> mappings, + int input_idx, + int weight_idx) { + NOT_IMPLEMENTED(); +} + /* void ParallelDimMappingRecordSolver::register_weight_parallel_dims( */ /* std::vector> mappings, int input_idx, int weight_idx) * { */ @@ -226,6 +242,15 @@ void construct_output_parallel_dims( } } +ParallelDimMappingRecord + construct_output_parallel_dims(int input_dim, + int output_dim, + int input_idx, + int output_idx, + tl::optional operation) { + NOT_IMPLEMENTED(); +} + void construct_output_parallel_dims( std::vector &records, std::vector> mappings, @@ -248,6 +273,18 @@ void construct_output_parallel_dims( input_idx, input_dim, output_idx, output_dim, operation)); } +std::vector construct_output_parallel_dims( + std::vector> mappings, int input_idx, int output_idx) { + NOT_IMPLEMENTED(); +} + +std::vector construct_output_parallel_dims( + std::vector> mappings, + int input_idx, + int output_idx) { + NOT_IMPLEMENTED(); +} + /* void register_output_parallel_dims( */ /* std::vector> mappings, int input_idx, int output_idx) * { */ @@ -320,4 +357,16 @@ void construct_output_parallel_dims( /* return solution; */ /* } */ +ParallelDimMappingSolution solve_parallel_dim_mappings( + std::vector const &mappings, + std::vector const &input, + int numWeights, + int numOutputs) { + // There is a definition of this function earlier in this file that might be + // the actual implementation, but is commented out for whatever reason. Rather + // than enabling that, just fail for now so someone else can take a look and + // decide whether to enable it. + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/parallel_tensor_shape.cc index 9a36e7d11b..dca2dd50af 100644 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/parallel_tensor_shape.cc @@ -16,9 +16,97 @@ static std::vector lift_dims(TensorDims const &dims) { ParallelTensorDims::ParallelTensorDims(TensorDims const &dims) : data(lift_dims(dims)) {} +size_t ParallelTensorDims::get_volume() const { + NOT_IMPLEMENTED(); +} + +size_t ParallelTensorDims::num_dims() const { + NOT_IMPLEMENTED(); +} + +ParallelDim const &ParallelTensorDims::at(ff_dim_t const &) const { + NOT_IMPLEMENTED(); +} + +ParallelDim &ParallelTensorDims::at(ff_dim_t const &) { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::iterator ParallelTensorDims::begin() { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::const_iterator ParallelTensorDims::begin() const { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::const_iterator ParallelTensorDims::cbegin() const { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::iterator ParallelTensorDims::end() { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::const_iterator ParallelTensorDims::end() const { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::const_iterator ParallelTensorDims::cend() const { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::reverse_iterator ParallelTensorDims::rbegin() { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::const_reverse_iterator ParallelTensorDims::rbegin() const { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::const_reverse_iterator ParallelTensorDims::crbegin() const { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::reverse_iterator ParallelTensorDims::rend() { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::const_reverse_iterator ParallelTensorDims::rend() const { + NOT_IMPLEMENTED(); +} + +ParallelTensorDims::const_reverse_iterator ParallelTensorDims::crend() const { + NOT_IMPLEMENTED(); +} + ParallelTensorShape::ParallelTensorShape(TensorShape const &tensor_shape) : dims(tensor_shape.dims), data_type(tensor_shape.data_type) {} +int ParallelTensorShape::num_dims() const { + NOT_IMPLEMENTED(); +} + +ParallelDim const &ParallelTensorShape::at(ff_dim_t const &) const { + NOT_IMPLEMENTED(); +} + +ParallelDim &ParallelTensorShape::at(ff_dim_t const &) { + NOT_IMPLEMENTED(); +} + +ParallelDim const &ParallelTensorShape::operator[](ff_dim_t const &) const { + NOT_IMPLEMENTED(); +} + +ParallelDim &ParallelTensorShape::operator[](ff_dim_t const &) { + NOT_IMPLEMENTED(); +} + +TensorShape get_piece_shape(ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + int get_num_replica_dims(ParallelTensorShape const &shape) { return count(shape.dims, is_replica_dim); } @@ -37,4 +125,13 @@ bool is_valid(ParallelTensorShape const &shape) { return is_valid(shape.dims); } +TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +std::vector + get_tensor_shapes_unsafe(std::vector const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/tensor_shape.cc b/lib/op-attrs/src/tensor_shape.cc index d31fb7cf21..0a8b1d1e19 100644 --- a/lib/op-attrs/src/tensor_shape.cc +++ b/lib/op-attrs/src/tensor_shape.cc @@ -1,3 +1,13 @@ #include "op-attrs/tensor_shape.h" -namespace FlexFlow {} +namespace FlexFlow { + +size_t TensorShape::at(ff_dim_t) const { + NOT_IMPLEMENTED(); +} + +size_t TensorShape::operator[](ff_dim_t) const { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index 81009b0f1f..e1875ca694 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -13,3 +13,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 7f01439712..2722214251 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -5,11 +5,8 @@ namespace FlexFlow { -struct ComputationGraphBuilder - : public use_visitable_cmp { +struct ComputationGraphBuilder { public: - ComputationGraphBuilder(); - // C++ APIs for constructing models // Add an exp layer Tensor exp(Tensor const &, optional const &name = nullopt); @@ -280,16 +277,11 @@ struct ComputationGraphBuilder optional const &name = nullopt); public: - ComputationGraph computation_graph; + req computation_graph; }; -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::ComputationGraphBuilder, computation_graph); +FF_VISITABLE_STRUCT(ComputationGraphBuilder, computation_graph); -namespace FlexFlow { -static_assert( - is_well_behaved_value_type_no_hash::value, ""); -} +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/file_format/v1/activation.h b/lib/pcg/include/pcg/file_format/v1/activation.h new file mode 100644 index 0000000000..ca12f27a56 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/activation.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_ACTIVATION_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_ACTIVATION_H + +#include "op-attrs/activation.h" +#include "utils/json.h" + +namespace FlexFlow { + +enum class V1Activation { RELU, SIGMOID, TANH, GELU }; + +NLOHMANN_JSON_SERIALIZE_ENUM(V1Activation, + {{V1Activation::RELU, "RELU"}, + {V1Activation::SIGMOID, "SIGMOID"}, + {V1Activation::TANH, "TANH"}, + {V1Activation::GELU, "GELU"}}); + +V1Activation to_v1(Activation const &a); +Activation from_v1(V1Activation const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/create_grad.h b/lib/pcg/include/pcg/file_format/v1/create_grad.h new file mode 100644 index 0000000000..de6678f6e4 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/create_grad.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_CREATE_GRAD_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_CREATE_GRAD_H + +#include "pcg/create_grad.h" +#include "utils/json.h" + +namespace FlexFlow { + +enum class V1CreateGrad { YES, NO }; + +NLOHMANN_JSON_SERIALIZE_ENUM(V1CreateGrad, + {{V1CreateGrad::YES, "YES"}, + {V1CreateGrad::NO, "NO"}}); + +V1CreateGrad to_v1(CreateGrad const &cg); +CreateGrad from_v1(V1CreateGrad const &vcg); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/data_type.h b/lib/pcg/include/pcg/file_format/v1/datatype.h similarity index 86% rename from lib/pcg/include/pcg/file_format/v1/data_type.h rename to lib/pcg/include/pcg/file_format/v1/datatype.h index dad98e462d..a15a04945a 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type.h +++ b/lib/pcg/include/pcg/file_format/v1/datatype.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DATA_TYPE_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DATA_TYPE_H +#include "op-attrs/datatype.h" #include "utils/fp16.h" #include "utils/json.h" #include "utils/variant.h" @@ -26,6 +27,12 @@ NLOHMANN_JSON_SERIALIZE_ENUM(V1DataType, {V1DataType::FLOAT, "FLOAT"}, {V1DataType::DOUBLE, "DOUBLE"}}); +V1DataType to_v1(DataType const &d); +DataType from_v1(V1DataType const &vd); + +V1DataTypeValue to_v1(DataTypeValue const &dv); +DataTypeValue from_v1(V1DataTypeValue const &vdv); + } // namespace FlexFlow namespace nlohmann { diff --git a/lib/pcg/include/pcg/file_format/v1/dim_ordered.h b/lib/pcg/include/pcg/file_format/v1/dim_ordered.h new file mode 100644 index 0000000000..d175005c5a --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/dim_ordered.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DIM_ORDERED_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DIM_ORDERED_H + +#include "op-attrs/dim_ordered.h" +#include "utils/json.h" + +namespace FlexFlow { + +template +struct V1DimOrdered { + std::vector contents; + + bool operator!=(V1DimOrdered const &o) { + return this->contents != o.contents; + } +}; + +template +V1DimOrdered to_v1(DimOrdered const &dim); + +template +DimOrdered from_v1(V1DimOrdered const &vdim); + +template +using V1FFOrdered = V1DimOrdered; + +template +V1FFOrdered to_v1(FFOrdered const &o); + +template +FFOrdered from_v1(V1FFOrdered const &vo); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ff_dim.h b/lib/pcg/include/pcg/file_format/v1/ff_dim.h new file mode 100644 index 0000000000..8e84599024 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ff_dim.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_FF_DIM_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_FF_DIM_H + +#include "op-attrs/ff_dim.h" + +namespace FlexFlow { + +// ff_dim_t is a strong typedef of int. This is unlikely to change, but if it +// does, this signature will need to be updated. +int to_v1(ff_dim_t const &t); +ff_dim_t from_v1(int const &vt); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h index 71a8adb344..c243f6613e 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs.h @@ -25,6 +25,7 @@ struct V1GraphEdge { req dstNode; req dstIdx; }; + FF_VISITABLE_STRUCT(V1GraphEdge, srcNode, srcIdx, dstNode, dstIdx); CHECK_IS_JSONABLE(V1GraphEdge); @@ -33,12 +34,17 @@ struct V1MultiDiGraph { req> ports; req> edges; }; + FF_VISITABLE_STRUCT(V1MultiDiGraph, nodes, ports, edges); CHECK_IS_JSONABLE(V1MultiDiGraph); + V1MultiDiGraph to_v1(MultiDiGraphView const &); +// FIXME: Add a from_v1 for a MultiDiGraph. + V1MultiDiGraph to_v1(MultiDiGraphView const &, std::unordered_map const &, std::unordered_map const &); +// FIXME: Do we need to add an equivalent from_v1 for this? template struct V1JsonableGraph { @@ -56,20 +62,35 @@ struct V1Layer { req> name; }; FF_VISITABLE_STRUCT(V1Layer, attrs, name); + V1Layer to_v1(Layer const &); +Layer from_v1(V1Layer const &); using V1ComputationGraph = V1JsonableGraph; FF_VISITABLE_STRUCT( V1ComputationGraph, node_labels, outputs, output_labels, graph); CHECK_IS_JSONABLE(V1ComputationGraph); + V1ComputationGraph to_v1(ComputationGraph const &); +ComputationGraph from_V1(V1ComputationGraph const &); + +struct V1Operator { + V1PCGOperatorAttrs attrs; + req> name; +}; +FF_VISITABLE_STRUCT(V1Operator, attrs, name); + +V1Operator to_v1(Operator const &); +Operator from_v1(V1Operator const &); using V1ParallelComputationGraph = - V1JsonableGraph; + V1JsonableGraph; FF_VISITABLE_STRUCT( V1ParallelComputationGraph, node_labels, outputs, output_labels, graph); CHECK_IS_JSONABLE(V1ParallelComputationGraph); + V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); +ParallelComputationGraph from_v1(V1ParallelComputationGraph const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/initializer.h b/lib/pcg/include/pcg/file_format/v1/initializer.h index 24f0320bd9..08f001e7a7 100644 --- a/lib/pcg/include/pcg/file_format/v1/initializer.h +++ b/lib/pcg/include/pcg/file_format/v1/initializer.h @@ -1,7 +1,8 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_INITIALIZER_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_INITIALIZER_H -#include "data_type.h" +#include "datatype.h" +#include "pcg/initializer.h" #include "utils/json.h" #include "utils/required.h" #include "utils/variant.h" @@ -14,16 +15,27 @@ struct V1GlorotInitializer { req seed; }; FF_VISITABLE_STRUCT(V1GlorotInitializer, seed); +CHECK_IS_JSONABLE(V1GlorotInitializer); + +V1GlorotInitializer to_v1(GlorotUniform const &i); +GlorotUniform from_v1(V1GlorotInitializer const &vi); struct V1ZeroInitializer {}; FF_VISITABLE_STRUCT(V1ZeroInitializer); +V1ZeroInitializer to_v1(ZeroInitializer const &i); +ZeroInitializer from_v1(V1ZeroInitializer const &vi); + struct V1UniformInitializer { req seed; req min_val; req max_val; }; FF_VISITABLE_STRUCT(V1UniformInitializer, seed, min_val, max_val); +CHECK_IS_JSONABLE(V1UniformInitializer); + +V1UniformInitializer to_v1(UniformInitializer const &i); +UniformInitializer from_v1(V1UniformInitializer const &vi); struct V1NormInitializer { req seed; @@ -31,11 +43,19 @@ struct V1NormInitializer { req stddev; }; FF_VISITABLE_STRUCT(V1NormInitializer, seed, mean, stddev); +CHECK_IS_JSONABLE(V1NormInitializer); + +V1NormInitializer to_v1(NormInitializer const &i); +NormInitializer from_v1(V1NormInitializer const &vi); struct V1ConstantInitializer { req value; }; FF_VISITABLE_STRUCT(V1ConstantInitializer, value); +CHECK_IS_JSONABLE(V1ConstantInitializer); + +V1ConstantInitializer to_v1(ConstantInitializer const &i); +ConstantInitializer from_v1(V1ConstantInitializer const &vi); using V1Initializer = variant; -} // namespace FlexFlow +V1Initializer to_v1(Initializer const &i); +Initializer from_v1(V1Initializer const &vi); -namespace FlexFlow { -CHECK_IS_JSONABLE(V1GlorotInitializer); -CHECK_IS_JSONABLE(V1ZeroInitializer); -CHECK_IS_JSONABLE(V1UniformInitializer); -CHECK_IS_JSONABLE(V1NormInitializer); -CHECK_IS_JSONABLE(V1ConstantInitializer); -CHECK_IS_JSONABLE(V1Initializer); } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/file_format/v1/op.h b/lib/pcg/include/pcg/file_format/v1/op.h new file mode 100644 index 0000000000..e4010b10aa --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/op.h @@ -0,0 +1,193 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OP_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OP_H + +#include "op-attrs/op.h" +#include "utils/json.h" + +namespace FlexFlow { + +enum class V1Op { + NOOP, + INPUT, + WEIGHT, + CONV2D, + DROPOUT, + LINEAR, + BATCHMATMUL, + POOL2D, + SCALAR_MULTIPLY, + SCALAR_ADD, + SCALAR_FLOOR_DIV, + SCALAR_TRUE_DIV, + SCALAR_SUB, + RELU, + IDENTITY, + SIGMOID, + TANH, + ELU, + FLAT, + SOFTMAX, + BATCHNORM, + CONCAT, + SPLIT, + EMBEDDING, + GROUP_BY, + CACHE, + AGGREGATE, + AGG_SPEC, + // OP_ELEMENTWISE, + RESHAPE, + REVERSE, + TRANSPOSE, + EW_ADD, + EW_MUL, + MATMUL, + MUL, + ENLARGE, + SQUEEZE, + UNSQUEEZE, + EW_SUB, + EW_DIV, + EW_EQUAL, + EW_GREATER, + EW_LESS, + EW_MAX, + EW_MIN, + REDUCE_ARGMAX, + REDUCE_ARGMIN, + REDUCE_MAX, + REDUCE_MEAN, + REDUCE_MIN, + REDUCE_PROD, + REDUCE_SUM, + PAD, + SHAPE, + SIZE, + TOPK, + WHERE, + CEIL, + CAST, + EXP, + ROUND, + LOG, + LOGICAL_NOT, + SQRT, + SIN, + COS, + LEAKYRELU, + SLICE, + RESIZE, + PRELU, + GELU, + MULTIHEAD_ATTENTION, + FUSED, + RSQRT, + POW, + MEAN, + LAYERNORM, + GATHER, + BROADCAST, + REPARTITION, + COMBINE, + REPLICATE, + REDUCTION, + BATCH, + PIPELINE, + FUSED_PARALLEL, +}; + +NLOHMANN_JSON_SERIALIZE_ENUM(V1Op, + {{V1Op::NOOP, "NOOP"}, + {V1Op::INPUT, "INPUT"}, + {V1Op::WEIGHT, "WEIGHT"}, + {V1Op::CONV2D, "CONV2D"}, + {V1Op::DROPOUT, "DROPOUT"}, + {V1Op::LINEAR, "LINEAR"}, + {V1Op::BATCHMATMUL, "BATCHMATMUL"}, + {V1Op::POOL2D, "POOL2D"}, + {V1Op::SCALAR_MULTIPLY, "SCALAR_MULTIPLY"}, + {V1Op::SCALAR_ADD, "SCALAR_ADD"}, + {V1Op::SCALAR_FLOOR_DIV, "SCALAR_FLOOR_DIV"}, + {V1Op::SCALAR_TRUE_DIV, "SCALAR_TRUE_DIV"}, + {V1Op::SCALAR_SUB, "SCALAR_SUB"}, + {V1Op::RELU, "RELU"}, + {V1Op::IDENTITY, "IDENTITY"}, + {V1Op::SIGMOID, "SIGMOID"}, + {V1Op::TANH, "TANH"}, + {V1Op::ELU, "ELU"}, + {V1Op::FLAT, "FLAT"}, + {V1Op::SOFTMAX, "SOFTMAX"}, + {V1Op::BATCHNORM, "BATCHNORM"}, + {V1Op::CONCAT, "CONCAT"}, + {V1Op::SPLIT, "SPLIT"}, + {V1Op::EMBEDDING, "EMBEDDING"}, + {V1Op::GROUP_BY, "GROUP_BY"}, + {V1Op::CACHE, "CACHE"}, + {V1Op::AGGREGATE, "AGGREGATE"}, + {V1Op::AGG_SPEC, "AGG_SPEC"}, + {V1Op::RESHAPE, "RESHAPE"}, + {V1Op::REVERSE, "REVERSE"}, + {V1Op::TRANSPOSE, "TRANSPOSE"}, + {V1Op::EW_ADD, "EW_ADD"}, + {V1Op::EW_MUL, "EW_MUL"}, + {V1Op::MATMUL, "MATMUL"}, + {V1Op::MUL, "MUL"}, + {V1Op::ENLARGE, "ENLARGE"}, + {V1Op::SQUEEZE, "SQUEEZE"}, + {V1Op::UNSQUEEZE, "UNSQUEEZE"}, + {V1Op::EW_SUB, "EW_SUB"}, + {V1Op::EW_DIV, "EW_DIV"}, + {V1Op::EW_EQUAL, "EW_EQUAL"}, + {V1Op::EW_GREATER, "EW_GREATER"}, + {V1Op::EW_LESS, "EW_LESS"}, + {V1Op::EW_MAX, "EW_MAX"}, + {V1Op::EW_MIN, "EW_MIN"}, + {V1Op::REDUCE_ARGMAX, "REDUCE_ARGMAX"}, + {V1Op::REDUCE_ARGMIN, "REDUCE_ARGMIN"}, + {V1Op::REDUCE_MAX, "REDUCE_MAX"}, + {V1Op::REDUCE_MEAN, "REDUCE_MEAN"}, + {V1Op::REDUCE_MIN, "REDUCE_MIN"}, + {V1Op::REDUCE_PROD, "REDUCE_PROD"}, + {V1Op::REDUCE_SUM, "REDUCE_SUM"}, + {V1Op::PAD, "PAD"}, + {V1Op::SHAPE, "SHAPE"}, + {V1Op::SIZE, "SIZE"}, + {V1Op::TOPK, "TOPK"}, + {V1Op::WHERE, "WHERE"}, + {V1Op::CEIL, "CEIL"}, + {V1Op::CAST, "CAST"}, + {V1Op::EXP, "EXP"}, + {V1Op::ROUND, "ROUND"}, + {V1Op::LOG, "LOG"}, + {V1Op::LOGICAL_NOT, "LOGICAL_NOT"}, + {V1Op::SQRT, "SQRT"}, + {V1Op::SIN, "SIN"}, + {V1Op::COS, "COS"}, + {V1Op::LEAKYRELU, "LEAKYRELU"}, + {V1Op::SLICE, "SLICE"}, + {V1Op::RESIZE, "RESIZE"}, + {V1Op::PRELU, "PRELU"}, + {V1Op::GELU, "GELU"}, + {V1Op::MULTIHEAD_ATTENTION, + "MULTIHEAD_ATTENTION"}, + {V1Op::FUSED, "FUSED"}, + {V1Op::RSQRT, "RSQRT"}, + {V1Op::POW, "POW"}, + {V1Op::MEAN, "MEAN"}, + {V1Op::LAYERNORM, "LAYERNORM"}, + {V1Op::GATHER, "GATHER"}, + {V1Op::BROADCAST, "BROADCAST"}, + {V1Op::REPARTITION, "REPARTITION"}, + {V1Op::COMBINE, "COMBINE"}, + {V1Op::REPLICATE, "REPLICATE"}, + {V1Op::REDUCTION, "REDUCTION"}, + {V1Op::BATCH, "BATCH"}, + {V1Op::PIPELINE, "PIPELINE"}, + {V1Op::FUSED_PARALLEL, "FUSED_PARALLEL"}}); + +V1Op to_v1(Op const &op); +Op from_v1(V1Op const &vop); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/operator_attrs.h b/lib/pcg/include/pcg/file_format/v1/operator_attrs.h index 2ea87cbf56..71dd4e4ce6 100644 --- a/lib/pcg/include/pcg/file_format/v1/operator_attrs.h +++ b/lib/pcg/include/pcg/file_format/v1/operator_attrs.h @@ -1,19 +1,90 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPERATOR_ATTRS_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPERATOR_ATTRS_H +#include "op-attrs/operator_attrs.h" +#include "ops/aggregate.h" +#include "ops/aggregate_spec.h" +#include "ops/attention.h" +#include "ops/batch_matmul.h" +#include "ops/batch_norm.h" +#include "ops/broadcast.h" +#include "ops/cast.h" +#include "ops/combine.h" +#include "ops/concat.h" +#include "ops/conv_2d.h" +#include "ops/dropout.h" +#include "ops/element_binary.h" +#include "ops/element_unary.h" +#include "ops/embedding.h" +#include "ops/flat.h" +#include "ops/gather.h" +#include "ops/groupby.h" +#include "ops/input.h" +#include "ops/layer_norm.h" +#include "ops/linear.h" +#include "ops/noop.h" +#include "ops/pool_2d.h" +#include "ops/reduce.h" +#include "ops/reduction.h" +#include "ops/repartition.h" +#include "ops/replicate.h" +#include "ops/reshape.h" +#include "ops/reverse.h" +#include "ops/softmax.h" +#include "ops/split.h" +#include "ops/topk.h" +#include "ops/transpose.h" #include "utils/json.h" #include "utils/variant.h" namespace FlexFlow { -struct V1Conv2DAttrs {}; -FF_VISITABLE_STRUCT(V1Conv2DAttrs); +using V1SharedOperatorAttrs = variant; -static_assert( - std::is_same, std::tuple<>>::value, ""); +using V1ParallelOperatorAttrs = variant; -using V1CompGraphOperatorAttrs = variant; -using V1PCGOperatorAttrs = variant; +using V1ComputationGraphAttrs = + variant_join>; +using V1CompGraphOperatorAttrs = V1ComputationGraphAttrs; + +V1CompGraphOperatorAttrs to_v1(CompGraphOperatorAttrs const &attrs); +CompGraphOperatorAttrs from_v1(V1CompGraphOperatorAttrs const &attrs); + +using V1PCGOperatorAttrs = + variant_join; + +V1PCGOperatorAttrs to_v1(PCGOperatorAttrs const &attrs); +PCGOperatorAttrs from_v1(V1PCGOperatorAttrs const &attrs); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/ops/aggregate.h b/lib/pcg/include/pcg/file_format/v1/ops/aggregate.h new file mode 100644 index 0000000000..f65350d70d --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/aggregate.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_AGGREGATE_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_AGGREGATE_H + +#include "op-attrs/ops/aggregate.h" +#include "utils/json.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1AggregateAttrs { + req n; + req lambda_bal; +}; +FF_VISITABLE_STRUCT(V1AggregateAttrs, n, lambda_bal); +CHECK_IS_JSONABLE(V1AggregateAttrs); + +V1AggregateAttrs to_v1(AggregateAttrs const &a); +AggregateAttrs from_v1(V1AggregateAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/aggregate_spec.h b/lib/pcg/include/pcg/file_format/v1/ops/aggregate_spec.h new file mode 100644 index 0000000000..1ad7c823bb --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/aggregate_spec.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_AGGREGATE_SPEC_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_AGGREGATE_SPEC_ATTRS_H + +#include "op-attrs/ops/aggregate_spec.h" +#include "utils/json.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1AggregateSpecAttrs { + req n; + req lambda_bal; +}; +FF_VISITABLE_STRUCT(V1AggregateSpecAttrs, n, lambda_bal); +CHECK_IS_JSONABLE(V1AggregateSpecAttrs); + +V1AggregateSpecAttrs to_v1(AggregateSpecAttrs const &a); +AggregateSpecAttrs from_v1(V1AggregateSpecAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/attention.h b/lib/pcg/include/pcg/file_format/v1/ops/attention.h new file mode 100644 index 0000000000..25e66e7813 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/attention.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_ATTENTION_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_ATTENTION_ATTRS_H + +#include "op-attrs/ops/attention.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1MultiHeadAttentionAttrs { + req embed_dim, num_heads, kdim, vdim; + req dropout; + req bias, add_bias_kv, add_zero_attn; +}; +FF_VISITABLE_STRUCT(V1MultiHeadAttentionAttrs, + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn); +CHECK_IS_JSONABLE(V1MultiHeadAttentionAttrs); + +V1MultiHeadAttentionAttrs to_v1(MultiHeadAttentionAttrs const &a); +MultiHeadAttentionAttrs from_v1(V1MultiHeadAttentionAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/batch_matmul.h b/lib/pcg/include/pcg/file_format/v1/ops/batch_matmul.h new file mode 100644 index 0000000000..744cebba52 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/batch_matmul.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_BATCH_MATMUL_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_BATCH_MATMUL_ATTRS_H + +#include "op-attrs/ops/batch_matmul.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1BatchMatmulAttrs { + req a_seq_length_dim, b_seq_length_dim; +}; +FF_VISITABLE_STRUCT(V1BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); +CHECK_IS_JSONABLE(V1BatchMatmulAttrs); + +V1BatchMatmulAttrs to_v1(BatchMatmulAttrs const &a); +BatchMatmulAttrs from_v1(V1BatchMatmulAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/batch_norm.h b/lib/pcg/include/pcg/file_format/v1/ops/batch_norm.h new file mode 100644 index 0000000000..b78fc8eace --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/batch_norm.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_BATCH_NORM_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_BATCH_NORM_H + +#include "op-attrs/ops/batch_norm.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1BatchNormAttrs { + req relu; +}; +FF_VISITABLE_STRUCT(V1BatchNormAttrs, relu); +CHECK_IS_JSONABLE(V1BatchNormAttrs); + +V1BatchNormAttrs to_v1(BatchNormAttrs const &a); +BatchNormAttrs from_v1(V1BatchNormAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/broadcast.h b/lib/pcg/include/pcg/file_format/v1/ops/broadcast.h new file mode 100644 index 0000000000..cb5b25bf12 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/broadcast.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_BROADCAST_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_BROADCAST_H + +#include "op-attrs/ops/broadcast.h" +#include "utils/json.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1BroadcastAttrs { + // The size of this vector must be <= MAX_TENSOR_DIM + req> target_dims; +}; +FF_VISITABLE_STRUCT(V1BroadcastAttrs, target_dims); +CHECK_IS_JSONABLE(V1BroadcastAttrs); + +V1BroadcastAttrs to_v1(BroadcastAttrs const &a); +BroadcastAttrs from_v1(V1BroadcastAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/cast.h b/lib/pcg/include/pcg/file_format/v1/ops/cast.h new file mode 100644 index 0000000000..54a343ef6d --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/cast.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_CAST_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_CAST_ATTRS_H + +#include "op-attrs/ops/cast.h" +#include "pcg/file_format/v1/datatype.h" +#include "utils/json.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1CastAttrs { + req dtype; +}; +FF_VISITABLE_STRUCT(V1CastAttrs, dtype); +CHECK_IS_JSONABLE(V1CastAttrs); + +V1CastAttrs to_v1(CastAttrs const &a); +CastAttrs from_v1(V1CastAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/combine.h b/lib/pcg/include/pcg/file_format/v1/ops/combine.h new file mode 100644 index 0000000000..a029aa72c8 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/combine.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_COMBINE_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_COMBINE_ATTRS_H + +#include "op-attrs/ops/combine.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1CombineAttrs { + int combine_dim; + req combine_degree; +}; +FF_VISITABLE_STRUCT(V1CombineAttrs, combine_dim, combine_degree); +CHECK_IS_JSONABLE(V1CombineAttrs); + +V1CombineAttrs to_v1(CombineAttrs const &a); +CombineAttrs from_v1(V1CombineAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/concat.h b/lib/pcg/include/pcg/file_format/v1/ops/concat.h new file mode 100644 index 0000000000..a3d657ea33 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/concat.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_CONCAT_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_CONCAT_ATTRS_H + +#include "op-attrs/ops/concat.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ConcatAttrs { + req axis; +}; +FF_VISITABLE_STRUCT(V1ConcatAttrs, axis); +CHECK_IS_JSONABLE(V1ConcatAttrs); + +V1ConcatAttrs to_v1(ConcatAttrs const &a); +ConcatAttrs from_v1(V1ConcatAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/conv_2d.h b/lib/pcg/include/pcg/file_format/v1/ops/conv_2d.h new file mode 100644 index 0000000000..1d6b7fc7c7 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/conv_2d.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_CONV_2D_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_CONV_2D_ATTRS_H + +#include "op-attrs/ops/conv_2d.h" +#include "pcg/file_format/v1/activation.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1Conv2DAttrs { + req out_channels, kernel_h, kernel_w, stride_h, stride_w, padding_h, + padding_w, groups; + req> activation; + req use_bias; +}; + +FF_VISITABLE_STRUCT(V1Conv2DAttrs, + out_channels, + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + groups, + activation, + use_bias); +CHECK_IS_JSONABLE(V1Conv2DAttrs); + +V1Conv2DAttrs to_v1(Conv2DAttrs const &a); +Conv2DAttrs from_v1(V1Conv2DAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/dropout.h b/lib/pcg/include/pcg/file_format/v1/ops/dropout.h new file mode 100644 index 0000000000..73641d7a78 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/dropout.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_DROPOUT_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_DROPOUT_ATTRS_H + +#include "op-attrs/ops/dropout.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1DropoutAttrs { + req rate; + req seed; +}; +FF_VISITABLE_STRUCT(V1DropoutAttrs, rate, seed); +CHECK_IS_JSONABLE(V1DropoutAttrs); + +V1DropoutAttrs to_v1(DropoutAttrs const &a); +DropoutAttrs from_v1(V1DropoutAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/element_binary.h b/lib/pcg/include/pcg/file_format/v1/ops/element_binary.h new file mode 100644 index 0000000000..a22aadc370 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/element_binary.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_ELEMENT_BINARY_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_ELEMENT_BINARY_ATTRS_H + +#include "op-attrs/ops/element_binary.h" +#include "pcg/file_format/v1/datatype.h" +#include "pcg/file_format/v1/op.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ElementBinaryAttrs { + req type; + req compute_type; + req should_broadcast_lhs; + req should_broadcast_rhs; +}; +FF_VISITABLE_STRUCT(V1ElementBinaryAttrs, + type, + compute_type, + should_broadcast_lhs, + should_broadcast_rhs); +CHECK_IS_JSONABLE(V1ElementBinaryAttrs); + +V1ElementBinaryAttrs to_v1(ElementBinaryAttrs const &a); +ElementBinaryAttrs from_v1(V1ElementBinaryAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/element_unary.h b/lib/pcg/include/pcg/file_format/v1/ops/element_unary.h new file mode 100644 index 0000000000..46ac27921a --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/element_unary.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_ELEMENTARY_UNARY_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_ELEMENTARY_UNARY_ATTRS_H + +#include "op-attrs/ops/element_unary.h" +#include "pcg/file_format/v1/op.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ElementScalarUnaryAttrs { + req op; + req scalar; +}; +FF_VISITABLE_STRUCT(V1ElementScalarUnaryAttrs, op, scalar); +CHECK_IS_JSONABLE(V1ElementScalarUnaryAttrs); + +V1ElementScalarUnaryAttrs to_v1(ElementScalarUnaryAttrs const &a); +ElementScalarUnaryAttrs from_v1(V1ElementScalarUnaryAttrs const &va); + +struct V1ElementUnaryAttrs { + req op; +}; +FF_VISITABLE_STRUCT(V1ElementUnaryAttrs, op); +CHECK_IS_JSONABLE(V1ElementUnaryAttrs); + +V1ElementUnaryAttrs to_v1(ElementUnaryAttrs const &a); +ElementUnaryAttrs from_v1(V1ElementUnaryAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/embedding.h b/lib/pcg/include/pcg/file_format/v1/ops/embedding.h new file mode 100644 index 0000000000..8f1299b0e1 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/embedding.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_EMBEDDING_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_EMBEDDING_ATTRS_H + +#include "op-attrs/ops/embedding.h" +#include "pcg/file_format/v1/datatype.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +enum class V1AggregateOp { + SUM, + AVG, +}; + +NLOHMANN_JSON_SERIALIZE_ENUM(V1AggregateOp, + {{V1AggregateOp::SUM, "SUM"}, + {V1AggregateOp::AVG, "AVG"}}); + +V1AggregateOp to_v1(AggregateOp const &op); +AggregateOp from_v1(V1AggregateOp const &vop); + +struct V1EmbeddingAttrs { + req num_entries, out_channels; + req aggr; + req data_type; +}; +FF_VISITABLE_STRUCT( + V1EmbeddingAttrs, num_entries, out_channels, aggr, data_type); +CHECK_IS_JSONABLE(V1EmbeddingAttrs); + +V1EmbeddingAttrs to_v1(EmbeddingAttrs const &a); +EmbeddingAttrs from_v1(V1EmbeddingAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/flat.h b/lib/pcg/include/pcg/file_format/v1/ops/flat.h new file mode 100644 index 0000000000..e736b992ed --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/flat.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_FLAT_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_FLAT_ATTRS_H + +#include "op-attrs/ops/flat.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1FlatAttrs {}; +FF_VISITABLE_STRUCT(V1FlatAttrs); +CHECK_IS_JSONABLE(V1FlatAttrs); + +V1FlatAttrs to_v1(FlatAttrs const &a); +FlatAttrs from_v1(V1FlatAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/gather.h b/lib/pcg/include/pcg/file_format/v1/ops/gather.h new file mode 100644 index 0000000000..22944893ba --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/gather.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_GATHER_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_GATHER_ATTRS_H + +#include "op-attrs/ops/gather.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1GatherAttrs { + req dim; +}; +FF_VISITABLE_STRUCT(V1GatherAttrs, dim); +CHECK_IS_JSONABLE(V1GatherAttrs); + +V1GatherAttrs to_v1(GatherAttrs const &a); +GatherAttrs from_v1(V1GatherAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/groupby.h b/lib/pcg/include/pcg/file_format/v1/ops/groupby.h new file mode 100644 index 0000000000..01b17876fe --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/groupby.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_GROUPBY_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_GROUPBY_ATTRS_H + +#include "op-attrs/ops/groupby.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1Group_byAttrs { + req n; + req alpha; +}; +FF_VISITABLE_STRUCT(V1Group_byAttrs, n, alpha); +CHECK_IS_JSONABLE(V1Group_byAttrs); + +V1Group_byAttrs to_v1(Group_byAttrs const &a); +Group_byAttrs from_v1(V1Group_byAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/input.h b/lib/pcg/include/pcg/file_format/v1/ops/input.h new file mode 100644 index 0000000000..41cb178c84 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/input.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_INPUT_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_INPUT_H + +#include "op-attrs/ops/input.h" +#include "utils/json.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1InputAttrs {}; +FF_VISITABLE_STRUCT(V1InputAttrs); +CHECK_IS_JSONABLE(V1InputAttrs); + +V1InputAttrs to_v1(InputAttrs const &a); +InputAttrs from_v1(V1InputAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/layer_norm.h b/lib/pcg/include/pcg/file_format/v1/ops/layer_norm.h new file mode 100644 index 0000000000..15acc64b0c --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/layer_norm.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_LAYER_NORM_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_LAYER_NORM_ATTRS_H + +#include "op-attrs/ops/layer_norm.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1LayerNormAttrs { + // The size of this vector must be <= MAX_TENSOR_DIMS + std::vector axes; + req elementwise_affine; + req eps; +}; +FF_VISITABLE_STRUCT(V1LayerNormAttrs, axes, elementwise_affine, eps); +CHECK_IS_JSONABLE(V1LayerNormAttrs); + +V1LayerNormAttrs to_v1(LayerNormAttrs const &a); +LayerNormAttrs from_v1(V1LayerNormAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/linear.h b/lib/pcg/include/pcg/file_format/v1/ops/linear.h new file mode 100644 index 0000000000..6b28f6fd95 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/linear.h @@ -0,0 +1,51 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_LINEAR_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_LINEAR_ATTRS_H + +#include "op-attrs/ops/linear.h" +#include "pcg/file_format/v1/activation.h" +#include "pcg/file_format/v1/datatype.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1L1RegularizerAttrs { + req lambda; +}; +FF_VISITABLE_STRUCT(V1L1RegularizerAttrs, lambda); +CHECK_IS_JSONABLE(V1L1RegularizerAttrs); + +V1L1RegularizerAttrs to_v1(L1RegularizerAttrs const &a); +L1RegularizerAttrs from_v1(V1L1RegularizerAttrs const &va); + +struct V1L2RegularizerAttrs { + req lambda; +}; +FF_VISITABLE_STRUCT(V1L2RegularizerAttrs, lambda); +CHECK_IS_JSONABLE(V1L2RegularizerAttrs); + +V1L2RegularizerAttrs to_v1(L2RegularizerAttrs const &a); +L2RegularizerAttrs from_v1(V1L2RegularizerAttrs const &va); + +using V1RegularizerAttrs = variant; +CHECK_IS_JSONABLE(V1RegularizerAttrs); + +V1RegularizerAttrs to_v1(RegularizerAttrs const &a); +RegularizerAttrs from_v1(V1RegularizerAttrs const &va); + +struct V1LinearAttrs { + req out_channels; + req use_bias; + req data_type; + req activation; + req> regularizer; +}; +FF_VISITABLE_STRUCT( + V1LinearAttrs, out_channels, use_bias, data_type, activation, regularizer); +CHECK_IS_JSONABLE(V1LinearAttrs); + +V1LinearAttrs to_v1(LinearAttrs const &a); +LinearAttrs from_v1(V1LinearAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/noop.h b/lib/pcg/include/pcg/file_format/v1/ops/noop.h new file mode 100644 index 0000000000..de313467f7 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/noop.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_NOOP_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_NOOP_H + +#include "op-attrs/ops/noop.h" +#include "utils/json.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1NoopAttrs {}; +FF_VISITABLE_STRUCT(V1NoopAttrs); +CHECK_IS_JSONABLE(NoopAttrs); + +V1NoopAttrs to_v1(NoopAttrs const &a); +NoopAttrs from_v1(V1NoopAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/pool_2d.h b/lib/pcg/include/pcg/file_format/v1/ops/pool_2d.h new file mode 100644 index 0000000000..0777f9e496 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/pool_2d.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_POOL_2D_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_POOL_2D_ATTRS_H + +#include "op-attrs/ops/pool_2d.h" +#include "pcg/file_format/v1/activation.h" +#include "utils/json.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +enum class V1PoolOp { + MAX, + AVG, +}; + +NLOHMANN_JSON_SERIALIZE_ENUM(V1PoolOp, + {{V1PoolOp::MAX, "MAX"}, {V1PoolOp::AVG, "AVG"}}); + +V1PoolOp to_v1(PoolOp const &op); +PoolOp from_v1(V1PoolOp const &vop); + +struct V1Pool2DAttrs { + req kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w; + req pool_type; + req activation; +}; +FF_VISITABLE_STRUCT(V1Pool2DAttrs, + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + pool_type, + activation); +CHECK_IS_JSONABLE(V1Pool2DAttrs); + +V1Pool2DAttrs to_v1(Pool2DAttrs const &a); +Pool2DAttrs from_v1(V1Pool2DAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/reduce.h b/lib/pcg/include/pcg/file_format/v1/ops/reduce.h new file mode 100644 index 0000000000..fbbaeb6594 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/reduce.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REDUCE_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REDUCE_ATTRS_H + +#include "op-attrs/ops/reduce.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "pcg/file_format/v1/op.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ReduceAttrs { + // The size of this vector is <= MAX_TENSOR_DIMS. + req> axes; + req op_type; + req keepdims; +}; +FF_VISITABLE_STRUCT(V1ReduceAttrs, axes, op_type, keepdims); +CHECK_IS_JSONABLE(V1ReduceAttrs); + +V1ReduceAttrs to_v1(ReduceAttrs const &a); +ReduceAttrs from_v1(V1ReduceAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/reduction.h b/lib/pcg/include/pcg/file_format/v1/ops/reduction.h new file mode 100644 index 0000000000..9afe9e9cfd --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/reduction.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REDUCTION_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REDUCTION_ATTRS_H + +#include "op-attrs/ops/reduction.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ReductionAttrs { + int reduction_dim; + req reduction_degree; +}; +FF_VISITABLE_STRUCT(V1ReductionAttrs, reduction_dim, reduction_degree); +CHECK_IS_JSONABLE(V1ReductionAttrs); + +V1ReductionAttrs to_v1(ReductionAttrs const &a); +ReductionAttrs from_v1(V1ReductionAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/repartition.h b/lib/pcg/include/pcg/file_format/v1/ops/repartition.h new file mode 100644 index 0000000000..f1aa75ed54 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/repartition.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REPARTITION_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REPARTITION_ATTRS_H + +#include "op-attrs/ops/repartition.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1RepartitionAttrs { + int repartition_dim; + req repartition_degree; +}; +FF_VISITABLE_STRUCT(V1RepartitionAttrs, repartition_dim, repartition_degree); +CHECK_IS_JSONABLE(V1RepartitionAttrs); + +V1RepartitionAttrs to_v1(RepartitionAttrs const &a); +RepartitionAttrs from_v1(V1RepartitionAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/replicate.h b/lib/pcg/include/pcg/file_format/v1/ops/replicate.h new file mode 100644 index 0000000000..6175a28074 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/replicate.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REPLICATE_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REPLICATE_ATTRS_H + +#include "op-attrs/ops/replicate.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ReplicateAttrs { + int replicate_dim; + req replicate_degree; +}; +FF_VISITABLE_STRUCT(V1ReplicateAttrs, replicate_dim, replicate_degree); +CHECK_IS_JSONABLE(V1ReplicateAttrs); + +V1ReplicateAttrs to_v1(ReplicateAttrs const &a); +ReplicateAttrs from_v1(V1ReplicateAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/reshape.h b/lib/pcg/include/pcg/file_format/v1/ops/reshape.h new file mode 100644 index 0000000000..1e2b289945 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/reshape.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_RESHAPE_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_RESHAPE_ATTRS_H + +#include "op-attrs/ops/reshape.h" +#include "pcg/file_format/v1/tensor_shape.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ReshapeAttrs { + V1TensorShape shape; +}; +FF_VISITABLE_STRUCT(V1ReshapeAttrs, shape); +CHECK_IS_JSONABLE(V1ReshapeAttrs); + +V1ReshapeAttrs to_v1(ReshapeAttrs const &a); +ReshapeAttrs from_v1(V1ReshapeAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/reverse.h b/lib/pcg/include/pcg/file_format/v1/ops/reverse.h new file mode 100644 index 0000000000..5dfab7f8c9 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/reverse.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REVERSE_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_REVERSE_H + +#include "op-attrs/ops/reverse.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "utils/json.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ReverseAttrs { + req axis; +}; +FF_VISITABLE_STRUCT(V1ReverseAttrs, axis); +CHECK_IS_JSONABLE(V1ReverseAttrs); + +V1ReverseAttrs to_v1(ReverseAttrs const &a); +ReverseAttrs from_v1(V1ReverseAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/softmax.h b/lib/pcg/include/pcg/file_format/v1/ops/softmax.h new file mode 100644 index 0000000000..1c7fc94326 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/softmax.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_SOFTMAX_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_SOFTMAX_ATTRS_H + +#include "op-attrs/ops/softmax.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1SoftmaxAttrs { + req dim; +}; +FF_VISITABLE_STRUCT(V1SoftmaxAttrs, dim); +CHECK_IS_JSONABLE(V1SoftmaxAttrs); + +V1SoftmaxAttrs to_v1(SoftmaxAttrs const &a); +SoftmaxAttrs from_v1(V1SoftmaxAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/split.h b/lib/pcg/include/pcg/file_format/v1/ops/split.h new file mode 100644 index 0000000000..cfd5e92d08 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/split.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_SPLIT_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_SPLIT_ATTRS_H + +#include "op-attrs/ops/split.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1SplitAttrs { + // The size of this vector must be <= MAX_TENSOR_DIM + std::vector splits; + req axis; +}; +FF_VISITABLE_STRUCT(V1SplitAttrs, splits, axis); +CHECK_IS_JSONABLE(V1SplitAttrs); + +V1SplitAttrs to_v1(SplitAttrs const &a); +SplitAttrs from_v1(V1SplitAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/topk.h b/lib/pcg/include/pcg/file_format/v1/ops/topk.h new file mode 100644 index 0000000000..b02037c4dd --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/topk.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_TOPK_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_TOPK_ATTRS_H + +#include "op-attrs/ops/topk.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1TopKAttrs { + req k; + req sorted; +}; +FF_VISITABLE_STRUCT(V1TopKAttrs, k, sorted); +CHECK_IS_JSONABLE(V1TopKAttrs); + +V1TopKAttrs to_v1(TopKAttrs const &a); +TopKAttrs from_v1(V1TopKAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/ops/transpose.h b/lib/pcg/include/pcg/file_format/v1/ops/transpose.h new file mode 100644 index 0000000000..66e1a0c871 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/ops/transpose.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_TRANSPOSE_ATTRS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPS_TRANSPOSE_ATTRS_H + +#include "op-attrs/ops/transpose.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1TransposeAttrs { + // The size of this vector must be <= MAX_TENSOR_DIMS + req> perm; +}; +FF_VISITABLE_STRUCT(V1TransposeAttrs, perm); +CHECK_IS_JSONABLE(V1TransposeAttrs); + +V1TransposeAttrs to_v1(TransposeAttrs const &a); +TransposeAttrs from_v1(V1TransposeAttrs const &va); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/parallel_dim.h b/lib/pcg/include/pcg/file_format/v1/parallel_dim.h new file mode 100644 index 0000000000..7f035f326f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/parallel_dim.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_DIM_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_DIM_H + +#include "op-attrs/parallel_dim.h" +#include "utils/type_traits.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ParallelDim { + size_t size; + int degree; + req is_replica_dim; +}; +FF_VISITABLE_STRUCT(V1ParallelDim, size, degree, is_replica_dim); +CHECK_IS_JSONABLE(V1ParallelDim); + +V1ParallelDim to_v1(ParallelDim const &dim); +ParallelDim from_v1(V1ParallelDim const &vdim); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h b/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h index 1ea4cd04de..bef2188ca6 100644 --- a/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h +++ b/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h @@ -1,36 +1,32 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_H -#include "data_type.h" +#include "create_grad.h" +#include "datatype.h" #include "initializer.h" +#include "parallel_tensor_dims.h" +#include "parallel_tensor_shape.h" #include "param_sync.h" +#include "pcg/parallel_tensor.h" #include "utils/json.h" #include "utils/variant.h" #include "utils/visitable.h" namespace FlexFlow { -struct V1ParallelDim { - req size; - req degree; - req is_replica_dim; -}; -FF_VISITABLE_STRUCT(V1ParallelDim, size, degree, is_replica_dim); - -struct V1ParallelTensorShape { - req> dims; - req data_type; -}; -FF_VISITABLE_STRUCT(V1ParallelTensorShape, dims, data_type); - struct V1ParallelTensor { V1ParallelTensorShape shape; - req> sync_type; + req create_gradients; req> initializer; - req create_grad; + req> sync_type; + req> name; }; FF_VISITABLE_STRUCT( - V1ParallelTensor, shape, sync_type, initializer, create_grad); + V1ParallelTensor, shape, create_gradients, initializer, sync_type, name); +CHECK_IS_JSONABLE(V1ParallelTensor); + +V1ParallelTensor to_v1(ParallelTensor const &t); +ParallelTensor from_v1(V1ParallelTensor const &vt); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/parallel_tensor_dims.h b/lib/pcg/include/pcg/file_format/v1/parallel_tensor_dims.h new file mode 100644 index 0000000000..c4a44bf968 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/parallel_tensor_dims.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_DIMS_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_DIMS_H + +#include "op-attrs/parallel_tensor_dims.h" +#include "parallel_dim.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ParallelTensorDims { + req> data; +}; +FF_VISITABLE_STRUCT(V1ParallelTensorDims, data); +CHECK_IS_JSONABLE(V1ParallelTensorDims); + +V1ParallelTensorDims to_v1(ParallelTensorDims const &dims); +ParallelTensorDims from_v1(V1ParallelTensorDims const &vdims); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/parallel_tensor_shape.h b/lib/pcg/include/pcg/file_format/v1/parallel_tensor_shape.h new file mode 100644 index 0000000000..efad97279c --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/parallel_tensor_shape.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_SHAPE_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_SHAPE_H + +#include "datatype.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "parallel_tensor_dims.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +struct V1ParallelTensorShape { + V1ParallelTensorDims dims; + req data_type; +}; +FF_VISITABLE_STRUCT(V1ParallelTensorShape, dims, data_type); +CHECK_IS_JSONABLE(V1ParallelTensorShape); + +V1ParallelTensorShape to_v1(ParallelTensorShape const &t); +ParallelTensorShape from_v1(V1ParallelTensorShape const &vt); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/param_sync.h b/lib/pcg/include/pcg/file_format/v1/param_sync.h index 32769a8d20..e18cc15397 100644 --- a/lib/pcg/include/pcg/file_format/v1/param_sync.h +++ b/lib/pcg/include/pcg/file_format/v1/param_sync.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_PCG_FILE_FORMAT_V1_PARAM_SYNC_H #define _FLEXFLOW_PCG_FILE_FORMAT_V1_PARAM_SYNC_H +#include "op-attrs/param_sync.h" #include "utils/json.h" namespace FlexFlow { @@ -11,6 +12,9 @@ NLOHMANN_JSON_SERIALIZE_ENUM(V1ParamSync, {{V1ParamSync::PARAM_SERVER, "PARAM_SERVER"}, {V1ParamSync::NCCL, "NCCL"}}); +V1ParamSync to_v1(ParamSync const &p); +ParamSync from_v1(V1ParamSync const &vp); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/file_format/v1/tensor.h b/lib/pcg/include/pcg/file_format/v1/tensor.h index e1f6828186..f0311cd718 100644 --- a/lib/pcg/include/pcg/file_format/v1/tensor.h +++ b/lib/pcg/include/pcg/file_format/v1/tensor.h @@ -1,35 +1,28 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_H -#include "data_type.h" +#include "create_grad.h" #include "initializer.h" -#include "op-attrs/tensor_shape.h" #include "param_sync.h" #include "pcg/tensor.h" +#include "tensor_shape.h" #include "utils/visitable.h" -#include namespace FlexFlow { -struct V1TensorShape { - req> dims; - req data_type; -}; -FF_VISITABLE_STRUCT(V1TensorShape, dims, data_type); -CHECK_IS_JSONABLE(V1TensorShape); -V1TensorShape to_v1(TensorShape const &); - struct V1Tensor { V1TensorShape shape; + req create_gradients; req> initializer; - req create_gradients; req> sync_type; req> name; }; FF_VISITABLE_STRUCT( - V1Tensor, shape, initializer, create_gradients, sync_type, name); + V1Tensor, shape, create_gradients, initializer, sync_type, name); CHECK_IS_JSONABLE(V1Tensor); + V1Tensor to_v1(Tensor const &); +Tensor from_v1(V1Tensor const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/tensor_shape.h b/lib/pcg/include/pcg/file_format/v1/tensor_shape.h new file mode 100644 index 0000000000..edab0faaab --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/tensor_shape.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_SHAPE_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_SHAPE_H + +#include "datatype.h" +#include "op-attrs/tensor_shape.h" +#include "utils/visitable.h" + +namespace FlexFlow { + +using V1TensorDims = std::vector; +V1TensorDims from_v1(std::vector const &dims); + +struct V1TensorShape { +public: + std::vector dims; + req data_type; +}; +FF_VISITABLE_STRUCT(V1TensorShape, dims, data_type); +CHECK_IS_JSONABLE(V1TensorShape); + +V1TensorShape to_v1(TensorShape const &t); +TensorShape from_v1(V1TensorShape const &vt); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1.h b/lib/pcg/include/pcg/file_format/v1/v1.h index e2557af4f5..7673fe0611 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1.h +++ b/lib/pcg/include/pcg/file_format/v1/v1.h @@ -4,6 +4,41 @@ #include "graphs.h" #include "pcg/computation_graph.h" -namespace FlexFlow {} +namespace FlexFlow { + +template , int> = 0> +T to_v1(T const &t) { + return t; +} + +template , int> = 0> +T from_v1(T const &vt) { + return vt; +} + +std::string to_v1(std::string const &s); +std::string from_v1(std::string const &vs); + +template +optional to_v1(optional const &t) { + if (t.has_value()) { + return to_v1(t.value()); + } else { + return nullopt; + } +} + +template +optional from_v1(optional const &vt) { + if (vt.has_value()) { + return from_v1(vt.value()); + } else { + return nullopt; + } +} + +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/initializer.h b/lib/pcg/include/pcg/initializer.h index 58e4fcc242..55a9e1e8c5 100644 --- a/lib/pcg/include/pcg/initializer.h +++ b/lib/pcg/include/pcg/initializer.h @@ -13,11 +13,15 @@ struct GlorotUniform { /* DataType data_type; */ }; FF_VISITABLE_STRUCT(GlorotUniform, seed); +FF_VISIT_FMTABLE(GlorotUniform); +CHECK_FMTABLE(GlorotUniform); struct ZeroInitializer { ZeroInitializer() = default; }; FF_VISITABLE_STRUCT(ZeroInitializer); +FF_VISIT_FMTABLE(ZeroInitializer); +CHECK_FMTABLE(ZeroInitializer); struct UniformInitializer { req seed; @@ -25,6 +29,8 @@ struct UniformInitializer { req max_val; }; FF_VISITABLE_STRUCT(UniformInitializer, seed, min_val, max_val); +FF_VISIT_FMTABLE(UniformInitializer); +CHECK_FMTABLE(UniformInitializer); struct NormInitializer { req seed; @@ -32,11 +38,15 @@ struct NormInitializer { req stddev; }; FF_VISITABLE_STRUCT(NormInitializer, seed, mean, stddev); +FF_VISIT_FMTABLE(NormInitializer); +CHECK_FMTABLE(NormInitializer); struct ConstantInitializer { req value; }; FF_VISITABLE_STRUCT(ConstantInitializer, value); +FF_VISIT_FMTABLE(ConstantInitializer); +CHECK_FMTABLE(ConstantInitializer); using Initializer = variant +struct formatter<::FlexFlow::Initializer> : formatter { + template + auto format(::FlexFlow::Initializer initializer, FormatContext &ctx) const + -> decltype(ctx.out()) { + using namespace FlexFlow; + + string_view s = "unknown"; + if (auto const *g = get_if(&initializer)) { + s = fmt::to_string(*g); + } else if (auto const *z = get_if(&initializer)) { + s = fmt::to_string(*z); + } else if (auto const *u = get_if(&initializer)) { + s = fmt::to_string(*u); + } else if (auto const *n = get_if(&initializer)) { + s = fmt::to_string(*n); + } else if (auto const *c = get_if(&initializer)) { + s = fmt::to_string(*c); + } + return formatter::format(s, ctx); + } +}; + +} // namespace fmt + #endif diff --git a/lib/pcg/include/pcg/layer.h b/lib/pcg/include/pcg/layer.h index 6e9415a8fb..d218e9d8b9 100644 --- a/lib/pcg/include/pcg/layer.h +++ b/lib/pcg/include/pcg/layer.h @@ -7,20 +7,15 @@ namespace FlexFlow { -struct Layer : public use_visitable_cmp { +struct Layer { public: - Layer() = delete; - Layer(CompGraphOperatorAttrs const &attrs, optional const &name); - -public: - optional> name; - CompGraphOperatorAttrs attrs; + req attrs; + req>> name; }; -} // namespace FlexFlow +FF_VISITABLE_STRUCT(Layer, attrs, name); -VISITABLE_STRUCT(::FlexFlow::Layer, attrs, name); -MAKE_VISIT_HASHABLE(::FlexFlow::Layer); +} // namespace FlexFlow namespace FlexFlow { diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h index c7a49bb57e..1f5a004680 100644 --- a/lib/pcg/include/pcg/operator.h +++ b/lib/pcg/include/pcg/operator.h @@ -8,24 +8,17 @@ namespace FlexFlow { -struct Operator : public use_visitable_cmp { +struct Operator { public: - Operator() = delete; - Operator(PCGOperatorAttrs const &attrs, optional const &name); - operator PCGOperatorAttrs() const; public: - PCGOperatorAttrs attrs; + req attrs; + req>> name; }; -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::Operator, attrs); -MAKE_VISIT_HASHABLE(::FlexFlow::Operator); +FF_VISITABLE_STRUCT(Operator, attrs, name); -namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); -} +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h index eadc83d9fd..55656c6a8e 100644 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ b/lib/pcg/include/pcg/parallel_tensor.h @@ -25,6 +25,7 @@ #include "initializer.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/param_sync.h" +#include "utils/visitable.h" namespace FlexFlow { @@ -34,41 +35,33 @@ namespace FlexFlow { * @details Parallel tensor is the fundamental component to support the * representation and exploration of parallelization strategies. */ -struct ParallelTensor : public use_visitable_cmp { - ParallelTensor() = delete; +struct ParallelTensor { + size_t get_volume() const; + ParallelTensorShape get_shape() const; + int num_dims() const; - ParallelTensor(ParallelTensorShape const &, - CreateGrad create_gradients, - optional sync_type = nullopt, - optional initializer = nullopt); - ParallelTensor(ParallelTensorDims const &, - DataType, - CreateGrad create_gradients, - optional sync_type = nullopt, - optional initializer = nullopt); + operator ParallelTensorShape() const; public: ParallelTensorDims dims; DataType data_type; - optional sync_type = nullopt; - optional initializer = nullopt; CreateGrad create_gradients; + req> initializer; + req> sync_type; + req> name; }; +FF_VISITABLE_STRUCT(ParallelTensor, + dims, + data_type, + create_gradients, + initializer, + sync_type, + name); +FF_VISIT_FMTABLE(ParallelTensor); +CHECK_FMTABLE(ParallelTensor); using ParallelParameter = ParallelTensor; } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ParallelTensor, - dims, - data_type, - sync_type, - initializer, - create_gradients); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensor); - -namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); -} - #endif diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 28331f441c..4fabf0f9f7 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -3,7 +3,6 @@ #include "op-attrs/dim_ordered.h" #include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" #include "utils/strong_typedef.h" #include "utils/visitable.h" @@ -12,23 +11,44 @@ namespace FlexFlow { struct num_points_t : public strong_typedef { using strong_typedef::strong_typedef; }; +FF_TYPEDEF_HASHABLE(num_points_t); +FF_TYPEDEF_PRINTABLE(num_points_t, "num_points"); struct side_size_t : public strong_typedef { using strong_typedef::strong_typedef; }; +FF_TYPEDEF_HASHABLE(side_size_t); +FF_TYPEDEF_PRINTABLE(side_size_t, "side_size"); struct StridedRectangleSide : public use_visitable_cmp { public: StridedRectangleSide() = delete; - StridedRectangleSide(num_points_t const &, int stride); - StridedRectangleSide(side_size_t const &, int stride); + StridedRectangleSide(num_points_t const &num_points, int stride) + : num_points(num_points), stride(stride) { + // FIXME: Move this definition elsewhere. + NOT_IMPLEMENTED(); + } + StridedRectangleSide(side_size_t const &num_points, int stride) + : num_points(num_points), stride(stride) { + // FIXME: Move this definition elsewhere. + NOT_IMPLEMENTED(); + } num_points_t get_num_points() const; - side_size_t get_size() const; + side_size_t get_size() const { + // FIXME: Move this definition elsewhere. + NOT_IMPLEMENTED(); + } int get_stride() const; - side_size_t at(num_points_t) const; - num_points_t at(side_size_t) const; + side_size_t at(num_points_t) const { + // FIXME: Move this definition elsewhere. + NOT_IMPLEMENTED(); + } + num_points_t at(side_size_t) const { + // FIXME: Move this definition elsewhere. + NOT_IMPLEMENTED(); + } public: num_points_t num_points; @@ -38,23 +58,24 @@ struct StridedRectangleSide : public use_visitable_cmp { struct StridedRectangle : public use_visitable_cmp { public: StridedRectangle() = delete; - StridedRectangle(std::vector const &); + StridedRectangle(std::vector const &sides) + : sides(sides) { + // FIXME: Move this definition elsewhere. + NOT_IMPLEMENTED(); + } size_t at(FFOrdered const &) const; StridedRectangleSide at(ff_dim_t const &) const; - size_t num_dims() const; + size_t num_dims() const { + // FIXME: Move this definition elsewhere. + NOT_IMPLEMENTED(); + } public: FFOrdered sides; }; } // namespace FlexFlow -MAKE_TYPEDEF_HASHABLE(::FlexFlow::num_points_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); - VISITABLE_STRUCT(::FlexFlow::StridedRectangleSide, num_points, stride); MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangleSide); diff --git a/lib/pcg/include/pcg/tensor.h b/lib/pcg/include/pcg/tensor.h index cb79be245a..3e2b970a89 100644 --- a/lib/pcg/include/pcg/tensor.h +++ b/lib/pcg/include/pcg/tensor.h @@ -9,12 +9,6 @@ namespace FlexFlow { struct Tensor { - /* Tensor() = delete; */ - /* Tensor(TensorShape const &, */ - /* CreateGrad create_gradients, */ - /* optional initializer = nullopt, */ - /* optional sync_type = nullopt); */ - size_t get_volume() const; TensorShape get_shape() const; int num_dims() const; @@ -24,12 +18,15 @@ struct Tensor { public: TensorDims dims; DataType data_type; + CreateGrad create_gradients; req> initializer; - req create_gradients; req> sync_type; + req> name; }; FF_VISITABLE_STRUCT( - Tensor, dims, data_type, initializer, create_gradients, sync_type); + Tensor, dims, data_type, create_gradients, initializer, sync_type, name); +FF_VISIT_FMTABLE(Tensor); +CHECK_FMTABLE(Tensor); using Parameter = Tensor; diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/computation_graph_builder.cc index a17cd18f7c..cc1f47b68d 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/computation_graph_builder.cc @@ -310,8 +310,16 @@ std::vector return this->add_layer(layer, {input}, {}, output_shapes); } -TensorShape get_shape(Tensor const &); -std::vector get_shape(std::vector const &); +TensorShape get_shape(Tensor const &t) { + return t.get_shape(); +} + +std::vector get_shape(std::vector const &ts) { + std::vector shps; + for (const Tensor& t : ts) + shps.emplace_back(t.get_shape()); + return shps; +} Tensor ComputationGraphBuilder::aggregate( Tensor const &gate_preds, @@ -351,4 +359,87 @@ Tensor ComputationGraphBuilder::batch_norm( return this->add_layer(layer, {input}, {}, output_shape); } +Tensor ComputationGraphBuilder::dense( + Tensor const &input, + int outDim, + optional activation, + bool use_bias, + DataType data_type, + optional kernel_initializer, + optional bias_initializer, + optional const &name) { + NOT_IMPLEMENTED(); +} + +Tensor ComputationGraphBuilder::cast(Tensor const &input, + DataType dtype, + optional const &name) { + NOT_IMPLEMENTED(); +} + +Tensor ComputationGraphBuilder::concat(int n, + std::vector const &tensors, + int axis, + optional const &name) { + NOT_IMPLEMENTED(); +} + +Tensor ComputationGraphBuilder::mean(Tensor const &input, + std::vector const &dims, + bool keepdims, + char const *name) { + NOT_IMPLEMENTED(); +} + +Tensor ComputationGraphBuilder::moe(Tensor const &input, + int num_exp, + int num_select, + int expert_hidden_size, + float alpha, + float lambda) { + NOT_IMPLEMENTED(); +} + +void ComputationGraphBuilder::split(Tensor const &input, + Tensor *outputs, + std::vector const &split, + int axis, + optional const &name) { + NOT_IMPLEMENTED(); +} + +Tensor ComputationGraphBuilder::broadcast(Tensor const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +void ComputationGraphBuilder::add_layer(Layer const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + NOT_IMPLEMENTED(); +} + +Tensor ComputationGraphBuilder::add_layer( + Layer const &layer, + std::vector const &inputs, + std::vector>> const + &weight_shapes, + TensorShape const &output_shape) { + NOT_IMPLEMENTED(); +} + +std::vector ComputationGraphBuilder::add_layer( + Layer const &layer, + std::vector const &inputs, + std::vector>> const + &weight_shapes, + std::vector const &output_shapes) { + NOT_IMPLEMENTED(); +} + +TensorShape ComputationGraphBuilder::get_broadcast_target_shape( + std::vector const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/device_id.cc b/lib/pcg/src/device_id.cc index b669009f42..7bb1de0763 100644 --- a/lib/pcg/src/device_id.cc +++ b/lib/pcg/src/device_id.cc @@ -3,6 +3,10 @@ namespace FlexFlow { +device_id_t operator+(device_id_t, size_t) { + NOT_IMPLEMENTED(); +} + DeviceType get_device_type(device_id_t const &id) { if (holds_alternative(id)) { return DeviceType::GPU; diff --git a/lib/pcg/src/file_format/v1/activation.cc b/lib/pcg/src/file_format/v1/activation.cc new file mode 100644 index 0000000000..2b35f53d01 --- /dev/null +++ b/lib/pcg/src/file_format/v1/activation.cc @@ -0,0 +1,36 @@ +#include "pcg/file_format/v1/activation.h" + +namespace FlexFlow { + +V1Activation to_v1(Activation const &a) { + // There should be a better way of doing this. + switch (a) { + case Activation::RELU: + return V1Activation::RELU; + case Activation::SIGMOID: + return V1Activation::SIGMOID; + case Activation::TANH: + return V1Activation::TANH; + case Activation::GELU: + return V1Activation::GELU; + default: + NOT_REACHABLE(); + } +} + +Activation from_v1(V1Activation const &va) { + switch (va) { + case V1Activation::RELU: + return Activation::RELU; + case V1Activation::SIGMOID: + return Activation::SIGMOID; + case V1Activation::TANH: + return Activation::TANH; + case V1Activation::GELU: + return Activation::GELU; + default: + NOT_REACHABLE(); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/create_grad.cc b/lib/pcg/src/file_format/v1/create_grad.cc new file mode 100644 index 0000000000..00dbf0ced8 --- /dev/null +++ b/lib/pcg/src/file_format/v1/create_grad.cc @@ -0,0 +1,28 @@ +#include "pcg/file_format/v1/create_grad.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1CreateGrad to_v1(CreateGrad const &cg) { + switch (cg) { + case CreateGrad::YES: + return V1CreateGrad::YES; + case CreateGrad::NO: + return V1CreateGrad::NO; + default: + NOT_REACHABLE(); + } +} + +CreateGrad from_v1(V1CreateGrad const &vcg) { + switch (vcg) { + case V1CreateGrad::YES: + return CreateGrad::YES; + case V1CreateGrad::NO: + return CreateGrad::NO; + default: + NOT_REACHABLE(); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/datatype.cc b/lib/pcg/src/file_format/v1/datatype.cc new file mode 100644 index 0000000000..6b5c644f86 --- /dev/null +++ b/lib/pcg/src/file_format/v1/datatype.cc @@ -0,0 +1,89 @@ +#include "pcg/file_format/v1/datatype.h" + +namespace FlexFlow { + +V1DataType to_v1(DataType const &dt) { + switch (dt) { + case DataType::BOOL: + return V1DataType::BOOL; + case DataType::INT32: + return V1DataType::INT32; + case DataType::INT64: + return V1DataType::INT64; + case DataType::HALF: + return V1DataType::HALF; + case DataType::FLOAT: + return V1DataType::FLOAT; + case DataType::DOUBLE: + return V1DataType::DOUBLE; + default: + // Should never get here unless a new element was added to the DataType + // enum that was not handled here. + NOT_REACHABLE(); + } +} + +DataType from_v1(V1DataType const &vdt) { + switch (vdt) { + case V1DataType::BOOL: + return DataType::BOOL; + case V1DataType::INT32: + return DataType::INT32; + case V1DataType::INT64: + return DataType::INT64; + case V1DataType::HALF: + return DataType::HALF; + case V1DataType::FLOAT: + return DataType::FLOAT; + case V1DataType::DOUBLE: + return DataType::DOUBLE; + default: + // Should never get here unless a new element was added to the DataType + // enum that was not handled here. + NOT_REACHABLE(); + } +} + +V1DataTypeValue to_v1(DataTypeValue const &dv) { + // There has to be a better way of doing this. + if (auto const *b = get_if(&dv)) { + return *b; + } else if (auto const *i32 = get_if(&dv)) { + return *i32; + } else if (auto const *i64 = get_if(&dv)) { + return *i64; + } else if (auto const *flt = get_if(&dv)) { + return *flt; + } else if (auto const *dbl = get_if(&dv)) { + return *dbl; + } else if (auto const *hlf = get_if(&dv)) { + return *hlf; + } else { + // Should never get here unless a new type was added into the DataTypeValue + // variant which was not handled here. + NOT_REACHABLE(); + } +} + +DataTypeValue from_v1(V1DataTypeValue const &vdv) { + // There has to be a better way of doing this. + if (auto const *b = get_if(&vdv)) { + return *b; + } else if (auto const *i32 = get_if(&vdv)) { + return *i32; + } else if (auto const *i64 = get_if(&vdv)) { + return *i64; + } else if (auto const *flt = get_if(&vdv)) { + return *flt; + } else if (auto const *dbl = get_if(&vdv)) { + return *dbl; + } else if (auto const *hlf = get_if(&vdv)) { + return *hlf; + } else { + // Should never get here unless a new type was added into the DataTypeValue + // variant which was not handled here. + NOT_REACHABLE(); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ff_dim.cc b/lib/pcg/src/file_format/v1/ff_dim.cc new file mode 100644 index 0000000000..c0627d7fbb --- /dev/null +++ b/lib/pcg/src/file_format/v1/ff_dim.cc @@ -0,0 +1,13 @@ +#include "pcg/file_format/v1/ff_dim.h" + +namespace FlexFlow { + +int to_v1(ff_dim_t const &t) { + return t.value(); +} + +ff_dim_t from_v1(int const &vt) { + return ff_dim_t(vt); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc index 519c38448b..13ea1124c3 100644 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ b/lib/pcg/src/file_format/v1/graphs.cc @@ -3,6 +3,22 @@ namespace FlexFlow { +V1Operator to_v1(Operator const &op) { + return {to_v1(op.attrs), op.name}; +} + +Operator from_v1(V1Operator const &vop) { + return {from_v1(vop.attrs), vop.name}; +} + +V1Layer to_v1(Layer const &l) { + return {to_v1(l.attrs), l.name}; +} + +Layer from_v1(V1Layer const &vl) { + return {from_v1(vl.attrs), vl.name}; +} + V1MultiDiGraph to_v1(MultiDiGraphView const &g) { return to_v1(g, enumerate(get_nodes(g)).reversed(), @@ -55,4 +71,16 @@ V1ComputationGraph to_v1(ComputationGraph const &g) { return to_v1(g.value()); } +ComputationGraph from_v1(V1ComputationGraph const &vg) { + NOT_IMPLEMENTED(); +} + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { + return to_v1(g.value()); +} + +ParallelComputationGraph from_v1(V1ParallelComputationGraph const &vg) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/initializer.cc b/lib/pcg/src/file_format/v1/initializer.cc new file mode 100644 index 0000000000..c06f689ccd --- /dev/null +++ b/lib/pcg/src/file_format/v1/initializer.cc @@ -0,0 +1,83 @@ +#include "pcg/file_format/v1/initializer.h" + +namespace FlexFlow { + +V1GlorotInitializer to_v1(GlorotUniform const &i) { + return {i.seed}; +} + +GlorotUniform from_v1(V1GlorotInitializer const &vi) { + return {vi.seed}; +} + +V1ZeroInitializer to_v1(ZeroInitializer const &i) { + return { + // No fields in ZeroInitializer. + }; +} + +ZeroInitializer from_v1(V1ZeroInitializer const &vi) { + return { + // No fields in V1ZeroInitializer + }; +} + +V1UniformInitializer to_v1(UniformInitializer const &i) { + return {i.seed, i.min_val, i.max_val}; +} + +UniformInitializer from_v1(V1UniformInitializer const &vi) { + return {vi.seed, vi.min_val, vi.max_val}; +} + +V1NormInitializer to_v1(NormInitializer const &i) { + return {i.seed, i.mean, i.stddev}; +} + +NormInitializer from_v1(V1NormInitializer const &vi) { + return {vi.seed, vi.mean, vi.stddev}; +} + +V1ConstantInitializer to_v1(ConstantInitializer const &i) { + return {to_v1(i.value)}; +} + +ConstantInitializer from_v1(V1ConstantInitializer const &vi) { + return {from_v1(vi.value)}; +} + +V1Initializer to_v1(Initializer const &i) { + // There is surely a better way of doing this ... + if (auto const *glorot = get_if(&i)) { + return to_v1(*glorot); + } else if (auto const *zero = get_if(&i)) { + return to_v1(*zero); + } else if (auto const *uniform = get_if(&i)) { + return to_v1(*uniform); + } else if (auto const *norm = get_if(&i)) { + return to_v1(*norm); + } else if (auto const *constant = get_if(&i)) { + return to_v1(*constant); + } else { + NOT_REACHABLE(); + } +} + +Initializer from_v1(V1Initializer const &vi) { + // There is surely a better way of doing this ... + if (auto const *glorot = get_if(&vi)) { + return from_v1(*glorot); + } else if (auto const *zero = get_if(&vi)) { + return from_v1(*zero); + } else if (auto const *uniform = get_if(&vi)) { + return from_v1(*uniform); + } else if (auto const *norm = get_if(&vi)) { + return from_v1(*norm); + } else if (auto const *constant = get_if(&vi)) { + return from_v1(*constant); + } else { + NOT_REACHABLE(); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/op.cc b/lib/pcg/src/file_format/v1/op.cc new file mode 100644 index 0000000000..6f9fa1b591 --- /dev/null +++ b/lib/pcg/src/file_format/v1/op.cc @@ -0,0 +1,365 @@ +#include "pcg/file_format/v1/op.h" + +namespace FlexFlow { + +V1Op to_v1(Op const &op) { + // There should be a better way of doing this. + switch (op) { + case Op::NOOP: + return V1Op::NOOP; + case Op::INPUT: + return V1Op::INPUT; + case Op::WEIGHT: + return V1Op::WEIGHT; + case Op::CONV2D: + return V1Op::CONV2D; + case Op::DROPOUT: + return V1Op::DROPOUT; + case Op::LINEAR: + return V1Op::LINEAR; + case Op::BATCHMATMUL: + return V1Op::BATCHMATMUL; + case Op::POOL2D: + return V1Op::POOL2D; + case Op::SCALAR_MULTIPLY: + return V1Op::SCALAR_MULTIPLY; + case Op::SCALAR_ADD: + return V1Op::SCALAR_ADD; + case Op::SCALAR_FLOOR_DIV: + return V1Op::SCALAR_FLOOR_DIV; + case Op::SCALAR_TRUE_DIV: + return V1Op::SCALAR_TRUE_DIV; + case Op::SCALAR_SUB: + return V1Op::SCALAR_SUB; + case Op::RELU: + return V1Op::RELU; + case Op::IDENTITY: + return V1Op::IDENTITY; + case Op::SIGMOID: + return V1Op::SIGMOID; + case Op::TANH: + return V1Op::TANH; + case Op::ELU: + return V1Op::ELU; + case Op::FLAT: + return V1Op::FLAT; + case Op::SOFTMAX: + return V1Op::SOFTMAX; + case Op::BATCHNORM: + return V1Op::BATCHNORM; + case Op::CONCAT: + return V1Op::CONCAT; + case Op::SPLIT: + return V1Op::SPLIT; + case Op::EMBEDDING: + return V1Op::EMBEDDING; + case Op::GROUP_BY: + return V1Op::GROUP_BY; + case Op::CACHE: + return V1Op::CACHE; + case Op::AGGREGATE: + return V1Op::AGGREGATE; + case Op::AGG_SPEC: + return V1Op::AGG_SPEC; + case Op::RESHAPE: + return V1Op::RESHAPE; + case Op::REVERSE: + return V1Op::REVERSE; + case Op::TRANSPOSE: + return V1Op::TRANSPOSE; + case Op::EW_ADD: + return V1Op::EW_ADD; + case Op::EW_MUL: + return V1Op::EW_MUL; + case Op::MATMUL: + return V1Op::MATMUL; + case Op::MUL: + return V1Op::MUL; + case Op::ENLARGE: + return V1Op::ENLARGE; + case Op::SQUEEZE: + return V1Op::SQUEEZE; + case Op::UNSQUEEZE: + return V1Op::UNSQUEEZE; + case Op::EW_SUB: + return V1Op::EW_SUB; + case Op::EW_DIV: + return V1Op::EW_DIV; + case Op::EW_EQUAL: + return V1Op::EW_EQUAL; + case Op::EW_GREATER: + return V1Op::EW_GREATER; + case Op::EW_LESS: + return V1Op::EW_LESS; + case Op::EW_MAX: + return V1Op::EW_MAX; + case Op::EW_MIN: + return V1Op::EW_MIN; + case Op::REDUCE_ARGMAX: + return V1Op::REDUCE_ARGMAX; + case Op::REDUCE_ARGMIN: + return V1Op::REDUCE_ARGMIN; + case Op::REDUCE_MAX: + return V1Op::REDUCE_MAX; + case Op::REDUCE_MEAN: + return V1Op::REDUCE_MEAN; + case Op::REDUCE_MIN: + return V1Op::REDUCE_MIN; + case Op::REDUCE_PROD: + return V1Op::REDUCE_PROD; + case Op::REDUCE_SUM: + return V1Op::REDUCE_SUM; + case Op::PAD: + return V1Op::PAD; + case Op::SHAPE: + return V1Op::SHAPE; + case Op::SIZE: + return V1Op::SIZE; + case Op::TOPK: + return V1Op::TOPK; + case Op::WHERE: + return V1Op::WHERE; + case Op::CEIL: + return V1Op::CEIL; + case Op::CAST: + return V1Op::CAST; + case Op::EXP: + return V1Op::EXP; + case Op::ROUND: + return V1Op::ROUND; + case Op::LOG: + return V1Op::LOG; + case Op::LOGICAL_NOT: + return V1Op::LOGICAL_NOT; + case Op::SQRT: + return V1Op::SQRT; + case Op::SIN: + return V1Op::SIN; + case Op::COS: + return V1Op::COS; + case Op::LEAKYRELU: + return V1Op::LEAKYRELU; + case Op::SLICE: + return V1Op::SLICE; + case Op::RESIZE: + return V1Op::RESIZE; + case Op::PRELU: + return V1Op::PRELU; + case Op::GELU: + return V1Op::GELU; + case Op::MULTIHEAD_ATTENTION: + return V1Op::MULTIHEAD_ATTENTION; + case Op::FUSED: + return V1Op::FUSED; + case Op::RSQRT: + return V1Op::RSQRT; + case Op::POW: + return V1Op::POW; + case Op::MEAN: + return V1Op::MEAN; + case Op::LAYERNORM: + return V1Op::LAYERNORM; + case Op::GATHER: + return V1Op::GATHER; + case Op::BROADCAST: + return V1Op::BROADCAST; + case Op::REPARTITION: + return V1Op::REPARTITION; + case Op::COMBINE: + return V1Op::COMBINE; + case Op::REPLICATE: + return V1Op::REPLICATE; + case Op::REDUCTION: + return V1Op::REDUCTION; + case Op::BATCH: + return V1Op::BATCH; + case Op::PIPELINE: + return V1Op::PIPELINE; + case Op::FUSED_PARALLEL: + return V1Op::FUSED_PARALLEL; + default: + NOT_REACHABLE(); + } +} + +Op from_v1(V1Op const &vop) { + // There should be a better way of doing this. + switch (vop) { + case V1Op::NOOP: + return Op::NOOP; + case V1Op::INPUT: + return Op::INPUT; + case V1Op::WEIGHT: + return Op::WEIGHT; + case V1Op::CONV2D: + return Op::CONV2D; + case V1Op::DROPOUT: + return Op::DROPOUT; + case V1Op::LINEAR: + return Op::LINEAR; + case V1Op::BATCHMATMUL: + return Op::BATCHMATMUL; + case V1Op::POOL2D: + return Op::POOL2D; + case V1Op::SCALAR_MULTIPLY: + return Op::SCALAR_MULTIPLY; + case V1Op::SCALAR_ADD: + return Op::SCALAR_ADD; + case V1Op::SCALAR_FLOOR_DIV: + return Op::SCALAR_FLOOR_DIV; + case V1Op::SCALAR_TRUE_DIV: + return Op::SCALAR_TRUE_DIV; + case V1Op::SCALAR_SUB: + return Op::SCALAR_SUB; + case V1Op::RELU: + return Op::RELU; + case V1Op::IDENTITY: + return Op::IDENTITY; + case V1Op::SIGMOID: + return Op::SIGMOID; + case V1Op::TANH: + return Op::TANH; + case V1Op::ELU: + return Op::ELU; + case V1Op::FLAT: + return Op::FLAT; + case V1Op::SOFTMAX: + return Op::SOFTMAX; + case V1Op::BATCHNORM: + return Op::BATCHNORM; + case V1Op::CONCAT: + return Op::CONCAT; + case V1Op::SPLIT: + return Op::SPLIT; + case V1Op::EMBEDDING: + return Op::EMBEDDING; + case V1Op::GROUP_BY: + return Op::GROUP_BY; + case V1Op::CACHE: + return Op::CACHE; + case V1Op::AGGREGATE: + return Op::AGGREGATE; + case V1Op::AGG_SPEC: + return Op::AGG_SPEC; + case V1Op::RESHAPE: + return Op::RESHAPE; + case V1Op::REVERSE: + return Op::REVERSE; + case V1Op::TRANSPOSE: + return Op::TRANSPOSE; + case V1Op::EW_ADD: + return Op::EW_ADD; + case V1Op::EW_MUL: + return Op::EW_MUL; + case V1Op::MATMUL: + return Op::MATMUL; + case V1Op::MUL: + return Op::MUL; + case V1Op::ENLARGE: + return Op::ENLARGE; + case V1Op::SQUEEZE: + return Op::SQUEEZE; + case V1Op::UNSQUEEZE: + return Op::UNSQUEEZE; + case V1Op::EW_SUB: + return Op::EW_SUB; + case V1Op::EW_DIV: + return Op::EW_DIV; + case V1Op::EW_EQUAL: + return Op::EW_EQUAL; + case V1Op::EW_GREATER: + return Op::EW_GREATER; + case V1Op::EW_LESS: + return Op::EW_LESS; + case V1Op::EW_MAX: + return Op::EW_MAX; + case V1Op::EW_MIN: + return Op::EW_MIN; + case V1Op::REDUCE_ARGMAX: + return Op::REDUCE_ARGMAX; + case V1Op::REDUCE_ARGMIN: + return Op::REDUCE_ARGMIN; + case V1Op::REDUCE_MAX: + return Op::REDUCE_MAX; + case V1Op::REDUCE_MEAN: + return Op::REDUCE_MEAN; + case V1Op::REDUCE_MIN: + return Op::REDUCE_MIN; + case V1Op::REDUCE_PROD: + return Op::REDUCE_PROD; + case V1Op::REDUCE_SUM: + return Op::REDUCE_SUM; + case V1Op::PAD: + return Op::PAD; + case V1Op::SHAPE: + return Op::SHAPE; + case V1Op::SIZE: + return Op::SIZE; + case V1Op::TOPK: + return Op::TOPK; + case V1Op::WHERE: + return Op::WHERE; + case V1Op::CEIL: + return Op::CEIL; + case V1Op::CAST: + return Op::CAST; + case V1Op::EXP: + return Op::EXP; + case V1Op::ROUND: + return Op::ROUND; + case V1Op::LOG: + return Op::LOG; + case V1Op::LOGICAL_NOT: + return Op::LOGICAL_NOT; + case V1Op::SQRT: + return Op::SQRT; + case V1Op::SIN: + return Op::SIN; + case V1Op::COS: + return Op::COS; + case V1Op::LEAKYRELU: + return Op::LEAKYRELU; + case V1Op::SLICE: + return Op::SLICE; + case V1Op::RESIZE: + return Op::RESIZE; + case V1Op::PRELU: + return Op::PRELU; + case V1Op::GELU: + return Op::GELU; + case V1Op::MULTIHEAD_ATTENTION: + return Op::MULTIHEAD_ATTENTION; + case V1Op::FUSED: + return Op::FUSED; + case V1Op::RSQRT: + return Op::RSQRT; + case V1Op::POW: + return Op::POW; + case V1Op::MEAN: + return Op::MEAN; + case V1Op::LAYERNORM: + return Op::LAYERNORM; + case V1Op::GATHER: + return Op::GATHER; + case V1Op::BROADCAST: + return Op::BROADCAST; + case V1Op::REPARTITION: + return Op::REPARTITION; + case V1Op::COMBINE: + return Op::COMBINE; + case V1Op::REPLICATE: + return Op::REPLICATE; + case V1Op::REDUCTION: + return Op::REDUCTION; + case V1Op::BATCH: + return Op::BATCH; + case V1Op::PIPELINE: + return Op::PIPELINE; + case V1Op::FUSED_PARALLEL: + return Op::FUSED_PARALLEL; + default: + NOT_REACHABLE(); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/operator_attrs.cc b/lib/pcg/src/file_format/v1/operator_attrs.cc new file mode 100644 index 0000000000..fc4b455b62 --- /dev/null +++ b/lib/pcg/src/file_format/v1/operator_attrs.cc @@ -0,0 +1,274 @@ +#include "pcg/file_format/v1/operator_attrs.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1CompGraphOperatorAttrs to_v1(CompGraphOperatorAttrs const &attrs) { + if (auto const *aggr = get_if(&attrs)) { + return to_v1(*aggr); + } else if (auto const *aggrSpec = get_if(&attrs)) { + return to_v1(*aggrSpec); + } else if (auto const *batchMm = get_if(&attrs)) { + return to_v1(*batchMm); + } else if (auto const *batchNorm = get_if(&attrs)) { + return to_v1(*batchNorm); + } else if (auto const *cast = get_if(&attrs)) { + return to_v1(*cast); + } else if (auto const *concat = get_if(&attrs)) { + return to_v1(*concat); + } else if (auto const *conv2d = get_if(&attrs)) { + return to_v1(*conv2d); + } else if (auto const *dropout = get_if(&attrs)) { + return to_v1(*dropout); + } else if (auto const *elemBin = get_if(&attrs)) { + return to_v1(*elemBin); + } else if (auto const *elemUnSc = get_if(&attrs)) { + return to_v1(*elemUnSc); + } else if (auto const *elemUn = get_if(&attrs)) { + return to_v1(*elemUn); + } else if (auto const *emb = get_if(&attrs)) { + return to_v1(*emb); + } else if (auto const *flat = get_if(&attrs)) { + return to_v1(*flat); + } else if (auto const *gather = get_if(&attrs)) { + return to_v1(*gather); + } else if (auto const *group = get_if(&attrs)) { + return to_v1(*group); + } else if (auto const *inp = get_if(&attrs)) { + return to_v1(*inp); + } else if (auto const *layerNorm = get_if(&attrs)) { + return to_v1(*layerNorm); + } else if (auto const *lin = get_if(&attrs)) { + return to_v1(*lin); + } else if (auto const *att = get_if(&attrs)) { + return to_v1(*att); + } else if (auto const *noop = get_if(&attrs)) { + return to_v1(*noop); + } else if (auto const *pool2d = get_if(&attrs)) { + return to_v1(*pool2d); + } else if (auto const *reduce = get_if(&attrs)) { + return to_v1(*reduce); + } else if (auto const *reverse = get_if(&attrs)) { + return to_v1(*reverse); + } else if (auto const *reshape = get_if(&attrs)) { + return to_v1(*reshape); + } else if (auto const *split = get_if(&attrs)) { + return to_v1(*split); + } else if (auto const *soft = get_if(&attrs)) { + return to_v1(*soft); + } else if (auto const *topk = get_if(&attrs)) { + return to_v1(*topk); + } else if (auto const *trans = get_if(&attrs)) { + return to_v1(*trans); + } else if (auto const *bcast = get_if(&attrs)) { + return to_v1(*bcast); + } else { + NOT_REACHABLE(); + } +} + +CompGraphOperatorAttrs from_v1(V1CompGraphOperatorAttrs const &va) { + if (auto const *aggr = get_if(&va)) { + return from_v1(*aggr); + } else if (auto const *aggrSpec = get_if(&va)) { + return from_v1(*aggrSpec); + } else if (auto const *batchMm = get_if(&va)) { + return from_v1(*batchMm); + } else if (auto const *batchNorm = get_if(&va)) { + return from_v1(*batchNorm); + } else if (auto const *cast = get_if(&va)) { + return from_v1(*cast); + } else if (auto const *concat = get_if(&va)) { + return from_v1(*concat); + } else if (auto const *conv2d = get_if(&va)) { + return from_v1(*conv2d); + } else if (auto const *dropout = get_if(&va)) { + return from_v1(*dropout); + } else if (auto const *elemBin = get_if(&va)) { + return from_v1(*elemBin); + } else if (auto const *elemUnSc = get_if(&va)) { + return from_v1(*elemUnSc); + } else if (auto const *elemUn = get_if(&va)) { + return from_v1(*elemUn); + } else if (auto const *emb = get_if(&va)) { + return from_v1(*emb); + } else if (auto const *flat = get_if(&va)) { + return from_v1(*flat); + } else if (auto const *gather = get_if(&va)) { + return from_v1(*gather); + } else if (auto const *group = get_if(&va)) { + return from_v1(*group); + } else if (auto const *inp = get_if(&va)) { + return from_v1(*inp); + } else if (auto const *layerNorm = get_if(&va)) { + return from_v1(*layerNorm); + } else if (auto const *lin = get_if(&va)) { + return from_v1(*lin); + } else if (auto const *att = get_if(&va)) { + return from_v1(*att); + } else if (auto const *noop = get_if(&va)) { + return from_v1(*noop); + } else if (auto const *pool2d = get_if(&va)) { + return from_v1(*pool2d); + } else if (auto const *reduce = get_if(&va)) { + return from_v1(*reduce); + } else if (auto const *reverse = get_if(&va)) { + return from_v1(*reverse); + } else if (auto const *reshape = get_if(&va)) { + return from_v1(*reshape); + } else if (auto const *split = get_if(&va)) { + return from_v1(*split); + } else if (auto const *soft = get_if(&va)) { + return from_v1(*soft); + } else if (auto const *topk = get_if(&va)) { + return from_v1(*topk); + } else if (auto const *trans = get_if(&va)) { + return from_v1(*trans); + } else if (auto const *bcast = get_if(&va)) { + return from_v1(*bcast); + } else { + NOT_REACHABLE(); + } +} + +V1PCGOperatorAttrs to_v1(PCGOperatorAttrs const &attrs) { + if (auto const *aggr = get_if(&attrs)) { + return to_v1(*aggr); + } else if (auto const *aggrSpec = get_if(&attrs)) { + return to_v1(*aggrSpec); + } else if (auto const *batchMm = get_if(&attrs)) { + return to_v1(*batchMm); + } else if (auto const *batchNorm = get_if(&attrs)) { + return to_v1(*batchNorm); + } else if (auto const *cast = get_if(&attrs)) { + return to_v1(*cast); + } else if (auto const *concat = get_if(&attrs)) { + return to_v1(*concat); + } else if (auto const *conv2d = get_if(&attrs)) { + return to_v1(*conv2d); + } else if (auto const *dropout = get_if(&attrs)) { + return to_v1(*dropout); + } else if (auto const *elemBin = get_if(&attrs)) { + return to_v1(*elemBin); + } else if (auto const *elemUnSc = get_if(&attrs)) { + return to_v1(*elemUnSc); + } else if (auto const *elemUn = get_if(&attrs)) { + return to_v1(*elemUn); + } else if (auto const *emb = get_if(&attrs)) { + return to_v1(*emb); + } else if (auto const *flat = get_if(&attrs)) { + return to_v1(*flat); + } else if (auto const *gather = get_if(&attrs)) { + return to_v1(*gather); + } else if (auto const *group = get_if(&attrs)) { + return to_v1(*group); + } else if (auto const *inp = get_if(&attrs)) { + return to_v1(*inp); + } else if (auto const *layerNorm = get_if(&attrs)) { + return to_v1(*layerNorm); + } else if (auto const *lin = get_if(&attrs)) { + return to_v1(*lin); + } else if (auto const *att = get_if(&attrs)) { + return to_v1(*att); + } else if (auto const *noop = get_if(&attrs)) { + return to_v1(*noop); + } else if (auto const *pool2d = get_if(&attrs)) { + return to_v1(*pool2d); + } else if (auto const *reduce = get_if(&attrs)) { + return to_v1(*reduce); + } else if (auto const *reverse = get_if(&attrs)) { + return to_v1(*reverse); + } else if (auto const *reshape = get_if(&attrs)) { + return to_v1(*reshape); + } else if (auto const *split = get_if(&attrs)) { + return to_v1(*split); + } else if (auto const *soft = get_if(&attrs)) { + return to_v1(*soft); + } else if (auto const *topk = get_if(&attrs)) { + return to_v1(*topk); + } else if (auto const *trans = get_if(&attrs)) { + return to_v1(*trans); + } else if (auto const *combine = get_if(&attrs)) { + return to_v1(*combine); + } else if (auto const *red = get_if(&attrs)) { + return to_v1(*red); + } else if (auto const *repart = get_if(&attrs)) { + return to_v1(*repart); + } else if (auto const *repl = get_if(&attrs)) { + return to_v1(*repl); + } else { + NOT_REACHABLE(); + } +} + +PCGOperatorAttrs from_v1(V1PCGOperatorAttrs const &va) { + if (auto const *aggr = get_if(&va)) { + return from_v1(*aggr); + } else if (auto const *aggrSpec = get_if(&va)) { + return from_v1(*aggrSpec); + } else if (auto const *batchMm = get_if(&va)) { + return from_v1(*batchMm); + } else if (auto const *batchNorm = get_if(&va)) { + return from_v1(*batchNorm); + } else if (auto const *cast = get_if(&va)) { + return from_v1(*cast); + } else if (auto const *concat = get_if(&va)) { + return from_v1(*concat); + } else if (auto const *conv2d = get_if(&va)) { + return from_v1(*conv2d); + } else if (auto const *dropout = get_if(&va)) { + return from_v1(*dropout); + } else if (auto const *elemBin = get_if(&va)) { + return from_v1(*elemBin); + } else if (auto const *elemUnSc = get_if(&va)) { + return from_v1(*elemUnSc); + } else if (auto const *elemUn = get_if(&va)) { + return from_v1(*elemUn); + } else if (auto const *emb = get_if(&va)) { + return from_v1(*emb); + } else if (auto const *flat = get_if(&va)) { + return from_v1(*flat); + } else if (auto const *gather = get_if(&va)) { + return from_v1(*gather); + } else if (auto const *group = get_if(&va)) { + return from_v1(*group); + } else if (auto const *inp = get_if(&va)) { + return from_v1(*inp); + } else if (auto const *layerNorm = get_if(&va)) { + return from_v1(*layerNorm); + } else if (auto const *lin = get_if(&va)) { + return from_v1(*lin); + } else if (auto const *att = get_if(&va)) { + return from_v1(*att); + } else if (auto const *noop = get_if(&va)) { + return from_v1(*noop); + } else if (auto const *pool2d = get_if(&va)) { + return from_v1(*pool2d); + } else if (auto const *reduce = get_if(&va)) { + return from_v1(*reduce); + } else if (auto const *reverse = get_if(&va)) { + return from_v1(*reverse); + } else if (auto const *reshape = get_if(&va)) { + return from_v1(*reshape); + } else if (auto const *split = get_if(&va)) { + return from_v1(*split); + } else if (auto const *soft = get_if(&va)) { + return from_v1(*soft); + } else if (auto const *topk = get_if(&va)) { + return from_v1(*topk); + } else if (auto const *trans = get_if(&va)) { + return from_v1(*trans); + } else if (auto const *combine = get_if(&va)) { + return from_v1(*combine); + } else if (auto const *red = get_if(&va)) { + return from_v1(*red); + } else if (auto const *repart = get_if(&va)) { + return from_v1(*repart); + } else if (auto const *repl = get_if(&va)) { + return from_v1(*repl); + } else { + NOT_REACHABLE(); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/aggregate.cc b/lib/pcg/src/file_format/v1/ops/aggregate.cc new file mode 100644 index 0000000000..47b5dfec6c --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/aggregate.cc @@ -0,0 +1,13 @@ +#include "pcg/file_format/v1/ops/aggregate.h" + +namespace FlexFlow { + +V1AggregateAttrs to_v1(AggregateAttrs const &a) { + return {a.n, a.lambda_bal}; +} + +AggregateAttrs from_v1(V1AggregateAttrs const &va) { + return {va.n, va.lambda_bal}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/aggregate_spec.cc b/lib/pcg/src/file_format/v1/ops/aggregate_spec.cc new file mode 100644 index 0000000000..a032124d50 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/aggregate_spec.cc @@ -0,0 +1,13 @@ +#include "pcg/file_format/v1/ops/aggregate_spec.h" + +namespace FlexFlow { + +V1AggregateSpecAttrs to_v1(AggregateSpecAttrs const &a) { + return {a.n, a.lambda_bal}; +} + +AggregateSpecAttrs from_v1(V1AggregateSpecAttrs const &va) { + return {va.n, va.lambda_bal}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/attention.cc b/lib/pcg/src/file_format/v1/ops/attention.cc new file mode 100644 index 0000000000..8850aa5d00 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/attention.cc @@ -0,0 +1,27 @@ +#include "pcg/file_format/v1/ops/attention.h" + +namespace FlexFlow { + +V1MultiHeadAttentionAttrs to_v1(MultiHeadAttentionAttrs const &a) { + return {a.embed_dim, + a.num_heads, + a.kdim, + a.vdim, + a.dropout, + a.bias, + a.add_bias_kv, + a.add_zero_attn}; +} + +MultiHeadAttentionAttrs from_v1(V1MultiHeadAttentionAttrs const &va) { + return {va.embed_dim, + va.num_heads, + va.kdim, + va.vdim, + va.dropout, + va.bias, + va.add_bias_kv, + va.add_zero_attn}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/batch_matmul.cc b/lib/pcg/src/file_format/v1/ops/batch_matmul.cc new file mode 100644 index 0000000000..86a49caf3f --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/batch_matmul.cc @@ -0,0 +1,14 @@ +#include "pcg/file_format/v1/ops/batch_matmul.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1BatchMatmulAttrs to_v1(BatchMatmulAttrs const &a) { + return {a.a_seq_length_dim, a.b_seq_length_dim}; +} + +BatchMatmulAttrs from_v1(V1BatchMatmulAttrs const &va) { + return {va.a_seq_length_dim, va.b_seq_length_dim}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/batch_norm.cc b/lib/pcg/src/file_format/v1/ops/batch_norm.cc new file mode 100644 index 0000000000..d0deaf95a6 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/batch_norm.cc @@ -0,0 +1,13 @@ +#include "pcg/file_format/v1/ops/batch_norm.h" + +namespace FlexFlow { + +V1BatchNormAttrs to_v1(BatchNormAttrs const &a) { + return {a.relu}; +} + +BatchNormAttrs from_v1(V1BatchNormAttrs const &va) { + return {va.relu}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/broadcast.cc b/lib/pcg/src/file_format/v1/ops/broadcast.cc new file mode 100644 index 0000000000..899d8de9cb --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/broadcast.cc @@ -0,0 +1,17 @@ +#include "pcg/file_format/v1/ops/broadcast.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1BroadcastAttrs to_v1(BroadcastAttrs const &a) { + return {std::vector(a.target_dims.begin(), a.target_dims.end())}; +} + +BroadcastAttrs from_v1(V1BroadcastAttrs const &va) { + stack_vector dims; + for (const int& dim : va.target_dims) + dims.emplace_back(dim); + return BroadcastAttrs{dims}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/cast.cc b/lib/pcg/src/file_format/v1/ops/cast.cc new file mode 100644 index 0000000000..c522d75171 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/cast.cc @@ -0,0 +1,14 @@ +#include "pcg/file_format/v1/ops/cast.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1CastAttrs to_v1(CastAttrs const &a) { + return {to_v1(a.dtype)}; +} + +CastAttrs from_v1(V1CastAttrs const &va) { + return {from_v1(va.dtype)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/combine.cc b/lib/pcg/src/file_format/v1/ops/combine.cc new file mode 100644 index 0000000000..62ce546edd --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/combine.cc @@ -0,0 +1,14 @@ +#include "pcg/file_format/v1/ops/combine.h" +#include "pcg/file_format/v1/ff_dim.h" + +namespace FlexFlow { + +V1CombineAttrs to_v1(CombineAttrs const &a) { + return {to_v1(a.combine_dim), a.combine_degree}; +} + +CombineAttrs from_v1(V1CombineAttrs const &va) { + return {from_v1(va.combine_dim), va.combine_degree}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/concat.cc b/lib/pcg/src/file_format/v1/ops/concat.cc new file mode 100644 index 0000000000..553f2fd11a --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/concat.cc @@ -0,0 +1,15 @@ +#include "pcg/file_format/v1/ops/concat.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1ConcatAttrs to_v1(ConcatAttrs const &a) { + return {to_v1(a.axis)}; +} + +ConcatAttrs from_v1(V1ConcatAttrs const &va) { + return {ff_dim_t(va.axis)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/conv_2d.cc b/lib/pcg/src/file_format/v1/ops/conv_2d.cc new file mode 100644 index 0000000000..270742b700 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/conv_2d.cc @@ -0,0 +1,32 @@ +#include "pcg/file_format/v1/ops/conv_2d.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1Conv2DAttrs to_v1(Conv2DAttrs const &a) { + return {a.out_channels, + a.kernel_h, + a.kernel_w, + a.stride_h, + a.stride_w, + a.padding_h, + a.padding_w, + a.groups, + to_v1(a.activation), + a.use_bias}; +} + +Conv2DAttrs from_v1(V1Conv2DAttrs const &va) { + return {va.out_channels, + va.kernel_h, + va.kernel_w, + va.stride_h, + va.stride_w, + va.padding_h, + va.padding_w, + va.groups, + from_v1(va.activation), + va.use_bias}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/dropout.cc b/lib/pcg/src/file_format/v1/ops/dropout.cc new file mode 100644 index 0000000000..2c92916530 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/dropout.cc @@ -0,0 +1,13 @@ +#include "pcg/file_format/v1/ops/dropout.h" + +namespace FlexFlow { + +V1DropoutAttrs to_v1(DropoutAttrs const &a) { + return {a.rate, a.seed}; +} + +DropoutAttrs from_v1(V1DropoutAttrs const &va) { + return {va.rate, va.seed}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/element_binary.cc b/lib/pcg/src/file_format/v1/ops/element_binary.cc new file mode 100644 index 0000000000..be2eb1a26d --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/element_binary.cc @@ -0,0 +1,19 @@ +#include "pcg/file_format/v1/ops/element_binary.h" + +namespace FlexFlow { + +V1ElementBinaryAttrs to_v1(ElementBinaryAttrs const &a) { + return {to_v1(a.type), + to_v1(a.compute_type), + a.should_broadcast_lhs, + a.should_broadcast_rhs}; +} + +ElementBinaryAttrs from_v1(V1ElementBinaryAttrs const &va) { + return {from_v1(va.type), + from_v1(va.compute_type), + va.should_broadcast_lhs, + va.should_broadcast_rhs}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/element_unary.cc b/lib/pcg/src/file_format/v1/ops/element_unary.cc new file mode 100644 index 0000000000..4aecc47c95 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/element_unary.cc @@ -0,0 +1,21 @@ +#include "pcg/file_format/v1/ops/element_unary.h" + +namespace FlexFlow { + +V1ElementScalarUnaryAttrs to_v1(ElementScalarUnaryAttrs const &a) { + return {to_v1(a.op), a.scalar}; +} + +ElementScalarUnaryAttrs from_v1(V1ElementScalarUnaryAttrs const &va) { + return {from_v1(va.op), va.scalar}; +} + +V1ElementUnaryAttrs to_v1(ElementUnaryAttrs const &a) { + return {to_v1(a.op)}; +} + +ElementUnaryAttrs from_v1(V1ElementUnaryAttrs const &va) { + return {from_v1(va.op)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/embedding.cc b/lib/pcg/src/file_format/v1/ops/embedding.cc new file mode 100644 index 0000000000..6fb54883c6 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/embedding.cc @@ -0,0 +1,39 @@ +#include "pcg/file_format/v1/ops/embedding.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1AggregateOp to_v1(AggregateOp const &op) { + // There should be a better way of doing this. + switch (op) { + case AggregateOp::SUM: + return V1AggregateOp::SUM; + case AggregateOp::AVG: + return V1AggregateOp::AVG; + default: + NOT_REACHABLE(); + } +} + +AggregateOp from_v1(V1AggregateOp const &vop) { + // There should be a better way of doing this. + switch (vop) { + case V1AggregateOp::SUM: + return AggregateOp::SUM; + case V1AggregateOp::AVG: + return AggregateOp::AVG; + default: + NOT_REACHABLE(); + } +} + +V1EmbeddingAttrs to_v1(EmbeddingAttrs const &a) { + return {a.num_entries, a.out_channels, to_v1(a.aggr), to_v1(a.data_type)}; +} + +EmbeddingAttrs from_v1(V1EmbeddingAttrs const &va) { + return { + va.num_entries, va.out_channels, from_v1(va.aggr), from_v1(va.data_type)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/flat.cc b/lib/pcg/src/file_format/v1/ops/flat.cc new file mode 100644 index 0000000000..be0c056b1c --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/flat.cc @@ -0,0 +1,18 @@ +#include "pcg/file_format/v1/ops/flat.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1FlatAttrs to_v1(FlatAttrs const &a) { + return { + // No fields in FlatAttrs + }; +} + +FlatAttrs from_v1(V1FlatAttrs const &va) { + return { + // No fields in FlatAttrs + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/gather.cc b/lib/pcg/src/file_format/v1/ops/gather.cc new file mode 100644 index 0000000000..c6ce769f5e --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/gather.cc @@ -0,0 +1,15 @@ +#include "pcg/file_format/v1/ops/gather.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1GatherAttrs to_v1(GatherAttrs const &a) { + return {to_v1(a.dim)}; +} + +GatherAttrs from_v1(V1GatherAttrs const &va) { + return {ff_dim_t(va.dim)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/groupby.cc b/lib/pcg/src/file_format/v1/ops/groupby.cc new file mode 100644 index 0000000000..11325e7a55 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/groupby.cc @@ -0,0 +1,13 @@ +#include "pcg/file_format/v1/ops/groupby.h" + +namespace FlexFlow { + +V1Group_byAttrs to_v1(Group_byAttrs const &a) { + return {a.n, a.alpha}; +} + +Group_byAttrs from_v1(V1Group_byAttrs const &va) { + return {va.n, va.alpha}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/input.cc b/lib/pcg/src/file_format/v1/ops/input.cc new file mode 100644 index 0000000000..aad7b98add --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/input.cc @@ -0,0 +1,18 @@ +#include "pcg/file_format/v1/ops/input.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1InputAttrs to_v1(InputAttrs const &a) { + return { + // No fields in InputAttrs + }; +} + +InputAttrs from_v1(V1InputAttrs const &va) { + return { + // No fields in InputAttrs + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/layer_norm.cc b/lib/pcg/src/file_format/v1/ops/layer_norm.cc new file mode 100644 index 0000000000..4fd52ed1c6 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/layer_norm.cc @@ -0,0 +1,21 @@ +#include "pcg/file_format/v1/ops/layer_norm.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1LayerNormAttrs to_v1(LayerNormAttrs const &a) { + return {std::vector(a.axes.begin(), a.axes.end()), + a.elementwise_affine, + a.eps}; +} + +LayerNormAttrs from_v1(V1LayerNormAttrs const &va) { + + return { + stack_vector(va.axes.begin(), va.axes.end()), + va.elementwise_affine, + va.eps}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/linear.cc b/lib/pcg/src/file_format/v1/ops/linear.cc new file mode 100644 index 0000000000..53549a3919 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/linear.cc @@ -0,0 +1,60 @@ +#include "pcg/file_format/v1/ops/linear.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1L1RegularizerAttrs to_v1(L1RegularizerAttrs const &a) { + return {a.lambda}; +} + +L1RegularizerAttrs from_v1(V1L1RegularizerAttrs const &va) { + return {va.lambda}; +} + +V1L2RegularizerAttrs to_v1(L2RegularizerAttrs const &a) { + return {a.lambda}; +} + +L2RegularizerAttrs from_v1(V1L2RegularizerAttrs const &va) { + return {va.lambda}; +} + +V1RegularizerAttrs to_v1(RegularizerAttrs const &a) { + // There should be a better way of doing this. + if (auto const *l1 = get_if(&a)) { + return to_v1(*l1); + } else if (auto const *l2 = get_if(&a)) { + return to_v1(*l2); + } else { + NOT_REACHABLE(); + } +} + +RegularizerAttrs from_v1(V1RegularizerAttrs const &a) { + // There should be a better way of doing this. + if (auto const *l1 = get_if(&a)) { + return from_v1(*l1); + } else if (auto const *l2 = get_if(&a)) { + return from_v1(*l2); + } else { + NOT_REACHABLE(); + } +} + +V1LinearAttrs to_v1(LinearAttrs const &a) { + return {a.out_channels, + a.use_bias, + to_v1(a.data_type), + to_v1(a.activation), + to_v1(a.regularizer)}; +} + +LinearAttrs from_v1(V1LinearAttrs const &va) { + return {va.out_channels, + va.use_bias, + from_v1(va.data_type), + from_v1(va.activation), + from_v1(va.regularizer)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/noop.cc b/lib/pcg/src/file_format/v1/ops/noop.cc new file mode 100644 index 0000000000..cbda4195ad --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/noop.cc @@ -0,0 +1,18 @@ +#include "pcg/file_format/v1/ops/noop.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1NoopAttrs to_v1(NoopAttrs const &a) { + return { + // No fields in NoopAttrs. + }; +} + +NoopAttrs from_v1(V1NoopAttrs const &va) { + return { + // No fields in NoopAttrs. + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/pool_2d.cc b/lib/pcg/src/file_format/v1/ops/pool_2d.cc new file mode 100644 index 0000000000..1127d76c5e --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/pool_2d.cc @@ -0,0 +1,52 @@ +#include "pcg/file_format/v1/ops/pool_2d.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1PoolOp to_v1(PoolOp const &op) { + // There should be a better way of doing this. + switch (op) { + case PoolOp::MAX: + return V1PoolOp::MAX; + case PoolOp::AVG: + return V1PoolOp::AVG; + default: + NOT_REACHABLE(); + } +} + +PoolOp from_v1(V1PoolOp const &vop) { + // There should be a better way of doing this. + switch (vop) { + case V1PoolOp::MAX: + return PoolOp::MAX; + case V1PoolOp::AVG: + return PoolOp::AVG; + default: + NOT_REACHABLE(); + } +} + +V1Pool2DAttrs to_v1(Pool2DAttrs const &a) { + return {a.kernel_h, + a.kernel_w, + a.stride_h, + a.stride_w, + a.padding_h, + a.padding_w, + to_v1(a.pool_type), + to_v1(a.activation)}; +} + +Pool2DAttrs from_v1(V1Pool2DAttrs const &va) { + return {va.kernel_h, + va.kernel_w, + va.stride_h, + va.stride_w, + va.padding_h, + va.padding_w, + from_v1(va.pool_type), + from_v1(va.activation)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/reduce.cc b/lib/pcg/src/file_format/v1/ops/reduce.cc new file mode 100644 index 0000000000..0e4188a588 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/reduce.cc @@ -0,0 +1,20 @@ +#include "pcg/file_format/v1/ops/reduce.h" +#include "pcg/file_format/v1/ff_dim.h" + +namespace FlexFlow { + +V1ReduceAttrs to_v1(ReduceAttrs const &a) { + return {std::vector(a.axes.begin(), a.axes.end()), + to_v1(a.op_type), + a.keepdims}; +} + +ReduceAttrs from_v1(V1ReduceAttrs const &va) { + stack_vector axes; + for (int const &d : va.axes) { + axes.push_back(ff_dim_t(d)); + } + return {axes, from_v1(va.op_type), va.keepdims}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/reduction.cc b/lib/pcg/src/file_format/v1/ops/reduction.cc new file mode 100644 index 0000000000..606ba14ba7 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/reduction.cc @@ -0,0 +1,14 @@ +#include "pcg/file_format/v1/ops/reduction.h" +#include "pcg/file_format/v1/ff_dim.h" + +namespace FlexFlow { + +V1ReductionAttrs to_v1(ReductionAttrs const &a) { + return {to_v1(a.reduction_dim), a.reduction_degree}; +} + +ReductionAttrs from_v1(V1ReductionAttrs const &va) { + return {ff_dim_t(va.reduction_dim), va.reduction_degree}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/repartition.cc b/lib/pcg/src/file_format/v1/ops/repartition.cc new file mode 100644 index 0000000000..55c4fc175a --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/repartition.cc @@ -0,0 +1,14 @@ +#include "pcg/file_format/v1/ops/repartition.h" +#include "pcg/file_format/v1/ff_dim.h" + +namespace FlexFlow { + +V1RepartitionAttrs to_v1(RepartitionAttrs const &a) { + return {to_v1(a.repartition_dim), a.repartition_degree}; +} + +RepartitionAttrs from_v1(V1RepartitionAttrs const &va) { + return {from_v1(va.repartition_dim), va.repartition_degree}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/replicate.cc b/lib/pcg/src/file_format/v1/ops/replicate.cc new file mode 100644 index 0000000000..520f0710d7 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/replicate.cc @@ -0,0 +1,14 @@ +#include "pcg/file_format/v1/ops/replicate.h" +#include "pcg/file_format/v1/ff_dim.h" + +namespace FlexFlow { + +V1ReplicateAttrs to_v1(ReplicateAttrs const &a) { + return {to_v1(a.replicate_dim), a.replicate_degree}; +} + +ReplicateAttrs from_v1(V1ReplicateAttrs const &va) { + return {ff_dim_t(va.replicate_dim), va.replicate_degree}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/reshape.cc b/lib/pcg/src/file_format/v1/ops/reshape.cc new file mode 100644 index 0000000000..e9832b1353 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/reshape.cc @@ -0,0 +1,14 @@ +#include "pcg/file_format/v1/ops/reshape.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1ReshapeAttrs to_v1(ReshapeAttrs const &a) { + return {to_v1(a.shape)}; +} + +ReshapeAttrs from_v1(V1ReshapeAttrs const &va) { + return {from_v1(va.shape)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/reverse.cc b/lib/pcg/src/file_format/v1/ops/reverse.cc new file mode 100644 index 0000000000..13c5e2cd43 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/reverse.cc @@ -0,0 +1,14 @@ +#include "pcg/file_format/v1/ops/reverse.h" +#include "pcg/file_format/v1/ff_dim.h" + +namespace FlexFlow { + +V1ReverseAttrs to_v1(ReverseAttrs const &a) { + return {to_v1(a.axis)}; +} + +ReverseAttrs from_v1(V1ReverseAttrs const &va) { + return {from_v1(va.axis)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/softmax.cc b/lib/pcg/src/file_format/v1/ops/softmax.cc new file mode 100644 index 0000000000..61e817512a --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/softmax.cc @@ -0,0 +1,15 @@ +#include "pcg/file_format/v1/ops/softmax.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1SoftmaxAttrs to_v1(SoftmaxAttrs const &a) { + return {to_v1(a.dim)}; +} + +SoftmaxAttrs from_v1(V1SoftmaxAttrs const &va) { + return {ff_dim_t(va.dim)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/split.cc b/lib/pcg/src/file_format/v1/ops/split.cc new file mode 100644 index 0000000000..478e175140 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/split.cc @@ -0,0 +1,19 @@ +#include "pcg/file_format/v1/ops/split.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1SplitAttrs to_v1(SplitAttrs const &a) { + return {std::vector(a.splits.begin(), a.splits.end()), to_v1(a.axis)}; +} + +SplitAttrs from_v1(V1SplitAttrs const &va) { + stack_vector splits; + for (int const &i : va.splits) { + splits.push_back(i); + } + return {splits, ff_dim_t(va.axis)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/topk.cc b/lib/pcg/src/file_format/v1/ops/topk.cc new file mode 100644 index 0000000000..1a8ab444d6 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/topk.cc @@ -0,0 +1,13 @@ +#include "pcg/file_format/v1/ops/topk.h" + +namespace FlexFlow { + +V1TopKAttrs to_v1(TopKAttrs const &a) { + return {a.k, a.sorted}; +} + +TopKAttrs from_v1(V1TopKAttrs const &va) { + return {va.k, va.sorted}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/ops/transpose.cc b/lib/pcg/src/file_format/v1/ops/transpose.cc new file mode 100644 index 0000000000..6fe12331c7 --- /dev/null +++ b/lib/pcg/src/file_format/v1/ops/transpose.cc @@ -0,0 +1,19 @@ +#include "pcg/file_format/v1/ops/transpose.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1TransposeAttrs to_v1(TransposeAttrs const &a) { + return {std::vector(a.perm.begin(), a.perm.end())}; +} + +TransposeAttrs from_v1(V1TransposeAttrs const &va) { + stack_vector perm; + for (int const &i : va.perm) { + perm.push_back(ff_dim_t(i)); + } + return {perm}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/parallel_tensor.cc b/lib/pcg/src/file_format/v1/parallel_tensor.cc new file mode 100644 index 0000000000..61aa52a80f --- /dev/null +++ b/lib/pcg/src/file_format/v1/parallel_tensor.cc @@ -0,0 +1,44 @@ +#include "pcg/file_format/v1/parallel_tensor.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1ParallelDim to_v1(ParallelDim const &dim) { + return {dim.size, dim.degree, dim.is_replica_dim}; +} + +ParallelDim from_v1(V1ParallelDim const &vdim) { + return {vdim.size, vdim.degree, vdim.is_replica_dim}; +} + +V1ParallelTensorShape to_v1(ParallelTensorShape const &shp) { + std::vector pdims; + for (ParallelDim const &pdim : shp.dims.data) { + pdims.emplace_back(to_v1(pdim)); + } + return {pdims, to_v1(shp.data_type)}; +} + +ParallelTensorShape from_v1(V1ParallelTensorShape const &vshp) { + return ParallelTensorShape(from_v1(vshp.dims), from_v1(vshp.data_type)); +} + +V1ParallelTensor to_v1(ParallelTensor const &t) { + return {to_v1(t.get_shape()), + to_v1(t.create_gradients), + to_v1(t.initializer), + to_v1(t.sync_type), + t.name}; +} + +ParallelTensor from_v1(V1ParallelTensor const &vt) { + ParallelTensorShape shape = from_v1(vt.shape); + return {shape.dims, + shape.data_type, + from_v1(vt.create_gradients), + from_v1(vt.initializer), + from_v1(vt.sync_type), + vt.name}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/parallel_tensor_dims.cc b/lib/pcg/src/file_format/v1/parallel_tensor_dims.cc new file mode 100644 index 0000000000..de8698e379 --- /dev/null +++ b/lib/pcg/src/file_format/v1/parallel_tensor_dims.cc @@ -0,0 +1,21 @@ +#include "pcg/file_format/v1/parallel_tensor_dims.h" + +namespace FlexFlow { + +V1ParallelTensorDims to_v1(ParallelTensorDims const &dims) { + std::vector data; + for (ParallelDim const &pdim : dims.data) { + data.emplace_back(to_v1(pdim)); + } + return {data}; +} + +ParallelTensorDims from_v1(V1ParallelTensorDims const &vdims) { + std::vector dims; + for (V1ParallelDim const &pdim : vdims.data) { + dims.emplace_back(from_v1(pdim)); + } + return ParallelTensorDims(dims); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/param_sync.cc b/lib/pcg/src/file_format/v1/param_sync.cc new file mode 100644 index 0000000000..26741c9b7a --- /dev/null +++ b/lib/pcg/src/file_format/v1/param_sync.cc @@ -0,0 +1,27 @@ +#include "pcg/file_format/v1/param_sync.h" + +namespace FlexFlow { + +V1ParamSync to_v1(ParamSync const &p) { + switch (p) { + case ParamSync::PS: + return V1ParamSync::PARAM_SERVER; + case ParamSync::NCCL: + return V1ParamSync::NCCL; + default: + NOT_REACHABLE(); + }; +} + +ParamSync from_v1(V1ParamSync const &vp) { + switch (vp) { + case V1ParamSync::PARAM_SERVER: + return ParamSync::PS; + case V1ParamSync::NCCL: + return ParamSync::NCCL; + default: + NOT_REACHABLE(); + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/tensor.cc b/lib/pcg/src/file_format/v1/tensor.cc new file mode 100644 index 0000000000..ff4c3d84c8 --- /dev/null +++ b/lib/pcg/src/file_format/v1/tensor.cc @@ -0,0 +1,24 @@ +#include "pcg/file_format/v1/tensor.h" +#include "pcg/file_format/v1/v1.h" + +namespace FlexFlow { + +V1Tensor to_v1(Tensor const &t) { + return {to_v1(t.get_shape()), + to_v1(t.create_gradients), + to_v1(t.initializer), + to_v1(t.sync_type), + t.name}; +} + +Tensor from_v1(V1Tensor const &vt) { + TensorShape shape = from_v1(vt.shape); + return {shape.dims, + shape.data_type, + from_v1(vt.create_gradients), + from_v1(vt.initializer), + from_v1(vt.sync_type), + vt.name}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/tensor_shape.cc b/lib/pcg/src/file_format/v1/tensor_shape.cc new file mode 100644 index 0000000000..ec1419d65e --- /dev/null +++ b/lib/pcg/src/file_format/v1/tensor_shape.cc @@ -0,0 +1,15 @@ +#include "pcg/file_format/v1/tensor_shape.h" + +namespace FlexFlow { + +V1TensorShape to_v1(TensorShape const &shape) { + return {std::vector(shape.dims.begin(), shape.dims.end()), + to_v1(shape.data_type)}; +} + +TensorShape from_v1(V1TensorShape const &vshp) { + return {TensorDims(vshp.dims.begin(), vshp.dims.end()), + from_v1(vshp.data_type)}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/v1.cc b/lib/pcg/src/file_format/v1/v1.cc index 530d3955ec..3f32837609 100644 --- a/lib/pcg/src/file_format/v1/v1.cc +++ b/lib/pcg/src/file_format/v1/v1.cc @@ -1,3 +1,13 @@ #include "pcg/file_format/v1/v1.h" -namespace FlexFlow {} +namespace FlexFlow { + +std::string to_v1(std::string const &s) { + return s; +} + +std::string from_v1(std::string const &vs) { + return vs; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/layer.cc b/lib/pcg/src/layer.cc deleted file mode 100644 index 27d5b31003..0000000000 --- a/lib/pcg/src/layer.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "pcg/layer.h" - -namespace FlexFlow { - -Layer::Layer(CompGraphOperatorAttrs const &_attrs, - optional const &_name) - : attrs(_attrs), name(_name) {} - -} // namespace FlexFlow diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc index 46f87833f0..9edfb09a8e 100644 --- a/lib/pcg/src/machine_view.cc +++ b/lib/pcg/src/machine_view.cc @@ -3,6 +3,9 @@ namespace FlexFlow { +MachineView::MachineView(device_id_t const &start, StridedRectangle const &rect) + : start(start), rect(rect) {} + static StridedRectangle make_1d_rect(int start, int stop, int stride) { assert(stop > start); assert(stride > 0); diff --git a/lib/pcg/src/parallel_tensor.cc b/lib/pcg/src/parallel_tensor.cc new file mode 100644 index 0000000000..f63f31224a --- /dev/null +++ b/lib/pcg/src/parallel_tensor.cc @@ -0,0 +1,21 @@ +#include "pcg/parallel_tensor.h" + +namespace FlexFlow { + +size_t ParallelTensor::get_volume() const { + NOT_IMPLEMENTED(); +} + +ParallelTensorShape ParallelTensor::get_shape() const { + return {dims, data_type}; +} + +int ParallelTensor::num_dims() const { + NOT_IMPLEMENTED(); +} + +ParallelTensor::operator ParallelTensorShape() const { + return get_shape(); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/tensor.cc b/lib/pcg/src/tensor.cc new file mode 100644 index 0000000000..f5de632c71 --- /dev/null +++ b/lib/pcg/src/tensor.cc @@ -0,0 +1,20 @@ +#include "pcg/tensor.h" + +namespace FlexFlow { +size_t Tensor::get_volume() const { + NOT_IMPLEMENTED(); +} + +TensorShape Tensor::get_shape() const { + return {dims, data_type}; +} + +int Tensor::num_dims() const { + NOT_IMPLEMENTED(); +} + +Tensor::operator TensorShape() const { + return get_shape(); +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/CMakeLists.txt b/lib/pcg/test/CMakeLists.txt new file mode 100644 index 0000000000..6dc844a25b --- /dev/null +++ b/lib/pcg/test/CMakeLists.txt @@ -0,0 +1,26 @@ +set(test_target pcg-test) +project(${test_target}) + +file(GLOB_RECURSE SRC + CONFIGURE_DEPENDS + LIST_DIRECTORIES False + *.cc) + +add_executable( + ${test_target} + ${SRC}) + +define_ff_vars(${test_target}) + +target_link_libraries( + ${test_target} + pcg + utils + doctest::doctest) + +set_target_properties( + ${test_target} + PROPERTIES + CXX_EXTENSIONS NO +) +doctest_discover_tests(${test_target}) diff --git a/lib/pcg/test/doctest.h b/lib/pcg/test/doctest.h new file mode 100644 index 0000000000..1712a05a5f --- /dev/null +++ b/lib/pcg/test/doctest.h @@ -0,0 +1,70 @@ +#include "doctest/doctest.h" +#include "utils/containers.h" +#include +#include +#include +#include + +namespace doctest { + +template +std::string + doctest_print_container(InputIt first, + InputIt last, + std::string const &open, + std::string const &delimiter, + std::string const &close, + std::function const &f) { + if (first == last) { + return open + "(empty)" + close; + } else { + return open + FlexFlow::join_strings(first, last, delimiter, f) + close; + } +} + +template +std::string doctest_print_container(InputIt first, + InputIt last, + std::string const &open, + std::string const &delimiter, + std::string const &close) { + return doctest_print_container( + first, last, open, delimiter, close, [](InputIt ref) { return *ref; }); +} + +template +std::string doctest_print_container(Container const &c, + std::string const &open, + std::string const &delimiter, + std::string const &close) { + return doctest_print_container( + c.cbegin(), c.cend(), open, delimiter, close); +} + +template +struct StringMaker> { + static String convert(std::unordered_set const &s) { + return doctest_print_container(s, "{ ", ", ", " }").c_str(); + } +}; + +template +struct StringMaker> { + static String convert(std::unordered_map const &m) { + std::unordered_set entries; + for (auto const &kv : m) { + std::ostringstream oss; + oss << toString(kv.first) << " -> " << toString(kv.second); + entries.insert(oss.str()); + } + return toString(entries); + } +}; + +template +struct StringMaker> { + static String convert(std::vector const &vec) { + return doctest_print_container(vec, "[ ", ", ", " ]").c_str(); + } +}; +} // namespace doctest diff --git a/lib/pcg/test/main.cc b/lib/pcg/test/main.cc new file mode 100644 index 0000000000..9522fa7fdb --- /dev/null +++ b/lib/pcg/test/main.cc @@ -0,0 +1,2 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "doctest/doctest.h" diff --git a/lib/pcg/test/test_activation.cc b/lib/pcg/test/test_activation.cc new file mode 100644 index 0000000000..0e9210f7dd --- /dev/null +++ b/lib/pcg/test/test_activation.cc @@ -0,0 +1,23 @@ +#include "doctest.h" +#include "pcg/file_format/v1/activation.h" +#include "utils.h" + +using namespace FlexFlow; + +TEST_CASE("Activation") { + V1Activation v10 = to_v1(Activation::RELU); + CHECK(from_v1(v10) == Activation::RELU); + CHECK(str(json(v10)) == "\"RELU\""); + + V1Activation v11 = to_v1(Activation::SIGMOID); + CHECK(from_v1(v11) == Activation::SIGMOID); + CHECK(str(json(v11)) == "\"SIGMOID\""); + + V1Activation v12 = to_v1(Activation::TANH); + CHECK(from_v1(v12) == Activation::TANH); + CHECK(str(json(v12)) == "\"TANH\""); + + V1Activation v13 = to_v1(Activation::GELU); + CHECK(from_v1(v13) == Activation::GELU); + CHECK(str(json(v13)) == "\"GELU\""); +} diff --git a/lib/pcg/test/test_create_grad.cc b/lib/pcg/test/test_create_grad.cc new file mode 100644 index 0000000000..ffffafd3e7 --- /dev/null +++ b/lib/pcg/test/test_create_grad.cc @@ -0,0 +1,15 @@ +#include "doctest.h" +#include "pcg/file_format/v1/create_grad.h" +#include "utils.h" + +using namespace FlexFlow; + +TEST_CASE("CreateGrad") { + V1CreateGrad v10 = to_v1(CreateGrad::YES); + CHECK(from_v1(v10) == CreateGrad::YES); + CHECK(str(json(v10)) == "\"YES\""); + + V1CreateGrad v11 = to_v1(CreateGrad::NO); + CHECK(from_v1(v11) == CreateGrad::NO); + CHECK(str(json(v11)) == "\"NO\""); +} diff --git a/lib/pcg/test/test_datatype.cc b/lib/pcg/test/test_datatype.cc new file mode 100644 index 0000000000..b39483ce5b --- /dev/null +++ b/lib/pcg/test/test_datatype.cc @@ -0,0 +1,37 @@ +#include "doctest.h" +#include "pcg/file_format/v1/datatype.h" +#include "utils.h" + +using namespace FlexFlow; + +#define TEST_MEMBER(m, exp) \ + do { \ + V1DataType v10 = to_v1(m); \ + CHECK(from_v1(v10) == m); \ + CHECK(str(json(v10)) == "\"" exp "\""); \ + } while (0) + +TEST_CASE("DataType") { + TEST_MEMBER(DataType::BOOL, "BOOL"); + TEST_MEMBER(DataType::INT32, "INT32"); + TEST_MEMBER(DataType::INT64, "INT64"); + TEST_MEMBER(DataType::HALF, "HALF"); + TEST_MEMBER(DataType::FLOAT, "FLOAT"); + TEST_MEMBER(DataType::DOUBLE, "DOUBLE"); +} + +TEST_CASE("DataTypeValue") { + DataTypeValue b = true; + DataTypeValue i32 = (int32_t)32; + DataTypeValue i64 = (int64_t)64L; + DataTypeValue f16 = half(3.14); + DataTypeValue f32 = (float)2.71828; + DataTypeValue f64 = (double)1.414235; + + CHECK(from_v1(to_v1(b)) == b); + CHECK(from_v1(to_v1(i32)) == i32); + CHECK(from_v1(to_v1(i64)) == i64); + CHECK(from_v1(to_v1(f16)) == f16); + CHECK(from_v1(to_v1(f32)) == f32); + CHECK(from_v1(to_v1(f64)) == f64); +} diff --git a/lib/pcg/test/test_ff_dim_t.cc b/lib/pcg/test/test_ff_dim_t.cc new file mode 100644 index 0000000000..8c10173621 --- /dev/null +++ b/lib/pcg/test/test_ff_dim_t.cc @@ -0,0 +1,9 @@ +#include "doctest.h" +#include "pcg/file_format/v1/ff_dim.h" +#include "utils.h" + +using namespace FlexFlow; + +TEST_CASE("ff_dim_t") { + CHECK(from_v1(to_v1(ff_dim_t(11))) == 11); +} diff --git a/lib/pcg/test/test_initializer.cc b/lib/pcg/test/test_initializer.cc new file mode 100644 index 0000000000..ed94c8946e --- /dev/null +++ b/lib/pcg/test/test_initializer.cc @@ -0,0 +1,84 @@ +#include "doctest.h" +#include "pcg/file_format/v1/initializer.h" +#include "utils.h" +#include "utils/containers.h" +#include "utils/required.h" + +using namespace FlexFlow; + +TEST_CASE("GlorotInitializer") { + GlorotUniform i = {11}; + V1GlorotInitializer v1 = to_v1(i); + + // CHECK(from_v1(v1) == i); + + json j = v1; + check_fields(j, {{"seed", "11"}}); +} + +TEST_CASE("ZeroInitializer") { + ZeroInitializer i; + V1ZeroInitializer v1 = to_v1(i); + + // CHECK(from_v1(v1) == i); + + json j = v1; + check_fields(j, {}); +} + +TEST_CASE("UniformInitializer") { + UniformInitializer i = {77, 9.1, 4.3}; + V1UniformInitializer v1 = to_v1(i); + + // CHECK(from_v1(v1) == i); + + json j = v1; + check_fields(j, {{"seed", "77"}, {"min_val", "9.1"}, {"max_val", "4.3"}}); +} + +TEST_CASE("NormInitializer") { + NormInitializer i = {77, 9.1, 4.3}; + V1NormInitializer v1 = to_v1(i); + + // CHECK(from_v1(v1) == i); + + json j = v1; + check_fields(j, {{"seed", "77"}, {"mean", "9.1"}, {"stddev", "4.3"}}); +} + +TEST_CASE("ConstantInitializer") { + ConstantInitializer i = {32}; + V1ConstantInitializer v1 = to_v1(i); + + // CHECK(from_v1(v1) == i); + + json j = v1; + // The value field is a variant. Don't try to check for anything in the + // serialization because that will have been tested elsewhere. Just check that + // the value of "value" is an object which is good enough. + check_fields(j, {{"value", "{"}}); +} + +TEST_CASE("Initializer") { + Initializer ig = GlorotUniform{11}; + V1Initializer v1g = to_v1(ig); + // CHECK(from_v1(v1g) == ig); + + Initializer iz = ZeroInitializer{}; + V1Initializer v1z = to_v1(iz); + // CHECK(from_v1(v1z) == iz); + + Initializer iu = UniformInitializer{77, 9.1, 4.3}; + V1Initializer v1u = to_v1(iu); + // CHECK(from_v1(v1u) == iu); + + Initializer in = NormInitializer{77, 9.1, 4.3}; + V1Initializer v1n = to_v1(in); + // CHECK(from_v1(v1n) == in); + + Initializer ic = ConstantInitializer{32}; + V1Initializer v1c = to_v1(ic); + // CHECK(from_v1(v1c) == ic); + + // No need to check the JSON because Initializer is just a variant. +} diff --git a/lib/pcg/test/test_op.cc b/lib/pcg/test/test_op.cc new file mode 100644 index 0000000000..ddaa975936 --- /dev/null +++ b/lib/pcg/test/test_op.cc @@ -0,0 +1,107 @@ +#include "doctest.h" +#include "pcg/file_format/v1/op.h" +#include "utils.h" +#include "utils/json.h" + +using namespace FlexFlow; + +#define TEST(m) \ + do { \ + V1Op v1 = to_v1(Op::m); \ + \ + CHECK(from_v1(v1) == Op::m); \ + \ + json j = v1; \ + CHECK(str(j) == "\"" #m "\""); \ + /* TODO: Check deserialization.*/ \ + } while (0) + +TEST_CASE("V1Op") { + TEST(NOOP); + TEST(INPUT); + TEST(WEIGHT); + TEST(CONV2D); + TEST(DROPOUT); + TEST(LINEAR); + TEST(BATCHMATMUL); + TEST(POOL2D); + TEST(SCALAR_MULTIPLY); + TEST(SCALAR_ADD); + TEST(SCALAR_FLOOR_DIV); + TEST(SCALAR_TRUE_DIV); + TEST(SCALAR_SUB); + TEST(RELU); + TEST(IDENTITY); + TEST(SIGMOID); + TEST(TANH); + TEST(ELU); + TEST(FLAT); + TEST(SOFTMAX); + TEST(BATCHNORM); + TEST(CONCAT); + TEST(SPLIT); + TEST(EMBEDDING); + TEST(GROUP_BY); + TEST(CACHE); + TEST(AGGREGATE); + TEST(AGG_SPEC); + // TEST(OP_ELEMENTWISE) + TEST(RESHAPE); + TEST(REVERSE); + TEST(TRANSPOSE); + TEST(EW_ADD); + TEST(EW_MUL); + TEST(MATMUL); + TEST(MUL); + TEST(ENLARGE); + TEST(SQUEEZE); + TEST(UNSQUEEZE); + TEST(EW_SUB); + TEST(EW_DIV); + TEST(EW_EQUAL); + TEST(EW_GREATER); + TEST(EW_LESS); + TEST(EW_MAX); + TEST(EW_MIN); + TEST(REDUCE_ARGMAX); + TEST(REDUCE_ARGMIN); + TEST(REDUCE_MAX); + TEST(REDUCE_MEAN); + TEST(REDUCE_MIN); + TEST(REDUCE_PROD); + TEST(REDUCE_SUM); + TEST(PAD); + TEST(SHAPE); + TEST(SIZE); + TEST(TOPK); + TEST(WHERE); + TEST(CEIL); + TEST(CAST); + TEST(EXP); + TEST(ROUND); + TEST(LOG); + TEST(LOGICAL_NOT); + TEST(SQRT); + TEST(SIN); + TEST(COS); + TEST(LEAKYRELU); + TEST(SLICE); + TEST(RESIZE); + TEST(PRELU); + TEST(GELU); + TEST(MULTIHEAD_ATTENTION); + TEST(FUSED); + TEST(RSQRT); + TEST(POW); + TEST(MEAN); + TEST(LAYERNORM); + TEST(GATHER); + TEST(BROADCAST); + TEST(REPARTITION); + TEST(COMBINE); + TEST(REPLICATE); + TEST(REDUCTION); + TEST(BATCH); + TEST(PIPELINE); + TEST(FUSED_PARALLEL); +} diff --git a/lib/pcg/test/test_operator_attrs.cc b/lib/pcg/test/test_operator_attrs.cc new file mode 100644 index 0000000000..8bfe7ec66c --- /dev/null +++ b/lib/pcg/test/test_operator_attrs.cc @@ -0,0 +1,479 @@ +#include "doctest.h" +#include "pcg/file_format/v1/operator_attrs.h" +#include "utils.h" +#include "utils/containers.h" +#include "utils/json.h" +#include "utils/required.h" + +using namespace FlexFlow; + +// FIXME: Check deserialization as well. This is currently not implemented +// because of a bug that prevents req from being properly deserialized. +// +// The comments below may apply to multiple test cases. +// +// The checks for the attributes compare the number of fields to ensure that if +// a field is added/removed, the check fails since it may be necessary to update +// the test in that case. +// +// Floating point numbers may be serialized with additional digits. Just check +// that the digits that were provided in the initialization are present. Is it +// even guaranteed that those digits will appear in the serialized result? + +TEST_CASE("AggregateAttrs") { + AggregateAttrs a = {42, 3.14}; + V1AggregateAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"lambda_bal", "3.14"}, {"n", "42"}}); + // TODO: Check deserialization. +} + +TEST_CASE("AggregateSpec") { + AggregateSpecAttrs a = {42, 3.14}; + V1AggregateSpecAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"lambda_bal", "3.14"}, {"n", "42"}}); + // TODO: Check deserialization. +} + +TEST_CASE("MultiHeadAttentionAttrs") { + MultiHeadAttentionAttrs a = {1, 2, 3, 4, 5.67, false, true, false}; + V1MultiHeadAttentionAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, + {{"embed_dim", "1"}, + {"num_heads", "2"}, + {"kdim", "3"}, + {"vdim", "4"}, + {"dropout", "5.67"}, + {"bias", "false"}, + {"add_bias_kv", "true"}, + {"add_zero_attn", "false"}}); + // TODO: Check deserialization. +} + +TEST_CASE("BatchMatmulAttrs") { + BatchMatmulAttrs a = {12, 34}; + V1BatchMatmulAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"a_seq_length_dim", "12"}, {"b_seq_length_dim", "34"}}); + // TODO: Check deserialization. +} + +TEST_CASE("BatchNormAttrs") { + BatchNormAttrs a = {true}; + V1BatchNormAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"relu", "true"}}); + // TODO: Check deserialization. +} + +TEST_CASE("BroadcastAttrs") { + BroadcastAttrs a = {stack_vector()}; + a.target_dims.push_back(1); + a.target_dims.push_back(2); + a.target_dims.push_back(3); + a.target_dims.push_back(4); + V1BroadcastAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"target_dims", "[1,2,3,4]"}}); + // TODO: Check deserialization. +} + +TEST_CASE("CastAttrs") { + CastAttrs a = {DataType::HALF}; + V1CastAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"dtype", "\"HALF\""}}); + // TODO: Check deserialization. +} + +TEST_CASE("CombineAttrs") { + CombineAttrs a = {ff_dim_t(1), 2}; + V1CombineAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"combine_dim", "1"}, {"combine_degree", "2"}}); + // TODO: Check deserialization. +} + +TEST_CASE("ConcatAttrs") { + ConcatAttrs a = {ff_dim_t(43)}; + V1ConcatAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"axis", "43"}}); + // TODO: Check deserialization. +} + +TEST_CASE("Conv2DAttrs") { + Conv2DAttrs a = {1, 2, 3, 4, 5, 6, 7, 8, Activation::SIGMOID, false}; + V1Conv2DAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, + {{"out_channels", "1"}, + {"kernel_h", "2"}, + {"kernel_w", "3"}, + {"stride_h", "4"}, + {"stride_w", "5"}, + {"padding_h", "6"}, + {"padding_w", "7"}, + {"groups", "8"}, + {"activation", "\"SIGMOID\""}, + {"use_bias", "false"}}); + // TODO: Check deserialization. +} + +TEST_CASE("DropoutAttrs") { + DropoutAttrs a = {3.14, 9823749238472398ULL}; + V1DropoutAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"rate", "3.14"}, {"seed", "9823749238472398"}}); + // TODO: Check deserialization. +} + +TEST_CASE("ElementBinaryAttrs") { + ElementBinaryAttrs a = {Op::SQUEEZE, DataType::FLOAT, false, true}; + V1ElementBinaryAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields( + j, {{"should_broadcast_lhs", "false"}, {"should_broadcast_rhs", "true"}}); + // TODO: Check deserialization. +} + +TEST_CASE("ElementUnaryAttrs") { + ElementUnaryAttrs a = {Op::LOGICAL_NOT}; + V1ElementUnaryAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"op", "\"LOGICAL_NOT\""}}); + // TODO: Check deserialization. +} + +TEST_CASE("ElementUnaryScalarAttrs") { + ElementScalarUnaryAttrs a = {Op::SCALAR_FLOOR_DIV, 2.71828}; + V1ElementScalarUnaryAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"op", "\"SCALAR_FLOOR_DIV\""}, {"scalar", "2.71828"}}); + // TODO: Check deserialization. +} + +TEST_CASE("AggregateOp") { + V1AggregateOp v1Sum = to_v1(AggregateOp::SUM); + CHECK(from_v1(v1Sum) == AggregateOp::SUM); + CHECK(str(json(v1Sum)) == "\"SUM\""); + + V1AggregateOp v1Avg = to_v1(AggregateOp::AVG); + CHECK(from_v1(v1Avg) == AggregateOp::AVG); + CHECK(str(json(v1Avg)) == "\"AVG\""); +} + +TEST_CASE("EmbeddingAttrs") { + EmbeddingAttrs a = {1, 2, AggregateOp::SUM, DataType::DOUBLE}; + V1EmbeddingAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, + {{"num_entries", "1"}, + {"out_channels", "2"}, + {"aggr", "\"SUM\""}, + {"data_type", "\"DOUBLE\""}}); + // TODO: Check deserialization. +} + +TEST_CASE("FlatAttrs") { + FlatAttrs a; + V1FlatAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {}); + // TODO: Check deserialization. +} + +TEST_CASE("GatherAttrs") { + GatherAttrs a = {ff_dim_t(42)}; + V1GatherAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"dim", "42"}}); + // TODO: Check deserialization. +} + +TEST_CASE("Group_byAttrs") { + Group_byAttrs a = {11, 3.14}; + V1Group_byAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"n", "11"}, {"alpha", "3.14"}}); + // TODO: Check deserialization. +} + +TEST_CASE("InputAttrs") { + InputAttrs a; + V1InputAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {}); + // TODO: Check deserialization. +} + +TEST_CASE("LayerNormAttrs") { + LayerNormAttrs a = {stack_vector(), false, 2.71828}; + a.axes.push_back(ff_dim_t(19)); + a.axes.push_back(ff_dim_t(29)); + a.axes.push_back(ff_dim_t(39)); + V1LayerNormAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, + {{"axes", "[19,29,39]"}, + {"elementwise_affine", "false"}, + {"eps", "2.71828"}}); + // TODO: Check deserialization. +} + +TEST_CASE("L1RegularizerAttrs") { + L1RegularizerAttrs a = {3.14159}; + V1L1RegularizerAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"lambda", "3.14159"}}); + // TODO: Check deserialization. +} + +TEST_CASE("L2RegularizerAttrs") { + L2RegularizerAttrs a = {3.14159}; + V1L2RegularizerAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"lambda", "3.14159"}}); + // TODO: Check deserialization. +} + +TEST_CASE("LinearAttrs") { + L1RegularizerAttrs r = {1234.567}; + LinearAttrs a = {11, false, DataType::HALF, Activation::TANH, r}; + V1LinearAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, + {{"out_channels", "11"}, + {"use_bias", "false"}, + {"data_type", "\"HALF\""}, + {"activation", "\"TANH\""}, + {"regularizer", + "{\"index\":0,\"type\":\"::FlexFlow::V1L1RegularizerAttrs\"," + "\"value\":{\"lambda\":1234.567"}}); + // TODO: Check deserialization. +} + +TEST_CASE("NoopAttrs") { + NoopAttrs a; + V1NoopAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {}); + // TODO: Check deserialization. +} + +TEST_CASE("PoolOp") { + V1PoolOp v1Max = to_v1(PoolOp::MAX); + CHECK(from_v1(v1Max) == PoolOp::MAX); + CHECK(str(json(v1Max)) == "\"MAX\""); + + V1PoolOp v1Avg = to_v1(PoolOp::AVG); + CHECK(from_v1(v1Avg) == PoolOp::AVG); + CHECK(str(json(v1Avg)) == "\"AVG\""); +} + +TEST_CASE("Pool2DAttrs") { + Pool2DAttrs a = {1, 2, 3, 4, 5, 6, PoolOp::MAX, Activation::RELU}; + V1Pool2DAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, + {{"kernel_h", "1"}, + {"kernel_w", "2"}, + {"stride_h", "3"}, + {"stride_w", "4"}, + {"padding_h", "5"}, + {"padding_w", "6"}, + {"pool_type", "\"MAX\""}, + {"activation", "\"RELU\""}}); + // TODO: Check deserialization. +} + +TEST_CASE("ReduceAttrs") { + ReduceAttrs a = { + stack_vector(), Op::LEAKYRELU, true}; + a.axes.push_back(ff_dim_t(19)); + a.axes.push_back(ff_dim_t(29)); + a.axes.push_back(ff_dim_t(39)); + V1ReduceAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, + {{"axes", "[19,29,39]"}, + {"op_type", "\"LEAKYRELU\""}, + {"keepdims", "true"}}); + // TODO: Check deserialization. +} + +TEST_CASE("ReductionAttrs") { + ReductionAttrs a = {ff_dim_t(66), ff_dim_t(77)}; + V1ReductionAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"reduction_dim", "66"}, {"reduction_degree", "77"}}); + // TODO: Check deserialization. +} + +TEST_CASE("RepartitionAttrs") { + RepartitionAttrs a = {ff_dim_t(66), ff_dim_t(77)}; + V1RepartitionAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"repartition_dim", "66"}, {"repartition_degree", "77"}}); + // TODO: Check deserialization. +} + +TEST_CASE("ReplicateAttrs") { + ReplicateAttrs a = {ff_dim_t(66), ff_dim_t(77)}; + V1ReplicateAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"replicate_dim", "66"}, {"replicate_degree", "77"}}); + // TODO: Check deserialization. +} + +TEST_CASE("ReshapeAttrs") { + // TODO: IMPLEMENT THIS. +} + +TEST_CASE("ReverseAttrs") { + ReverseAttrs a = {ff_dim_t(11)}; + V1ReverseAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"axis", "11"}}); + // TODO: Check deserialization. +} + +TEST_CASE("SoftmaxAttrs") { + SoftmaxAttrs a = {ff_dim_t(37)}; + V1SoftmaxAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"dim", "37"}}); + // TODO: Check deserialization. +} + +TEST_CASE("SplitAttrs") { + SplitAttrs a = {stack_vector(), ff_dim_t(97)}; + a.splits.push_back(53); + a.splits.push_back(67); + V1SplitAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"splits", "[53,67]"}, {"axis", "97"}}); + // TODO: Check deserialization. +} + +TEST_CASE("TopKAttrs") { + TopKAttrs a = {17, true}; + V1TopKAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"k", "17"}, {"sorted", "true"}}); + // TODO: Check deserialization. +} + +TEST_CASE("TransposeAttrs") { + TransposeAttrs a = {stack_vector()}; + a.perm.push_back(ff_dim_t(3)); + a.perm.push_back(ff_dim_t(43)); + V1TransposeAttrs v1 = to_v1(a); + + CHECK(from_v1(v1) == a); + + json j = v1; + check_fields(j, {{"perm", "[3,43]"}}); + // TODO: Check deserialization. +} diff --git a/lib/pcg/test/test_parallel_tensor.cc b/lib/pcg/test/test_parallel_tensor.cc new file mode 100644 index 0000000000..a0123b8cee --- /dev/null +++ b/lib/pcg/test/test_parallel_tensor.cc @@ -0,0 +1,102 @@ +#include "doctest.h" +#include "pcg/file_format/v1/parallel_tensor.h" +#include "pcg/parallel_tensor.h" +#include "utils.h" +#include "utils/json.h" +#include "utils/required.h" + +using namespace FlexFlow; + +TEST_CASE("ParallelDim") { + ParallelDim d{4, 11, false}; + V1ParallelDim v1 = to_v1(d); + + CHECK(from_v1(v1) == d); + + json j = v1; + check_fields(j, {{"size", "4"}, + {"degree", "11"}, + {"is_replica_dim", "false"}}); + // TODO: Check deserialization. +} + +TEST_CASE("ParallelTensorDims") { + ParallelTensorDims d(std::vector{{3, 11, false}, {4, 2, true}}); + V1ParallelTensorDims v1 = to_v1(d); + + CHECK(from_v1(v1) == d); + + json j = v1; + + // Currently, there isn't a great way to check the actual JSON since the + // elements of the array will also be objects whose fields may be in a + // different order. The test utilities are ok for checking that the fields of + // an object were serialized but they are currently not up to checking if + // an array of objects was serialized. So just check that it is an array of + // objects that was serialized. + std::string strj = str(j); + CHECK(strj.find("{\"data\":[{") == 0); + CHECK(strj.substr(strj.size() - 3, 3) == "}]}"); + + // TODO: Check deserialization. +} + +TEST_CASE("ParallelTensorShape") { + ParallelTensorShape t{std::vector{{3, 11, false}, {4, 2, true}}, + DataType::FLOAT}; + V1ParallelTensorShape v1 = to_v1(t); + + CHECK(from_v1(v1) == t); + + json j = v1; + + // We can't check the serialization of the dims (see comment in the test case + // for ParallelTensorDims) + check_fields(j, {{"dims", "{\"data\":[{"}, {"data_type", "\"FLOAT\""}}); + + // TODO: Check deserialization. +} + +TEST_CASE("ParallelTensor") { + ParallelTensor t{std::vector{{3, 11, false}, {4, 2, true}}, + DataType::FLOAT, + CreateGrad::NO, + GlorotUniform{932}, + ParamSync::PS, + std::string("tensor")}; + V1ParallelTensor v1 = to_v1(t); + + CHECK(from_v1(v1) == t); + + json j = v1; + // shape is itself an object. Since the order of the fields there may not be + // consistent, just check that the key exists. This is particularly relevant + // since there is no field named shape in the tensor object. + check_fields(j, + {{"shape", "{"}, + {"create_gradients", "\"NO\""}, + {"initializer", "{"}, + {"sync_type", "\"PARAM_SERVER\""}, + {"name", "\"tensor\""}}); + // TODO: Check deserialization. + + ParallelTensor t0{ParallelTensorDims( + std::vector{{3, 11, false}, {4, 2, true}}), + DataType::FLOAT, + CreateGrad::YES, + nullopt, + nullopt, + nullopt}; + V1ParallelTensor v10 = to_v1(t0); + + CHECK(from_v1(v10) == t0); + + json j0 = v10; + check_fields(j0, + {{"shape", "{"}, + {"create_gradients", "\"YES\""}, + {"initializer", "null"}, + {"sync_type", "null"}, + {"name", "null"}}); + // TODO: Check deserialization. +} diff --git a/lib/pcg/test/test_param_sync.cc b/lib/pcg/test/test_param_sync.cc new file mode 100644 index 0000000000..5115a28a2b --- /dev/null +++ b/lib/pcg/test/test_param_sync.cc @@ -0,0 +1,15 @@ +#include "doctest.h" +#include "pcg/file_format/v1/param_sync.h" +#include "utils.h" + +using namespace FlexFlow; + +TEST_CASE("ParamSync") { + V1ParamSync v10 = to_v1(ParamSync::PS); + CHECK(from_v1(v10) == ParamSync::PS); + CHECK(str(json(v10)) == "\"PARAM_SERVER\""); + + V1ParamSync v11 = to_v1(ParamSync::NCCL); + CHECK(from_v1(v11) == ParamSync::NCCL); + CHECK(str(json(v11)) == "\"NCCL\""); +} diff --git a/lib/pcg/test/test_tensor.cc b/lib/pcg/test/test_tensor.cc new file mode 100644 index 0000000000..48c6dc083a --- /dev/null +++ b/lib/pcg/test/test_tensor.cc @@ -0,0 +1,58 @@ +#include "doctest.h" +#include "pcg/file_format/v1/tensor.h" +#include "pcg/tensor.h" +#include "utils.h" +#include "utils/json.h" +#include "utils/required.h" + +using namespace FlexFlow; + +TEST_CASE("TensorShape") { + TensorShape t{{3, 4, 5}, DataType::FLOAT}; + V1TensorShape v1 = to_v1(t); + + CHECK(from_v1(v1) == t); + + json j = v1; + check_fields(j, {{"dims", "[3,4,5]"}, {"data_type", "\"FLOAT\""}}); + // TODO: Check deserialization. +} + +TEST_CASE("Tensor") { + Tensor t{{3, 4, 5}, + DataType::FLOAT, + CreateGrad::NO, + GlorotUniform{932}, + ParamSync::PS, + std::string("tensor")}; + V1Tensor v1 = to_v1(t); + + CHECK(from_v1(v1) == t); + + json j = v1; + // shape is itself an object. Since the order of the fields there may not be + // consistent, just check that the key exists. This is particularly relevant + // since there is no field named shape in the tensor object. + check_fields(j, + {{"shape", "{"}, + {"create_gradients", "\"NO\""}, + {"initializer", "{"}, + {"sync_type", "\"PARAM_SERVER\""}, + {"name", "\"tensor\""}}); + // TODO: Check deserialization. + + Tensor t0{ + {3, 4, 5}, DataType::FLOAT, CreateGrad::YES, nullopt, nullopt, nullopt}; + V1Tensor v10 = to_v1(t0); + + CHECK(from_v1(v10) == t0); + + json j0 = v10; + check_fields(j0, + {{"shape", "{"}, + {"create_gradients", "\"YES\""}, + {"initializer", "null"}, + {"sync_type", "null"}, + {"name", "null"}}); + // TODO: Check deserialization. +} diff --git a/lib/pcg/test/utils.cc b/lib/pcg/test/utils.cc new file mode 100644 index 0000000000..fe6220a9c4 --- /dev/null +++ b/lib/pcg/test/utils.cc @@ -0,0 +1,28 @@ +#include "utils.h" +#include "doctest.h" +#include "utils/json.h" + +namespace FlexFlow { + +std::string str(json const &j) { + std::stringstream ss; + ss << j; + return ss.str(); +} + +void check_fields(json const &j, std::vector const &fields) { + std::string strj = str(j); + if (fields.size()) { + for (auto const &[key, val] : fields) { + std::stringstream fs; + fs << "\"" << key << "\":" << val; + std::string field = fs.str(); + + CHECK(strj.find(field) != std::string::npos); + } + } else { + CHECK(strj == "null"); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/utils.h b/lib/pcg/test/utils.h new file mode 100644 index 0000000000..49494662b1 --- /dev/null +++ b/lib/pcg/test/utils.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_PCG_TEST_UTILS_H +#define _FLEXFLOW_PCG_TEST_UTILS_H + +#include "utils/json.h" + +namespace FlexFlow { + +std::string str(json const &j); + +using Field = std::pair; +void check_fields(json const &j, std::vector const &fields); + +} // namespace FlexFlow + +#endif // _FLEXFLOW_PCG_TEST_UTILS_H diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h index d27174f474..1e60e99987 100644 --- a/lib/utils/include/utils/exception.decl.h +++ b/lib/utils/include/utils/exception.decl.h @@ -9,12 +9,24 @@ namespace FlexFlow { #ifdef FF_REQUIRE_IMPLEMENTED #define NOT_IMPLEMENTED() static_assert(false, "Function not yet implemented"); #else -#define NOT_IMPLEMENTED() throw not_implemented(); +#define NOT_IMPLEMENTED() throw not_implemented(__FILE__, __LINE__); #endif class not_implemented : public std::logic_error { public: - not_implemented(); + not_implemented(char const *file, unsigned line); +}; + +// This macro should only be used when code is known to be unreachable under +// normal circumstances but may get hit because of a *developer* (not user) +// error. An example of this would be adding a member to an enum but not +// updating all switch statements that use the enum. It is primarily provided +// to squelch compiler warnings about such situations. +#define NOT_REACHABLE() throw not_reachable(__FILE__, __LINE__); + +class not_reachable : public std::logic_error { +public: + not_reachable(char const *file, unsigned line); }; template diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 218f72d8af..bd2ea83c1f 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -67,7 +67,7 @@ auto formatter<::std::unordered_set>::format( -> decltype(ctx.out()) { CHECK_FMTABLE(T); - std::string result = join_strings( + std::string result = ::FlexFlow::join_strings( m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); return formatter::format(result, ctx); } @@ -78,7 +78,7 @@ auto formatter<::std::vector>::format(::std::vector const &m, FormatContext &ctx) -> decltype(ctx.out()) { CHECK_FMTABLE(T); - std::string result = join_strings( + std::string result = ::FlexFlow::join_strings( m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); return formatter::format(result, ctx); } diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 9f1d2e12d5..ed18d6e1c5 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -233,6 +233,8 @@ tl::optional get_imm_post_dominator(MultiDiGraphView const &, Node const &); tl::optional get_imm_post_dominator(DiGraphView const &, std::unordered_set const &); +tl::optional get_imm_post_dominator(MultiDiGraphView const &, + std::unordered_set const &); std::vector get_dfs_ordering(DiGraphView const &, diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 9a655ae072..3c5cf6ed10 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -11,8 +11,8 @@ namespace FlexFlow { template struct cow_ptr_t { - static_assert(is_clonable::value, - "cow_ptr_t requires the type to have a clone() method"); + // static_assert(is_clonable::value, + // "cow_ptr_t requires the type to have a clone() method"); cow_ptr_t(std::shared_ptr const &ptr) : ptr(ptr) {} cow_ptr_t(std::shared_ptr &&ptr) : ptr(std::move(ptr)) {} diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index ffb69b717d..54dc397d06 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -47,9 +47,17 @@ struct OutputLabelledMultiDiGraph { OutputLabelledMultiDiGraph & operator=(OutputLabelledMultiDiGraph const &other) = default; - operator LabelledMultiDiGraphView() const; - operator MultiDiGraphView() const; - operator GraphView() const; + operator LabelledMultiDiGraphView() const { + NOT_IMPLEMENTED(); + } + + operator MultiDiGraphView() const { + NOT_IMPLEMENTED(); + } + + operator GraphView() const { + NOT_IMPLEMENTED(); + } friend void swap(OutputLabelledMultiDiGraph &lhs, OutputLabelledMultiDiGraph &rhs) { diff --git a/lib/utils/include/utils/graph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.h index 77bd3aedea..deb32fceb0 100644 --- a/lib/utils/include/utils/graph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph.h @@ -47,6 +47,7 @@ struct MultiDiGraph { MultiDiGraph() = delete; MultiDiGraph(MultiDiGraph const &) = default; MultiDiGraph &operator=(MultiDiGraph const &) = default; + virtual ~MultiDiGraph() = default; operator MultiDiGraphView() const; diff --git a/lib/utils/include/utils/graph/multidigraph_interfaces.h b/lib/utils/include/utils/graph/multidigraph_interfaces.h index 291021ae6b..1fcd83903c 100644 --- a/lib/utils/include/utils/graph/multidigraph_interfaces.h +++ b/lib/utils/include/utils/graph/multidigraph_interfaces.h @@ -52,6 +52,7 @@ struct IMultiDiGraph : public IMultiDiGraphView, public IGraph { } virtual IMultiDiGraph *clone() const override = 0; + virtual ~IMultiDiGraph() = default; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraph); diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json.h index a753c52daa..b5207d0d6d 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json.h @@ -10,6 +10,100 @@ namespace FlexFlow { +template +struct json_type_name { + static constexpr char const *name = "anonymous"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "bool"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "char"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "signed char"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "short"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "int"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "long"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "long long"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "unsigned char"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "unsigned short"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "unsigned int"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "unsigned long"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "unsigned long long"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "float"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "double"; +}; + +template <> +struct json_type_name { + static constexpr char const *name = "long double"; +}; + +template +struct json_type_name< + T, + std::enable_if_t && is_visitable::value>> { + static constexpr char const *name = visit_struct::get_name(); +}; + +template +struct json_type_name< + T, + std::enable_if_t && !is_visitable::value>> { + static constexpr char const *name = "unnamed struct"; +}; + template struct is_json_serializable : std::false_type {}; @@ -143,39 +237,58 @@ struct VariantToJsonFunctor { void operator()(T const &t) { static_assert(is_jsonable::value, ""); - j["type"] = get_name(t); + // The type field is not used for deserialization. It is there primarily + // as a debugging aid if needed. + j["type"] = json_type_name::name; j["value"] = t; } }; template void variant_to_json(json &j, variant const &v) { - visit(::FlexFlow::VariantToJsonFunctor{j}, v.value); + // The index indicates the type of the variant that is serialized. The + // actual serialization of the type will be handled by the functor. The + // variant itself is lost in the visitor, so this needs to be done first. The + // type field is not used in deserialization - only the index is. + j["index"] = v.index(); + visit(::FlexFlow::VariantToJsonFunctor{j}, v); } -template -struct VariantFromJsonFunctor { - VariantFromJsonFunctor(json const &j) : j(j) {} +template +optional variant_from_json_impl(json const &j) { + using Type = typename variant_alternative::type; - json const &j; - - template - optional operator()(std::integral_constant const &) const { - using Type = typename variant_alternative::type; + if (j.at("index").get() == Idx) { + return j.at("value").get(); + } + return nullopt; +} - if (visit_struct::get_name()) { - return j.at("value").get(); +template +optional variant_from_json_impl(json const &j, + std::index_sequence) { + // If there were no errors when parsing, all but one element of the array + // will be nullopt. This is because each call to variant_from_json_impl will + // have a unique index and exactly one of them will match the index in the + // json object. + std::array, sizeof...(Is)> results{ + variant_from_json_impl(j)...}; + for (optional &maybe : results) { + if (maybe) { + return maybe.value(); } } -}; + return nullopt; +} template variant variant_from_json(json const &j) { - ::FlexFlow::VariantFromJsonFunctor<::FlexFlow::variant> func(j); - auto result = seq_map(func, seq_enumerate_args_t{}); + using Variant = ::FlexFlow::variant; + optional result = variant_from_json_impl( + j, std::make_index_sequence()); if (!result.has_value()) { throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("type").get()); + j.at("index").get()); } return result.value(); } @@ -223,7 +336,7 @@ struct adl_serializer< typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { static void to_json(json &j, ::FlexFlow::optional const &t) { if (t.has_value()) { - to_json(j, t.value()); + j = t.value(); } else { j = nullptr; } diff --git a/lib/utils/include/utils/required.h b/lib/utils/include/utils/required.h index 499994770a..7af22db3d4 100644 --- a/lib/utils/include/utils/required.h +++ b/lib/utils/include/utils/required.h @@ -24,24 +24,32 @@ struct adl_serializer<::FlexFlow::req> { }; } // namespace nlohmann -namespace fmt { +/* namespace fmt { */ -template -struct formatter<::FlexFlow::req> : formatter { - template - auto format(::FlexFlow::req const &t, FormatContext &ctx) - -> decltype(ctx.out()) { - return formatter::format(static_cast(t), ctx); - } -}; +/* template */ +/* struct formatter<::FlexFlow::req> : formatter { */ +/* template */ +/* auto format(::FlexFlow::req const &t, FormatContext &ctx) */ +/* -> decltype(ctx.out()) { */ +/* return formatter::format(static_cast(t), ctx); */ +/* } */ +/* }; */ -} // namespace fmt +/* } // namespace fmt */ namespace FlexFlow { static_assert(is_json_serializable>::value, ""); static_assert(is_json_deserializable>::value, ""); static_assert(is_jsonable>::value, ""); -static_assert(is_fmtable>::value, ""); +CHECK_FMTABLE(req); +CHECK_FMTABLE(std::vector); +CHECK_FMTABLE(required_inheritance_impl>); +static_assert( + std::is_base_of>, + req>>::value, + ""); +CHECK_FMTABLE(req>); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/required_core.h b/lib/utils/include/utils/required_core.h index 3336e38243..5677e84b7d 100644 --- a/lib/utils/include/utils/required_core.h +++ b/lib/utils/include/utils/required_core.h @@ -1,12 +1,26 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H +#include "fmt.decl.h" #include "hash-utils-core.h" +#include "test_types.h" #include "type_traits_core.h" #include namespace FlexFlow { +template +struct enable_if_valid {}; + +template +struct enable_if_valid, Args...> : type_identity {}; + +/* required_wrapper_impl() + std::declval())>> */ +/* operator+(required_wrapper_impl const &lhs, required_wrapper_impl const + * &rhs) { */ +/* /1* return 1; *1/ */ +/* } */ + template struct required_wrapper_impl { public: @@ -34,6 +48,10 @@ struct required_wrapper_impl { return static_cast(this->m_value); } + friend T format_as(required_wrapper_impl const &r) { + return static_cast(r); + } + /* T const &operator*() const { */ /* return this->m_value; */ /* } */ @@ -42,12 +60,56 @@ struct required_wrapper_impl { /* return &this->m_value; */ /* } */ - /* bool operator==(T const &other) const { */ - /* return this->m_value == other; */ + template + enable_if_t::value, bool> + operator==(required_wrapper_impl const &rhs) const { + return this->m_value == rhs.m_value; + } + + template + enable_if_t::value, bool> + operator==(TT const &rhs) const { + return this->m_value == rhs; + } + + /* friend enable_if_t::value, bool> */ + /* operator==(required_wrapper_impl const &lhs, T const &rhs) { */ + /* return lhs.m_value == rhs; */ /* } */ - /* bool operator!=(T const &other) const { */ - /* return this->m_value != other; */ + /* friend enable_if_t::value, bool> */ + /* operator==(T const &lhs, required_wrapper_impl const &rhs) { */ + /* return lhs == rhs.m_value; */ + /* } */ + + template + enable_if_t::value, bool> + operator!=(required_wrapper_impl const &rhs) const { + return this->m_value != rhs.m_value; + } + + /* friend enable_if_t::value, + * required_wrapper_impl() + std::declval())>> */ + /* operator+(required_wrapper_impl const &lhs, required_wrapper_impl + * const &rhs) { */ + /* /1* return 1; *1/ */ + /* } */ + /* required_wrapper_impl */ + /* operator+(required_wrapper_impl const &rhs) { */ + /* Out o = this->m_value + rhs.m_value; */ + /* return required_wrapper_impl{o}; */ + /* } */ + + /* template ::value> = true> */ + /* required_wrapper_impl operator-(required_wrapper_impl const &rhs) { */ + /* return {this->m_value - rhs.m_value}; */ + /* } */ + + /* template ::value> = true> */ + /* required_wrapper_impl operator*(required_wrapper_impl const &rhs) { */ + /* return {this->m_value * rhs.m_value}; */ /* } */ /* bool operator<(T const &other) const { */ @@ -68,8 +130,35 @@ struct required_inheritance_impl : public T { using T::T; required_inheritance_impl() = delete; - required_inheritance_impl(T const &); - required_inheritance_impl(T &&t); + required_inheritance_impl(T const &t) : T(t) {} + required_inheritance_impl(T &&t) : T(t) {} + + required_inheritance_impl(required_inheritance_impl const &) = default; + required_inheritance_impl(required_inheritance_impl &&) = default; + + required_inheritance_impl & + operator=(required_inheritance_impl const &) = default; + required_inheritance_impl & + operator=(required_inheritance_impl &&) = default; + + friend enable_if_t::value, bool> + operator==(required_inheritance_impl const &lhs, + required_inheritance_impl const &rhs) { + return static_cast(lhs) == static_cast(rhs); + } + + friend enable_if_t::value, bool> + operator!=(required_inheritance_impl const &lhs, + required_inheritance_impl const &rhs) { + return static_cast(lhs) != static_cast(rhs); + } + + friend std::string format_as(required_inheritance_impl const &r) { + return ""; + /* static_assert(is_fmtable::value, ""); */ + + /* return static_cast(r); */ + } template required_inheritance_impl( @@ -78,8 +167,6 @@ struct required_inheritance_impl : public T { !std::is_same::value>::type * = 0) : required_inheritance_impl(static_cast(tt)) {} - operator T() const; - template ::value && !std::is_same::value), @@ -116,19 +203,22 @@ struct remove_req> { template using remove_req_t = typename remove_req::type; -static_assert(std::is_convertible, int>::value, ""); -static_assert(is_static_castable, int *>::value, ""); static_assert( - std::is_same< - void_t>() == std::declval())>, - void>::value, - ""); -static_assert(is_list_initializable, bool>::value, ""); -static_assert( - std::is_same< - void_t>() + std::declval())>, - void>::value, + is_equal_comparable>>::value, ""); +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH( + required_inheritance_impl); +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH( + required_wrapper_impl); + +/* static_assert(std::is_same>() + * + std::declval>()), + * required_wrapper_impl>::value, ""); */ + +static_assert(std::is_copy_constructible>::value, ""); + +static_assert(std::is_convertible, int>::value, ""); +static_assert(is_static_castable, int *>::value, ""); } // namespace FlexFlow diff --git a/lib/utils/include/utils/strong_typedef.h b/lib/utils/include/utils/strong_typedef.h index f700a20c79..cba8880b37 100644 --- a/lib/utils/include/utils/strong_typedef.h +++ b/lib/utils/include/utils/strong_typedef.h @@ -5,6 +5,7 @@ #include #include #include +#include "utils/json.h" namespace FlexFlow { @@ -193,8 +194,28 @@ struct numerical_typedef : strong_typedef { } }; +template struct is_strong_typedef : std::false_type {}; + +template +struct is_strong_typedef>> : std::true_type {}; + +template +inline constexpr bool is_strong_typedef_v = is_strong_typedef::value; } // namespace FlexFlow +namespace nlohmann { +template +struct adl_serializer>> { + static T from_json(json const &j) { + return {j.template get<::FlexFlow::underlying_type_t>()}; + } + + static void to_json(json &j, T const &t) { + j = static_cast<::FlexFlow::underlying_type_t>(t); + } +}; +} // namespace nlohmann + #define MAKE_TYPEDEF_HASHABLE(TYPEDEF_NAME) \ namespace std { \ template <> \ diff --git a/lib/utils/include/utils/test_types.h b/lib/utils/include/utils/test_types.h index 2cac547bb6..984c0bc60d 100644 --- a/lib/utils/include/utils/test_types.h +++ b/lib/utils/include/utils/test_types.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_TEST_TYPES_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_TEST_TYPES_H -#include "type_traits.h" +#include "type_traits_core.h" namespace FlexFlow { @@ -12,7 +12,10 @@ enum capability { EQ, CMP, DEFAULT_CONSTRUCTIBLE, - COPYABLE, + MOVE_CONSTRUCTIBLE, + MOVE_ASSIGNABLE, + COPY_CONSTRUCTIBLE, + COPY_ASSIGNABLE, PLUS, PLUSEQ, FMT @@ -51,14 +54,38 @@ struct test_type_t { typename std::enable_if::value, bool>::type = true> test_type_t() = delete; - template ::value, bool>::type = true> test_type_t(test_type_t const &); - template ::value, bool>::type = true> test_type_t(test_type_t const &) = delete; + template ::value, bool>::type = true> + test_type_t &operator=(test_type_t const &); + + template ::value, bool>::type = true> + test_type_t &operator=(test_type_t const &) = delete; + + template ::value, bool>::type = true> + test_type_t(test_type_t &&); + + template ::value, bool>::type = true> + test_type_t(test_type_t &&) = delete; + + template ::value, bool>::type = true> + test_type_t &operator=(test_type_t &&); + + template ::value, bool>::type = true> + test_type_t &operator=(test_type_t &&) = delete; + template typename std::enable_if::value, bool>::type operator==(test_type_t const &) const; @@ -102,6 +129,11 @@ using cmp = test_type_t; using hash_cmp = test_type_t; using plusable = test_type_t; using fmtable = test_type_t; +using well_behaved_value_type = test_type_t; } // namespace test_types } // namespace FlexFlow diff --git a/lib/utils/include/utils/type_traits.h b/lib/utils/include/utils/type_traits.h index dc8fe2cf57..ee45e8dc2e 100644 --- a/lib/utils/include/utils/type_traits.h +++ b/lib/utils/include/utils/type_traits.h @@ -65,24 +65,6 @@ template struct is_streamable())>> : std::true_type {}; -template -struct is_equal_comparable : std::false_type {}; - -template -struct is_equal_comparable< - T, - void_t() == std::declval()))>> - : std::true_type {}; - -template -struct is_neq_comparable : std::false_type {}; - -template -struct is_neq_comparable< - T, - void_t() != std::declval()))>> - : std::true_type {}; - template struct is_lt_comparable : std::false_type {}; @@ -92,19 +74,6 @@ struct is_lt_comparable< void_t() < std::declval()))>> : std::true_type {}; -template -struct is_hashable : std::false_type {}; - -template -struct is_hashable< - T, - void_t>()(std::declval())))>> - : std::true_type {}; - -#define CHECK_HASHABLE(...) \ - static_assert(is_hashable<__VA_ARGS__>::value, \ - #__VA_ARGS__ " should be hashable (but is not)"); - template