diff --git a/nemo/collections/nlp/modules/common/text_generation_server.py b/nemo/collections/nlp/modules/common/text_generation_server.py index 40c9dc385e5e..3939f82f3e0d 100644 --- a/nemo/collections/nlp/modules/common/text_generation_server.py +++ b/nemo/collections/nlp/modules/common/text_generation_server.py @@ -158,6 +158,9 @@ def put(self): repetition_penalty, min_tokens_to_generate, ) + for k in output: + if isinstance(output[k], torch.Tensor): + output[k] = output[k].tolist() if not all_probs: del output['full_logprob'] return jsonify(output)