diff --git a/include/tvm/runtime/contrib/libtorch_runtime.h b/include/tvm/runtime/contrib/libtorch_runtime.h new file mode 100644 index 000000000000..2645fb94d10d --- /dev/null +++ b/include/tvm/runtime/contrib/libtorch_runtime.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief runtime implementation for LibTorch/TorchScript. + */ +#ifndef TVM_RUNTIME_CONTRIB_LIBTORCH_RUNTIME_H_ +#define TVM_RUNTIME_CONTRIB_LIBTORCH_RUNTIME_H_ +#include + +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +runtime::Module TorchRuntimeCreate(const String& symbol_name, + const std::string& serialized_function); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_LIBTORCH_RUNTIME_H_ diff --git a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc index 25bfbfad4443..f70466f00eed 100644 --- a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc +++ b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/runtime/contrib/libtorch/libtorch_runtime.cc b/src/runtime/contrib/libtorch/libtorch_runtime.cc index 5076b967a1de..e76d04389ec7 100644 --- a/src/runtime/contrib/libtorch/libtorch_runtime.cc +++ b/src/runtime/contrib/libtorch/libtorch_runtime.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include diff --git a/tests/python/contrib/test_libtorch_ops.py b/tests/python/contrib/test_libtorch_ops.py index 751a547f94f5..28ae39c329f5 100644 --- a/tests/python/contrib/test_libtorch_ops.py +++ b/tests/python/contrib/test_libtorch_ops.py @@ -20,13 +20,16 @@ import tvm.relay from tvm.relay.op.contrib import torchop +import_torch_error = None + try: import torch -except ImportError as _: +except ImportError as e: torch = None + import_torch_error = str(e) -@pytest.mark.skipif(torch is None, reason="PyTorch is not available") +@pytest.mark.skipif(torch is None, reason=f"PyTorch is not available: {import_torch_error}") def test_backend(): @torch.jit.script def script_fn(x, y):