diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index fec8224ceb17..916139874579 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -50,7 +50,7 @@ def add_run_parser(subparsers): # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "cuda", "cl", "metal"], + choices=["cpu", "cuda", "cl", "metal", "vulkan"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -392,6 +392,8 @@ def run_module( dev = session.cl() elif device == "metal": dev = session.metal() + elif device == "vulkan": + dev = session.vulkan() else: assert device == "cpu" dev = session.cpu()