diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index e8d5a8efdc56..4604bc27fe9f 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -845,8 +845,6 @@ def forward( ) losses = [] - wer_numer_list = [] - wer_denom_list = [] batch_size = int(encoder_outputs.size(0)) # actual batch size # Iterate over batch using fused_batch_size steps @@ -914,31 +912,14 @@ def forward( else: losses = None - # Compute WER for sub batch + # Update WER for sub batch if compute_wer: sub_enc = sub_enc.transpose(1, 2) # [B, T, D] -> [B, D, T] sub_enc = sub_enc.detach() sub_transcripts = sub_transcripts.detach() - original_log_prediction = self.wer.log_prediction - if original_log_prediction and batch_idx == 0: - self.wer.log_prediction = True - else: - self.wer.log_prediction = False - - # Compute the wer (with logging for just 1st sub-batch) + # Update WER on each process without syncing self.wer.update(sub_enc, sub_enc_lens, sub_transcripts, sub_transcript_lens) - wer, wer_num, wer_denom = self.wer.compute() - self.wer.reset() - - wer_numer_list.append(wer_num) - wer_denom_list.append(wer_denom) - - # Reset logging default - self.wer.log_prediction = original_log_prediction - - else: - wer = None del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens @@ -951,12 +932,11 @@ def forward( # Collect sub batch wer results if compute_wer: - wer_num = torch.tensor(wer_numer_list, dtype=torch.long) - wer_denom = torch.tensor(wer_denom_list, dtype=torch.long) - - wer_num = wer_num.sum() # global sum of correct words/chars - wer_denom = wer_denom.sum() # global sum of all words/chars + # Sync and all_reduce on all processes, compute global WER + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() else: + wer = None wer_num = None wer_denom = None