Skip to content

Simplify cat implementation#3215

Merged
naoyam merged 3 commits intomainfrom
cat_opt
Oct 19, 2024
Merged

Simplify cat implementation#3215
naoyam merged 3 commits intomainfrom
cat_opt

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Oct 18, 2024

cat is translated to CUDA code with a if-then-else block:

if (i < input0_ext) {
  out[idx] = input0[idx];
} else if (i < input0_ext + input1_ext) {
  out[idx] = input1[idx];
} else if ...

This is correct, but I believe this can be simplified to just out[idx] = input0[idx] + input1[idx] + ... since all of the inputs are padded by zero, so this result should be equivalent.

Since + is not defined for some low precision types, bitwise-or is instead used when addition is not available.

On A100, this simplification yielded about 5% perf improvement of a RoPE module.

I didn't add any specific test since I don't know if any new test would be beneficial. Half, bfloat16 and fp8 types are tested by ResizeTest.CatMemoryPromotionReducedFloating.

return val;
}

__device__ __inline__ __e4m3 operator|(const __e4m3 x, const __e4m3 y) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly sure if this is reasonable. I don't know why memcpy is extensively used for these fp8 types.

Pinging @jjsjann123

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does look cleaner :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there may be some magic with memcpy.

@naoyam
Copy link
Collaborator Author

naoyam commented Oct 18, 2024

!build

@naoyam naoyam requested a review from jacobhinkle October 18, 2024 07:02
Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I wonder if we could handle all types (other than ComplexDouble) by bitcasting to an integer of the right size then doing bitwise or. For this we would need Int8 and Int16, but we would not need support in the runtime files for bitwise ops on floats then.

@naoyam
Copy link
Collaborator Author

naoyam commented Oct 18, 2024

LGTM. I wonder if we could handle all types (other than ComplexDouble) by bitcasting to an integer of the right size then doing bitwise or. For this we would need Int8 and Int16, but we would not need support in the runtime files for bitwise ops on floats then.

Yeah, I did think about it, but I found it's just easier to handle these low precision types separately than adding the new integer types, at least for now.

@naoyam naoyam merged commit c7818bd into main Oct 19, 2024
@naoyam naoyam deleted the cat_opt branch October 19, 2024 01:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants