diff --git a/flake.lock b/flake.lock index b36a96ee80..1aad68ae29 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1722405648, - "narHash": "sha256-+9cRIT+bwo7qxI966HjwR2Sw37CcXD1JlG9nw+vq2lY=", + "lastModified": 1722923482, + "narHash": "sha256-myUec+oBcnKNCqLQqSiPCyXFsIsvlrsGoj/mQFlHVrY=", "owner": "lockshaw", "repo": "proj", - "rev": "3674de6208c52f3a022e8f00660ee01d580aa466", + "rev": "c650b0e52337652ea7190131988c0370e0ee7f25", "type": "github" }, "original": { diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 12eacb2a30..af7756c635 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -12,10 +12,10 @@ #include "utils/containers/contains_key.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" #include "utils/exception.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.h" diff --git a/lib/local-execution/src/op_arg_spec.cc b/lib/local-execution/src/local-execution/op_arg_spec.cc similarity index 100% rename from lib/local-execution/src/op_arg_spec.cc rename to lib/local-execution/src/local-execution/op_arg_spec.cc diff --git a/lib/local-execution/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc index 789ed2cd63..33d62b713c 100644 --- a/lib/local-execution/src/ops/pool_2d.cc +++ b/lib/local-execution/src/ops/pool_2d.cc @@ -3,7 +3,6 @@ #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/pool_2d.h" -#include "utils/exception.decl.h" #include "utils/exception.h" #include "utils/hash-utils.h" diff --git a/lib/local-execution/src/ops/transpose.cc b/lib/local-execution/src/ops/transpose.cc index 5c3c1dd1ca..3e4ac15db3 100644 --- a/lib/local-execution/src/ops/transpose.cc +++ b/lib/local-execution/src/ops/transpose.cc @@ -17,7 +17,6 @@ #include "kernels/transpose_kernels.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/transpose.h" -#include "utils/exception.decl.h" using namespace FlexFlow::Kernels::Transpose; diff --git a/lib/op-attrs/include/op-attrs/as_dot.h b/lib/op-attrs/include/op-attrs/as_dot.h deleted file mode 100644 index d92557c2f4..0000000000 --- a/lib/op-attrs/include/op-attrs/as_dot.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H - -#include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "utils/record_formatter.h" - -namespace FlexFlow { - -RecordFormatter as_dot(ComputationGraphOpAttrs const &); -RecordFormatter as_dot(PCGOperatorAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h index 4be17798f7..03f38bb8f9 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -2,10 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H #include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "utils/record_formatter.h" namespace FlexFlow { OperatorType get_op_type(ComputationGraphOpAttrs const &); +RecordFormatter as_dot(ComputationGraphOpAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h new file mode 100644 index 0000000000..f9f6d00532 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H + +#include "op-attrs/dim_ordered.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/count.h" + +namespace FlexFlow { + +/** + * @brief Generate a map from indices to elements of \p c. + * + * @note We return a std::map to prevent mixups of \ref ff_dim_t and + * \ref legion_dim_t. Note that std::map provides ordered iteration in + * increasing order, so iterating through the result of this function should + * function as expected. + */ +template +std::map enumerate(FFOrdered const &ff_ordered) { + std::map result; + for (int raw_ff_dim : count(ff_ordered.size())) { + ff_dim_t ff_dim = ff_dim_t{raw_ff_dim}; + result.insert({ff_dim, ff_ordered.at(ff_dim)}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 612c226a13..c27bbb190f 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -1,228 +1,15 @@ #ifndef _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H #define _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "ops/reverse.h" -#include "tensor_shape.h" -#include "utils/containers/get_only.h" -#include "utils/optional.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include namespace FlexFlow { -template -struct has_unary_output_t : std::false_type {}; -template -struct has_unary_input_t : std::false_type {}; -template -struct has_binary_input_t : std::false_type {}; - -template -struct has_multi_output_t : std::true_type {}; -template -struct has_multi_input_t : std::true_type {}; - -template -struct has_multi_output_t< - T, - typename std::enable_if::value>::type> - : std::false_type {}; - -template -struct has_multi_input_t< - T, - typename std::enable_if<(has_unary_input_t::value || - has_binary_input_t::value)>::type> - : std::false_type {}; - -/* template struct output_type_t { using - * type = std::vector; }; */ - -template -typename std::enable_if::value, bool>::type - is_valid(T const &t, std::vector const &shapes) { - if (shapes.size() != 1) { - return false; - } - - return is_valid(t, get_only(shapes)); -} - -template -typename std::enable_if::value, bool>::type - is_valid(T const &t, std::vector const &shapes) { - if (shapes.size() != 2) { - return false; - } - - return is_valid(t, shapes.at(0), shapes.at(1)); -} - -template -typename std::enable_if<(has_unary_input_t::value && - has_unary_output_t::value), - ParallelTensorShape>::type - output_shapes(T const &t, std::vector const &shapes) { - return output_shape(t, get_only(shapes)); -} - -template -typename std::enable_if<(has_binary_input_t::value && - has_unary_output_t::value), - std::vector>::type - output_shapes(T const &t, std::vector const &shapes) { - assert(shapes.size() == 2); - - return {output_shape(t, shapes.at(0), shapes.at(1))}; -} - -TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); -std::vector - get_tensor_shapes_unsafe(std::vector const &); - -template -TensorShape get_output_shape(Attrs const &attrs, TensorShape const &shape) { - NOT_IMPLEMENTED(); -} - -template -TensorShape get_output_shape(Attrs const &attrs, - TensorShape const &, - TensorShape const &) { - NOT_IMPLEMENTED(); -} - -template -TensorShape get_output_shape(Attrs const &attrs, - std::vector const &) { - NOT_IMPLEMENTED(); -} -template -std::vector get_output_shapes(Attrs const &attrs, - TensorShape const &); -template -std::vector get_output_shapes(Attrs const &attrs, - TensorShape const &, - TensorShape const &) { - NOT_IMPLEMENTED(); -} -template -std::vector get_output_shapes(Attrs const &attrs, - std::vector const &); - -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &); -std::vector get_output_shapes(GatherAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReduceAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReverseAttrs const &, - ParallelTensorShape const &); -std::vector get_output_shapes(SplitAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(TopKAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(TransposeAttrs const &, - std::vector const &); - -struct GetOutputShapesFunctor { - GetOutputShapesFunctor(std::vector const &s) : s(s) {} - - std::vector const &s; - - template - std::vector operator()(T const &t) { - return get_output_shapes(t, s); - } -}; - -template std::vector - get_output_shapes(std::variant const &t, - std::vector const &s) { - return get_output_shape(GetOutputShapesFunctor{s}, t); -} - -template -typename std::enable_if::value, std::optional>::type - get_num_outputs(T const &) { - return std::nullopt; -} - -template -typename std::enable_if::value, std::optional>::type - get_num_outputs(T const &) { - return 1; -} - -int get_num_outputs(SplitAttrs const &attrs); - -template -bool is_valid(T const &t, std::vector const &shapes) { - auto num_outputs = get_num_outputs(t); - if (num_outputs.has_value() && shapes.size() != num_outputs.value()) { - return false; - } - - for (ParallelTensorShape const &shape : shapes) { - if (!is_valid(shape)) { - return false; - } - } - - return is_valid_internal(t, shapes); -} - -template -typename std::enable_if::value, bool>::type - is_valid_internal(T const &t, - std::vector const &shapes) { - return is_valid_internal(t, get_only(shapes)); -} - -template -typename std::enable_if::value, bool>::type - is_valid_internal(T const &t, - std::vector const &shapes) { - return is_valid_internal(t, shapes.at(0), shapes.at(1)); -} - -bool is_valid_internal(MultiHeadAttentionAttrs const &, - std::vector const &); -bool is_valid_internal(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ConcatAttrs const &, - std::vector const &); -bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(GatherAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &); + get_output_shapes(PCGOperatorAttrs const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/is_valid.h b/lib/op-attrs/include/op-attrs/is_valid.h new file mode 100644 index 0000000000..2d91307e19 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/is_valid.h @@ -0,0 +1,59 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_IS_VALID_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_IS_VALID_H + +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +template +bool is_valid(T const &t, std::vector const &shapes) { + auto num_outputs = get_num_outputs(t); + if (num_outputs.has_value() && shapes.size() != num_outputs.value()) { + return false; + } + + for (ParallelTensorShape const &shape : shapes) { + if (!is_valid(shape)) { + return false; + } + } + + return is_valid_internal(t, shapes); +} + +bool is_valid_internal(MultiHeadAttentionAttrs const &, + std::vector const &); +bool is_valid_internal(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); +bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ConcatAttrs const &, + std::vector const &); +bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); +bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(GatherAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); +bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index e126c425dc..40f57d08af 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_ATTENTION_ATTRS_H #define _FLEXFLOW_ATTENTION_ATTRS_H -#include "core.h" #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" #include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index b9a1d87a75..8afcbb06b1 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -1,12 +1,13 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H -#include "core.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { +TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index ead779c553..f85481b45b 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -3,8 +3,8 @@ #include "op-attrs/ops/cast_attrs.dtg.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index 8a72708971..f3ac8494c0 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -1,13 +1,20 @@ -#ifndef _FLEXFLOW_CONCAT_ATTRS_H -#define _FLEXFLOW_CONCAT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_H -#include "core.h" #include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ConcatAttrs); +TensorShape get_output_shape(ConcatAttrs const &, + std::vector const &); +ParallelTensorShape get_output_shape(ConcatAttrs const &, + std::vector const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index d5d9069f51..676d21c59b 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -1,14 +1,19 @@ #ifndef _FLEXFLOW_FLAT_ATTRS_H #define _FLEXFLOW_FLAT_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(FlatAttrs); +TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); +ParallelTensorShape get_output_shape(FlatAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index 79516a8862..42efd13b60 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_GATHER_ATTRS_H #define _FLEXFLOW_GATHER_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" @@ -9,6 +9,13 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(GatherAttrs); +TensorShape get_output_shape(GatherAttrs const &, + TensorShape const &input, + TensorShape const &index); +ParallelTensorShape get_output_shape(GatherAttrs const &, + ParallelTensorShape const &input, + ParallelTensorShape const &index); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 9fe0ee2c2d..fe92c77a52 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -1,15 +1,17 @@ #ifndef _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H #define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/input_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(InputAttrs); -ParallelTensorShape get_output_shape(InputAttrs const &); +TensorShape get_output_shape(InputAttrs const &); +ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 94f9b9e147..29b0b2f514 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -1,9 +1,10 @@ #ifndef _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/layer_norm_attrs.dtg.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index dd6948165e..795ba19ae8 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -5,12 +5,15 @@ #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "utils/record_formatter.h" #include namespace FlexFlow { CHECK_VALID_OP_ATTR(LinearAttrs); +RecordFormatter as_dot(LinearAttrs const &); + tl::expected get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions.h deleted file mode 100644 index 58d372d9e5..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions.h +++ /dev/null @@ -1,75 +0,0 @@ -#ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H -#define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H - -#include "core.h" -#include "utils/exception.h" -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -enum class LossFunction { - CATEGORICAL_CROSSENTROPY, - SPARSE_CATEGORICAL_CROSSENTROPY, - MEAN_SQUARED_ERROR_AVG_REDUCE, - MEAN_SQUARED_ERROR_SUM_REDUCE, - IDENTITY -}; - -LossFunction parse_loss_function_name(std::string const &); - -struct SparseCategoricalCrossEntropyLossAttrs { - req replace_labels; // for aggregate_spec: More predictions than labels -}; -FF_VISITABLE_STRUCT(SparseCategoricalCrossEntropyLossAttrs, replace_labels); -CHECK_VALID_OP_ATTR(SparseCategoricalCrossEntropyLossAttrs); - -struct OtherLossAttrs { - req loss_type; -}; -FF_VISITABLE_STRUCT(OtherLossAttrs, loss_type); -CHECK_VALID_OP_ATTR(OtherLossAttrs); - -using LossAttrs = - std::variant; - -LossFunction get_loss_function(OtherLossAttrs const &); -LossFunction get_loss_function(SparseCategoricalCrossEntropyLossAttrs const &); -LossFunction get_loss_function(LossAttrs const &); - -} // namespace FlexFlow - -namespace fmt { - -template <> -struct formatter<::FlexFlow::LossFunction> : formatter { - template - auto format(::FlexFlow::LossFunction d, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (d) { - case LossFunction::CATEGORICAL_CROSSENTROPY: - name = "CategoricalCrossEntropy"; - break; - case LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY: - name = "SparseCategoricalCrossEntropy"; - break; - case LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE: - name = "MeanSquaredErrorAvgReduce"; - break; - case LossFunction::MEAN_SQUARED_ERROR_SUM_REDUCE: - name = "MeanSquaredErrorSumReduce"; - break; - case LossFunction::IDENTITY: - name = "Identity"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml new file mode 100644 index 0000000000..17293095e4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LossAttrs" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +includes = [ + "op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.h", + "op-attrs/ops/loss_functions/other_loss_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::SparseCategoricalCrossEntropyLossAttrs" +key = "sparse_categorical_cross_entropy_loss" + +[[values]] +type = "::FlexFlow::OtherLossAttrs" +key = "other_loss" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml new file mode 100644 index 0000000000..9658202a45 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LossFunction" +features = [ + "fmt", + "hash", + "rapidcheck", + "json", +] + +[[values]] +name = "CATEGORICAL_CROSSENTROPY" + +[[values]] +name = "SPARSE_CATEGORICAL_CROSSENTROPY" + +[[values]] +name = "MEAN_SQUARED_ERROR_AVG_REDUCE" + +[[values]] +name = "MEAN_SQUARED_ERROR_SUM_REDUCE" + +[[values]] +name = "IDENTITY" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h new file mode 100644 index 0000000000..ca8f3e6602 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H +#define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H + +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "op-attrs/ops/loss_functions/loss_function.dtg.h" + +namespace FlexFlow { + +CHECK_VALID_OP_ATTR(LossAttrs); + +LossFunction parse_loss_function_name(std::string const &); + +LossFunction get_loss_function(OtherLossAttrs const &); +LossFunction get_loss_function(SparseCategoricalCrossEntropyLossAttrs const &); +LossFunction get_loss_function(LossAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml new file mode 100644 index 0000000000..284a4b1d7d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OtherLossAttrs" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "op-attrs/ops/loss_functions/loss_function.dtg.h", +] + +[[fields]] +name = "loss_type" +type = "::FlexFlow::LossFunction" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml new file mode 100644 index 0000000000..c50b432ba2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "SparseCategoricalCrossEntropyLossAttrs" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[fields]] +# for aggregate_spec: More predictions than labels +name = "replace_labels" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index eb01009259..2c61dff886 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -1,14 +1,16 @@ #ifndef _FLEXFLOW_OP_ATTRS_OPS_NOOP_H #define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/noop_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(NoopAttrs); +TensorShape get_output_shape(NoopAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 162f9aef05..505fdd9f8c 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -1,14 +1,16 @@ #ifndef _FLEXFLOW_POOL_2D_ATTRS_H #define _FLEXFLOW_POOL_2D_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); +TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 4c46bf88a9..9104a36155 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_REPLICATE_ATTRS_H #define _FLEXFLOW_REPLICATE_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/replicate_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index cd2ca80c3a..e87ca5c750 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_RESHAPE_ATTRS_H #define _FLEXFLOW_RESHAPE_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/reshape_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -9,6 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReshapeAttrs); +TensorShape get_output_shape(ReshapeAttrs const &attrs, + TensorShape const &input_shape); ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index adc62dc9ae..023e714c20 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -1,14 +1,16 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/reverse_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ReverseAttrs); +TensorShape get_output_shape(ReverseAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index 8fc2257760..e6a08d6e77 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -1,15 +1,18 @@ #ifndef _FLEXFLOW_SPLIT_ATTRS_H #define _FLEXFLOW_SPLIT_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { CHECK_VALID_OP_ATTR(SplitAttrs); +std::vector get_output_shapes(SplitAttrs const &, + TensorShape const &); std::vector get_output_shapes(SplitAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index c6af40dd48..bd11f0ae91 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_TOPK_ATTRS_H #define _FLEXFLOW_TOPK_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 6e23d91d78..6de83ee414 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -1,16 +1,18 @@ #ifndef _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H #define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/transpose_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(TransposeAttrs); -ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, - ParallelTensorShape const &input_shape); +TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &); +ParallelTensorShape get_output_shape(TransposeAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/weight.h b/lib/op-attrs/include/op-attrs/ops/weight.h new file mode 100644 index 0000000000..ab97b31012 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H + +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" + +namespace FlexFlow { + +CHECK_VALID_OP_ATTR(WeightAttrs); + +TensorShape get_output_shape(WeightAttrs const &); +ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml index 28810a437e..c4d22a006c 100644 --- a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml @@ -8,4 +8,11 @@ features = [ "rapidcheck", "fmt", ] -fields = [] + +includes = [ + "op-attrs/tensor_shape.dtg.h", +] + +[[fields]] +name = "tensor_shape" +type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/parallel_op_attrs.h b/lib/op-attrs/include/op-attrs/parallel_op_attrs.h new file mode 100644 index 0000000000..8669669f09 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_op_attrs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_OP_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_OP_ATTRS_H + +#include "op-attrs/parallel_op_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ParallelOpAttrs const &, + ParallelTensorShape const &); +PCGOperatorAttrs pcg_op_attrs_from_parallel_op_attrs(ParallelOpAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml new file mode 100644 index 0000000000..f1631a41f2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "ParallelOpAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/reduction_attrs.dtg.h", + "op-attrs/ops/repartition_attrs.dtg.h", + "op-attrs/ops/replicate_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::CombineAttrs" +key = "combine_distributed" + +[[values]] +type = "::FlexFlow::ReductionAttrs" +key = "reduce_distributed" + +[[values]] +type = "::FlexFlow::RepartitionAttrs" +key = "partition_distributed" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate_distributed" + diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 99be635ffc..76356b39d4 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -2,6 +2,7 @@ #define _OP_META_PARALLEL_TENSOR_SHAPE_H #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" #include @@ -36,6 +37,7 @@ int get_total_parallel_degree(ParallelTensorShape const &); bool is_valid(ParallelTensorShape const &); +TensorShape require_not_parallel(ParallelTensorShape const &); TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); std::vector get_tensor_shapes_unsafe(std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml index e6197bcd51..806af55cba 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/parallel_tensor_dims.h", - "op-attrs/datatype.h", + "op-attrs/parallel_tensor_dims.dtg.h", + "op-attrs/datatype.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h index 25be926cbe..08167fe3d9 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -10,6 +10,7 @@ bool is_parallel_op(PCGOperatorAttrs const &); OperatorType get_op_type(PCGOperatorAttrs const &); ComputationGraphOpAttrs compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); +RecordFormatter as_dot(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc deleted file mode 100644 index c20d4be34c..0000000000 --- a/lib/op-attrs/src/get_output_shapes.cc +++ /dev/null @@ -1,29 +0,0 @@ -#include "op-attrs/get_output_shapes.h" - -namespace FlexFlow { - -ParallelTensorShape as_parallel(TensorShape const &); -std::vector as_parallel(std::vector const &); - -std::vector get_output_shapes( - PCGOperatorAttrs const &op_params, - std::vector const &input_tensor_shapes) { - NOT_IMPLEMENTED(); -} - -// TensorShape get_output_shape(AggregateAttrs const &attrs, -// TensorShape const &gate_preds, -// TensorShape const &gate_assign, -// TensorShape const &true_gate_assign, -// TensorShape const &full_gate_gradients, -// std::vector const &exp_preds) { -// return get_tensor_shape_unsafe( -// get_output_shape(attrs, -// as_parallel(gate_preds), -// as_parallel(gate_assign), -// as_parallel(true_gate_assign), -// as_parallel(full_gate_gradients), -// as_parallel(exp_preds))); -// } - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/as_dot.cc b/lib/op-attrs/src/op-attrs/as_dot.cc deleted file mode 100644 index f8d05de941..0000000000 --- a/lib/op-attrs/src/op-attrs/as_dot.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/as_dot.h" - -namespace FlexFlow { - -RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { - NOT_IMPLEMENTED(); -} - -RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/datatype.cc b/lib/op-attrs/src/op-attrs/datatype.cc index bd29c8033a..3bee05c253 100644 --- a/lib/op-attrs/src/op-attrs/datatype.cc +++ b/lib/op-attrs/src/op-attrs/datatype.cc @@ -19,7 +19,7 @@ size_t size_of_datatype(DataType data_type) { case DataType::DOUBLE: return sizeof(double); default: - throw mk_runtime_error("Unknown DataType {}", data_type); + throw mk_runtime_error(fmt::format("Unknown DataType {}", data_type)); } } diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc new file mode 100644 index 0000000000..6edd5485af --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/enumerate.h" diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc new file mode 100644 index 0000000000..d91d1a1eca --- /dev/null +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -0,0 +1,85 @@ +#include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/concat.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/dropout.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/flat.h" +#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/input.h" +#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/ops/weight.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::vector + get_output_shapes(PCGOperatorAttrs const &pcg_op_attrs, + std::vector const &inputs) { + return pcg_op_attrs.visit>(overload{ + [&](BatchMatmulAttrs const &attrs) -> std::vector { + return {throw_if_unexpected( + get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; + }, + [&](BatchNormAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0))}; + }, + [&](CastAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](CombineAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](ConcatAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs)}; + }, + [&](Conv2DAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0))}; + }, + [&](DropoutAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](ElementBinaryAttrs const &attrs) -> std::vector { + return {throw_if_unexpected( + get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; + }, + [&](ElementUnaryAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](EmbeddingAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](FlatAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0))}; + }, + [&](GatherAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0), inputs.at(1))}; + }, + [&](InputAttrs const &attrs) -> std::vector { + return {get_output_parallel_tensor_shape(attrs)}; + }, + [&](LayerNormAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](LinearAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](ReplicateAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0))}; + }, + [&](WeightAttrs const &attrs) -> std::vector { + return {get_output_parallel_tensor_shape(attrs)}; + }, + [&](auto const &attrs) -> std::vector { + NOT_IMPLEMENTED(); + }}); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/is_valid.cc b/lib/op-attrs/src/op-attrs/is_valid.cc new file mode 100644 index 0000000000..14eae33b4b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/is_valid.cc @@ -0,0 +1,3 @@ +#include "op-attrs/is_valid.h" + +namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index 7be51efa22..b75c3521c6 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -2,6 +2,11 @@ namespace FlexFlow { +TensorShape get_output_shape(BatchNormAttrs const &, + TensorShape const &input_shape) { + return input_shape; +} + ParallelTensorShape get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc index 444409ffcb..cfbfd61ced 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/cast.h" +#include "op-attrs/datatype.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 065c58f365..02fee70bea 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -11,4 +11,14 @@ namespace FlexFlow { /* return valid; */ /* } */ +TensorShape get_output_shape(ConcatAttrs const &, + std::vector const &) { + NOT_IMPLEMENTED(); +} + +ParallelTensorShape get_output_shape(ConcatAttrs const &, + std::vector const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index 4a7d4395b6..d10d52c6f5 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -1,6 +1,7 @@ #include "op-attrs/ops/embedding.h" #include "op-attrs/dim_ordered/slice.h" #include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/parallel_tensor_dims.h" #include "utils/containers/product.h" #include "utils/integer_conversions.h" diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index b0683c5f08..5d318207ee 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -3,15 +3,24 @@ namespace FlexFlow { -namespace Input { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; +TensorShape get_output_shape(FlatAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); } -namespace Output { -constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; +ParallelTensorShape get_output_shape(FlatAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); } +// namespace Input { +// constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, +// REPLICA = 4; +// } +// +// namespace Output { +// constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; +// } +// /* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ /* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ diff --git a/lib/op-attrs/src/op-attrs/ops/gather.cc b/lib/op-attrs/src/op-attrs/ops/gather.cc index 4f2c13c794..4b1053aee1 100644 --- a/lib/op-attrs/src/op-attrs/ops/gather.cc +++ b/lib/op-attrs/src/op-attrs/ops/gather.cc @@ -2,6 +2,18 @@ namespace FlexFlow { +TensorShape get_output_shape(GatherAttrs const &, + TensorShape const &input, + TensorShape const &index) { + NOT_IMPLEMENTED(); +} + +ParallelTensorShape get_output_shape(GatherAttrs const &, + ParallelTensorShape const &input, + ParallelTensorShape const &index) { + NOT_IMPLEMENTED(); +} + /* bool GatherAttrs::is_valid(ParallelTensorShape const &lhs, * ParallelTensorShape const &rhs) const { */ /* if (lhs.num_dims() != rhs.num_dims()) { */ diff --git a/lib/op-attrs/src/op-attrs/ops/input.cc b/lib/op-attrs/src/op-attrs/ops/input.cc index 93606b603a..acc0b02e69 100644 --- a/lib/op-attrs/src/op-attrs/ops/input.cc +++ b/lib/op-attrs/src/op-attrs/ops/input.cc @@ -2,7 +2,11 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(InputAttrs const &) { +TensorShape get_output_shape(InputAttrs const &) { + NOT_IMPLEMENTED(); +} + +ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index 76a5e25dfc..b9603d7850 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -1,6 +1,8 @@ #include "op-attrs/ops/layer_norm.h" #include "op-attrs/dim_ordered/ff_ordered_of.h" #include "op-attrs/dim_ordered/get_idxs.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" #include "utils/containers/all_of.h" #include "utils/containers/any_of.h" #include "utils/containers/contains.h" diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index beb944d1a0..24a8250690 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -8,6 +8,22 @@ namespace FlexFlow { +RecordFormatter as_dot(LinearAttrs const &attrs) { + RecordFormatter r; + + auto kv = [](std::string const &label, auto const &val) { + RecordFormatter rr; + rr << label << fmt::to_string(val); + return rr; + }; + + r << kv("out_channels", attrs.out_channels) << kv("use_bias", attrs.use_bias) + << kv("data_type", attrs.data_type) << kv("activation", attrs.activation) + << kv("regularizer", attrs.regularizer); + + return r; +} + tl::expected get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); diff --git a/lib/op-attrs/src/loss_functions.cc b/lib/op-attrs/src/op-attrs/ops/loss_functions.cc similarity index 80% rename from lib/op-attrs/src/loss_functions.cc rename to lib/op-attrs/src/op-attrs/ops/loss_functions.cc index 094e117d77..e756d08547 100644 --- a/lib/op-attrs/src/loss_functions.cc +++ b/lib/op-attrs/src/op-attrs/ops/loss_functions.cc @@ -1,4 +1,4 @@ -#include "op-attrs/ops/loss_functions.h" +#include "op-attrs/ops/loss_functions/loss_functions.h" #include "utils/containers/transform.h" #include #include @@ -8,20 +8,15 @@ namespace FlexFlow { LossFunction get_loss_type(OtherLossAttrs const &attrs) { return attrs.loss_type; } + LossFunction get_loss_type(SparseCategoricalCrossEntropyLossAttrs const &attrs) { return LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY; } -struct GetLossFunction { - template - LossFunction operator()(T const &t) { - return get_loss_type(t); - } -}; - LossFunction get_loss_type(LossAttrs const &attrs) { - return visit(GetLossFunction{}, attrs); + return attrs.visit( + [](auto const &t) { return get_loss_type(t); }); } LossFunction parse_loss_name(std::string const &raw_name) { @@ -37,8 +32,8 @@ LossFunction parse_loss_name(std::string const &raw_name) { } else if (name == "identity") { return LossFunction::IDENTITY; } else { - throw mk_runtime_error( - "Unknown loss type {}. Please report this as an issue.", name); + throw mk_runtime_error(fmt::format( + "Unknown loss type {}. Please report this as an issue.", name)); } } diff --git a/lib/op-attrs/src/op-attrs/ops/noop.cc b/lib/op-attrs/src/op-attrs/ops/noop.cc index b2b15d820c..6ba33146e4 100644 --- a/lib/op-attrs/src/op-attrs/ops/noop.cc +++ b/lib/op-attrs/src/op-attrs/ops/noop.cc @@ -2,6 +2,11 @@ namespace FlexFlow { +TensorShape get_output_shape(NoopAttrs const &, + TensorShape const &input_shape) { + return input_shape; +} + ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &input_shape) { return input_shape; diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index cf6ed177d3..e1917efd89 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc index 7d0600550a..6216ad8c6c 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc @@ -2,6 +2,11 @@ namespace FlexFlow { +TensorShape get_output_shape(ReshapeAttrs const &attrs, + TensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc index 79b5bd50fb..c38d7e4782 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(ReverseAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc index cfb4071833..a9fe691584 100644 --- a/lib/op-attrs/src/op-attrs/ops/split.cc +++ b/lib/op-attrs/src/op-attrs/ops/split.cc @@ -2,6 +2,11 @@ namespace FlexFlow { +std::vector get_output_shapes(SplitAttrs const &, + TensorShape const &) { + NOT_IMPLEMENTED(); +} + std::vector get_output_shapes(SplitAttrs const &attrs, ParallelTensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc index 75f7eb3c18..50e6fb35f5 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/weight.cc b/lib/op-attrs/src/op-attrs/ops/weight.cc new file mode 100644 index 0000000000..f8b6b7ec49 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/weight.cc @@ -0,0 +1,14 @@ +#include "op-attrs/ops/weight.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +TensorShape get_output_shape(WeightAttrs const &attrs) { + return attrs.tensor_shape; +} + +ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &attrs) { + return lift_to_parallel(attrs.tensor_shape); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc new file mode 100644 index 0000000000..c458d4149d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc @@ -0,0 +1,37 @@ +#include "op-attrs/parallel_op_attrs.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "utils/overload.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ParallelOpAttrs const &attrs, + ParallelTensorShape const &input_shape) { + return attrs.visit(overload{ + [&](CombineAttrs const &combine_attrs) { + return throw_if_unexpected( + get_output_shape(combine_attrs, input_shape)); + }, + [&](ReductionAttrs const &reduction_attrs) { + return throw_if_unexpected( + get_output_shape(reduction_attrs, input_shape)); + }, + [&](RepartitionAttrs const &repartition_attrs) { + return throw_if_unexpected( + get_output_shape(repartition_attrs, input_shape)); + }, + [&](ReplicateAttrs const &replicate_attrs) { + return get_output_shape(replicate_attrs, input_shape); + }, + }); +} + +PCGOperatorAttrs + pcg_op_attrs_from_parallel_op_attrs(ParallelOpAttrs const &attrs) { + return attrs.visit( + [](auto const &attrs) { return PCGOperatorAttrs{attrs}; }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 150fb6a76d..10bf5027a4 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -1,4 +1,5 @@ #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_dims.h" #include "op-attrs/tensor_dims.h" #include "utils/containers/product.h" #include "utils/containers/transform.h" @@ -74,6 +75,19 @@ ParallelTensorShape }; } +TensorShape require_not_parallel(ParallelTensorShape const &s) { + int total_degree = get_total_parallel_degree(s); + if (total_degree != 1) { + throw mk_runtime_error( + fmt::format("Error: require_not_parallel received a parallel tensor " + "shape with parallel degree {}: {}", + total_degree, + s)); + } + + return get_reduced_shape(s); +} + TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc index 74882fe9f2..0bb134da6b 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -1,5 +1,6 @@ #include "op-attrs/pcg_operator_attrs.h" #include "op-attrs/get_op_type.h" +#include "op-attrs/ops/linear.h" #include "utils/overload.h" namespace FlexFlow { @@ -64,4 +65,15 @@ ComputationGraphOpAttrs }); } +RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { + return attrs.visit(overload{ + [](LinearAttrs const &l) { return as_dot(l); }, + [&](auto const &) { + RecordFormatter r; + r << fmt::to_string(get_op_type(attrs)); + return r; + }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op_attrs.cc b/lib/op-attrs/src/op_attrs.cc deleted file mode 100644 index 6125c03a59..0000000000 --- a/lib/op-attrs/src/op_attrs.cc +++ /dev/null @@ -1,10 +0,0 @@ -/* #include "op-attrs/ops/op_attrs.h" */ - -/* namespace FlexFlow { */ - -/* int OpAttrsInterface::num_outputs(std::vector const - * &inputs) const { */ -/* return this->output_shapes(inputs).size(); */ -/* } */ - -/* } */ diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc new file mode 100644 index 0000000000..d2c758a05f --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc @@ -0,0 +1,20 @@ +#include "op-attrs/dim_ordered/enumerate.h" +#include "utils/fmt/map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("enumerate(FFOrdered)") { + FFOrdered input = {"zero", "one", "two"}; + + std::map result = enumerate(input); + std::map correct = { + {ff_dim_t{0}, "zero"}, + {ff_dim_t{1}, "one"}, + {ff_dim_t{2}, "two"}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index 8f5f4054d6..b9dd66df5d 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/layer_norm.h" +#include "op-attrs/parallel_tensor_shape.h" #include "test/utils/doctest.h" #include "utils/expected.h" #include "utils/fmt/expected.h" diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/ops/cast.cc index 086d25d042..31030ca0f9 100644 --- a/lib/op-attrs/test/src/ops/cast.cc +++ b/lib/op-attrs/test/src/ops/cast.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/cast.h" +#include "op-attrs/parallel_tensor_shape.h" #include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 0ca330408e..c641aed6a4 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -114,11 +114,10 @@ struct ComputationGraphBuilder { std::optional const &kernel_initializer = std::nullopt, std::optional const &name = std::nullopt); // Add a gather layer - std::vector - gather(tensor_guid_t const &input, - tensor_guid_t const &index, - ff_dim_t dim, - std::optional const &name = std::nullopt); + tensor_guid_t gather(tensor_guid_t const &input, + tensor_guid_t const &index, + ff_dim_t dim, + std::optional const &name = std::nullopt); // Add a cache layer tensor_guid_t cache(tensor_guid_t const &input, diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h index 0e547e7688..05c486f0f7 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h @@ -8,7 +8,7 @@ namespace FlexFlow { V1DataflowGraph to_v1(DataflowGraphView const &); V1DataflowGraph to_v1(DataflowGraphView const &, - std::unordered_map const &); + std::unordered_map const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml index dc9dc96f29..d9aade739c 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml @@ -19,7 +19,7 @@ includes = [ [[fields]] name = "nodes" -type = "std::vector" +type = "std::vector" [[fields]] name = "edges" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml index b0d2546977..752706fe1d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml @@ -11,16 +11,16 @@ features = [ [[fields]] name = "srcNode" -type = "size_t" +type = "int" [[fields]] name = "srcIdx" -type = "size_t" +type = "int" [[fields]] name = "dstNode" -type = "size_t" +type = "int" [[fields]] name = "dstIdx" -type = "size_t" +type = "int" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h index b1f96c513b..48203d73ae 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h @@ -3,7 +3,7 @@ #include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" #include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" -#include "utils/containers/enumerate.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -16,14 +16,14 @@ template V1LabelledDataflowGraph to_v1(LabelledDataflowGraphView const &g) { - bidict nodes = enumerate(get_nodes(g)); + bidict nodes = bidict_from_enumerating(get_nodes(g)); V1DataflowGraph unlabelled = to_v1(g, nodes.reversed()); - std::unordered_map node_labels = map_values( + std::unordered_map node_labels = map_values( nodes.as_unordered_map(), [&](Node const &n) { return g.at(n); }); - std::unordered_map> output_labels = + std::unordered_map> output_labels = map_values(nodes.as_unordered_map(), [&](Node const &n) { return transform(get_outputs(g, n), [&](DataflowOutput const &o) { return g.at(o); }); diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml index 0a6a148159..fd8d4c39c4 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml @@ -22,11 +22,11 @@ includes = [ [[fields]] name = "node_labels" -type = "std::unordered_map" +type = "std::unordered_map" [[fields]] name = "output_labels" -type = "std::unordered_map>" +type = "std::unordered_map>" [[fields]] name = "graph" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h b/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h new file mode 100644 index 0000000000..eb4928deaa --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_GENERATE_WEIGHT_TRANSFORM_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_GENERATE_WEIGHT_TRANSFORM_H + +#include "op-attrs/parallel_op_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" + +namespace FlexFlow { + +std::unordered_set + generate_weight_transform(TensorShape const ¤t, + ParallelTensorShape const &goal); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 4caaad06b2..9150681070 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -28,9 +28,6 @@ std::vector get_layer_outputs(ParallelComputationGraph const &, parallel_layer_guid_t const &); -parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &, - parallel_tensor_guid_t const &); - ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 5b34ee641a..20e947ad58 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -84,10 +84,35 @@ struct ParallelComputationGraphBuilder { std::optional output_bias_initializer = std::nullopt, std::optional const &name = std::nullopt); + parallel_tensor_guid_t + batch_norm(parallel_tensor_guid_t const &input, + bool relu = true, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t relu(parallel_tensor_guid_t const &x, std::optional const &name = std::nullopt); + parallel_tensor_guid_t + identity(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + gelu(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + sigmoid(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + tanh(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + elu(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t parallel_partition(parallel_tensor_guid_t const &x, ff_dim_t dim, @@ -137,6 +162,15 @@ struct ParallelComputationGraphBuilder { std::vector const &weights, ParallelTensorShape const &output); + parallel_tensor_guid_t + add_weight(ParallelTensorAttrs const &weight_tensor_attrs, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + element_unary(ElementUnaryAttrs const &element_unary_attrs, + parallel_tensor_guid_t const &input, + std::optional const &name); + public: ParallelComputationGraph pcg; }; diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h new file mode 100644 index 0000000000..7aac8558e4 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_EDGE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_EDGE_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +parallel_tensor_guid_t + get_parallel_tensor(ParallelComputationGraphEdge const &); +parallel_layer_guid_t get_src_layer(ParallelComputationGraphEdge const &); +parallel_layer_guid_t get_dst_layer(ParallelComputationGraphEdge const &); +int get_dst_layer_input_idx(ParallelComputationGraphEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml new file mode 100644 index 0000000000..25ef3f5d27 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ParallelComputationGraphEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DataflowEdge" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h new file mode 100644 index 0000000000..905a365b4b --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_H + +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +parallel_layer_guid_t get_source_layer(parallel_tensor_guid_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml new file mode 100644 index 0000000000..6d5e007650 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "parallel_tensor_use_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_input" +type = "::FlexFlow::DataflowInput" diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index e7f5f2b737..deaa440ef8 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -53,7 +53,7 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg, std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n) { - return transform(get_inputs(cg.raw_graph, n.raw_node), + return transform(get_input_values(cg.raw_graph, n.raw_node), [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index b6d0e7c890..3f2feaf619 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -2,13 +2,24 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/dropout.h" #include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "pcg/computation_graph.h" #include "utils/containers/any_of.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/enumerate_vector.h" +#include "utils/containers/get_only.h" #include "utils/containers/transform.h" #include "utils/expected.h" #include @@ -49,7 +60,7 @@ std::vector ComputationGraphBuilder::add_layer( return fmt::format("{}.weights[{}]", layer_name, weight_idx); }); LayerAttrs weight_layer_attrs = LayerAttrs{ - ComputationGraphOpAttrs{WeightAttrs{}}, + ComputationGraphOpAttrs{WeightAttrs{weight_tensor_attrs.shape}}, weight_name, }; std::vector weight_layer_inputs = {}; @@ -451,7 +462,7 @@ tensor_guid_t ComputationGraphBuilder::embedding( return this->add_layer(layer, {input}, {weight_attrs}, output_shape); } -std::vector ComputationGraphBuilder::gather( +tensor_guid_t ComputationGraphBuilder::gather( tensor_guid_t const &input, tensor_guid_t const &index, ff_dim_t dim, @@ -469,10 +480,10 @@ std::vector ComputationGraphBuilder::gather( DataType::INT32, DataType::INT64); } - std::vector output_shapes = - get_output_shapes(attrs, this->get_shape(input), this->get_shape(index)); + TensorShape output_shape = + get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); - return this->add_layer(layer, {input}, {}, output_shapes); + return this->add_layer(layer, {input}, {}, output_shape); } /* std::vector diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc index 787ce5bf7d..cf150a339f 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc @@ -1,4 +1,5 @@ #include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" #include "utils/containers/enumerate.h" #include "utils/containers/sorted.h" #include "utils/containers/values.h" @@ -9,17 +10,19 @@ namespace FlexFlow { V1DataflowGraph to_v1(DataflowGraphView const &g) { - return to_v1(g, enumerate(get_nodes(g)).reversed()); + bidict node_enumeration_bidict = + bidict_from_enumerating(get_nodes(g)); + std::unordered_map node_enumeration = + node_enumeration_bidict.reversed().as_unordered_map(); + return to_v1(g, node_enumeration); } V1DataflowGraph to_v1(DataflowGraphView const &g, - std::unordered_map const &nodes) { + std::unordered_map const &nodes) { std::unordered_set edges; for (DataflowEdge const &e : get_edges(g)) { - edges.insert(V1GraphEdge{nodes.at(e.src.node), - size_t_from_int(e.src.idx), - nodes.at(e.dst.node), - size_t_from_int(e.dst.idx)}); + edges.insert(V1GraphEdge{ + nodes.at(e.src.node), e.src.idx, nodes.at(e.dst.node), e.dst.idx}); } return V1DataflowGraph{ diff --git a/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc new file mode 100644 index 0000000000..dadad6277f --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc @@ -0,0 +1,35 @@ +#include "pcg/parallel_computation_graph/generate_weight_transform.h" +#include "op-attrs/dim_ordered/enumerate.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +std::unordered_set + generate_weight_transform(TensorShape const ¤t, + ParallelTensorShape const &goal) { + std::unordered_set result; + + int sum_degree = get_sum_degree(goal); + if (sum_degree != 1) { + throw mk_runtime_error( + fmt::format("generate_weight_transform currently only supports " + "sum_degree = 1, but received {}", + sum_degree)); + } + + int discard_copy_degree = get_discard_copy_degree(goal); + if (discard_copy_degree != 1) { + result.insert(ParallelOpAttrs{ReplicateAttrs{discard_copy_degree}}); + } + + for (auto const &[shard_dim, shard_degree] : + enumerate(ff_ordered_shard_degrees(goal))) { + if (shard_degree != 1) { + result.insert(ParallelOpAttrs{RepartitionAttrs{shard_dim, shard_degree}}); + } + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 831287567d..5b178160cd 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -44,7 +44,7 @@ std::vector get_layer_inputs(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { return transform( - get_inputs(pcg.raw_graph, l.raw_graph_node), + get_input_values(pcg.raw_graph, l.raw_graph_node), [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index b632c984bc..8290a2ff94 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,6 +1,8 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/parallel_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" +#include "pcg/parallel_computation_graph/generate_weight_transform.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/enumerate_vector.h" @@ -326,11 +328,28 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( return this->add_layer(layer, {query, key, value}, weights, output_shape); } -parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( +parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( parallel_tensor_guid_t const &input, + bool relu, std::optional const &maybe_name) { - ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU, std::nullopt}; + BatchNormAttrs attrs = BatchNormAttrs{relu}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape output_shape = + get_output_shape(attrs, this->get_shape(input)); + + return this->add_layer(layer, {input}, {}, {output_shape}); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::element_unary( + ElementUnaryAttrs const &attrs, + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); @@ -343,6 +362,78 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( return this->add_layer(layer, {input}, {}, {output_shape}); } +parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::RELU, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::identity( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::IDENTITY, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::gelu( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::GELU, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::sigmoid( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::SIGMOID, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::tanh( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::TANH, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::elu( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::ELU, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_partition( parallel_tensor_guid_t const &input, ff_dim_t dim, @@ -441,6 +532,54 @@ ParallelTensorShape ParallelComputationGraphBuilder::get_shape( return get_parallel_tensor_attrs(this->pcg, t).shape; } +parallel_tensor_guid_t ParallelComputationGraphBuilder::add_weight( + ParallelTensorAttrs const &weight_tensor_attrs, + std::optional const &weight_name) { + ParallelTensorShape par_weight_shape = weight_tensor_attrs.shape; + TensorShape unpar_weight_shape = get_reduced_shape(weight_tensor_attrs.shape); + + ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{unpar_weight_shape}}, + weight_name, + }; + + std::vector weight_layer_inputs = {}; + std::vector weight_output_attrs = {weight_tensor_attrs}; + + DataflowOutput current_raw_weight_tensor = get_only( + this->pcg.raw_graph + .add_node( + weight_layer_attrs, weight_layer_inputs, weight_output_attrs) + .outputs); + ParallelTensorShape current_shape = lift_to_parallel(unpar_weight_shape); + + for (ParallelOpAttrs const ¶llel_op_attr : + generate_weight_transform(unpar_weight_shape, par_weight_shape)) { + ParallelTensorShape output_shape = + get_output_shape(parallel_op_attr, current_shape); + ParallelTensorAttrs output_attrs = ParallelTensorAttrs{ + output_shape, + std::nullopt, + std::nullopt, + CreateGrad::YES, + }; + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + pcg_op_attrs_from_parallel_op_attrs(parallel_op_attr), + std::nullopt, + }; + current_raw_weight_tensor = get_only( + this->pcg.raw_graph + .add_node(layer_attrs, {current_raw_weight_tensor}, {output_attrs}) + .outputs); + current_shape = output_shape; + } + + assert(current_shape == par_weight_shape); + + return parallel_tensor_guid_t{current_raw_weight_tensor}; +} + std::vector ParallelComputationGraphBuilder::add_layer( ParallelLayerAttrs const &layer, std::vector const &inputs, @@ -455,18 +594,9 @@ std::vector ParallelComputationGraphBuilder::add_layer( transform(layer.name, [&](std::string const &layer_name) { return fmt::format("{}.weights[{}]", layer_name, weight_idx); }); - ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{ - PCGOperatorAttrs{WeightAttrs{}}, - weight_name, - }; - std::vector weight_layer_inputs = {}; - std::vector weight_output_attrs = { - weight_tensor_attrs}; - raw_weight_tensors.push_back(get_only(this->pcg.raw_graph - .add_node(weight_layer_attrs, - weight_layer_inputs, - weight_output_attrs) - .outputs)); + + raw_weight_tensors.push_back( + this->add_weight(weight_tensor_attrs, weight_name).raw_graph_output); } std::vector raw_inputs = diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc new file mode 100644 index 0000000000..dca8154eb4 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc @@ -0,0 +1,22 @@ +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" + +namespace FlexFlow { + +parallel_tensor_guid_t + get_parallel_tensor(ParallelComputationGraphEdge const &e) { + return parallel_tensor_guid_t{e.raw_edge.src}; +} + +parallel_layer_guid_t get_src_layer(ParallelComputationGraphEdge const &e) { + return parallel_layer_guid_t{e.raw_edge.src.node}; +} + +parallel_layer_guid_t get_dst_layer(ParallelComputationGraphEdge const &e) { + return parallel_layer_guid_t{e.raw_edge.dst.node}; +} + +int get_dst_layer_input_idx(ParallelComputationGraphEdge const &e) { + return e.raw_edge.dst.idx; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc new file mode 100644 index 0000000000..ad4eae041f --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc @@ -0,0 +1,9 @@ +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" + +namespace FlexFlow { + +parallel_layer_guid_t get_source_layer(parallel_tensor_guid_t const &t) { + return parallel_layer_guid_t{t.raw_graph_output.node}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index db01728cf0..440f735e80 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -2,6 +2,7 @@ #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" #include "test/utils/doctest.h" #include "utils/containers/count.h" #include "utils/containers/generate_map.h" @@ -39,7 +40,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t rhs = b.create_input_tensor(rhs_shape); parallel_tensor_guid_t out = b.add(lhs, rhs); - parallel_layer_guid_t layer = get_source_layer(b.pcg, out); + parallel_layer_guid_t layer = get_source_layer(out); SUBCASE("inputs") { std::vector result = @@ -102,7 +103,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t b_tensor = b.create_input_tensor(b_shape); parallel_tensor_guid_t out = b.batch_matmul(a_tensor, b_tensor); - parallel_layer_guid_t layer = get_source_layer(b.pcg, out); + parallel_layer_guid_t layer = get_source_layer(out); SUBCASE("inputs") { std::vector result = @@ -145,7 +146,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType output_datatype = DataType::DOUBLE; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.cast(input, output_datatype); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -205,7 +206,7 @@ TEST_SUITE(FF_TEST_SUITE) { [&](parallel_layer_guid_t const &l) { return get_parallel_layer_attrs(b.pcg, l); }); - CHECK_MESSAGE(layers.size() == 4, "Incorrect layers ", layers); + CHECK_MESSAGE(layers.size() == 6, "Incorrect layers ", layers); auto num_attrs_of_type = [&](OperatorType op_type) -> int { return count(values(layers), [&](ParallelLayerAttrs const &l) { @@ -222,6 +223,9 @@ TEST_SUITE(FF_TEST_SUITE) { int num_conv_attrs = num_attrs_of_type(OperatorType::CONV2D); CHECK(num_conv_attrs == 1); + int num_replicate_attrs = num_attrs_of_type(OperatorType::REPLICATE); + CHECK(num_replicate_attrs == 2); + parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform( as_vector(items(layers)), [](std::pair const &kv) @@ -307,7 +311,7 @@ TEST_SUITE(FF_TEST_SUITE) { Activation::RELU, /*use_bias=*/true, DataType::FLOAT); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -350,7 +354,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*outDim=*/8, AggregateOp::SUM, DataType::FLOAT); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -400,7 +404,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t value = b.create_input_tensor(value_shape); parallel_tensor_guid_t output = b.multihead_attention(query, key, value, embed_dim, num_heads); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -441,7 +445,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.relu(input); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -480,7 +484,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.parallel_partition(input, ff_dim_t{0}, 2); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -519,7 +523,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.parallel_combine(input, ff_dim_t{0}, 2); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -558,7 +562,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.parallel_replicate(input, 2); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -597,7 +601,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.parallel_reduce(input, 2); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = diff --git a/lib/substitution-generator/CMakeLists.txt b/lib/substitution-generator/CMakeLists.txt index 41005e6a4e..1db0d888ba 100644 --- a/lib/substitution-generator/CMakeLists.txt +++ b/lib/substitution-generator/CMakeLists.txt @@ -11,6 +11,7 @@ ff_add_library( utils op-attrs pcg + substitutions ) # add_subdirectory(ffi) diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h deleted file mode 100644 index 5563d8a835..0000000000 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H -#define _FLEXFLOW_SUBSTITUTION_LOADER_H - -#include "substitution-generator/legacy_operator_type.dtg.h" -#include "substitution-generator/legacy_pm_parameter.dtg.h" -#include -#include -#include - -namespace FlexFlow { - -struct Parameter { - LegacyPMParameter key; - int value; -}; -void from_json(nlohmann::json const &j, Parameter &p); - -struct Tensor { - int opId; - int tsId; -}; -void from_json(nlohmann::json const &j, Tensor &t); - -struct Operator { - LegacyOperatorType op_type; - std::vector input; - std::vector para; - - std::optional at(LegacyPMParameter key) const; -}; -void from_json(nlohmann::json const &j, Operator &t); - -struct MapOutput { - int dstOpId; - int dstTsId; - int srcOpId; - int srcTsId; -}; -void from_json(nlohmann::json const &j, MapOutput &t); - -struct Rule { - std::string name; - std::vector srcOp; - std::vector dstOp; - std::vector mappedOutput; -}; -void from_json(nlohmann::json const &j, Rule &t); - -struct RuleCollection { - std::vector rules; -}; -void from_json(nlohmann::json const &j, RuleCollection &c); - -RuleCollection load_rule_collection(std::istream &s); -RuleCollection load_rule_collection_from_path(std::string const &path); - -} // namespace FlexFlow - -#endif // _FLEXFLOW_SUBSTITUTION_LOADER_H diff --git a/lib/substitution-generator/include/substitution-generator/legacy_rules.h b/lib/substitution-generator/include/substitution-generator/legacy_rules.h new file mode 100644 index 0000000000..a0e0a9790a --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_rules.h @@ -0,0 +1,59 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_RULES_H +#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_RULES_H + +#include "substitution-generator/legacy_operator_type.dtg.h" +#include "substitution-generator/legacy_pm_parameter.dtg.h" +#include +#include +#include + +namespace FlexFlow { + +struct LegacyParameter { + LegacyPMParameter key; + int value; +}; +void from_json(nlohmann::json const &j, LegacyParameter &p); + +struct LegacyTensor { + int opId; + int tsId; +}; +void from_json(nlohmann::json const &j, LegacyTensor &t); + +struct LegacyOperator { + LegacyOperatorType op_type; + std::vector input; + std::vector para; + + std::optional at(LegacyPMParameter key) const; +}; +void from_json(nlohmann::json const &j, LegacyOperator &t); + +struct LegacyMapOutput { + int dstOpId; + int dstTsId; + int srcOpId; + int srcTsId; +}; +void from_json(nlohmann::json const &j, LegacyMapOutput &t); + +struct LegacyRule { + std::string name; + std::vector srcOp; + std::vector dstOp; + std::vector mappedOutput; +}; +void from_json(nlohmann::json const &j, LegacyRule &t); + +struct LegacyRuleCollection { + std::vector rules; +}; +void from_json(nlohmann::json const &j, LegacyRuleCollection &c); + +LegacyRuleCollection load_rule_collection(std::istream &s); +LegacyRuleCollection load_rule_collection_from_path(std::string const &path); + +} // namespace FlexFlow + +#endif // _FLEXFLOW_SUBSTITUTION_LOADER_H diff --git a/lib/substitution-generator/src/substitution-generator/json.cc b/lib/substitution-generator/src/substitution-generator/legacy_rules.cc similarity index 67% rename from lib/substitution-generator/src/substitution-generator/json.cc rename to lib/substitution-generator/src/substitution-generator/legacy_rules.cc index 940ecb3e36..157f062cbf 100644 --- a/lib/substitution-generator/src/substitution-generator/json.cc +++ b/lib/substitution-generator/src/substitution-generator/legacy_rules.cc @@ -1,4 +1,4 @@ -#include "substitution-generator/json.h" +#include "substitution-generator/legacy_rules.h" #include #include #include @@ -7,12 +7,12 @@ using json = nlohmann::json; namespace FlexFlow { -void from_json(json const &j, Parameter &p) { +void from_json(json const &j, LegacyParameter &p) { j.at("key").get_to(p.key); j.at("value").get_to(p.value); } -void from_json(json const &j, Tensor &t) { +void from_json(json const &j, LegacyTensor &t) { j.at("opId").get_to(t.opId); j.at("tsId").get_to(t.tsId); } @@ -29,38 +29,38 @@ void from_json(json const &j, Tensor &t) { /* return value; */ /* } */ -void from_json(json const &j, Operator &o) { +void from_json(json const &j, LegacyOperator &o) { j.at("type").get_to(o.op_type); j.at("input").get_to(o.input); j.at("para").get_to(o.para); } -void from_json(json const &j, MapOutput &m) { +void from_json(json const &j, LegacyMapOutput &m) { j.at("dstOpId").get_to(m.dstOpId); j.at("dstTsId").get_to(m.dstTsId); j.at("srcOpId").get_to(m.srcOpId); j.at("srcTsId").get_to(m.srcTsId); } -void from_json(json const &j, Rule &r) { +void from_json(json const &j, LegacyRule &r) { j.at("name").get_to(r.name); j.at("srcOp").get_to(r.srcOp); j.at("dstOp").get_to(r.dstOp); j.at("mappedOutput").get_to(r.mappedOutput); } -void from_json(json const &j, RuleCollection &c) { +void from_json(json const &j, LegacyRuleCollection &c) { j.at("rule").get_to(c.rules); } -RuleCollection load_rule_collection(std::istream &s) { +LegacyRuleCollection load_rule_collection(std::istream &s) { json j; s >> j; - RuleCollection rule_collection = j; + LegacyRuleCollection rule_collection = j; return rule_collection; } -RuleCollection load_rule_collection_from_path(std::string const &path) { +LegacyRuleCollection load_rule_collection_from_path(std::string const &path) { std::ifstream input(path); return load_rule_collection(input); } diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/legacy_rules.cc similarity index 88% rename from lib/substitution-generator/test/substitution-generator/json.cc rename to lib/substitution-generator/test/substitution-generator/legacy_rules.cc index befdaf1308..4dd9bb8cc4 100644 --- a/lib/substitution-generator/test/substitution-generator/json.cc +++ b/lib/substitution-generator/test/substitution-generator/legacy_rules.cc @@ -1,4 +1,4 @@ -#include "substitution-generator/json.h" +#include "substitution-generator/legacy_rules.h" #include "doctest/doctest.h" using namespace FlexFlow; @@ -15,7 +15,7 @@ TEST_SUITE(FF_TEST_SUITE) { {"type", "OP_EW_ADD"}, }; - Operator o; + LegacyOperator o; from_json(j, o); CHECK(o.op_type == LegacyOperatorType::EW_ADD); @@ -28,7 +28,7 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("deserialize full file") { - RuleCollection collection = + LegacyRuleCollection collection = load_rule_collection_from_path("graph_subst_3_v2.json"); CHECK(collection.rules.size() == 640); } diff --git a/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml new file mode 100644 index 0000000000..dd2e850aed --- /dev/null +++ b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "input_parallel_tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_graph_input" +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h new file mode 100644 index 0000000000..ad60d50db1 --- /dev/null +++ b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPEN_PARALLEL_TENSOR_GUID_T_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPEN_PARALLEL_TENSOR_GUID_T_H + +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "substitutions/input_parallel_tensor_guid_t.dtg.h" +#include "substitutions/open_parallel_tensor_guid_t.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +open_parallel_tensor_guid_t + open_parallel_tensor_guid_from_closed(parallel_tensor_guid_t); +open_parallel_tensor_guid_t + open_parallel_tensor_guid_from_input(input_parallel_tensor_guid_t); + +template > +Ret visit_open_parallel_tensor_guid(open_parallel_tensor_guid_t t, F f) { + return t.raw_open_dataflow_value.visit(overload{ + [&](DataflowOutput const &o) { return f(parallel_tensor_guid_t{o}); }, + [&](DataflowGraphInput const &i) { + return f(input_parallel_tensor_guid_t{i}); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml new file mode 100644 index 0000000000..f07dc12d62 --- /dev/null +++ b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "open_parallel_tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +] + +[[fields]] +name = "raw_open_dataflow_value" +type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h new file mode 100644 index 0000000000..4affdd697f --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_H + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" + +namespace FlexFlow { + +OperatorAttributeConstraint op_type_equals_constraint(OperatorType); + +OperatorAttributeConstraint op_attr_key_equals(OperatorAttributeKey, + OperatorAttributeValue const &); +OperatorAttributeConstraint + make_equals_constraint(OperatorAttributeExpr const &, + OperatorAttributeValue const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h index e63c03207b..a6324863a6 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h @@ -9,8 +9,8 @@ namespace FlexFlow { std::optional - evaluate_attribute_expr(PCGOperatorAttrs const &attrs, - OperatorAttributeExpr const &expr); + evaluate_attribute_expr(OperatorAttributeExpr const &expr, + PCGOperatorAttrs const &attrs); } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index da2feb1903..7df65ef361 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -35,6 +35,9 @@ type = "int" [[values]] type = "bool" +[[values]] +type = "float" + [[values]] type = "std::vector" @@ -45,7 +48,7 @@ type = "std::vector<::FlexFlow::ff_dim_t>" type = "::FlexFlow::OperatorType" [[values]] -type = "::FlexFlow::Activation" +type = "std::optional<::FlexFlow::Activation>" [[values]] type = "::FlexFlow::ff_dim_t" diff --git a/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h b/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h new file mode 100644 index 0000000000..cc2fac4805 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_MATERIALIZE_OPERATOR_FROM_ATTRS_MAP_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_MATERIALIZE_OPERATOR_FROM_ATTRS_MAP_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" + +namespace FlexFlow { + +PCGOperatorAttrs materialize_operator_from_attrs_map( + std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h new file mode 100644 index 0000000000..e550767292 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_H + +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/output_graph/output_graph_expr_node.dtg.h" +#include "substitutions/output_graph/output_graph_expr_node_output.dtg.h" + +namespace FlexFlow { + +std::vector + get_node_outputs(OutputGraphExpr const &, OutputGraphExprNode const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml index 5caeff92f5..9ad65369a9 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml @@ -5,8 +5,9 @@ features = [] includes = [ "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h", "substitutions/output_graph/output_operator_attrs_assignment.dtg.h", + "", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::nullopt_t>" +type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::monostate>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml new file mode 100644 index 0000000000..fe7a861f0a --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "OutputGraphExprInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_graph_input" +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml new file mode 100644 index 0000000000..37c2a1f563 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "OutputGraphExprNode" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h" +] + +[[fields]] +name = "raw_graph_node" +type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml new file mode 100644 index 0000000000..7a2072e385 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "OutputGraphExprNodeOutput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_output.dtg.h", +] + +[[fields]] +name = "raw_dataflow_output" +type = "::FlexFlow::DataflowOutput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml index 5527635a2e..e856249e50 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml @@ -8,13 +8,13 @@ features = [ ] includes = [ - "utils/graph/node/node.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", "substitutions/operator_pattern/operator_attribute_expr.dtg.h", ] [[fields]] name = "node" -type = "::FlexFlow::Node" +type = "::FlexFlow::PatternNode" # NOTE(@wmdi) I am not sure whether these should be part of attribute expr. [[fields]] diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h new file mode 100644 index 0000000000..cba095b444 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_H + +#include "output_operator_attribute_expr.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" + +namespace FlexFlow { + +OperatorAttributeValue evaluate_output_operator_attribute_expr( + OutputOperatorAttributeExpr const &, + std::unordered_map const &node_match); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h new file mode 100644 index 0000000000..60540c0711 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +OutputOperatorAttrsAssignment output_operator_clone_node(PatternNode const &); + +PCGOperatorAttrs materialize_output_operator_from_attrs_assignment( + OutputOperatorAttrsAssignment const &attrs_assignment, + std::unordered_map const &node_match); + +std::pair + copy_attr_from_pattern_node(OperatorAttributeKey key, + PatternNode const &pattern_node); +std::pair + set_attr_to_constant(OperatorAttributeKey key, + OperatorAttributeValue const &value); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml index ac91e9f146..d712ea96f7 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml @@ -11,6 +11,7 @@ features = [ includes = [ "substitutions/operator_pattern/operator_attribute_key.dtg.h", "substitutions/output_graph/output_operator_attribute_expr.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", "", ] @@ -19,6 +20,10 @@ src_includes = [ "utils/fmt/unordered_map.h", ] +# [[fields]] +# name = "clone_operator" +# type = "std::optional" + # NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can # define the assignment for each operator type. [[fields]] diff --git a/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml new file mode 100644 index 0000000000..e29eef4cdd --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "OutputPatternValue" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", +] + +[[fields]] +name = "raw_dataflow_value" +type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/pcg_pattern.h b/lib/substitutions/include/substitutions/pcg_pattern.h index 593f0ddc9e..7342e8169f 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.h @@ -2,18 +2,18 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_H #include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/unlabelled/pattern_matching.h" #include "substitutions/unlabelled/pattern_node.dtg.h" #include "substitutions/unlabelled/pattern_value.dtg.h" -#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { /** * @brief Find all locations in \p pcg that match \p pattern */ -std::vector +std::vector find_pattern_matches(PCGPattern const &pattern, SubParallelComputationGraph const &pcg); @@ -24,10 +24,12 @@ TensorAttributePattern get_tensor_pattern(PCGPattern const &, OperatorAttributePattern get_operator_pattern(PCGPattern const &, PatternNode const &); std::unordered_set get_inputs(PCGPattern const &); +std::vector get_pattern_node_outputs(PCGPattern const &, + PatternNode const &); bool assignment_satisfies(SubParallelComputationGraph const &, PCGPattern const &, - UnlabelledDataflowGraphPatternMatch const &); + PCGPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/pcg_pattern_builder.h b/lib/substitutions/include/substitutions/pcg_pattern_builder.h new file mode 100644 index 0000000000..4c91dd07af --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern_builder.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_BUILDER_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_BUILDER_H + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "substitutions/unlabelled/pattern_value.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" + +namespace FlexFlow { + +struct PCGPatternBuilder { + PCGPatternBuilder(); + + PatternValue add_input(); + PatternValue add_input(TensorAttributePattern const &); + + std::vector + add_operator(OperatorAttributePattern const &, + std::vector const &inputs, + std::vector const &outputs); + PatternValue add_operator(OperatorAttributePattern const &, + std::vector const &inputs, + TensorAttributePattern const &output); + + PCGPattern get_pattern() const; + +private: + LabelledOpenDataflowGraph g; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern_match.h b/lib/substitutions/include/substitutions/pcg_pattern_match.h new file mode 100644 index 0000000000..388377d70c --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern_match.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_MATCH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_MATCH_H + +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/unlabelled/pattern_node_output.dtg.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" + +namespace FlexFlow { + +bidict + get_output_mapping_for_pcg_pattern_match( + PCGPatternMatch const &match, + PCGPattern const &pattern, + SubParallelComputationGraph const &spcg); + +UnlabelledDataflowGraphPatternMatch + get_unlabelled_pattern_match(PCGPatternMatch const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml b/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml new file mode 100644 index 0000000000..f45bedd2be --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "PCGPatternMatch" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/open_parallel_tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "node_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "input_assignment" +type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::open_parallel_tensor_guid_t>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 42d85dc549..00032045c0 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -2,20 +2,26 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" +#include "substitutions/open_parallel_tensor_guid_t.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/sub_parallel_computation_graph_data.dtg.h" +#include "substitutions/sub_parallel_computation_graph_edge.dtg.h" namespace FlexFlow { std::unordered_set get_parallel_layers(SubParallelComputationGraph const &sub_pcg); ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, - Node const &); + parallel_layer_guid_t const &); PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, - Node const &); + parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &, - OpenDataflowValue const &); + open_parallel_tensor_guid_t const &); SubParallelComputationGraph sub_pcg_from_full_pcg(ParallelComputationGraph const &); ParallelComputationGraph @@ -25,6 +31,41 @@ parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, std::string const &name); +std::vector + get_layer_inputs(SubParallelComputationGraph const &, + parallel_layer_guid_t const &); +std::vector + get_layer_outputs(SubParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_set get_subgraph_incoming_edges( + SubParallelComputationGraph const &, + std::unordered_set const &); +std::unordered_set get_subgraph_outgoing_edges( + SubParallelComputationGraph const &, + std::unordered_set const &); + +std::unordered_set get_subgraph_incoming_edges( + SubParallelComputationGraph const &, + std::unordered_set const &); + +std::unordered_set + get_parallel_tensor_uses(SubParallelComputationGraph const &, + open_parallel_tensor_guid_t const &); + +SubParallelComputationGraphData + get_sub_pcg_data(SubParallelComputationGraph const &); +SubParallelComputationGraph + sub_pcg_from_graph_data(SubParallelComputationGraphData const &); +bool sub_pcgs_are_isomorphic(SubParallelComputationGraph const &, + SubParallelComputationGraph const &); + +SubParallelComputationGraph + without_layer_names(SubParallelComputationGraph const &); + +std::string as_dot(SubParallelComputationGraph const &); +void debug_print_dot(SubParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml new file mode 100644 index 0000000000..537af231bf --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraphData" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/open_parallel_tensor_guid_t.dtg.h", + "substitutions/input_parallel_tensor_guid_t.dtg.h", + "substitutions/sub_parallel_computation_graph_edge.dtg.h", + "", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "node_data" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::ParallelLayerAttrs>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::SubParallelComputationGraphEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::input_parallel_tensor_guid_t>" + +[[fields]] +name = "value_data" +type = "std::unordered_map<::FlexFlow::open_parallel_tensor_guid_t, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h new file mode 100644 index 0000000000..15cbb6127c --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_EDGE_H + +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" +#include "substitutions/open_parallel_tensor_guid_t.dtg.h" +#include "substitutions/sub_parallel_computation_graph_edge.dtg.h" + +namespace FlexFlow { + +SubParallelComputationGraphEdge + subpcg_edge_from_tensor_and_dst(parallel_tensor_guid_t const &tensor, + parallel_layer_guid_t const &layer, + int input_idx); +SubParallelComputationGraphEdge + subpcg_edge_from_tensor_and_use(open_parallel_tensor_guid_t const &tensor, + parallel_tensor_use_t const &use); +open_parallel_tensor_guid_t + get_parallel_tensor(SubParallelComputationGraphEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml new file mode 100644 index 0000000000..6d8f72bae8 --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraphEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OpenDataflowEdge" diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 4d3473997b..7b4e5e6912 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H -#include "sub_parallel_computation_graph.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/substitution.dtg.h" -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" namespace FlexFlow { @@ -23,8 +23,8 @@ namespace FlexFlow { bool is_valid_substitution(Substitution const &); /** - * @brief Applies substitution to sub_pcg at the location specified by match, - * returning the resulting SubParallelComputationGraph + * @brief Applies \p substitution to \p sub_pcg at the location specified by \p + * match, returning the resulting SubParallelComputationGraph * * @param sub_pcg * @param substitution @@ -39,7 +39,7 @@ bool is_valid_substitution(Substitution const &); SubParallelComputationGraph apply_substitution(SubParallelComputationGraph const &sub_pcg, Substitution const &substitution, - UnlabelledDataflowGraphPatternMatch const &match); + PCGPatternMatch const &match); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/substitution.struct.toml b/lib/substitutions/include/substitutions/substitution.struct.toml index f370ef80fd..49bef62747 100644 --- a/lib/substitutions/include/substitutions/substitution.struct.toml +++ b/lib/substitutions/include/substitutions/substitution.struct.toml @@ -5,6 +5,10 @@ features = [] includes = [ "substitutions/pcg_pattern.dtg.h", "substitutions/output_graph/output_graph_expr.dtg.h", + "substitutions/output_graph/output_graph_expr_input.dtg.h", + "substitutions/output_graph/output_graph_expr_node_output.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "substitutions/unlabelled/pattern_node_output.dtg.h", ] [[fields]] @@ -16,9 +20,9 @@ name = "output_graph_expr" type = "::FlexFlow::OutputGraphExpr" [[fields]] -name = "input_edge_match_to_output" -type = "::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, ::FlexFlow::OpenDataflowValue>" +name = "inputs_mapping" +type = "::FlexFlow::bidict<::FlexFlow::PatternInput, ::FlexFlow::OutputGraphExprInput>" [[fields]] -name = "output_edge_match_to_output" -type = "::FlexFlow::bidict<::FlexFlow::DataflowOutput, ::FlexFlow::DataflowOutput>" +name = "outputs_mapping" +type = "::FlexFlow::bidict<::FlexFlow::PatternNodeOutput, ::FlexFlow::OutputGraphExprNodeOutput>" diff --git a/lib/substitutions/include/substitutions/substitution_internal/evaluate_substitution_output.h b/lib/substitutions/include/substitutions/substitution_internal/evaluate_substitution_output.h new file mode 100644 index 0000000000..a0461b075b --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution_internal/evaluate_substitution_output.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_EVALUATE_SUBSTITUTION_OUTPUT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_EVALUATE_SUBSTITUTION_OUTPUT_H + +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution.dtg.h" +#include "substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.dtg.h" +#include + +namespace FlexFlow { + +/** + * @brief Takes a SubParallelComputationGraph and a PCGPatternMatch where a + * Substitution applies and evaluates the Substitution's OutputGraphExpr + * (producing another SubParallelComputationGraph) using the information from + * the matched nodes. + * + * @details Exists only to enable apply_substitution(SubParallelComputationGraph + * const &, Substitution const &, PCGPatternMatch const &) + * + * @note The resulting SubParallelComputationGraph has new node ids, i.e., does + * not have the same node ids as the OutputGraphExpr + */ +std::pair + evaluate_substitution_output(SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h b/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h new file mode 100644 index 0000000000..603cb670bf --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_OUTPUT_EXPR_TO_RESULT_SUB_PCG_MAPPING_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_OUTPUT_EXPR_TO_RESULT_SUB_PCG_MAPPING_H + +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/output_graph/output_graph_expr_node_output.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.dtg.h" + +namespace FlexFlow { + +bidict + get_output_graph_expr_output_mapping( + OutputExprToResultSubPCGMapping const &, + OutputGraphExpr const &, + SubParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.struct.toml b/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.struct.toml new file mode 100644 index 0000000000..1fac79a91d --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OutputExprToResultSubPCGMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/input_parallel_tensor_guid_t.dtg.h", + "substitutions/output_graph/output_graph_expr_node.dtg.h", + "substitutions/output_graph/output_graph_expr_input.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::OutputGraphExprNode>" + +[[fields]] +name = "input_mapping" +type = "::FlexFlow::bidict<::FlexFlow::input_parallel_tensor_guid_t, ::FlexFlow::OutputGraphExprInput>" diff --git a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h new file mode 100644 index 0000000000..de9d1cd78a --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H + +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" + +namespace FlexFlow { + +/** + * @brief Takes a SubParallelComputationGraph but without ParallelTensorShape + * annotations on its OpenDataflowValue%s and uses shape inference to fill them + * in. + * + * @details The OutputGraphExpr of a Substitution only computes + * PCGOperatorAttr%s, not ParallelTensorShape%s, under the theory that shapes + * can be inferred by parallel shape inference. The responsibility of this + * function is to traverse the result of evaluating the OutputGraphExpr + * (resulting from evaluate_substitution_output) + * and annotate each of the OpenDataflowValue%s with the inferred shape. + * + * Exists only to enable apply_substitution(SubParallelComputationGraph const &, + * Substitution const &, PCGPatternMatch const &) + */ +LabelledOpenDataflowGraphView + perform_shape_inference( + LabelledOpenDataflowGraphView const + &g, + std::unordered_map const + &input_shapes); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml index 3df36d13ac..541888038b 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml @@ -12,3 +12,9 @@ name = "DIM_SIZES" [[values]] name = "DIM_DEGREES" + +[[values]] +name = "DISCARD_COPY_DEGREE_DIM" + +[[values]] +name = "SUM_DEGREE_DIM" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h new file mode 100644 index 0000000000..5b7ebf4ef8 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H + +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" + +namespace FlexFlow { + +TensorAttributePattern tensor_attribute_pattern_match_all(); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h index 262ae64bf8..09d6a12716 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_H +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/unlabelled/pattern_value.dtg.h" #include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include @@ -20,6 +22,11 @@ std::optional bidict const &merged_graph_values_to_inputs_of_2); +std::unordered_map + get_output_assignment(SubParallelComputationGraph const &, + PCGPattern const &, + UnlabelledDataflowGraphPatternMatch const &); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h index 95277edfc3..949fbf455b 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -17,7 +17,8 @@ std::unordered_set get_values(UnlabelledGraphPattern const &); std::vector get_topological_ordering(UnlabelledGraphPattern const &); -std::unordered_set get_inputs(UnlabelledGraphPattern const &); +std::unordered_set + get_graph_inputs(UnlabelledGraphPattern const &); std::unordered_set get_edges(UnlabelledGraphPattern const &); diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc deleted file mode 100644 index 94993f3c90..0000000000 --- a/lib/substitutions/src/substitution.cc +++ /dev/null @@ -1,387 +0,0 @@ -#include "substitutions/substitution.h" - -namespace FlexFlow { - -/* struct DeriveValidOperatorAttributeExpr { */ -/* template */ -/* std::unordered_set> */ -/* operator()(T const &t) { */ -/* return derive_valid_operator_attribute_expr(t); */ -/* } */ - -/* std::unordered_set> */ -/* derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { - */ -/* return {key}; */ -/* } */ - -/* std::unordered_set> */ -/* derive_valid_operator_attribute_expr( */ -/* ListIndexAccess const &access) { */ -/* return {access, access.attribute_key}; */ -/* } */ - -/* std::unordered_set> */ -/* derive_valid_operator_attribute_expr( */ -/* ListSize const &ls) { */ -/* return {ls, ls.attribute_key}; */ -/* } */ -/* }; */ - -/* std::unordered_set> */ -/* get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { */ -/* return set_union(transform( */ -/* pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) - * { */ -/* return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); - */ -/* })); */ -/* } */ - -/* bool is_valid_operator_attribute_expr( */ -/* OperatorPattern const &pattern, */ -/* AttributeExpr const &expr) { */ -/* return contains(get_valid_operator_attribute_exprs(pattern), expr); */ -/* } */ - -/* struct IsValidOperatorAttributeExprFunctor { */ -/* GraphPattern const &graph_pattern; */ - -/* template */ -/* bool operator()(T const &t) const { */ -/* return is_valid(t); */ -/* } */ - -/* bool is_valid(OperatorAttrAccess const &t) const { */ -/* return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), - */ -/* t.attr_expr); */ -/* } */ - -/* bool is_valid(AttrConstant const &t) const { */ -/* return true; */ -/* } */ -/* }; */ - -/* bool is_valid_operator_attribute_expr(GraphPattern const &pattern, */ -/* OperatorAttributeExpr const &expr) { */ -/* return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); */ -/* } */ - -/* bool is_valid_substitution(Substitution const &s) { */ -/* for (Node const &node : get_nodes(s.output_graph_expr.value())) { */ -/* for (OperatorAttributeExpr expr : */ -/* values(s.output_graph_expr.value().at(node).assignments)) { */ -/* if (!is_valid_operator_attribute_expr(s.input_graph, expr)) { */ -/* return false; */ -/* } */ -/* } */ -/* } */ -/* return true; */ -/* } */ - -/* struct EvaluateOperatorAttributeExpr { */ -/* SubParallelComputationGraph const &graph; */ -/* MultiDiGraphPatternMatch const &match; */ - -/* template */ -/* OperatorAttributeValue operator()(T const &t) { */ -/* return evaluate(t); */ -/* } */ - -/* OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { */ -/* Node node_in_pattern = t.node; */ -/* Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); */ -/* return evaluate_attribute_expr(graph.at(node_in_pcg), - * t.attr_expr).value(); */ -/* } */ - -/* OperatorAttributeValue evaluate(AttrConstant const &t) { */ -/* return t.value; */ -/* } */ -/* }; */ - -/* OperatorAttributeValue */ -/* evaluate_graph_attribute_expr(SubParallelComputationGraph const &g, */ -/* MultiDiGraphPatternMatch const &match, */ -/* OperatorAttributeExpr const &expr) { */ -/* return visit(EvaluateOperatorAttributeExpr{g, match}, expr); */ -/* } */ - -/* Operator get_operator_attrs(SubParallelComputationGraph const &graph, */ -/* MultiDiGraphPatternMatch const &match, */ -/* OperatorAttrAssignment const &assignment) { */ -/* std::unordered_map - * assignments; */ -/* for (auto const &[key, expr] : assignment.assignments) { */ -/* OperatorAttributeValue value = */ -/* evaluate_graph_attribute_expr(graph, match, expr); */ -/* assignments.emplace(key, value); */ -/* } */ -/* assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); */ -/* assert(std::holds_alternative( */ -/* assignments.at(OperatorAttributeKey::OP_TYPE))); */ -/* OperatorType op_type = */ -/* std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); - */ -/* switch (op_type) { */ -/* case OperatorType::BATCHMATMUL: */ -/* return Operator{ */ -/* BatchMatmulAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::A_SEQ_LENGTH_DIM)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, */ -/* std::nullopt}; */ -/* case OperatorType::BATCHNORM: */ -/* return Operator{BatchNormAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::RELU))}, */ -/* std::nullopt}; */ -/* case OperatorType::CAST: */ -/* return Operator{CastAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::DATA_TYPE))}, - */ -/* std::nullopt}; */ -/* case OperatorType::CONCAT: */ -/* return Operator{ */ -/* ConcatAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::AXIS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::CONV2D: */ -/* return Operator{ */ -/* Conv2DAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::GROUPS)), */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ -/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::DROPOUT: */ -/* return Operator{DropoutAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::RATE)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::SEED))}, */ -/* std::nullopt}; */ -/* case OperatorType::EW_ADD: */ -/* case OperatorType::EW_DIV: */ -/* case OperatorType::EW_EQUAL: */ -/* case OperatorType::EW_GREATER: */ -/* case OperatorType::EW_LESS: */ -/* case OperatorType::EW_MAX: */ -/* case OperatorType::EW_MIN: */ -/* case OperatorType::EW_MUL: */ -/* case OperatorType::EW_SUB: */ -/* return Operator{ */ -/* ElementBinaryAttrs{op_type, */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::DATA_TYPE)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::SCALAR_ADD: */ -/* case OperatorType::SCALAR_FLOOR_DIV: */ -/* case OperatorType::SCALAR_MULTIPLY: */ -/* case OperatorType::SCALAR_SUB: */ -/* case OperatorType::SCALAR_TRUE_DIV: */ -/* return Operator{ */ -/* ElementScalarUnaryAttrs{ */ -/* op_type, */ -/* std::get(assignments.at(OperatorAttributeKey::SCALAR))}, - */ -/* std::nullopt}; */ -/* case OperatorType::EXP: */ -/* case OperatorType::IDENTITY: */ -/* case OperatorType::GELU: */ -/* case OperatorType::RSQRT: */ -/* case OperatorType::POW: */ -/* case OperatorType::SIN: */ -/* case OperatorType::COS: */ -/* return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; */ -/* case OperatorType::EMBEDDING: */ -/* return Operator{ */ -/* EmbeddingAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), - */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::AGGR)), - */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::OP_TYPE))}, */ -/* std::nullopt}; */ -/* case OperatorType::FLAT: */ -/* return Operator{FlatAttrs{}, std::nullopt}; */ -/* case OperatorType::GATHER: */ -/* return Operator{GatherAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::DIM))}, */ -/* std::nullopt}; */ -/* case OperatorType::INPUT: */ -/* return Operator{InputAttrs{}, std::nullopt}; */ -/* case OperatorType::LAYERNORM: */ -/* return Operator{ */ -/* LayerNormAttrs{ */ -/* std::get>( */ -/* assignments.at(OperatorAttributeKey::AXES)), */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), - */ -/* std::get(assignments.at(OperatorAttributeKey::EPSILON))}, - */ -/* std::nullopt}; */ -/* case OperatorType::LINEAR: */ -/* return Operator{ */ -/* LinearAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), - */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::DATA_TYPE)), */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ -/* std::get>( */ -/* assignments.at(OperatorAttributeKey::REGULARIZER))}, */ -/* std::nullopt}; */ -/* case OperatorType::MULTIHEAD_ATTENTION: */ -/* return Operator{ */ -/* MultiHeadAttentionAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), - */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::VDIM)), */ -/* std::get(assignments.at(OperatorAttributeKey::DROPOUT)), - */ -/* std::get(assignments.at(OperatorAttributeKey::BIAS)), */ -/* std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, */ -/* std::nullopt}; */ -/* case OperatorType::NOOP: */ -/* return Operator{NoopAttrs{}, std::nullopt}; */ -/* case OperatorType::POOL2D: */ -/* return Operator{ */ -/* Pool2DAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), - */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ACTIVATION))}, */ -/* std::nullopt}; */ -/* case OperatorType::REDUCE_ARGMAX: */ -/* case OperatorType::REDUCE_ARGMIN: */ -/* case OperatorType::REDUCE_MAX: */ -/* case OperatorType::REDUCE_MEAN: */ -/* case OperatorType::REDUCE_MIN: */ -/* case OperatorType::REDUCE_PROD: */ -/* case OperatorType::REDUCE_SUM: */ -/* return Operator{ */ -/* ReduceAttrs{ */ -/* std::get>( */ -/* assignments.at(OperatorAttributeKey::AXES)), */ -/* op_type, */ -/* std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::REVERSE: */ -/* return Operator{ReverseAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::AXIS))}, */ -/* std::nullopt}; */ -/* case OperatorType::RESHAPE: */ -/* return Operator{ReshapeAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::SHAPE))}, */ -/* std::nullopt}; */ -/* case OperatorType::SPLIT: */ -/* return Operator{ */ -/* SplitAttrs{ */ -/* std::get>( */ -/* assignments.at(OperatorAttributeKey::SPLITS)), */ -/* std::get(assignments.at(OperatorAttributeKey::AXIS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::SOFTMAX: */ -/* return Operator{SoftmaxAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::DIM))}, */ -/* std::nullopt}; */ -/* case OperatorType::TOPK: */ -/* return Operator{ */ -/* TopKAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::K)), */ -/* std::get(assignments.at(OperatorAttributeKey::SORTED))}, - */ -/* std::nullopt}; */ -/* case OperatorType::TRANSPOSE: */ -/* return Operator{ */ -/* TransposeAttrs{std::get>( */ -/* assignments.at(OperatorAttributeKey::PERMUTATION))}, */ -/* std::nullopt}; */ -/* case OperatorType::COMBINE: */ -/* return Operator{CombineAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), - */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, - */ -/* std::nullopt}; */ -/* case OperatorType::REDUCTION: */ -/* return Operator{ */ -/* ReductionAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ -/* std::nullopt}; */ -/* case OperatorType::REPARTITION: */ -/* return Operator{ */ -/* RepartitionAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ -/* std::nullopt}; */ -/* case OperatorType::REPLICATE: */ -/* return Operator{ */ -/* ReplicateAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ -/* std::nullopt}; */ -/* default: */ -/* throw mk_runtime_error("Unknown Operator"); */ -/* } */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc b/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc new file mode 100644 index 0000000000..76329229a4 --- /dev/null +++ b/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc @@ -0,0 +1,16 @@ +#include "substitutions/open_parallel_tensor_guid_t.h" + +namespace FlexFlow { + +open_parallel_tensor_guid_t + open_parallel_tensor_guid_from_closed(parallel_tensor_guid_t t) { + return open_parallel_tensor_guid_t{OpenDataflowValue{t.raw_graph_output}}; +} + +open_parallel_tensor_guid_t + open_parallel_tensor_guid_from_input(input_parallel_tensor_guid_t i) { + return open_parallel_tensor_guid_t{ + OpenDataflowValue{i.raw_dataflow_graph_input}}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index b12564faf0..26f8ff5062 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -190,9 +190,7 @@ std::optional get_attribute(LinearAttrs const &p, case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; case OperatorAttributeKey::USE_BIAS: - return bool(p.use_bias); // NOTE(@wmd): Without casting to bool, it will - // return an OperatorAttributeValue with - // underlying type int. Might be a req issue. + return p.use_bias; case OperatorAttributeKey::DATA_TYPE: return p.data_type; case OperatorAttributeKey::ACTIVATION: @@ -213,6 +211,8 @@ std::optional return p.num_heads; case OperatorAttributeKey::USE_BIAS: return p.bias; + case OperatorAttributeKey::DROPOUT: + return p.dropout; default: return std::nullopt; } @@ -248,7 +248,7 @@ std::optional get_attribute(Pool2DAttrs const &p, case OperatorAttributeKey::POOL_TYPE: return p.pool_type; case OperatorAttributeKey::ACTIVATION: - return p.activation; + return std::optional{p.activation}; default: return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc new file mode 100644 index 0000000000..5ab528ed3d --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc @@ -0,0 +1,33 @@ +#include "substitutions/operator_pattern/operator_attribute_constraint.h" + +namespace FlexFlow { + +OperatorAttributeConstraint op_type_equals_constraint(OperatorType op_type) { + return OperatorAttributeConstraint{ + ConstraintType::EQUAL, + OperatorAttributeExpr{OperatorAttributeKey::OP_TYPE}, + OperatorAttributeValue{op_type}, + }; +} + +OperatorAttributeConstraint + op_attr_key_equals(OperatorAttributeKey key, + OperatorAttributeValue const &val) { + return OperatorAttributeConstraint{ + ConstraintType::EQUAL, + OperatorAttributeExpr{key}, + OperatorAttributeValue{val}, + }; +} + +OperatorAttributeConstraint + make_equals_constraint(OperatorAttributeExpr const &expr, + OperatorAttributeValue const &val) { + return OperatorAttributeConstraint{ + ConstraintType::EQUAL, + expr, + val, + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc index 4a55fa3de3..20f32b129f 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc @@ -7,8 +7,8 @@ namespace FlexFlow { std::optional - evaluate_attribute_expr(PCGOperatorAttrs const &attrs, - OperatorAttributeExpr const &expr) { + evaluate_attribute_expr(OperatorAttributeExpr const &expr, + PCGOperatorAttrs const &attrs) { return expr.visit>(overload{ [&](OperatorAttributeKey const &k) { return get_attribute(attrs, k); }, [&](OperatorAttributeListSize const &k) { diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index ae42515cc8..194ae49255 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -7,7 +7,7 @@ bool operator_satisfies_constraint( PCGOperatorAttrs const &attrs, OperatorAttributeConstraint const &constraint) { std::optional expr_val = - evaluate_attribute_expr(attrs, constraint.attribute_expr); + evaluate_attribute_expr(constraint.attribute_expr, attrs); if (!expr_val.has_value()) { return false; diff --git a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc new file mode 100644 index 0000000000..7d65f687c8 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc @@ -0,0 +1,155 @@ +#include "substitutions/output_graph/materialize_operator_from_attrs_map.h" +#include "utils/containers/contains_key.h" +#include "utils/fmt/unordered_map.h" + +namespace FlexFlow { + +struct Accessor { + Accessor( + std::unordered_map const &m) + : m(m) {} + + std::unordered_map const &m; + + template + T const &get(OperatorAttributeKey k) const { + if (contains_key(this->m, k)) { + return this->m.at(k).get(); + } else { + throw mk_runtime_error( + fmt::format("Could not find key {} in attrs map: {}", k, this->m)); + } + } +}; + +PCGOperatorAttrs materialize_operator_from_attrs_map( + std::unordered_map const + &attrs) { + OperatorType op_type = + attrs.at(OperatorAttributeKey::OP_TYPE).get(); + + Accessor acc = Accessor{attrs}; + + switch (op_type) { + case OperatorType::MULTIHEAD_ATTENTION: + return PCGOperatorAttrs{MultiHeadAttentionAttrs{ + /*embed_dim=*/acc.get(OperatorAttributeKey::EMBED_DIM), + /*num_heads=*/acc.get(OperatorAttributeKey::NUM_HEADS), + /*kdim=*/acc.get(OperatorAttributeKey::KDIM), + /*vdim=*/acc.get(OperatorAttributeKey::VDIM), + /*dropout=*/acc.get(OperatorAttributeKey::DROPOUT), + /*bias=*/acc.get(OperatorAttributeKey::BIAS), + /*add_bias_kv=*/acc.get(OperatorAttributeKey::ADD_BIAS_KV), + /*add_zero_attn=*/acc.get(OperatorAttributeKey::ADD_ZERO_ATTN), + }}; + case OperatorType::POOL2D: + return PCGOperatorAttrs{Pool2DAttrs{ + /*kernel_h=*/acc.get(OperatorAttributeKey::KERNEL_H), + /*kernel_w=*/acc.get(OperatorAttributeKey::KERNEL_W), + /*stride_h=*/acc.get(OperatorAttributeKey::STRIDE_H), + /*stride_w=*/acc.get(OperatorAttributeKey::STRIDE_W), + /*padding_h=*/acc.get(OperatorAttributeKey::PADDING_H), + /*padding_w=*/acc.get(OperatorAttributeKey::PADDING_W), + /*pool_type=*/acc.get(OperatorAttributeKey::POOL_TYPE), + /*activation=*/ + acc.get>(OperatorAttributeKey::ACTIVATION) + .value(), + }}; + case OperatorType::NOOP: + case OperatorType::INPUT: + case OperatorType::WEIGHT: + case OperatorType::CONV2D: + case OperatorType::DROPOUT: + case OperatorType::LINEAR: + return PCGOperatorAttrs{LinearAttrs{ + /*out_channels=*/acc.get(OperatorAttributeKey::OUT_CHANNELS), + /*use_bias=*/acc.get(OperatorAttributeKey::USE_BIAS), + /*data_type=*/acc.get(OperatorAttributeKey::DATA_TYPE), + /*activation=*/ + acc.get>(OperatorAttributeKey::ACTIVATION), + /*regularizer=*/ + acc.get>( + OperatorAttributeKey::REGULARIZER), + }}; + case OperatorType::BATCHMATMUL: + case OperatorType::SCALAR_MULTIPLY: + case OperatorType::SCALAR_ADD: + case OperatorType::SCALAR_FLOOR_DIV: + case OperatorType::SCALAR_TRUE_DIV: + case OperatorType::SCALAR_SUB: + case OperatorType::RELU: + case OperatorType::IDENTITY: + case OperatorType::SIGMOID: + case OperatorType::TANH: + case OperatorType::ELU: + case OperatorType::FLAT: + case OperatorType::SOFTMAX: + case OperatorType::BATCHNORM: + case OperatorType::CONCAT: + case OperatorType::SPLIT: + case OperatorType::EMBEDDING: + case OperatorType::CACHE: + case OperatorType::RESHAPE: + case OperatorType::REVERSE: + case OperatorType::TRANSPOSE: + case OperatorType::EW_ADD: + case OperatorType::EW_MUL: + case OperatorType::MATMUL: + case OperatorType::MUL: + case OperatorType::ENLARGE: + case OperatorType::SQUEEZE: + case OperatorType::UNSQUEEZE: + case OperatorType::EW_SUB: + case OperatorType::EW_DIV: + case OperatorType::EW_EQUAL: + case OperatorType::EW_GREATER: + case OperatorType::EW_LESS: + case OperatorType::EW_MAX: + case OperatorType::EW_MIN: + case OperatorType::REDUCE_ARGMAX: + case OperatorType::REDUCE_ARGMIN: + case OperatorType::REDUCE_MAX: + case OperatorType::REDUCE_MEAN: + case OperatorType::REDUCE_MIN: + case OperatorType::REDUCE_PROD: + case OperatorType::REDUCE_SUM: + case OperatorType::PAD: + case OperatorType::SHAPE: + case OperatorType::SIZE: + case OperatorType::TOPK: + case OperatorType::WHERE: + case OperatorType::CEIL: + case OperatorType::CAST: + case OperatorType::EXP: + case OperatorType::ROUND: + case OperatorType::LOG: + case OperatorType::LOGICAL_NOT: + case OperatorType::SQRT: + case OperatorType::SIN: + case OperatorType::COS: + case OperatorType::LEAKYRELU: + case OperatorType::SLICE: + case OperatorType::RESIZE: + case OperatorType::PRELU: + case OperatorType::GELU: + case OperatorType::FUSED: + case OperatorType::RSQRT: + case OperatorType::POW: + case OperatorType::MEAN: + case OperatorType::LAYERNORM: + case OperatorType::GATHER: + case OperatorType::BROADCAST: + case OperatorType::REPARTITION: + case OperatorType::COMBINE: + case OperatorType::REPLICATE: + case OperatorType::REDUCTION: + case OperatorType::BATCH: + case OperatorType::PIPELINE: + case OperatorType::FUSED_PARALLEL: + default: + throw mk_runtime_error( + fmt::format("Unsupported operator type {}", op_type)); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc new file mode 100644 index 0000000000..3d6aadc795 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc @@ -0,0 +1,17 @@ +#include "substitutions/output_graph/output_graph_expr.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms.h" + +namespace FlexFlow { + +std::vector + get_node_outputs(OutputGraphExpr const &g, OutputGraphExprNode const &n) { + std::vector raw_outputs = + get_outputs(g.raw_graph, n.raw_graph_node); + + return transform(raw_outputs, [](DataflowOutput const &o) { + return OutputGraphExprNodeOutput{o}; + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.cc new file mode 100644 index 0000000000..e7cfcf232c --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.cc @@ -0,0 +1,19 @@ +#include "substitutions/output_graph/output_operator_attribute_expr.h" +#include "substitutions/operator_pattern/operator_attribute_expr.h" +#include "utils/overload.h" + +namespace FlexFlow { + +OperatorAttributeValue evaluate_output_operator_attribute_expr( + OutputOperatorAttributeExpr const &expr, + std::unordered_map const &node_match) { + return expr.visit(overload{ + [&](OutputOperatorAttrAccess const &a) { + return evaluate_attribute_expr(a.attr_expr, node_match.at(a.node)) + .value(); + }, + [](AttrConstant const &c) { return c.value; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc new file mode 100644 index 0000000000..fa247cd151 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc @@ -0,0 +1,42 @@ +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/output_graph/materialize_operator_from_attrs_map.h" +#include "substitutions/output_graph/output_operator_attribute_expr.h" +#include "utils/containers/map_values.h" + +namespace FlexFlow { + +OutputOperatorAttrsAssignment output_operator_clone_node(PatternNode const &) { + NOT_IMPLEMENTED(); +} + +PCGOperatorAttrs materialize_output_operator_from_attrs_assignment( + OutputOperatorAttrsAssignment const &attrs_assignment, + std::unordered_map const &node_match) { + std::unordered_map attr_map = + map_values(attrs_assignment.assignments, + [&](OutputOperatorAttributeExpr const &expr) { + return evaluate_output_operator_attribute_expr(expr, + node_match); + }); + + return materialize_operator_from_attrs_map(attr_map); +} + +std::pair + copy_attr_from_pattern_node(OperatorAttributeKey key, + PatternNode const &pattern_node) { + return {key, + OutputOperatorAttributeExpr{OutputOperatorAttrAccess{ + pattern_node, OperatorAttributeExpr{key}}}}; +} + +std::pair + set_attr_to_constant(OperatorAttributeKey key, + OperatorAttributeValue const &value) { + return { + key, + OutputOperatorAttributeExpr{AttrConstant{value}}, + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index 4591e644bb..e53877006d 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -1,8 +1,12 @@ #include "substitutions/pcg_pattern.h" #include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/pcg_pattern_match.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/satisfies_pattern.h" #include "substitutions/unlabelled/pattern_value.h" +#include "utils/containers/map_values.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms.h" namespace FlexFlow { @@ -12,22 +16,37 @@ static MatchAdditionalCriterion return MatchAdditionalCriterion{ [&](PatternNode const &patternNode, Node const &pcgNode) { return operator_satisfies_pattern( - get_operator_attrs(pcg, pcgNode), + get_operator_attrs(pcg, parallel_layer_guid_t{pcgNode}), get_operator_pattern(pattern, patternNode)); }, [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { return parallel_tensor_satisfies_pattern( - get_parallel_tensor_attrs(pcg, pcgValue), + get_parallel_tensor_attrs(pcg, + open_parallel_tensor_guid_t{pcgValue}), get_tensor_pattern(pattern, patternValue)); }}; } -std::vector +std::vector find_pattern_matches(PCGPattern const &pattern, SubParallelComputationGraph const &pcg) { - return find_pattern_matches(get_unlabelled_pattern(pattern), - pcg.raw_graph, - pcg_pattern_criteria(pattern, pcg)); + std::vector unlabelled_matches = + find_pattern_matches(get_unlabelled_pattern(pattern), + pcg.raw_graph, + pcg_pattern_criteria(pattern, pcg)); + auto pcg_match_from_unlabelled_match = + [](UnlabelledDataflowGraphPatternMatch const &m) { + return PCGPatternMatch{ + map_values(m.node_assignment, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + map_values(m.input_assignment, + [](OpenDataflowValue const &i) { + return open_parallel_tensor_guid_t{i}; + }), + }; + }; + + return transform(unlabelled_matches, pcg_match_from_unlabelled_match); } UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { @@ -44,14 +63,25 @@ OperatorAttributePattern get_operator_pattern(PCGPattern const &p, return p.raw_graph.at(n.raw_node); } -bool assignment_satisfies( - SubParallelComputationGraph const &pcg, - PCGPattern const &pattern, - UnlabelledDataflowGraphPatternMatch const &patternMatch) { - return unlabelled_pattern_does_match(get_unlabelled_pattern(pattern), - pcg.raw_graph, - patternMatch, - pcg_pattern_criteria(pattern, pcg)); +std::vector + get_pattern_node_outputs(PCGPattern const &pattern, + PatternNode const &node) { + std::vector raw_outputs = + get_outputs(pattern.raw_graph, node.raw_node); + + return transform(raw_outputs, [](DataflowOutput const &o) { + return PatternNodeOutput{o}; + }); +} + +bool assignment_satisfies(SubParallelComputationGraph const &pcg, + PCGPattern const &pattern, + PCGPatternMatch const &pattern_match) { + return unlabelled_pattern_does_match( + get_unlabelled_pattern(pattern), + pcg.raw_graph, + get_unlabelled_pattern_match(pattern_match), + pcg_pattern_criteria(pattern, pcg)); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern_builder.cc b/lib/substitutions/src/substitutions/pcg_pattern_builder.cc new file mode 100644 index 0000000000..e81671f08a --- /dev/null +++ b/lib/substitutions/src/substitutions/pcg_pattern_builder.cc @@ -0,0 +1,52 @@ +#include "substitutions/pcg_pattern_builder.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" +#include "substitutions/unlabelled/pattern_value.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" + +namespace FlexFlow { + +PCGPatternBuilder::PCGPatternBuilder() + : g(LabelledOpenDataflowGraph:: + create>()) {} + +// PatternValue add_input() { +// return tensor_attribute_pattern_match_all(); +// } +// +// PatternValue PCGPatternBuilder::add_input(TensorAttributePattern const &p) { +// return PatternValue{PatternInput{this->g.add_input(p)}}; +// } +// +// std::vector +// PCGPatternBuilder::add_operator(OperatorAttributePattern const &p, +// std::vector const +// &inputs, +// std::vector +// const &outputs) { +// NodeAddedResult node_added_result = this->g.add_node(p, +// transform(inputs, +// raw_open_dataflow_value_from_pattern_value), +// outputs); +// return transform(node_added_result.outputs, +// pattern_value_from_raw_open_dataflow_value); +// } +// +// PatternValue PCGPatternBuilder::add_operator(OperatorAttributePattern const +// &p, +// std::vector const +// &inputs, TensorAttributePattern +// const &output) { +// return get_only(this->add_operator(p, inputs, {output})); +// } +// +// +// PCGPattern PCGPatternBuilder::get_pattern() const { +// return PCGPattern{this->g}; +// } + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc new file mode 100644 index 0000000000..f1f4e31d57 --- /dev/null +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -0,0 +1,49 @@ +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/containers/map_values.h" +#include "utils/containers/zip.h" + +namespace FlexFlow { + +bidict + get_output_mapping_for_pcg_pattern_match( + PCGPatternMatch const &match, + PCGPattern const &pattern, + SubParallelComputationGraph const &spcg) { + bidict result; + + for (auto const &[pattern_node, matched_layer] : match.node_assignment) { + std::vector matched_layer_output_tensors = + get_layer_outputs(spcg, matched_layer); + std::vector pattern_node_outputs = + get_pattern_node_outputs(pattern, pattern_node); + + assert(matched_layer_output_tensors.size() == pattern_node_outputs.size()); + + bidict mapping = + bidict_from_keys_and_values(pattern_node_outputs, + matched_layer_output_tensors); + + result = merge_bidicts(result, mapping); + } + + return result; +} + +UnlabelledDataflowGraphPatternMatch + get_unlabelled_pattern_match(PCGPatternMatch const &match) { + return UnlabelledDataflowGraphPatternMatch{ + map_values( + match.node_assignment, + [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), + map_values(match.input_assignment, + [](open_parallel_tensor_guid_t const &i) { + return i.raw_open_dataflow_value; + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 2f050ce45e..0bbe0e97a7 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -1,32 +1,43 @@ #include "substitutions/sub_parallel_computation_graph.h" +#include "op-attrs/pcg_operator_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/values.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h" namespace FlexFlow { std::unordered_set get_parallel_layers(SubParallelComputationGraph const &sub_pcg) { - return get_parallel_layers(pcg_from_sub_pcg_by_dropping_inputs(sub_pcg)); + return transform(get_nodes(sub_pcg.raw_graph), + [](Node const &n) { return parallel_layer_guid_t{n}; }); } ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, - Node const &n) { - return spcg.raw_graph.at(n); + parallel_layer_guid_t const &layer) { + return spcg.raw_graph.at(layer.raw_graph_node); } PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, - Node const &n) { + parallel_layer_guid_t const &n) { return get_parallel_layer_attrs(spcg, n).op_attrs; } ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &spcg, - OpenDataflowValue const &v) { - return spcg.raw_graph.at(v); + open_parallel_tensor_guid_t const &v) { + return spcg.raw_graph.at(v.raw_open_dataflow_value); } SubParallelComputationGraph @@ -58,4 +69,162 @@ parallel_layer_guid_t name); } +std::vector + get_layer_inputs(SubParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer) { + return transform(get_inputs(pcg.raw_graph, layer.raw_graph_node), + [](OpenDataflowValue const &v) { + return open_parallel_tensor_guid_t{v}; + }); +} + +std::vector + get_layer_outputs(SubParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer) { + return transform( + get_outputs(pcg.raw_graph, layer.raw_graph_node), + [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); +} + +std::unordered_set get_subgraph_outgoing_edges( + SubParallelComputationGraph const &spcg, + std::unordered_set const &layers) { + std::unordered_set raw_edges = get_subgraph_outgoing_edges( + spcg.raw_graph, transform(layers, [](parallel_layer_guid_t const &l) { + return l.raw_graph_node; + })); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_incoming_edges( + SubParallelComputationGraph const &spcg, + std::unordered_set const &subgraph) { + std::unordered_set raw_subgraph = + transform(subgraph, [](parallel_layer_guid_t const &l) { + return l.raw_graph_node; + }); + std::unordered_set raw_incoming_edges = + get_subgraph_incoming_edges(spcg.raw_graph, raw_subgraph); + + return transform(raw_incoming_edges, [](OpenDataflowEdge const &e) { + return SubParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set + get_parallel_tensor_uses(SubParallelComputationGraph const &spcg, + open_parallel_tensor_guid_t const &t) { + std::unordered_set raw_uses = + get_open_dataflow_value_uses(spcg.raw_graph, t.raw_open_dataflow_value); + return transform(raw_uses, [](DataflowInput const &i) { + return parallel_tensor_use_t{i}; + }); +} + +SubParallelComputationGraphData + get_sub_pcg_data(SubParallelComputationGraph const &pcg) { + LabelledOpenDataflowGraphData + raw_data = get_graph_data(pcg.raw_graph); + + return SubParallelComputationGraphData{ + map_keys(raw_data.node_data, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + transform(raw_data.edges, + [](OpenDataflowEdge const &e) { + return SubParallelComputationGraphEdge{e}; + }), + transform(raw_data.inputs, + [](DataflowGraphInput const &i) { + return input_parallel_tensor_guid_t{i}; + }), + map_keys(raw_data.value_data, + [](OpenDataflowValue const &v) { + return open_parallel_tensor_guid_t{v}; + }), + }; +} + +SubParallelComputationGraph + sub_pcg_from_graph_data(SubParallelComputationGraphData const &data) { + LabelledOpenDataflowGraphData + raw_data = LabelledOpenDataflowGraphData{ + map_keys( + data.node_data, + [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), + transform(data.edges, + [](SubParallelComputationGraphEdge const &e) { + return e.raw_edge; + }), + transform(data.inputs, + [](input_parallel_tensor_guid_t const &i) { + return i.raw_dataflow_graph_input; + }), + map_keys(data.value_data, + [](open_parallel_tensor_guid_t const &t) { + return t.raw_open_dataflow_value; + }), + }; + + return SubParallelComputationGraph{ + from_labelled_open_dataflow_graph_data(raw_data), + }; +} + +SubParallelComputationGraph + without_layer_names(SubParallelComputationGraph const &spcg) { + return SubParallelComputationGraph{ + rewrite_node_labels( + spcg.raw_graph, + [](Node const &n, ParallelLayerAttrs const &old_attrs) { + ParallelLayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + }), + }; +} + +bool sub_pcgs_are_isomorphic(SubParallelComputationGraph const &lhs, + SubParallelComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + +std::string as_dot(SubParallelComputationGraph const &spcg) { + std::function get_node_label = + [](ParallelLayerAttrs const &a) -> std::string { + RecordFormatter r = as_dot(a.op_attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + std::function get_input_label = + [](ParallelTensorAttrs const &a) -> std::string { + RecordFormatter r; + + r << fmt::to_string(a.shape); + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + return as_dot(spcg.raw_graph, get_node_label, get_input_label); +} + +void debug_print_dot(SubParallelComputationGraph const &spcg) { + std::cout << as_dot(spcg) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc new file mode 100644 index 0000000000..bb8cb449bc --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc @@ -0,0 +1,38 @@ +#include "substitutions/sub_parallel_computation_graph_edge.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" + +namespace FlexFlow { + +SubParallelComputationGraphEdge + subpcg_edge_from_tensor_and_dst(parallel_tensor_guid_t const &tensor, + parallel_layer_guid_t const &layer, + int input_idx) { + return SubParallelComputationGraphEdge{ + OpenDataflowEdge{ + DataflowEdge{ + tensor.raw_graph_output, + DataflowInput{ + layer.raw_graph_node, + input_idx, + }, + }, + }, + }; +} + +SubParallelComputationGraphEdge + subpcg_edge_from_tensor_and_use(open_parallel_tensor_guid_t const &tensor, + parallel_tensor_use_t const &use) { + return SubParallelComputationGraphEdge{ + open_dataflow_edge_from_src_and_dst(tensor.raw_open_dataflow_value, + use.raw_dataflow_input), + }; +} + +open_parallel_tensor_guid_t + get_parallel_tensor(SubParallelComputationGraphEdge const &e) { + OpenDataflowValue raw_value = get_open_dataflow_edge_src(e.raw_edge); + return open_parallel_tensor_guid_t{raw_value}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc index b4e6709a73..22e15cb01a 100644 --- a/lib/substitutions/src/substitutions/substitution.cc +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -1,154 +1,169 @@ #include "substitutions/substitution.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph_edge.h" +#include "substitutions/substitution_internal/evaluate_substitution_output.h" +#include "substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/values.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/graph/node/algorithms.h" +#include "utils/overload.h" namespace FlexFlow { -/* struct AddMappedEdgeFunctor { */ -/* bidict const &node_mapping; */ -/* SubParallelComputationGraph &new_pcg; */ - -/* template */ -/* void operator()(T const &t) { */ -/* return add_mapped_edge(t); */ -/* } */ - -/* void add_mapped_edge(InputMultiDiEdge const &e) { */ -/* new_pcg.add_edge(InputMultiDiEdge{ */ -/* node_mapping.at_l(e.dst), new_pcg.add_node_port(), e.uid}); */ -/* } */ - -/* void add_mapped_edge(OutputMultiDiEdge const &e) { */ -/* new_pcg.add_edge(OutputMultiDiEdge{ */ -/* node_mapping.at_l(e.src), new_pcg.add_node_port(), e.uid}); */ -/* } */ - -/* void add_mapped_edge(MultiDiEdge const &e) { */ -/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ -/* new_pcg.add_node_port(), */ -/* node_mapping.at_l(e.src), */ -/* new_pcg.add_node_port()}); */ -/* } */ -/* }; */ - -/* struct AddNewEdgeFunctor { */ -/* SubParallelComputationGraph const &old_pcg; */ -/* SubParallelComputationGraph &new_pcg; */ -/* MultiDiGraphPatternMatch const &match; */ -/* bidict node_mapping; */ - -/* template */ -/* void operator()(TO const &old_edge, TN const &new_edge) { */ -/* return add_new_edge(old_edge, new_edge); */ -/* } */ - -/* void add_new_edge(InputMultiDiEdge const &old_edge, */ -/* InputMultiDiEdge const &new_edge) { */ -/* new_pcg.add_edge(InputMultiDiEdge{node_mapping.at_l(new_edge.dst), */ -/* new_pcg.add_node_port(), */ -/* old_edge.uid}); */ -/* } */ - -/* void add_new_edge(MultiDiEdge const &old_edge, */ -/* InputMultiDiEdge const &new_edge) { */ -/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(new_edge.dst), */ -/* new_pcg.add_node_port(), */ -/* node_mapping.at_l(old_edge.src), */ -/* new_pcg.add_node_port()}); */ -/* } */ - -/* void add_new_edge(OutputMultiDiEdge const &old_edge, */ -/* OutputMultiDiEdge const &new_edge) { */ -/* new_pcg.add_edge(OutputMultiDiEdge{node_mapping.at_l(new_edge.src), */ -/* new_pcg.add_node_port(), */ -/* old_edge.uid}); */ -/* } */ - -/* void add_new_edge(MultiDiEdge const &old_edge, */ -/* OutputMultiDiEdge const &new_edge) { */ -/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(old_edge.dst), */ -/* new_pcg.add_node_port(), */ -/* node_mapping.at_l(new_edge.src), */ -/* new_pcg.add_node_port()}); */ -/* } */ - -/* void add_new_edge(InputMultiDiEdge const &, OutputMultiDiEdge const &) { */ -/* assert(false); */ -/* } */ - -/* void add_new_edge(OpenMultiDiEdge const &, MultiDiEdge const &) { */ -/* assert(false); */ -/* } */ - -/* void add_new_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &) { */ -/* assert(false); */ -/* } */ -/* }; */ - -/* SubParallelComputationGraph */ -/* apply_substitution(SubParallelComputationGraph const &pcg, */ -/* Substitution const &substitution, */ -/* MultiDiGraphPatternMatch const &match) { */ -/* SubParallelComputationGraph new_pcg = */ -/* OutputLabelledOpenMultiDiGraph::template - * create< */ -/* UnorderedOutputLabelledOpenMultiDiGraph>(); */ -/* bidict node_mapping; // Refactor it with global nodes */ -/* for (Node const &node : get_nodes(pcg)) { */ -/* if (!contains_r(match.node_assignment, node)) { */ -/* node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); */ -/* } */ -/* } */ -/* for (OpenMultiDiEdge const &edge : get_edges(pcg)) { */ -/* if (!contains_r(match.edge_assignment, edge)) { */ -/* visit(AddMappedEdgeFunctor{node_mapping, new_pcg}, edge); */ -/* } */ -/* } */ -/* for (Node const &output_node : */ -/* get_nodes(substitution.output_graph_expr.value())) { */ -/* Operator new_op = get_operator_attrs( */ -/* pcg, match, substitution.output_graph_expr.value().at(output_node)); - */ -/* Node new_node = new_pcg.add_node(new_op); */ -/* node_mapping.equate(output_node, new_node); */ -/* } */ -/* for (OpenMultiDiEdge const &output_edge : */ -/* get_edges(substitution.output_graph_expr.value())) { */ -/* if (std::holds_alternative(output_edge)) { */ -/* InputMultiDiEdge e = std::get(output_edge); */ -/* OpenMultiDiEdge original_edge = */ -/* match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); */ -/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ -/* original_edge, */ -/* output_edge); */ -/* } else if (std::holds_alternative(output_edge)) { */ -/* OutputMultiDiEdge e = std::get(output_edge); */ -/* OpenMultiDiEdge original_edge = */ -/* match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); */ -/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ -/* original_edge, */ -/* output_edge); */ -/* } else { */ -/* assert(std::holds_alternative(output_edge)); */ -/* MultiDiEdge e = std::get(output_edge); */ -/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ -/* new_pcg.add_node_port(), */ -/* node_mapping.at_l(e.src), */ -/* new_pcg.add_node_port()}); */ -/* } */ -/* } */ - -/* return new_pcg; */ -/* } */ - bool is_valid_substitution(Substitution const &) { NOT_IMPLEMENTED(); } SubParallelComputationGraph - apply_substitution(SubParallelComputationGraph const &, - Substitution const &, - UnlabelledDataflowGraphPatternMatch const &) { - NOT_IMPLEMENTED(); + apply_substitution(SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match) { + auto substitution_output_result = + evaluate_substitution_output(spcg, sub, match); + SubParallelComputationGraph substitution_output_graph = + substitution_output_result.first; + OutputExprToResultSubPCGMapping output_expr_to_result_sub_pcg_mapping = + substitution_output_result.second; + + SubParallelComputationGraphData output_graph_data = + get_sub_pcg_data(substitution_output_graph); + SubParallelComputationGraphData pre_data = get_sub_pcg_data(spcg); + + std::unordered_set pre_nodes = + keys(pre_data.node_data); + std::unordered_set matched_nodes = + unordered_set_of(values(match.node_assignment)); + std::unordered_set post_nodes_from_original_graph = + set_minus(pre_nodes, matched_nodes); + + std::unordered_map post_node_data = + [&] { + std::unordered_map + post_node_data_from_orig = restrict_keys( + pre_data.node_data, post_nodes_from_original_graph); + std::unordered_map + post_node_data_from_sub = output_graph_data.node_data; + + return merge_maps(post_node_data_from_orig, post_node_data_from_sub); + }(); + + std::unordered_set post_edges = [&] { + std::unordered_set post_edges_from_orig = + filter(pre_data.edges, [&](SubParallelComputationGraphEdge const &e) { + if (e.raw_edge.has()) { + return true; + } else { + DataflowEdge dfe = e.raw_edge.get(); + parallel_layer_guid_t src = parallel_layer_guid_t{dfe.src.node}; + parallel_layer_guid_t dst = parallel_layer_guid_t{dfe.dst.node}; + return !(contains(matched_nodes, src) || + contains(matched_nodes, dst)); + } + }); + + std::unordered_set post_edges_from_sub = + filter(output_graph_data.edges, + [&](SubParallelComputationGraphEdge const &e) { + return !e.raw_edge.has(); + }); + + bidict + output_orig_pattern_mapping = get_output_mapping_for_pcg_pattern_match( + match, sub.pcg_pattern, spcg); + bidict + output_post_outexpr_mapping = get_output_graph_expr_output_mapping( + output_expr_to_result_sub_pcg_mapping, + sub.output_graph_expr, + substitution_output_graph); + + std::unordered_set incoming_to_sub_edges; + for (auto const &[pattern_input, base_graph_tensor] : + match.input_assignment) { + OutputGraphExprInput output_expr_input = + sub.inputs_mapping.at_l(pattern_input); + input_parallel_tensor_guid_t output_graph_input = + output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( + output_expr_input); + std::unordered_set uses = get_parallel_tensor_uses( + substitution_output_graph, + open_parallel_tensor_guid_from_input(output_graph_input)); + for (parallel_tensor_use_t const &use : uses) { + SubParallelComputationGraphEdge new_edge = + subpcg_edge_from_tensor_and_use(base_graph_tensor, use); + incoming_to_sub_edges.insert(new_edge); + } + } + + std::unordered_set outgoing_from_sub_edges; + for (ParallelComputationGraphEdge const &outgoing_edge : + get_subgraph_outgoing_edges(spcg, matched_nodes)) { + parallel_tensor_guid_t original_tensor = + get_parallel_tensor(outgoing_edge); + PatternNodeOutput pattern_tensor = + output_orig_pattern_mapping.at_r(original_tensor); + OutputGraphExprNodeOutput output_graph_tensor = + sub.outputs_mapping.at_l(pattern_tensor); + parallel_tensor_guid_t new_tensor = + output_post_outexpr_mapping.at_r(output_graph_tensor); + + SubParallelComputationGraphEdge new_edge = + subpcg_edge_from_tensor_and_dst( + new_tensor, + get_dst_layer(outgoing_edge), + get_dst_layer_input_idx(outgoing_edge)); + outgoing_from_sub_edges.insert(new_edge); + } + + return set_union(std::vector{ + post_edges_from_orig, + post_edges_from_sub, + incoming_to_sub_edges, + outgoing_from_sub_edges, + }); + }(); + + std::unordered_set post_inputs = + pre_data.inputs; + + std::unordered_map + post_value_data = [&] { + std::unordered_map + post_value_data_from_orig = filter_keys( + pre_data.value_data, [&](open_parallel_tensor_guid_t const &t) { + return visit_open_parallel_tensor_guid( + t, + overload{ + [&](parallel_tensor_guid_t const &t) { + return contains(post_nodes_from_original_graph, + get_source_layer(t)); + }, + [](input_parallel_tensor_guid_t const &) { + return true; + }, + }); + }); + + std::unordered_map + post_value_data_from_sub = output_graph_data.value_data; + return merge_maps(post_value_data_from_orig, post_value_data_from_sub); + }(); + + SubParallelComputationGraphData post_data = SubParallelComputationGraphData{ + post_node_data, + post_edges, + post_inputs, + post_value_data, + }; + + return sub_pcg_from_graph_data(post_data); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution_internal/evaluate_substitution_output.cc b/lib/substitutions/src/substitutions/substitution_internal/evaluate_substitution_output.cc new file mode 100644 index 0000000000..186e2fc03a --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution_internal/evaluate_substitution_output.cc @@ -0,0 +1,94 @@ +#include "substitutions/substitution_internal/evaluate_substitution_output.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution_internal/perform_shape_inference.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/map_values.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" +#include "utils/graph/node/algorithms/generate_new_node_id_permutation.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" + +namespace FlexFlow { + +std::pair + evaluate_substitution_output(SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match) { + std::unordered_map node_match = + map_values(match.node_assignment.as_unordered_map(), + [&](parallel_layer_guid_t const &n) { + return get_operator_attrs(spcg, n); + }); + + bidict new_node_id_permutation = + generate_new_node_id_permutation(sub.output_graph_expr.raw_graph); + bidict new_input_id_permutation = + generate_new_input_id_permutation(sub.output_graph_expr.raw_graph); + LabelledOpenDataflowGraphView + permuted = + permute_input_ids(permute_node_ids(sub.output_graph_expr.raw_graph, + new_node_id_permutation), + new_input_id_permutation); + + LabelledOpenDataflowGraphView + without_shapes = rewrite_node_labels( + permuted, + [&](Node const &n, OutputOperatorAttrsAssignment const &attrs) { + return ParallelLayerAttrs{ + materialize_output_operator_from_attrs_assignment(attrs, + node_match), + std::nullopt, + }; + }); + + bidict result_input_map = + map_keys(map_values(new_input_id_permutation, + [](DataflowGraphInput const &i) { + return OutputGraphExprInput{i}; + }), + [](NewDataflowGraphInput const &i) { + return input_parallel_tensor_guid_t{i.raw_input}; + }); + + bidict result_node_map = map_keys( + map_values(new_node_id_permutation, + [](Node const &n) { return OutputGraphExprNode{n}; }), + [](NewNode const &n) { return parallel_layer_guid_t{n.raw_node}; }); + + std::unordered_map input_shapes = + map_values(map_keys(match.input_assignment, + [&](PatternInput const &i) { + return result_input_map + .at_r(sub.inputs_mapping.at_l(i)) + .raw_dataflow_graph_input; + }), + [&](open_parallel_tensor_guid_t const &v) { + return spcg.raw_graph.at(v.raw_open_dataflow_value).shape; + }); + LabelledOpenDataflowGraphView + with_shapes = perform_shape_inference(without_shapes, input_shapes); + LabelledOpenDataflowGraphView + with_attrs = rewrite_value_labels( + with_shapes, + [](OpenDataflowValue const &, ParallelTensorShape const &s) { + return ParallelTensorAttrs{ + s, + std::nullopt, + std::nullopt, + CreateGrad::YES, + }; + }); + + return std::make_pair(SubParallelComputationGraph{with_attrs}, + OutputExprToResultSubPCGMapping{ + result_node_map, + result_input_map, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc new file mode 100644 index 0000000000..083334f0db --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc @@ -0,0 +1,32 @@ +#include "substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h" +#include "substitutions/output_graph/output_graph_expr.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/merge_bidicts.h" + +namespace FlexFlow { + +bidict + get_output_graph_expr_output_mapping( + OutputExprToResultSubPCGMapping const &m, + OutputGraphExpr const &output_graph_expr, + SubParallelComputationGraph const &spcg) { + bidict result; + + for (auto const &[parallel_layer, output_graph_expr_node] : m.node_mapping) { + std::vector layer_outputs = + get_layer_outputs(spcg, parallel_layer); + std::vector output_graph_expr_outputs = + get_node_outputs(output_graph_expr, output_graph_expr_node); + + bidict + mapping_for_layer = bidict_from_keys_and_values( + layer_outputs, output_graph_expr_outputs); + + result = merge_bidicts(result, mapping_for_layer); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc new file mode 100644 index 0000000000..0bde326bd1 --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -0,0 +1,45 @@ +#include "substitutions/substitution_internal/perform_shape_inference.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/zip.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" + +namespace FlexFlow { + +LabelledOpenDataflowGraphView + perform_shape_inference( + LabelledOpenDataflowGraphView const + &g, + std::unordered_map const + &input_shapes) { + + std::unordered_map inferred = + map_keys(input_shapes, [](DataflowGraphInput const &i) { + return OpenDataflowValue{i}; + }); + + for (Node const &n : get_topological_ordering(g)) { + std::vector input_shapes = + transform(get_inputs(g, n), + [&](OpenDataflowValue const &v) { return inferred.at(v); }); + + std::vector output_shapes = + get_output_shapes(g.at(n).op_attrs, input_shapes); + + std::vector outputs = get_outputs(g, n); + + for (auto const &[output, shape] : zip(outputs, output_shapes)) { + inferred.insert({OpenDataflowValue{output}, shape}); + } + } + + return rewrite_value_labels( + g, [&](OpenDataflowValue const &v, std::monostate const &) { + return inferred.at(v); + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index 8a71d92e0e..05f21247c7 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -1,4 +1,5 @@ #include "substitutions/tensor_pattern/get_attribute.h" +#include "op-attrs/parallel_tensor_dims.h" #include "utils/containers/as_vector.h" #include "utils/containers/transform.h" #include "utils/integer_conversions.h" diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc new file mode 100644 index 0000000000..794ab5abda --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc @@ -0,0 +1,9 @@ +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" + +namespace FlexFlow { + +TensorAttributePattern tensor_attribute_pattern_match_all() { + return TensorAttributePattern{{}}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index fb01733bae..a7ebc0bff7 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -10,7 +10,7 @@ #include "utils/containers/zip.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" namespace FlexFlow { @@ -37,7 +37,8 @@ static std::optional std::vector pattern_node_inputs = get_inputs_to_pattern_node(pattern, pattern_node); - std::unordered_set pattern_graph_inputs = get_inputs(pattern); + std::unordered_set pattern_graph_inputs = + get_graph_inputs(pattern); assert(unordered_set_of(pattern_node_inputs) == transform(pattern_graph_inputs, diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 31c4a23e7e..304bb8cf46 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -5,11 +5,14 @@ #include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/standard_pattern_edge.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" @@ -22,8 +25,7 @@ namespace FlexFlow { OpenDataflowSubgraphResult subgraph_matched(OpenDataflowGraphView const &g, UnlabelledDataflowGraphPatternMatch const &match) { - std::unordered_set matched_nodes = - keys(match.node_assignment.reversed()); + std::unordered_set matched_nodes = right_entries(match.node_assignment); return get_subgraph(g, matched_nodes); } @@ -149,8 +151,8 @@ bool unlabelled_pattern_does_match( OpenDataflowSubgraphResult subgraph_result = subgraph_matched(graph, match); OpenDataflowGraphView matched_subgraph = subgraph_result.graph; - assert(keys(match.node_assignment) == get_nodes(pattern)); - assert(keys(match.node_assignment.reversed()) == get_nodes(matched_subgraph)); + assert(left_entries(match.node_assignment) == get_nodes(pattern)); + assert(right_entries(match.node_assignment) == get_nodes(matched_subgraph)); MatchAdditionalCriterion through_subgraph_operation = MatchAdditionalCriterion{ diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index db49e01611..84e0d91fee 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -5,7 +5,10 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" @@ -29,8 +32,9 @@ std::unordered_set get_values(UnlabelledGraphPattern const &p) { pattern_value_from_raw_open_dataflow_value); } -std::unordered_set get_inputs(UnlabelledGraphPattern const &p) { - return transform(get_inputs(p.raw_graph), +std::unordered_set + get_graph_inputs(UnlabelledGraphPattern const &p) { + return transform(get_open_dataflow_graph_inputs(p.raw_graph), [](DataflowGraphInput const &i) { return PatternInput{i}; }); } diff --git a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc new file mode 100644 index 0000000000..70e960bc73 --- /dev/null +++ b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc @@ -0,0 +1,34 @@ +#include "substitutions/operator_pattern/get_attribute.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_attribute(LinearAttrs, OperatorAttributeKey)") { + int out_channels = 16; + bool use_bias = true; + std::optional activation = Activation::GELU; + std::optional regularizer = RegularizerAttrs{ + L1RegularizerAttrs{ + 0.5, + }, + }; + + LinearAttrs attrs = LinearAttrs{ + out_channels, + use_bias, + DataType::FLOAT, + activation, + regularizer, + }; + + SUBCASE("USE_BIAS") { + std::optional result = + get_attribute(attrs, OperatorAttributeKey::USE_BIAS); + std::optional correct = + OperatorAttributeValue{use_bias}; + CHECK(result == correct); + CHECK(result.value().has()); + } + } +} diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 8631d574f8..6922798a97 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -1,10 +1,12 @@ -#include "utils/containers/get_only.h" -#define DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#include "substitutions/pcg_pattern.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -#include "substitutions/pcg_pattern.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.h" #include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" #include "test/utils/doctest.h" +#include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" using namespace ::FlexFlow; @@ -79,18 +81,20 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributePattern, TensorAttributePattern>>(); - TensorAttributePattern pattern_tensor_a = TensorAttributePattern{{}}; - TensorAttributePattern pattern_tensor_b = TensorAttributePattern{{}}; - TensorAttributePattern pattern_tensor_c = TensorAttributePattern{{}}; - TensorAttributePattern pattern_tensor_x = TensorAttributePattern{{}}; - TensorAttributePattern pattern_tensor_y = TensorAttributePattern{{}}; - - OperatorAttributePattern op_pattern_1 = - OperatorAttributePattern{{OperatorAttributeConstraint{ - ConstraintType::EQUAL, - OperatorAttributeExpr{OperatorAttributeKey::OP_TYPE}, - OperatorAttributeValue{OperatorType::LINEAR}, - }}}; + TensorAttributePattern pattern_tensor_a = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_b = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_c = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_x = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_y = + tensor_attribute_pattern_match_all(); + + OperatorAttributePattern op_pattern_1 = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + }}; OperatorAttributePattern op_pattern_2 = op_pattern_1; @@ -116,42 +120,38 @@ TEST_SUITE(FF_TEST_SUITE) { PCGPattern pattern = PCGPattern{g}; - std::unordered_set result = - unordered_set_of( - find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); - - UnlabelledDataflowGraphPatternMatch match1 = - UnlabelledDataflowGraphPatternMatch{ - bidict{ - {op_pattern_1_node, x_matmul.raw_graph_node}, - {op_pattern_2_node, y_matmul.raw_graph_node}, - }, - bidict{ - {PatternInput{pt_a}, - OpenDataflowValue{a_tensor.raw_graph_output}}, - {PatternInput{pt_b}, - OpenDataflowValue{x_weights.raw_graph_output}}, - {PatternInput{pt_c}, - OpenDataflowValue{y_weights.raw_graph_output}}, - }}; - - UnlabelledDataflowGraphPatternMatch match2 = - UnlabelledDataflowGraphPatternMatch{ - bidict{ - {op_pattern_1_node, y_matmul.raw_graph_node}, - {op_pattern_2_node, x_matmul.raw_graph_node}, - }, - bidict{ - {PatternInput{pt_a}, - OpenDataflowValue{a_tensor.raw_graph_output}}, - {PatternInput{pt_b}, - OpenDataflowValue{y_weights.raw_graph_output}}, - {PatternInput{pt_c}, - OpenDataflowValue{x_weights.raw_graph_output}}, - }}; - - std::unordered_set correct = {match1, - match2}; + std::unordered_set result = unordered_set_of( + find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); + + PCGPatternMatch match1 = + PCGPatternMatch{bidict{ + {op_pattern_1_node, x_matmul}, + {op_pattern_2_node, y_matmul}, + }, + bidict{ + {PatternInput{pt_a}, + open_parallel_tensor_guid_from_closed(a_tensor)}, + {PatternInput{pt_b}, + open_parallel_tensor_guid_from_closed(x_weights)}, + {PatternInput{pt_c}, + open_parallel_tensor_guid_from_closed(y_weights)}, + }}; + + PCGPatternMatch match2 = + PCGPatternMatch{bidict{ + {op_pattern_1_node, y_matmul}, + {op_pattern_2_node, x_matmul}, + }, + bidict{ + {PatternInput{pt_a}, + open_parallel_tensor_guid_from_closed(a_tensor)}, + {PatternInput{pt_b}, + open_parallel_tensor_guid_from_closed(y_weights)}, + {PatternInput{pt_c}, + open_parallel_tensor_guid_from_closed(x_weights)}, + }}; + + std::unordered_set correct = {match1, match2}; CHECK(result == correct); } diff --git a/lib/substitutions/test/src/substitutions/substitution.cc b/lib/substitutions/test/src/substitutions/substitution.cc new file mode 100644 index 0000000000..87ffc01f0b --- /dev/null +++ b/lib/substitutions/test/src/substitutions/substitution.cc @@ -0,0 +1,229 @@ +#include "substitutions/substitution.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.h" +#include "substitutions/output_graph/output_graph_expr_node.dtg.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/pcg_pattern_builder.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + // TEST_CASE("is_valid_substitution") { + // FAIL("TODO"); + // } + + TEST_CASE("evaluate_substitution_output(SubParallelComputationGraph, " + "Substituion, PCGPatternMatch)") { + // Currently Substitution creation is very verbose. + // This is being addressed in + // https://github.com/flexflow/FlexFlow/issues/1473. + auto pattern_g = LabelledOpenDataflowGraph:: + create>(); + + PatternInput pattern_i_activation = + PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + PatternInput pattern_i_weights = + PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + + OperatorAttributePattern mm_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + op_attr_key_equals( + OperatorAttributeKey::ACTIVATION, + OperatorAttributeValue{std::optional{std::nullopt}}), + }}; + NodeAddedResult mm_added = pattern_g.add_node( + mm_pattern, + {OpenDataflowValue{pattern_i_activation.raw_dataflow_graph_input}, + OpenDataflowValue{pattern_i_weights.raw_dataflow_graph_input}}, + {tensor_attribute_pattern_match_all()}); + PatternNode pattern_mm_node = PatternNode{mm_added.node}; + DataflowOutput mm_output = get_only(mm_added.outputs); + + OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::RELU), + }}; + NodeAddedResult relu_added = + pattern_g.add_node(relu_pattern, + {OpenDataflowValue{mm_output}}, + {tensor_attribute_pattern_match_all()}); + PatternNode pattern_relu_node = PatternNode{relu_added.node}; + DataflowOutput relu_output = get_only(relu_added.outputs); + + LabelledOpenDataflowGraph + output_g = LabelledOpenDataflowGraph:: + create>(); + + OutputGraphExprInput output_i_activation = + OutputGraphExprInput{output_g.add_input({})}; + OutputGraphExprInput output_i_weights = + OutputGraphExprInput{output_g.add_input({})}; + + OutputOperatorAttrsAssignment fused_mm_relu_attrs_assignment = + OutputOperatorAttrsAssignment{{ + set_attr_to_constant(OperatorAttributeKey::OP_TYPE, + OperatorAttributeValue{OperatorType::LINEAR}), + copy_attr_from_pattern_node(OperatorAttributeKey::OUT_CHANNELS, + pattern_mm_node), + copy_attr_from_pattern_node(OperatorAttributeKey::USE_BIAS, + pattern_mm_node), + copy_attr_from_pattern_node(OperatorAttributeKey::DATA_TYPE, + pattern_mm_node), + set_attr_to_constant(OperatorAttributeKey::ACTIVATION, + OperatorAttributeValue{Activation::RELU}), + copy_attr_from_pattern_node(OperatorAttributeKey::REGULARIZER, + pattern_mm_node), + }}; + NodeAddedResult fused_mm_relu_added = output_g.add_node( + fused_mm_relu_attrs_assignment, + {OpenDataflowValue{output_i_activation.raw_dataflow_graph_input}, + OpenDataflowValue{output_i_weights.raw_dataflow_graph_input}}, + {{}}); + OutputGraphExprNode fused_mm_relu_node = + OutputGraphExprNode{fused_mm_relu_added.node}; + DataflowOutput fused_mm_relu_output = get_only(fused_mm_relu_added.outputs); + + Substitution sub = Substitution{ + PCGPattern{pattern_g}, + OutputGraphExpr{output_g}, + bidict{ + { + pattern_i_activation, + output_i_activation, + }, + { + pattern_i_weights, + output_i_weights, + }, + }, + bidict{ + { + PatternNodeOutput{relu_output}, + OutputGraphExprNodeOutput{fused_mm_relu_output}, + }, + }, + }; + + int in_channels = 24; + int batch_size = 4; + int batch_degree = 2; + std::string mm_match = "mm_match"; + std::string relu_match = "relu_match"; + + SubParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{size_t_from_int(batch_size), batch_degree}, + ShardParallelDim{size_t_from_int(in_channels), 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }); + t = b.dense(t, + /*outDim=*/16, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/mm_match); + t = b.relu(t, + /*name=*/relu_match); + t = b.dense(t, + /*outDim=*/8, + /*activation=*/Activation::RELU); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t mm_match_layer = + get_parallel_layer_by_name(pcg, mm_match); + parallel_layer_guid_t relu_match_layer = + get_parallel_layer_by_name(pcg, relu_match); + open_parallel_tensor_guid_t mm_match_layer_input_activations = + get_layer_inputs(pcg, mm_match_layer).at(0); + open_parallel_tensor_guid_t mm_match_layer_input_weights = + get_layer_inputs(pcg, mm_match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {pattern_mm_node, mm_match_layer}, + {pattern_relu_node, relu_match_layer}, + }, + std::unordered_map{ + { + PatternInput{pattern_i_activation}, + mm_match_layer_input_activations, + }, + { + PatternInput{pattern_i_weights}, + mm_match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = apply_substitution(pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{size_t_from_int(batch_size), batch_degree}, + ShardParallelDim{size_t_from_int(in_channels), 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }); + t = b.dense(t, + /*outDim=*/16, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12, + /*activation=*/Activation::RELU, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/std::nullopt); + t = b.dense(t, + /*outDim=*/8, + /*activation=*/Activation::RELU); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + // since the new nodes produced by the substitution have new ids, it's + // easier/more correct to check that the graphs are isomorphic rather than + // checking their exact graph data + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } +} diff --git a/lib/substitutions/test/src/substitutions/substitution_internal/evaluate_substitution_output.cc b/lib/substitutions/test/src/substitutions/substitution_internal/evaluate_substitution_output.cc new file mode 100644 index 0000000000..52b54b32fb --- /dev/null +++ b/lib/substitutions/test/src/substitutions/substitution_internal/evaluate_substitution_output.cc @@ -0,0 +1,274 @@ +#include "substitutions/substitution_internal/evaluate_substitution_output.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("evaluate_substitution_output") { + // Currently Substitution creation is very verbose. + // This is being addressed in + // https://github.com/flexflow/FlexFlow/issues/1473. + auto pattern_g = LabelledOpenDataflowGraph:: + create>(); + + PatternInput pattern_i_activation = + PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + PatternInput pattern_i_weights = + PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + + OperatorAttributePattern mm_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + op_attr_key_equals( + OperatorAttributeKey::ACTIVATION, + OperatorAttributeValue{std::optional{std::nullopt}}), + }}; + NodeAddedResult mm_added = pattern_g.add_node( + mm_pattern, + {OpenDataflowValue{pattern_i_activation.raw_dataflow_graph_input}, + OpenDataflowValue{pattern_i_weights.raw_dataflow_graph_input}}, + {tensor_attribute_pattern_match_all()}); + PatternNode pattern_mm_node = PatternNode{mm_added.node}; + DataflowOutput mm_output = get_only(mm_added.outputs); + + OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::RELU), + }}; + NodeAddedResult relu_added = + pattern_g.add_node(relu_pattern, + {OpenDataflowValue{mm_output}}, + {tensor_attribute_pattern_match_all()}); + PatternNode pattern_relu_node = PatternNode{relu_added.node}; + DataflowOutput relu_output = get_only(relu_added.outputs); + + LabelledOpenDataflowGraph + output_g = LabelledOpenDataflowGraph:: + create>(); + + OutputGraphExprInput output_i_activation = + OutputGraphExprInput{output_g.add_input({})}; + OutputGraphExprInput output_i_weights = + OutputGraphExprInput{output_g.add_input({})}; + + OutputOperatorAttrsAssignment fused_mm_relu_attrs_assignment = + OutputOperatorAttrsAssignment{{ + set_attr_to_constant(OperatorAttributeKey::OP_TYPE, + OperatorAttributeValue{OperatorType::LINEAR}), + copy_attr_from_pattern_node(OperatorAttributeKey::OUT_CHANNELS, + pattern_mm_node), + copy_attr_from_pattern_node(OperatorAttributeKey::USE_BIAS, + pattern_mm_node), + copy_attr_from_pattern_node(OperatorAttributeKey::DATA_TYPE, + pattern_mm_node), + set_attr_to_constant(OperatorAttributeKey::ACTIVATION, + OperatorAttributeValue{Activation::RELU}), + copy_attr_from_pattern_node(OperatorAttributeKey::REGULARIZER, + pattern_mm_node), + }}; + NodeAddedResult fused_mm_relu_added = output_g.add_node( + fused_mm_relu_attrs_assignment, + {OpenDataflowValue{output_i_activation.raw_dataflow_graph_input}, + OpenDataflowValue{output_i_weights.raw_dataflow_graph_input}}, + {{}}); + OutputGraphExprNode fused_mm_relu_node = + OutputGraphExprNode{fused_mm_relu_added.node}; + DataflowOutput fused_mm_relu_output = get_only(fused_mm_relu_added.outputs); + + Substitution sub = Substitution{ + PCGPattern{pattern_g}, + OutputGraphExpr{output_g}, + bidict{ + { + pattern_i_activation, + output_i_activation, + }, + { + pattern_i_weights, + output_i_weights, + }, + }, + bidict{ + { + PatternNodeOutput{relu_output}, + OutputGraphExprNodeOutput{fused_mm_relu_output}, + }, + }, + }; + + int in_channels = 24; + int batch_size = 4; + int batch_degree = 2; + std::string mm_match = "mm_match"; + std::string relu_match = "relu_match"; + + SubParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{size_t_from_int(batch_size), batch_degree}, + ShardParallelDim{size_t_from_int(in_channels), 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }); + t = b.dense(t, + /*outDim=*/16, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/mm_match); + t = b.relu(t, + /*name=*/relu_match); + t = b.dense(t, + /*outDim=*/8, + /*activation=*/Activation::RELU); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + parallel_layer_guid_t mm_match_layer = + get_parallel_layer_by_name(pcg, mm_match); + parallel_layer_guid_t relu_match_layer = + get_parallel_layer_by_name(pcg, relu_match); + open_parallel_tensor_guid_t mm_match_layer_input_activations = + get_layer_inputs(pcg, mm_match_layer).at(0); + open_parallel_tensor_guid_t mm_match_layer_input_weights = + get_layer_inputs(pcg, mm_match_layer).at(1); + + PCGPatternMatch match = PCGPatternMatch{ + bidict{ + {pattern_mm_node, mm_match_layer}, + {pattern_relu_node, relu_match_layer}, + }, + std::unordered_map{ + { + PatternInput{pattern_i_activation}, + mm_match_layer_input_activations, + }, + { + PatternInput{pattern_i_weights}, + mm_match_layer_input_weights, + }}, + }; + + SUBCASE("evaluate_substitution_output") { + std::pair + result = evaluate_substitution_output(pcg, sub, match); + + SubParallelComputationGraph result_graph = result.first; + bidict result_node_map = + result.second.node_mapping; + bidict + result_input_map = result.second.input_mapping; + + LinearAttrs correct_result_fused_mm_relu_attrs = LinearAttrs{ + 12, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + ParallelTensorAttrs correct_result_i_activation_attrs = + get_parallel_tensor_attrs(pcg, mm_match_layer_input_activations); + ParallelTensorAttrs correct_result_i_weights_attrs = + get_parallel_tensor_attrs(pcg, mm_match_layer_input_weights); + ParallelTensorAttrs correct_result_fused_mm_relu_output_attrs = + get_parallel_tensor_attrs( + pcg, + open_parallel_tensor_guid_from_closed( + get_only(get_layer_outputs(pcg, relu_match_layer)))); + + parallel_layer_guid_t result_fused_mm_relu_node = + result_node_map.at_r(fused_mm_relu_node); + parallel_tensor_guid_t result_fused_mm_relu_output = + get_only(get_layer_outputs(result_graph, result_fused_mm_relu_node)); + input_parallel_tensor_guid_t result_i_activation = + result_input_map.at_r(output_i_activation); + input_parallel_tensor_guid_t result_i_weights = + result_input_map.at_r(output_i_weights); + + SubParallelComputationGraphData correct_graph_data = + SubParallelComputationGraphData{ + std::unordered_map{{ + result_fused_mm_relu_node, + ParallelLayerAttrs{ + PCGOperatorAttrs{correct_result_fused_mm_relu_attrs}, + /*name=*/std::nullopt, + }, + }}, + std::unordered_set{ + SubParallelComputationGraphEdge{ + OpenDataflowEdge{ + DataflowInputEdge{ + result_i_activation.raw_dataflow_graph_input, + DataflowInput{ + result_fused_mm_relu_node.raw_graph_node, + 0, + }, + }, + }, + }, + SubParallelComputationGraphEdge{ + OpenDataflowEdge{ + DataflowInputEdge{ + result_i_weights.raw_dataflow_graph_input, + DataflowInput{ + result_fused_mm_relu_node.raw_graph_node, + 1, + }, + }, + }, + }, + }, + std::unordered_set{ + result_i_activation, + result_i_weights, + }, + std::unordered_map{ + { + open_parallel_tensor_guid_from_input(result_i_activation), + correct_result_i_activation_attrs, + }, + { + open_parallel_tensor_guid_from_input(result_i_weights), + correct_result_i_weights_attrs, + }, + { + open_parallel_tensor_guid_from_closed( + result_fused_mm_relu_output), + correct_result_fused_mm_relu_output_attrs, + }}}; + + SubParallelComputationGraphData result_graph_data = + get_sub_pcg_data(result_graph); + + CHECK(result_graph_data == correct_graph_data); + } + } +} diff --git a/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc new file mode 100644 index 0000000000..0bf1c21e7f --- /dev/null +++ b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -0,0 +1,173 @@ +#include "substitutions/substitution_internal/perform_shape_inference.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("perform_shape_inference") { + auto g = + LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + int in_channels = 24; + int out_channels = 16; + int batch_size = 4; + int batch_degree = 2; + + DataflowGraphInput i0 = g.add_input({}); + ParallelTensorShape i0_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{size_t_from_int(batch_size), batch_degree}, + ShardParallelDim{size_t_from_int(in_channels), 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + bool use_bias = false; + LinearAttrs n1_op_attrs = LinearAttrs{ + out_channels, + use_bias, + DataType::FLOAT, + std::nullopt, + std::nullopt, + }; + ParallelLayerAttrs n1_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + n1_op_attrs, + }, + std::nullopt, + }; + + ElementUnaryAttrs n2_op_attrs = ElementUnaryAttrs{ + OperatorType::RELU, + std::nullopt, + }; + ParallelLayerAttrs n2_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + n2_op_attrs, + }, + std::nullopt, + }; + + ParallelTensorShape n1_output_shape = + throw_if_unexpected(get_output_shape(n1_op_attrs, i0_shape)); + ParallelTensorShape n1_weight_shape = + throw_if_unexpected(get_kernel_shape(n1_op_attrs, i0_shape)); + ParallelTensorShape n2_output_shape = + throw_if_unexpected(get_output_shape(n2_op_attrs, n1_output_shape)); + + ParallelLayerAttrs n1_weight_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + WeightAttrs{get_reduced_shape(n1_weight_shape)}, + }, + std::nullopt, + }; + + ParallelLayerAttrs n1_weight_replicate_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{batch_degree}, + }, + std::nullopt, + }; + + NodeAddedResult n1_weight_added_result = + g.add_node(n1_weight_attrs, {}, {{}}); + Node n1_weight_node = n1_weight_added_result.node; + DataflowOutput n1_weight = get_only(n1_weight_added_result.outputs); + + NodeAddedResult n1_weight_replicate_added_result = g.add_node( + n1_weight_replicate_attrs, {OpenDataflowValue{n1_weight}}, {{}}); + Node n1_weight_replicate_node = n1_weight_replicate_added_result.node; + DataflowOutput n1_weight_replicated = + get_only(n1_weight_replicate_added_result.outputs); + + NodeAddedResult n1_added_result = g.add_node( + n1_attrs, + {OpenDataflowValue{i0}, OpenDataflowValue{n1_weight_replicated}}, + {{}}); + Node n1 = n1_added_result.node; + DataflowOutput o1 = get_only(n1_added_result.outputs); + + NodeAddedResult n2_added_result = + g.add_node(n2_attrs, {OpenDataflowValue{o1}}, {{}}); + Node n2 = n2_added_result.node; + DataflowOutput o2 = get_only(n2_added_result.outputs); + + std::unordered_map input_shapes = { + {i0, i0_shape}, + }; + + LabelledOpenDataflowGraphView + result = perform_shape_inference(g, input_shapes); + + LabelledOpenDataflowGraphData + result_data = get_graph_data(result); + + LabelledOpenDataflowGraphData + correct_data = LabelledOpenDataflowGraphData{ + { + {n1, n1_attrs}, + {n2, n2_attrs}, + {n1_weight_node, n1_weight_attrs}, + {n1_weight_replicate_node, n1_weight_replicate_attrs}, + }, + { + OpenDataflowEdge{ + DataflowInputEdge{ + i0, + DataflowInput{n1, 0}, + }, + }, + OpenDataflowEdge{DataflowEdge{ + DataflowOutput{n1_weight_node, 0}, + DataflowInput{n1_weight_replicate_node, 0}, + }}, + OpenDataflowEdge{ + DataflowEdge{ + DataflowOutput{n1_weight_replicate_node, 0}, + DataflowInput{n1, 1}, + }, + }, + OpenDataflowEdge{DataflowEdge{ + DataflowOutput{n1, 0}, + DataflowInput{n2, 0}, + }}, + }, + {i0}, + {{ + OpenDataflowValue{i0}, + i0_shape, + }, + { + OpenDataflowValue{DataflowOutput{n1_weight_node, 0}}, + lift_to_parallel(get_reduced_shape(n1_weight_shape)), + }, + { + OpenDataflowValue{DataflowOutput{n1_weight_replicate_node, 0}}, + n1_weight_shape, + }, + { + OpenDataflowValue{DataflowOutput{n1, 0}}, + n1_output_shape, + }, + { + OpenDataflowValue{DataflowOutput{n2, 0}}, + n2_output_shape, + }}}; + + CHECK(result_data == correct_data); + } +} diff --git a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc index 341cb23c29..6621145d39 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc @@ -61,7 +61,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("full_pattern_values_to_subpattern_2_inputs") { bidict result = split_result.full_pattern_values_to_subpattern_2_inputs; - PatternInput i0 = get_only(get_inputs(split_result.subpattern_2)); + PatternInput i0 = get_only(get_graph_inputs(split_result.subpattern_2)); bidict correct = { {pv0, i0}, }; @@ -117,7 +117,7 @@ TEST_SUITE(FF_TEST_SUITE) { split_result.full_pattern_values_to_subpattern_1_inputs; bidict correct = { {PatternValue{pi0}, - get_only(get_inputs(split_result.subpattern_1))}, + get_only(get_graph_inputs(split_result.subpattern_1))}, }; CHECK(result == correct); } @@ -126,7 +126,7 @@ TEST_SUITE(FF_TEST_SUITE) { split_result.full_pattern_values_to_subpattern_2_inputs; bidict correct = { {PatternValue{pi1}, - get_only(get_inputs(split_result.subpattern_2))}, + get_only(get_graph_inputs(split_result.subpattern_2))}, }; CHECK(result == correct); } diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index b2f4103c6a..9478195523 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -7,7 +7,8 @@ #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h new file mode 100644 index 0000000000..86ef6c4b4d --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_ENUMERATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_ENUMERATING_H + +#include "utils/bidict/bidict.h" +#include + +namespace FlexFlow { + +template +bidict bidict_from_enumerating(std::unordered_set const &s) { + bidict result; + int idx = 0; + for (T const &t : s) { + result.equate(idx, t); + idx++; + } + + return result; +} + +template +bidict bidict_from_enumerating(std::set const &s) { + bidict result; + int idx = 0; + for (T const &t : s) { + result.equate(idx, t); + idx++; + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h new file mode 100644 index 0000000000..47af03591a --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_KEYS_AND_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_KEYS_AND_VALUES_H + +#include "utils/bidict/algorithms/bidict_from_pairs.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/zip.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bidict bidict_from_keys_and_values(std::vector const &ls, + std::vector const &rs) { + size_t l_size = ls.size(); + size_t r_size = rs.size(); + if (l_size != r_size) { + throw mk_runtime_error(fmt::format( + "recieved keys (of size {}) not matching values (of size {})", + l_size, + r_size)); + } + + return bidict_from_pairs(zip(ls, rs)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_pairs.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_pairs.h new file mode 100644 index 0000000000..e33ab68f60 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_pairs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_PAIRS_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict bidict_from_pairs(C const &c) { + return bidict{c.begin(), c.end()}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/left_entries.h b/lib/utils/include/utils/bidict/algorithms/left_entries.h new file mode 100644 index 0000000000..a3fab172b1 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/left_entries.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_LEFT_ENTRIES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_LEFT_ENTRIES_H + +#include "utils/bidict/bidict.h" +#include + +namespace FlexFlow { + +template +std::unordered_set left_entries(bidict const &b) { + std::unordered_set result; + for (auto const &[l, _] : b) { + result.insert(l); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h new file mode 100644 index 0000000000..d388e35d75 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H + +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/are_disjoint.h" + +namespace FlexFlow { + +template +bidict merge_bidicts(bidict const &lhs, bidict const &rhs) { + assert(are_disjoint(left_entries(lhs), left_entries(rhs))); + assert(are_disjoint(right_entries(lhs), right_entries(rhs))); + + bidict result; + for (auto const &kv : lhs) { + result.equate(kv.first, kv.second); + } + for (auto const &kv : rhs) { + result.equate(kv.first, kv.second); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/right_entries.h b/lib/utils/include/utils/bidict/algorithms/right_entries.h new file mode 100644 index 0000000000..ec0e822c74 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/right_entries.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_RIGHT_ENTRIES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_RIGHT_ENTRIES_H + +#include "utils/bidict/bidict.h" +#include + +namespace FlexFlow { + +template +std::unordered_set right_entries(bidict const &b) { + std::unordered_set result; + for (auto const &[_, r] : b) { + result.insert(r); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index eaecb6e405..8b19313002 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_BIDICT_H #include "utils/fmt/unordered_map.h" +#include "utils/hash/unordered_map.h" #include #include #include @@ -22,6 +23,10 @@ struct bidict { } } + bool contains(L const &l, R const &r) const { + return this->contains_l(l) && this->at_l(l) == r; + } + bool contains_l(L const &l) const { return fwd_map.find(l) != fwd_map.end(); } @@ -85,6 +90,10 @@ struct bidict { return fwd_map.size(); } + bool empty() const { + return this->size() == 0; + } + using const_iterator = typename std::unordered_map::const_iterator; using value_type = std::pair; using reference = value_type &; diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 81fdff8a40..937ed51af2 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -41,13 +41,6 @@ template std::unordered_map restrict_keys(std::unordered_map const &m, std::unordered_set const &mask); -template -std::unordered_map merge_maps(std::unordered_map const &lhs, - std::unordered_map const &rhs); - -template -bidict merge_maps(bidict const &lhs, bidict const &rhs); - template std::optional at_idx(std::vector const &v, size_t idx); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 6164699f2e..7c0490fa2a 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -88,37 +88,6 @@ std::optional index_of(Container const &c, Element const &e) { } } -template -std::unordered_map merge_maps(std::unordered_map const &lhs, - std::unordered_map const &rhs) { - assert(are_disjoint(keys(lhs), keys(rhs))); - - std::unordered_map result; - for (auto const &kv : lhs) { - result.insert(kv); - } - for (auto const &kv : rhs) { - result.insert(kv); - } - - return result; -} - -template -bidict merge_maps(bidict const &lhs, bidict const &rhs) { - assert(are_disjoint(keys(lhs), keys(rhs))); - - bidict result; - for (auto const &kv : lhs) { - result.equate(kv.first, kv.second); - } - for (auto const &kv : rhs) { - result.equate(kv.first, kv.second); - } - - return result; -} - template std::function lookup_in(std::unordered_map const &m) { return [&m](K const &k) -> V { return m.at(k); }; diff --git a/lib/utils/include/utils/containers/enumerate.h b/lib/utils/include/utils/containers/enumerate.h index c9c5f4e97b..e3722e52c6 100644 --- a/lib/utils/include/utils/containers/enumerate.h +++ b/lib/utils/include/utils/containers/enumerate.h @@ -1,25 +1,46 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H -#include "utils/bidict/bidict.h" #include "utils/containers/enumerate_vector.h" +#include #include +#include namespace FlexFlow { +/** + * @brief Generate a map from indices to elements of \p c. + * + * @note We return a std::map rather than a + * std::vector> for consistency + * with enumerate(FFOrdered const &). Note that std::map + * provides ordered iteration in increasing order, so iterating through + * the result of this function should still function as expected. + */ template -bidict enumerate(std::vector const &c) { +std::map enumerate(std::vector const &c) { return enumerate_vector(c); } +/** + * @brief Choose an arbitrary ordering of the elements of \p c and + * return a map from indices of this ordering to elements of \p c. + + * + * @note We return a std::map rather than a + * std::vector> for consistency + * with enumerate(FFOrdered const &). Note that std::map + * provides ordered iteration in increasing order, so iterating through + * the result of this function should still function as expected. + */ template -bidict enumerate(std::unordered_set const &c) { - bidict m; - size_t idx = 0; +std::map enumerate(std::unordered_set const &c) { + std::map result; + int idx = 0; for (auto const &v : c) { - m.equate(idx++, v); + result.insert({idx++, v}); } - return m; + return result; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h index 8d36a5fe3b..11ee8d2352 100644 --- a/lib/utils/include/utils/containers/enumerate_vector.h +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -1,16 +1,17 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_VECTOR_H +#include #include #include namespace FlexFlow { template -std::vector> enumerate_vector(std::vector const &v) { - std::vector> result; +std::map enumerate_vector(std::vector const &v) { + std::map result; for (int i = 0; i < v.size(); i++) { - result.push_back({i, v.at(i)}); + result.insert({i, v.at(i)}); } return result; } diff --git a/lib/utils/include/utils/containers/filtrans.h b/lib/utils/include/utils/containers/filtrans.h new file mode 100644 index 0000000000..be1b5093c9 --- /dev/null +++ b/lib/utils/include/utils/containers/filtrans.h @@ -0,0 +1,73 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTRANS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTRANS_H + +#include "utils/type_traits_core.h" +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct unwrap_optional { + static_assert("T is not a std::optional!"); +}; + +template +struct unwrap_optional> : type_identity {}; + +template +using unwrap_optional_t = typename unwrap_optional::type; + +template >> +std::vector filtrans(std::vector const &v, F f) { + std::vector result; + + for (In const &i : v) { + std::optional o = f(i); + if (o.has_value()) { + result.push_back(o.value()); + } + } + + return result; +} + +template >> +std::unordered_set filtrans(std::unordered_set const &s, F f) { + std::unordered_set result; + + for (In const &i : s) { + std::optional o = f(i); + if (o.has_value()) { + result.insert(o.value()); + } + } + + return result; +} + +template >> +std::set filtrans(std::set const &s, F f) { + std::set result; + + for (In const &i : s) { + std::optional o = f(i); + if (o.has_value()) { + result.insert(o.value()); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_all_permutations.h b/lib/utils/include/utils/containers/get_all_permutations.h new file mode 100644 index 0000000000..b7e797dad2 --- /dev/null +++ b/lib/utils/include/utils/containers/get_all_permutations.h @@ -0,0 +1,106 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_H + +#include "utils/containers/sorted.h" +#include +#include +#include + +namespace FlexFlow { + +template +struct permutations_container { +public: + template + permutations_container(It start, It end) : current(start, end) { + std::sort(this->current.begin(), this->current.end()); + } + + struct iterator { + public: + using difference_type = long; + using value_type = std::vector; + using pointer = std::vector const *; + using reference = std::vector const &; + using iterator_category = std::input_iterator_tag; + + public: + explicit iterator(permutations_container const &c, bool done) + : c(c), done(done) {} + + iterator &operator++() { + assert(!this->done); + + this->done = !std::next_permutation(this->c.current.begin(), + this->c.current.end()); + return *this; + } + + iterator operator++(int) { + iterator retval = *this; + ++(*this); + return retval; + } + + bool operator==(iterator other) const { + return &this->c == &other.c && this->done == other.done; + } + + bool operator!=(iterator other) const { + return &this->c != &other.c || this->done != other.done; + } + + reference operator*() const { + return this->c.current; + } + + private: + permutations_container const &c; + bool done; + }; + + using const_iterator = iterator; + using value_type = typename iterator::value_type; + using difference_type = typename iterator::difference_type; + using pointer = typename iterator::pointer; + using reference = typename iterator::reference; + using const_reference = typename iterator::reference; + + iterator begin() const { + return iterator(*this, false); + } + + iterator end() const { + return iterator(*this, true); + } + + const_iterator cbegin() const { + return iterator(*this, false); + } + + const_iterator cend() const { + return iterator(*this, true); + } + +private: + mutable std::vector current; +}; + +/** + * @brief Lazily compute all permutations of the elements of in the input + * container. + * + * @note In cases where an element appears multiple times in the input + * (e.g., std::vector{1, 2, 2}), duplicate permutations are removed + * (i.e., {2, 1, 2} is only returned once, not twice), so it is + * possible for this function to return fewer than (but no more than) + * n! permutations. + */ +template +permutations_container get_all_permutations(C const &c) { + return permutations_container(c.cbegin(), c.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/keys.h b/lib/utils/include/utils/containers/keys.h index c1c8af54cc..e14612541e 100644 --- a/lib/utils/include/utils/containers/keys.h +++ b/lib/utils/include/utils/containers/keys.h @@ -1,13 +1,24 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_KEYS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_KEYS_H +#include +#include #include namespace FlexFlow { -template -std::unordered_set keys(C const &c) { - std::unordered_set result; +template +std::unordered_set keys(std::unordered_map const &c) { + std::unordered_set result; + for (auto const &kv : c) { + result.insert(kv.first); + } + return result; +} + +template +std::unordered_set keys(std::map const &c) { + std::unordered_set result; for (auto const &kv : c) { result.insert(kv.first); } diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h new file mode 100644 index 0000000000..653c9d24f1 --- /dev/null +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_H + +#include "utils/containers/are_disjoint.h" +#include "utils/containers/keys.h" +#include + +namespace FlexFlow { + +template +std::unordered_map merge_maps(std::unordered_map const &lhs, + std::unordered_map const &rhs) { + assert(are_disjoint(keys(lhs), keys(rhs))); + + std::unordered_map result; + for (auto const &kv : lhs) { + result.insert(kv); + } + for (auto const &kv : rhs) { + result.insert(kv); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/set_union.h b/lib/utils/include/utils/containers/set_union.h index 0f5d6d5157..0f7b895f7a 100644 --- a/lib/utils/include/utils/containers/set_union.h +++ b/lib/utils/include/utils/containers/set_union.h @@ -16,7 +16,7 @@ std::unordered_set set_union(std::unordered_set const &l, template std::unordered_set set_union(C const &sets) { std::unordered_set result; - for (std::unordered_set const &s : sets) { + for (auto const &s : sets) { for (T const &element : s) { result.insert(element); } diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index c4e561f059..ec3d5f5612 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -5,6 +5,7 @@ #include "utils/required_core.h" #include #include +#include #include #include @@ -32,6 +33,17 @@ std::unordered_set transform(std::unordered_set const &v, F const &f) { return result; } +template ()(std::declval()))> +std::set transform(std::set const &v, F const &f) { + std::set result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + template std::string transform(std::string const &s, F const &f) { std::string result; diff --git a/lib/utils/include/utils/dot_file.h b/lib/utils/include/utils/dot_file.h index 1fd9813646..214e6eeddc 100644 --- a/lib/utils/include/utils/dot_file.h +++ b/lib/utils/include/utils/dot_file.h @@ -94,13 +94,29 @@ class DotFile { this->get_ostream() << "}" << std::endl; } - void add_edge(T const &src, T const &dst) { + void add_edge(T const &src, + T const &dst, + std::optional const &src_field = std::nullopt, + std::optional const &dst_field = std::nullopt) { this->reserve_node(src); this->reserve_node(dst); - auto src_name = this->get_node_name(this->node_ids.at(src)); - auto dst_name = this->get_node_name(this->node_ids.at(dst)); - this->get_ostream() << " " << src_name << " -> " << dst_name << ";" - << std::endl; + + auto get_field_suffix = + [](std::optional const &field) -> std::string { + if (field.has_value()) { + return (":" + field.value()); + } else { + return ""; + } + }; + + std::string src_name = this->get_node_name(this->node_ids.at(src)); + + std::string dst_name = this->get_node_name(this->node_ids.at(dst)); + + this->get_ostream() << " " << src_name << get_field_suffix(src_field) + << " -> " << dst_name << get_field_suffix(dst_field) + << ";" << std::endl; } void close() { for (size_t subgraph = 0; subgraph < this->subgraph_id; subgraph++) { diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h deleted file mode 100644 index 93c450294b..0000000000 --- a/lib/utils/include/utils/exception.decl.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_DECL_H - -#include "utils/fmt.decl.h" -#include -#include - -namespace FlexFlow { - -#ifdef FF_REQUIRE_IMPLEMENTED -#define NOT_IMPLEMENTED() \ - static_assert(false, \ - "Function " __FUNC__ " not yet implemented " __FILE__ \ - ":" __LINE__); -#else -#define NOT_IMPLEMENTED() \ - throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); -#endif - -class not_implemented : public std::logic_error { -public: - not_implemented(std::string const &function_name, - std::string const &file_name, - int line); -}; - -template -T throw_if_unexpected(tl::expected const &r); - -template -std::runtime_error mk_runtime_error(fmt::format_string fmt_str, - T &&...args); -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index a00d2dba2b..20a8098040 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -1,13 +1,30 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_H #define _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_H -#include "utils/exception.decl.h" #include "utils/fmt.h" +#include #include #include namespace FlexFlow { +#ifdef FF_REQUIRE_IMPLEMENTED +#define NOT_IMPLEMENTED() \ + static_assert(false, \ + "Function " __FUNC__ " not yet implemented " __FILE__ \ + ":" __LINE__); +#else +#define NOT_IMPLEMENTED() \ + throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); +#endif + +class not_implemented : public std::logic_error { +public: + not_implemented(std::string const &function_name, + std::string const &file_name, + int line); +}; + template T throw_if_unexpected(tl::expected const &r) { if (r.has_value()) { diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h deleted file mode 100644 index 26193ae416..0000000000 --- a/lib/utils/include/utils/fmt.decl.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H - -#include "fmt/format.h" -#include "utils/check_fmtable.h" -#include -#include -#include -#include - -#define DELEGATE_OSTREAM(...) \ - template <> \ - struct delegate_ostream_operator<__VA_ARGS__> : std::true_type {} - -namespace FlexFlow { - -template -struct delegate_ostream_operator : std::false_type {}; - -template -typename std::enable_if>::value, - std::ostream &>::type - operator<<(std::ostream &s, T); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index f1d4a9f2d9..ee008f7bfe 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -1,16 +1,24 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_FMT_H #define _FLEXFLOW_UTILS_INCLUDE_FMT_H -#include "utils/fmt.decl.h" +#include "utils/check_fmtable.h" #include "utils/test_types.h" #include "utils/type_traits_core.h" +#include #include #include #include #include +#define DELEGATE_OSTREAM(...) \ + template <> \ + struct delegate_ostream_operator<__VA_ARGS__> : std::true_type {} + namespace FlexFlow { +template +struct delegate_ostream_operator : std::false_type {}; + template typename std::enable_if>::value, std::ostream &>::type diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms.h index db868a59f4..d50facee57 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms.h @@ -7,9 +7,10 @@ namespace FlexFlow { std::unordered_set get_edges(DataflowGraphView const &); -std::vector get_incoming_edges(DataflowGraphView const &, +std::vector get_input_values(DataflowGraphView const &, Node const &); -std::vector get_inputs(DataflowGraphView const &, Node const &); +std::vector get_dataflow_inputs(DataflowGraphView const &, + Node const &); std::vector get_outputs(DataflowGraphView const &, Node const &); std::unordered_set diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h new file mode 100644 index 0000000000..6c9626ce00 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H + +#include "utils/dot_file.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::string as_dot(DataflowGraphView const &); +void as_dot(DotFile &, + DataflowGraphView const &, + std::function const &get_node_label); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml new file mode 100644 index 0000000000..082c25f6ea --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DataflowGraphIsomorphism" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..914f8553dc --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_ARE_ISOMORPHIC_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +bool dataflow_graphs_are_isomorphic(DataflowGraphView const &, + DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphism.h new file mode 100644 index 0000000000..de78f9bec3 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphism.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" +#include + +namespace FlexFlow { + +/** + * @brief Find a valid isomorphism between \p src and \p dst, if one exists + * + * @note If multiple isomorphisms exist, an arbitrary one is returned + */ +std::optional + find_isomorphism(DataflowGraphView const &, DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphisms.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphisms.h new file mode 100644 index 0000000000..dda69ea69a --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphisms.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + find_isomorphisms(DataflowGraphView const &, DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_incoming_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_incoming_edges.h new file mode 100644 index 0000000000..a4cd27bf9d --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_incoming_edges.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::vector get_incoming_edges(DataflowGraphView const &, + Node const &); +std::unordered_set + get_incoming_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h new file mode 100644 index 0000000000..f26ea20473 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_outgoing_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h index febec3d14d..b1bade4254 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h @@ -10,6 +10,9 @@ DataflowEdgeQuery dataflow_edge_query_all(); DataflowEdgeQuery dataflow_edge_query_none(); bool dataflow_edge_query_includes_dataflow_edge(DataflowEdgeQuery const &, DataflowEdge const &); +DataflowEdgeQuery dataflow_edge_query_for_edge(DataflowEdge const &); +DataflowEdgeQuery dataflow_edge_query_all_outgoing_from(DataflowOutput const &); +DataflowEdgeQuery dataflow_edge_query_all_incoming_to(DataflowInput const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index 7974c033c3..6a1898dd13 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -29,8 +29,7 @@ struct DataflowGraph : virtual public DataflowGraphView { } template - static typename std::enable_if::value, - DataflowGraph>::type + static std::enable_if_t, DataflowGraph> create_copy_of(DataflowGraphView const &view) { cow_ptr_t impl = make_cow_ptr(); impl.get_mutable()->inplace_materialize_from(view); diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h index 7ed54a5c27..fc1a222f1e 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h @@ -10,6 +10,10 @@ DataflowOutputQuery dataflow_output_query_all(); DataflowOutputQuery dataflow_output_query_none(); bool dataflow_output_query_includes_dataflow_output(DataflowOutputQuery const &, DataflowOutput const &); +DataflowOutputQuery dataflow_output_query_for_output(DataflowOutput const &); +std::unordered_set + apply_dataflow_output_query(DataflowOutputQuery const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h index ad1b5f3bf5..f1063c1f21 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -16,6 +16,10 @@ #include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_source.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" @@ -109,11 +113,11 @@ struct UnorderedSetLabelledOpenDataflowGraph final return this->inputs; } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->nodes.at(n); } - ValueLabel const &at(OpenDataflowValue const &v) const override { + ValueLabel at(OpenDataflowValue const &v) const override { return this->values.at(v); } @@ -136,6 +140,26 @@ struct UnorderedSetLabelledOpenDataflowGraph final }); } + virtual void inplace_materialize_from( + LabelledOpenDataflowGraphView const &view) + override { + + std::unordered_map nodes = generate_map( + get_nodes(view), [&](Node const &n) { return view.at(n); }); + std::unordered_set edges = get_edges(view); + std::unordered_set inputs = + ::FlexFlow::get_open_dataflow_graph_inputs(view); + + std::unordered_map values = + generate_map(get_open_dataflow_values(view), + [&](OpenDataflowValue const &v) { return view.at(v); }); + + this->inputs = inputs; + this->nodes = nodes; + this->edges = edges; + this->values = values; + } + UnorderedSetLabelledOpenDataflowGraph *clone() const override { return new UnorderedSetLabelledOpenDataflowGraph{ this->node_source, diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h new file mode 100644 index 0000000000..2d4e6b11e9 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +std::optional find_isomorphism( + LabelledDataflowGraphView const &src, + LabelledDataflowGraphView const &dst) { + std::optional open_isomorphism = + find_isomorphism(view_as_labelled_open_dataflow_graph(src), + view_as_labelled_open_dataflow_graph(dst)); + + return transform(open_isomorphism, + [](OpenDataflowGraphIsomorphism const &open) { + assert(open.input_mapping.empty()); + return DataflowGraphIsomorphism{open.node_mapping}; + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h new file mode 100644 index 0000000000..4f42653380 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +bool is_isomorphic_under( + LabelledDataflowGraphView const &src, + LabelledDataflowGraphView const &dst, + DataflowGraphIsomorphism const &candidate_isomorphism) { + return is_isomorphic_under(view_as_labelled_open_dataflow_graph(src), + view_as_labelled_open_dataflow_graph(dst), + OpenDataflowGraphIsomorphism{ + candidate_isomorphism.node_mapping, + {}, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..d399c5fcdb --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H + +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" + +namespace FlexFlow { + +template +bool labelled_dataflow_graphs_are_isomorphic( + LabelledDataflowGraph const &src, + LabelledDataflowGraph const &dst) { + return find_isomorphism(src, dst).has_value(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h index 13e75efdd6..f1cdfd9690 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h @@ -33,11 +33,11 @@ struct LabelledDataflowGraphAsOpenView final return {}; } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->g.at(n); } - ValueLabel const &at(OpenDataflowValue const &v) const override { + ValueLabel at(OpenDataflowValue const &v) const override { return this->g.at(v.get()); } diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h index 9f0fc0f30d..f7bbbd9964 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h @@ -8,8 +8,8 @@ namespace FlexFlow { template struct ILabelledDataflowGraphView : virtual public IDataflowGraphView { public: - virtual NodeLabel const &at(Node const &) const = 0; - virtual OutputLabel const &at(DataflowOutput const &) const = 0; + virtual NodeLabel at(Node const &) const = 0; + virtual OutputLabel at(DataflowOutput const &) const = 0; virtual ~ILabelledDataflowGraphView() = default; }; diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h index a6a6b9d061..61e0677061 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h @@ -16,10 +16,10 @@ struct LabelledDataflowGraphView : virtual public DataflowGraphView { LabelledDataflowGraphView & operator=(LabelledDataflowGraphView const &) = default; - NodeLabel const &at(Node const &n) const { + NodeLabel at(Node const &n) const { return this->get_interface().at(n); } - OutputLabel const &at(DataflowOutput const &o) const { + OutputLabel at(DataflowOutput const &o) const { return this->get_interface().at(o); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h new file mode 100644 index 0000000000..6faddcdfcb --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" + +namespace FlexFlow { + +template +std::string as_dot( + LabelledOpenDataflowGraphView const &g, + std::function const &get_node_label, + std::function const &get_input_label) { + std::function unlabelled_get_node_label = + [&](Node const &n) -> std::string { return get_node_label(g.at(n)); }; + + std::function + unlabelled_get_input_label = [&](DataflowGraphInput const &i) { + return get_input_label(g.at(OpenDataflowValue{i})); + }; + + return as_dot(static_cast(g), + unlabelled_get_node_label, + unlabelled_get_input_label); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h new file mode 100644 index 0000000000..a1d6e9e37a --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H + +#include "utils/containers/as_vector.h" +#include "utils/containers/get_all_permutations.h" +#include "utils/containers/zip.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +/** + * @brief Finds an isomorphism between \p src and \p dst, if one exists. + * + * @note If multiple isomorphisms exist, an arbitrary one is returned. + */ +template +std::optional find_isomorphism( + LabelledOpenDataflowGraphView const &src, + LabelledOpenDataflowGraphView const &dst) { + std::unordered_set unlabelled_isomorphisms = + find_isomorphisms(static_cast(src), + static_cast(dst)); + + for (OpenDataflowGraphIsomorphism const &candidate_isomorphism : + unlabelled_isomorphisms) { + if (is_isomorphic_under(src, dst, candidate_isomorphism)) { + return candidate_isomorphism; + } + } + + return std::nullopt; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h new file mode 100644 index 0000000000..106d500464 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FROM_LABELLED_OPEN_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FROM_LABELLED_OPEN_DATAFLOW_GRAPH_DATA_H + +#include "utils/containers/filtrans.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +LabelledOpenDataflowGraphView + from_labelled_open_dataflow_graph_data( + LabelledOpenDataflowGraphData const &data) { + std::unordered_set values = keys(data.value_data); + std::unordered_set outputs = + filtrans(values, try_get_dataflow_output); + + OpenDataflowGraphData unlabelled_data = OpenDataflowGraphData{ + keys(data.node_data), + data.edges, + data.inputs, + outputs, + }; + + return with_labelling(from_open_dataflow_graph_data(unlabelled_data), + data.node_data, + data.value_data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h new file mode 100644 index 0000000000..ec8f025ac3 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_GRAPH_DATA_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" + +namespace FlexFlow { + +template +LabelledOpenDataflowGraphData get_graph_data( + LabelledOpenDataflowGraphView const &g) { + + std::unordered_map node_data = + generate_map(get_nodes(g), [&](Node const &n) { return g.at(n); }); + + std::unordered_set edges = get_edges(g); + + std::unordered_set inputs = g.get_inputs(); + + std::unordered_map value_data = + generate_map(get_open_dataflow_values(g), + [&](OpenDataflowValue const &v) { return g.at(v); }); + + return LabelledOpenDataflowGraphData{ + node_data, + edges, + inputs, + value_data, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h new file mode 100644 index 0000000000..ecf9c22143 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +bool is_isomorphic_under( + LabelledOpenDataflowGraphView const &src, + LabelledOpenDataflowGraphView const &dst, + OpenDataflowGraphIsomorphism const &candidate_isomorphism) { + + bidict node_permutation = + map_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { + return NewNode{dst_node}; + }).reversed(); + bidict input_permutation = + map_values(candidate_isomorphism.input_mapping, + [](DataflowGraphInput const &dst_input) { + return NewDataflowGraphInput{dst_input}; + }) + .reversed(); + return get_graph_data(permute_input_ids( + permute_node_ids(src, node_permutation), input_permutation)) == + get_graph_data(dst); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml new file mode 100644 index 0000000000..082b61e691 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "LabelledOpenDataflowGraphData" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = ["NodeLabel", "ValueLabel"] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", + "", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "node_data" +type = "std::unordered_map<::FlexFlow::Node, NodeLabel>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::OpenDataflowEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::DataflowGraphInput>" + +[[fields]] +name = "value_data" +type = "std::unordered_map<::FlexFlow::OpenDataflowValue, ValueLabel>" diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..b3a71235cc --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" + +namespace FlexFlow { + +template +bool labelled_open_dataflow_graphs_are_isomorphic( + LabelledOpenDataflowGraphView const &lhs, + LabelledOpenDataflowGraphView const &rhs) { + return find_isomorphism(lhs, rhs).has_value(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h new file mode 100644 index 0000000000..88132e0a79 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_INPUT_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_INPUT_IDS_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +LabelledOpenDataflowGraphView permute_input_ids( + LabelledOpenDataflowGraphView const &g, + bidict const &input_mapping) { + + OpenDataflowGraphView permuted = + permute_input_ids(static_cast(g), input_mapping); + + auto old_value_from_new = [&](OpenDataflowValue const &new_value) { + return new_value.visit(overload{ + [](DataflowOutput const &o) { return OpenDataflowValue{o}; }, + [&](DataflowGraphInput const &new_i) { + return OpenDataflowValue{ + input_mapping.at_l(NewDataflowGraphInput{new_i}), + }; + }, + }); + }; + + std::unordered_map node_labels = + generate_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); + + std::unordered_map value_labels = + generate_map(get_open_dataflow_values(permuted), + [&](OpenDataflowValue const &new_value) { + return g.at(old_value_from_new(new_value)); + }); + + return with_labelling(permuted, node_labels, value_labels); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h new file mode 100644 index 0000000000..2d1dd03755 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h @@ -0,0 +1,54 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" +#include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +LabelledOpenDataflowGraphView permute_node_ids( + LabelledOpenDataflowGraphView const &g, + bidict const &new_node_tofrom_old_node) { + OpenDataflowGraphView permuted = permute_node_ids( + static_cast(g), new_node_tofrom_old_node); + + auto old_node_from_new = [&](Node const &new_node) { + return new_node_tofrom_old_node.at_l(NewNode{new_node}); + }; + + auto old_value_from_new = [&](OpenDataflowValue const &new_value) { + return new_value.visit(overload{ + [&](DataflowOutput const &new_o) { + return OpenDataflowValue{ + DataflowOutput{ + old_node_from_new(new_o.node), + new_o.idx, + }, + }; + }, + [](DataflowGraphInput const &i) { return OpenDataflowValue{i}; }, + }); + }; + + std::unordered_map node_labels = + generate_map(get_nodes(permuted), [&](Node const &new_node) { + return g.at(old_node_from_new(new_node)); + }); + + std::unordered_map value_labels = + generate_map(get_open_dataflow_values(permuted), + [&](OpenDataflowValue const &new_value) { + return g.at(old_value_from_new(new_value)); + }); + + return with_labelling(permuted, node_labels, value_labels); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h index 2849bfa72f..92938d7142 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h @@ -4,7 +4,7 @@ #include "utils/containers/generate_map.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h new file mode 100644 index 0000000000..eb39c4fe6a --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template > +LabelledOpenDataflowGraphView rewrite_node_labels( + LabelledOpenDataflowGraphView const &g, F f) { + return rewrite_labels( + g, + overload{ + [&](Node const &n, NodeLabel const &l) { return f(n, l); }, + [](OpenDataflowValue const &v, ValueLabel const &l) { return l; }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h new file mode 100644 index 0000000000..c0582d8e3d --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_VALUE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_VALUE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template < + typename NodeLabel, + typename ValueLabel, + typename F, + typename NewValueLabel = + std::invoke_result_t> +LabelledOpenDataflowGraphView rewrite_value_labels( + LabelledOpenDataflowGraphView const &g, F f) { + return rewrite_labels(g, + overload{ + [](Node const &n, NodeLabel const &l) { return l; }, + [&](OpenDataflowValue const &v, + ValueLabel const &l) { return f(v, l); }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h index e95781af6e..3697ab0f93 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h @@ -36,11 +36,11 @@ struct OpenDataflowGraphLabellingWrapper final return this->unlabelled.get_inputs(); } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->node_labels.at(n); } - ValueLabel const &at(OpenDataflowValue const &v) const override { + ValueLabel at(OpenDataflowValue const &v) const override { return this->value_labels.at(v); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h index a4a3fc0bea..01777909cd 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h @@ -4,6 +4,7 @@ #include "utils/graph/dataflow_graph/node_added_result.dtg.h" #include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" namespace FlexFlow { @@ -18,6 +19,9 @@ struct ILabelledOpenDataflowGraph virtual DataflowGraphInput add_input(ValueLabel const &value_label) = 0; + virtual void inplace_materialize_from( + LabelledOpenDataflowGraphView const &) = 0; + // NodeAddedResult add_node(NodeLabel const &node_label, // std::vector const &inputs, // std::vector const &output_labels) diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h index 58137704e6..a59ce72896 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h @@ -12,10 +12,10 @@ struct ILabelledOpenDataflowGraphView : virtual public ILabelledDataflowGraphView, virtual public IOpenDataflowGraphView { public: - virtual NodeLabel const &at(Node const &) const override = 0; - virtual ValueLabel const &at(OpenDataflowValue const &) const = 0; + virtual NodeLabel at(Node const &) const override = 0; + virtual ValueLabel at(OpenDataflowValue const &) const = 0; - ValueLabel const &at(DataflowOutput const &o) const override final { + ValueLabel at(DataflowOutput const &o) const override final { return this->at(OpenDataflowValue{o}); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h index 76877e245a..375e40d5ea 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h @@ -34,6 +34,16 @@ struct LabelledOpenDataflowGraph return LabelledOpenDataflowGraph(make_cow_ptr()); } + template + static std::enable_if_t, + LabelledOpenDataflowGraph> + create_copy_of( + LabelledOpenDataflowGraphView const &view) { + cow_ptr_t impl = make_cow_ptr(); + impl.get_mutable()->inplace_materialize_from(view); + return LabelledOpenDataflowGraph(std::move(impl)); + } + protected: using LabelledOpenDataflowGraphView:: LabelledOpenDataflowGraphView; diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h index 6e08b10a29..935f615ec8 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h @@ -20,11 +20,11 @@ struct LabelledOpenDataflowGraphView LabelledOpenDataflowGraphView & operator=(LabelledOpenDataflowGraphView const &) = default; - NodeLabel const &at(Node const &n) const { + NodeLabel at(Node const &n) const { return this->get_interface().at(n); } - ValueLabel const &at(OpenDataflowValue const &v) const { + ValueLabel at(OpenDataflowValue const &v) const { return this->get_interface().at(v); } diff --git a/lib/utils/include/utils/graph/node/algorithms/generate_new_node_id_permutation.h b/lib/utils/include/utils/graph/node/algorithms/generate_new_node_id_permutation.h new file mode 100644 index 0000000000..42ae867883 --- /dev/null +++ b/lib/utils/include/utils/graph/node/algorithms/generate_new_node_id_permutation.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_ALGORITHMS_GENERATE_NEW_NODE_ID_PERMUTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_ALGORITHMS_GENERATE_NEW_NODE_ID_PERMUTATION_H + +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/node/graph_view.h" + +namespace FlexFlow { + +bidict generate_new_node_id_permutation(GraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml b/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml new file mode 100644 index 0000000000..f3b8244573 --- /dev/null +++ b/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "NewNode" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/node/node_query.h b/lib/utils/include/utils/graph/node/node_query.h index b7d754ceac..2ec8958083 100644 --- a/lib/utils/include/utils/graph/node/node_query.h +++ b/lib/utils/include/utils/graph/node/node_query.h @@ -8,6 +8,8 @@ namespace FlexFlow { NodeQuery node_query_all(); NodeQuery query_intersection(NodeQuery const &, NodeQuery const &); NodeQuery query_union(NodeQuery const &, NodeQuery const &); +std::unordered_set apply_node_query(NodeQuery const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h new file mode 100644 index 0000000000..4c600637aa --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::string as_dot(OpenDataflowGraphView const &); +std::string + as_dot(OpenDataflowGraphView const &, + std::function const &get_node_label, + std::function const + &get_input_label); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h new file mode 100644 index 0000000000..4c1ec38b89 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +/** + * @brief Find a valid isomorphism between \p src and \p dst, if one exists + * + * @note If multiple isomorphisms exist, an arbitrary one is returned + */ +std::optional + find_isomorphism(OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h new file mode 100644 index 0000000000..022fc5b9fd --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + find_isomorphisms(OpenDataflowGraphView const &, + OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h new file mode 100644 index 0000000000..1fbbea21b0 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FROM_OPEN_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FROM_OPEN_DATAFLOW_GRAPH_DATA_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +struct FromOpenDataflowGraphDataView final + : virtual public IOpenDataflowGraphView { + FromOpenDataflowGraphDataView(OpenDataflowGraphData const &); + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set + query_edges(OpenDataflowEdgeQuery const &) const override; + std::unordered_set + query_outputs(DataflowOutputQuery const &) const override; + std::unordered_set get_inputs() const override; + + FromOpenDataflowGraphDataView *clone() const override; + +private: + OpenDataflowGraphData data; +}; + +OpenDataflowGraphView + from_open_dataflow_graph_data(OpenDataflowGraphData const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h new file mode 100644 index 0000000000..803b5c849b --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GENERATE_NEW_INPUT_ID_PERMUTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GENERATE_NEW_INPUT_ID_PERMUTATION_H + +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +bidict + generate_new_input_id_permutation(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_edges.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_edges.h new file mode 100644 index 0000000000..0710b3d970 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_edges.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_EDGES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_edges(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_graph_data.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_graph_data.h new file mode 100644 index 0000000000..6bb4f123df --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_graph_data.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_GRAPH_DATA_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +OpenDataflowGraphData get_graph_data(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h new file mode 100644 index 0000000000..84e0f57e3d --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGE_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +OpenDataflowEdge get_incoming_edge(OpenDataflowGraphView const &, + DataflowInput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h similarity index 53% rename from lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h rename to lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h index 9ba22394b2..22d66a0c0f 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h @@ -1,23 +1,17 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGES_H #include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" namespace FlexFlow { -std::unordered_set get_edges(OpenDataflowGraphView const &); -std::unordered_set - get_inputs(OpenDataflowGraphView const &); -std::vector get_inputs(OpenDataflowGraphView const &, - Node const &); +std::unordered_set + get_incoming_edges(OpenDataflowGraphView const &); std::vector get_incoming_edges(OpenDataflowGraphView const &, Node const &); std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &, std::unordered_set const &); -std::unordered_set - get_open_dataflow_values(OpenDataflowGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_inputs.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_inputs.h new file mode 100644 index 0000000000..ae596010f8 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_inputs.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INPUTS_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +std::vector get_inputs(OpenDataflowGraphView const &, + Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h new file mode 100644 index 0000000000..98231c8f8c --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_GRAPH_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_GRAPH_INPUTS_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_graph_inputs(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h new file mode 100644 index 0000000000..bd7749a172 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_VALUE_USES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_VALUE_USES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_value_uses(OpenDataflowGraphView const &view, + OpenDataflowValue const &value); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h new file mode 100644 index 0000000000..5d8f58540e --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_VALUES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_values(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h new file mode 100644 index 0000000000..a89b4e1bc1 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SOURCE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SOURCE_NODES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_source_nodes(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h new file mode 100644 index 0000000000..0df5f8458c --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(OpenDataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h new file mode 100644 index 0000000000..2325dcfbda --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_UNUSED_OPEN_DATAFLOW_GRAPH_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_UNUSED_OPEN_DATAFLOW_GRAPH_INPUTS_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_unused_open_dataflow_graph_inputs(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h new file mode 100644 index 0000000000..9ee5ac0790 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +bool is_isomorphic_under(OpenDataflowGraphView const &, + OpenDataflowGraphView const &, + OpenDataflowGraphIsomorphism const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml new file mode 100644 index 0000000000..76b062e211 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "NewDataflowGraphInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_input" +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml new file mode 100644 index 0000000000..467ca73b3f --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "OpenDataflowGraphData" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::OpenDataflowEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::DataflowGraphInput>" + +[[fields]] +name = "outputs" +type = "std::unordered_set<::FlexFlow::DataflowOutput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml new file mode 100644 index 0000000000..bafe3c7117 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OpenDataflowGraphIsomorphism" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" + +[[fields]] +name = "input_mapping" +type = "::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, ::FlexFlow::DataflowGraphInput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..6e27e55802 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +bool open_dataflow_graphs_are_isomorphic(OpenDataflowGraphView const &, + OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h new file mode 100644 index 0000000000..36add91574 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_INPUT_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_INPUT_IDS_H + +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +OpenDataflowGraphView permute_input_ids( + OpenDataflowGraphView const &, + bidict const &input_mapping); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h new file mode 100644 index 0000000000..64293383c6 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H + +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +OpenDataflowGraphView + permute_node_ids(OpenDataflowGraphView const &, + bidict const &new_node_tofrom_old_node); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h index 1189757c0e..78099fec57 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h @@ -10,6 +10,12 @@ DataflowInputEdgeQuery dataflow_input_edge_query_all(); DataflowInputEdgeQuery dataflow_input_edge_query_none(); bool dataflow_input_edge_query_includes(DataflowInputEdgeQuery const &, DataflowInputEdge const &); +DataflowInputEdgeQuery + dataflow_input_edge_query_for_edge(DataflowInputEdge const &); +DataflowInputEdgeQuery + dataflow_input_edge_query_all_outgoing_from(DataflowGraphInput const &); +DataflowInputEdgeQuery + dataflow_input_edge_query_all_incoming_to(DataflowInput const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h index 3289ea48ae..09499f8e5f 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h @@ -8,7 +8,8 @@ namespace FlexFlow { Node get_open_dataflow_edge_dst_node(OpenDataflowEdge const &); int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &); -OpenDataflowValue get_open_dataflow_edge_source(OpenDataflowEdge const &); +DataflowInput get_open_dataflow_edge_dst(OpenDataflowEdge const &); +OpenDataflowValue get_open_dataflow_edge_src(OpenDataflowEdge const &); OpenDataflowEdge open_dataflow_edge_from_src_and_dst(OpenDataflowValue const &src, DataflowInput const &dst); diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h index 46630a2625..ae6e30549b 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h @@ -3,6 +3,7 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" namespace FlexFlow { @@ -10,6 +11,13 @@ OpenDataflowEdgeQuery open_dataflow_edge_query_all(); OpenDataflowEdgeQuery open_dataflow_edge_query_none(); bool open_dataflow_edge_query_includes(OpenDataflowEdgeQuery const &q, OpenDataflowEdge const &); +OpenDataflowEdgeQuery + open_dataflow_edge_query_all_outgoing_from(OpenDataflowValue const &); +OpenDataflowEdgeQuery + open_dataflow_edge_query_all_incoming_to(DataflowInput const &); +std::unordered_set apply_open_dataflow_edge_query( + OpenDataflowEdgeQuery const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.h new file mode 100644 index 0000000000..d106205a07 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_VALUE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_VALUE_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include + +namespace FlexFlow { + +std::optional + try_get_dataflow_output(OpenDataflowValue const &); +std::optional + try_get_dataflow_graph_input(OpenDataflowValue const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/required_core.h b/lib/utils/include/utils/required_core.h index 76f03549a4..7a7abcd2c4 100644 --- a/lib/utils/include/utils/required_core.h +++ b/lib/utils/include/utils/required_core.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H -#include "fmt.decl.h" #include "hash-utils.h" #include "test_types.h" #include "type_traits_core.h" +#include #include namespace FlexFlow { @@ -191,7 +191,9 @@ template using req = required; template -struct delegate_ostream_operator> : std::true_type {}; +std::ostream &operator<<(std::ostream &s, required const &t) { + return (s << fmt::to_string(t)); +} template struct remove_req { diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index 71c369df6a..afc16d4c4b 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_TUPLE_H #define _FLEXFLOW_UTILS_TUPLE_H -#include "utils/exception.decl.h" +#include "utils/exception.h" #include "utils/type_traits_core.h" #include #include diff --git a/lib/utils/src/utils/bidict/algorithms/bidict_from_enumerating.cc b/lib/utils/src/utils/bidict/algorithms/bidict_from_enumerating.cc new file mode 100644 index 0000000000..350f08600c --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/bidict_from_enumerating.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/bidict_from_enumerating.h" diff --git a/lib/utils/src/utils/bidict/algorithms/bidict_from_keys_and_values.cc b/lib/utils/src/utils/bidict/algorithms/bidict_from_keys_and_values.cc new file mode 100644 index 0000000000..34562f40c1 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/bidict_from_keys_and_values.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" diff --git a/lib/utils/src/utils/bidict/algorithms/bidict_from_pairs.cc b/lib/utils/src/utils/bidict/algorithms/bidict_from_pairs.cc new file mode 100644 index 0000000000..c8a27b8143 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/bidict_from_pairs.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/bidict_from_pairs.h" diff --git a/lib/utils/src/utils/bidict/algorithms/left_entries.cc b/lib/utils/src/utils/bidict/algorithms/left_entries.cc new file mode 100644 index 0000000000..a2c19de124 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/left_entries.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/left_entries.h" diff --git a/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc new file mode 100644 index 0000000000..f70be2355f --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/merge_bidicts.h" diff --git a/lib/utils/src/utils/bidict/algorithms/right_entries.cc b/lib/utils/src/utils/bidict/algorithms/right_entries.cc new file mode 100644 index 0000000000..2f517a0af6 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/right_entries.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/right_entries.h" diff --git a/lib/utils/src/utils/containers/filtrans.cc b/lib/utils/src/utils/containers/filtrans.cc new file mode 100644 index 0000000000..a57a743ef0 --- /dev/null +++ b/lib/utils/src/utils/containers/filtrans.cc @@ -0,0 +1 @@ +#include "utils/containers/filtrans.h" diff --git a/lib/utils/src/utils/containers/get_all_permutations.cc b/lib/utils/src/utils/containers/get_all_permutations.cc new file mode 100644 index 0000000000..0fa4e16f08 --- /dev/null +++ b/lib/utils/src/utils/containers/get_all_permutations.cc @@ -0,0 +1 @@ +#include "utils/containers/get_all_permutations.h" diff --git a/lib/utils/src/utils/containers/merge_maps.cc b/lib/utils/src/utils/containers/merge_maps.cc new file mode 100644 index 0000000000..a36217fbeb --- /dev/null +++ b/lib/utils/src/utils/containers/merge_maps.cc @@ -0,0 +1 @@ +#include "utils/containers/merge_maps.h" diff --git a/lib/utils/src/utils/exception.cc b/lib/utils/src/utils/exception.cc new file mode 100644 index 0000000000..9bbf780fd8 --- /dev/null +++ b/lib/utils/src/utils/exception.cc @@ -0,0 +1 @@ +#include "utils/exception.h" diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc index 64af07636a..f0e52d6fc2 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc @@ -1,6 +1,7 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/containers/sorted_by.h" #include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" #include "utils/graph/dataflow_graph/dataflow_output_query.h" @@ -10,23 +11,16 @@ std::unordered_set get_edges(DataflowGraphView const &g) { return g.query_edges(dataflow_edge_query_all()); } -std::vector get_incoming_edges(DataflowGraphView const &g, +std::vector get_input_values(DataflowGraphView const &g, Node const &n) { - return sorted_by(g.query_edges(DataflowEdgeQuery{ - query_set::matchall(), - query_set::matchall(), - {n}, - query_set::matchall(), - }), - [](DataflowEdge const &l, DataflowEdge const &r) { - return l.dst.idx < r.dst.idx; - }); + return transform(get_incoming_edges(g, n), + [](DataflowEdge const &e) { return e.src; }); } -std::vector get_inputs(DataflowGraphView const &g, - Node const &n) { +std::vector get_dataflow_inputs(DataflowGraphView const &g, + Node const &n) { return transform(get_incoming_edges(g, n), - [](DataflowEdge const &e) { return e.src; }); + [](DataflowEdge const &e) { return e.dst; }); } std::vector get_outputs(DataflowGraphView const &g, diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc new file mode 100644 index 0000000000..47c30ce998 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc @@ -0,0 +1,64 @@ +#include "utils/graph/dataflow_graph/algorithms/as_dot.h" +#include "utils/dot_file.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/node/algorithms.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +// WARN(@lockshaw): doing this all with string ids is ugly and error prone, +// as it requires duplicating the stringification logic across functions. +// +// Fixing this is tracked in issue +std::string as_dot(DataflowGraphView const &g) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + std::function get_node_label = + [](Node const &n) -> std::string { + return fmt::format("n{}", n.raw_uid); + }; + as_dot(dot, g, get_node_label); + + dot.close(); + return oss.str(); +} + +void as_dot(DotFile &dot, + DataflowGraphView const &g, + std::function const &get_node_label) { + auto get_node_name = [](Node n) { return fmt::format("n{}", n.raw_uid); }; + + auto get_input_field = [](int idx) { return fmt::format("i{}", idx); }; + + auto get_output_field = [](int idx) { return fmt::format("o{}", idx); }; + + for (Node const &n : get_nodes(g)) { + std::vector n_inputs = get_dataflow_inputs(g, n); + std::vector n_outputs = get_outputs(g, n); + + RecordFormatter inputs_record; + for (DataflowInput const &i : n_inputs) { + inputs_record << fmt::format("<{}>{}", get_input_field(i.idx), i.idx); + } + + RecordFormatter outputs_record; + for (DataflowOutput const &o : n_outputs) { + outputs_record << fmt::format("<{}>{}", get_output_field(o.idx), o.idx); + } + + RecordFormatter rec; + rec << inputs_record << get_node_label(n) << outputs_record; + + dot.add_record_node(get_node_name(n), rec); + } + + for (DataflowEdge const &e : get_edges(g)) { + dot.add_edge(get_node_name(e.src.node), + get_node_name(e.dst.node), + get_output_field(e.src.idx), + get_input_field(e.dst.idx)); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..ac7f9967be --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,11 @@ +#include "utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h" +#include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" + +namespace FlexFlow { + +bool dataflow_graphs_are_isomorphic(DataflowGraphView const &src, + DataflowGraphView const &dst) { + return find_isomorphism(src, dst).has_value(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..d06a64597e --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,20 @@ +#include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_first.h" +#include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" + +namespace FlexFlow { + +std::optional + find_isomorphism(DataflowGraphView const &src, + DataflowGraphView const &dst) { + std::unordered_set all_isomorphisms = + find_isomorphisms(src, dst); + + if (all_isomorphisms.empty()) { + return std::nullopt; + } else { + return get_first(all_isomorphisms); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphisms.cc new file mode 100644 index 0000000000..0e0210e5a2 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphisms.cc @@ -0,0 +1,22 @@ +#include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" + +namespace FlexFlow { + +std::unordered_set + find_isomorphisms(DataflowGraphView const &src, + DataflowGraphView const &dst) { + std::unordered_set open_isomorphisms = + find_isomorphisms(view_as_open_dataflow_graph(src), + view_as_open_dataflow_graph(dst)); + + return transform(open_isomorphisms, + [](OpenDataflowGraphIsomorphism const &open) { + assert(open.input_mapping.empty()); + return DataflowGraphIsomorphism{open.node_mapping}; + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc new file mode 100644 index 0000000000..9500836db1 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc @@ -0,0 +1,31 @@ +#include "utils/graph/dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/containers/sorted_by.h" + +namespace FlexFlow { + +std::vector get_incoming_edges(DataflowGraphView const &g, + Node const &n) { + return sorted_by(g.query_edges(DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + {n}, + query_set::matchall(), + }), + [](DataflowEdge const &l, DataflowEdge const &r) { + return l.dst.idx < r.dst.idx; + }); +} + +std::unordered_set + get_incoming_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + DataflowEdgeQuery query = DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }; + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc new file mode 100644 index 0000000000..c442a26dab --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_outgoing_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + + std::unordered_set all_nodes = get_nodes(g); + query_set dst_query = query_set{set_minus(all_nodes, ns)}; + + DataflowEdgeQuery query = DataflowEdgeQuery{ + query_set{ns}, + query_set::matchall(), + dst_query, + query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.cc new file mode 100644 index 0000000000..0fd0b85b71 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.cc @@ -0,0 +1,41 @@ +#include "utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +ViewDataflowGraphAsOpen::ViewDataflowGraphAsOpen(DataflowGraphView const &g) + : g(g) {} + +std::unordered_set + ViewDataflowGraphAsOpen::query_nodes(NodeQuery const &q) const { + return this->g.query_nodes(q); +} + +std::unordered_set + ViewDataflowGraphAsOpen::query_edges(OpenDataflowEdgeQuery const &q) const { + std::unordered_set closed_edges = + this->g.query_edges(q.standard_edge_query); + + return transform(closed_edges, + [](DataflowEdge const &e) { return OpenDataflowEdge{e}; }); +} + +std::unordered_set + ViewDataflowGraphAsOpen::query_outputs(DataflowOutputQuery const &q) const { + return this->g.query_outputs(q); +} + +std::unordered_set + ViewDataflowGraphAsOpen::get_inputs() const { + return {}; +} + +ViewDataflowGraphAsOpen *ViewDataflowGraphAsOpen::clone() const { + return new ViewDataflowGraphAsOpen{this->g}; +} + +OpenDataflowGraphView view_as_open_dataflow_graph(DataflowGraphView const &g) { + return OpenDataflowGraphView::create(g); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h new file mode 100644 index 0000000000..bec9d0e019 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +struct ViewDataflowGraphAsOpen final : public IOpenDataflowGraphView { +public: + ViewDataflowGraphAsOpen() = delete; + ViewDataflowGraphAsOpen(DataflowGraphView const &); + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set + query_edges(OpenDataflowEdgeQuery const &) const override; + std::unordered_set + query_outputs(DataflowOutputQuery const &) const override; + std::unordered_set get_inputs() const override; + + ViewDataflowGraphAsOpen *clone() const override; + + ~ViewDataflowGraphAsOpen() = default; + +private: + DataflowGraphView g; +}; + +OpenDataflowGraphView view_as_open_dataflow_graph(DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc index e30dc41c1f..2196f7a028 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc @@ -27,4 +27,33 @@ bool dataflow_edge_query_includes_dataflow_edge(DataflowEdgeQuery const &q, includes(q.dst_idxs, e.dst.idx); } +DataflowEdgeQuery dataflow_edge_query_for_edge(DataflowEdge const &e) { + return DataflowEdgeQuery{ + query_set{e.src.node}, + query_set{e.src.idx}, + query_set{e.dst.node}, + query_set{e.dst.idx}, + }; +} + +DataflowEdgeQuery + dataflow_edge_query_all_outgoing_from(DataflowOutput const &src) { + return DataflowEdgeQuery{ + query_set{src.node}, + query_set{src.idx}, + query_set::matchall(), + query_set::matchall(), + }; +} + +DataflowEdgeQuery + dataflow_edge_query_all_incoming_to(DataflowInput const &dst) { + return DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + query_set{dst.node}, + query_set{dst.idx}, + }; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc index b8d89a250d..64df4c77f2 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc @@ -21,4 +21,19 @@ bool dataflow_output_query_includes_dataflow_output( return includes(q.nodes, o.node) && includes(q.output_idxs, o.idx); } +DataflowOutputQuery dataflow_output_query_for_output(DataflowOutput const &o) { + return DataflowOutputQuery{ + query_set{o.node}, + query_set{o.idx}, + }; +} + +std::unordered_set + apply_dataflow_output_query(DataflowOutputQuery const &q, + std::unordered_set const &os) { + return filter(os, [&](DataflowOutput const &o) { + return dataflow_output_query_includes_dataflow_output(q, o); + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.cc new file mode 100644 index 0000000000..88ec6d141a --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h" diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..9fa68e58b2 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc new file mode 100644 index 0000000000..78dbed5262 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..c53cd4cd15 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.cc new file mode 100644 index 0000000000..49d3a663d9 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.cc new file mode 100644 index 0000000000..854f55732f --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc new file mode 100644 index 0000000000..32a6da0bb5 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..1dcbbdc1e6 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.cc new file mode 100644 index 0000000000..2a5fe55809 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.cc new file mode 100644 index 0000000000..d2252d91e9 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.cc new file mode 100644 index 0000000000..655988fb28 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" diff --git a/lib/utils/src/utils/graph/node/algorithms/generate_new_node_id_permutation.cc b/lib/utils/src/utils/graph/node/algorithms/generate_new_node_id_permutation.cc new file mode 100644 index 0000000000..256dccd185 --- /dev/null +++ b/lib/utils/src/utils/graph/node/algorithms/generate_new_node_id_permutation.cc @@ -0,0 +1,16 @@ +#include "utils/graph/node/algorithms/generate_new_node_id_permutation.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node_source.h" + +namespace FlexFlow { + +bidict generate_new_node_id_permutation(GraphView const &g) { + NodeSource node_source; + return generate_bidict( + get_nodes(g), + [&](Node const &) { return NewNode{node_source.new_node()}; }) + .reversed(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/node/node_query.cc b/lib/utils/src/utils/graph/node/node_query.cc index c74457465c..834086a733 100644 --- a/lib/utils/src/utils/graph/node/node_query.cc +++ b/lib/utils/src/utils/graph/node/node_query.cc @@ -28,4 +28,9 @@ NodeQuery query_union(NodeQuery const &lhs, NodeQuery const &rhs) { NOT_IMPLEMENTED(); } +std::unordered_set apply_node_query(NodeQuery const &query, + std::unordered_set const &ns) { + return apply_query(query.nodes, ns); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc new file mode 100644 index 0000000000..9077ea5f9a --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc @@ -0,0 +1,63 @@ +#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" +#include "utils/dot_file.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/as_dot.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" + +namespace FlexFlow { + +std::string as_dot(OpenDataflowGraphView const &g) { + std::function get_node_label = [](Node const &n) { + return fmt::format("n{}", n.raw_uid); + }; + + std::function get_input_label = + [](DataflowGraphInput const &i) { return fmt::format("i{}", i.idx); }; + + return as_dot(g, get_node_label, get_input_label); +} + +// WARN(@lockshaw): doing this all with string ids is ugly and error prone, +// as it requires duplicating the stringification logic across functions. +// +// Fixing this is tracked in issue +// https://github.com/flexflow/FlexFlow/issues/1476 +std::string + as_dot(OpenDataflowGraphView const &g, + std::function const &get_node_label, + std::function const + &get_input_label) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + as_dot(dot, static_cast(g), get_node_label); + + auto get_node_name = [](Node n) { return fmt::format("n{}", n.raw_uid); }; + + auto get_input_field = [](int idx) { return fmt::format("i{}", idx); }; + + auto get_output_field = [](int idx) { return fmt::format("o{}", idx); }; + + auto get_graph_input_name = [](DataflowGraphInput i) { + return fmt::format("gi{}", i.idx); + }; + + for (DataflowGraphInput const &i : get_open_dataflow_graph_inputs(g)) { + dot.add_node(get_graph_input_name(i), + {{"style", "dashed"}, {"label", get_input_label(i)}}); + } + + for (DataflowInputEdge const &e : get_incoming_edges(g)) { + dot.add_edge(get_graph_input_name(e.src), + get_node_name(e.dst.node), + std::nullopt, + get_input_field(e.dst.idx)); + } + + dot.close(); + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..d622497629 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,20 @@ +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_first.h" +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" + +namespace FlexFlow { + +std::optional + find_isomorphism(OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst) { + std::unordered_set all_isomorphisms = + find_isomorphisms(src, dst); + + if (all_isomorphisms.empty()) { + return std::nullopt; + } else { + return get_first(all_isomorphisms); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc new file mode 100644 index 0000000000..d95a9b9565 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -0,0 +1,248 @@ +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/get_all_permutations.h" +#include "utils/containers/get_first.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/values.h" +#include "utils/containers/zip.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include + +namespace FlexFlow { + +static std::optional + find_isomorphism_under_sink_node_mapping( + OpenDataflowGraphView const &src_g, + OpenDataflowGraphView const &dst_g, + bidict const &sink_node_mapping, + bidict const + &unused_graph_inputs_mapping) { + { + std::unordered_set already_mapped_src_nodes = + left_entries(sink_node_mapping); + std::unordered_set src_g_sink_nodes = get_sinks(src_g); + assert(already_mapped_src_nodes == src_g_sink_nodes); + } + + { + std::unordered_set already_mapped_dst_nodes = + right_entries(sink_node_mapping); + std::unordered_set dst_g_sink_nodes = get_sinks(dst_g); + assert(already_mapped_dst_nodes == dst_g_sink_nodes); + } + + { + std::unordered_set already_mapped_src_inputs = + right_entries(unused_graph_inputs_mapping); + std::unordered_set src_g_unused_inputs = + get_unused_open_dataflow_graph_inputs(src_g); + assert(already_mapped_src_inputs == src_g_unused_inputs); + } + + { + std::unordered_set already_mapped_dst_inputs = + right_entries(unused_graph_inputs_mapping); + std::unordered_set dst_g_unused_inputs = + get_unused_open_dataflow_graph_inputs(dst_g); + assert(already_mapped_dst_inputs == dst_g_unused_inputs); + } + + std::optional result = + OpenDataflowGraphIsomorphism{ + {}, + unused_graph_inputs_mapping, + }; + + auto fail = [&]() -> void { result = std::nullopt; }; + + auto has_failed = [&]() -> bool { return result == std::nullopt; }; + + std::function unify_nodes; + std::function + unify_edges; + std::function + unify_graph_inputs; + std::function + unify_values; + std::function + unify_outputs; + + unify_outputs = [&](DataflowOutput const &src_output, + DataflowOutput const &dst_output) { + if (has_failed()) { + return; + } + + if (src_output.idx != dst_output.idx) { + result = std::nullopt; + return; + } + + unify_nodes(src_output.node, dst_output.node); + }; + + unify_values = [&](OpenDataflowValue const &src_val, + OpenDataflowValue const &dst_val) { + if (has_failed()) { + return; + } + + if (src_val.index() != dst_val.index()) { + fail(); + return; + } + + if (src_val.has()) { + unify_outputs(src_val.get(), + dst_val.get()); + } else { + unify_graph_inputs(src_val.get(), + dst_val.get()); + } + }; + + unify_graph_inputs = [&](DataflowGraphInput const &src, + DataflowGraphInput const &dst) { + if (has_failed()) { + return; + } + + if (result->input_mapping.contains_l(src) && + result->input_mapping.at_l(src) != dst) { + fail(); + return; + } + if (result->input_mapping.contains_r(dst) && + result->input_mapping.at_r(dst) != src) { + fail(); + return; + } + + result->input_mapping.equate(src, dst); + }; + + unify_edges = [&](OpenDataflowEdge const &src_edge, + OpenDataflowEdge const &dst_edge) { + if (has_failed()) { + return; + } + + assert(get_open_dataflow_edge_dst(src_edge).idx == + get_open_dataflow_edge_dst(dst_edge).idx); + assert( + get_open_dataflow_edge_dst(src_edge).node == + result->node_mapping.at_r(get_open_dataflow_edge_dst(dst_edge).node)); + + unify_values(get_open_dataflow_edge_src(src_edge), + get_open_dataflow_edge_src(dst_edge)); + }; + + unify_nodes = [&](Node const &src_node, Node const &dst_node) { + if (has_failed()) { + return; + } + + if (result->node_mapping.contains(src_node, dst_node)) { + return; + } + + if (result->node_mapping.contains_l(src_node) && + result->node_mapping.at_l(src_node) != dst_node) { + fail(); + return; + } + if (result->node_mapping.contains_r(dst_node) && + result->node_mapping.at_r(dst_node) != src_node) { + fail(); + return; + } + + result->node_mapping.equate(src_node, dst_node); + + std::vector src_incoming_edges = + get_incoming_edges(src_g, src_node); + std::vector dst_incoming_edges = + get_incoming_edges(dst_g, dst_node); + + if (src_incoming_edges.size() != dst_incoming_edges.size()) { + fail(); + return; + } + + for (auto const &[src_edge, dst_edge] : + zip(src_incoming_edges, dst_incoming_edges)) { + unify_edges(src_edge, dst_edge); + } + }; + + for (auto const &[src_node, dst_node] : sink_node_mapping) { + unify_nodes(src_node, dst_node); + } + + return result; +} + +std::unordered_set + find_isomorphisms(OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst) { + std::unordered_set result; + + std::vector src_sink_nodes = as_vector(get_sinks(src)); + std::unordered_set dst_sink_nodes = get_sinks(dst); + + if (src_sink_nodes.size() != dst_sink_nodes.size()) { + return {}; + } + + std::vector src_unused_graph_inputs = + as_vector(get_unused_open_dataflow_graph_inputs(src)); + std::unordered_set dst_unused_graph_inputs = + get_unused_open_dataflow_graph_inputs(dst); + + if (src_unused_graph_inputs.size() != dst_unused_graph_inputs.size()) { + return {}; + } + + for (std::vector const &dst_sink_nodes : + get_all_permutations(dst_sink_nodes)) { + + bidict sink_node_mapping = + bidict_from_keys_and_values(src_sink_nodes, dst_sink_nodes); + + for (std::vector const &dst_unused_graph_inputs : + get_all_permutations(dst_unused_graph_inputs)) { + + bidict + unused_graph_inputs_mapping = bidict_from_keys_and_values( + src_unused_graph_inputs, dst_unused_graph_inputs); + + std::optional found = + find_isomorphism_under_sink_node_mapping( + src, dst, sink_node_mapping, unused_graph_inputs_mapping); + + if (found.has_value()) { + assert(is_isomorphic_under(src, dst, found.value())); + + result.insert(found.value()); + } + } + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.cc new file mode 100644 index 0000000000..c4b5befcbc --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.cc @@ -0,0 +1,41 @@ +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +FromOpenDataflowGraphDataView::FromOpenDataflowGraphDataView( + OpenDataflowGraphData const &data) + : data(data) {} + +std::unordered_set + FromOpenDataflowGraphDataView::query_nodes(NodeQuery const &q) const { + return apply_node_query(q, this->data.nodes); +} + +std::unordered_set FromOpenDataflowGraphDataView::query_edges( + OpenDataflowEdgeQuery const &q) const { + return apply_open_dataflow_edge_query(q, this->data.edges); +} + +std::unordered_set FromOpenDataflowGraphDataView::query_outputs( + DataflowOutputQuery const &q) const { + return apply_dataflow_output_query(q, this->data.outputs); +} + +std::unordered_set + FromOpenDataflowGraphDataView::get_inputs() const { + return this->data.inputs; +} + +FromOpenDataflowGraphDataView *FromOpenDataflowGraphDataView::clone() const { + return new FromOpenDataflowGraphDataView{this->data}; +} + +OpenDataflowGraphView + from_open_dataflow_graph_data(OpenDataflowGraphData const &data) { + return OpenDataflowGraphView::create(data); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.cc new file mode 100644 index 0000000000..7d9a3e3a0e --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.cc @@ -0,0 +1,19 @@ +#include "utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" + +namespace FlexFlow { + +bidict + generate_new_input_id_permutation(OpenDataflowGraphView const &g) { + DataflowGraphInputSource input_source; + return generate_bidict(get_open_dataflow_graph_inputs(g), + [&](DataflowGraphInput const &) { + return NewDataflowGraphInput{ + input_source.new_dataflow_graph_input()}; + }) + .reversed(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_edges.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_edges.cc new file mode 100644 index 0000000000..610239feff --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_edges.cc @@ -0,0 +1,10 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +std::unordered_set get_edges(OpenDataflowGraphView const &g) { + return g.query_edges(open_dataflow_edge_query_all()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_graph_data.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_graph_data.cc new file mode 100644 index 0000000000..3199be92f9 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_graph_data.cc @@ -0,0 +1,17 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" + +namespace FlexFlow { + +OpenDataflowGraphData get_graph_data(OpenDataflowGraphView const &g) { + return OpenDataflowGraphData{ + get_nodes(g), + get_edges(g), + g.get_inputs(), + get_all_dataflow_outputs(g), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.cc new file mode 100644 index 0000000000..ac1aae1168 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.cc @@ -0,0 +1,15 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h" +#include "utils/containers/get_only.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +OpenDataflowEdge get_incoming_edge(OpenDataflowGraphView const &g, + DataflowInput const &i) { + OpenDataflowEdgeQuery query = open_dataflow_edge_query_all_incoming_to(i); + std::unordered_set query_result = g.query_edges(query); + + return get_only(query_result); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc similarity index 55% rename from lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc rename to lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc index ff5451d239..cad00c71e1 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc @@ -1,28 +1,23 @@ -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/containers/generate_map.h" -#include "utils/containers/group_by.h" #include "utils/containers/sorted_by.h" #include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" namespace FlexFlow { -std::unordered_set get_edges(OpenDataflowGraphView const &g) { - return g.query_edges(open_dataflow_edge_query_all()); -} - -std::unordered_set - get_inputs(OpenDataflowGraphView const &g) { - return g.get_inputs(); -} +std::unordered_set + get_incoming_edges(OpenDataflowGraphView const &g) { + std::unordered_set raw_edges = + g.query_edges(OpenDataflowEdgeQuery{ + dataflow_input_edge_query_all(), + dataflow_edge_query_none(), + }); -std::vector get_inputs(OpenDataflowGraphView const &g, - Node const &n) { - return transform(get_incoming_edges(g, n), [](OpenDataflowEdge const &e) { - return get_open_dataflow_edge_source(e); + return transform(raw_edges, [](OpenDataflowEdge const &e) { + return e.get(); }); } @@ -54,14 +49,4 @@ std::unordered_map> [&](Node const &n) { return get_incoming_edges(g, n); }); } -std::unordered_set - get_open_dataflow_values(OpenDataflowGraphView const &g) { - return set_union( - transform( - unordered_set_of(g.get_inputs()), - [](DataflowGraphInput const &gi) { return OpenDataflowValue{gi}; }), - transform(get_all_dataflow_outputs(g), - [](DataflowOutput const &o) { return OpenDataflowValue{o}; })); -} - } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_inputs.cc new file mode 100644 index 0000000000..f4e23e04f4 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_inputs.cc @@ -0,0 +1,15 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/containers/transform.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" + +namespace FlexFlow { + +std::vector get_inputs(OpenDataflowGraphView const &g, + Node const &n) { + return transform(get_incoming_edges(g, n), [](OpenDataflowEdge const &e) { + return get_open_dataflow_edge_src(e); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc new file mode 100644 index 0000000000..78c7677de9 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc @@ -0,0 +1,10 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_graph_inputs(OpenDataflowGraphView const &g) { + return g.get_inputs(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc new file mode 100644 index 0000000000..12795b8f7e --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc @@ -0,0 +1,17 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" +#include "utils/containers/transform.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_value_uses(OpenDataflowGraphView const &view, + OpenDataflowValue const &value) { + std::unordered_set edges = + view.query_edges(open_dataflow_edge_query_all_outgoing_from(value)); + + return transform(edges, get_open_dataflow_edge_dst); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.cc new file mode 100644 index 0000000000..0aa1bdb054 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.cc @@ -0,0 +1,17 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_values(OpenDataflowGraphView const &g) { + return set_union( + transform( + unordered_set_of(g.get_inputs()), + [](DataflowGraphInput const &gi) { return OpenDataflowValue{gi}; }), + transform(get_all_dataflow_outputs(g), + [](DataflowOutput const &o) { return OpenDataflowValue{o}; })); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.cc new file mode 100644 index 0000000000..14099e1c64 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.cc @@ -0,0 +1,16 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" + +namespace FlexFlow { + +std::unordered_set get_source_nodes(OpenDataflowGraphView const &g) { + auto is_source_node = [&](Node const &n) { + std::vector incoming_edges = get_incoming_edges(g, n); + return incoming_edges.empty(); + }; + + return filter(get_nodes(g), is_source_node); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..95a8e095fc --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,29 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(OpenDataflowGraphView const &g, + std::unordered_set const &ns) { + std::unordered_set nodes_not_in_ns = set_minus(get_nodes(g), ns); + + OpenDataflowEdgeQuery query = OpenDataflowEdgeQuery{ + DataflowInputEdgeQuery{ + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }, + DataflowEdgeQuery{ + query_set{nodes_not_in_ns}, + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }, + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc index 8c07f4bfdb..4ade34941c 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -3,7 +3,7 @@ #include "utils/containers/extend.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/overload.h" @@ -28,7 +28,7 @@ std::unordered_set extend(relevant_edges, filter(incoming, comes_from_outside_subgraph)); } - return transform(relevant_edges, get_open_dataflow_edge_source); + return transform(relevant_edges, get_open_dataflow_edge_src); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc new file mode 100644 index 0000000000..8fbe7ae5bc --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc @@ -0,0 +1,15 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" + +namespace FlexFlow { + +std::unordered_set + get_unused_open_dataflow_graph_inputs(OpenDataflowGraphView const &g) { + return filter( + get_open_dataflow_graph_inputs(g), [&](DataflowGraphInput const &i) { + return get_open_dataflow_value_uses(g, OpenDataflowValue{i}).empty(); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc new file mode 100644 index 0000000000..77e23d9c87 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -0,0 +1,30 @@ +#include "utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" + +namespace FlexFlow { + +bool is_isomorphic_under( + OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst, + OpenDataflowGraphIsomorphism const &candidate_isomorphism) { + + bidict node_permutation = + map_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { + return NewNode{dst_node}; + }).reversed(); + bidict input_permutation = + map_values(candidate_isomorphism.input_mapping, + [](DataflowGraphInput const &dst_input) { + return NewDataflowGraphInput{dst_input}; + }) + .reversed(); + return get_graph_data(permute_input_ids( + permute_node_ids(src, node_permutation), input_permutation)) == + get_graph_data(dst); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..af56db2de3 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,11 @@ +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h" +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" + +namespace FlexFlow { + +bool open_dataflow_graphs_are_isomorphic(OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst) { + return find_isomorphism(src, dst).has_value(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc new file mode 100644 index 0000000000..c9c60edae3 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc @@ -0,0 +1,44 @@ +#include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/containers/transform.h" +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/overload.h" + +namespace FlexFlow { + +OpenDataflowGraphView permute_input_ids( + OpenDataflowGraphView const &g, + bidict const &input_mapping) { + auto new_input_from_old = + [&](DataflowGraphInput const &old_input) -> DataflowGraphInput { + return input_mapping.at_r(old_input).raw_input; + }; + + auto new_edge_from_old = [&](OpenDataflowEdge const &e) { + return e.visit(overload{ + [&](DataflowInputEdge const &input_e) { + return OpenDataflowEdge{ + DataflowInputEdge{ + new_input_from_old(input_e.src), + input_e.dst, + }, + }; + }, + [&](DataflowEdge const &standard_e) { + return OpenDataflowEdge{standard_e}; + }, + }); + }; + + OpenDataflowGraphData old_data = get_graph_data(g); + OpenDataflowGraphData permuted_data = OpenDataflowGraphData{ + old_data.nodes, + transform(old_data.edges, new_edge_from_old), + transform(old_data.inputs, new_input_from_old), + old_data.outputs, + }; + + return from_open_dataflow_graph_data(permuted_data); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc new file mode 100644 index 0000000000..ab05cbbdc3 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc @@ -0,0 +1,72 @@ +#include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/transform.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/node/node_source.h" +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/query_set.h" +#include "utils/overload.h" + +namespace FlexFlow { + +OpenDataflowGraphView + permute_node_ids(OpenDataflowGraphView const &g, + bidict const &new_node_tofrom_old_node) { + auto new_node_from_old = [&](Node const &n) -> Node { + return new_node_tofrom_old_node.at_r(n).raw_node; + }; + + auto new_output_from_old = [&](DataflowOutput const &o) -> DataflowOutput { + return DataflowOutput{ + new_node_from_old(o.node), + o.idx, + }; + }; + + auto new_input_from_old = [&](DataflowInput const &i) -> DataflowInput { + return DataflowInput{ + new_node_from_old(i.node), + i.idx, + }; + }; + + auto new_edge_from_old = [&](OpenDataflowEdge const &e) { + return e.visit(overload{ + [&](DataflowInputEdge const &input_e) { + return OpenDataflowEdge{ + DataflowInputEdge{ + input_e.src, + new_input_from_old(input_e.dst), + }, + }; + }, + [&](DataflowEdge const &standard_e) { + return OpenDataflowEdge{ + DataflowEdge{ + new_output_from_old(standard_e.src), + new_input_from_old(standard_e.dst), + }, + }; + }, + }); + }; + + OpenDataflowGraphData old_data = get_graph_data(g); + + OpenDataflowGraphData permuted_data = OpenDataflowGraphData{ + transform(old_data.nodes, new_node_from_old), + transform(old_data.edges, new_edge_from_old), + old_data.inputs, + transform(old_data.outputs, new_output_from_old), + }; + + return from_open_dataflow_graph_data(permuted_data); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc index 19da98aabd..8736f2d157 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc @@ -23,4 +23,31 @@ bool dataflow_input_edge_query_includes(DataflowInputEdgeQuery const &q, includes(q.dst_idxs, e.dst.idx); } +DataflowInputEdgeQuery + dataflow_input_edge_query_for_edge(DataflowInputEdge const &e) { + return DataflowInputEdgeQuery{ + query_set{e.src}, + query_set{e.dst.node}, + query_set{e.dst.idx}, + }; +} + +DataflowInputEdgeQuery + dataflow_input_edge_query_all_outgoing_from(DataflowGraphInput const &src) { + return DataflowInputEdgeQuery{ + query_set{src}, + query_set::matchall(), + query_set::matchall(), + }; +} + +DataflowInputEdgeQuery + dataflow_input_edge_query_all_incoming_to(DataflowInput const &dst) { + return DataflowInputEdgeQuery{ + query_set::matchall(), + query_set{dst.node}, + query_set{dst.idx}, + }; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc index e3311e4d18..d5e5b614af 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc @@ -4,21 +4,21 @@ namespace FlexFlow { Node get_open_dataflow_edge_dst_node(OpenDataflowEdge const &e) { - return e.visit(overload{ - [](DataflowEdge const &e) { return e.dst.node; }, - [](DataflowInputEdge const &e) { return e.dst.node; }, - }); + return get_open_dataflow_edge_dst(e).node; } int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &e) { - return e.visit(overload{ - [](DataflowEdge const &e) { return e.dst.idx; }, - [](DataflowInputEdge const &e) { return e.dst.idx; }, + return get_open_dataflow_edge_dst(e).idx; +} + +DataflowInput get_open_dataflow_edge_dst(OpenDataflowEdge const &e) { + return e.visit(overload{ + [](DataflowEdge const &e) { return e.dst; }, + [](DataflowInputEdge const &e) { return e.dst; }, }); } -OpenDataflowValue - get_open_dataflow_edge_source(OpenDataflowEdge const &open_e) { +OpenDataflowValue get_open_dataflow_edge_src(OpenDataflowEdge const &open_e) { return open_e.visit(overload{ [](DataflowEdge const &e) { return OpenDataflowValue{e.src}; }, [](DataflowInputEdge const &e) { return OpenDataflowValue{e.src}; }, diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc index 4d12889a1e..4882c3e143 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc @@ -32,4 +32,38 @@ bool open_dataflow_edge_query_includes(OpenDataflowEdgeQuery const &q, }); } +OpenDataflowEdgeQuery + open_dataflow_edge_query_all_outgoing_from(OpenDataflowValue const &src) { + return src.visit(overload{ + [](DataflowOutput const &o) { + return OpenDataflowEdgeQuery{ + dataflow_input_edge_query_none(), + dataflow_edge_query_all_outgoing_from(o), + }; + }, + [](DataflowGraphInput const &i) { + return OpenDataflowEdgeQuery{ + dataflow_input_edge_query_all_outgoing_from(i), + dataflow_edge_query_none(), + }; + }, + }); +} + +OpenDataflowEdgeQuery + open_dataflow_edge_query_all_incoming_to(DataflowInput const &dst) { + return OpenDataflowEdgeQuery{ + dataflow_input_edge_query_all_incoming_to(dst), + dataflow_edge_query_all_incoming_to(dst), + }; +} + +std::unordered_set apply_open_dataflow_edge_query( + OpenDataflowEdgeQuery const &q, + std::unordered_set const &es) { + return filter(es, [&](OpenDataflowEdge const &e) { + return open_dataflow_edge_query_includes(q, e); + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.cc new file mode 100644 index 0000000000..25f13fd298 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.cc @@ -0,0 +1,22 @@ +#include "utils/graph/open_dataflow_graph/open_dataflow_value.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + try_get_dataflow_output(OpenDataflowValue const &v) { + return v.visit>(overload{ + [](DataflowOutput const &o) { return o; }, + [](DataflowGraphInput const &i) { return std::nullopt; }, + }); +} + +std::optional + try_get_dataflow_graph_input(OpenDataflowValue const &v) { + return v.visit>(overload{ + [](DataflowOutput const &o) { return std::nullopt; }, + [](DataflowGraphInput const &i) { return i; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc b/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc index 347c906bd7..4c9eb9d3ef 100644 --- a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc +++ b/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc @@ -7,7 +7,8 @@ namespace FlexFlow { void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { for (Node const &node : get_nodes(ext)) { - g.add_node_unsafe(node, get_inputs(ext, node), get_outputs(ext, node)); + g.add_node_unsafe( + node, get_input_values(ext, node), get_outputs(ext, node)); } } diff --git a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc new file mode 100644 index 0000000000..6e3ac8c155 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc @@ -0,0 +1,38 @@ +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("bidict_from_enumerating(std::unordered_set)") { + std::unordered_set input = {"zero", "one", "two"}; + + bidict result = bidict_from_enumerating(input); + + std::unordered_set result_left_entries = left_entries(result); + std::unordered_set correct_left_entries = {0, 1, 2}; + CHECK(result_left_entries == correct_left_entries); + + std::unordered_set result_right_entries = + right_entries(result); + std::unordered_set correct_right_entries = input; + CHECK(result_right_entries == correct_right_entries); + } + + TEST_CASE("bidict_from_enumerating(std::set)") { + std::set input = {"a", "c", "b"}; + + bidict correct = { + {0, "a"}, + {1, "b"}, + {2, "c"}, + }; + + bidict result = bidict_from_enumerating(input); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc new file mode 100644 index 0000000000..2be5f1ef93 --- /dev/null +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -0,0 +1,50 @@ +#include "utils/containers/enumerate.h" +#include "utils/containers/as_vector.h" +#include "utils/fmt/map.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("enumerate(std::vector)") { + std::vector input = {"zero", "one", "two", "three"}; + + std::map correct = { + {0, "zero"}, + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + + std::map result = enumerate(input); + + CHECK(result == correct); + + SUBCASE("check iteration order") { + std::vector> iterated_result = + as_vector(result); + std::vector> correct_iteration_order = { + {0, "zero"}, + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + + CHECK(iterated_result == correct_iteration_order); + } + } + + TEST_CASE("enumerate(std::unordered_set)") { + std::unordered_set input = {"zero", "one", "two", "three"}; + + std::map correct = { + {0, "zero"}, + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + } +} diff --git a/lib/utils/test/src/utils/containers/filtrans.cc b/lib/utils/test/src/utils/containers/filtrans.cc new file mode 100644 index 0000000000..b8bb832b06 --- /dev/null +++ b/lib/utils/test/src/utils/containers/filtrans.cc @@ -0,0 +1,57 @@ +#include "utils/containers/filtrans.h" +#include "utils/fmt/set.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filtrans(std::vector, F)") { + std::vector input = {1, 2, 3, 2, 4}; + std::vector result = + filtrans(input, [](int x) -> std::optional { + if ((x % 2) == 0) { + return std::to_string(x); + } else { + return std::nullopt; + } + }); + + std::vector correct = {"2", "2", "4"}; + + CHECK(result == correct); + } + + TEST_CASE("filtrans(std::unordered_set, F)") { + std::unordered_set input = {1, 2, 3, 4}; + std::unordered_set result = + filtrans(input, [](int x) -> std::optional { + if ((x % 2) == 0) { + return std::to_string(x); + } else { + return std::nullopt; + } + }); + + std::unordered_set correct = {"2", "4"}; + + CHECK(result == correct); + } + + TEST_CASE("filtrans(std::set, F)") { + std::set input = {1, 2, 3, 4}; + std::set result = + filtrans(input, [](int x) -> std::optional { + if ((x % 2) == 0) { + return std::to_string(x); + } else { + return std::nullopt; + } + }); + + std::set correct = {"2", "4"}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_permutations.cc b/lib/utils/test/src/utils/containers/get_all_permutations.cc new file mode 100644 index 0000000000..5f22266809 --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_all_permutations.cc @@ -0,0 +1,54 @@ +#include "utils/containers/get_all_permutations.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/fmt/unordered_multiset.h" +#include "utils/fmt/vector.h" +#include "utils/hash/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_all_permutations") { + SUBCASE("input size 1") { + std::vector input = {1}; + + std::unordered_multiset> result = + unordered_multiset_of(get_all_permutations(input)); + std::unordered_multiset> correct = {{1}}; + + CHECK(result == correct); + } + + SUBCASE("input size 3") { + std::vector input = {2, 1, 3}; + + std::unordered_multiset> result = + unordered_multiset_of(get_all_permutations(input)); + std::unordered_multiset> correct = { + {1, 2, 3}, + {1, 3, 2}, + {2, 1, 3}, + {2, 3, 1}, + {3, 1, 2}, + {3, 2, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("elements repeated") { + std::vector input = {1, 2, 2}; + + std::unordered_multiset> result = + unordered_multiset_of(get_all_permutations(input)); + std::unordered_multiset> correct = { + {1, 2, 2}, + {2, 1, 2}, + {2, 2, 1}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc index f716f73a03..25f990f80e 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc @@ -28,8 +28,8 @@ TEST_SUITE(FF_TEST_SUITE) { Node n4 = n4_added.node; DataflowOutput o4 = get_only(n4_added.outputs); - SUBCASE("get_inputs") { - std::vector result = get_inputs(g, n4); + SUBCASE("get_input_values") { + std::vector result = get_input_values(g, n4); std::vector correct = {o1, o2, o3}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..f991b4a65e --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,93 @@ +#include "utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "dataflow_graphs_are_isomorphic(DataflowGraphView, DataflowGraphView)") { + auto g1 = DataflowGraph::create(); + + NodeAddedResult g1_n1_added = g1.add_node({}, 1); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node({g1_n1_output}, 1); + Node g1_n2_node = g1_n2_added.node; + + auto g2 = DataflowGraph::create(); + + SUBCASE("input graphs are isomorphic") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = true; + + bool result = dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different connectivity)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node({g2_n1_output, g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = false; + + bool result = dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different number of src and sink " + "nodes)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + NodeAddedResult g2_n3_added = g2.add_node({}, 1); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different number of internal " + "nodes)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node({g2_n2_output}, 1); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..160e4c4f73 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,101 @@ +#include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_isomorphism(DataflowGraphView, DataflowGraphView)") { + auto g1 = DataflowGraph::create(); + + NodeAddedResult g1_n1_added = g1.add_node({}, 1); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node({g1_n1_output}, 1); + Node g1_n2_node = g1_n2_added.node; + + auto g2 = DataflowGraph::create(); + + SUBCASE("input graphs are isomorphic") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct_isomorphism = + DataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + }; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct_isomorphism); + } + + SUBCASE("input graphs are not isomorphic (different connectivity)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node({g2_n1_output, g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct_isomorphism = + std::nullopt; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct_isomorphism); + } + + SUBCASE("input graphs are not isomorphic (different number of src and sink " + "nodes)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + NodeAddedResult g2_n3_added = g2.add_node({}, 0); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct_isomorphism = + std::nullopt; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct_isomorphism); + } + + SUBCASE("input graphs are not isomorphic (different number of internal " + "nodes)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node({g2_n2_output}, 1); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct_isomorphism = + std::nullopt; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct_isomorphism); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc new file mode 100644 index 0000000000..7e02686dde --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -0,0 +1,41 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_outgoing_edges(DataflowGraphView, std::unordered_set") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o1, o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + std::unordered_set input_node_set = {n2, n3}; + + std::unordered_set result = + get_subgraph_outgoing_edges(g, input_node_set); + + std::unordered_set correct = { + DataflowEdge{o2, DataflowInput{n4, 1}}, + DataflowEdge{o3, DataflowInput{n4, 2}}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..1ac1b7ff01 --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,187 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_isomorphism(LabelledDataflowGraphView, " + "LabelledDataflowGraphView)") { + auto g1 = LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + auto g2 = LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + SUBCASE("duplicate labels") { + std::string node_label = "n"; + int value_label = 1; + + NodeAddedResult g1_n1_added = g1.add_node(node_label, {}, {value_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node(node_label, {}, {value_label}); + Node g1_n2_node = g1_n2_added.node; + DataflowOutput g1_n2_output = get_only(g1_n2_added.outputs); + + NodeAddedResult g1_n3_added = + g1.add_node(node_label, {g1_n1_output, g1_n2_output}, {value_label}); + Node g1_n3_node = g1_n3_added.node; + + NodeAddedResult g2_n1_added = g2.add_node(node_label, {}, {value_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node(node_label, {}, {value_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = + g2.add_node(node_label, {g2_n1_output, g2_n2_output}, {value_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = + DataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + {g1_n3_node, g2_n3_node}, + }, + }; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("non-duplicate labels") { + std::string n1_label = "n1"; + std::string n2_label = "n2"; + std::string n3_label = "n3"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + int n3_output_label = 4; + + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = + g1.add_node(n2_label, {}, {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + DataflowOutput g1_n2_output = get_only(g1_n2_added.outputs); + + NodeAddedResult g1_n3_added = g1.add_node( + n3_label, {g1_n1_output, g1_n2_output}, {n3_output_label}); + Node g1_n3_node = g1_n3_added.node; + + SUBCASE("input graphs are isomorphic") { + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = + DataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + {g1_n3_node, g2_n3_node}, + }, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched node labels)") { + std::string mismatched_node_label = "mismatched_node_label"; + + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(mismatched_node_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched output label)") { + int mismatched_output_label = 20000; + + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {mismatched_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (underlying unlabelled graphs " + "not isomorphic)") { + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n2_output, g2_n1_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc b/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..12950b8ad2 --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,169 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("labelled_dataflow_graphs_are_isomorphic(LabelledDataflowGraphView," + " LabelledDataflowGraphView)") { + auto g1 = LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + auto g2 = LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + SUBCASE("duplicate labels") { + std::string node_label = "n"; + int value_label = 1; + + NodeAddedResult g1_n1_added = g1.add_node(node_label, {}, {value_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node(node_label, {}, {value_label}); + Node g1_n2_node = g1_n2_added.node; + DataflowOutput g1_n2_output = get_only(g1_n2_added.outputs); + + NodeAddedResult g1_n3_added = + g1.add_node(node_label, {g1_n1_output, g1_n2_output}, {value_label}); + Node g1_n3_node = g1_n3_added.node; + + NodeAddedResult g2_n1_added = g2.add_node(node_label, {}, {value_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node(node_label, {}, {value_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = + g2.add_node(node_label, {g2_n1_output, g2_n2_output}, {value_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = true; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("non-duplicate labels") { + std::string n1_label = "n1"; + std::string n2_label = "n2"; + std::string n3_label = "n3"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + int n3_output_label = 4; + + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = + g1.add_node(n2_label, {}, {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + DataflowOutput g1_n2_output = get_only(g1_n2_added.outputs); + + NodeAddedResult g1_n3_added = g1.add_node( + n3_label, {g1_n1_output, g1_n2_output}, {n3_output_label}); + Node g1_n3_node = g1_n3_added.node; + + SUBCASE("input graphs are isomorphic") { + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = true; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched node labels)") { + std::string mismatched_node_label = "mismatched_node_label"; + + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(mismatched_node_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched output label)") { + int mismatched_output_label = 20000; + + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {mismatched_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (underlying unlabelled graphs " + "not isomorphic)") { + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n2_output, g2_n1_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..c83366e78c --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,185 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_isomorphism") { + auto g1 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + auto g2 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + SUBCASE("duplicate labels") { + std::string node_label = "n"; + int value_label = 2; + + DataflowGraphInput g1_i1 = g1.add_input(value_label); + NodeAddedResult g1_n1_added = + g1.add_node(node_label, {OpenDataflowValue{g1_i1}}, {value_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + NodeAddedResult g1_n2_added = g1.add_node( + node_label, + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, + {value_label}); + Node g1_n2_node = g1_n2_added.node; + + DataflowGraphInput g2_i1 = g2.add_input(value_label); + NodeAddedResult g2_n1_added = + g2.add_node(node_label, {OpenDataflowValue{g2_i1}}, {value_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + node_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {value_label}); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = + OpenDataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + bidict{ + {g1_i1, g2_i1}, + }, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("differing labels") { + std::string n1_label = "n1"; + std::string n2_label = "n2"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + + DataflowGraphInput g1_i1 = g1.add_input(i1_label); + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {OpenDataflowValue{g1_i1}}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + NodeAddedResult g1_n2_added = g1.add_node( + n2_label, + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, + {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + + SUBCASE("input graphs are isomorphic") { + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = + OpenDataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + bidict{ + {g1_i1, g2_i1}, + }, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched node labels)") { + std::string mismatched_node_label = "mismatched_node_label"; + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + mismatched_node_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched output label)") { + int mismatched_output_label = 20000; + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {mismatched_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched input label)") { + int mismatched_input_label = 10000; + + DataflowGraphInput g2_i1 = g2.add_input(mismatched_input_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (underlying unlabelled graphs " + "not isomorphic)") { + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, {OpenDataflowValue{g2_n1_output}}, {n2_output_label}); + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc new file mode 100644 index 0000000000..0f59392fcc --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -0,0 +1,60 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_isomorphic_under") { + auto g1 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + std::string n1_label = "n1"; + std::string n2_label = "n2"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + + DataflowGraphInput g1_i1 = g1.add_input(i1_label); + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {OpenDataflowValue{g1_i1}}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + NodeAddedResult g1_n2_added = + g1.add_node(n2_label, + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, + {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + + auto g2 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + + OpenDataflowGraphIsomorphism correct_isomorphism = + OpenDataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + bidict{ + {g1_i1, g2_i1}, + }, + }; + + bool result = is_isomorphic_under(g1, g2, correct_isomorphism); + + CHECK(result); + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..9f8d5eb08a --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,121 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("labelled_open_dataflow_graphs_are_isomorphic") { + auto g1 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + std::string n1_label = "n1"; + std::string n2_label = "n2"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + + DataflowGraphInput g1_i1 = g1.add_input(i1_label); + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {OpenDataflowValue{g1_i1}}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + NodeAddedResult g1_n2_added = + g1.add_node(n2_label, + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, + {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + + auto g2 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + SUBCASE("input graphs are isomorphic") { + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + + bool correct = true; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched node labels)") { + std::string mismatched_node_label = "mismatched_node_label"; + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + "mismatched_label", + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + bool correct = false; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched output label)") { + int mismatched_output_label = 20000; + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {mismatched_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + bool correct = false; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched input label)") { + int mismatched_input_label = 10000; + + DataflowGraphInput g2_i1 = g2.add_input(mismatched_input_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + bool correct = false; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (underlying unlabelled graphs not " + "isomorphic)") { + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, {OpenDataflowValue{g2_n1_output}}, {n2_output_label}); + + bool correct = false; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..78aaa8d9fc --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,134 @@ +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_isomorphism(OpenDataflowGraphView, OpenDataflowGraphView)") { + auto g1 = OpenDataflowGraph::create(); + auto g2 = OpenDataflowGraph::create(); + + SUBCASE("input graphs are empty") { + std::optional correct = + OpenDataflowGraphIsomorphism{ + {}, + {}, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not empty") { + DataflowGraphInput g1_i1 = g1.add_input(); + NodeAddedResult g1_n1_added = g1.add_node({OpenDataflowValue{g1_i1}}, 1); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node( + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, 1); + Node g1_n2_node = g1_n2_added.node; + + SUBCASE("one graph is empty") { + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are isomorphic") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = + OpenDataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + bidict{ + {g1_i1, g2_i1}, + }, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different number of graph " + "inputs)") { + DataflowGraphInput g2_i1 = g2.add_input(); + DataflowGraphInput g2_i2 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different connectivity)") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_n1_output}, OpenDataflowValue{g2_n1_output}}, + 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different numbers of nodes)") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + NodeAddedResult g2_n3_added = g2.add_node({}, 0); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc new file mode 100644 index 0000000000..ff75e8fe48 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc @@ -0,0 +1,24 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_open_dataflow_graph_inputs(OpenDataflowGraphView)") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + DataflowGraphInput i1 = g.add_input(); + + NodeAddedResult n0_added = g.add_node({}, 1); + + std::unordered_set result = + get_open_dataflow_graph_inputs(g); + std::unordered_set correct = {i0, i1}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc new file mode 100644 index 0000000000..7496c3009d --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc @@ -0,0 +1,74 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_open_dataflow_value_uses(OpenDataflowGraphView, " + "OpenDataflowValue)") { + SUBCASE("value is a DataflowGraphInput") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + DataflowGraphInput i1 = g.add_input(); + + NodeAddedResult n0_added = g.add_node( + {OpenDataflowValue{i0}, OpenDataflowValue{i1}, OpenDataflowValue{i0}}, + 1); + Node n0 = n0_added.node; + DataflowOutput o0 = get_only(n0_added.outputs); + + NodeAddedResult n1_added = g.add_node( + {OpenDataflowValue{i1}, OpenDataflowValue{o0}, OpenDataflowValue{i0}}, + 1); + Node n1 = n1_added.node; + + std::unordered_set correct = { + DataflowInput{n0, 0}, + DataflowInput{n0, 2}, + DataflowInput{n1, 2}, + }; + + std::unordered_set result = + get_open_dataflow_value_uses(g, OpenDataflowValue{i0}); + + CHECK(result == correct); + } + + SUBCASE("value is a DataflowOutput") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + + NodeAddedResult n0_added = g.add_node({OpenDataflowValue{i0}}, 2); + Node n0 = n0_added.node; + DataflowOutput o0_0 = n0_added.outputs.at(0); + DataflowOutput o0_1 = n0_added.outputs.at(1); + + NodeAddedResult n1_added = g.add_node({OpenDataflowValue{i0}, + OpenDataflowValue{o0_1}, + OpenDataflowValue{o0_0}}, + 1); + Node n1 = n1_added.node; + + NodeAddedResult n2_added = + g.add_node({OpenDataflowValue{o0_1}, OpenDataflowValue{i0}}, 1); + Node n2 = n2_added.node; + + std::unordered_set correct = { + DataflowInput{n1, 1}, + DataflowInput{n2, 0}, + }; + + std::unordered_set result = + get_open_dataflow_value_uses(g, OpenDataflowValue{o0_1}); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc new file mode 100644 index 0000000000..ddd6d74119 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc @@ -0,0 +1,41 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_unused_open_dataflow_graph_inputs(OpenDataflowGraphView)") { + auto g = OpenDataflowGraph::create(); + SUBCASE("unused inputs exist") { + DataflowGraphInput g_i1 = g.add_input(); + DataflowGraphInput g_i2 = g.add_input(); + DataflowGraphInput g_i3 = g.add_input(); + + NodeAddedResult g_n1_added = g.add_node({OpenDataflowValue{g_i2}}, 1); + + std::unordered_set result = + get_unused_open_dataflow_graph_inputs(g); + + std::unordered_set correct = {g_i1, g_i3}; + + CHECK(result == correct); + } + + SUBCASE("unused inputs don't exist") { + DataflowGraphInput g_i1 = g.add_input(); + DataflowGraphInput g_i2 = g.add_input(); + + NodeAddedResult g_n1_added = + g.add_node({OpenDataflowValue{g_i1}, OpenDataflowValue{g_i2}}, 1); + + std::unordered_set result = + get_unused_open_dataflow_graph_inputs(g); + + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..bdb1bb4814 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,110 @@ +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("open_dataflow_graphs_are_isomorphic(OpenDataflowGraphView, " + "OpenDataflowGraphView)") { + auto g1 = OpenDataflowGraph::create(); + auto g2 = OpenDataflowGraph::create(); + + SUBCASE("input graphs are empty") { + bool correct = true; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not empty") { + DataflowGraphInput g1_i1 = g1.add_input(); + NodeAddedResult g1_n1_added = g1.add_node({OpenDataflowValue{g1_i1}}, 1); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node( + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, 1); + Node g1_n2_node = g1_n2_added.node; + + SUBCASE("one input graph is empty") { + bool correct = false; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are isomorphic") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = true; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different number of graph " + "inputs)") { + DataflowGraphInput g2_i1 = g2.add_input(); + DataflowGraphInput g2_i2 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = false; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different connectivity)") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_n1_output}, OpenDataflowValue{g2_n1_output}}, + 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = false; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different numbers of nodes)") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + NodeAddedResult g2_n3_added = g2.add_node({}, 0); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc new file mode 100644 index 0000000000..b565e46e67 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc @@ -0,0 +1,79 @@ +#include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("permute_input_ids(OpenDataflowGraphView, " + "bidict)") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + DataflowGraphInput i1 = g.add_input(); + + NodeAddedResult n0_added = g.add_node({OpenDataflowValue{i0}}, 1); + Node n0 = n0_added.node; + DataflowOutput n0_output = get_only(n0_added.outputs); + + NodeAddedResult n1_added = g.add_node({OpenDataflowValue{n0_output}}, 1); + Node n1 = n1_added.node; + DataflowOutput n1_output = get_only(n1_added.outputs); + + DataflowGraphInput new_i0 = DataflowGraphInput{6}; + DataflowGraphInput new_i1 = DataflowGraphInput{7}; + + bidict input_mapping = { + {NewDataflowGraphInput{new_i0}, i0}, + {NewDataflowGraphInput{new_i1}, i1}, + }; + + OpenDataflowGraphView result = permute_input_ids(g, input_mapping); + OpenDataflowGraphData result_data = get_graph_data(result); + + OpenDataflowGraphData correct_data = OpenDataflowGraphData{ + {n0, n1}, + { + OpenDataflowEdge{ + DataflowInputEdge{ + new_i0, + DataflowInput{ + n0, + 0, + }, + }, + }, + OpenDataflowEdge{ + DataflowEdge{ + DataflowOutput{ + n0, + 0, + }, + DataflowInput{ + n1, + 0, + }, + }, + }, + }, + {new_i0, new_i1}, + { + DataflowOutput{ + n0, + 0, + }, + DataflowOutput{ + n1, + 0, + }, + }, + }; + + CHECK(result_data == correct_data); + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc new file mode 100644 index 0000000000..36bcd16dad --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc @@ -0,0 +1,175 @@ +#include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("permute_node_ids(OpenDataflowGraphView, bidict)") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + + NodeAddedResult n0_added = g.add_node({OpenDataflowValue{i0}}, 1); + Node n0 = n0_added.node; + DataflowOutput n0_output = get_only(n0_added.outputs); + + NodeAddedResult n1_added = + g.add_node({OpenDataflowValue{i0}, OpenDataflowValue{n0_output}}, 1); + Node n1 = n1_added.node; + DataflowOutput n1_output = get_only(n1_added.outputs); + + Node new_node0 = Node{5}; + Node new_node1 = Node{6}; + + bidict node_mapping = { + {NewNode{new_node0}, n0}, + {NewNode{new_node1}, n1}, + }; + + OpenDataflowGraphView result = permute_node_ids(g, node_mapping); + OpenDataflowGraphData result_data = get_graph_data(result); + + OpenDataflowGraphData correct_data = OpenDataflowGraphData{ + {new_node0, new_node1}, + { + OpenDataflowEdge{ + DataflowInputEdge{ + i0, + DataflowInput{ + new_node0, + 0, + }, + }, + }, + OpenDataflowEdge{ + DataflowInputEdge{ + i0, + DataflowInput{ + new_node1, + 0, + }, + }, + }, + OpenDataflowEdge{ + DataflowEdge{ + DataflowOutput{ + new_node0, + 0, + }, + DataflowInput{ + new_node1, + 1, + }, + }, + }, + }, + {i0}, + { + DataflowOutput{ + new_node0, + 0, + }, + DataflowOutput{ + new_node1, + 0, + }, + }, + }; + + CHECK(result_data == correct_data); + + // because get_graph_data only uses matchall nodes which don't require as + // much updating, we also add test cases for the query methods with concrete + // queries to check the through-node-permutation querying logic + SUBCASE("query_nodes(NodeQuery)") { + SUBCASE("check access to old nodes") { + std::unordered_set result_nodes = + result.query_nodes(NodeQuery{n0}); + std::unordered_set correct = {}; + CHECK(result_nodes == correct); + } + + SUBCASE("check access to new nodes") { + std::unordered_set result_nodes = + result.query_nodes(NodeQuery{new_node0}); + std::unordered_set correct = {new_node0}; + CHECK(result_nodes == correct); + } + } + + SUBCASE("query_edges(OpenDataflowEdgeQuery)") { + SUBCASE("check access to old edges") { + OpenDataflowEdgeQuery query = OpenDataflowEdgeQuery{ + dataflow_input_edge_query_for_edge( + DataflowInputEdge{i0, DataflowInput{n0, 0}}), + dataflow_edge_query_for_edge( + DataflowEdge{n0_output, DataflowInput{n1, 1}}), + }; + std::unordered_set result_nodes = + result.query_edges(query); + std::unordered_set correct = {}; + CHECK(result_nodes == correct); + } + + SUBCASE("check access to new edges") { + DataflowEdge new_standard_edge = DataflowEdge{ + DataflowOutput{new_node0, 0}, + DataflowInput{new_node1, 1}, + }; + DataflowInputEdge new_input_edge = DataflowInputEdge{ + i0, + DataflowInput{new_node0, 0}, + }; + OpenDataflowEdgeQuery query = OpenDataflowEdgeQuery{ + dataflow_input_edge_query_for_edge(new_input_edge), + dataflow_edge_query_for_edge(new_standard_edge), + }; + + std::unordered_set result_nodes = + result.query_edges(query); + std::unordered_set correct = { + OpenDataflowEdge{new_standard_edge}, + OpenDataflowEdge{new_input_edge}, + }; + + CHECK(result_nodes == correct); + } + } + + SUBCASE("query_outputs(DataflowOutputQuery)") { + SUBCASE("check access to old outputs") { + DataflowOutput old_output = n0_output; + + DataflowOutputQuery query = + dataflow_output_query_for_output(old_output); + std::unordered_set result_outputs = + result.query_outputs(query); + + std::unordered_set correct = {}; + + CHECK(result_outputs == correct); + } + + SUBCASE("check access to new outputs") { + DataflowOutput new_output = DataflowOutput{new_node0, 0}; + + DataflowOutputQuery query = + dataflow_output_query_for_output(new_output); + std::unordered_set result_outputs = + result.query_outputs(query); + + std::unordered_set correct = {new_output}; + + CHECK(result_outputs == correct); + } + } + } +}