Skip to content
  •  
  •  
  •  
1 change: 1 addition & 0 deletions .proj.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ build_targets = [
# "substitutions",
# "compiler",
"substitution-generator",
"local-execution",
]
test_targets = [
"utils-tests",
Expand Down
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
"-DFF_USE_EXTERNAL_TYPE_INDEX=ON"
];

RC_PARAMS = "max_discard_ratio=100";

buildInputs = builtins.concatLists [
(with pkgs; [
zlib
Expand Down Expand Up @@ -110,7 +112,7 @@

default = mkShell {
inputsFrom = [ ci ];
inherit (ci) CMAKE_FLAGS;
inherit (ci) CMAKE_FLAGS RC_PARAMS;

VIMPLUGINS = lib.strings.concatStringsSep "," [
"${proj-repo.packages.${system}.proj-nvim}"
Expand Down
4 changes: 2 additions & 2 deletions lib/compiler/test/src/test_machine_mapping.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

TEST_SUITE(FF_TEST_SUITE) {
// TEST_CASE("MachineMapping::combine") {
// rc::check([](MachineMapping const &m0, MachineMapping const &m1) {
// RC_SUBCASE([](MachineMapping const &m0, MachineMapping const &m1) {
// RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1));

// MachineMapping comb = MachineMapping::combine(m0, m1);
Expand All @@ -16,7 +16,7 @@ TEST_SUITE(FF_TEST_SUITE) {
// }

// TEST_CASE("OptimalCostResult::infinity") {
// rc::check([](OptimalCostResult const &c) {
// RC_SUBCASE([](OptimalCostResult const &c) {
// RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime);
// });
// }
Expand Down
2 changes: 1 addition & 1 deletion lib/compiler/test/src/test_optimal_cost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ TEST_SUITE(FF_TEST_SUITE) {
// MachineSpecification const &) {
// return std::unordered_set<MachineView>{make_1d_machine_view(0, 1, 1)};
// };
// rc::check([](ParallelComputationGraph const &g,
// RC_SUBCASE([](ParallelComputationGraph const &g,
// MachineSpecification const &machine_spec) {
// OptimalCostCache cached_subgraph_costs;
// OptimalCostResult result = optimal_cost(g,
Expand Down
2 changes: 1 addition & 1 deletion lib/compiler/test/src/test_unity_algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
TEST_SUITE(FF_TEST_SUITE) {
// Rapidcheck does not work for now
// TEST_CASE("graph_optimize") {
// rc::check([](ComputationGraph const &g,
// RC_SUBCASE([](ComputationGraph const &g,
// float alpha,
// int budget,
// float threshold,
Expand Down
12 changes: 6 additions & 6 deletions lib/kernels/include/kernels/legion_dim_t.dtg.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions lib/kernels/src/kernels/legion_dim_t.dtg.cc

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 13 additions & 19 deletions lib/local-execution/src/ops/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "attention.h"
#include "kernels/attention_kernels.h"
#include "local-execution/op_task_signature.h"
#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h"

namespace FlexFlow {

Expand Down Expand Up @@ -95,31 +96,24 @@ static DeviceSpecific<DeviceStates>
ParallelTensorShape value_parallel_tensor_shape =
acc.get_argument<ParallelTensorShape>(VALUE_PARALLEL_TENSOR_SHAPE);

MultiHeadAttentionInputs inputs = {
shard_dim_at_idx(query_parallel_tensor_shape, ff_dim_t{0}).size,
shard_dim_at_idx(query_parallel_tensor_shape, ff_dim_t{1}).size,
qProjSize,
kProjSize,
vProjSize,
query_parallel_tensor_shape.data_type};
;
MultiHeadAttentionParallelInputs parsed = throw_if_unexpected(
parse_attention_parallel_input_shape(query_parallel_tensor_shape,
key_parallel_tensor_shape,
value_parallel_tensor_shape));
ParallelTensorShape weight_parallel_tensor_shape =
throw_if_unexpected(get_weights_shape(attrs,
query_parallel_tensor_shape,
key_parallel_tensor_shape,
value_parallel_tensor_shape));

int kvSeqLength = get_kvSeqLength(inputs);
int qSize = get_qSize(inputs);
int kSize = get_kSize(inputs);
int vSize = get_vSize(inputs);

int qoSeqLength =
dim_at_idx(get_piece_shape(query_parallel_tensor_shape), ff_dim_t(1));
int num_samples =
dim_at_idx(get_piece_shape(query_parallel_tensor_shape), ff_dim_t(2));
int num_heads =
dim_at_idx(get_piece_shape(weight_parallel_tensor_shape), ff_dim_t(1));
int kvSeqLength = get_kvSeqLength(parsed);
int qSize = get_qSize(parsed);
int kSize = get_kSize(parsed);
int vSize = get_vSize(parsed);

int qoSeqLength = get_qoSeqLength(parsed);
int num_samples = get_num_samples(parsed);
int num_heads = attrs.num_heads;

MHAPerDeviceState per_device_state = init_kernel(handle,
allocator,
Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ using DataTypeValue = std::variant<real_type<DataType::FLOAT>,

size_t size_of_datatype(DataType);

bool can_strictly_promote_datatype_from_to(DataType, DataType);

} // namespace FlexFlow

#endif
12 changes: 6 additions & 6 deletions lib/op-attrs/include/op-attrs/ff_dim.dtg.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions lib/op-attrs/include/op-attrs/get_output_shapes.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,8 @@ std::vector<TensorShape> get_output_shapes(Attrs const &attrs,

ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &,
std::vector<ParallelTensorShape> const &);
ParallelTensorShape get_output_shape(CastAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(ConcatAttrs const &,
std::vector<ParallelTensorShape> const &);
ParallelTensorShape get_output_shape(Conv2DAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(DropoutAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(FlatAttrs const &,
Expand All @@ -131,8 +127,6 @@ ParallelTensorShape get_output_shape(Pool2DAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(ReduceAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(ReplicateAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(ReverseAttrs const &,
ParallelTensorShape const &);
std::vector<ParallelTensorShape> get_output_shapes(SplitAttrs const &,
Expand Down
16 changes: 8 additions & 8 deletions lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 26 additions & 6 deletions lib/op-attrs/include/op-attrs/ops/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,37 @@ tl::expected<TensorShape, std::string>
TensorShape const &input_q,
TensorShape const &input_k,
TensorShape const &input_v);
tl::expected<ParallelTensorShape, std::string>
get_weights_shape(MultiHeadAttentionAttrs const &,
ParallelTensorShape const &input_q,
ParallelTensorShape const &input_k,
ParallelTensorShape const &input_v);

tl::expected<TensorShape, std::string>
get_input_bias_shape(MultiHeadAttentionAttrs const &,
TensorShape const &input_q,
TensorShape const &input_k,
TensorShape const &input_v);
tl::expected<TensorShape, std::string>
get_output_bias_shape(MultiHeadAttentionAttrs const &,
TensorShape const &input_q,
TensorShape const &input_k,
TensorShape const &input_v);
tl::expected<TensorShape, std::string>
get_output_shape(MultiHeadAttentionAttrs const &,
TensorShape const &input_q,
TensorShape const &input_k,
TensorShape const &input_v);

tl::expected<ParallelTensorShape, std::string>
get_weights_shape(MultiHeadAttentionAttrs const &,
ParallelTensorShape const &input_q,
ParallelTensorShape const &input_k,
ParallelTensorShape const &input_v);
tl::expected<ParallelTensorShape, std::string>
get_input_bias_shape(MultiHeadAttentionAttrs const &,
ParallelTensorShape const &input_q,
ParallelTensorShape const &input_k,
ParallelTensorShape const &input_v);
tl::expected<ParallelTensorShape, std::string>
get_output_bias_shape(MultiHeadAttentionAttrs const &,
ParallelTensorShape const &input_q,
ParallelTensorShape const &input_k,
ParallelTensorShape const &input_v);
tl::expected<ParallelTensorShape, std::string>
get_output_shape(MultiHeadAttentionAttrs const &,
ParallelTensorShape const &input_q,
Expand Down
Loading