Skip to content

Local Execution: Op refactor#1389

Merged
lockshaw merged 38 commits intoflexflow:repo-refactorfrom
reyna-abhyankar:op-refactor
Jun 2, 2024
Merged

Local Execution: Op refactor#1389
lockshaw merged 38 commits intoflexflow:repo-refactorfrom
reyna-abhyankar:op-refactor

Conversation

@reyna-abhyankar
Copy link
Collaborator

@reyna-abhyankar reyna-abhyankar commented May 14, 2024

Description of changes:

Small fixes to operators to local execution and remove legion names

Related Issues:

Linked Issues:

  • Issue #

Issues closed by this PR:

  • Closes #

This change is Reviewable

@reyna-abhyankar reyna-abhyankar requested a review from lockshaw May 14, 2024 16:38
Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 34 of 60 files at r1, 68 of 68 files at r2, all commit messages.
Reviewable status: all files reviewed, 10 unresolved discussions (waiting on @reyna-abhyankar)


lib/local-execution/include/serialization.h line 1 at r2 (raw file):

#ifndef _FLEXFLOW_LOCAL_EXECUTION_SERIALIZATION_H

Does serialization need to be in local execution or should it be in runtime? We don't need to copy anything between machines so I'm not sure functionality-wise it needs to be in local execution, though if it makes things easier I don't mind


lib/local-execution/include/task_argument_accessor.h line 1 at r2 (raw file):

#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_ARGUMENT_ACCESSOR_H

Might be nice to split up this file a bit eventually


lib/utils/include/utils/type_index.h line 11 at r2 (raw file):

template <typename T>
std::type_index init_type_index() {

get_type_index might be better?


lib/local-execution/include/op_task_invocation.h line 62 at r2 (raw file):

  }

  void bind_args_from_fwd(OpTaskBinding const &fwd) {

What is this (and the next method) intended to do?


lib/local-execution/include/op_task_invocation.h line 101 at r2 (raw file):

OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd);

bool is_invocation_valid(OpTaskSignature sig, OpTaskInvocation inv);

Probably want to take args as const &


lib/local-execution/include/op_task_signature.h line 41 at r2 (raw file):

struct OpTaskSignature {
  OpTaskSignature() = delete;

Why not delete the default constructor?


lib/local-execution/include/op_task_signature.h line 65 at r2 (raw file):

  /* void add_input_slot(slot_id, SlotType, Legion::PrivilegeMode); */

  bool operator==(OpTaskSignature const &) const;

Why remove equality comparison?


lib/local-execution/include/profiling.h line 18 at r2 (raw file):

      profiling_wrapper<F, Ts...>(f, profiling, std::forward<Ts>(ts)...);
  if (elapsed.has_value()) {
    log_profile.debug(s, elapsed.value());

Why remove?


lib/local-execution/src/ops/layer_norm.cc line 32 at r2 (raw file):

  PROFILING,
  INPUT,
  INPUT_GRAD,

Don't we have the get_tensor_grad (or something similar) or am I misremembering?


lib/local-execution/src/ops/linear.cc line 149 at r2 (raw file):

                 "[Linear] backward_time = %.2lfms\n",
                 per_device_state,
                 (void *)input.get_float_ptr(),

Why does this need a cast to void *?

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 89 of 101 files reviewed, 10 unresolved discussions (waiting on @lockshaw)


lib/local-execution/include/serialization.h line 1 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Does serialization need to be in local execution or should it be in runtime? We don't need to copy anything between machines so I'm not sure functionality-wise it needs to be in local execution, though if it makes things easier I don't mind

That's true, but it does make things easier sinceArgRefSpec, ConcreteArgSpec, and OpTaskSignature all check for serialization so it'll just make things easier.


lib/local-execution/include/task_argument_accessor.h line 1 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Might be nice to split up this file a bit eventually

Noted


lib/utils/include/utils/type_index.h line 11 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

get_type_index might be better?

We're already using get_type_index(). How about "construct"?


lib/local-execution/include/op_task_invocation.h line 62 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

What is this (and the next method) intended to do?

This is called for infer_bwd_binding, but I guess we don't need a method for this and it can be replaced directly the other function. I've changed this.


lib/local-execution/include/op_task_invocation.h line 101 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Probably want to take args as const &

Done.


lib/local-execution/include/op_task_signature.h line 41 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why not delete the default constructor?

Done.


lib/local-execution/include/op_task_signature.h line 65 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why remove equality comparison?

Don't we get this from visitable?


lib/local-execution/include/profiling.h line 18 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why remove?

It's the legion logger. What should we use instead for logging?


lib/local-execution/src/ops/layer_norm.cc line 32 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Don't we have the get_tensor_grad (or something similar) or am I misremembering?

You're correct. Fixed


lib/local-execution/src/ops/linear.cc line 149 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why does this need a cast to void *?

I think this is because cublasGemmEx takes void* and then also the datatype for that tensor

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 10 of 10 files at r6, all commit messages.
Reviewable status: all files reviewed, 1 unresolved discussion (waiting on @reyna-abhyankar)


lib/local-execution/include/profiling.h line 17 at r6 (raw file):

      profiling_wrapper<F, Ts...>(f, profiling, std::forward<Ts>(ts)...);
  if (elapsed.has_value()) {
    spdlog::debug(elapsed.value());

Can we get an actual message here rather than just a float?

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 100 of 109 files reviewed, 1 unresolved discussion (waiting on @lockshaw)


lib/local-execution/include/profiling.h line 17 at r6 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Can we get an actual message here rather than just a float?

Done? So s is something like "[MultiHeadAttention] backward_time = %.2lfms\n" which we can format with elapsed.value()
I think this is syntactically correct, but I could be wrong.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 3 files at r7, 7 of 7 files at r8, 1 of 1 files at r9, all commit messages.
Reviewable status: all files reviewed, 10 unresolved discussions (waiting on @reyna-abhyankar)


lib/kernels/include/kernels/gather_kernels.h line 12 at r9 (raw file):

  PerDeviceFFHandle handle;
  int legion_dim;
};

Suggestion:

struct GatherPerDeviceState {
  PerDeviceFFHandle handle;
  legion_dim_t legion_dim;
};

lib/kernels/src/cuda/ops/gather_kernels.cu line 42 at r9 (raw file):

    // Therefore, input_index = outter_index * (stride * input_dim_size)
    //                        + index[0] * stride + left_over;
    coord_t outter_index = o / (stride * output_dim_size);

Suggestion:

    coord_t outer_index = o / (stride * output_dim_size);

lib/kernels/src/cuda/ops/gather_kernels.cu line 46 at r9 (raw file):

    coord_t left_over = o % stride;
    coord_t input_idx = outter_index * (stride * input_dim_size) +
                        index[o] * stride + left_over;

I assume this was just a typo?

Suggestion:

    coord_t input_idx = outer_index * (stride * input_dim_size) +
                        index[o] * stride + left_over;

lib/kernels/src/cuda/ops/gather_kernels.cu line 142 at r9 (raw file):

  for (int i = 0; i < m.legion_dim; i++) {
    stride *= output.shape[legion_dim_t(i)] + 1;
  }

Suggestion:

  coord_t stride = output.shape.sub_shape(std::nullopt, m.legion_dim + 1).get_volume();

lib/kernels/src/cuda/ops/gather_kernels.cu line 144 at r9 (raw file):

  }

  coord_t output_dim_size = output.shape[legion_dim_t(m.legion_dim)] + 1;

Suggestion:

  coord_t output_dim_size = output.shape[legion_dim_t(m.legion_dim)];

lib/kernels/src/cuda/ops/gather_kernels.cu line 145 at r9 (raw file):

  coord_t output_dim_size = output.shape[legion_dim_t(m.legion_dim)] + 1;
  coord_t input_dim_size = input.shape[legion_dim_t(m.legion_dim)] + 1;

Suggestion:

  coord_t input_dim_size = input.shape[legion_dim_t(m.legion_dim)];

lib/kernels/src/cuda/ops/gather_kernels.cu line 178 at r9 (raw file):

  coord_t stride = 1;
  for (int i = 0; i < m.legion_dim; i++) {

Same fixes as in forward kernel are needed here


lib/op-attrs/include/op-attrs/ops/gather.h line 12 at r9 (raw file):

struct GatherAttrs {
  req<int> legion_dim;

And make sure everything flips the dim. Definitely don't use raw (i.e., int) dims anywhere as it is unclear if they are in legion ordering or ff ordering. I'd lean toward ff_dim_t at this level as legion_dim_t is only really necessary when dealing with the concrete data layouts

Suggestion:

  ff_dim_t dim;

lib/substitutions/src/operator_attributes.cc line 132 at r9 (raw file):

  switch (key) {
    case OperatorAttributeKey::AXIS:
      return p.legion_dim;

Suggestion:

      return p.dim;

lib/local-execution/include/profiling.h line 17 at r6 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Done? So s is something like "[MultiHeadAttention] backward_time = %.2lfms\n" which we can format with elapsed.value()
I think this is syntactically correct, but I could be wrong.

I don't think spdlog and Legion's logger use the same syntax (spdlog uses fmt, Legion uses printf-style), so we'll need to choose one--considering that we already use fmt all over the place (and it has support for non-primitive data structures), probably that, in which case here I think you'd do spdlog::debug(s, elapsed.value());, and then we'd have to update the strings that get passed in to be "[MultiHeadAttention] backward_time = {:.2lf}ms\n" I think. You should look into vformat in https://fmt.dev/latest/api.html and at https://github.com/flexflow/FlexFlow/blob/repo-refactor/lib/utils/include/utils/exception.h#L10-L15, though it may be best done slightly different than there. Another option is to, instead of taking in an arbitrary format string, take in an operator name and a pass, and just use "[{}] {}_time = {:.2lf}ms\n", which is probably best as we really don't need the complexity of passing in arbitrary format strings.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 8 of 8 files at r10, all commit messages.
Reviewable status: all files reviewed, 1 unresolved discussion (waiting on @reyna-abhyankar)


lib/local-execution/include/profiling.h line 17 at r6 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I don't think spdlog and Legion's logger use the same syntax (spdlog uses fmt, Legion uses printf-style), so we'll need to choose one--considering that we already use fmt all over the place (and it has support for non-primitive data structures), probably that, in which case here I think you'd do spdlog::debug(s, elapsed.value());, and then we'd have to update the strings that get passed in to be "[MultiHeadAttention] backward_time = {:.2lf}ms\n" I think. You should look into vformat in https://fmt.dev/latest/api.html and at https://github.com/flexflow/FlexFlow/blob/repo-refactor/lib/utils/include/utils/exception.h#L10-L15, though it may be best done slightly different than there. Another option is to, instead of taking in an arbitrary format string, take in an operator name and a pass, and just use "[{}] {}_time = {:.2lf}ms\n", which is probably best as we really don't need the complexity of passing in arbitrary format strings.

You need to update the format strings in all of the operators due to the change in syntax

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 1 unresolved discussion (waiting on @lockshaw)


lib/kernels/include/kernels/gather_kernels.h line 12 at r9 (raw file):

  PerDeviceFFHandle handle;
  int legion_dim;
};

Done.


lib/kernels/src/cuda/ops/gather_kernels.cu line 42 at r9 (raw file):

    // Therefore, input_index = outter_index * (stride * input_dim_size)
    //                        + index[0] * stride + left_over;
    coord_t outter_index = o / (stride * output_dim_size);

Done.


lib/kernels/src/cuda/ops/gather_kernels.cu line 46 at r9 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I assume this was just a typo?

Done.


lib/kernels/src/cuda/ops/gather_kernels.cu line 142 at r9 (raw file):

  for (int i = 0; i < m.legion_dim; i++) {
    stride *= output.shape[legion_dim_t(i)] + 1;
  }

Done.


lib/kernels/src/cuda/ops/gather_kernels.cu line 144 at r9 (raw file):

  }

  coord_t output_dim_size = output.shape[legion_dim_t(m.legion_dim)] + 1;

Done.


lib/kernels/src/cuda/ops/gather_kernels.cu line 145 at r9 (raw file):

  coord_t output_dim_size = output.shape[legion_dim_t(m.legion_dim)] + 1;
  coord_t input_dim_size = input.shape[legion_dim_t(m.legion_dim)] + 1;

Done.


lib/kernels/src/cuda/ops/gather_kernels.cu line 178 at r9 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Same fixes as in forward kernel are needed here

Done.


lib/op-attrs/include/op-attrs/ops/gather.h line 12 at r9 (raw file):

Previously, lockshaw (Colin Unger) wrote…

And make sure everything flips the dim. Definitely don't use raw (i.e., int) dims anywhere as it is unclear if they are in legion ordering or ff ordering. I'd lean toward ff_dim_t at this level as legion_dim_t is only really necessary when dealing with the concrete data layouts

Done.


lib/substitutions/src/operator_attributes.cc line 132 at r9 (raw file):

  switch (key) {
    case OperatorAttributeKey::AXIS:
      return p.legion_dim;

Done.


lib/local-execution/include/profiling.h line 17 at r6 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I don't think spdlog and Legion's logger use the same syntax (spdlog uses fmt, Legion uses printf-style), so we'll need to choose one--considering that we already use fmt all over the place (and it has support for non-primitive data structures), probably that, in which case here I think you'd do spdlog::debug(s, elapsed.value());, and then we'd have to update the strings that get passed in to be "[MultiHeadAttention] backward_time = {:.2lf}ms\n" I think. You should look into vformat in https://fmt.dev/latest/api.html and at https://github.com/flexflow/FlexFlow/blob/repo-refactor/lib/utils/include/utils/exception.h#L10-L15, though it may be best done slightly different than there. Another option is to, instead of taking in an arbitrary format string, take in an operator name and a pass, and just use "[{}] {}_time = {:.2lf}ms\n", which is probably best as we really don't need the complexity of passing in arbitrary format strings.

Done.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 28 of 28 files at r11, all commit messages.
Reviewable status: all files reviewed, 1 unresolved discussion (waiting on @reyna-abhyankar)


lib/local-execution/src/ops/gather.cc line 112 at r11 (raw file):

  auto input_grad = acc.get_tensor_grad<Permissions::WO>(INPUT);

  return profile(forward_kernel,

Really should not be calling foward_kernel in the backward task impl I assume?


lib/local-execution/src/ops/combine.cc line 67 at r11 (raw file):

  return profile(backward_kernel,
                 profiling,
                 "[Combine] forward_time = {:.2lf}ms\n",

Someone should probably fix this to be backward time eventually...


lib/local-execution/src/ops/embedding.cc line 79 at r11 (raw file):

  return profile(backward_kernel,
                 profiling,
                 "[Embedding] forward_time = {:.2lf}ms\n",

Should eventually be fixed to be backward time


lib/local-execution/src/ops/flat.cc line 51 at r11 (raw file):

  return profile(backward_kernel,
                 profiling,
                 "[Flat] forward_time = {:.2lf}ms\n",

Should eventually be fixed to be backward time

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 43 of 111 files reviewed, 1 unresolved discussion (waiting on @lockshaw)


lib/local-execution/src/ops/gather.cc line 112 at r11 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Really should not be calling foward_kernel in the backward task impl I assume?

Done.


lib/local-execution/src/ops/combine.cc line 67 at r11 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Someone should probably fix this to be backward time eventually...

Done.


lib/local-execution/src/ops/embedding.cc line 79 at r11 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Should eventually be fixed to be backward time

Done.


lib/local-execution/src/ops/flat.cc line 51 at r11 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Should eventually be fixed to be backward time

Done.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 68 of 68 files at r12, all commit messages.
Reviewable status: :shipit: complete! all files reviewed, all discussions resolved (waiting on @reyna-abhyankar)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

runtime Runtime library

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants