Current CP implementation will get full tensor before calculating loss which has shape [b, s, v]. There will be much time spent on communication. It is easier to implement but loss some perf.
Solution:
Allow out of order sequence dtensor to be passed to loss function and calculate in parallel way.
This is a follow up of issue 41.
Current CP implementation will get full tensor before calculating loss which has shape [b, s, v]. There will be much time spent on communication. It is easier to implement but loss some perf.
Solution:
Allow out of order sequence dtensor to be passed to loss function and calculate in parallel way.
This is a follow up of issue 41.