Skip to content

Embedding Op#1179

Closed
reyna-abhyankar wants to merge 8 commits intoflexflow:repo-refactorfrom
reyna-abhyankar:emb-eleu-layno
Closed

Embedding Op#1179
reyna-abhyankar wants to merge 8 commits intoflexflow:repo-refactorfrom
reyna-abhyankar:emb-eleu-layno

Conversation

@reyna-abhyankar
Copy link
Collaborator

@reyna-abhyankar reyna-abhyankar commented Oct 7, 2023

Description of changes:

Ignore branch name

Related Issues:

Linked Issues:

  • Issue #

Issues closed by this PR:


This change is Reviewable

Comment on lines +65 to +70
DeviceSpecific<ElementUnaryPerDeviceState> per_device_state =
acc.create_device_specific<ElementUnaryPerDeviceState>(
init_kernel(handle,
{input_shape.dims},
{output_shape.dims},
input_shape.data_type));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lambda7xx the kernel takes ArrayShape and we have ParallelTensorShape

@lambda7xx
Copy link
Contributor

lib/kernels/src/cuda/layer_norm_kernels.cu line 36 at r4 (raw file):

  checkCUDA(cudaMalloc(&rstd, sizeof(float) * batch_size));
  checkCUDA(cudaMalloc(&ds, sizeof(float) * batch_size));
  checkCUDA(cudaMalloc(&db, sizeof(float) * batch_size));

how about use Allocator to allocate memory? and

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/element_unary.cc line 70 at r2 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

@lambda7xx the kernel takes ArrayShape and we have ParallelTensorShape

I think we can use input to get its ArrayShape

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/element_unary.cc line 70 at r2 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

I think we can use input to get its ArrayShape

forget to pass the op_type. The definition of init_kernel is below.

ElementUnaryPerDeviceState init_kernel(PerDeviceFFHandle const &handle,
                                       ArrayShape const &input_shape,
                                       ArrayShape const &output_shape,
                                       OperatorType op_type,
                                       DataType data_type)

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/element_unary.cc line 216 at r4 (raw file):

  SimTaskBinding init_binding;
  init_binding.bind_arg(HANDLE, ff_handle());
  init_binding.bind_arg(ATTRS, attrs);

can we bind ElementScalarUnaryAttrs const &attrs , then we can get ElementUnaryPerDeviceState?

The ElementScalarUnaryAttrs const &attrs and ElementUnaryAttrs const &attrs are different class.

Copy link
Contributor

@lambda7xx lambda7xx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 5 files at r1.
Reviewable status: 1 of 13 files reviewed, 3 unresolved discussions (waiting on @lockshaw, @reyna-abhyankar, and @wmdi)

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/embedding.h line 20 at r4 (raw file):

CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory,
                                  EmbeddingAttrs const &attrs,
                                  InputParallelTensorDesc const &input_shape,

why InputParallelTensorDesc ? original code is ParallelTensorShape

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/embedding.cc line 71 at r4 (raw file):

                 input.shape.get_dim(),
                 output.shape.get_dim(),
                 input.shape[legion_dim_t(1)]);

the original code

    int out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1;
    int effective_batch_size = output.domain.get_volume() / out_dim;


``
so I think the batch_size should be

int out_dim = output.shape.at(ff_dim_t{0}) + 1;
int batch_size = output.shape.get_volume() / out_dim;

``

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/embedding.cc line 71 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

the original code

    int out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1;
    int effective_batch_size = output.domain.get_volume() / out_dim;


``
so I think the batch_size should be

int out_dim = output.shape.at(ff_dim_t{0}) + 1;
int batch_size = output.shape.get_volume() / out_dim;

``

the original code

    int out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1;
    int effective_batch_size = output.domain.get_volume() / out_dim;


so I think the batch_size should be


int out\_dim = output.shape.at(ff\_dim\_t{0}) + 1;  
int batch\_size = output.shape.get\_volume() / out\_dim;

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/embedding.cc line 85 at r4 (raw file):

  auto input = acc.get_tensor<Permissions::RO>(INPUT);
  auto output = acc.get_tensor<Permissions::RO>(OUTPUT);
  auto weight_grad = acc.get_tensor_grad<Permissions::RO>(WEIGHT);

auto weight_grad = acc.get_tensor_gradPermissions::RW(WEIGHT);

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/embedding.cc line 101 at r4 (raw file):

                 input.shape.get_dim(),
                 output.shape.get_dim(),
                 input.shape[legion_dim_t(1)]);

the input.shape[legion_dim_t(1)]) is batch_size? I don't think so.

@lambda7xx
Copy link
Contributor

lib/kernels/src/cuda/element_unary_kernels.cu line 78 at r4 (raw file):

  ElementUnaryPerDeviceState per_device_state = {
      handle, inputTensor, outputTensor, actiDesc, op_type, data_type, scalar};

where is scalar

@lambda7xx
Copy link
Contributor

lib/kernels/src/cuda/layer_norm_kernels.cu line 36 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

how about use Allocator to allocate memory? and

I set a pr about layernorm LayerNorm OP draft by lambda7xx · Pull Request #1186 · flexflow/FlexFlow (github.com)

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 1 of 13 files reviewed, 7 unresolved discussions (waiting on @lambda7xx, @lockshaw, and @wmdi)


lib/kernels/src/cuda/element_unary_kernels.cu line 78 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

where is scalar

Done. See other comment (I think we should merge the two attrs classes)


lib/kernels/src/cuda/layer_norm_kernels.cu line 36 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

I set a pr about layernorm LayerNorm OP draft by lambda7xx · Pull Request #1186 · flexflow/FlexFlow (github.com)

Ok, we can use your layer norm PR
I think you can do something like this:

Code snippet:

int n = 6; // number of pointers
mean = (float *) allocator.allocate(sizeof(float) * batch_size * n);
rstd = (float *) mean + batch_size;
...

lib/runtime/src/ops/element_unary.cc line 216 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

can we bind ElementScalarUnaryAttrs const &attrs , then we can get ElementUnaryPerDeviceState?

The ElementScalarUnaryAttrs const &attrs and ElementUnaryAttrs const &attrs are different class.

I actually think we should merge them. @lockshaw the op type will determine what is executed in the kernel anyway


lib/runtime/src/ops/embedding.h line 20 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

why InputParallelTensorDesc ? original code is ParallelTensorShape

InputParallelTensorDesc tells us if an input is trainable or not. Useful for the binding


lib/runtime/src/ops/embedding.cc line 71 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

the original code

    int out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1;
    int effective_batch_size = output.domain.get_volume() / out_dim;


so I think the batch_size should be


int out\_dim = output.shape.at(ff\_dim\_t{0}) + 1;  
int batch\_size = output.shape.get\_volume() / out\_dim;

@lockshaw


lib/runtime/src/ops/embedding.cc line 85 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

auto weight_grad = acc.get_tensor_gradPermissions::RW(WEIGHT);

Done.


lib/runtime/src/ops/embedding.cc line 101 at r4 (raw file):

Previously, lambda7xx (Lambda(Xiaoxiang) Shi ) wrote…

the input.shape[legion_dim_t(1)]) is batch_size? I don't think so.

Done. See previous comment.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 4 of 5 files at r1, 3 of 4 files at r2, 2 of 2 files at r3, 3 of 3 files at r4, all commit messages.
Reviewable status: all files reviewed, 13 unresolved discussions (waiting on @lambda7xx, @reyna-abhyankar, and @wmdi)


lib/kernels/include/kernels/element_unary_kernels.h line 19 at r4 (raw file):

  OperatorType op_type;
  DataType data_type;
  float scalar;

Does this really compile without the req?

Suggestion:

  req<float> scalar;

lib/kernels/include/kernels/layer_norm_kernels.h line 38 at r4 (raw file):

                     GenericTensorAccessorW const &beta_grad,
                     DataType data_type,
                     int64_t batch_size,

Isn't this part of the shape so it can just be accessed via input?


lib/kernels/include/kernels/layer_norm_kernels.h line 39 at r4 (raw file):

                     DataType data_type,
                     int64_t batch_size,
                     int64_t num_elements,

Isn't this part of the shape so it can jut be accessed via one of the weights?


lib/kernels/src/cuda/element_unary_kernels.cu line 78 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Done. See other comment (I think we should merge the two attrs classes)

Are you still in favor of this after the meeting yesterday? I'd like to keep the attrs separate, though I don't really care what happens with them in kernels


lib/kernels/src/cuda/layer_norm_kernels.cu line 36 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Ok, we can use your layer norm PR
I think you can do something like this:

Yeah cudaMalloc should be replaced by Allocator


lib/kernels/src/hip/layer_norm_kernels.cpp line 33 at r4 (raw file):

                                    int64_t effective_batch_size) {
  float *mean, *rstd, *ds, *db, *scale, *bias;
  checkCUDA(cudaMalloc(&mean, sizeof(float) * batch_size));

Use Allocator here


lib/kernels/src/hip/layer_norm_kernels.cpp line 191 at r4 (raw file):

                       GenericTensorAccessorW const &beta_grad,
                       DataType data_type,
                       int64_t batch_size,

Isn't this accessible through the tensor shapes?


lib/runtime/src/ops/element_unary.cc line 178 at r4 (raw file):

  init_binding.bind_arg(HANDLE, ff_handle());
  init_binding.bind_arg(ATTRS, attrs);
  init_binding.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0));

Suggestion:

init_binding.bind_arg(INPUT_SHAPE, input_shape);

lib/runtime/src/ops/element_unary.cc line 216 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

I actually think we should merge them. @lockshaw the op type will determine what is executed in the kernel anyway

I'd like to keep them separate at the op-attrs level, but I don't care what happens to them them in the runtime/ops and kernels levels


lib/runtime/src/ops/embedding.cc line 71 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

@lockshaw

Should be input.shape[ff_dim_t(0)] I think, as this is just a TensorShape and not a ParallelTensorShape and so there shouldn't be a parallel dimension present


lib/runtime/src/ops/embedding.cc line 85 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Done.

I'm still seeing Permissions::RO here...


lib/runtime/src/ops/embedding.cc line 101 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Done. See previous comment.

Should be ff_dim_t(0)

@lambda7xx
Copy link
Contributor

lib/runtime/src/ops/embedding.cc line 71 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Should be input.shape[ff_dim_t(0)] I think, as this is just a TensorShape and not a ParallelTensorShape and so there shouldn't be a parallel dimension present

what's the implementation @lockshaw

@reyna-abhyankar reyna-abhyankar deleted the emb-eleu-layno branch January 1, 2024 20:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants