From 1a1d7d893a11a8d609b2ae2b82037179b75934dd Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:41:44 +0800 Subject: [PATCH] Add Pytorch version check when export onnx --- scripts/export.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/scripts/export.py b/scripts/export.py index 537cdad9f..613b3ee45 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -14,6 +14,16 @@ from utils.hparams import set_hparams, hparams +def check_pytorch_version(): + version = torch.__version__ + print(f"PyTorch version: {version}") + major, minor, _ = version.split('.') + if major != '1' and minor != '13': + raise RuntimeError(f"Unsupported PyTorch Version: {version}. need 1.13.x.") + else: + pass + + def find_exp(exp): if not (root_dir / 'checkpoints' / exp).exists(): for subdir in (root_dir / 'checkpoints').iterdir(): @@ -291,4 +301,5 @@ def nsf_hifigan( if __name__ == '__main__': + check_pytorch_version() main()