Hip refactor for attention, batch, combine, cast, conv#1402
Hip refactor for attention, batch, combine, cast, conv#1402reyna-abhyankar merged 7 commits intoflexflow:repo-refactorfrom
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## repo-refactor #1402 +/- ##
==============================================
Coverage 38.10% 38.10%
==============================================
Files 167 167
Lines 5026 5026
Branches 246 246
==============================================
Hits 1915 1915
Misses 3111 3111
Flags with carried forward coverage won't be shown. Click here to find out more. |
reyna-abhyankar
left a comment
There was a problem hiding this comment.
Reviewable status: 0 of 7 files reviewed, 3 unresolved discussions (waiting on @Bob-Chen222)
lib/kernels/src/hip/attention_kernels.cpp line 242 at r1 (raw file):
device_state.reserveSpaceSize, device_state.reserveSpace)); #endif
Delete
lib/kernels/src/hip/batch_norm_kernels.cpp line 119 at r1 (raw file):
checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); mode = miopenBNSpatial; #if HIPDNN_VERSION >= 7000
Is this still true for HIP?
lib/kernels/src/hip/cast_kernels.cpp line 75 at r1 (raw file):
}; void forward_kernel(PerDeviceFFHandle handle,
Actually just keep stream as the first parameter for both functions. I'll change this in the cuda kernel as well.
Bob-Chen222
left a comment
There was a problem hiding this comment.
Reviewable status: 0 of 7 files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)
lib/kernels/src/hip/attention_kernels.cpp line 242 at r1 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
Delete
Done.
lib/kernels/src/hip/batch_norm_kernels.cpp line 119 at r1 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
Is this still true for HIP?
Done.
lib/kernels/src/hip/cast_kernels.cpp line 75 at r1 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
Actually just keep
streamas the first parameter for both functions. I'll change this in the cuda kernel as well.
Done.
Description of changes:
Hip refactor for attention, batch, combine, cast, conv
Related Issues:
Linked Issues:
Issues closed by this PR:
This change is