Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/op-attrs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ff_add_library(
PUBLIC_INCLUDE
include/
PRIVATE_INCLUDE
src/
src/
DEPS
utils
)
Expand Down
26 changes: 26 additions & 0 deletions lib/op-attrs/include/op-attrs/datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,32 @@ struct formatter<::FlexFlow::DataType> : formatter<string_view> {
}
};

template <>
struct formatter<::FlexFlow::DataTypeValue> : formatter<string_view> {
template <typename FormatContext>
auto format(::FlexFlow::DataTypeValue v, FormatContext &ctx)
-> decltype(ctx.out()) {
using namespace FlexFlow;

string_view s = "unknown";
if (auto const *f32 = get_if<real_type<DataType::FLOAT>>(&v)) {
s = fmt::to_string(*f32);
} else if (auto const *f64 = get_if<real_type<DataType::DOUBLE>>(&v)) {
s = fmt::to_string(*f64);
} else if (auto const *i32 = get_if<real_type<DataType::INT32>>(&v)) {
s = fmt::to_string(*i32);
} else if (auto const *i64 = get_if<real_type<DataType::INT64>>(&v)) {
s = fmt::to_string(*i64);
} else if (auto const *h = get_if<real_type<DataType::HALF>>(&v)) {
s = fmt::to_string(*h);
} else if (auto const *b = get_if<real_type<DataType::BOOL>>(&v)) {
s = fmt::to_string(*b);
}
return formatter<string_view>::format(s, ctx);
}
};

} // namespace fmt
// namespace fmt

#endif
4 changes: 3 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ struct AggregateAttrs {
req<float> lambda_bal;
};
FF_VISITABLE_STRUCT(AggregateAttrs, n, lambda_bal);
FF_VISIT_FMTABLE(AggregateAttrs);
CHECK_FMTABLE(AggregateAttrs);
CHECK_VALID_OP_ATTR(AggregateAttrs);

DataType get_datatype(AggregateAttrs const &);
bool is_valid(AggregateAttrs const &,
Expand All @@ -32,7 +35,6 @@ ParallelTensorShape
ParallelTensorShape const &full_gate_gradients,
std::vector<ParallelTensorShape> const &exp_preds);

CHECK_VALID_OP_ATTR(AggregateAttrs);
} // namespace FlexFlow

#endif
4 changes: 3 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/aggregate_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ struct AggregateSpecAttrs {
req<float> lambda_bal;
};
FF_VISITABLE_STRUCT(AggregateSpecAttrs, n, lambda_bal);
FF_VISIT_FMTABLE(AggregateSpecAttrs);
CHECK_FMTABLE(AggregateSpecAttrs);
CHECK_VALID_OP_ATTR(AggregateSpecAttrs);

ParallelTensorShape
get_output_shape(AggregateSpecAttrs const &,
Expand All @@ -21,7 +24,6 @@ ParallelTensorShape
ParallelTensorShape const &gate_gradients_full,
std::vector<ParallelTensorShape> const &exp_preds);

CHECK_VALID_OP_ATTR(AggregateSpecAttrs);
} // namespace FlexFlow

#endif
4 changes: 3 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs,
bias,
add_bias_kv,
add_zero_attn);
FF_VISIT_FMTABLE(MultiHeadAttentionAttrs);
CHECK_FMTABLE(MultiHeadAttentionAttrs);
CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs);

template <typename TensorType>
struct MultiHeadAttentionInputs
Expand Down Expand Up @@ -70,7 +73,6 @@ ParallelTensorShape
TensorShape get_output_shape(MultiHeadAttentionAttrs const &,
MultiHeadAttentionInputs<TensorShape> const &);

CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs);
} // namespace FlexFlow

#endif
3 changes: 2 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/batch_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ struct BatchMatmulAttrs {
req<int> a_seq_length_dim, b_seq_length_dim;
};
FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim);

FF_VISIT_FMTABLE(BatchMatmulAttrs);
CHECK_FMTABLE(BatchMatmulAttrs);
CHECK_VALID_OP_ATTR(BatchMatmulAttrs);

} // namespace FlexFlow
Expand Down
5 changes: 3 additions & 2 deletions lib/op-attrs/include/op-attrs/ops/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ struct BatchNormAttrs {
req<bool> relu;
};
FF_VISITABLE_STRUCT(BatchNormAttrs, relu);
FF_VISIT_FMTABLE(BatchNormAttrs);
CHECK_FMTABLE(BatchNormAttrs);
CHECK_VALID_OP_ATTR(BatchNormAttrs);

ParallelTensorShape get_output_shape(BatchNormAttrs const &);

CHECK_VALID_OP_ATTR(BatchNormAttrs);

} // namespace FlexFlow

#endif
3 changes: 2 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ struct BroadcastAttrs {
req<stack_vector<int, MAX_TENSOR_DIM>> target_dims;
};
FF_VISITABLE_STRUCT(BroadcastAttrs, target_dims);

FF_VISIT_FMTABLE(BroadcastAttrs);
CHECK_FMTABLE(BroadcastAttrs);
CHECK_VALID_OP_ATTR(BroadcastAttrs);

} // namespace FlexFlow
Expand Down
4 changes: 3 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ struct CastAttrs {
req<DataType> dtype;
};
FF_VISITABLE_STRUCT(CastAttrs, dtype);

FF_VISIT_FMTABLE(CastAttrs);
CHECK_FMTABLE(CastAttrs);
CHECK_VALID_OP_ATTR(CastAttrs);

} // namespace FlexFlow

#endif
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/combine.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct CombineAttrs {
req<int> combine_degree;
};
FF_VISITABLE_STRUCT(CombineAttrs, combine_dim, combine_degree);
FF_VISIT_FMTABLE(CombineAttrs);
CHECK_FMTABLE(CombineAttrs);
CHECK_VALID_OP_ATTR(CombineAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct ConcatAttrs {
ff_dim_t axis;
};
FF_VISITABLE_STRUCT(ConcatAttrs, axis);
FF_VISIT_FMTABLE(ConcatAttrs);
CHECK_FMTABLE(ConcatAttrs);
CHECK_VALID_OP_ATTR(ConcatAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/conv_2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ FF_VISITABLE_STRUCT(Conv2DAttrs,
groups,
activation,
use_bias);
FF_VISIT_FMTABLE(Conv2DAttrs);
CHECK_FMTABLE(Conv2DAttrs);
CHECK_VALID_OP_ATTR(Conv2DAttrs);

TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &);
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/dropout.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct DropoutAttrs {
req<unsigned long long> seed;
};
FF_VISITABLE_STRUCT(DropoutAttrs, rate, seed);
FF_VISIT_FMTABLE(DropoutAttrs);
CHECK_FMTABLE(DropoutAttrs);
CHECK_VALID_OP_ATTR(DropoutAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/element_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ FF_VISITABLE_STRUCT(ElementBinaryAttrs,
compute_type,
should_broadcast_lhs,
should_broadcast_rhs);
FF_VISIT_FMTABLE(ElementBinaryAttrs);
CHECK_FMTABLE(ElementBinaryAttrs);
CHECK_VALID_OP_ATTR(ElementBinaryAttrs);

} // namespace FlexFlow
Expand Down
4 changes: 4 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/element_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@ struct ElementScalarUnaryAttrs {
req<float> scalar;
};
FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op, scalar);
FF_VISIT_FMTABLE(ElementScalarUnaryAttrs);
CHECK_FMTABLE(ElementScalarUnaryAttrs);
CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs);

struct ElementUnaryAttrs {
req<Op> op;
};
FF_VISITABLE_STRUCT(ElementUnaryAttrs, op);
FF_VISIT_FMTABLE(ElementUnaryAttrs);
CHECK_FMTABLE(ElementUnaryAttrs);
CHECK_VALID_OP_ATTR(ElementUnaryAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ struct EmbeddingAttrs {
req<DataType> data_type;
};
FF_VISITABLE_STRUCT(EmbeddingAttrs, num_entries, out_channels, aggr, data_type);
FF_VISIT_FMTABLE(EmbeddingAttrs);
CHECK_FMTABLE(EmbeddingAttrs);
CHECK_VALID_OP_ATTR(EmbeddingAttrs);

TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &);
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace FlexFlow {

struct FlatAttrs {};
FF_VISITABLE_STRUCT(FlatAttrs);
FF_VISIT_FMTABLE(FlatAttrs);
CHECK_FMTABLE(FlatAttrs);
CHECK_VALID_OP_ATTR(FlatAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct GatherAttrs {
ff_dim_t dim;
};
FF_VISITABLE_STRUCT(GatherAttrs, dim);
FF_VISIT_FMTABLE(GatherAttrs);
CHECK_FMTABLE(GatherAttrs);
CHECK_VALID_OP_ATTR(GatherAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/groupby.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct Group_byAttrs {
req<float> alpha;
};
FF_VISITABLE_STRUCT(Group_byAttrs, n, alpha);
FF_VISIT_FMTABLE(Group_byAttrs);
CHECK_FMTABLE(Group_byAttrs);
CHECK_VALID_OP_ATTR(Group_byAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/input.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace FlexFlow {

struct InputAttrs {};
FF_VISITABLE_STRUCT(InputAttrs);
FF_VISIT_FMTABLE(InputAttrs);
CHECK_FMTABLE(InputAttrs);
CHECK_VALID_OP_ATTR(InputAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ struct LayerNormAttrs {
req<float> eps;
};
FF_VISITABLE_STRUCT(LayerNormAttrs, axes, elementwise_affine, eps);
FF_VISIT_FMTABLE(LayerNormAttrs);
CHECK_FMTABLE(LayerNormAttrs);
CHECK_VALID_OP_ATTR(LayerNormAttrs);

} // namespace FlexFlow
Expand Down
6 changes: 6 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@ struct L1RegularizerAttrs {
req<float> lambda;
};
FF_VISITABLE_STRUCT(L1RegularizerAttrs, lambda);
FF_VISIT_FMTABLE(L1RegularizerAttrs);
CHECK_FMTABLE(L1RegularizerAttrs);
CHECK_VALID_OP_ATTR(L1RegularizerAttrs);

struct L2RegularizerAttrs {
req<float> lambda;
};
FF_VISITABLE_STRUCT(L2RegularizerAttrs, lambda);
FF_VISIT_FMTABLE(L2RegularizerAttrs);
CHECK_FMTABLE(L2RegularizerAttrs);
CHECK_VALID_OP_ATTR(L2RegularizerAttrs);

using RegularizerAttrs = variant<L1RegularizerAttrs, L2RegularizerAttrs>;
Expand All @@ -32,6 +36,8 @@ struct LinearAttrs {
};
FF_VISITABLE_STRUCT(
LinearAttrs, out_channels, use_bias, data_type, activation, regularizer);
FF_VISIT_FMTABLE(LinearAttrs);
CHECK_FMTABLE(LinearAttrs);
CHECK_VALID_OP_ATTR(LinearAttrs);

} // namespace FlexFlow
Expand Down
4 changes: 4 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/loss_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ struct SparseCategoricalCrossEntropyLossAttrs {
req<bool> replace_labels; // for aggregate_spec: More predictions than labels
};
FF_VISITABLE_STRUCT(SparseCategoricalCrossEntropyLossAttrs, replace_labels);
FF_VISIT_FMTABLE(SparseCategoricalCrossEntropyLossAttrs);
CHECK_FMTABLE(SparseCategoricalCrossEntropyLossAttrs);
CHECK_VALID_OP_ATTR(SparseCategoricalCrossEntropyLossAttrs);

struct OtherLossAttrs {
req<LossFunction> loss_type;
};
FF_VISITABLE_STRUCT(OtherLossAttrs, loss_type);
FF_VISIT_FMTABLE(OtherLossAttrs);
CHECK_FMTABLE(OtherLossAttrs);
CHECK_VALID_OP_ATTR(OtherLossAttrs);

using LossAttrs =
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/noop.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace FlexFlow {

struct NoopAttrs {};
FF_VISITABLE_STRUCT(NoopAttrs);
FF_VISIT_FMTABLE(NoopAttrs);
CHECK_FMTABLE(NoopAttrs);
CHECK_VALID_OP_ATTR(NoopAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/pool_2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ FF_VISITABLE_STRUCT(Pool2DAttrs,
padding_w,
pool_type,
activation);
FF_VISIT_FMTABLE(Pool2DAttrs);
CHECK_FMTABLE(Pool2DAttrs);
CHECK_VALID_OP_ATTR(Pool2DAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct ReduceAttrs {
req<bool> keepdims;
};
FF_VISITABLE_STRUCT(ReduceAttrs, axes, op_type, keepdims);
FF_VISIT_FMTABLE(ReduceAttrs);
CHECK_FMTABLE(ReduceAttrs);
CHECK_VALID_OP_ATTR(ReduceAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct ReductionAttrs {
req<int> reduction_degree;
};
FF_VISITABLE_STRUCT(ReductionAttrs, reduction_dim, reduction_degree);
FF_VISIT_FMTABLE(ReductionAttrs);
CHECK_FMTABLE(ReductionAttrs);
CHECK_VALID_OP_ATTR(ReductionAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/repartition.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct RepartitionAttrs {
req<int> repartition_degree;
};
FF_VISITABLE_STRUCT(RepartitionAttrs, repartition_dim, repartition_degree);
FF_VISIT_FMTABLE(RepartitionAttrs);
CHECK_FMTABLE(RepartitionAttrs);
CHECK_VALID_OP_ATTR(RepartitionAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/replicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct ReplicateAttrs {
req<int> replicate_degree;
};
FF_VISITABLE_STRUCT(ReplicateAttrs, replicate_dim, replicate_degree);
FF_VISIT_FMTABLE(ReplicateAttrs);
CHECK_FMTABLE(ReplicateAttrs);
CHECK_VALID_OP_ATTR(ReplicateAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/reshape.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ struct ReshapeAttrs {
TensorShape shape;
};
FF_VISITABLE_STRUCT(ReshapeAttrs, shape);
FF_VISIT_FMTABLE(ReshapeAttrs);
CHECK_FMTABLE(ReshapeAttrs);
CHECK_VALID_OP_ATTR(ReshapeAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/reverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ struct ReverseAttrs {
ff_dim_t axis;
};
FF_VISITABLE_STRUCT(ReverseAttrs, axis);
FF_VISIT_FMTABLE(ReverseAttrs);
CHECK_FMTABLE(ReverseAttrs);
CHECK_VALID_OP_ATTR(ReverseAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct SoftmaxAttrs {
ff_dim_t dim;
};
FF_VISITABLE_STRUCT(SoftmaxAttrs, dim);
FF_VISIT_FMTABLE(SoftmaxAttrs);
CHECK_FMTABLE(SoftmaxAttrs);
CHECK_VALID_OP_ATTR(SoftmaxAttrs);

} // namespace FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct SplitAttrs {
ff_dim_t axis;
};
FF_VISITABLE_STRUCT(SplitAttrs, splits, axis);
FF_VISIT_FMTABLE(SplitAttrs);
CHECK_FMTABLE(SplitAttrs);
CHECK_VALID_OP_ATTR(SplitAttrs);

} // namespace FlexFlow
Expand Down
Loading