diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ce0cdc129db..d67f7fd59aee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF) tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF) +tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions, STATIC, DYNAMIC, or OFF" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -352,6 +353,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) +include(cmake/modules/RustExt.cmake) include(CheckCXXCompilerFlag) if(NOT MSVC) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 5f8ace17111f..ac870b17faeb 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -16,7 +16,14 @@ # under the License. # LLVM rules -add_definitions(-DDMLC_USE_FOPEN64=0) +# Due to LLVM debug symbols you can sometimes face linking issues on +# certain compiler, platform combinations if you don't set NDEBUG. +# +# See https://github.com/imageworks/OpenShadingLanguage/issues/1069 +# for more discussion. +add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1) +# TODO(@jroesch, @tkonolige): if we actually use targets we can do this. +# target_compile_definitions(tvm PRIVATE NDEBUG=1) # Test if ${USE_LLVM} is not an explicit boolean false # It may be a boolean or a string diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake new file mode 100644 index 000000000000..2922bc48dee2 --- /dev/null +++ b/cmake/modules/RustExt.cmake @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_RUST_EXT) + set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust") + set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target") + + if(USE_RUST_EXT STREQUAL "STATIC") + set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.a") + elseif(USE_RUST_EXT STREQUAL "DYNAMIC") + set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so") + else() + message(FATAL_ERROR "invalid setting for USE_RUST_EXT, STATIC, DYNAMIC or OFF") + endif() + + add_custom_command( + OUTPUT "${COMPILER_EXT_PATH}" + COMMAND cargo build --release + MAIN_DEPENDENCY "${RUST_SRC_DIR}" + WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext") + + add_custom_target(rust_ext ALL DEPENDS "${COMPILER_EXT_PATH}") + + # TODO(@jroesch, @tkonolige): move this to CMake target + # target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE) + list(APPEND TVM_LINKER_LIBS ${COMPILER_EXT_PATH}) + + add_definitions(-DRUST_COMPILER_EXT=1) +endif() diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 5595574265c6..5316c8bd2b33 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -103,8 +103,6 @@ class SourceMap : public ObjectRef { TVM_DLL SourceMap() : SourceMap({}) {} - TVM_DLL static SourceMap Global(); - void Add(const Source& source); SourceMapNode* operator->() { diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py index 6503743aaa51..3a6402c0359d 100644 --- a/python/tvm/ir/diagnostics/__init__.py +++ b/python/tvm/ir/diagnostics/__init__.py @@ -38,6 +38,7 @@ def get_renderer(): return _ffi_api.GetRenderer() +@tvm.register_func("diagnostics.override_renderer") def override_renderer(render_func): """ Sets a custom renderer for diagnostics. diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 9935ce7c8b9f..6e14c2b02c2b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -28,4 +28,5 @@ members = [ "tvm-graph-rt/tests/test_tvm_dso", "tvm-graph-rt/tests/test_wasm32", "tvm-graph-rt/tests/test_nn", + "compiler-ext", ] diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml new file mode 100644 index 000000000000..b830b7a84135 --- /dev/null +++ b/rust/compiler-ext/Cargo.toml @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "compiler-ext" +version = "0.1.0" +authors = ["TVM Contributors"] +edition = "2018" + +[lib] +crate-type = ["staticlib", "cdylib"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tvm = { path = "../tvm", default-features = false, features = ["static-linking"] } +log = "*" +env_logger = "*" diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs new file mode 100644 index 000000000000..278060ef4897 --- /dev/null +++ b/rust/compiler-ext/src/lib.rs @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use env_logger; +use tvm::export; + +fn diagnostics() -> Result<(), tvm::Error> { + tvm::ir::diagnostics::codespan::init() +} + +export!(diagnostics); + +#[no_mangle] +extern "C" fn compiler_ext_initialize() -> i32 { + let _ = env_logger::try_init(); + tvm_export("rust_ext").expect("failed to initialize the Rust compiler extensions."); + log::debug!("Loaded the Rust compiler extension."); + return 0; +} diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml index acece5aeec48..9660943da50d 100644 --- a/rust/tvm-rt/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -28,19 +28,26 @@ categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" +[features] +default = ["dynamic-linking"] +dynamic-linking = ["tvm-sys/bindings"] +static-linking = [] +blas = ["ndarray/blas"] + [dependencies] thiserror = "^1.0" ndarray = "0.12" num-traits = "0.2" -tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] } tvm-macros = { version = "0.1", path = "../tvm-macros" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" memoffset = "0.5.6" +[dependencies.tvm-sys] +version = "0.1" +default-features = false +path = "../tvm-sys/" + [dev-dependencies] anyhow = "^1.0" - -[features] -blas = ["ndarray/blas"] diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 5e19cefd8e97..98414f9c5b34 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -18,6 +18,7 @@ */ use std::convert::{TryFrom, TryInto}; +use std::iter::{IntoIterator, Iterator}; use std::marker::PhantomData; use crate::errors::Error; @@ -81,6 +82,42 @@ impl Array { } } +pub struct IntoIter { + array: Array, + pos: isize, + size: isize, +} + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + if self.pos < self.size { + let item = + self.array.get(self.pos) + .expect("Can not index as in-bounds position after bounds checking.\nNote: this error can only be do to an uncaught issue with API bindings."); + self.pos += 1; + Some(item) + } else { + None + } + } +} + +impl IntoIterator for Array { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + let size = self.len() as isize; + IntoIter { + array: self, + pos: 0, + size: size, + } + } +} + impl From> for ArgValue<'static> { fn from(array: Array) -> ArgValue<'static> { array.object.into() diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index c884c56fed44..31ce385ef662 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -68,6 +68,23 @@ pub enum Error { Infallible(#[from] std::convert::Infallible), #[error("a panic occurred while executing a Rust packed function")] Panic, + #[error( + "one or more error diagnostics were emitted, please check diagnostic render for output." + )] + DiagnosticError(String), + #[error("{0}")] + Raw(String), +} + +impl Error { + pub fn from_raw_tvm(raw: &str) -> Error { + let err_header = raw.find(":").unwrap_or(0); + let (err_ty, err_content) = raw.split_at(err_header); + match err_ty { + "DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()), + _ => Error::Raw(raw.into()), + } + } } impl Error { diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index bae06e929361..aec4a8ad44de 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -120,24 +120,27 @@ impl Function { let mut ret_val = ffi::TVMValue { v_int64: 0 }; let mut ret_type_code = 0i32; - check_call!(ffi::TVMFuncCall( - self.handle, - values.as_mut_ptr() as *mut ffi::TVMValue, - type_codes.as_mut_ptr() as *mut c_int, - num_args as c_int, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _ - )); + let ret_code = unsafe { + ffi::TVMFuncCall( + self.handle, + values.as_mut_ptr() as *mut ffi::TVMValue, + type_codes.as_mut_ptr() as *mut c_int, + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _, + ) + }; + + if ret_code != 0 { + let raw_error = crate::get_last_error(); + let error = match Error::from_raw_tvm(raw_error) { + Error::Raw(string) => Error::CallFailed(string), + e => e, + }; + return Err(error); + } let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); - match rv { - RetValue::ObjectHandle(object) => { - let optr = crate::object::ObjectPtr::from_raw(object as _).unwrap(); - // println!("after wrapped call: {}", optr.count()); - crate::object::ObjectPtr::leak(optr); - } - _ => {} - }; Ok(rv) } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 77254d2fbca2..8d535368c352 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -125,7 +125,7 @@ impl Object { /// By using associated constants and generics we can provide a /// type indexed abstraction over allocating objects with the /// correct index and deleter. - pub fn base_object() -> Object { + pub fn base() -> Object { let index = Object::get_type_index::(); Object::new(index, delete::) } @@ -351,7 +351,7 @@ mod tests { #[test] fn test_new_object() -> anyhow::Result<()> { - let object = Object::base_object::(); + let object = Object::base::(); let ptr = ObjectPtr::new(object); assert_eq!(ptr.count(), 1); Ok(()) @@ -359,7 +359,7 @@ mod tests { #[test] fn test_leak() -> anyhow::Result<()> { - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let object = ObjectPtr::leak(ptr); assert_eq!(object.count(), 1); @@ -368,7 +368,7 @@ mod tests { #[test] fn test_clone() -> anyhow::Result<()> { - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let ptr2 = ptr.clone(); assert_eq!(ptr2.count(), 2); @@ -379,7 +379,7 @@ mod tests { #[test] fn roundtrip_retvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let ret_value: RetValue = ptr.clone().into(); let ptr2: ObjectPtr = ret_value.try_into()?; @@ -401,7 +401,7 @@ mod tests { #[test] fn roundtrip_argvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let ptr_clone = ptr.clone(); assert_eq!(ptr.count(), 2); @@ -435,7 +435,7 @@ mod tests { fn test_ref_count_boundary3() { use super::*; use crate::function::{register, Function}; - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let stay = ptr.clone(); assert_eq!(ptr.count(), 2); diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 6ff24bef3a60..3cd33a226d44 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -38,7 +38,7 @@ impl From for String { fn from(s: std::string::String) -> Self { let size = s.len() as u64; let data = Box::into_raw(s.into_boxed_str()).cast(); - let base = Object::base_object::(); + let base = Object::base::(); StringObj { base, data, size }.into() } } @@ -47,7 +47,7 @@ impl From<&'static str> for String { fn from(s: &'static str) -> Self { let size = s.len() as u64; let data = s.as_bytes().as_ptr(); - let base = Object::base_object::(); + let base = Object::base::(); StringObj { base, data, size }.into() } } diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index 4e3fc98b4e75..2952aa4938d7 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -23,6 +23,7 @@ license = "Apache-2.0" edition = "2018" [features] +default = [] bindings = [] [dependencies] diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 05806c0d5ce0..159023463e8d 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -60,7 +60,7 @@ fn main() -> Result<()> { if cfg!(feature = "bindings") { println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rustc-link-lib=dylib=tvm"); - println!("cargo:rustc-link-search={}/build", tvm_home); + println!("cargo:rustc-link-search=native={}/build", tvm_home); } // @see rust-bindgen#550 for `blacklist_type` diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index f7b289c59675..7b8d5296d641 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -101,6 +101,7 @@ macro_rules! TVMPODValue { TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMArgTypeCode_kTVMObjectRValueRefArg => ObjectHandle(*($value.v_handle as *mut *mut c_void)), TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index 55fc1790604e..153a1950e46b 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -28,22 +28,32 @@ categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" +[features] +default = ["python", "dynamic-linking"] +dynamic-linking = ["tvm-rt/dynamic-linking"] +static-linking = ["tvm-rt/static-linking"] +blas = ["ndarray/blas"] +python = ["pyo3"] + +[dependencies.tvm-rt] +version = "0.1" +default-features = false +path = "../tvm-rt/" + [dependencies] thiserror = "^1.0" anyhow = "^1.0" lazy_static = "1.1" ndarray = "0.12" num-traits = "0.2" -tvm-rt = { version = "0.1", path = "../tvm-rt/" } -tvm-sys = { version = "0.1", path = "../tvm-sys/" } tvm-macros = { version = "*", path = "../tvm-macros/" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" pyo3 = { version = "0.11.1", optional = true } +codespan-reporting = "0.9.5" +structopt = { version = "0.3" } -[features] -default = ["python"] - -blas = ["ndarray/blas"] -python = ["pyo3"] +[[bin]] +name = "tyck" +required-features = ["dynamic-linking"] diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs new file mode 100644 index 000000000000..839a6bd1c17f --- /dev/null +++ b/rust/tvm/src/bin/tyck.rs @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::path::PathBuf; + +use anyhow::Result; +use structopt::StructOpt; + +use tvm::ir::diagnostics::codespan; +use tvm::ir::{self, IRModule}; +use tvm::runtime::Error; + +#[derive(Debug, StructOpt)] +#[structopt(name = "tyck", about = "Parse and type check a Relay program.")] +struct Opt { + /// Input file + #[structopt(parse(from_os_str))] + input: PathBuf, +} + +fn main() -> Result<()> { + codespan::init().expect("Failed to initialize Rust based diagnostics."); + let opt = Opt::from_args(); + let _module = match IRModule::parse_file(opt.input) { + Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => return Ok(()), + Err(e) => { + return Err(e.into()); + } + Ok(module) => module, + }; + + Ok(()) +} diff --git a/rust/tvm/src/ir/arith.rs b/rust/tvm/src/ir/arith.rs index f589f2ac25c6..92a1de69ff78 100644 --- a/rust/tvm/src/ir/arith.rs +++ b/rust/tvm/src/ir/arith.rs @@ -34,7 +34,7 @@ macro_rules! define_node { impl $name { pub fn new($($id : $t,)*) -> $name { - let base = Object::base_object::<$node>(); + let base = Object::base::<$node>(); let node = $node { base, $($id),* }; $name(Some(ObjectPtr::new(node))) } diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs new file mode 100644 index 000000000000..c411c0cd31a7 --- /dev/null +++ b/rust/tvm/src/ir/diagnostics/codespan.rs @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! A TVM diagnostics renderer which uses the Rust `codespan` library +//! to produce error messages. +//! +//! This is an example of using the exposed API surface of TVM to +//! customize the compiler behavior. +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity}; +use codespan_reporting::files::SimpleFiles; +use codespan_reporting::term::termcolor::{ColorChoice, StandardStream}; +use codespan_reporting::term::{self}; + +use super::*; +use crate::ir::source_map::*; + +/// A representation of a TVM Span as a range of bytes in a file. +struct ByteRange { + /// The file in which the range occurs. + #[allow(dead_code)] + file_id: FileId, + /// The range start. + start_pos: usize, + /// The range end. + end_pos: usize, +} + +/// A mapping from Span to ByteRange for a single file. +enum FileSpanToByteRange { + AsciiSource(Vec), + #[allow(dead_code)] + Utf8 { + /// Map character regions which are larger then 1-byte to length. + lengths: HashMap, + /// The source of the program. + source: String, + }, +} + +impl FileSpanToByteRange { + /// Construct a span to byte range mapping from the program source. + fn new(source: String) -> FileSpanToByteRange { + if source.is_ascii() { + let line_lengths = source.lines().map(|line| line.len()).collect(); + FileSpanToByteRange::AsciiSource(line_lengths) + } else { + panic!() + } + } + + /// Lookup the corresponding ByteRange for a given Span. + fn lookup(&self, span: &Span) -> ByteRange { + use FileSpanToByteRange::*; + + let source_name: String = span.source_name.name.as_str().unwrap().into(); + + match self { + AsciiSource(ref line_lengths) => { + let start_pos = (&line_lengths[0..(span.line - 1) as usize]) + .into_iter() + .sum::() + + (span.column) as usize; + let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]) + .into_iter() + .sum::() + + (span.end_column) as usize; + ByteRange { + file_id: source_name, + start_pos, + end_pos, + } + } + _ => panic!(), + } + } +} + +/// A mapping for all files in a source map to byte ranges. +struct SpanToByteRange { + map: HashMap, +} + +impl SpanToByteRange { + fn new() -> SpanToByteRange { + SpanToByteRange { + map: HashMap::new(), + } + } + + /// Add a source file to the span mapping. + pub fn add_source(&mut self, source: Source) { + let source_name: String = source.source_name.name.as_str().expect("foo").into(); + + if self.map.contains_key(&source_name) { + panic!() + } else { + let source = source.source.as_str().expect("fpp").into(); + self.map + .insert(source_name, FileSpanToByteRange::new(source)); + } + } + + /// Lookup a span to byte range mapping. + /// + /// First resolves the Span to a file, and then maps the span to a byte range in the file. + pub fn lookup(&self, span: &Span) -> ByteRange { + let source_name: String = span.source_name.name.as_str().expect("foo").into(); + + match self.map.get(&source_name) { + Some(file_span_to_bytes) => file_span_to_bytes.lookup(span), + None => panic!(), + } + } +} + +/// The state of the `codespan` based diagnostics. +struct DiagnosticState { + files: SimpleFiles, + span_map: SpanToByteRange, + // todo unify wih source name + source_to_id: HashMap, +} + +impl DiagnosticState { + fn new() -> DiagnosticState { + DiagnosticState { + files: SimpleFiles::new(), + span_map: SpanToByteRange::new(), + source_to_id: HashMap::new(), + } + } + + fn add_source(&mut self, source: Source) { + let source_str: String = source.source.as_str().unwrap().into(); + let source_name: String = source.source_name.name.as_str().unwrap().into(); + self.span_map.add_source(source); + let file_id = self.files.add(source_name.clone(), source_str); + self.source_to_id.insert(source_name, file_id); + } + + fn to_diagnostic(&self, diag: super::Diagnostic) -> CDiagnostic { + let severity = match diag.level { + DiagnosticLevel::Error => Severity::Error, + DiagnosticLevel::Warning => Severity::Warning, + DiagnosticLevel::Note => Severity::Note, + DiagnosticLevel::Help => Severity::Help, + DiagnosticLevel::Bug => Severity::Bug, + }; + + let source_name: String = diag.span.source_name.name.as_str().unwrap().into(); + let file_id = *self.source_to_id.get(&source_name).unwrap(); + + let message: String = diag.message.as_str().unwrap().into(); + + let byte_range = self.span_map.lookup(&diag.span); + + let diagnostic = CDiagnostic::new(severity) + .with_message(message) + .with_code("EXXX") + .with_labels(vec![Label::primary( + file_id, + byte_range.start_pos..byte_range.end_pos, + )]); + + diagnostic + } +} + +fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) { + let source_map = diag_ctx.module.source_map.clone(); + let writer = StandardStream::stderr(ColorChoice::Always); + let config = codespan_reporting::term::Config::default(); + for diagnostic in diag_ctx.diagnostics.clone() { + match source_map.source_map.get(&diagnostic.span.source_name) { + Err(err) => panic!(err), + Ok(source) => { + state.add_source(source); + let diagnostic = state.to_diagnostic(diagnostic); + term::emit(&mut writer.lock(), &config, &state.files, &diagnostic).unwrap(); + } + } + } +} + +/// Initialize the `codespan` based diagnostics. +/// +/// Calling this function will globally override the TVM diagnostics renderer. +pub fn init() -> Result<()> { + let diag_state = Arc::new(Mutex::new(DiagnosticState::new())); + let render_fn = move |diag_ctx: DiagnosticContext| { + let mut guard = diag_state.lock().unwrap(); + renderer(&mut *guard, diag_ctx); + }; + + override_renderer(Some(render_fn))?; + Ok(()) +} diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs new file mode 100644 index 000000000000..051bb9eb16c4 --- /dev/null +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use super::module::IRModule; +use super::span::*; +use crate::runtime::function::Result; +use crate::runtime::object::{Object, ObjectPtr}; +use crate::runtime::{ + array::Array, + function::{self, Function, ToFunction}, + string::String as TString, +}; +/// The diagnostic interface to TVM, used for reporting and rendering +/// diagnostic information by the compiler. This module exposes +/// three key abstractions: a Diagnostic, the DiagnosticContext, +/// and the DiagnosticRenderer. +use tvm_macros::{external, Object}; + +pub mod codespan; + +external! { + #[name("node.ArrayGetItem")] + fn get_renderer() -> DiagnosticRenderer; + + #[name("diagnostics.DiagnosticRenderer")] + fn diagnostic_renderer(func: Function) -> DiagnosticRenderer; + + #[name("diagnostics.Emit")] + fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> (); + + #[name("diagnostics.DiagnosticContextDefault")] + fn diagnostic_context_default(module: IRModule) -> DiagnosticContext; + + #[name("diagnostics.DiagnosticContextRender")] + fn diagnostic_context_render(ctx: DiagnosticContext) -> (); + + #[name("diagnostics.DiagnosticRendererRender")] + fn diagnositc_renderer_render(renderer: DiagnosticRenderer,ctx: DiagnosticContext) -> (); + + #[name("diagnostics.ClearRenderer")] + fn clear_renderer() -> (); +} + +/// The diagnostic level, controls the printing of the message. +#[repr(C)] +pub enum DiagnosticLevel { + Bug = 10, + Error = 20, + Warning = 30, + Note = 40, + Help = 50, +} + +/// A compiler diagnostic. +#[repr(C)] +#[derive(Object)] +#[ref_name = "Diagnostic"] +#[type_key = "Diagnostic"] +pub struct DiagnosticNode { + pub base: Object, + /// The level. + pub level: DiagnosticLevel, + /// The span at which to report an error. + pub span: Span, + /// The diagnostic message. + pub message: TString, +} + +impl Diagnostic { + pub fn new(level: DiagnosticLevel, span: Span, message: TString) -> Diagnostic { + let node = DiagnosticNode { + base: Object::base::(), + level, + span, + message, + }; + ObjectPtr::new(node).into() + } + + pub fn bug(span: Span) -> DiagnosticBuilder { + DiagnosticBuilder::new(DiagnosticLevel::Bug, span) + } + + pub fn error(span: Span) -> DiagnosticBuilder { + DiagnosticBuilder::new(DiagnosticLevel::Error, span) + } + + pub fn warning(span: Span) -> DiagnosticBuilder { + DiagnosticBuilder::new(DiagnosticLevel::Warning, span) + } + + pub fn note(span: Span) -> DiagnosticBuilder { + DiagnosticBuilder::new(DiagnosticLevel::Note, span) + } + + pub fn help(span: Span) -> DiagnosticBuilder { + DiagnosticBuilder::new(DiagnosticLevel::Help, span) + } +} + +/// A wrapper around std::stringstream to build a diagnostic. +pub struct DiagnosticBuilder { + /// The level. + pub level: DiagnosticLevel, + + /// The span of the diagnostic. + pub span: Span, + + /// The in progress message. + pub message: String, +} + +impl DiagnosticBuilder { + pub fn new(level: DiagnosticLevel, span: Span) -> DiagnosticBuilder { + DiagnosticBuilder { + level, + span, + message: "".into(), + } + } +} + +/// Display diagnostics in a given display format. +/// +/// A diagnostic renderer is responsible for converting the +/// raw diagnostics into consumable output. +/// +/// For example the terminal renderer will render a sequence +/// of compiler diagnostics to std::out and std::err in +/// a human readable form. +#[repr(C)] +#[derive(Object)] +#[ref_name = "DiagnosticRenderer"] +#[type_key = "DiagnosticRenderer"] +/// A diagnostic renderer, which given a diagnostic context produces a "rendered" +/// form of the diagnostics for either human or computer consumption. +pub struct DiagnosticRendererNode { + /// The base type. + pub base: Object, + // TODO(@jroesch): we can't easily exposed packed functions due to + // memory layout + // missing field here +} + +impl DiagnosticRenderer { + /// Render the provided context. + pub fn render(&self, ctx: DiagnosticContext) -> Result<()> { + diagnositc_renderer_render(self.clone(), ctx) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "DiagnosticContext"] +#[type_key = "DiagnosticContext"] +/// A diagnostic context for recording errors against a source file. +pub struct DiagnosticContextNode { + // The base type. + pub base: Object, + + /// The Module to report against. + pub module: IRModule, + + /// The set of diagnostics to report. + pub diagnostics: Array, + + /// The renderer set for the context. + pub renderer: DiagnosticRenderer, +} + +/// A diagnostic context which records active errors +/// and contains a renderer. +impl DiagnosticContext { + pub fn new(module: IRModule, render_func: F) -> DiagnosticContext + where + F: Fn(DiagnosticContext) -> () + 'static, + { + let renderer = diagnostic_renderer(render_func.to_function()).unwrap(); + let node = DiagnosticContextNode { + base: Object::base::(), + module, + diagnostics: Array::from_vec(vec![]).unwrap(), + renderer, + }; + DiagnosticContext(Some(ObjectPtr::new(node))) + } + + pub fn default(module: IRModule) -> DiagnosticContext { + diagnostic_context_default(module).unwrap() + } + + /// Emit a diagnostic. + pub fn emit(&mut self, diagnostic: Diagnostic) -> Result<()> { + emit(self.clone(), diagnostic) + } + + /// Render the errors and raise a DiagnosticError exception. + pub fn render(&mut self) -> Result<()> { + diagnostic_context_render(self.clone()) + } + + /// Emit a diagnostic and then immediately attempt to render all errors. + pub fn emit_fatal(&mut self, diagnostic: Diagnostic) -> Result<()> { + self.emit(diagnostic)?; + self.render()?; + Ok(()) + } +} + +/// Override the global diagnostics renderer. +// render_func: Option[Callable[[DiagnosticContext], None]] +// If the render_func is None it will remove the current custom renderer +// and return to default behavior. +fn override_renderer(opt_func: Option) -> Result<()> +where + F: Fn(DiagnosticContext) -> () + 'static, +{ + match opt_func { + None => clear_renderer(), + Some(func) => { + let func = func.to_function(); + let render_factory = move || diagnostic_renderer(func.clone()).unwrap(); + + function::register_override(render_factory, "diagnostics.OverrideRenderer", true)?; + + Ok(()) + } + } +} diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index 91c42f0edbcf..f74522d91c70 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -35,7 +35,7 @@ pub struct BaseExprNode { impl BaseExprNode { pub fn base() -> BaseExprNode { BaseExprNode { - base: Object::base_object::(), + base: Object::base::(), } } } diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 126d0faccabb..6d5158005497 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -19,11 +19,13 @@ pub mod arith; pub mod attrs; +pub mod diagnostics; pub mod expr; pub mod function; pub mod module; pub mod op; pub mod relay; +pub mod source_map; pub mod span; pub mod tir; pub mod ty; diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index e0444b3101da..190b477b98f2 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -16,6 +16,10 @@ * specific language governing permissions and limitations * under the License. */ +use std::path::Path; + +use thiserror::Error; +use tvm_macros::Object; use crate::runtime::array::Array; use crate::runtime::function::Result; @@ -25,16 +29,20 @@ use crate::runtime::{external, Object, ObjectRef}; use super::expr::GlobalVar; use super::function::BaseFunc; - -use std::io::Result as IOResult; -use std::path::Path; - -use tvm_macros::Object; +use super::source_map::SourceMap; // TODO(@jroesch): define type type TypeData = ObjectRef; type GlobalTypeVar = ObjectRef; +#[derive(Error, Debug)] +pub enum Error { + #[error("{0}")] + IO(#[from] std::io::Error), + #[error("{0}")] + TVM(#[from] crate::runtime::Error), +} + #[repr(C)] #[derive(Object)] #[ref_name = "IRModule"] @@ -43,6 +51,8 @@ pub struct IRModuleNode { pub base: Object, pub functions: Map, pub type_definitions: Map, + pub source_map: SourceMap, + // TODO(@jroesch): this is missing some fields } external! { @@ -113,19 +123,21 @@ external! { // }); impl IRModule { - pub fn parse(file_name: N, source: S) -> IRModule + pub fn parse(file_name: N, source: S) -> Result where N: Into, S: Into, { - parse_module(file_name.into(), source.into()).expect("failed to call parser") + parse_module(file_name.into(), source.into()) } - pub fn parse_file>(file_path: P) -> IOResult { + pub fn parse_file>( + file_path: P, + ) -> std::result::Result { let file_path = file_path.as_ref(); let file_path_as_str = file_path.to_str().unwrap().to_string(); let source = std::fs::read_to_string(file_path)?; - let module = IRModule::parse(file_path_as_str, source); + let module = IRModule::parse(file_path_as_str, source)?; Ok(module) } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index e539221d1db6..cc1a76bef7e3 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -22,11 +22,12 @@ pub mod attrs; use std::hash::Hash; use crate::runtime::array::Array; -use crate::runtime::{object::*, String as TString}; +use crate::runtime::{object::*, IsObjectRef, String as TString}; use super::attrs::Attrs; use super::expr::BaseExprNode; use super::function::BaseFuncNode; +use super::span::Span; use super::ty::{Type, TypeNode}; use tvm_macros::Object; @@ -50,8 +51,8 @@ impl ExprNode { base: BaseExprNode::base::(), span: ObjectRef::null(), checked_type: Type::from(TypeNode { - base: Object::base_object::(), - span: ObjectRef::null(), + base: Object::base::(), + span: Span::null(), }), } } @@ -83,7 +84,7 @@ pub struct IdNode { impl Id { fn new(name_hint: TString) -> Id { let node = IdNode { - base: Object::base_object::(), + base: Object::base::(), name_hint: name_hint, }; Id(Some(ObjectPtr::new(node))) @@ -351,7 +352,7 @@ pub struct PatternNode { impl PatternNode { pub fn base() -> PatternNode { PatternNode { - base: Object::base_object::(), + base: Object::base::(), span: ObjectRef::null(), } } @@ -450,7 +451,7 @@ pub struct ClauseNode { impl Clause { pub fn new(lhs: Pattern, rhs: Expr, _span: ObjectRef) -> Clause { let node = ClauseNode { - base: Object::base_object::(), + base: Object::base::(), lhs, rhs, }; @@ -553,7 +554,8 @@ def @main() -> float32 { 0.01639530062675476f } "#, - ); + ) + .unwrap(); let main = module .lookup(module.get_global_var("main".to_string().into()).unwrap()) .unwrap(); diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs new file mode 100644 index 000000000000..54e16dac62ac --- /dev/null +++ b/rust/tvm/src/ir/source_map.rs @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either exprss or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::map::Map; +use crate::runtime::object::Object; +use crate::runtime::string::String as TString; + +use super::span::SourceName; + +use tvm_macros::Object; + +/// A program source in any language. +/// +/// Could represent the source from an ML framework or a source of an IRModule. +#[repr(C)] +#[derive(Object)] +#[type_key = "Source"] +#[ref_name = "Source"] +pub struct SourceNode { + pub base: Object, + /// The source name. + pub source_name: SourceName, + + /// The raw source. + pub source: TString, + // TODO(@jroesch): Non-ABI compat field + // A mapping of line breaks into the raw source. + // std::vector> line_map; +} + +/// A mapping from a unique source name to source fragments. +#[repr(C)] +#[derive(Object)] +#[type_key = "SourceMap"] +#[ref_name = "SourceMap"] +pub struct SourceMapNode { + /// The base object. + pub base: Object, + /// The source mapping. + pub source_map: Map, +} diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index d2e19a25a950..eb6821af69dc 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -1,22 +1,71 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the -use crate::runtime::ObjectRef; +* specific language governing permissions and limitations +* under the License. +*/ -pub type Span = ObjectRef; +use crate::runtime::{Object, ObjectPtr, String as TString}; +use tvm_macros::Object; + +/// A source file name, contained in a Span. +#[repr(C)] +#[derive(Object)] +#[type_key = "SourceName"] +#[ref_name = "SourceName"] +pub struct SourceNameNode { + pub base: Object, + pub name: TString, +} + +/// Span information for diagnostic purposes. +#[repr(C)] +#[derive(Object)] +#[type_key = "Span"] +#[ref_name = "Span"] +pub struct SpanNode { + pub base: Object, + /// The source name. + pub source_name: SourceName, + /// The line number. + pub line: i32, + /// The column offset. + pub column: i32, + /// The end line number. + pub end_line: i32, + /// The end column number. + pub end_column: i32, +} + +impl Span { + pub fn new( + source_name: SourceName, + line: i32, + end_line: i32, + column: i32, + end_column: i32, + ) -> Span { + let span_node = SpanNode { + base: Object::base::(), + source_name, + line, + end_line, + column, + end_column, + }; + Span(Some(ObjectPtr::new(span_node))) + } +} diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index b6a47f553da4..d12f094a63ea 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -36,7 +36,7 @@ pub struct TypeNode { impl TypeNode { fn base(span: Span) -> Self { TypeNode { - base: Object::base_object::(), + base: Object::base::(), span, } } diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index 36c750328249..7e0682b86b33 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -24,7 +24,7 @@ //! One particular use case is that given optimized deep learning model artifacts, //! (compiled with TVM) which include a shared library //! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them -//! in Rust idomatically to create a TVM Graph Runtime and +//! in Rust idiomatically to create a TVM Graph Runtime and //! run the model for some inputs and get the //! desired predictions *all in Rust*. //! @@ -47,3 +47,28 @@ pub mod runtime; pub mod transform; pub use runtime::version; + +#[macro_export] +macro_rules! export { + ($($fn_name:expr),*) => { + pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> { + $( + let name = String::fromwe(ns) + ::std::stringify!($fn_name); + tvm::runtime::function::register_override($fn_name, name, true)?; + )* + Ok(()) + } + } +} + +#[macro_export] +macro_rules! export_mod { + ($ns:expr, $($mod_name:expr),*) => { + pub fn tvm_mod_export() -> Result<(), tvm::Error> { + $( + $mod_name::tvm_export($ns)?; + )* + Ok(()) + } + } +} diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index 59fc60450825..c5a65c417c93 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -50,7 +50,7 @@ impl PassInfo { let required = Array::from_vec(required)?; let node = PassInfoNode { - base: Object::base_object::(), + base: Object::base::(), opt_level, name: name.into(), required, diff --git a/src/contrib/rust_extension.cc b/src/contrib/rust_extension.cc new file mode 100644 index 000000000000..46e94fffdf55 --- /dev/null +++ b/src/contrib/rust_extension.cc @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/rust_extension.cc + * \brief Expose Rust extensions initialization. + */ +#ifdef RUST_COMPILER_EXT + +extern "C" { +int compiler_ext_initialize(); +static int test = compiler_ext_initialize(); +} + +#endif diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 148831dc3ab6..e533972cc71a 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -113,6 +113,7 @@ TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender") }); DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) { + CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function"; auto n = make_object(); n->module = module; n->renderer = renderer; @@ -167,6 +168,10 @@ DiagnosticContext DiagnosticContext::Default(const IRModule& module) { return DiagnosticContext(module, renderer); } +TVM_REGISTER_GLOBAL("diagnostics.Default").set_body_typed([](const IRModule& module) { + return DiagnosticContext::Default(module); +}); + std::ostream& EmitDiagnosticHeader(std::ostream& out, const Span& span, DiagnosticLevel level, std::string msg) { rang::fg diagnostic_color = rang::fg::reset; diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 40998b0c9dc4..c6ea808733e1 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -77,12 +77,6 @@ tvm::String Source::GetLine(int line) { return line_text; } -// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { -// auto* node = static_cast(ref.get()); -// p->stream << "SourceName(" << node->name << ", " << node << ")"; -// }); - TVM_REGISTER_NODE_TYPE(SourceMapNode); SourceMap::SourceMap(Map source_map) { @@ -91,11 +85,6 @@ SourceMap::SourceMap(Map source_map) { data_ = std::move(n); } -// TODO(@jroesch): fix this -static SourceMap global_source_map = SourceMap(Map()); - -SourceMap SourceMap::Global() { return global_source_map; } - void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) {