diff --git a/flake.lock b/flake.lock
index b36a96ee80..1aad68ae29 100644
--- a/flake.lock
+++ b/flake.lock
@@ -43,11 +43,11 @@
]
},
"locked": {
- "lastModified": 1722405648,
- "narHash": "sha256-+9cRIT+bwo7qxI966HjwR2Sw37CcXD1JlG9nw+vq2lY=",
+ "lastModified": 1722923482,
+ "narHash": "sha256-myUec+oBcnKNCqLQqSiPCyXFsIsvlrsGoj/mQFlHVrY=",
"owner": "lockshaw",
"repo": "proj",
- "rev": "3674de6208c52f3a022e8f00660ee01d580aa466",
+ "rev": "c650b0e52337652ea7190131988c0370e0ee7f25",
"type": "github"
},
"original": {
diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc
index 12eacb2a30..af7756c635 100644
--- a/lib/compiler/src/machine_mapping.cc
+++ b/lib/compiler/src/machine_mapping.cc
@@ -12,10 +12,10 @@
#include "utils/containers/contains_key.h"
#include "utils/containers/get_only.h"
#include "utils/containers/keys.h"
+#include "utils/containers/merge_maps.h"
#include "utils/exception.h"
#include "utils/graph/graph_split.dtg.h"
#include "utils/graph/node/algorithms.h"
-#include "utils/graph/open_dataflow_graph/algorithms.h"
#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h"
#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h"
#include "utils/graph/serial_parallel/serial_parallel_decomposition.h"
diff --git a/lib/local-execution/src/op_arg_spec.cc b/lib/local-execution/src/local-execution/op_arg_spec.cc
similarity index 100%
rename from lib/local-execution/src/op_arg_spec.cc
rename to lib/local-execution/src/local-execution/op_arg_spec.cc
diff --git a/lib/local-execution/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc
index 789ed2cd63..33d62b713c 100644
--- a/lib/local-execution/src/ops/pool_2d.cc
+++ b/lib/local-execution/src/ops/pool_2d.cc
@@ -3,7 +3,6 @@
#include "op-attrs/get_output_shapes.h"
#include "op-attrs/ops/pool_2d.h"
-#include "utils/exception.decl.h"
#include "utils/exception.h"
#include "utils/hash-utils.h"
diff --git a/lib/local-execution/src/ops/transpose.cc b/lib/local-execution/src/ops/transpose.cc
index 5c3c1dd1ca..3e4ac15db3 100644
--- a/lib/local-execution/src/ops/transpose.cc
+++ b/lib/local-execution/src/ops/transpose.cc
@@ -17,7 +17,6 @@
#include "kernels/transpose_kernels.h"
#include "op-attrs/get_output_shapes.h"
#include "op-attrs/ops/transpose.h"
-#include "utils/exception.decl.h"
using namespace FlexFlow::Kernels::Transpose;
diff --git a/lib/op-attrs/include/op-attrs/as_dot.h b/lib/op-attrs/include/op-attrs/as_dot.h
deleted file mode 100644
index d92557c2f4..0000000000
--- a/lib/op-attrs/include/op-attrs/as_dot.h
+++ /dev/null
@@ -1,15 +0,0 @@
-#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H
-#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H
-
-#include "op-attrs/computation_graph_op_attrs.dtg.h"
-#include "op-attrs/pcg_operator_attrs.dtg.h"
-#include "utils/record_formatter.h"
-
-namespace FlexFlow {
-
-RecordFormatter as_dot(ComputationGraphOpAttrs const &);
-RecordFormatter as_dot(PCGOperatorAttrs const &);
-
-} // namespace FlexFlow
-
-#endif
diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h
index 4be17798f7..03f38bb8f9 100644
--- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h
+++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h
@@ -2,10 +2,12 @@
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H
#include "op-attrs/computation_graph_op_attrs.dtg.h"
+#include "utils/record_formatter.h"
namespace FlexFlow {
OperatorType get_op_type(ComputationGraphOpAttrs const &);
+RecordFormatter as_dot(ComputationGraphOpAttrs const &);
} // namespace FlexFlow
diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h
new file mode 100644
index 0000000000..f9f6d00532
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h
@@ -0,0 +1,30 @@
+#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H
+#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H
+
+#include "op-attrs/dim_ordered.h"
+#include "utils/bidict/bidict.h"
+#include "utils/containers/count.h"
+
+namespace FlexFlow {
+
+/**
+ * @brief Generate a map from indices to elements of \p c.
+ *
+ * @note We return a std::map to prevent mixups of \ref ff_dim_t and
+ * \ref legion_dim_t. Note that std::map provides ordered iteration in
+ * increasing order, so iterating through the result of this function should
+ * function as expected.
+ */
+template
+std::map enumerate(FFOrdered const &ff_ordered) {
+ std::map result;
+ for (int raw_ff_dim : count(ff_ordered.size())) {
+ ff_dim_t ff_dim = ff_dim_t{raw_ff_dim};
+ result.insert({ff_dim, ff_ordered.at(ff_dim)});
+ }
+ return result;
+}
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h
index 612c226a13..c27bbb190f 100644
--- a/lib/op-attrs/include/op-attrs/get_output_shapes.h
+++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h
@@ -1,228 +1,15 @@
#ifndef _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H
#define _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H
-#include "op-attrs/operator_attrs.h"
-#include "op-attrs/parallel_tensor_shape.h"
-#include "ops/reverse.h"
-#include "tensor_shape.h"
-#include "utils/containers/get_only.h"
-#include "utils/optional.h"
+#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/pcg_operator_attrs.dtg.h"
+#include
namespace FlexFlow {
-template
-struct has_unary_output_t : std::false_type {};
-template
-struct has_unary_input_t : std::false_type {};
-template
-struct has_binary_input_t : std::false_type {};
-
-template
-struct has_multi_output_t : std::true_type {};
-template
-struct has_multi_input_t : std::true_type {};
-
-template
-struct has_multi_output_t<
- T,
- typename std::enable_if::value>::type>
- : std::false_type {};
-
-template
-struct has_multi_input_t<
- T,
- typename std::enable_if<(has_unary_input_t::value ||
- has_binary_input_t::value)>::type>
- : std::false_type {};
-
-/* template struct output_type_t { using
- * type = std::vector; }; */
-
-template
-typename std::enable_if::value, bool>::type
- is_valid(T const &t, std::vector const &shapes) {
- if (shapes.size() != 1) {
- return false;
- }
-
- return is_valid(t, get_only(shapes));
-}
-
-template
-typename std::enable_if::value, bool>::type
- is_valid(T const &t, std::vector const &shapes) {
- if (shapes.size() != 2) {
- return false;
- }
-
- return is_valid(t, shapes.at(0), shapes.at(1));
-}
-
-template
-typename std::enable_if<(has_unary_input_t::value &&
- has_unary_output_t::value),
- ParallelTensorShape>::type
- output_shapes(T const &t, std::vector const &shapes) {
- return output_shape(t, get_only(shapes));
-}
-
-template
-typename std::enable_if<(has_binary_input_t::value &&
- has_unary_output_t::value),
- std::vector>::type
- output_shapes(T const &t, std::vector const &shapes) {
- assert(shapes.size() == 2);
-
- return {output_shape(t, shapes.at(0), shapes.at(1))};
-}
-
-TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &);
-std::vector
- get_tensor_shapes_unsafe(std::vector const &);
-
-template
-TensorShape get_output_shape(Attrs const &attrs, TensorShape const &shape) {
- NOT_IMPLEMENTED();
-}
-
-template
-TensorShape get_output_shape(Attrs const &attrs,
- TensorShape const &,
- TensorShape const &) {
- NOT_IMPLEMENTED();
-}
-
-template
-TensorShape get_output_shape(Attrs const &attrs,
- std::vector const &) {
- NOT_IMPLEMENTED();
-}
-template
-std::vector get_output_shapes(Attrs const &attrs,
- TensorShape const &);
-template
-std::vector get_output_shapes(Attrs const &attrs,
- TensorShape const &,
- TensorShape const &) {
- NOT_IMPLEMENTED();
-}
-template
-std::vector get_output_shapes(Attrs const &attrs,
- std::vector const &);
-
-ParallelTensorShape get_output_shape(ConcatAttrs const &,
- std::vector const &);
-ParallelTensorShape get_output_shape(FlatAttrs const &,
- ParallelTensorShape const &);
-std::vector get_output_shapes(GatherAttrs const &,
- ParallelTensorShape const &,
- ParallelTensorShape const &);
-ParallelTensorShape get_output_shape(Pool2DAttrs const &,
- ParallelTensorShape const &);
-ParallelTensorShape get_output_shape(ReduceAttrs const &,
- ParallelTensorShape const &);
-ParallelTensorShape get_output_shape(ReverseAttrs const &,
- ParallelTensorShape const &);
-std::vector get_output_shapes(SplitAttrs const &,
- ParallelTensorShape const &);
-ParallelTensorShape get_output_shape(TopKAttrs const &,
- ParallelTensorShape const &);
-ParallelTensorShape get_output_shape(TransposeAttrs const &,
- std::vector const &);
-
-struct GetOutputShapesFunctor {
- GetOutputShapesFunctor(std::vector const &s) : s(s) {}
-
- std::vector const &s;
-
- template
- std::vector operator()(T const &t) {
- return get_output_shapes(t, s);
- }
-};
-
-template
std::vector
- get_output_shapes(std::variant const &t,
- std::vector const &s) {
- return get_output_shape(GetOutputShapesFunctor{s}, t);
-}
-
-template
-typename std::enable_if::value, std::optional>::type
- get_num_outputs(T const &) {
- return std::nullopt;
-}
-
-template
-typename std::enable_if::value, std::optional>::type
- get_num_outputs(T const &) {
- return 1;
-}
-
-int get_num_outputs(SplitAttrs const &attrs);
-
-template
-bool is_valid(T const &t, std::vector const &shapes) {
- auto num_outputs = get_num_outputs(t);
- if (num_outputs.has_value() && shapes.size() != num_outputs.value()) {
- return false;
- }
-
- for (ParallelTensorShape const &shape : shapes) {
- if (!is_valid(shape)) {
- return false;
- }
- }
-
- return is_valid_internal(t, shapes);
-}
-
-template
-typename std::enable_if::value, bool>::type
- is_valid_internal(T const &t,
- std::vector const &shapes) {
- return is_valid_internal(t, get_only(shapes));
-}
-
-template
-typename std::enable_if::value, bool>::type
- is_valid_internal(T const &t,
- std::vector const &shapes) {
- return is_valid_internal(t, shapes.at(0), shapes.at(1));
-}
-
-bool is_valid_internal(MultiHeadAttentionAttrs const &,
- std::vector const &);
-bool is_valid_internal(BatchMatmulAttrs const &,
- ParallelTensorShape const &,
- ParallelTensorShape const &);
-bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(ConcatAttrs const &,
- std::vector const &);
-bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(ElementBinaryAttrs const &,
- ParallelTensorShape const &,
- ParallelTensorShape const &);
-bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(GatherAttrs const &,
- ParallelTensorShape const &,
- ParallelTensorShape const &);
-bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &);
-bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &);
+ get_output_shapes(PCGOperatorAttrs const &,
+ std::vector const &);
} // namespace FlexFlow
diff --git a/lib/op-attrs/include/op-attrs/is_valid.h b/lib/op-attrs/include/op-attrs/is_valid.h
new file mode 100644
index 0000000000..2d91307e19
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/is_valid.h
@@ -0,0 +1,59 @@
+#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_IS_VALID_H
+#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_IS_VALID_H
+
+#include "op-attrs/parallel_tensor_shape.h"
+#include "op-attrs/pcg_operator_attrs.dtg.h"
+
+namespace FlexFlow {
+
+template
+bool is_valid(T const &t, std::vector const &shapes) {
+ auto num_outputs = get_num_outputs(t);
+ if (num_outputs.has_value() && shapes.size() != num_outputs.value()) {
+ return false;
+ }
+
+ for (ParallelTensorShape const &shape : shapes) {
+ if (!is_valid(shape)) {
+ return false;
+ }
+ }
+
+ return is_valid_internal(t, shapes);
+}
+
+bool is_valid_internal(MultiHeadAttentionAttrs const &,
+ std::vector const &);
+bool is_valid_internal(BatchMatmulAttrs const &,
+ ParallelTensorShape const &,
+ ParallelTensorShape const &);
+bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(ConcatAttrs const &,
+ std::vector const &);
+bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(ElementBinaryAttrs const &,
+ ParallelTensorShape const &,
+ ParallelTensorShape const &);
+bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(GatherAttrs const &,
+ ParallelTensorShape const &,
+ ParallelTensorShape const &);
+bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &);
+bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h
index e126c425dc..40f57d08af 100644
--- a/lib/op-attrs/include/op-attrs/ops/attention.h
+++ b/lib/op-attrs/include/op-attrs/ops/attention.h
@@ -1,10 +1,10 @@
#ifndef _FLEXFLOW_ATTENTION_ATTRS_H
#define _FLEXFLOW_ATTENTION_ATTRS_H
-#include "core.h"
#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h"
#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h"
#include "op-attrs/ops/attention_attrs.dtg.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
#include "op-attrs/tensor_shape.dtg.h"
#include
diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h
index b9a1d87a75..8afcbb06b1 100644
--- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h
+++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h
@@ -1,12 +1,13 @@
#ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H
#define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H
-#include "core.h"
#include "op-attrs/ops/batch_norm_attrs.dtg.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/parallel_tensor_shape.h"
namespace FlexFlow {
+TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &);
ParallelTensorShape get_output_shape(BatchNormAttrs const &,
ParallelTensorShape const &);
diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h
index ead779c553..f85481b45b 100644
--- a/lib/op-attrs/include/op-attrs/ops/cast.h
+++ b/lib/op-attrs/include/op-attrs/ops/cast.h
@@ -3,8 +3,8 @@
#include "op-attrs/ops/cast_attrs.dtg.h"
#include "op-attrs/ops/core.h"
-#include "op-attrs/parallel_tensor_shape.h"
-#include "op-attrs/tensor_shape.h"
+#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
#include
namespace FlexFlow {
diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h
index 8a72708971..f3ac8494c0 100644
--- a/lib/op-attrs/include/op-attrs/ops/concat.h
+++ b/lib/op-attrs/include/op-attrs/ops/concat.h
@@ -1,13 +1,20 @@
-#ifndef _FLEXFLOW_CONCAT_ATTRS_H
-#define _FLEXFLOW_CONCAT_ATTRS_H
+#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_H
+#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_H
-#include "core.h"
#include "op-attrs/ops/concat_attrs.dtg.h"
+#include "op-attrs/ops/core.h"
+#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
namespace FlexFlow {
CHECK_VALID_OP_ATTR(ConcatAttrs);
+TensorShape get_output_shape(ConcatAttrs const &,
+ std::vector const &);
+ParallelTensorShape get_output_shape(ConcatAttrs const &,
+ std::vector const &);
+
} // namespace FlexFlow
#endif
diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h
index d5d9069f51..676d21c59b 100644
--- a/lib/op-attrs/include/op-attrs/ops/flat.h
+++ b/lib/op-attrs/include/op-attrs/ops/flat.h
@@ -1,14 +1,19 @@
#ifndef _FLEXFLOW_FLAT_ATTRS_H
#define _FLEXFLOW_FLAT_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/flat_attrs.dtg.h"
-#include "op-attrs/parallel_tensor_shape.h"
+#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
namespace FlexFlow {
CHECK_VALID_OP_ATTR(FlatAttrs);
+TensorShape get_output_shape(FlatAttrs const &, TensorShape const &);
+ParallelTensorShape get_output_shape(FlatAttrs const &,
+ ParallelTensorShape const &);
+
} // namespace FlexFlow
#endif
diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h
index 79516a8862..42efd13b60 100644
--- a/lib/op-attrs/include/op-attrs/ops/gather.h
+++ b/lib/op-attrs/include/op-attrs/ops/gather.h
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_GATHER_ATTRS_H
#define _FLEXFLOW_GATHER_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/gather_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.h"
@@ -9,6 +9,13 @@ namespace FlexFlow {
CHECK_VALID_OP_ATTR(GatherAttrs);
+TensorShape get_output_shape(GatherAttrs const &,
+ TensorShape const &input,
+ TensorShape const &index);
+ParallelTensorShape get_output_shape(GatherAttrs const &,
+ ParallelTensorShape const &input,
+ ParallelTensorShape const &index);
+
} // namespace FlexFlow
#endif
diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h
index 9fe0ee2c2d..fe92c77a52 100644
--- a/lib/op-attrs/include/op-attrs/ops/input.h
+++ b/lib/op-attrs/include/op-attrs/ops/input.h
@@ -1,15 +1,17 @@
#ifndef _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H
#define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/input_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
namespace FlexFlow {
CHECK_VALID_OP_ATTR(InputAttrs);
-ParallelTensorShape get_output_shape(InputAttrs const &);
+TensorShape get_output_shape(InputAttrs const &);
+ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &);
} // namespace FlexFlow
diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h
index 94f9b9e147..29b0b2f514 100644
--- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h
+++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h
@@ -1,9 +1,10 @@
#ifndef _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H
#define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/layer_norm_attrs.dtg.h"
-#include "op-attrs/parallel_tensor_shape.h"
+#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
namespace FlexFlow {
diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h
index dd6948165e..795ba19ae8 100644
--- a/lib/op-attrs/include/op-attrs/ops/linear.h
+++ b/lib/op-attrs/include/op-attrs/ops/linear.h
@@ -5,12 +5,15 @@
#include "op-attrs/ops/linear_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
#include "op-attrs/tensor_shape.dtg.h"
+#include "utils/record_formatter.h"
#include
namespace FlexFlow {
CHECK_VALID_OP_ATTR(LinearAttrs);
+RecordFormatter as_dot(LinearAttrs const &);
+
tl::expected
get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input);
tl::expected get_bias_shape(LinearAttrs const &attrs,
diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions.h
deleted file mode 100644
index 58d372d9e5..0000000000
--- a/lib/op-attrs/include/op-attrs/ops/loss_functions.h
+++ /dev/null
@@ -1,75 +0,0 @@
-#ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H
-#define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H
-
-#include "core.h"
-#include "utils/exception.h"
-#include "utils/visitable.h"
-#include
-
-namespace FlexFlow {
-
-enum class LossFunction {
- CATEGORICAL_CROSSENTROPY,
- SPARSE_CATEGORICAL_CROSSENTROPY,
- MEAN_SQUARED_ERROR_AVG_REDUCE,
- MEAN_SQUARED_ERROR_SUM_REDUCE,
- IDENTITY
-};
-
-LossFunction parse_loss_function_name(std::string const &);
-
-struct SparseCategoricalCrossEntropyLossAttrs {
- req replace_labels; // for aggregate_spec: More predictions than labels
-};
-FF_VISITABLE_STRUCT(SparseCategoricalCrossEntropyLossAttrs, replace_labels);
-CHECK_VALID_OP_ATTR(SparseCategoricalCrossEntropyLossAttrs);
-
-struct OtherLossAttrs {
- req loss_type;
-};
-FF_VISITABLE_STRUCT(OtherLossAttrs, loss_type);
-CHECK_VALID_OP_ATTR(OtherLossAttrs);
-
-using LossAttrs =
- std::variant;
-
-LossFunction get_loss_function(OtherLossAttrs const &);
-LossFunction get_loss_function(SparseCategoricalCrossEntropyLossAttrs const &);
-LossFunction get_loss_function(LossAttrs const &);
-
-} // namespace FlexFlow
-
-namespace fmt {
-
-template <>
-struct formatter<::FlexFlow::LossFunction> : formatter {
- template
- auto format(::FlexFlow::LossFunction d, FormatContext &ctx) const
- -> decltype(ctx.out()) {
- using namespace FlexFlow;
-
- string_view name = "unknown";
- switch (d) {
- case LossFunction::CATEGORICAL_CROSSENTROPY:
- name = "CategoricalCrossEntropy";
- break;
- case LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY:
- name = "SparseCategoricalCrossEntropy";
- break;
- case LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE:
- name = "MeanSquaredErrorAvgReduce";
- break;
- case LossFunction::MEAN_SQUARED_ERROR_SUM_REDUCE:
- name = "MeanSquaredErrorSumReduce";
- break;
- case LossFunction::IDENTITY:
- name = "Identity";
- break;
- }
- return formatter::format(name, ctx);
- }
-};
-
-} // namespace fmt
-
-#endif
diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml
new file mode 100644
index 0000000000..17293095e4
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml
@@ -0,0 +1,23 @@
+namespace = "FlexFlow"
+name = "LossAttrs"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+ "json",
+ "rapidcheck",
+]
+
+includes = [
+ "op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.h",
+ "op-attrs/ops/loss_functions/other_loss_attrs.dtg.h",
+]
+
+[[values]]
+type = "::FlexFlow::SparseCategoricalCrossEntropyLossAttrs"
+key = "sparse_categorical_cross_entropy_loss"
+
+[[values]]
+type = "::FlexFlow::OtherLossAttrs"
+key = "other_loss"
diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml
new file mode 100644
index 0000000000..9658202a45
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml
@@ -0,0 +1,23 @@
+namespace = "FlexFlow"
+name = "LossFunction"
+features = [
+ "fmt",
+ "hash",
+ "rapidcheck",
+ "json",
+]
+
+[[values]]
+name = "CATEGORICAL_CROSSENTROPY"
+
+[[values]]
+name = "SPARSE_CATEGORICAL_CROSSENTROPY"
+
+[[values]]
+name = "MEAN_SQUARED_ERROR_AVG_REDUCE"
+
+[[values]]
+name = "MEAN_SQUARED_ERROR_SUM_REDUCE"
+
+[[values]]
+name = "IDENTITY"
diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h
new file mode 100644
index 0000000000..ca8f3e6602
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h
@@ -0,0 +1,20 @@
+#ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H
+#define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H
+
+#include "op-attrs/ops/core.h"
+#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h"
+#include "op-attrs/ops/loss_functions/loss_function.dtg.h"
+
+namespace FlexFlow {
+
+CHECK_VALID_OP_ATTR(LossAttrs);
+
+LossFunction parse_loss_function_name(std::string const &);
+
+LossFunction get_loss_function(OtherLossAttrs const &);
+LossFunction get_loss_function(SparseCategoricalCrossEntropyLossAttrs const &);
+LossFunction get_loss_function(LossAttrs const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml
new file mode 100644
index 0000000000..284a4b1d7d
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml
@@ -0,0 +1,18 @@
+namespace = "FlexFlow"
+name = "OtherLossAttrs"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+ "rapidcheck",
+ "json",
+]
+
+includes = [
+ "op-attrs/ops/loss_functions/loss_function.dtg.h",
+]
+
+[[fields]]
+name = "loss_type"
+type = "::FlexFlow::LossFunction"
diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml
new file mode 100644
index 0000000000..c50b432ba2
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml
@@ -0,0 +1,15 @@
+namespace = "FlexFlow"
+name = "SparseCategoricalCrossEntropyLossAttrs"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+ "rapidcheck",
+ "json",
+]
+
+[[fields]]
+# for aggregate_spec: More predictions than labels
+name = "replace_labels"
+type = "bool"
diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h
index eb01009259..2c61dff886 100644
--- a/lib/op-attrs/include/op-attrs/ops/noop.h
+++ b/lib/op-attrs/include/op-attrs/ops/noop.h
@@ -1,14 +1,16 @@
#ifndef _FLEXFLOW_OP_ATTRS_OPS_NOOP_H
#define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/noop_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
namespace FlexFlow {
CHECK_VALID_OP_ATTR(NoopAttrs);
+TensorShape get_output_shape(NoopAttrs const &, TensorShape const &);
ParallelTensorShape get_output_shape(NoopAttrs const &,
ParallelTensorShape const &);
diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h
index 162f9aef05..505fdd9f8c 100644
--- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h
+++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h
@@ -1,14 +1,16 @@
#ifndef _FLEXFLOW_POOL_2D_ATTRS_H
#define _FLEXFLOW_POOL_2D_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/pool_2d_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
namespace FlexFlow {
CHECK_VALID_OP_ATTR(Pool2DAttrs);
+TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &);
ParallelTensorShape get_output_shape(Pool2DAttrs const &,
ParallelTensorShape const &);
diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h
index 4c46bf88a9..9104a36155 100644
--- a/lib/op-attrs/include/op-attrs/ops/replicate.h
+++ b/lib/op-attrs/include/op-attrs/ops/replicate.h
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_REPLICATE_ATTRS_H
#define _FLEXFLOW_REPLICATE_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/replicate_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h
index cd2ca80c3a..e87ca5c750 100644
--- a/lib/op-attrs/include/op-attrs/ops/reshape.h
+++ b/lib/op-attrs/include/op-attrs/ops/reshape.h
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_RESHAPE_ATTRS_H
#define _FLEXFLOW_RESHAPE_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/reshape_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
@@ -9,6 +9,8 @@ namespace FlexFlow {
CHECK_VALID_OP_ATTR(ReshapeAttrs);
+TensorShape get_output_shape(ReshapeAttrs const &attrs,
+ TensorShape const &input_shape);
ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs,
ParallelTensorShape const &input_shape);
diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h
index adc62dc9ae..023e714c20 100644
--- a/lib/op-attrs/include/op-attrs/ops/reverse.h
+++ b/lib/op-attrs/include/op-attrs/ops/reverse.h
@@ -1,14 +1,16 @@
#ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H
#define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/reverse_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
namespace FlexFlow {
CHECK_VALID_OP_ATTR(ReverseAttrs);
+TensorShape get_output_shape(ReverseAttrs const &, TensorShape const &);
ParallelTensorShape get_output_shape(ReverseAttrs const &attrs,
ParallelTensorShape const &input_shape);
diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h
index 8fc2257760..e6a08d6e77 100644
--- a/lib/op-attrs/include/op-attrs/ops/split.h
+++ b/lib/op-attrs/include/op-attrs/ops/split.h
@@ -1,15 +1,18 @@
#ifndef _FLEXFLOW_SPLIT_ATTRS_H
#define _FLEXFLOW_SPLIT_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/split_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
#include
namespace FlexFlow {
CHECK_VALID_OP_ATTR(SplitAttrs);
+std::vector get_output_shapes(SplitAttrs const &,
+ TensorShape const &);
std::vector
get_output_shapes(SplitAttrs const &attrs,
ParallelTensorShape const &input_shape);
diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h
index c6af40dd48..bd11f0ae91 100644
--- a/lib/op-attrs/include/op-attrs/ops/topk.h
+++ b/lib/op-attrs/include/op-attrs/ops/topk.h
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_TOPK_ATTRS_H
#define _FLEXFLOW_TOPK_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/topk_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h
index 6e23d91d78..6de83ee414 100644
--- a/lib/op-attrs/include/op-attrs/ops/transpose.h
+++ b/lib/op-attrs/include/op-attrs/ops/transpose.h
@@ -1,16 +1,18 @@
#ifndef _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H
#define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H
-#include "core.h"
+#include "op-attrs/ops/core.h"
#include "op-attrs/ops/transpose_attrs.dtg.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
namespace FlexFlow {
CHECK_VALID_OP_ATTR(TransposeAttrs);
-ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs,
- ParallelTensorShape const &input_shape);
+TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &);
+ParallelTensorShape get_output_shape(TransposeAttrs const &,
+ ParallelTensorShape const &);
} // namespace FlexFlow
diff --git a/lib/op-attrs/include/op-attrs/ops/weight.h b/lib/op-attrs/include/op-attrs/ops/weight.h
new file mode 100644
index 0000000000..ab97b31012
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/ops/weight.h
@@ -0,0 +1,18 @@
+#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H
+#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H
+
+#include "op-attrs/ops/core.h"
+#include "op-attrs/ops/weight_attrs.dtg.h"
+#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
+
+namespace FlexFlow {
+
+CHECK_VALID_OP_ATTR(WeightAttrs);
+
+TensorShape get_output_shape(WeightAttrs const &);
+ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml
index 28810a437e..c4d22a006c 100644
--- a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml
+++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml
@@ -8,4 +8,11 @@ features = [
"rapidcheck",
"fmt",
]
-fields = []
+
+includes = [
+ "op-attrs/tensor_shape.dtg.h",
+]
+
+[[fields]]
+name = "tensor_shape"
+type = "::FlexFlow::TensorShape"
diff --git a/lib/op-attrs/include/op-attrs/parallel_op_attrs.h b/lib/op-attrs/include/op-attrs/parallel_op_attrs.h
new file mode 100644
index 0000000000..8669669f09
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/parallel_op_attrs.h
@@ -0,0 +1,17 @@
+#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_OP_ATTRS_H
+#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_OP_ATTRS_H
+
+#include "op-attrs/parallel_op_attrs.dtg.h"
+#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/pcg_operator_attrs.dtg.h"
+#include "utils/record_formatter.h"
+
+namespace FlexFlow {
+
+ParallelTensorShape get_output_shape(ParallelOpAttrs const &,
+ ParallelTensorShape const &);
+PCGOperatorAttrs pcg_op_attrs_from_parallel_op_attrs(ParallelOpAttrs const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml
new file mode 100644
index 0000000000..f1631a41f2
--- /dev/null
+++ b/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml
@@ -0,0 +1,34 @@
+namespace = "FlexFlow"
+name = "ParallelOpAttrs"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "json",
+ "rapidcheck",
+ "fmt",
+]
+
+includes = [
+ "op-attrs/ops/combine_attrs.dtg.h",
+ "op-attrs/ops/reduction_attrs.dtg.h",
+ "op-attrs/ops/repartition_attrs.dtg.h",
+ "op-attrs/ops/replicate_attrs.dtg.h",
+]
+
+[[values]]
+type = "::FlexFlow::CombineAttrs"
+key = "combine_distributed"
+
+[[values]]
+type = "::FlexFlow::ReductionAttrs"
+key = "reduce_distributed"
+
+[[values]]
+type = "::FlexFlow::RepartitionAttrs"
+key = "partition_distributed"
+
+[[values]]
+type = "::FlexFlow::ReplicateAttrs"
+key = "replicate_distributed"
+
diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h
index 99be635ffc..76356b39d4 100644
--- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h
+++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h
@@ -2,6 +2,7 @@
#define _OP_META_PARALLEL_TENSOR_SHAPE_H
#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/replica_parallel_dim.dtg.h"
#include "op-attrs/tensor_shape.h"
#include
@@ -36,6 +37,7 @@ int get_total_parallel_degree(ParallelTensorShape const &);
bool is_valid(ParallelTensorShape const &);
+TensorShape require_not_parallel(ParallelTensorShape const &);
TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &);
std::vector
get_tensor_shapes_unsafe(std::vector const &);
diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml
index e6197bcd51..806af55cba 100644
--- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml
+++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml
@@ -10,8 +10,8 @@ features = [
]
includes = [
- "op-attrs/parallel_tensor_dims.h",
- "op-attrs/datatype.h",
+ "op-attrs/parallel_tensor_dims.dtg.h",
+ "op-attrs/datatype.dtg.h",
]
[[fields]]
diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h
index 25be926cbe..08167fe3d9 100644
--- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h
+++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h
@@ -10,6 +10,7 @@ bool is_parallel_op(PCGOperatorAttrs const &);
OperatorType get_op_type(PCGOperatorAttrs const &);
ComputationGraphOpAttrs
compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &);
+RecordFormatter as_dot(PCGOperatorAttrs const &);
} // namespace FlexFlow
diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc
deleted file mode 100644
index c20d4be34c..0000000000
--- a/lib/op-attrs/src/get_output_shapes.cc
+++ /dev/null
@@ -1,29 +0,0 @@
-#include "op-attrs/get_output_shapes.h"
-
-namespace FlexFlow {
-
-ParallelTensorShape as_parallel(TensorShape const &);
-std::vector as_parallel(std::vector const &);
-
-std::vector get_output_shapes(
- PCGOperatorAttrs const &op_params,
- std::vector const &input_tensor_shapes) {
- NOT_IMPLEMENTED();
-}
-
-// TensorShape get_output_shape(AggregateAttrs const &attrs,
-// TensorShape const &gate_preds,
-// TensorShape const &gate_assign,
-// TensorShape const &true_gate_assign,
-// TensorShape const &full_gate_gradients,
-// std::vector const &exp_preds) {
-// return get_tensor_shape_unsafe(
-// get_output_shape(attrs,
-// as_parallel(gate_preds),
-// as_parallel(gate_assign),
-// as_parallel(true_gate_assign),
-// as_parallel(full_gate_gradients),
-// as_parallel(exp_preds)));
-// }
-
-} // namespace FlexFlow
diff --git a/lib/op-attrs/src/op-attrs/as_dot.cc b/lib/op-attrs/src/op-attrs/as_dot.cc
deleted file mode 100644
index f8d05de941..0000000000
--- a/lib/op-attrs/src/op-attrs/as_dot.cc
+++ /dev/null
@@ -1,13 +0,0 @@
-#include "op-attrs/as_dot.h"
-
-namespace FlexFlow {
-
-RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) {
- NOT_IMPLEMENTED();
-}
-
-RecordFormatter as_dot(PCGOperatorAttrs const &attrs) {
- NOT_IMPLEMENTED();
-}
-
-} // namespace FlexFlow
diff --git a/lib/op-attrs/src/op-attrs/datatype.cc b/lib/op-attrs/src/op-attrs/datatype.cc
index bd29c8033a..3bee05c253 100644
--- a/lib/op-attrs/src/op-attrs/datatype.cc
+++ b/lib/op-attrs/src/op-attrs/datatype.cc
@@ -19,7 +19,7 @@ size_t size_of_datatype(DataType data_type) {
case DataType::DOUBLE:
return sizeof(double);
default:
- throw mk_runtime_error("Unknown DataType {}", data_type);
+ throw mk_runtime_error(fmt::format("Unknown DataType {}", data_type));
}
}
diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc
new file mode 100644
index 0000000000..6edd5485af
--- /dev/null
+++ b/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc
@@ -0,0 +1 @@
+#include "op-attrs/dim_ordered/enumerate.h"
diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc
new file mode 100644
index 0000000000..d91d1a1eca
--- /dev/null
+++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc
@@ -0,0 +1,85 @@
+#include "op-attrs/get_output_shapes.h"
+#include "op-attrs/ops/batch_matmul.h"
+#include "op-attrs/ops/batch_norm.h"
+#include "op-attrs/ops/cast.h"
+#include "op-attrs/ops/combine.h"
+#include "op-attrs/ops/concat.h"
+#include "op-attrs/ops/conv_2d.h"
+#include "op-attrs/ops/dropout.h"
+#include "op-attrs/ops/element_binary.h"
+#include "op-attrs/ops/element_unary.h"
+#include "op-attrs/ops/embedding.h"
+#include "op-attrs/ops/flat.h"
+#include "op-attrs/ops/gather.h"
+#include "op-attrs/ops/input.h"
+#include "op-attrs/ops/layer_norm.h"
+#include "op-attrs/ops/linear.h"
+#include "op-attrs/ops/replicate.h"
+#include "op-attrs/ops/weight.h"
+#include "utils/overload.h"
+
+namespace FlexFlow {
+
+std::vector
+ get_output_shapes(PCGOperatorAttrs const &pcg_op_attrs,
+ std::vector const &inputs) {
+ return pcg_op_attrs.visit>(overload{
+ [&](BatchMatmulAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(
+ get_output_shape(attrs, inputs.at(0), inputs.at(1)))};
+ },
+ [&](BatchNormAttrs const &attrs) -> std::vector {
+ return {get_output_shape(attrs, inputs.at(0))};
+ },
+ [&](CastAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))};
+ },
+ [&](CombineAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))};
+ },
+ [&](ConcatAttrs const &attrs) -> std::vector {
+ return {get_output_shape(attrs, inputs)};
+ },
+ [&](Conv2DAttrs const &attrs) -> std::vector {
+ return {get_output_shape(attrs, inputs.at(0))};
+ },
+ [&](DropoutAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))};
+ },
+ [&](ElementBinaryAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(
+ get_output_shape(attrs, inputs.at(0), inputs.at(1)))};
+ },
+ [&](ElementUnaryAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))};
+ },
+ [&](EmbeddingAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))};
+ },
+ [&](FlatAttrs const &attrs) -> std::vector {
+ return {get_output_shape(attrs, inputs.at(0))};
+ },
+ [&](GatherAttrs const &attrs) -> std::vector {
+ return {get_output_shape(attrs, inputs.at(0), inputs.at(1))};
+ },
+ [&](InputAttrs const &attrs) -> std::vector {
+ return {get_output_parallel_tensor_shape(attrs)};
+ },
+ [&](LayerNormAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))};
+ },
+ [&](LinearAttrs const &attrs) -> std::vector {
+ return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))};
+ },
+ [&](ReplicateAttrs const &attrs) -> std::vector {
+ return {get_output_shape(attrs, inputs.at(0))};
+ },
+ [&](WeightAttrs const &attrs) -> std::vector {
+ return {get_output_parallel_tensor_shape(attrs)};
+ },
+ [&](auto const &attrs) -> std::vector {
+ NOT_IMPLEMENTED();
+ }});
+}
+
+} // namespace FlexFlow
diff --git a/lib/op-attrs/src/op-attrs/is_valid.cc b/lib/op-attrs/src/op-attrs/is_valid.cc
new file mode 100644
index 0000000000..14eae33b4b
--- /dev/null
+++ b/lib/op-attrs/src/op-attrs/is_valid.cc
@@ -0,0 +1,3 @@
+#include "op-attrs/is_valid.h"
+
+namespace FlexFlow {} // namespace FlexFlow
diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc
index 7be51efa22..b75c3521c6 100644
--- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc
+++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc
@@ -2,6 +2,11 @@
namespace FlexFlow {
+TensorShape get_output_shape(BatchNormAttrs const &,
+ TensorShape const &input_shape) {
+ return input_shape;
+}
+
ParallelTensorShape get_output_shape(BatchNormAttrs const &,
ParallelTensorShape const &) {
NOT_IMPLEMENTED();
diff --git a/lib/op-attrs/src/op-attrs/ops/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc
index 444409ffcb..cfbfd61ced 100644
--- a/lib/op-attrs/src/op-attrs/ops/cast.cc
+++ b/lib/op-attrs/src/op-attrs/ops/cast.cc
@@ -1,4 +1,5 @@
#include "op-attrs/ops/cast.h"
+#include "op-attrs/datatype.h"
namespace FlexFlow {
diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc
index 065c58f365..02fee70bea 100644
--- a/lib/op-attrs/src/op-attrs/ops/concat.cc
+++ b/lib/op-attrs/src/op-attrs/ops/concat.cc
@@ -11,4 +11,14 @@ namespace FlexFlow {
/* return valid; */
/* } */
+TensorShape get_output_shape(ConcatAttrs const &,
+ std::vector const &) {
+ NOT_IMPLEMENTED();
+}
+
+ParallelTensorShape get_output_shape(ConcatAttrs const &,
+ std::vector const &) {
+ NOT_IMPLEMENTED();
+}
+
} // namespace FlexFlow
diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc
index 4a7d4395b6..d10d52c6f5 100644
--- a/lib/op-attrs/src/op-attrs/ops/embedding.cc
+++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc
@@ -1,6 +1,7 @@
#include "op-attrs/ops/embedding.h"
#include "op-attrs/dim_ordered/slice.h"
#include "op-attrs/dim_ordered/transform.h"
+#include "op-attrs/parallel_tensor_dims.h"
#include "utils/containers/product.h"
#include "utils/integer_conversions.h"
diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc
index b0683c5f08..5d318207ee 100644
--- a/lib/op-attrs/src/op-attrs/ops/flat.cc
+++ b/lib/op-attrs/src/op-attrs/ops/flat.cc
@@ -3,15 +3,24 @@
namespace FlexFlow {
-namespace Input {
-constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3,
- REPLICA = 4;
+TensorShape get_output_shape(FlatAttrs const &, TensorShape const &) {
+ NOT_IMPLEMENTED();
}
-namespace Output {
-constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2;
+ParallelTensorShape get_output_shape(FlatAttrs const &,
+ ParallelTensorShape const &) {
+ NOT_IMPLEMENTED();
}
+// namespace Input {
+// constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3,
+// REPLICA = 4;
+// }
+//
+// namespace Output {
+// constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2;
+// }
+//
/* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */
/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */
diff --git a/lib/op-attrs/src/op-attrs/ops/gather.cc b/lib/op-attrs/src/op-attrs/ops/gather.cc
index 4f2c13c794..4b1053aee1 100644
--- a/lib/op-attrs/src/op-attrs/ops/gather.cc
+++ b/lib/op-attrs/src/op-attrs/ops/gather.cc
@@ -2,6 +2,18 @@
namespace FlexFlow {
+TensorShape get_output_shape(GatherAttrs const &,
+ TensorShape const &input,
+ TensorShape const &index) {
+ NOT_IMPLEMENTED();
+}
+
+ParallelTensorShape get_output_shape(GatherAttrs const &,
+ ParallelTensorShape const &input,
+ ParallelTensorShape const &index) {
+ NOT_IMPLEMENTED();
+}
+
/* bool GatherAttrs::is_valid(ParallelTensorShape const &lhs,
* ParallelTensorShape const &rhs) const { */
/* if (lhs.num_dims() != rhs.num_dims()) { */
diff --git a/lib/op-attrs/src/op-attrs/ops/input.cc b/lib/op-attrs/src/op-attrs/ops/input.cc
index 93606b603a..acc0b02e69 100644
--- a/lib/op-attrs/src/op-attrs/ops/input.cc
+++ b/lib/op-attrs/src/op-attrs/ops/input.cc
@@ -2,7 +2,11 @@
namespace FlexFlow {
-ParallelTensorShape get_output_shape(InputAttrs const &) {
+TensorShape get_output_shape(InputAttrs const &) {
+ NOT_IMPLEMENTED();
+}
+
+ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &) {
NOT_IMPLEMENTED();
}
diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc
index 76a5e25dfc..b9603d7850 100644
--- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc
+++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc
@@ -1,6 +1,8 @@
#include "op-attrs/ops/layer_norm.h"
#include "op-attrs/dim_ordered/ff_ordered_of.h"
#include "op-attrs/dim_ordered/get_idxs.h"
+#include "op-attrs/parallel_tensor_shape.h"
+#include "op-attrs/tensor_shape.h"
#include "utils/containers/all_of.h"
#include "utils/containers/any_of.h"
#include "utils/containers/contains.h"
diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc
index beb944d1a0..24a8250690 100644
--- a/lib/op-attrs/src/op-attrs/ops/linear.cc
+++ b/lib/op-attrs/src/op-attrs/ops/linear.cc
@@ -8,6 +8,22 @@
namespace FlexFlow {
+RecordFormatter as_dot(LinearAttrs const &attrs) {
+ RecordFormatter r;
+
+ auto kv = [](std::string const &label, auto const &val) {
+ RecordFormatter rr;
+ rr << label << fmt::to_string(val);
+ return rr;
+ };
+
+ r << kv("out_channels", attrs.out_channels) << kv("use_bias", attrs.use_bias)
+ << kv("data_type", attrs.data_type) << kv("activation", attrs.activation)
+ << kv("regularizer", attrs.regularizer);
+
+ return r;
+}
+
tl::expected
get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) {
size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1});
diff --git a/lib/op-attrs/src/loss_functions.cc b/lib/op-attrs/src/op-attrs/ops/loss_functions.cc
similarity index 80%
rename from lib/op-attrs/src/loss_functions.cc
rename to lib/op-attrs/src/op-attrs/ops/loss_functions.cc
index 094e117d77..e756d08547 100644
--- a/lib/op-attrs/src/loss_functions.cc
+++ b/lib/op-attrs/src/op-attrs/ops/loss_functions.cc
@@ -1,4 +1,4 @@
-#include "op-attrs/ops/loss_functions.h"
+#include "op-attrs/ops/loss_functions/loss_functions.h"
#include "utils/containers/transform.h"
#include
#include
@@ -8,20 +8,15 @@ namespace FlexFlow {
LossFunction get_loss_type(OtherLossAttrs const &attrs) {
return attrs.loss_type;
}
+
LossFunction
get_loss_type(SparseCategoricalCrossEntropyLossAttrs const &attrs) {
return LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY;
}
-struct GetLossFunction {
- template
- LossFunction operator()(T const &t) {
- return get_loss_type(t);
- }
-};
-
LossFunction get_loss_type(LossAttrs const &attrs) {
- return visit(GetLossFunction{}, attrs);
+ return attrs.visit(
+ [](auto const &t) { return get_loss_type(t); });
}
LossFunction parse_loss_name(std::string const &raw_name) {
@@ -37,8 +32,8 @@ LossFunction parse_loss_name(std::string const &raw_name) {
} else if (name == "identity") {
return LossFunction::IDENTITY;
} else {
- throw mk_runtime_error(
- "Unknown loss type {}. Please report this as an issue.", name);
+ throw mk_runtime_error(fmt::format(
+ "Unknown loss type {}. Please report this as an issue.", name));
}
}
diff --git a/lib/op-attrs/src/op-attrs/ops/noop.cc b/lib/op-attrs/src/op-attrs/ops/noop.cc
index b2b15d820c..6ba33146e4 100644
--- a/lib/op-attrs/src/op-attrs/ops/noop.cc
+++ b/lib/op-attrs/src/op-attrs/ops/noop.cc
@@ -2,6 +2,11 @@
namespace FlexFlow {
+TensorShape get_output_shape(NoopAttrs const &,
+ TensorShape const &input_shape) {
+ return input_shape;
+}
+
ParallelTensorShape get_output_shape(NoopAttrs const &,
ParallelTensorShape const &input_shape) {
return input_shape;
diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc
index cf6ed177d3..e1917efd89 100644
--- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc
+++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc
@@ -2,6 +2,10 @@
namespace FlexFlow {
+TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &) {
+ NOT_IMPLEMENTED();
+}
+
ParallelTensorShape get_output_shape(Pool2DAttrs const &,
ParallelTensorShape const &) {
NOT_IMPLEMENTED();
diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc
index 7d0600550a..6216ad8c6c 100644
--- a/lib/op-attrs/src/op-attrs/ops/reshape.cc
+++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc
@@ -2,6 +2,11 @@
namespace FlexFlow {
+TensorShape get_output_shape(ReshapeAttrs const &attrs,
+ TensorShape const &input_shape) {
+ NOT_IMPLEMENTED();
+}
+
ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs,
ParallelTensorShape const &input_shape) {
NOT_IMPLEMENTED();
diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc
index 79b5bd50fb..c38d7e4782 100644
--- a/lib/op-attrs/src/op-attrs/ops/reverse.cc
+++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc
@@ -2,6 +2,10 @@
namespace FlexFlow {
+TensorShape get_output_shape(ReverseAttrs const &, TensorShape const &) {
+ NOT_IMPLEMENTED();
+}
+
ParallelTensorShape get_output_shape(ReverseAttrs const &attrs,
ParallelTensorShape const &input_shape) {
NOT_IMPLEMENTED();
diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc
index cfb4071833..a9fe691584 100644
--- a/lib/op-attrs/src/op-attrs/ops/split.cc
+++ b/lib/op-attrs/src/op-attrs/ops/split.cc
@@ -2,6 +2,11 @@
namespace FlexFlow {
+std::vector get_output_shapes(SplitAttrs const &,
+ TensorShape const &) {
+ NOT_IMPLEMENTED();
+}
+
std::vector
get_output_shapes(SplitAttrs const &attrs,
ParallelTensorShape const &input_shape) {
diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc
index 75f7eb3c18..50e6fb35f5 100644
--- a/lib/op-attrs/src/op-attrs/ops/transpose.cc
+++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc
@@ -2,6 +2,10 @@
namespace FlexFlow {
+TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &) {
+ NOT_IMPLEMENTED();
+}
+
ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs,
ParallelTensorShape const &input_shape) {
NOT_IMPLEMENTED();
diff --git a/lib/op-attrs/src/op-attrs/ops/weight.cc b/lib/op-attrs/src/op-attrs/ops/weight.cc
new file mode 100644
index 0000000000..f8b6b7ec49
--- /dev/null
+++ b/lib/op-attrs/src/op-attrs/ops/weight.cc
@@ -0,0 +1,14 @@
+#include "op-attrs/ops/weight.h"
+#include "op-attrs/parallel_tensor_shape.h"
+
+namespace FlexFlow {
+
+TensorShape get_output_shape(WeightAttrs const &attrs) {
+ return attrs.tensor_shape;
+}
+
+ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &attrs) {
+ return lift_to_parallel(attrs.tensor_shape);
+}
+
+} // namespace FlexFlow
diff --git a/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc
new file mode 100644
index 0000000000..c458d4149d
--- /dev/null
+++ b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc
@@ -0,0 +1,37 @@
+#include "op-attrs/parallel_op_attrs.h"
+#include "op-attrs/ops/combine.h"
+#include "op-attrs/ops/reduction.h"
+#include "op-attrs/ops/repartition.h"
+#include "op-attrs/ops/replicate.h"
+#include "utils/overload.h"
+
+namespace FlexFlow {
+
+ParallelTensorShape get_output_shape(ParallelOpAttrs const &attrs,
+ ParallelTensorShape const &input_shape) {
+ return attrs.visit(overload{
+ [&](CombineAttrs const &combine_attrs) {
+ return throw_if_unexpected(
+ get_output_shape(combine_attrs, input_shape));
+ },
+ [&](ReductionAttrs const &reduction_attrs) {
+ return throw_if_unexpected(
+ get_output_shape(reduction_attrs, input_shape));
+ },
+ [&](RepartitionAttrs const &repartition_attrs) {
+ return throw_if_unexpected(
+ get_output_shape(repartition_attrs, input_shape));
+ },
+ [&](ReplicateAttrs const &replicate_attrs) {
+ return get_output_shape(replicate_attrs, input_shape);
+ },
+ });
+}
+
+PCGOperatorAttrs
+ pcg_op_attrs_from_parallel_op_attrs(ParallelOpAttrs const &attrs) {
+ return attrs.visit(
+ [](auto const &attrs) { return PCGOperatorAttrs{attrs}; });
+}
+
+} // namespace FlexFlow
diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc
index 150fb6a76d..10bf5027a4 100644
--- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc
+++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc
@@ -1,4 +1,5 @@
#include "op-attrs/parallel_tensor_shape.h"
+#include "op-attrs/parallel_tensor_dims.h"
#include "op-attrs/tensor_dims.h"
#include "utils/containers/product.h"
#include "utils/containers/transform.h"
@@ -74,6 +75,19 @@ ParallelTensorShape
};
}
+TensorShape require_not_parallel(ParallelTensorShape const &s) {
+ int total_degree = get_total_parallel_degree(s);
+ if (total_degree != 1) {
+ throw mk_runtime_error(
+ fmt::format("Error: require_not_parallel received a parallel tensor "
+ "shape with parallel degree {}: {}",
+ total_degree,
+ s));
+ }
+
+ return get_reduced_shape(s);
+}
+
TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) {
NOT_IMPLEMENTED();
}
diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc
index 74882fe9f2..0bb134da6b 100644
--- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc
+++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc
@@ -1,5 +1,6 @@
#include "op-attrs/pcg_operator_attrs.h"
#include "op-attrs/get_op_type.h"
+#include "op-attrs/ops/linear.h"
#include "utils/overload.h"
namespace FlexFlow {
@@ -64,4 +65,15 @@ ComputationGraphOpAttrs
});
}
+RecordFormatter as_dot(PCGOperatorAttrs const &attrs) {
+ return attrs.visit(overload{
+ [](LinearAttrs const &l) { return as_dot(l); },
+ [&](auto const &) {
+ RecordFormatter r;
+ r << fmt::to_string(get_op_type(attrs));
+ return r;
+ },
+ });
+}
+
} // namespace FlexFlow
diff --git a/lib/op-attrs/src/op_attrs.cc b/lib/op-attrs/src/op_attrs.cc
deleted file mode 100644
index 6125c03a59..0000000000
--- a/lib/op-attrs/src/op_attrs.cc
+++ /dev/null
@@ -1,10 +0,0 @@
-/* #include "op-attrs/ops/op_attrs.h" */
-
-/* namespace FlexFlow { */
-
-/* int OpAttrsInterface::num_outputs(std::vector const
- * &inputs) const { */
-/* return this->output_shapes(inputs).size(); */
-/* } */
-
-/* } */
diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc
new file mode 100644
index 0000000000..d2c758a05f
--- /dev/null
+++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc
@@ -0,0 +1,20 @@
+#include "op-attrs/dim_ordered/enumerate.h"
+#include "utils/fmt/map.h"
+#include
+
+using namespace ::FlexFlow;
+
+TEST_SUITE(FF_TEST_SUITE) {
+ TEST_CASE("enumerate(FFOrdered)") {
+ FFOrdered input = {"zero", "one", "two"};
+
+ std::map result = enumerate(input);
+ std::map correct = {
+ {ff_dim_t{0}, "zero"},
+ {ff_dim_t{1}, "one"},
+ {ff_dim_t{2}, "two"},
+ };
+
+ CHECK(result == correct);
+ }
+}
diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc
index 8f5f4054d6..b9dd66df5d 100644
--- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc
+++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc
@@ -1,4 +1,5 @@
#include "op-attrs/ops/layer_norm.h"
+#include "op-attrs/parallel_tensor_shape.h"
#include "test/utils/doctest.h"
#include "utils/expected.h"
#include "utils/fmt/expected.h"
diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/ops/cast.cc
index 086d25d042..31030ca0f9 100644
--- a/lib/op-attrs/test/src/ops/cast.cc
+++ b/lib/op-attrs/test/src/ops/cast.cc
@@ -1,4 +1,5 @@
#include "op-attrs/ops/cast.h"
+#include "op-attrs/parallel_tensor_shape.h"
#include "test/utils/doctest.h"
TEST_SUITE(FF_TEST_SUITE) {
diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h
index 0ca330408e..c641aed6a4 100644
--- a/lib/pcg/include/pcg/computation_graph_builder.h
+++ b/lib/pcg/include/pcg/computation_graph_builder.h
@@ -114,11 +114,10 @@ struct ComputationGraphBuilder {
std::optional const &kernel_initializer = std::nullopt,
std::optional const &name = std::nullopt);
// Add a gather layer
- std::vector
- gather(tensor_guid_t const &input,
- tensor_guid_t const &index,
- ff_dim_t dim,
- std::optional const &name = std::nullopt);
+ tensor_guid_t gather(tensor_guid_t const &input,
+ tensor_guid_t const &index,
+ ff_dim_t dim,
+ std::optional const &name = std::nullopt);
// Add a cache layer
tensor_guid_t
cache(tensor_guid_t const &input,
diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h
index 0e547e7688..05c486f0f7 100644
--- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h
+++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h
@@ -8,7 +8,7 @@ namespace FlexFlow {
V1DataflowGraph to_v1(DataflowGraphView const &);
V1DataflowGraph to_v1(DataflowGraphView const &,
- std::unordered_map const &);
+ std::unordered_map const &);
} // namespace FlexFlow
diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml
index dc9dc96f29..d9aade739c 100644
--- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml
+++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml
@@ -19,7 +19,7 @@ includes = [
[[fields]]
name = "nodes"
-type = "std::vector"
+type = "std::vector"
[[fields]]
name = "edges"
diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml
index b0d2546977..752706fe1d 100644
--- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml
+++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml
@@ -11,16 +11,16 @@ features = [
[[fields]]
name = "srcNode"
-type = "size_t"
+type = "int"
[[fields]]
name = "srcIdx"
-type = "size_t"
+type = "int"
[[fields]]
name = "dstNode"
-type = "size_t"
+type = "int"
[[fields]]
name = "dstIdx"
-type = "size_t"
+type = "int"
diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h
index b1f96c513b..48203d73ae 100644
--- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h
+++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h
@@ -3,7 +3,7 @@
#include "pcg/file_format/v1/graphs/v1_dataflow_graph.h"
#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h"
-#include "utils/containers/enumerate.h"
+#include "utils/bidict/algorithms/bidict_from_enumerating.h"
#include "utils/containers/map_values.h"
#include "utils/containers/transform.h"
#include "utils/graph/dataflow_graph/algorithms.h"
@@ -16,14 +16,14 @@ template
V1LabelledDataflowGraph
to_v1(LabelledDataflowGraphView const &g) {
- bidict nodes = enumerate(get_nodes(g));
+ bidict nodes = bidict_from_enumerating(get_nodes(g));
V1DataflowGraph unlabelled = to_v1(g, nodes.reversed());
- std::unordered_map node_labels = map_values(
+ std::unordered_map node_labels = map_values(
nodes.as_unordered_map(), [&](Node const &n) { return g.at(n); });
- std::unordered_map> output_labels =
+ std::unordered_map> output_labels =
map_values(nodes.as_unordered_map(), [&](Node const &n) {
return transform(get_outputs(g, n),
[&](DataflowOutput const &o) { return g.at(o); });
diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml
index 0a6a148159..fd8d4c39c4 100644
--- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml
+++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml
@@ -22,11 +22,11 @@ includes = [
[[fields]]
name = "node_labels"
-type = "std::unordered_map"
+type = "std::unordered_map"
[[fields]]
name = "output_labels"
-type = "std::unordered_map>"
+type = "std::unordered_map>"
[[fields]]
name = "graph"
diff --git a/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h b/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h
new file mode 100644
index 0000000000..eb4928deaa
--- /dev/null
+++ b/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h
@@ -0,0 +1,16 @@
+#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_GENERATE_WEIGHT_TRANSFORM_H
+#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_GENERATE_WEIGHT_TRANSFORM_H
+
+#include "op-attrs/parallel_op_attrs.dtg.h"
+#include "op-attrs/parallel_tensor_shape.dtg.h"
+#include "op-attrs/tensor_shape.dtg.h"
+
+namespace FlexFlow {
+
+std::unordered_set
+ generate_weight_transform(TensorShape const ¤t,
+ ParallelTensorShape const &goal);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h
index 4caaad06b2..9150681070 100644
--- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h
+++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h
@@ -28,9 +28,6 @@ std::vector
get_layer_outputs(ParallelComputationGraph const &,
parallel_layer_guid_t const &);
-parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &,
- parallel_tensor_guid_t const &);
-
ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &,
parallel_layer_guid_t const &);
ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &,
diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h
index 5b34ee641a..20e947ad58 100644
--- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h
+++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h
@@ -84,10 +84,35 @@ struct ParallelComputationGraphBuilder {
std::optional output_bias_initializer = std::nullopt,
std::optional const &name = std::nullopt);
+ parallel_tensor_guid_t
+ batch_norm(parallel_tensor_guid_t const &input,
+ bool relu = true,
+ std::optional const &name = std::nullopt);
+
parallel_tensor_guid_t
relu(parallel_tensor_guid_t const &x,
std::optional const &name = std::nullopt);
+ parallel_tensor_guid_t
+ identity(parallel_tensor_guid_t const &x,
+ std::optional const &name = std::nullopt);
+
+ parallel_tensor_guid_t
+ gelu(parallel_tensor_guid_t const &x,
+ std::optional const &name = std::nullopt);
+
+ parallel_tensor_guid_t
+ sigmoid(parallel_tensor_guid_t const &x,
+ std::optional const &name = std::nullopt);
+
+ parallel_tensor_guid_t
+ tanh(parallel_tensor_guid_t const &x,
+ std::optional const &name = std::nullopt);
+
+ parallel_tensor_guid_t
+ elu(parallel_tensor_guid_t const &x,
+ std::optional const &name = std::nullopt);
+
parallel_tensor_guid_t
parallel_partition(parallel_tensor_guid_t const &x,
ff_dim_t dim,
@@ -137,6 +162,15 @@ struct ParallelComputationGraphBuilder {
std::vector const &weights,
ParallelTensorShape const &output);
+ parallel_tensor_guid_t
+ add_weight(ParallelTensorAttrs const &weight_tensor_attrs,
+ std::optional const &name = std::nullopt);
+
+ parallel_tensor_guid_t
+ element_unary(ElementUnaryAttrs const &element_unary_attrs,
+ parallel_tensor_guid_t const &input,
+ std::optional const &name);
+
public:
ParallelComputationGraph pcg;
};
diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h
new file mode 100644
index 0000000000..7aac8558e4
--- /dev/null
+++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h
@@ -0,0 +1,18 @@
+#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_EDGE_H
+#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_EDGE_H
+
+#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h"
+#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h"
+#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h"
+
+namespace FlexFlow {
+
+parallel_tensor_guid_t
+ get_parallel_tensor(ParallelComputationGraphEdge const &);
+parallel_layer_guid_t get_src_layer(ParallelComputationGraphEdge const &);
+parallel_layer_guid_t get_dst_layer(ParallelComputationGraphEdge const &);
+int get_dst_layer_input_idx(ParallelComputationGraphEdge const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml
new file mode 100644
index 0000000000..25ef3f5d27
--- /dev/null
+++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml
@@ -0,0 +1,16 @@
+namespace = "FlexFlow"
+name = "ParallelComputationGraphEdge"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+]
+
+includes = [
+ "utils/graph/dataflow_graph/dataflow_edge.dtg.h",
+]
+
+[[fields]]
+name = "raw_edge"
+type = "::FlexFlow::DataflowEdge"
diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h
new file mode 100644
index 0000000000..905a365b4b
--- /dev/null
+++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h
@@ -0,0 +1,13 @@
+#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_H
+#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_H
+
+#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h"
+#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h"
+
+namespace FlexFlow {
+
+parallel_layer_guid_t get_source_layer(parallel_tensor_guid_t const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml
new file mode 100644
index 0000000000..6d5e007650
--- /dev/null
+++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml
@@ -0,0 +1,16 @@
+namespace = "FlexFlow"
+name = "parallel_tensor_use_t"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+]
+
+includes = [
+ "utils/graph/dataflow_graph/dataflow_input.dtg.h",
+]
+
+[[fields]]
+name = "raw_dataflow_input"
+type = "::FlexFlow::DataflowInput"
diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc
index e7f5f2b737..deaa440ef8 100644
--- a/lib/pcg/src/pcg/computation_graph.cc
+++ b/lib/pcg/src/pcg/computation_graph.cc
@@ -53,7 +53,7 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg,
std::vector get_incoming_tensors(ComputationGraph const &cg,
layer_guid_t n) {
- return transform(get_inputs(cg.raw_graph, n.raw_node),
+ return transform(get_input_values(cg.raw_graph, n.raw_node),
[](DataflowOutput const &o) { return tensor_guid_t{o}; });
}
diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc
index b6d0e7c890..3f2feaf619 100644
--- a/lib/pcg/src/pcg/computation_graph_builder.cc
+++ b/lib/pcg/src/pcg/computation_graph_builder.cc
@@ -2,13 +2,24 @@
#include "op-attrs/computation_graph_op_attrs.h"
#include "op-attrs/get_op_type.h"
#include "op-attrs/get_output_shapes.h"
+#include "op-attrs/ops/attention.h"
+#include "op-attrs/ops/batch_norm.h"
+#include "op-attrs/ops/broadcast.h"
+#include "op-attrs/ops/conv_2d.h"
+#include "op-attrs/ops/dropout.h"
#include "op-attrs/ops/element_binary.h"
+#include "op-attrs/ops/element_unary.h"
#include "op-attrs/ops/embedding.h"
+#include "op-attrs/ops/gather.h"
+#include "op-attrs/ops/layer_norm.h"
+#include "op-attrs/ops/linear.h"
+#include "op-attrs/ops/softmax.h"
#include "op-attrs/ops/weight_attrs.dtg.h"
#include "pcg/computation_graph.h"
#include "utils/containers/any_of.h"
#include "utils/containers/concat_vectors.h"
#include "utils/containers/enumerate_vector.h"
+#include "utils/containers/get_only.h"
#include "utils/containers/transform.h"
#include "utils/expected.h"
#include
@@ -49,7 +60,7 @@ std::vector ComputationGraphBuilder::add_layer(
return fmt::format("{}.weights[{}]", layer_name, weight_idx);
});
LayerAttrs weight_layer_attrs = LayerAttrs{
- ComputationGraphOpAttrs{WeightAttrs{}},
+ ComputationGraphOpAttrs{WeightAttrs{weight_tensor_attrs.shape}},
weight_name,
};
std::vector weight_layer_inputs = {};
@@ -451,7 +462,7 @@ tensor_guid_t ComputationGraphBuilder::embedding(
return this->add_layer(layer, {input}, {weight_attrs}, output_shape);
}
-std::vector ComputationGraphBuilder::gather(
+tensor_guid_t ComputationGraphBuilder::gather(
tensor_guid_t const &input,
tensor_guid_t const &index,
ff_dim_t dim,
@@ -469,10 +480,10 @@ std::vector ComputationGraphBuilder::gather(
DataType::INT32,
DataType::INT64);
}
- std::vector output_shapes =
- get_output_shapes(attrs, this->get_shape(input), this->get_shape(index));
+ TensorShape output_shape =
+ get_output_shape(attrs, this->get_shape(input), this->get_shape(index));
- return this->add_layer(layer, {input}, {}, output_shapes);
+ return this->add_layer(layer, {input}, {}, output_shape);
}
/* std::vector
diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc
index 787ce5bf7d..cf150a339f 100644
--- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc
+++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc
@@ -1,4 +1,5 @@
#include "pcg/file_format/v1/graphs/v1_dataflow_graph.h"
+#include "utils/bidict/algorithms/bidict_from_enumerating.h"
#include "utils/containers/enumerate.h"
#include "utils/containers/sorted.h"
#include "utils/containers/values.h"
@@ -9,17 +10,19 @@
namespace FlexFlow {
V1DataflowGraph to_v1(DataflowGraphView const &g) {
- return to_v1(g, enumerate(get_nodes(g)).reversed());
+ bidict node_enumeration_bidict =
+ bidict_from_enumerating(get_nodes(g));
+ std::unordered_map node_enumeration =
+ node_enumeration_bidict.reversed().as_unordered_map();
+ return to_v1(g, node_enumeration);
}
V1DataflowGraph to_v1(DataflowGraphView const &g,
- std::unordered_map const &nodes) {
+ std::unordered_map const &nodes) {
std::unordered_set edges;
for (DataflowEdge const &e : get_edges(g)) {
- edges.insert(V1GraphEdge{nodes.at(e.src.node),
- size_t_from_int(e.src.idx),
- nodes.at(e.dst.node),
- size_t_from_int(e.dst.idx)});
+ edges.insert(V1GraphEdge{
+ nodes.at(e.src.node), e.src.idx, nodes.at(e.dst.node), e.dst.idx});
}
return V1DataflowGraph{
diff --git a/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc
new file mode 100644
index 0000000000..dadad6277f
--- /dev/null
+++ b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc
@@ -0,0 +1,35 @@
+#include "pcg/parallel_computation_graph/generate_weight_transform.h"
+#include "op-attrs/dim_ordered/enumerate.h"
+#include "op-attrs/parallel_tensor_shape.h"
+
+namespace FlexFlow {
+
+std::unordered_set
+ generate_weight_transform(TensorShape const ¤t,
+ ParallelTensorShape const &goal) {
+ std::unordered_set result;
+
+ int sum_degree = get_sum_degree(goal);
+ if (sum_degree != 1) {
+ throw mk_runtime_error(
+ fmt::format("generate_weight_transform currently only supports "
+ "sum_degree = 1, but received {}",
+ sum_degree));
+ }
+
+ int discard_copy_degree = get_discard_copy_degree(goal);
+ if (discard_copy_degree != 1) {
+ result.insert(ParallelOpAttrs{ReplicateAttrs{discard_copy_degree}});
+ }
+
+ for (auto const &[shard_dim, shard_degree] :
+ enumerate(ff_ordered_shard_degrees(goal))) {
+ if (shard_degree != 1) {
+ result.insert(ParallelOpAttrs{RepartitionAttrs{shard_dim, shard_degree}});
+ }
+ }
+
+ return result;
+}
+
+} // namespace FlexFlow
diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc
index 831287567d..5b178160cd 100644
--- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc
+++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc
@@ -44,7 +44,7 @@ std::vector
get_layer_inputs(ParallelComputationGraph const &pcg,
parallel_layer_guid_t const &l) {
return transform(
- get_inputs(pcg.raw_graph, l.raw_graph_node),
+ get_input_values(pcg.raw_graph, l.raw_graph_node),
[](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; });
}
diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc
index b632c984bc..8290a2ff94 100644
--- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc
+++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc
@@ -1,6 +1,8 @@
#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h"
#include "op-attrs/ops/weight_attrs.dtg.h"
+#include "op-attrs/parallel_op_attrs.h"
#include "op-attrs/pcg_operator_attrs.h"
+#include "pcg/parallel_computation_graph/generate_weight_transform.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.h"
#include "utils/containers/concat_vectors.h"
#include "utils/containers/enumerate_vector.h"
@@ -326,11 +328,28 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention(
return this->add_layer(layer, {query, key, value}, weights, output_shape);
}
-parallel_tensor_guid_t ParallelComputationGraphBuilder::relu(
+parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm(
parallel_tensor_guid_t const &input,
+ bool relu,
std::optional const &maybe_name) {
- ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU, std::nullopt};
+ BatchNormAttrs attrs = BatchNormAttrs{relu};
+
+ std::string name =
+ maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs}));
+
+ ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name};
+
+ ParallelTensorShape output_shape =
+ get_output_shape(attrs, this->get_shape(input));
+
+ return this->add_layer(layer, {input}, {}, {output_shape});
+}
+
+parallel_tensor_guid_t ParallelComputationGraphBuilder::element_unary(
+ ElementUnaryAttrs const &attrs,
+ parallel_tensor_guid_t const &input,
+ std::optional const &maybe_name) {
std::string name =
maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs}));
@@ -343,6 +362,78 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::relu(
return this->add_layer(layer, {input}, {}, {output_shape});
}
+parallel_tensor_guid_t ParallelComputationGraphBuilder::relu(
+ parallel_tensor_guid_t const &input,
+ std::optional const &maybe_name) {
+
+ ElementUnaryAttrs attrs = ElementUnaryAttrs{
+ OperatorType::RELU,
+ std::nullopt,
+ };
+
+ return this->element_unary(attrs, input, maybe_name);
+}
+
+parallel_tensor_guid_t ParallelComputationGraphBuilder::identity(
+ parallel_tensor_guid_t const &input,
+ std::optional const &maybe_name) {
+
+ ElementUnaryAttrs attrs = ElementUnaryAttrs{
+ OperatorType::IDENTITY,
+ std::nullopt,
+ };
+
+ return this->element_unary(attrs, input, maybe_name);
+}
+
+parallel_tensor_guid_t ParallelComputationGraphBuilder::gelu(
+ parallel_tensor_guid_t const &input,
+ std::optional const &maybe_name) {
+
+ ElementUnaryAttrs attrs = ElementUnaryAttrs{
+ OperatorType::GELU,
+ std::nullopt,
+ };
+
+ return this->element_unary(attrs, input, maybe_name);
+}
+
+parallel_tensor_guid_t ParallelComputationGraphBuilder::sigmoid(
+ parallel_tensor_guid_t const &input,
+ std::optional const &maybe_name) {
+
+ ElementUnaryAttrs attrs = ElementUnaryAttrs{
+ OperatorType::SIGMOID,
+ std::nullopt,
+ };
+
+ return this->element_unary(attrs, input, maybe_name);
+}
+
+parallel_tensor_guid_t ParallelComputationGraphBuilder::tanh(
+ parallel_tensor_guid_t const &input,
+ std::optional const &maybe_name) {
+
+ ElementUnaryAttrs attrs = ElementUnaryAttrs{
+ OperatorType::TANH,
+ std::nullopt,
+ };
+
+ return this->element_unary(attrs, input, maybe_name);
+}
+
+parallel_tensor_guid_t ParallelComputationGraphBuilder::elu(
+ parallel_tensor_guid_t const &input,
+ std::optional const &maybe_name) {
+
+ ElementUnaryAttrs attrs = ElementUnaryAttrs{
+ OperatorType::ELU,
+ std::nullopt,
+ };
+
+ return this->element_unary(attrs, input, maybe_name);
+}
+
parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_partition(
parallel_tensor_guid_t const &input,
ff_dim_t dim,
@@ -441,6 +532,54 @@ ParallelTensorShape ParallelComputationGraphBuilder::get_shape(
return get_parallel_tensor_attrs(this->pcg, t).shape;
}
+parallel_tensor_guid_t ParallelComputationGraphBuilder::add_weight(
+ ParallelTensorAttrs const &weight_tensor_attrs,
+ std::optional const &weight_name) {
+ ParallelTensorShape par_weight_shape = weight_tensor_attrs.shape;
+ TensorShape unpar_weight_shape = get_reduced_shape(weight_tensor_attrs.shape);
+
+ ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{
+ PCGOperatorAttrs{WeightAttrs{unpar_weight_shape}},
+ weight_name,
+ };
+
+ std::vector weight_layer_inputs = {};
+ std::vector weight_output_attrs = {weight_tensor_attrs};
+
+ DataflowOutput current_raw_weight_tensor = get_only(
+ this->pcg.raw_graph
+ .add_node(
+ weight_layer_attrs, weight_layer_inputs, weight_output_attrs)
+ .outputs);
+ ParallelTensorShape current_shape = lift_to_parallel(unpar_weight_shape);
+
+ for (ParallelOpAttrs const ¶llel_op_attr :
+ generate_weight_transform(unpar_weight_shape, par_weight_shape)) {
+ ParallelTensorShape output_shape =
+ get_output_shape(parallel_op_attr, current_shape);
+ ParallelTensorAttrs output_attrs = ParallelTensorAttrs{
+ output_shape,
+ std::nullopt,
+ std::nullopt,
+ CreateGrad::YES,
+ };
+
+ ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{
+ pcg_op_attrs_from_parallel_op_attrs(parallel_op_attr),
+ std::nullopt,
+ };
+ current_raw_weight_tensor = get_only(
+ this->pcg.raw_graph
+ .add_node(layer_attrs, {current_raw_weight_tensor}, {output_attrs})
+ .outputs);
+ current_shape = output_shape;
+ }
+
+ assert(current_shape == par_weight_shape);
+
+ return parallel_tensor_guid_t{current_raw_weight_tensor};
+}
+
std::vector ParallelComputationGraphBuilder::add_layer(
ParallelLayerAttrs const &layer,
std::vector const &inputs,
@@ -455,18 +594,9 @@ std::vector ParallelComputationGraphBuilder::add_layer(
transform(layer.name, [&](std::string const &layer_name) {
return fmt::format("{}.weights[{}]", layer_name, weight_idx);
});
- ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{
- PCGOperatorAttrs{WeightAttrs{}},
- weight_name,
- };
- std::vector weight_layer_inputs = {};
- std::vector weight_output_attrs = {
- weight_tensor_attrs};
- raw_weight_tensors.push_back(get_only(this->pcg.raw_graph
- .add_node(weight_layer_attrs,
- weight_layer_inputs,
- weight_output_attrs)
- .outputs));
+
+ raw_weight_tensors.push_back(
+ this->add_weight(weight_tensor_attrs, weight_name).raw_graph_output);
}
std::vector raw_inputs =
diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc
new file mode 100644
index 0000000000..dca8154eb4
--- /dev/null
+++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc
@@ -0,0 +1,22 @@
+#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h"
+
+namespace FlexFlow {
+
+parallel_tensor_guid_t
+ get_parallel_tensor(ParallelComputationGraphEdge const &e) {
+ return parallel_tensor_guid_t{e.raw_edge.src};
+}
+
+parallel_layer_guid_t get_src_layer(ParallelComputationGraphEdge const &e) {
+ return parallel_layer_guid_t{e.raw_edge.src.node};
+}
+
+parallel_layer_guid_t get_dst_layer(ParallelComputationGraphEdge const &e) {
+ return parallel_layer_guid_t{e.raw_edge.dst.node};
+}
+
+int get_dst_layer_input_idx(ParallelComputationGraphEdge const &e) {
+ return e.raw_edge.dst.idx;
+}
+
+} // namespace FlexFlow
diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc
new file mode 100644
index 0000000000..ad4eae041f
--- /dev/null
+++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc
@@ -0,0 +1,9 @@
+#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h"
+
+namespace FlexFlow {
+
+parallel_layer_guid_t get_source_layer(parallel_tensor_guid_t const &t) {
+ return parallel_layer_guid_t{t.raw_graph_output.node};
+}
+
+} // namespace FlexFlow
diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc
index db01728cf0..440f735e80 100644
--- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc
+++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc
@@ -2,6 +2,7 @@
#include "op-attrs/parallel_tensor_shape.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.h"
#include "pcg/parallel_computation_graph/parallel_layer_attrs.h"
+#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h"
#include "test/utils/doctest.h"
#include "utils/containers/count.h"
#include "utils/containers/generate_map.h"
@@ -39,7 +40,7 @@ TEST_SUITE(FF_TEST_SUITE) {
parallel_tensor_guid_t rhs = b.create_input_tensor(rhs_shape);
parallel_tensor_guid_t out = b.add(lhs, rhs);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, out);
+ parallel_layer_guid_t layer = get_source_layer(out);
SUBCASE("inputs") {
std::vector result =
@@ -102,7 +103,7 @@ TEST_SUITE(FF_TEST_SUITE) {
parallel_tensor_guid_t b_tensor = b.create_input_tensor(b_shape);
parallel_tensor_guid_t out = b.batch_matmul(a_tensor, b_tensor);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, out);
+ parallel_layer_guid_t layer = get_source_layer(out);
SUBCASE("inputs") {
std::vector result =
@@ -145,7 +146,7 @@ TEST_SUITE(FF_TEST_SUITE) {
DataType output_datatype = DataType::DOUBLE;
parallel_tensor_guid_t input = b.create_input_tensor(input_shape);
parallel_tensor_guid_t output = b.cast(input, output_datatype);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
@@ -205,7 +206,7 @@ TEST_SUITE(FF_TEST_SUITE) {
[&](parallel_layer_guid_t const &l) {
return get_parallel_layer_attrs(b.pcg, l);
});
- CHECK_MESSAGE(layers.size() == 4, "Incorrect layers ", layers);
+ CHECK_MESSAGE(layers.size() == 6, "Incorrect layers ", layers);
auto num_attrs_of_type = [&](OperatorType op_type) -> int {
return count(values(layers), [&](ParallelLayerAttrs const &l) {
@@ -222,6 +223,9 @@ TEST_SUITE(FF_TEST_SUITE) {
int num_conv_attrs = num_attrs_of_type(OperatorType::CONV2D);
CHECK(num_conv_attrs == 1);
+ int num_replicate_attrs = num_attrs_of_type(OperatorType::REPLICATE);
+ CHECK(num_replicate_attrs == 2);
+
parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform(
as_vector(items(layers)),
[](std::pair const &kv)
@@ -307,7 +311,7 @@ TEST_SUITE(FF_TEST_SUITE) {
Activation::RELU,
/*use_bias=*/true,
DataType::FLOAT);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
@@ -350,7 +354,7 @@ TEST_SUITE(FF_TEST_SUITE) {
/*outDim=*/8,
AggregateOp::SUM,
DataType::FLOAT);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
@@ -400,7 +404,7 @@ TEST_SUITE(FF_TEST_SUITE) {
parallel_tensor_guid_t value = b.create_input_tensor(value_shape);
parallel_tensor_guid_t output =
b.multihead_attention(query, key, value, embed_dim, num_heads);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
@@ -441,7 +445,7 @@ TEST_SUITE(FF_TEST_SUITE) {
parallel_tensor_guid_t input = b.create_input_tensor(input_shape);
parallel_tensor_guid_t output = b.relu(input);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
@@ -480,7 +484,7 @@ TEST_SUITE(FF_TEST_SUITE) {
parallel_tensor_guid_t input = b.create_input_tensor(input_shape);
parallel_tensor_guid_t output = b.parallel_partition(input, ff_dim_t{0}, 2);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
@@ -519,7 +523,7 @@ TEST_SUITE(FF_TEST_SUITE) {
parallel_tensor_guid_t input = b.create_input_tensor(input_shape);
parallel_tensor_guid_t output = b.parallel_combine(input, ff_dim_t{0}, 2);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
@@ -558,7 +562,7 @@ TEST_SUITE(FF_TEST_SUITE) {
parallel_tensor_guid_t input = b.create_input_tensor(input_shape);
parallel_tensor_guid_t output = b.parallel_replicate(input, 2);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
@@ -597,7 +601,7 @@ TEST_SUITE(FF_TEST_SUITE) {
parallel_tensor_guid_t input = b.create_input_tensor(input_shape);
parallel_tensor_guid_t output = b.parallel_reduce(input, 2);
- parallel_layer_guid_t layer = get_source_layer(b.pcg, output);
+ parallel_layer_guid_t layer = get_source_layer(output);
SUBCASE("inputs") {
std::vector result =
diff --git a/lib/substitution-generator/CMakeLists.txt b/lib/substitution-generator/CMakeLists.txt
index 41005e6a4e..1db0d888ba 100644
--- a/lib/substitution-generator/CMakeLists.txt
+++ b/lib/substitution-generator/CMakeLists.txt
@@ -11,6 +11,7 @@ ff_add_library(
utils
op-attrs
pcg
+ substitutions
)
# add_subdirectory(ffi)
diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h
deleted file mode 100644
index 5563d8a835..0000000000
--- a/lib/substitution-generator/include/substitution-generator/json.h
+++ /dev/null
@@ -1,59 +0,0 @@
-#ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H
-#define _FLEXFLOW_SUBSTITUTION_LOADER_H
-
-#include "substitution-generator/legacy_operator_type.dtg.h"
-#include "substitution-generator/legacy_pm_parameter.dtg.h"
-#include
-#include
-#include
-
-namespace FlexFlow {
-
-struct Parameter {
- LegacyPMParameter key;
- int value;
-};
-void from_json(nlohmann::json const &j, Parameter &p);
-
-struct Tensor {
- int opId;
- int tsId;
-};
-void from_json(nlohmann::json const &j, Tensor &t);
-
-struct Operator {
- LegacyOperatorType op_type;
- std::vector input;
- std::vector para;
-
- std::optional at(LegacyPMParameter key) const;
-};
-void from_json(nlohmann::json const &j, Operator &t);
-
-struct MapOutput {
- int dstOpId;
- int dstTsId;
- int srcOpId;
- int srcTsId;
-};
-void from_json(nlohmann::json const &j, MapOutput &t);
-
-struct Rule {
- std::string name;
- std::vector srcOp;
- std::vector dstOp;
- std::vector mappedOutput;
-};
-void from_json(nlohmann::json const &j, Rule &t);
-
-struct RuleCollection {
- std::vector rules;
-};
-void from_json(nlohmann::json const &j, RuleCollection &c);
-
-RuleCollection load_rule_collection(std::istream &s);
-RuleCollection load_rule_collection_from_path(std::string const &path);
-
-} // namespace FlexFlow
-
-#endif // _FLEXFLOW_SUBSTITUTION_LOADER_H
diff --git a/lib/substitution-generator/include/substitution-generator/legacy_rules.h b/lib/substitution-generator/include/substitution-generator/legacy_rules.h
new file mode 100644
index 0000000000..a0e0a9790a
--- /dev/null
+++ b/lib/substitution-generator/include/substitution-generator/legacy_rules.h
@@ -0,0 +1,59 @@
+#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_RULES_H
+#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_RULES_H
+
+#include "substitution-generator/legacy_operator_type.dtg.h"
+#include "substitution-generator/legacy_pm_parameter.dtg.h"
+#include
+#include
+#include
+
+namespace FlexFlow {
+
+struct LegacyParameter {
+ LegacyPMParameter key;
+ int value;
+};
+void from_json(nlohmann::json const &j, LegacyParameter &p);
+
+struct LegacyTensor {
+ int opId;
+ int tsId;
+};
+void from_json(nlohmann::json const &j, LegacyTensor &t);
+
+struct LegacyOperator {
+ LegacyOperatorType op_type;
+ std::vector input;
+ std::vector para;
+
+ std::optional at(LegacyPMParameter key) const;
+};
+void from_json(nlohmann::json const &j, LegacyOperator &t);
+
+struct LegacyMapOutput {
+ int dstOpId;
+ int dstTsId;
+ int srcOpId;
+ int srcTsId;
+};
+void from_json(nlohmann::json const &j, LegacyMapOutput &t);
+
+struct LegacyRule {
+ std::string name;
+ std::vector srcOp;
+ std::vector dstOp;
+ std::vector mappedOutput;
+};
+void from_json(nlohmann::json const &j, LegacyRule &t);
+
+struct LegacyRuleCollection {
+ std::vector rules;
+};
+void from_json(nlohmann::json const &j, LegacyRuleCollection &c);
+
+LegacyRuleCollection load_rule_collection(std::istream &s);
+LegacyRuleCollection load_rule_collection_from_path(std::string const &path);
+
+} // namespace FlexFlow
+
+#endif // _FLEXFLOW_SUBSTITUTION_LOADER_H
diff --git a/lib/substitution-generator/src/substitution-generator/json.cc b/lib/substitution-generator/src/substitution-generator/legacy_rules.cc
similarity index 67%
rename from lib/substitution-generator/src/substitution-generator/json.cc
rename to lib/substitution-generator/src/substitution-generator/legacy_rules.cc
index 940ecb3e36..157f062cbf 100644
--- a/lib/substitution-generator/src/substitution-generator/json.cc
+++ b/lib/substitution-generator/src/substitution-generator/legacy_rules.cc
@@ -1,4 +1,4 @@
-#include "substitution-generator/json.h"
+#include "substitution-generator/legacy_rules.h"
#include
#include
#include
@@ -7,12 +7,12 @@ using json = nlohmann::json;
namespace FlexFlow {
-void from_json(json const &j, Parameter &p) {
+void from_json(json const &j, LegacyParameter &p) {
j.at("key").get_to(p.key);
j.at("value").get_to(p.value);
}
-void from_json(json const &j, Tensor &t) {
+void from_json(json const &j, LegacyTensor &t) {
j.at("opId").get_to(t.opId);
j.at("tsId").get_to(t.tsId);
}
@@ -29,38 +29,38 @@ void from_json(json const &j, Tensor &t) {
/* return value; */
/* } */
-void from_json(json const &j, Operator &o) {
+void from_json(json const &j, LegacyOperator &o) {
j.at("type").get_to(o.op_type);
j.at("input").get_to(o.input);
j.at("para").get_to(o.para);
}
-void from_json(json const &j, MapOutput &m) {
+void from_json(json const &j, LegacyMapOutput &m) {
j.at("dstOpId").get_to(m.dstOpId);
j.at("dstTsId").get_to(m.dstTsId);
j.at("srcOpId").get_to(m.srcOpId);
j.at("srcTsId").get_to(m.srcTsId);
}
-void from_json(json const &j, Rule &r) {
+void from_json(json const &j, LegacyRule &r) {
j.at("name").get_to(r.name);
j.at("srcOp").get_to(r.srcOp);
j.at("dstOp").get_to(r.dstOp);
j.at("mappedOutput").get_to(r.mappedOutput);
}
-void from_json(json const &j, RuleCollection &c) {
+void from_json(json const &j, LegacyRuleCollection &c) {
j.at("rule").get_to(c.rules);
}
-RuleCollection load_rule_collection(std::istream &s) {
+LegacyRuleCollection load_rule_collection(std::istream &s) {
json j;
s >> j;
- RuleCollection rule_collection = j;
+ LegacyRuleCollection rule_collection = j;
return rule_collection;
}
-RuleCollection load_rule_collection_from_path(std::string const &path) {
+LegacyRuleCollection load_rule_collection_from_path(std::string const &path) {
std::ifstream input(path);
return load_rule_collection(input);
}
diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/legacy_rules.cc
similarity index 88%
rename from lib/substitution-generator/test/substitution-generator/json.cc
rename to lib/substitution-generator/test/substitution-generator/legacy_rules.cc
index befdaf1308..4dd9bb8cc4 100644
--- a/lib/substitution-generator/test/substitution-generator/json.cc
+++ b/lib/substitution-generator/test/substitution-generator/legacy_rules.cc
@@ -1,4 +1,4 @@
-#include "substitution-generator/json.h"
+#include "substitution-generator/legacy_rules.h"
#include "doctest/doctest.h"
using namespace FlexFlow;
@@ -15,7 +15,7 @@ TEST_SUITE(FF_TEST_SUITE) {
{"type", "OP_EW_ADD"},
};
- Operator o;
+ LegacyOperator o;
from_json(j, o);
CHECK(o.op_type == LegacyOperatorType::EW_ADD);
@@ -28,7 +28,7 @@ TEST_SUITE(FF_TEST_SUITE) {
}
TEST_CASE("deserialize full file") {
- RuleCollection collection =
+ LegacyRuleCollection collection =
load_rule_collection_from_path("graph_subst_3_v2.json");
CHECK(collection.rules.size() == 640);
}
diff --git a/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml
new file mode 100644
index 0000000000..dd2e850aed
--- /dev/null
+++ b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml
@@ -0,0 +1,16 @@
+namespace = "FlexFlow"
+name = "input_parallel_tensor_guid_t"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+]
+
+includes = [
+ "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h",
+]
+
+[[fields]]
+name = "raw_dataflow_graph_input"
+type = "::FlexFlow::DataflowGraphInput"
diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h
new file mode 100644
index 0000000000..ad60d50db1
--- /dev/null
+++ b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h
@@ -0,0 +1,29 @@
+#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPEN_PARALLEL_TENSOR_GUID_T_H
+#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPEN_PARALLEL_TENSOR_GUID_T_H
+
+#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h"
+#include "substitutions/input_parallel_tensor_guid_t.dtg.h"
+#include "substitutions/open_parallel_tensor_guid_t.dtg.h"
+#include "utils/overload.h"
+
+namespace FlexFlow {
+
+open_parallel_tensor_guid_t
+ open_parallel_tensor_guid_from_closed(parallel_tensor_guid_t);
+open_parallel_tensor_guid_t
+ open_parallel_tensor_guid_from_input(input_parallel_tensor_guid_t);
+
+template >
+Ret visit_open_parallel_tensor_guid(open_parallel_tensor_guid_t t, F f) {
+ return t.raw_open_dataflow_value.visit(overload{
+ [&](DataflowOutput const &o) { return f(parallel_tensor_guid_t{o}); },
+ [&](DataflowGraphInput const &i) {
+ return f(input_parallel_tensor_guid_t{i});
+ },
+ });
+}
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml
new file mode 100644
index 0000000000..f07dc12d62
--- /dev/null
+++ b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml
@@ -0,0 +1,16 @@
+namespace = "FlexFlow"
+name = "open_parallel_tensor_guid_t"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+]
+
+includes = [
+ "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h"
+]
+
+[[fields]]
+name = "raw_open_dataflow_value"
+type = "::FlexFlow::OpenDataflowValue"
diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h
new file mode 100644
index 0000000000..4affdd697f
--- /dev/null
+++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h
@@ -0,0 +1,18 @@
+#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_H
+#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_H
+
+#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h"
+
+namespace FlexFlow {
+
+OperatorAttributeConstraint op_type_equals_constraint(OperatorType);
+
+OperatorAttributeConstraint op_attr_key_equals(OperatorAttributeKey,
+ OperatorAttributeValue const &);
+OperatorAttributeConstraint
+ make_equals_constraint(OperatorAttributeExpr const &,
+ OperatorAttributeValue const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h
index e63c03207b..a6324863a6 100644
--- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h
+++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h
@@ -9,8 +9,8 @@
namespace FlexFlow {
std::optional
- evaluate_attribute_expr(PCGOperatorAttrs const &attrs,
- OperatorAttributeExpr const &expr);
+ evaluate_attribute_expr(OperatorAttributeExpr const &expr,
+ PCGOperatorAttrs const &attrs);
} // namespace FlexFlow
#endif
diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml
index da2feb1903..7df65ef361 100644
--- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml
+++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml
@@ -35,6 +35,9 @@ type = "int"
[[values]]
type = "bool"
+[[values]]
+type = "float"
+
[[values]]
type = "std::vector"
@@ -45,7 +48,7 @@ type = "std::vector<::FlexFlow::ff_dim_t>"
type = "::FlexFlow::OperatorType"
[[values]]
-type = "::FlexFlow::Activation"
+type = "std::optional<::FlexFlow::Activation>"
[[values]]
type = "::FlexFlow::ff_dim_t"
diff --git a/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h b/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h
new file mode 100644
index 0000000000..cc2fac4805
--- /dev/null
+++ b/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h
@@ -0,0 +1,15 @@
+#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_MATERIALIZE_OPERATOR_FROM_ATTRS_MAP_H
+#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_MATERIALIZE_OPERATOR_FROM_ATTRS_MAP_H
+
+#include "op-attrs/pcg_operator_attrs.dtg.h"
+#include "substitutions/operator_pattern/operator_attribute_key.dtg.h"
+#include "substitutions/operator_pattern/operator_attribute_value.dtg.h"
+
+namespace FlexFlow {
+
+PCGOperatorAttrs materialize_operator_from_attrs_map(
+ std::unordered_map const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h
new file mode 100644
index 0000000000..e550767292
--- /dev/null
+++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h
@@ -0,0 +1,15 @@
+#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_H
+#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_H
+
+#include "substitutions/output_graph/output_graph_expr.dtg.h"
+#include "substitutions/output_graph/output_graph_expr_node.dtg.h"
+#include "substitutions/output_graph/output_graph_expr_node_output.dtg.h"
+
+namespace FlexFlow {
+
+std::vector
+ get_node_outputs(OutputGraphExpr const &, OutputGraphExprNode const &);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml
index 5caeff92f5..9ad65369a9 100644
--- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml
+++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml
@@ -5,8 +5,9 @@ features = []
includes = [
"utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h",
"substitutions/output_graph/output_operator_attrs_assignment.dtg.h",
+ "",
]
[[fields]]
name = "raw_graph"
-type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::nullopt_t>"
+type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::monostate>"
diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml
new file mode 100644
index 0000000000..fe7a861f0a
--- /dev/null
+++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml
@@ -0,0 +1,16 @@
+namespace = "FlexFlow"
+name = "OutputGraphExprInput"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+]
+
+includes = [
+ "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h",
+]
+
+[[fields]]
+name = "raw_dataflow_graph_input"
+type = "::FlexFlow::DataflowGraphInput"
diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml
new file mode 100644
index 0000000000..37c2a1f563
--- /dev/null
+++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml
@@ -0,0 +1,16 @@
+namespace = "FlexFlow"
+name = "OutputGraphExprNode"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+]
+
+includes = [
+ "utils/graph/node/node.dtg.h"
+]
+
+[[fields]]
+name = "raw_graph_node"
+type = "::FlexFlow::Node"
diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml
new file mode 100644
index 0000000000..7a2072e385
--- /dev/null
+++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml
@@ -0,0 +1,16 @@
+namespace = "FlexFlow"
+name = "OutputGraphExprNodeOutput"
+features = [
+ "eq",
+ "ord",
+ "hash",
+ "fmt",
+]
+
+includes = [
+ "utils/graph/dataflow_graph/dataflow_output.dtg.h",
+]
+
+[[fields]]
+name = "raw_dataflow_output"
+type = "::FlexFlow::DataflowOutput"
diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml
index 5527635a2e..e856249e50 100644
--- a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml
+++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml
@@ -8,13 +8,13 @@ features = [
]
includes = [
- "utils/graph/node/node.dtg.h",
+ "substitutions/unlabelled/pattern_node.dtg.h",
"substitutions/operator_pattern/operator_attribute_expr.dtg.h",
]
[[fields]]
name = "node"
-type = "::FlexFlow::Node"
+type = "::FlexFlow::PatternNode"
# NOTE(@wmdi) I am not sure whether these should be part of attribute expr.
[[fields]]
diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h
new file mode 100644
index 0000000000..cba095b444
--- /dev/null
+++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h
@@ -0,0 +1,15 @@
+#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_H
+#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_H
+
+#include "output_operator_attribute_expr.dtg.h"
+#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h"
+
+namespace FlexFlow {
+
+OperatorAttributeValue evaluate_output_operator_attribute_expr(
+ OutputOperatorAttributeExpr const &,
+ std::unordered_map const &node_match);
+
+} // namespace FlexFlow
+
+#endif
diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h
new file mode 100644
index 0000000000..60540c0711
--- /dev/null
+++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h
@@ -0,0 +1,25 @@
+#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_H
+#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_H
+
+#include "op-attrs/pcg_operator_attrs.dtg.h"
+#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h"
+#include "substitutions/unlabelled/pattern_node.dtg.h"
+
+namespace FlexFlow {
+
+OutputOperatorAttrsAssignment output_operator_clone_node(PatternNode const &);
+
+PCGOperatorAttrs materialize_output_operator_from_attrs_assignment(
+ OutputOperatorAttrsAssignment const &attrs_assignment,
+ std::unordered_map